Basic class for handling the training loop
from nbdev.showdoc import *

We'll use the following for testing purposes (a basic linear regression problem):

from import TensorDataset

def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):
    def get_data(n):
        x = torch.randn(int(bs*n))
        return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))
    train_ds = get_data(n_train)
    valid_ds = get_data(n_valid)
    device = default_device() if cuda else None
    train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)
    valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)
    return DataLoaders(train_dl, valid_dl, device=device)

class RegModel(Module):
    def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))
    def forward(self, x): return x*self.a + self.b


replacing_yield(o, attr, val)

Context manager to temporarily replace an attribute



Convert m to an AvgMetric, unless it's already a Metric


save_model(file, model, opt, with_opt=True)

Save model to file along with opt (if available, and if with_opt)


load_model(file, model, opt, with_opt=None, device=None, strict=True)

Load model from file along with opt (if available, and if with_opt)

class Learner[source]

Learner(dls, model, loss_func=None, opt_func='Adam', lr=0.001, splitter='trainable_params', cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Group together a model, some dls and a loss_func to handle training

opt_func will be used to create an optimizer when is called, with lr as a learning rate. splitter is a function taht takes self.model and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is trainable_params, which returns all trainable parameters of the model.

cbs is one or a list of Callbacks to pass to the Learner. Each Callback is registered as an attribute of Learner (with camel case). At creation, all the callbacks in defaults.callbacks (TrainEvalCallback and Recorder) are associated to the Learner.

metrics is an optional list of metrics, that can be either functions or Metrics (see below).

Training loop

#Test init with callbacks
class TstCallback(Callback):
    def batch_begin(self): self.learn.a = self.a + 1

def synth_learner(n_train=10, n_valid=2, cuda=False,, **kwargs):
    data = synth_dbunch(n_train=n_train,n_valid=n_valid, cuda=cuda)
    return Learner(data, RegModel(), loss_func=MSELossFlat(), lr=lr, **kwargs)

tst_learn = synth_learner()
test_eq(len(, 1)
assert isinstance([0], TrainEvalCallback)
assert hasattr(tst_learn, ('train_eval'))

tst_learn = synth_learner(cbs=TstCallback())
test_eq(len(, 2)
assert isinstance([1], TstCallback)
assert hasattr(tst_learn, ('tst'))

tst_learn = synth_learner(cbs=TstCallback)
test_eq(len(, 2)
assert isinstance([1], TstCallback)
assert hasattr(tst_learn, ('tst'))

#A name that becomes an existing attribute of the Learner will throw an exception (here add_cb)
class AddCbCallback(Callback): pass
test_fail(lambda: synth_learner(cbs=AddCbCallback()))[source], lr=None, wd=None, cbs=None, reset_opt=False)

Fit self.model for n_epoch using cbs. Optionally reset_opt.

#Training a few epochs should make the model better
learn = synth_learner(cbs=TstCallback, lr=1e-2)
learn.model = learn.model.cpu()
xb,yb = learn.dls.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)
assert learn.loss < init_loss


Learner.one_batch(i, b)

Train or evaluate self.model on batch (xb,yb)

This is an internal method called by If passed, i is the index of this iteration in the epoch. In training method, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation.

class VerboseCallback[source]

VerboseCallback() :: Callback

Callback that prints the name of each event called



Train or evaluate self.model on all batches of self.dl

Serializing[source], with_opt=True)

Save model and optimizer state (if with_opt) to self.path/self.model_dir/file

file can be a Path, a string or a buffer.


Learner.load(file, with_opt=None, device=None, strict=True)

Load model and optimizer state (if with_opt) from self.path/self.model_dir/file using device

file can be a Path, a string or a buffer. Use device to load the model/optimizer state on a device different from the one it was saved.

learn = synth_learner(cbs=TstCallback, opt_func=partial(SGD, mom=0.9))
xb,yb = learn.dls.one_batch()
init_loss = learn.loss_func(learn.model(xb), yb)'tmp')
assert (Path.cwd()/'models/tmp.pth').exists()

learn1 = synth_learner(cbs=TstCallback, opt_func=partial(SGD, mom=0.9))
learn1 = learn1.load('tmp')
test_eq(learn.model.a, learn1.model.a)
test_eq(learn.model.b, learn1.model.b)
test_eq(learn.opt.state_dict(), learn1.opt.state_dict())'tmp1', with_opt=False)
learn1 = synth_learner(cbs=TstCallback, opt_func=partial(SGD, mom=0.9))
learn1 = learn1.load('tmp1')
test_eq(learn.model.a, learn1.model.a)
test_eq(learn.model.b, learn1.model.b)
test_ne(learn.opt.state_dict(), learn1.opt.state_dict())


Callback handling



Call self as a function.



Add cb to the list of Callback and register self as their learner

learn = synth_learner()
test_eq(len(, 2)
assert isinstance([1], TestTrainEvalCallback)
test_eq(learn.train_eval.learn, learn)



Add cbs to the list of Callback and register self as their learner

learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])
test_eq(len(, 4)



Add cb from the list of Callback and deregister self as their learner

cb =[1]
test_eq(len(, 3)
assert cb.learn is None
assert not getattr(learn,'test_train_eval',None)



Remove cbs from the list of Callback and deregister self as their learner

cb =[1]
test_eq(len(, 1)

When writing a callback, the following attributes of Learner are available:

  • model: the model used for training/validation
  • data: the underlying DataLoaders
  • loss_func: the loss function used
  • opt: the optimizer used to udpate the model parameters
  • opt_func: the function used to create the optimizer
  • cbs: the list containing all Callbacks
  • dl: current DataLoader used for iteration
  • x/xb: last input drawn from self.dl (potentially modified by callbacks). xb is always a tuple (potentially with one element) and x is detuplified. You can only assign to xb.
  • y/yb: last target drawn from self.dl (potentially modified by callbacks). yb is always a tuple (potentially with one element) and y is detuplified. You can only assign to yb.
  • pred: last predictions from self.model (potentially modified by callbacks)
  • loss: last computed loss (potentially modified by callbacks)
  • n_epoch: the number of epochs in this training
  • n_iter: the number of iterations in the current self.dl
  • epoch: the current epoch index (from 0 to n_epoch-1)
  • iter: the current iteration index in self.dl (from 0 to n_iter-1)

The following attributes are added by TrainEvalCallback and should be available unless you went out of your way to remove that callback:

  • train_iter: the number of training iterations done since the beginning of this training
  • pct_train: from 0. to 1., the percentage of training iterations completed
  • training: flag to indicate if we're in training mode or not

The following attribute is added by Recorder and should be available unless you went out of your way to remove that callback:

  • smooth_loss: an exponentially-averaged version of the training loss

Control flow testing

class Metric[source]


Blueprint for defining a metric

Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class AvgMetric, otherwise you'll need to implement the following methods. {% include note.html content='If your Metric has state depending on tensors, don’t forget to store it on the CPU to avoid any potential memory leaks.' %}



Reset inner state to prepare for new computation



Use learn to update the state with new results


The value of the metric[source]

Name of the Metric, camel-cased and with Metric removed

class AvgMetric[source]

AvgMetric(func) :: Metric

Average the values of func taking into account potential different batch sizes

learn = synth_learner()
tst = AvgMetric(lambda x,y: (x-y).abs().mean())
t,u = torch.randn(100),torch.randn(100)
for i in range(0,100,25): 
    learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)
test_close(tst.value, (t-u).abs().mean())

class AvgLoss[source]

AvgLoss() :: Metric

Average the losses taking into account potential different batch sizes

tst = AvgLoss()
t = torch.randn(100)
for i in range(0,100,25): 
    learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()
test_close(tst.value, t.mean())

class AvgSmoothLoss[source]

AvgSmoothLoss(beta=0.98) :: Metric

Smooth average of the losses (exponentially weighted with beta)

tst = AvgSmoothLoss()
t = torch.randn(100)
val = tensor(0.)
for i in range(4): 
    learn.loss = t[i*25:(i+1)*25].mean()
    val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)
    test_close(val/(1-0.98**(i+1)), tst.value)

class Recorder[source]

Recorder(add_time=True, train_metrics=False, valid_metrics=True, beta=0.98) :: Callback

Callback that registers statistics (lr, loss and metrics) during training

By default, metrics are computed on the validation set only, although that can be changed with training_metrics=True. beta is the weight used to compute the exponentially weighted average of the losses (which gives the smooth_loss attribute to Learner).

#Test printed output
def tst_metric(out, targ): return F.mse_loss(out, targ)
learn = synth_learner(n_train=5, metrics=tst_metric)
pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']"
test_stdout(lambda:, pat, regex=True)

Callback internals



Prepare state for training



Set timer if self.add_time=True



Reset loss and metrics state



Update all metrics and records lr and smooth loss in training



Store and log the loss/metric values

Plotting tools


Recorder.plot_loss(skip_start=5, with_valid=True)

Plot the losses from skip_start and onward

class FetchPreds[source]

FetchPreds(ds_idx=1, dl=None, with_input=False, with_decoded=False) :: Callback

A callback to fetch predictions during the training loop

Inference functions



Context manager to temporarily remove logger

learn = synth_learner(n_train=5, metrics=tst_metric)
with learn.no_logging():
    test_stdout(lambda:, '')
test_eq(learn.logger, print)


Learner.validate(ds_idx=1, dl=None, cbs=None)

Validate on dl with potential new cbs.

#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
res = learn.validate()
test_eq(res[0], res[1])
x,y = learn.dls.valid_ds.tensors
test_close(res[0], F.mse_loss(learn.model(x), y))



A context manager to evaluate loss_func with reduction set to none.


Learner.get_preds(ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None, inner=False, save_preds=None, save_targs=None, concat_dim=0)

Get the predictions and targets on the ds_idx-th dbunchset or dl, optionally with_input and with_loss

Depending on the loss_func attribute of Learner, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with act.

{% include note.html content='If you want to use the option with_loss=True on a custom loss function, make sure you have implemented a reduction attribute that supports ’none’ ' %}

#Test result
learn = synth_learner(n_train=5, metrics=tst_metric)
preds,targs = learn.get_preds()
x,y = learn.dls.valid_ds.tensors
test_eq(targs, y)
test_close(preds, learn.model(x))

preds,targs = learn.get_preds(act = torch.sigmoid)
test_eq(targs, y)
test_close(preds, torch.sigmoid(learn.model(x)))
#Test get_preds work with ds not evenly dividble by bs
learn = synth_learner(n_train=2.5, metrics=tst_metric)
preds,targs = learn.get_preds(ds_idx=0)
inps,preds,targs = learn.get_preds(ds_idx=0, with_input=True)
tst = learn.get_preds(ds_idx=0, with_input=True, with_decoded=True)


Learner.predict(item, rm_type_tfms=None, with_input=False)

Return the prediction on item, fully decoded, loss function decoded and probabilities

It returns a tuple of three elements with, in reverse order,

  • the prediction from the model, potentially passed through the activation of the loss function (if it has one)
  • the decoded prediction, using the poential decodes method from it
  • the fully decoded prediction, using the transforms used to buil the Datasets/DataLoaders
class _FakeLossFunc(Module):
    reduction = 'none'
    def forward(self, x, y): return F.mse_loss(x,y)
    def activation(self, x): return x+1
    def decodes(self, x):    return 2*x

class _Add1(Transform):
    def encodes(self, x): return x+1
    def decodes(self, x): return x-1
learn = synth_learner(n_train=5)
dl = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]]))
learn.dls = DataLoaders(dl, dl)
learn.loss_func = _FakeLossFunc()

inp = tensor([2.])
out = learn.model(inp).detach()+1  #applying model + activation
dec = 2*out                        #decodes from loss function
full_dec = dec-1                   #decodes from _Add1
test_eq(learn.predict(inp), [full_dec,dec,out])
test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])

Transfer learning



Freeze parameter groups up to n



Freeze up to last parameter group



Unfreeze the entire model

Exporting a Learner



Export the content of self without the items and the optimizer state for inference


load_learner(fname, cpu=True)

Load a Learner object in fname, optionally putting it on the cpu



Learner.tta(ds_idx=1, dl=None, n=4, item_tfms=None, batch_tfms=None, beta=0.25, use_max=False)

Return predictions on the ds_idx dataset or dl using Test Time Augmentation

In practice, we get the predictions n times with the transforms of the training set and average those. The final predictions are (1-beta) multiplied by this average + beta multiplied by the predictions obtained with the transforms of the dataset. Set beta to None to get a tuple of the predictions and tta results.