fastai.v3,

Part2 lesson 10, 5ab, closure&decorator, early stopping

Follow Aug 25, 2021 · 7 mins read
Part2 lesson 10, 5ab, closure&decorator, early stopping
Share this

Direct link to answer: link

Callbacks

: which manage model behaviors of fastai

📝 Q1. Roughly explain the concept of callback

Here we will create various type of callbacks as function, lambda, closure, and class.

👩‍💻 Q2. Assume we want to print out progress of iteration with a pre-defined format. Then slow_calculation feeds back its progress and cb handles that information.

from time import sleep

def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
        if cb: cb(i)
    return res

(a) implement callback as basic function
(b) implement callback as lambda

👩‍💻 Q3. Now we want to add one more argument in callback fuction, but new argument is given before another one.

(c) Implement the callback using lambda (d) Implement the callback using decorator (e) Implement the callback using partial

hint: in this case, you can make callback as closure

extra questions

👩‍💻 Q4. Make test callback when meets follwing conditions

(a) cancels training session when you’ve finished 10th iteration of training. (b) cancels training session when you’ve finished 2nd epoch. (c) cancels training session when you’ve finished 3rd batch of inference on 2nd epoch.

hint : (one batch (one iteration -> to the batch) -> to the all batches) -> one train -> one inference = one epoch

👩‍💻 Q5. Make callback class which stops training when model’s current loss exceeds 10 times of its best loss.

(a) Learning rate ranges from 1e-06 to 10. You can test the loss for a batch and change learning rate.
Confine maximum itereation to 150. (b) Plot its loss per learning rate.


A1. Callback is inner process which gets a specific action(i.e. event) and handles that response.

A2.

# a
def show_progress(epoch):
    print(f"Awesome! We've finished epoch {epoch}!")
# b
slow_calculation(lambda o: print(f"Awesome! We've finished epoch {o}!"))

A3.

# a
def show_progress(exclamation, epoch):
    print(f"{exclamation}! We've finished epoch {epoch}!")

slow_calculation(lambda o: show_progress("OK I guess", o))    
# b
def make_show_progress(exclamation):
    _inner = lambda epoch: print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner
slow_calculation(make_show_progress("Nice!"))

def make_show_progress(exclamation):
    # Leading "_" is generally understood to be "private"
    def _inner(epoch): print(f"{exclamation}! We've finished epoch {epoch}!")
    return _inner
slow_calculation(make_show_progress("Nice!"))    

# c
from functools import partial
def show_progress(exclamation, epoch):
    print(f"{exclamation}! We've finished epoch {epoch}!")
slow_calculation(partial(show_progress, "OK I guess"))

A4.


class EarlyStopIterCallback(callback):
    def after_iter(self):
        if self.n_iter > 9: raise CancelTrainException()

class EarlyStopEpochCallback(callback):
    def after_epoch(self):
        if not self.is_train: return
        if self.n_epoch > 1: raise CancelTrainException()

class EarlyStopCallback(callback):
    """1st epoch, inference, 2nd batch"""

    def after_batch(self):
        if self.is_train: return
        if self.n_epoch > 1:
            if self.n_iter // self.bs > 2: raise CancelTrainException()

A5


class Recorder(Callback):
    def bigin_fit(self):
        # Now we need to make learning rate list as for number of layers
        self.lrs = [[] for _ in self.opt.param_groups]
        self.losses = []

    def after_batch(self):
        if not self.in_train: return
        for pg, lr in zip(self.opt.param_groups, self.lrs):
            lr.append(pg['lr'])

    def plot_lr(self, pgid=-1):
        plt.plot(self.lrs[pgid])
    
    def plot_loss(self, skip_last):
        n = len(self.losses) - skip_last
        plt.plot(self.losses[:n])

    def plot(self, skip_last = 0, pgid = -1):
        losses = [tmp.item() for tmp in self.losses]
        lrs = self.lrs[pgid]
        n = len(self.losses) - skip_last
        plt.xscale('log')
        plt.plot(self.lrs[:n], self.losses[:n])

class LR_Find(Callback):
    _order = 1
    def __init__(self, max_iter = 150, min_lr = 1e-06, max_lr = 10):
        self.max_iter, self.min_lr, self.max_lr = max_iter, min_lr, max_lr
        self.best_loss = 1e10

    def begin_batch(self):
        if not self.in_train: return #we don't consider inference time
        pos = self.n_iter / self.max_iter #current iteration position regarding our goal iteration
        # update new lr
        lr = self.min_lr * (self.max_lr / self.min_lr) ** pos
        for pg in self.opt.param_groups:
            pg['lr'] = lr

    def after_step(self):
        if self.n_iter > self.max_iter or self.loss > self.best_loss * 10:
            raise CancelTrainException()

        if self.loss < self.best_loss: self.best_loss = self.loss

learn = create_learner(get_model, loss_func, data)
run = Runner(cb_functions = [LR_Find, Recorder])
run.fit(2, learn)
run.recorder.plot(skip_last = 5)