Callbacks that make decisions depending how a monitored metric/loss behaves

class ShortEpochCallback[source]

ShortEpochCallback(pct=0.01, short_valid=True) :: Callback

Fit just pct of an epoch, then stop

learn = synth_learner(), cbs=ShortEpochCallback())
learn = synth_learner(), cbs=ShortEpochCallback(short_valid=False))

class GradientAccumulation[source]

GradientAccumulation(n_acc=32) :: Callback

Accumulate gradients before updating weights

learn = synth_learner(), lr=0.01, cbs=GradientAccumulation(n_acc=2*
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0], lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]




Set bn layers in eval mode for all recursive children of m.

class BnFreeze[source]

BnFreeze() :: Callback

Freeze moving average statistics in all non-trainable batchnorm layers.

BnFreeze is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.

Learner.freeze()) doesn't suffice here as the BatchNorm layers are trainable by default, and running mean and sdev of batches are tracked. For feature extractors to fully match, you need to set train_bn=False and these stats need to be frozen as well, which is precisely the function of BnFreeze.

from import *

path = untar_data(URLs.MNIST_TINY)
dls  = ImageDataLoaders.from_folder(path, valid_pct=0.2)

We first demonstrate the mismatch of the running stats when using only train_bn=False

learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
learn2 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False), lr=0.02), lr=0.02)
def models_equal(model_1, model_2, verbose=False):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                if verbose: print(f'Mismtach found at {key_item_1[0]}')
                raise Exception
                if verbose: print('Models being compared have different architectures')
    if models_differ == 0:
        if verbose: print('Models match perfectly')
        return True
    return False
models_equal(learn1.model, learn2.model)
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
learn2 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze), lr=0.02), lr=0.02)

assert models_equal(learn1.model[0], learn2.model[0])