FastAI’s callbacks for better CNN training — meet SaveModelCallback.

FastAI has a very flexible callback system that let’s you greatly customize your training process. However, some of the pre-built and useful callbacks are not as easy to find without a deep dive into the documentation and to my knowledge, aren’t covered in the regular courses.

A common question is thus, “how do I automatically save my best model if it happens in the middle of a training run?” and the answer is to use the SaveModelCallback.

Image for post
Image for post

The purpose of this callback, as the name implies, is to automatically save a new ‘best loss’ model automatically during training. At the end of training, it then conveniently loads the top model so it’s ready for you to continue with.

You can specify what criterion defines the ‘best model’, but per a forum post by Jeremy Howard, the rule of thumb is to use ‘error_rate’ as the criterion.

To access the callbacks, use the following:

from fastai.callbacks import *

To use, you add it in the list of callbacks to your learner. This can be done at fit time, or at the learner creation time:

learn.fit_one_cycle(4, max_lr=(4e-3),callbacks=[SaveModelCallback(learn,

monitor=’error_rate’,

mode=’min’,

name=”marker-625AM-80")])

There are a couple main parameters to modify:

1 — monitor = “error_rate” Error rate is the preferred option, but you can also monitor other metrics such as validation_loss, etc.

2 — mode = ( auto | min |max) — use this to set the comparison operator for how the monitored metric should be compared. Auto is default and will try and determine if min or max is what you really want. I prefer to hard-set it to min or max to be safe.

3 — name = use to set the name to save the best model under. The default name is simply “bestmodel” but I like to customize it with things like the resolution image size, or time of the training run, etc. to make the name more informative.

During the run, you’ll see the callback print out each time it’s saved a new model and what monitor score was:

Image for post
Image for post

4 — every = This controls when the ‘best model’ should be saved. By default it’s set to ‘improvement’ meaning it will save at the end of each epoch if the metric sets a new best value. I usually leave this to the default. ‘epoch’ is the other option and would simply save a new, individually tagged model at the end of each epoch regardless (ala modelname_epoch, or “bestmodel_3” for epoch 3).

Note that the best model will be auto-loaded for you at the end of training under the default of every=’improvement’.

Note that you can dynamically set callbacks as a group and this may be easier than doing it with each fit call.

callbacks = [
SaveModelCallback(learn, monitor=’dice’,mode=’min’,name=’res50-uf1'),
ShowGraph(learn),

EarlyStoppingCallback(learn, min_delta=1e-5, patience=3),
]

learn.callbacks = callbacks

Customizing SaveModelCallback — after using SaveModel extensively, I’ve modified it to have a threshold value for when to even start saving (i.e. min_to_save = .70) as well as saving out the best model as filename + “_metric_score” so that I can easily look through models and know what is what. That results in several model saves but gives some flexibility. I may set it up to remove the previous best model in the future, but I like having the stats with the model name to make it clear.

Summary: If you aren’t already using SaveModelCallback, I think you’ll find it invaluable to let you better automate your CNN training and speed it up with the best model saved out automatically and already loaded at the end of each training cycle.

The code for it is in /fastai/callbacks/tracker.py if you want to customize it, and here’s a link to the documentation for it since it’s relatively buried otherwise:

https://docs.fast.ai/callbacks.tracker.html

PyTorch, Deep Learning, Object detection, Stock Index investing and long term compounding.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store