from nbdev.showdoc import *
The training loop is defined in
Learner a bit below and consists in a minimal set of instructions: looping through the data we:
- compute the output of the model from the input
- calculate a loss between this output and the desired target
- compute the gradients of this loss with respect to all the model parameters
- update the parameters accordingly
- zero all the gradients
Any tweak of this training loop is defined in a
Callback to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they'll be defined in different callbacks). A callback can implement actions on the following events:
begin_fit: called before doing anything, ideal for initial setup.
begin_epoch: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.
begin_train: called at the beginning of the training part of an epoch.
begin_batch: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).
after_pred: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.
after_loss: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).
after_backward: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).
after_step: called after the step and before the gradients are zeroed.
after_batch: called at the end of a batch, for any clean-up before the next one.
after_train: called at the end of the training phase of an epoch.
begin_validate: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.
after_validate: called at the end of the validation part of an epoch.
after_epoch: called at the end of an epoch, for any clean-up before the next one.
after_fit: called at the end of training, for final clean-up.
tst_cb = Callback() tst_cb.call_me = lambda: print("maybe") test_stdout(lambda: tst_cb("call_me"), "maybe")
This is a shortcut to avoid having to write
self.learn.bla for any
bla attribute we seek, and just write
mk_class('TstLearner', 'a') class TstCallback(Callback): def batch_begin(self): print(self.a) learn,cb = TstLearner(1),TstCallback() cb.learn = learn test_stdout(lambda: cb('batch_begin'), "1")
Note that it only works to get the value of the attribute, if you want to change it, you have to manually access it with
self.learn.bla. In the example below,
self.a += 1 creates an
a attribute of 2 in the callback instead of setting the
a of the learner to 2. It also issues a warning that something is probably wrong:
class TstCallback(Callback): def batch_begin(self): self.a += 1 learn,cb = TstLearner(1),TstCallback() cb.learn = learn cb('batch_begin') test_eq(cb.a, 2) test_eq(cb.learn.a, 1)
/home/sgugger/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:16: UserWarning: You are setting an attribute (a) that also exists in the learner. Please be advised that you're not setting it in the learner but in the callback. Use `self.learn.a` if you would like to change it in the learner. app.launch_new_instance()
A proper version needs to write
self.learn.a = self.a + 1:
class TstCallback(Callback): def batch_begin(self): self.learn.a = self.a + 1 learn,cb = TstLearner(1),TstCallback() cb.learn = learn cb('batch_begin') test_eq(cb.learn.a, 2)
test_eq(TstCallback().name, 'tst') class ComplicatedNameCallback(Callback): pass test_eq(ComplicatedNameCallback().name, 'complicated_name')
When writing a callback, the following attributes of
Learner are available:
model: the model used for training/validation
data: the underlying
loss_func: the loss function used
opt: the optimizer used to udpate the model parameters
opt_func: the function used to create the optimizer
cbs: the list containing all
DataLoaderused for iteration
xb: last input drawn from
self.dl(potentially modified by callbacks).
xbis always a tuple (potentially with one element) and
xis detuplified. You can only assign to
yb: last target drawn from
self.dl(potentially modified by callbacks).
ybis always a tuple (potentially with one element) and
yis detuplified. You can only assign to
pred: last predictions from
self.model(potentially modified by callbacks)
loss: last computed loss (potentially modified by callbacks)
n_epoch: the number of epochs in this training
n_iter: the number of iterations in the current
epoch: the current epoch index (from 0 to
iter: the current iteration index in
self.dl(from 0 to
The following attributes are added by
TrainEvalCallback and should be available unless you went out of your way to remove that callback:
train_iter: the number of training iterations done since the beginning of this training
pct_train: from 0. to 1., the percentage of training iterations completed
training: flag to indicate if we're in training mode or not
The following attribute is added by
Recorder and should be available unless you went out of your way to remove that callback:
smooth_loss: an exponentially-averaged version of the training loss
It happens that we may want to skip some of the steps of the training loop: in gradient accumulation, we don't aways want to do the step/zeroing of the grads for instance. During an LR finder test, we don't want to do the validation phase of an epoch. Or if we're training with a strategy of early stopping, we want to be able to completely interrupt the training loop.
This is made possible by raising specific exceptions the training loop will look for (and properly catch).
You can detect one of those exceptions occurred and add code that executes right after with the following events:
after_cancel_batch: reached imediately after a
CancelBatchExceptionbefore proceeding to
after_cancel_train: reached imediately after a
CancelTrainExceptionbefore proceeding to
after_cancel_valid: reached imediately after a
CancelValidExceptionbefore proceeding to
after_cancel_epoch: reached imediately after a
CancelEpochExceptionbefore proceeding to
after_cancel_fit: reached imediately after a
CancelFitExceptionbefore proceeding to
Here's the full list: begin_fit begin_epoch begin_train begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch after_cancel_train after_train begin_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit.