Model
- class poutyne.Model(network, optimizer, loss_function, *, batch_metrics=None, epoch_metrics=None, device=None)[source]
The Model class encapsulates a PyTorch network, a PyTorch optimizer, a loss function and metric functions. It allows the user to train a neural network without hand-coding the epoch/step logic.
- Parameters:
network (torch.nn.Module) – A PyTorch network.
optimizer (Union[torch.optim.Optimizer, str, dict]) – If torch.optim.Optimier, an initialized PyTorch. If str, should be the name of the optimizer in Pytorch (i.e. ‘Adam’ for torch.optim.Adam). If dict, should contain a key
'optim'
with the value be the name of the optimizer; other entries are passed to the optimizer as keyword arguments. (Default value = None)loss_function (Union[Callable, str]) – can also be a string with the same name as a PyTorch loss function (either the functional or object name). The loss function must have the signature
loss_function(input, target)
whereinput
is the prediction of the network andtarget
is the ground truth. (Default value = None)batch_metrics (list) –
List of functions with the same signature as a loss function or objects with the same signature as either
Metric
ortorchmetrics.Metric
. It can also be a string with the same name as a PyTorch loss function (either the functional or object name). Some metrics, such as ‘accuracy’ (or just ‘acc’), are also available as strings. See Metrics and the TorchMetrics documentation for available metrics.Batch metric are computed on computed for each batch. (Default value = None)
Warning
When using this argument, the metrics are computed for each batch. This can significantly slow down the compuations depending on the metrics used. This mostly happens on non-decomposable metrics such as
torchmetrics.AUROC
where an ordering of the elements is necessary to compute the metric. In such case, we advise to use them as epoch metrics instead.epoch_metrics (list) –
List of functions with the same signature as a loss function or objects with the same signature as either
Metric
ortorchmetrics.Metric
. It can also be a string with the same name as a PyTorch loss function (either the functional or object name). Some metrics, such as ‘accuracy’ (or just ‘acc’), are also available as strings. See Metrics and the TorchMetrics documentation for available metrics.Epoch metrics are computed only at the end of the epoch. (Default value = None)
device (Union[torch.torch.device, List[torch.torch.device]]) – The device to which the network is sent or the list of device to which the network is sent. See
to()
for details.
Note
The name of each batch and epoch metric can be change by passing a tuple
(name, metric)
instead of simply the metric function or object, wherename
is the alternative name of the metric. Batch and epoch metrics can return multiple metrics (e.g. an epoch metric could return an F1-score with the associated precision and recall). See Computing Multiple Metrics at Once for more details.- network
The associated PyTorch network.
- Type:
- optimizer
The associated PyTorch optimizer.
- Type:
- loss_function
The associated loss function.
Examples
Using Numpy arrays (or tensors) dataset:
from poutyne import Model import torch import numpy as np import torchmetrics num_features = 20 num_classes = 5 # Our training dataset with 800 samples. num_train_samples = 800 train_x = np.random.randn(num_train_samples, num_features).astype('float32') train_y = np.random.randint(num_classes, size=num_train_samples).astype('int64') # Our validation dataset with 200 samples. num_valid_samples = 200 valid_x = np.random.randn(num_valid_samples, num_features).astype('float32') valid_y = np.random.randint(num_classes, size=num_valid_samples).astype('int64') pytorch_network = torch.nn.Linear(num_features, num_classes) # Our network # We create and optimize our model model = Model(pytorch_network, 'sgd', 'cross_entropy', batch_metrics=['accuracy'], epoch_metrics=[torchmetrics.AUROC(num_classes=num_classes, task="multiclass")]) model.fit(train_x, train_y, validation_data=(valid_x, valid_y), epochs=5, batch_size=32)
Epoch: 1/5 Train steps: 25 Val steps: 7 0.51s loss: 1.757784 acc: 20.750000 auroc: 0.494891 val_loss: 1.756639 val_acc: 18.500000 val_auroc: 0.499404 Epoch: 2/5 Train steps: 25 Val steps: 7 0.03s loss: 1.749623 acc: 20.375000 auroc: 0.496878 val_loss: 1.748795 val_acc: 19.000000 val_auroc: 0.499723 Epoch: 3/5 Train steps: 25 Val steps: 7 0.03s loss: 1.742070 acc: 20.250000 auroc: 0.499461 val_loss: 1.741379 val_acc: 19.000000 val_auroc: 0.498577 ...
Using PyTorch DataLoader:
import torch from torch.utils.data import DataLoader, TensorDataset from poutyne import Model import torchmetrics num_features = 20 num_classes = 5 # Our training dataset with 800 samples. num_train_samples = 800 train_x = torch.rand(num_train_samples, num_features) train_y = torch.randint(num_classes, (num_train_samples,), dtype=torch.long) train_dataset = TensorDataset(train_x, train_y) train_generator = DataLoader(train_dataset, batch_size=32) # Our validation dataset with 200 samples. num_valid_samples = 200 valid_x = torch.rand(num_valid_samples, num_features) valid_y = torch.randint(num_classes, (num_valid_samples,), dtype=torch.long) valid_dataset = TensorDataset(valid_x, valid_y) valid_generator = DataLoader(valid_dataset, batch_size=32) pytorch_network = torch.nn.Linear(num_features, num_classes) model = Model(pytorch_network, 'sgd', 'cross_entropy', batch_metrics=['accuracy'], epoch_metrics=[torchmetrics.AUROC(num_classes=num_classes, task="multiclass")]) model.fit_generator(train_generator, valid_generator, epochs=5)
Epoch: 1/5 Train steps: 25 Val steps: 7 0.07s loss: 1.614473 acc: 20.500000 auroc: 0.516850 val_loss: 1.617141 val_acc: 21.500000 val_auroc: 0.522068 Epoch: 2/5 Train steps: 25 Val steps: 7 0.03s loss: 1.614454 acc: 20.125000 auroc: 0.517618 val_loss: 1.615585 val_acc: 22.000000 val_auroc: 0.521051 Epoch: 3/5 Train steps: 25 Val steps: 7 0.03s loss: 1.613709 acc: 20.125000 auroc: 0.518307 val_loss: 1.614440 val_acc: 22.000000 val_auroc: 0.520762 ...
- fit(x, y, validation_data=None, *, batch_size=32, epochs=1000, steps_per_epoch=None, validation_steps=None, batches_per_step=1, initial_epoch=1, verbose=True, progress_options: dict | None = None, callbacks=None, dataloader_kwargs=None)[source]
Trains the network on a dataset. This method creates generators and calls the
fit_generator()
method.- Parameters:
x (Union[Tensor, ndarray] or Union[tuple, list] of Union[Tensor, ndarray]) – Training dataset. Union[Tensor, ndarray] if the model has a single input. Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
y (Union[Tensor, ndarray] or Union[tuple, list] of Union[Tensor, ndarray]) – Target. Union[Tensor, ndarray] if the model has a single output. Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple outputs.
validation_data (Tuple[
x_val
,y_val
]) – Same format asx
andy
previously described. Validation dataset on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. (Default value = None)batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
epochs (int) – Number of times the entire training dataset is seen. (Default value = 1000)
steps_per_epoch (int, optional) – Number of batch used during one epoch. Obviously, using this argument may cause one epoch not to see the entire training dataset or see it multiple times. (Defaults the number of steps needed to see the entire training dataset)
validation_steps (int, optional) – Same as for
steps_per_epoch
but for the validation dataset. (Defaults to the number of steps needed to see the entire validation dataset)batches_per_step (int) – Number of batches on which to compute the running loss before backpropagating it through the network. Note that the total loss used for backpropagation is the mean of the batches_per_step batch losses. (Default value = 1)
initial_epoch (int, optional) – Epoch at which to start training (useful for resuming a previous training run). (Default value = 1)
verbose (bool) – Whether to display the progress of the training. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None)callbacks (List[Callback]) – List of callbacks that will be called during training. (Default value = None)
dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally. By default,
shuffle=True
is passed for the training dataloader but this can be overridden by using this argument.
- Returns:
List of dict containing the history of each epoch.
Example
model = Model(pytorch_network, optimizer, loss_function) history = model.fit(train_x, train_y, validation_data=(valid_x, valid_y) epochs=num_epochs, batch_size=batch_size, verbose=False) print(*history, sep="\n")
{'epoch': 1, 'loss': 1.7198852968215943, 'time': 0.019999928001197986, 'acc': 19.375, 'val_loss': 1.6674459838867188, 'val_acc': 22.0} {'epoch': 2, 'loss': 1.7054892110824584, 'time': 0.015421080999658443, 'acc': 19.75, 'val_loss': 1.660806336402893, 'val_acc': 22.0} {'epoch': 3, 'loss': 1.6923445892333984, 'time': 0.01363091799794347, 'acc': 19.625, 'val_loss': 1.6550078630447387, 'val_acc': 22.5} ...
- fit_dataset(train_dataset, valid_dataset=None, *, batch_size=32, epochs=1000, steps_per_epoch=None, validation_steps=None, batches_per_step=1, initial_epoch=1, verbose=True, progress_options=None, callbacks=None, num_workers=0, collate_fn=None, dataloader_kwargs=None)[source]
Trains the network on a dataset. This method creates dataloaders and calls the
fit_generator()
method.- Parameters:
train_dataset (Dataset) – Training dataset.
valid_dataset (Dataset) – Validation dataset.
batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
epochs (int) – Number of times the entire training dataset is seen. (Default value = 1000)
steps_per_epoch (int, optional) – Number of batch used during one epoch. Obviously, using this argument may cause one epoch not to see the entire training dataset or see it multiple times. (Defaults the number of steps needed to see the entire training dataset)
validation_steps (int, optional) – Same as for
steps_per_epoch
but for the validation dataset. (Defaults to the number of steps needed to see the entire validation dataset)batches_per_step (int) – Number of batches on which to compute the running loss before backpropagating it through the network. Note that the total loss used for backpropagation is the mean of the batches_per_step batch losses. (Default value = 1)
initial_epoch (int, optional) – Epoch at which to start training (useful for resuming a previous training run). (Default value = 1)
verbose (bool) – Whether to display the progress of the training. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None)callbacks (List[Callback]) – List of callbacks that will be called during training. (Default value = None)
dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally. By default,
shuffle=True
is passed for the training dataloader but this can be overridden by using this argument.num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (Default value = 0)collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
- Returns:
List of dict containing the history of each epoch.
See
DataLoader
for details onbatch_size
,num_workers
andcollate_fn
.Example
model = Model(pytorch_network, optimizer, loss_function) history = model.fit(train_dataset, valid_dataset, epochs=num_epochs, batch_size=batch_size, verbose=False) print(*history, sep="\n")
{'epoch': 1, 'loss': 1.7198852968215943, 'time': 0.019999928001197986, 'acc': 19.375, 'val_loss': 1.6674459838867188, 'val_acc': 22.0} {'epoch': 2, 'loss': 1.7054892110824584, 'time': 0.015421080999658443, 'acc': 19.75, 'val_loss': 1.660806336402893, 'val_acc': 22.0} {'epoch': 3, 'loss': 1.6923445892333984, 'time': 0.01363091799794347, 'acc': 19.625, 'val_loss': 1.6550078630447387, 'val_acc': 22.5} ...
- fit_generator(train_generator, valid_generator=None, *, epochs=1000, steps_per_epoch=None, validation_steps=None, batches_per_step=1, initial_epoch=1, verbose=True, progress_options: dict | None = None, callbacks=None)[source]
Trains the network on a dataset using a generator.
- Parameters:
train_generator –
Generator-like object for the training dataset. The generator must yield a batch in the form of a tuple (x, y) where
x
is the input andy
is the target. The batch size is inferred fromx
andy
. Seeget_batch_size()
for details on the inferring algorithm. The loss and the metrics are averaged using this batch size. If the batch size cannot be inferred then a warning is raised and the “batch size” defaults to 1.If the generator does not have a method
__len__()
, either thesteps_per_epoch
argument must be provided, or the iterator returned raises a StopIteration exception at the end of the training dataset. PyTorch DataLoaders object do provide a__len__()
method.Before each epoch, the method
__iter__()
on the generator is called and the method__next__()
is called for each step on resulting object returned by__iter__()
. Notice that a call to__iter__()
on a generator made using the python keywordyield
returns the generator itself.valid_generator (optional) – Generator-like object for the validation dataset. This generator is optional. The generator is used the same way as the generator
train_generator
. If the generator does not have a method__len__()
, either thevalidation_steps
or thesteps_per_epoch
argument must be provided or the iterator returned raises a StopIteration exception at the end of the validation dataset. (Default value = None)epochs (int) – Number of times the entire training dataset is seen. (Default value = 1000)
steps_per_epoch (int, optional) – Number of batch used during one epoch. Obviously, using this argument may cause one epoch not to see the entire training dataset or see it multiple times. See argument
train_generator
andvalid_generator
for more details of howsteps_per_epoch
is used.validation_steps (int, optional) – Same as for
steps_per_epoch
but for the validation dataset. See argumentvalid_generator
for more details of howvalidation_steps
is used.batches_per_step (int) – Number of batches on which to compute the running loss before backpropagating it through the network. Note that the total loss used for backpropagation is the mean of the batches_per_step batch losses. (Default value = 1)
initial_epoch (int, optional) – Epoch at which to start training (useful for resuming a previous training run). (Default value = 1)
verbose (bool) – Whether to display the progress of the training. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)callbacks (List[Callback]) – List of callbacks that will be called during training. (Default value = None)
- Returns:
List of dict containing the history of each epoch.
Example
model = Model(pytorch_network, optimizer, loss_function) history = model.fit_generator(train_generator, valid_generator, epochs=num_epochs, verbose=False) print(*history, sep="\n")
{'epoch': 1, 'loss': 1.7198852968215943, 'time': 0.019999928001197986, 'acc': 19.375, 'val_loss': 1.6674459838867188, 'val_acc': 22.0} {'epoch': 2, 'loss': 1.7054892110824584, 'time': 0.015421080999658443, 'acc': 19.75, 'val_loss': 1.660806336402893, 'val_acc': 22.0} {'epoch': 3, 'loss': 1.6923445892333984, 'time': 0.01363091799794347, 'acc': 19.625, 'val_loss': 1.6550078630447387, 'val_acc': 22.5} ...
- train_on_batch(x, y, *, return_pred=False, return_dict_format=False, convert_to_numpy=True)[source]
Trains the network for the batch
(x, y)
and computes the loss and the metrics, and optionally returns the predictions.- Parameters:
x – Input data as a batch.
y – Target data as a batch.
return_pred (bool, optional) – Whether to return the predictions. (Default value = False)
return_dict_format (bool, optional) – Whether to return the loss and metrics in a dict format or not. (Default value = False)
convert_to_numpy (bool, optional) – Whether to convert the predictions into Numpy Arrays when
return_pred
is true. (Default value = True)
- Returns:
Float
loss
if no metrics were specified andreturn_pred
is false.Otherwise, tuple
(loss, metrics)
ifreturn_pred
is false.metrics
is a Numpy array of sizen
, wheren
is the number of metrics ifn > 1
. Ifn == 1
, thenmetrics
is a float. Ifn == 0
, themetrics
is omitted.Tuple
(loss, metrics, pred_y)
ifreturn_pred
is true wherepred_y
is the predictions with tensors converted into Numpy arrays.If
return_dict_format
is True, thenloss, metrics
are replaced by a dictionary.
- predict(x, *, batch_size=32, convert_to_numpy=True, verbose=True, progress_options: dict | None = None, callbacks=None, dataloader_kwargs=None) Any [source]
Returns the predictions of the network given a dataset
x
, where the tensors are converted into Numpy arrays.- Parameters:
x (Union[Tensor, ndarray] or Union[tuple, list] of Union[Tensor, ndarray]) – Input to the model. Union[Tensor, ndarray] if the model has a single input. Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
concatenate_returns (bool, optional) – Whether to concatenate the predictions when returning them. (Default value = True)
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally.
- Returns:
Return the predictions in the format outputted by the model.
- predict_dataset(dataset, *, batch_size=32, steps=None, has_ground_truth=False, return_ground_truth=False, concatenate_returns=True, convert_to_numpy=True, num_workers=0, collate_fn=None, verbose=True, progress_options: dict | None = None, callbacks=None, dataloader_kwargs=None) Any [source]
Returns the predictions of the network given a dataset
x
, where the tensors are converted into Numpy arrays.- Parameters:
dataset (Dataset) – Dataset. Must not return
y
, justx
, unless has_ground_truth is true.batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
steps (int, optional) – Number of iterations done on
generator
. (Defaults the number of steps needed to see the entire dataset)has_ground_truth (bool, optional) – Whether the generator yields the target
y
. Automatically set to true if return_ground_truth is true. (Default value = False)return_ground_truth (bool, optional) – Whether to return the ground truths. If true, automatically set has_ground_truth to true. (Default value = False)
concatenate_returns (bool, optional) – Whether to concatenate the predictions or the ground truths when returning them. See
predict_generator()
for details. (Default value = True)concatenate_returns – Whether to concatenate the predictions or the ground truths when returning them. (Default value = True)
num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (Default value = 0)collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally.
- Returns:
Depends on the value of
concatenate_returns
. By default, (concatenate_returns
is true), the data structures (tensor, tuple, list, dict) returned as predictions for the batches are merged together. In the merge, the tensors are converted into Numpy arrays and are then concatenated together. Ifconcatenate_returns
is false, then a list of the predictions for the batches is returned with tensors converted into Numpy arrays.
- See:
DataLoader
for details onbatch_size
,num_workers
andcollate_fn
.
- predict_generator(generator, *, steps=None, has_ground_truth=False, return_ground_truth=False, concatenate_returns=True, convert_to_numpy=True, verbose=True, progress_options: dict | None = None, callbacks=None) Any [source]
Returns the predictions of the network given batches of samples
x
, where the tensors are converted into Numpy arrays.- Parameters:
generator – Generator-like object for the dataset. The generator must yield a batch of samples. See the
fit_generator()
method for details on the types of generators supported. This should only yield input datax
and NOT the targety
, unless has_ground_truth is true.steps (int, optional) – Number of iterations done on
generator
. (Defaults the number of steps needed to see the entire dataset)has_ground_truth (bool, optional) – Whether the generator yields the target
y
. Automatically set to true if return_ground_truth is true. (Default value = False)return_ground_truth (bool, optional) – Whether to return the ground truths. If true, automatically set has_ground_truth to true. (Default value = False)
concatenate_returns (bool, optional) – Whether to concatenate the predictions or the ground truths when returning them. (Default value = True)
convert_to_numpy (bool, optional) – Whether to convert the predictions or ground truths into Numpy Arrays. (Default value = True)
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
- Returns:
Depends on the value of
concatenate_returns
. By default, (concatenate_returns
is true), the data structures (tensor, tuple, list, dict) returned as predictions for the batches are merged together. In the merge, the tensors are converted into Numpy arrays and are then concatenated together. Ifconcatenate_returns
is false, then a list of the predictions for the batches is returned with tensors converted into Numpy arrays.
- predict_on_batch(x, *, convert_to_numpy=True) Any [source]
Returns the predictions of the network given a batch
x
, where the tensors are converted into Numpy arrays.- Parameters:
x – Input data as a batch.
convert_to_numpy (bool, optional) – Whether to convert the predictions into Numpy Arrays. (Default value = True)
- Returns:
Return the predictions in the format outputted by the model.
- evaluate(x, y, *, batch_size=32, return_pred=False, return_dict_format=False, convert_to_numpy=True, callbacks=None, verbose=True, progress_options: dict | None = None, dataloader_kwargs=None) Tuple [source]
Computes the loss and the metrics of the network on batches of samples and optionally returns the predictions.
- Parameters:
x (Union[Tensor, ndarray] or Union[tuple, list] of Union[Tensor, ndarray]) – Input to the model. Union[Tensor, ndarray] if the model has a single input. Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple inputs.
y (Union[Tensor, ndarray] or Union[tuple, list] of Union[Tensor, ndarray]) – Target, corresponding ground truth. Union[Tensor, ndarray] if the model has a single output. Union[tuple, list] of Union[Tensor, ndarray] if the model has multiple outputs.
batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
return_pred (bool, optional) – Whether to return the predictions. (Default value = False)
return_dict_format (bool, optional) – Whether to return the loss and metrics in a dict format or not. (Default value = False)
convert_to_numpy (bool, optional) – Whether to convert the predictions into Numpy Arrays when
return_pred
is true. (Default value = True)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally.
- Returns:
Tuple
(loss, metrics, pred_y)
where specific elements are omitted if not applicable. If only loss is applicable, then it is returned as a float.metrics
is a Numpy array of sizen
, wheren
is the number of batch metrics plus the number of epoch metrics ifn > 1
. Ifn == 1
, thenmetrics
is a float. Ifn == 0
, themetrics
is omitted. The first elements ofmetrics
are the batch metrics and are followed by the epoch metrics. See thefit_generator()
method for examples with batch metrics and epoch metrics.If
return_pred
is True,pred_y
is the list of the predictions of each batch with tensors converted into Numpy arrays. It is otherwise omitted.If
return_dict_format
is True, thenloss, metrics
are replaced by a dictionary as passed toon_test_end()
.
- evaluate_dataset(dataset, *, batch_size=32, steps=None, return_pred=False, return_ground_truth=False, return_dict_format=False, concatenate_returns=True, convert_to_numpy=True, callbacks=None, num_workers=0, collate_fn=None, dataloader_kwargs=None, verbose=True, progress_options: dict | None = None) Tuple [source]
Computes the loss and the metrics of the network on batches of samples and optionally returns the predictions.
- Parameters:
dataset (Dataset) – Dataset.
batch_size (int) – Number of samples given to the network at one time. (Default value = 32)
steps (int, optional) – Number of batches used for evaluation. (Defaults the number of steps needed to see the entire dataset)
return_pred (bool, optional) – Whether to return the predictions. (Default value = False)
return_ground_truth (bool, optional) – Whether to return the ground truths. (Default value = False)
return_dict_format (bool, optional) – Whether to return the loss and metrics in a dict format or not. (Default value = False)
concatenate_returns (bool, optional) – Whether to concatenate the predictions or the ground truths when returning them. (Default value = True)
convert_to_numpy (bool, optional) – Whether to convert the predictions or ground truths into Numpy Arrays when
return_pred
orreturn_ground_truth
are true. (Default value = True)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
num_workers (int, optional) – how many subprocesses to use for data loading.
0
means that the data will be loaded in the main process. (Default value = 0)collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
dataloader_kwargs (dict, optional) – Keyword arguments to pass to the PyTorch dataloaders created internally.
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)
- Returns:
Tuple
(loss, metrics, pred_y)
where specific elements are omitted if not applicable. If only loss is applicable, then it is returned as a float.metrics
is a Numpy array of sizen
, wheren
is the number of batch metrics plus the number of epoch metrics ifn > 1
. Ifn == 1
, thenmetrics
is a float. Ifn == 0
, themetrics
is omitted. The first elements ofmetrics
are the batch metrics and are followed by the epoch metrics. See thefit_generator()
method for examples with batch metrics and epoch metrics.If
return_pred
is True,pred_y
is the list of the predictions of each batch with tensors converted into Numpy arrays. It is otherwise omitted.If
return_dict_format
is True, thenloss, metrics
are replaced by a dictionary as passed toon_test_end()
.
- See:
DataLoader
for details onbatch_size
,num_workers
andcollate_fn
.
- evaluate_generator(generator, *, steps=None, return_pred=False, return_ground_truth=False, return_dict_format=False, concatenate_returns=True, convert_to_numpy=True, verbose=True, progress_options: dict | None = None, callbacks=None) Tuple [source]
Computes the loss and the metrics of the network on batches of samples and optionally returns the predictions.
- Parameters:
generator – Generator-like object for the dataset. See the
fit_generator()
method for details on the types of generators supported.steps (int, optional) – Number of iterations done on
generator
. (Defaults the number of steps needed to see the entire dataset)return_pred (bool, optional) – Whether to return the predictions. (Default value = False)
return_ground_truth (bool, optional) – Whether to return the ground truths. (Default value = False)
return_dict_format (bool, optional) – Whether to return the loss and metrics in a dict format or not. (Default value = False)
convert_to_numpy (bool, optional) – Whether to convert the predictions or ground truths into Numpy Arrays when
return_pred
orreturn_ground_truth
are true. (Default value = True)concatenate_returns (bool, optional) – Whether to concatenate the predictions or the ground truths when returning them. (Default value = True)
verbose (bool) – Whether to display the progress of the evaluation. (Default value = True)
progress_options (dict, optional) – Keyword arguments to pass to the default progression callback used in Poutyne (See
ProgressionCallback
for the available arguments). (Default value = None, meaning default color setting and progress bar)callbacks (List[Callback]) – List of callbacks that will be called during testing. (Default value = None)
- Returns:
Tuple
(loss, metrics, pred_y, true_y)
where specific elements are omitted if not applicable. If only loss is applicable, then it is returned as a float.metrics
is a Numpy array of sizen
, wheren
is the number of batch metrics plus the number of epoch metrics ifn > 1
. Ifn == 1
, thenmetrics
is a float. Ifn == 0
, themetrics
is omitted. The first elements ofmetrics
are the batch metrics and are followed by the epoch metrics.If
return_pred
is True,pred_y
is the predictions returned as in thepredict_generator()
method. It is otherwise ommited.If
return_ground_truth
is True,true_y
is the ground truths returned as in thepredict_generator()
method. It is otherwise omitted.If
return_dict_format
is True, thenloss, metrics
are replaced by a dictionary as passed toon_test_end()
.
Example
With no metrics:
model = Model(pytorch_network, optimizer, loss_function, batch_metrics=None) loss = model.evaluate_generator(test_generator)
With only one batch metric:
model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric_fn]) loss, my_metric = model.evaluate_generator(test_generator)
With several batch metrics:
model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric1_fn, my_metric2_fn]) loss, (my_metric1, my_metric2) = model.evaluate_generator(test_generator)
With one batch metric and one epoch metric:
model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric_fn], epoch_metrics=[MyMetricClass()]) loss, (my_batch_metric, my__epoch_metric) = model.evaluate_generator(test_generator)
With batch metrics and
return_pred
flag:model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric1_fn, my_metric2_fn]) loss, (my_metric1, my_metric2), pred_y = model.evaluate_generator( test_generator, return_pred=True )
With batch metrics,
return_pred
andreturn_ground_truth
flags:model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric1_fn, my_metric2_fn]) loss, (my_metric1, my_metric2), pred_y, true_y = model.evaluate_generator( test_generator, return_pred=True, return_ground_truth=True )
With
return_dict_format
:model = Model(pytorch_network, optimizer, loss_function, batch_metrics=[my_metric_fn]) logs = model.evaluate_generator(test_generator, return_dict_format=True)
- evaluate_on_batch(x, y, *, return_pred=False, return_dict_format=False, convert_to_numpy=True) Tuple [source]
Computes the loss and the metrics of the network on a single batch of samples and optionally returns the predictions.
- Parameters:
x – Input data as a batch.
y – Target data as a batch.
return_pred (bool, optional) – Whether to return the predictions for
batch
. (Default value = False)return_dict_format (bool, optional) – Whether to return the loss and metrics in a dict format or not. (Default value = False)
convert_to_numpy (bool, optional) – Whether to convert the predictions into Numpy Arrays when
return_pred
is true. (Default value = True)
- Returns:
Tuple
(loss, metrics, pred_y)
where specific elements are omitted if not applicable. If only loss is applicable, then it is returned as a float.metrics` is a Numpy array of size
n
, wheren
is the number of metrics ifn > 1
. Ifn == 1
, thenmetrics
is a float. Ifn == 0
, themetrics
is omitted.If
return_pred
is True,pred_y
is the list of the predictions of each batch with tensors converted into Numpy arrays. It is otherwise omitted.If
return_dict_format
is True, thenloss, metrics
are replaced by a dictionary.
- load_weights(f, strict=True)[source]
Loads the weights saved using the
torch.save()
method or thesave_weights()
method of this class. Contrary totorch.load()
, the weights are not transferred to the device from which they were saved from. In other words, the PyTorch module will stay on the same device it already is on.- Parameters:
f – File-like object (has to implement fileno that returns a file descriptor) or string containing a file name.
- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
- save_weights(f)[source]
Saves the weights of the current network.
- Parameters:
f – File-like object (has to implement fileno that returns a file descriptor) or string containing a file name.
- load_optimizer_state(f)[source]
Loads the optimizer state saved using the
torch.save()
method or thesave_optimizer_state()
method of this class.- Parameters:
f – File-like object (has to implement fileno that returns a file descriptor) or string containing a file name.
- save_optimizer_state(f)[source]
Saves the state of the current optimizer.
- Parameters:
f – File-like object (has to implement fileno that returns a file descriptor) or string containing a file name.
- get_weights()[source]
Returns a dictionary containing the parameters of the network. The tensors are just references to the parameters. To get copies of the weights, see the
get_weight_copies()
method.
- get_weight_copies()[source]
Returns a dictionary containing copies of the parameters of the network.
- set_weights(weights, strict=True)[source]
Modifies the weights of the network with the given weights.
- Parameters:
weights (dict) – Weights returned by either
get_weights()
orget_weight_copies()
.- Returns:
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type:
NamedTuple
withmissing_keys
andunexpected_keys
fields
- cuda(*args, **kwargs)[source]
Transfers the network on the GPU. The arguments are passed to the
torch.nn.Module.cuda()
method. Notice that the device is saved so that the batches can send to the right device before passing it to the network.Note
PyTorch optimizers assume that the parameters have been transferred to the right device before their creations. Furthermore, future versions of PyTorch will no longer modify the parameters of a PyTorch module in-place when transferring them to another device. See this issue and this pull request for details.
Since Poutyne supposes that the optimizer has been initialized before the Poutyne Model, necessarily the parameters are not guaranteed to be in sync with those contained in the optimizer once the PyTorch module is transferred to another device. Thus, this method takes care of this inconsistency by updating the parameters inside the optimizer.
- Returns:
self.
- cpu(*args, **kwargs)[source]
Transfers the network on the CPU. The arguments are passed to the
torch.nn.Module.cpu()
method. Notice that the device is saved so that the batches can send to the right device before passing it to the network.Note
PyTorch optimizers assume that the parameters have been transferred to the right device before their creations. Furthermore, future versions of PyTorch will no longer modify the parameters of a PyTorch module in-place when transferring them to another device. See this issue and this pull request for details.
Since Poutyne supposes that the optimizer has been initialized before the Poutyne Model, necessarily the parameters are not guaranteed to be in sync with those contained in the optimizer once the PyTorch module is transferred to another device. Thus, this method takes care of this inconsistency by updating the parameters inside the optimizer.
- Returns:
self.
- to(device)[source]
Transfer the network on the specified device. The device is saved so that the batches can send to the right device before passing it to the network. One could also use multi GPUs by using either a list of devices or “all” to take all the available devices. In both cases, the training loop will use the ~torch.nn.parallel.data_parallel() function for single node multi GPUs parallel process and the main device is the first device.
Note
PyTorch optimizers assume that the parameters have been transferred to the right device before their creations. Furthermore, future versions of PyTorch will no longer modify the parameters of a PyTorch module in-place when transferring them to another device. See this issue and this pull request for details.
Since Poutyne supposes that the optimizer has been initialized before the Poutyne Model, necessarily the parameters are not guaranteed to be in sync with those contained in the optimizer once the PyTorch module is transferred to another device. Thus, this method takes care of this inconsistency by updating the parameters inside the optimizer.
- Parameters:
device (Union[torch.torch.device, List[torch.torch.device]]) – The device to which the network is sent or
sent. (the list of device to which the network is) –
- Returns:
self.