class _A: def __init__(self, a): self.a = a @contextmanager def a_changed(self, v): return replacing_yield(self, 'a', v) a = _A(42) with a.a_changed(32): test_eq(a.a, 32) test_eq(a.a, 42)
file can be a
Path object, a string or an opened file object.
pickle_protocol is passed along to
file can be a
Path object, a string or an opened file object. If a
device is passed, the model is loaded on it, otherwise it's loaded on the CPU.
True, the file must exactly contain weights for every parameter key in
False, only the keys that are in the saved model are loaded in
(0.95, 0.85, 0.95))
Group together a
dls and a
loss_func to handle training
opt_func will be used to create an optimizer when
Learner.fit is called, with
lr as a default learning rate.
splitter is a function that takes
self.model and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is
trainable_params, which returns all trainable parameters of the model.
cbs is one or a list of
Callbacks to pass to the
Callbacks are used for every tweak of the training loop. Each
Callback is registered as an attribute of
Learner (with camel case). At creation, all the callbacks in
ProgressCallback) are associated to the
model_dir are used to save and/or load models. Often
path will be inferred from
dls, but you can override it or pass a
Path object to
model_dir. Make sure you can write in
train_bn controls if
BatchNorm layers are trained even when they are supposed to be frozen according to the
splitter. Our empirical experiments have shown that it's the best behavior for those layers in transfer learning.
You can use regular PyTorch functionality for most of the arguments of the
Learner, although the experience will be smoother with pure fastai objects and you will be able to use the full functionality of the library. The expectation is that the training loop will work smoothly even if you did not use fastai end to end. What you might lose are interpretation objects or showing functionality. The list below explains how to use plain PyTorch objects for all the arguments and what you might lose.
The most important is
opt_func. If you are not using a fastai optimizer, you will need to write a function that wraps your PyTorch optimizer in an
OptimWrapper. See the optimizer module for more details. This is to ensure the library's schedulers/freeze API work with your code.
DataLoadersobject, that you can create from standard PyTorch dataloaders. By doing so, you will lose all showing functionality like
show_results. You can check the data block API or the mid-level data API tutorial to learn how to use fastai to gather your data!
modelis a standard PyTorch model. You can use anyone you like, just make sure it accepts the number of inputs you have in your
DataLoadersand returns as many outputs as you have targets.
loss_funccan be any loss function you like. It needs to be one of fastai's if you want to use
Learn.get_preds, or you will have to implement special methods (see more details after the
wd if they are provided, otherwise use the defaults values given by the
wd attributes of
learn = synth_learner(lr=1e-2) learn.model = learn.model.cpu() xb,yb = learn.dls.one_batch() init_loss = learn.loss_func(learn.model(xb), yb) learn.fit(6) xb,yb = learn.dls.one_batch() final_loss = learn.loss_func(learn.model(xb), yb) assert final_loss < init_loss
This is an internal method called by
Learner.fit. If passed,
i is the index of this iteration in the epoch. In training mode, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation. Training or validation is controlled internally by the
TrainEvalCallback through the
Nothing is returned, but the attributes
loss of the
Learner are set with the propet values:
b = learn.dls.one_batch() learn.one_batch(0, b) test_eq(learn.x, b) test_eq(learn.y, b) out = learn.model(learn.x) test_eq(learn.pred, out) test_eq(learn.loss, learn.loss_func(out, b))
More generally, the following attributes of
Learner are available and updated during the training loop:
model: the model used for training/validation
data: the underlying
loss_func: the loss function used
opt: the optimizer used to udpate the model parameters
opt_func: the function used to create the optimizer
cbs: the list containing all
DataLoaderused for iteration
xb: last input drawn from
self.dl(potentially modified by callbacks).
xbis always a tuple (potentially with one element) and
xis detuplified. You can only assign to
yb: last target drawn from
self.dl(potentially modified by callbacks).
ybis always a tuple (potentially with one element) and
yis detuplified. You can only assign to
pred: last predictions from
self.model(potentially modified by callbacks)
loss: last computed loss (potentially modified by callbacks)
n_epoch: the number of epochs in this training
n_iter: the number of iterations in the current
epoch: the current epoch index (from 0 to
iter: the current iteration index in
self.dl(from 0 to
The following attributes are added by
TrainEvalCallback and should be available unless you went out of your way to remove that callback:
train_iter: the number of training iterations done since the beginning of this training
pct_train: from 0. to 1., the percentage of training iterations completed
training: flag to indicate if we're in training mode or not
The following attribute is added by
Recorder and should be available unless you went out of your way to remove that callback:
smooth_loss: an exponentially-averaged version of the training loss
learn = synth_learner(n_train=5, cbs=VerboseCallback()) assert learn.opt is None learn.create_opt() assert learn.opt is not None test_eq(learn.opt.hypers['lr'], learn.lr)
file can be a
string or a buffer.
pickle_protocol is passed along to
file can be a
string or a buffer. Use
device to load the model/optimizer state on a device different from the one it was saved.
with tempfile.TemporaryDirectory() as d: learn = synth_learner(path=d) learn.fit(1) #Test save created a file learn.save('tmp') assert (Path(d)/'models/tmp.pth').exists() #Test load did load the model learn1 = synth_learner(path=d) learn1 = learn1.load('tmp') test_eq(learn.model.a, learn1.model.a) test_eq(learn.model.b, learn1.model.b) test_eq(learn.opt.state_dict(), learn1.opt.state_dict())
class TstCallback(Callback): def batch_begin(self): self.learn.a = self.a + 1 tst_learn = synth_learner() test_eq(len(tst_learn.cbs), 1) assert isinstance(tst_learn.cbs, TrainEvalCallback) assert hasattr(tst_learn, ('train_eval')) tst_learn = synth_learner(cbs=TstCallback()) test_eq(len(tst_learn.cbs), 2) assert isinstance(tst_learn.cbs, TstCallback) assert hasattr(tst_learn, ('tst'))
class AddCbCallback(Callback): pass test_fail(lambda: synth_learner(cbs=AddCbCallback()))
learn = synth_learner(cbs=VerboseCallback()) learn('after_fit')
learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) test_eq(len(learn.cbs), 2) assert isinstance(learn.cbs, TestTrainEvalCallback) test_eq(learn.train_eval.learn, learn)
learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()]) test_eq(len(learn.cbs), 4)
learn = synth_learner() test_eq(len(learn.cbs), 1) with learn.added_cbs(TestTrainEvalCallback()): test_eq(len(learn.cbs), 2)
By order, we mean using the internal ordering of the
callbcak.core for more information on how it works).
learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) learn.ordered_cbs('begin_fit')
learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) cb = learn.cbs learn.remove_cb(learn.cbs) test_eq(len(learn.cbs), 1) assert cb.learn is None assert not getattr(learn,'test_train_eval',None)
cb can simply be the class of the
Callback we want to remove (in which case all instances of that callback are removed).
learn = synth_learner() learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()]) learn.remove_cb(TestTrainEvalCallback) test_eq(len(learn.cbs), 1) assert not getattr(learn,'test_train_eval',None)
learn = synth_learner() learn.add_cbs([TestTrainEvalCallback() for _ in range(3)]) cb = learn.cbs learn.remove_cbs(learn.cbs[1:]) test_eq(len(learn.cbs), 1)
learn = synth_learner() learn.add_cb(TestTrainEvalCallback()) with learn.removed_cbs(learn.cbs): test_eq(len(learn.cbs), 1) test_eq(len(learn.cbs), 2)
At each step, callbacks are shown in order, which can help debugging.
learn = synth_learner() learn.show_training_loop()
Start Fit - begin_fit : [TrainEvalCallback] Start Epoch Loop - begin_epoch :  Start Train - begin_train : [TrainEvalCallback] Start Batch Loop - begin_batch :  - after_pred :  - after_loss :  - after_backward :  - after_step :  - after_cancel_batch:  - after_batch : [TrainEvalCallback] End Batch Loop End Train - after_cancel_train:  - after_train :  Start Valid - begin_validate : [TrainEvalCallback] Start Batch Loop - **CBs same as train batch**:  End Batch Loop End Valid - after_cancel_validate:  - after_validate :  End Epoch Loop - after_cancel_epoch:  - after_epoch :  End Fit - after_cancel_fit:  - after_fit : 
Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class
AvgMetric, otherwise you'll need to implement the following methods.
Metrichas state depending on tensors, don’t forget to store it on the CPU to avoid any potential memory leaks.
learn = synth_learner() tst = AvgMetric(lambda x,y: (x-y).abs().mean()) t,u = torch.randn(100),torch.randn(100) tst.reset() for i in range(0,100,25): learn.pred,learn.yb = t[i:i+25],(u[i:i+25],) tst.accumulate(learn) test_close(tst.value, (t-u).abs().mean())
tst = AvgLoss() t = torch.randn(100) tst.reset() for i in range(0,100,25): learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean() tst.accumulate(learn) test_close(tst.value, t.mean())
tst = AvgSmoothLoss() t = torch.randn(100) tst.reset() val = tensor(0.) for i in range(4): learn.loss = t[i*25:(i+1)*25].mean() tst.accumulate(learn) val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98) test_close(val/(1-0.98**(i+1)), tst.value)
def metric_value_fn(): return 5e-3 vm = ValueMetric(metric_value_fn, 'custom_value_metric') test_eq(vm.value, 5e-3) test_eq(vm.name, 'custom_value_metric') vm = ValueMetric(metric_value_fn) test_eq(vm.name, 'metric_value_fn')
By default, metrics are computed on the validation set only, although that can be changed by adjusting
beta is the weight used to compute the exponentially weighted average of the losses (which gives the
smooth_loss attribute to
logger attribute of a
Learner determines what happens to those metrics. By default, it just print them:
def tst_metric(out, targ): return F.mse_loss(out, targ) learn = synth_learner(n_train=5, metrics=tst_metric) pat = r"[tensor\(\d.\d*\), tensor\(\d.\d*\), tensor\(\d.\d*\), 'dd:dd']" test_stdout(lambda: learn.fit(1), pat, regex=True)
learn = synth_learner(n_train=5, metrics=tst_metric) res = learn.validate() test_eq(res, res) x,y = learn.dls.valid_ds.tensors test_close(res, F.mse_loss(learn.model(x), y))
Get the predictions and targets on the
ds_idx-th dbunchset or
with_decoded will also return the decoded predictions using the
decodes function of the loss function (if it exists). For instance, fastai's
CrossEntropyFlat takes the argmax or predictions in its decodes.
Depending on the
loss_func attribute of
Learner, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with
save_targs should be used when your predictions are too big to fit all in memory. Give a
Path object that points to a folder where the predictions and targets will be saved.
concat_dim is the batch dimension, where all the tensors will be concatenated.
inner is an internal attribute that tells
get_preds it's called internally, inside another training loop, to avoid recursion errors.
with_loss=Trueon a custom loss function, make sure you have implemented a
reductionattribute that supports ’none’
learn = synth_learner(n_train=5, metrics=tst_metric) preds,targs = learn.get_preds() x,y = learn.dls.valid_ds.tensors test_eq(targs, y) test_close(preds, learn.model(x)) preds,targs = learn.get_preds(act = torch.sigmoid) test_eq(targs, y) test_close(preds, torch.sigmoid(learn.model(x)))
It returns a tuple of three elements with, in reverse order,
rm_type_tfms is a deprecated argument that should not be used and will be removed in a future version.
with_input will add the decoded inputs to the result.
class _FakeLossFunc(Module): reduction = 'none' def forward(self, x, y): return F.mse_loss(x,y) def activation(self, x): return x+1 def decodes(self, x): return 2*x class _Add1(Transform): def encodes(self, x): return x+1 def decodes(self, x): return x-1 learn = synth_learner(n_train=5) dl = TfmdDL(Datasets(torch.arange(50), tfms = [L(), [_Add1()]])) learn.dls = DataLoaders(dl, dl) learn.loss_func = _FakeLossFunc() inp = tensor([2.]) out = learn.model(inp).detach()+1 #applying model + activation dec = 2*out #decodes from loss function full_dec = dec-1 #decodes from _Add1 test_eq(learn.predict(inp), [full_dec,dec,out]) test_eq(learn.predict(inp, with_input=True), [inp,full_dec,dec,out])
max_n samples (unless the batch size of
dl is less than
max_n, in which case it will show as many samples) and
shuffle the data unless you pass
false to that flag.
kwargs are application-dependant.
We can't show an example on our synthetic
Learner, but check all the beginners tutorials which will show you how that method works accross applications.
The last functions in this section are used internally for inference, but should be less useful to you.
learn = synth_learner(n_train=5, metrics=tst_metric) with learn.no_logging(): test_stdout(lambda: learn.fit(1), '') test_eq(learn.logger, print)
This requires your loss function to either have a
reduction attribute or a
reduction argument (like all fastai and PyTorch loss functions).
Learner is saved in
pickle_protocol. Note that serialization in Python saves the names of functions, not the code itself. Therefore, any custom code you have for models, data transformation, loss function etc... should be put in a module that you will import in your training environment before exporting, and in your deployment environment before loading it.
In practice, we get the predictions
n times with the transforms of the training set and average those. The final predictions are
(1-beta) multiplied by this average +
beta multiplied by the predictions obtained with the transforms of the dataset. Set
None to get a tuple of the predictions and tta results. You can also use the maximum of all predictions instead of an average by setting
If you want to use new transforms, you can pass them with
learn = synth_learner(lr=1e-2) test_eq(learn.init_args['Learner.__init__.lr'], 0.01)