Callbacks and helper functions to train in parallel or use distributed training

Parallel

Patch the parallel models so they work with RNNs

DataParallel.reset[source]

DataParallel.reset()

class ParallelTrainer[source]

ParallelTrainer(device_ids) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

Learner.to_parallel[source]

Learner.to_parallel(device_ids=None)

Learner.detach_parallel[source]

Learner.detach_parallel()

Remove ParallelTrainer callback from Learner.

Learner.parallel_ctx[source]

Learner.parallel_ctx(device_ids=None)

A context manager to adapt a learner to train in data parallel mode.

Distributed

Patch the parallel models so they work with RNNs

DistributedDataParallel.reset[source]

DistributedDataParallel.reset()

Convenience functions to set up/tear down torch distributed data parallel mode.

setup_distrib[source]

setup_distrib(gpu=None)

teardown_distrib[source]

teardown_distrib()

DataLoader

We need to change the dataloaders so that they only get one part of the batch each (otherwise there is no point in using distributed training).

class DistributedDL[source]

DistributedDL(dataset, rank, world_size, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None) :: TfmdDL

Transformed DataLoader

dl = TfmdDL(list(range(50)), bs=16, num_workers=2)
for i in range(4):
    dl1 = DistributedDL.from_dl(dl, i, 4)
    test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)
dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)
res = []
for i in range(4):
    dl1 = DistributedDL.from_dl(dl, i, 4)
    dl1.set_epoch(0)
    res += list(dl1)[0].tolist()
#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle
test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))

class DistributedTrainer[source]

DistributedTrainer(cuda_id=0) :: Callback

Basic class handling tweaks of the training loop by changing a Learner in various events

Attach, remove a callback which adapts the model to use DistributedDL to train in distributed data parallel mode.

Learner.to_distributed[source]

Learner.to_distributed(cuda_id)

Learner.detach_distributed[source]

Learner.detach_distributed()

Learner.distrib_ctx[source]

Learner.distrib_ctx(cuda_id=None)

A context manager to adapt a learner to train in distributed data parallel mode.

distrib_ctx context manager

distrib_ctx(cuda_id) prepares a learner to train in distributed data parallel mode. It assumes these environment variables have all been setup properly, such as those launched by python -m fastai2.launch.

Typical usage:

with learn.distrib_ctx(): learn.fit(.....)

It attaches a DistributedTrainer callback and DistributedDL data loader to the learner, then executes learn.fit(.....). Upon exiting the context, it removes the DistributedTrainer and DistributedDL, and destroys any locally created distributed process group. The process is still attached to the GPU though.

rank0_first[source]

rank0_first(func)

Execute func in the Rank-0 process first, then in other ranks in parallel.

rank0_first(f) calls f() in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, f() is called only once as expected.

One application of rank0_first() is to make fresh downloads via untar_data() safe in distributed training scripts launched by python -m fastai2.launch <script>:

path = untar_data(URLs.IMDB)

becomes:> path = rank0_first(lambda: untar_data(URLs.IMDB))

Some learner factory methods may use untar_data() to download pretrained models by default:

learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)

becomes:> learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)) Otherwise, multiple processes will download at the same time and corrupt the data.