The function to immediately get a `Learner` ready to train for tabular data
from nbdev.showdoc import *
from fastai2.tabular.data import *

The main function you probably want to use in this module is tabular_learner. It will automatically create a TabulaModel suitable for your data and infer the irght loss function. See the tabular tutorial for an example of use in context.

Main functions

class TabularLearner[source]

TabularLearner(dls, model, loss_func=None, opt_func='Adam', lr=0.001, splitter='trainable_params', cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Learner

Learner for tabular data

It works exactly as a normal Learner, the only difference is that it implements a predict method specific to work on a row of data.

tabular_learner[source]

tabular_learner(dls, layers=None, emb_szs=None, config=None, n_out=None, y_range=None, loss_func=None, opt_func='Adam', lr=0.001, splitter='trainable_params', cbs=None, metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Get a Learner using dls, with metrics, including a TabularModel created using the remaining params.

If your data was built with fastai, you probably won't need to pass anything to emb_szs unless you want to change the default of the library (produced by get_emb_sz), same for n_out which should be automatically inferred. layers will default to [200,100] and is passed to TabularModel along with the config.

Use tabular_config to create a config and cusotmize the model used. There is just easy access to y_range because this argument is often used.

All the other arguments are passed to Learner.

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names, 
                                 y_names="salary", valid_idx=list(range(800,1000)), bs=64)
learn = tabular_learner(dls)