Regular training with square images and rectangular training
from fastai2.basics import *
from fastai2.callback.all import *
from import *
from nbdev.showdoc import *

Square training

Loading the data with Datasets

To load the data with the medium-level API Datasets, we need to gather all the images and define some way to split them between training and validation sets.

source = untar_data(URLs.IMAGENETTE_160)
items = get_image_files(source)
split_idx = GrandparentSplitter(valid_name='val')(items)

Then we detail the type transforms (applied to the items to form a tuple) and the dataset transforms. For our inputs we use PILImage.create and for our targets, the parent_label function to convert a filename to its class, followed by Categorize. We'll also map the wordnet category ids used in Imagenette to words.

The dataset transforms contain data augmentation using PIL and a resize to 128.

lbl_dict = dict(
    n02102040='English springer',
    n02979186='cassette player',
    n03000684='chain saw',
    n03394916='French horn',
    n03417042='garbage truck',
    n03425413='gas pump',
    n03445777='golf ball',
tfms = [[PILImage.create], [parent_label, lbl_dict.__getitem__, Categorize]]
item_img_tfms = [ToTensor, FlipItem(0.5), RandomResizedCrop(128, min_scale=0.35)]

We can then pass all of this information to Datasets.

dsets = Datasets(items, tfms, splits=split_idx)

To convert our Datasets to a DataLoaders, we need to indicate the transforms we want to use at the batch level, here putting on the GPU with Cuda, converting the tensors of bytes to float then normalizing using the traditional imagenet statistics.

batch_tfms = [IntToFloatTensor, Normalize.from_stats(*imagenet_stats)]
dls = dsets.dataloaders(after_item=item_img_tfms, after_batch=batch_tfms, bs=64, num_workers=0)
(#13394) [Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_12962.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/ILSVRC2012_val_00035211.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_1381.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_13752.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_12701.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_10450.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_7022.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_8661.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_14891.JPEG'),Path('/home/sgugger/.fastai/data/imagenette2-160/val/n03425413/n03425413_21202.JPEG')...]

Loading the data with DataBlock

An easier way is to use the DataBlock higher-level API. We just need to specify the types, how to get the items, how to split them and how to label to build an Imagenette datablock.

imagenette = DataBlock(blocks=(ImageBlock, CategoryBlock), 

We can then directly call the dataloaders method when specifying a source (where the items are) and the non-default dataset and dataloader transforms. To check which transforms are included by default (inferred from the types passed), we can check (and potentially modify) the attributes default_type_tfms, default_item_tfms and default_batch_tfms of the imagenette object.

((#2) [(#1) [<bound method PILBase.create of <class ''>>],(#1) [Categorize: (object,object) -> encodes (object,object) -> decodes]],
 (#3) [ToTensor: (PILMask,object) -> encodes
 (PILBase,object) -> encodes ,FlipItem: (TensorBBox,object) -> encodes
 (TensorPoint,object) -> encodes
 (TensorMask,object) -> encodes
 (TensorImage,object) -> encodes
 (Image,object) -> encodes ,RandomResizedCrop: (TensorBBox,object) -> encodes
 (TensorPoint,object) -> encodes
 (Image,object) -> encodes ],
 (#2) [IntToFloatTensor: (TensorMask,object) -> encodes
 (TensorImage,object) -> encodes (TensorImage,object) -> decodes,Normalize: (TensorImage,object) -> encodes (TensorImage,object) -> decodes])

Here we need to add the data augmentation and resize, as well as the normalization.

dls = imagenette.dataloaders(source, bs=64, num_workers=8)


The following function will give us a Learner to train a model on imagenette.

learn = Learner(dls, xresnet18(), lr=1e-2, metrics=accuracy,
                opt_func=partial(Adam, wd=0.01, eps=1e-3)

Then we can train our model.

epoch train_loss valid_loss accuracy time
0 1.508685 1.212527 0.623185 00:10

Showing results

To get predictions on one item, we use Learner.predict

tst_item = items[0]
t = learn.predict(tst_item)