Tensorflow Callback 사용하기

1. Callback 이란?

모델을 훈련시키는 일은 대체적으로 오랜 시간이 걸립니다. 짧게는 몇분에 끝나기도 하지만 보통은 몇 시간에서 몇 일이 걸리기도 합니다. Tensorflow가 기본적인 Log를 출력해주기는 하지만 훈련이 끝날 때까지 기도하고 있기에는 너무나 불안한 일입니다. 그래서 Tensorflow는 모델을 훈련시키는 동안 어떤 이벤트들이 발생하면 개발자가 원하는 동작을 수행할 수 있는 방법을 제공하고 있습니다. 개발자가 원하는 동작Callback이라고 부릅니다.

Callback 자체는 새로운 개념은 아닙니다. Event-driven 개발 방식에서 쓰이고요. GUI 개발을 해보신 분들은 익숙하실 겁니다. 기본 방식은 이렇습니다.

  • Tensorflow가 미리 이벤트들을 정의합니다. 이벤트는 개발자들이 관심있을 법한 일들입니다. 예를 들어 epoch이 시작한다, 끝났다, batch가 시작한다, 끝난다와 같은 것들입니다.
  • 개발자는 관심있는 이벤트에 대해서 Callback 함수를 만들고, Tensorflow에게 그 이벤트가 발생하면 제공한 Callback 함수를 호출해달라고 합니다. 예를 들어, "매번 epoch이 시작하면 my_callback_for_epoch_begin() 함수를 호출해달라"와 같은 방식입니다.
  • 특정 이벤트가 발생할 때마다 Tensorflow는 개발자가 등록해 둔 Callback 함수들을 호출합니다. 예를 들어, 매 epoch이 시작되면 on_epoch_begin에 등록된 Callback 함수들을 차례로 호출합니다.

여기 제 주소가 있으니 필요할 때 불러주세요. 잘 부탁드립니다. 같은 느낌이랄까요?

2. Callback 구현 방법

Tensorflow는 자주 사용되는 Callback들을 미리 구현해서 제공하고 있습니다. 전체 epoch이 끝나기 전에 성능 향상이 없다면 훈련을 멈춘다거나, Learning Rate을 동적으로 조절한다거나, 중간중간 Prameter를 저장한다거나 할 수 있습니다.

미리 제공되는 기능들을 알아보기 전에 Callback 자체를 어떻게 만드는지 알아보겠습니다. 작동 방식을 알고나면 제공되는 기능도 더 잘 사용할 수 있고, 입맛에 맞게 만들어 쓸 수도 있으니까요.

2.1. tf.keras.callbacks.Callback

새로운 Callback을 만들기 위해서는 tf.keras.callbacks.Callback을 상속해서 새로운 클래스를 만들면 됩니다. 앞에서 각 이벤트에 Callback 함수를 등록한다고 설명했는데, 실제 Tensorflow의 구현은 조금 다릅니다. Callback 클래스에 각 이벤트에 해당하는 메소드들이 정의되어있고, 이 중 필요한 메소드를 오버라이딩합니다. Callback 클래스의 메소드 중 관심있을만한 메소드를 알아보겠습니다.

  • on_epoch_begin(epoch, logs=None): 매번 epoch이 시작될 때 호출됩니다. epoch는 0으로 시작하는 현재 epoch의 index입니다. v2.2.0 기준으로 logs dict는 빈 값입니다.
  • on_epoch_end(epoch, logs=None): 매번 epoch이 끝날 때 호출됩니다. epoch는 0으로 시작하는 현재 epoch의 index입니다. logs['loss']는 현재 epoch 기준 loss가 저장됩니다. 이 외에 model.fit() 인자에 따라서 acc, val_loss 등도 저장됩니다.
  • on_train_batch_begin(batch, logs=None): 매번 훈련 batch가 시작될 때 호출됩니다. batch는 0으로 시작하는 현재 batch의 index입니다.
  • on_train_batch_end(batch, logs=None): 매번 훈련 batch가 끝날 때 호출됩니다. batch는 0으로 시작하는 현재 batch의 index입니다.
  • on_train_begin(logs=None): 훈련이 시작될 때 호출됩니다.
  • on_train_end(logs=None): 훈련이 끝날 때 호출됩니다.

실제로 간단한 Callback 클래스를 하나 만들어 보겠습니다.

import tensorflow as tf

class MyCallback(tf.keras.callbacks.Callback):
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.previous_loss = None

    def on_epoch_begin(self, epoch, logs=None):
        print('\nFrom {}: Epoch {} is starting'.format(self.name, epoch + 1))

    def on_epoch_end(self, epoch, logs=None):
        print('\nFrom {}: Epoch {} ended.'.format(self.name, epoch + 1))        

        if epoch > 0:            
            if (logs['loss'] < self.previous_loss):
                print('From {}: loss got better! {:.4f} -> {:.4f}'.format(self.name, self.previous_loss, logs['loss']))            

        self.previous_loss = logs['loss']                    

    def on_train_batch_begin(self, batch, logs=None):
        print('\nFrom {}: Batch {} is starting.'.format(self.name, batch + 1))

    def on_train_batch_end(self, batch, logs=None):
        print('\nFrom {}: Batch {} ended'.format(self.name, batch + 1))

first_callback = MyCallback('1st callback')
  • 일반적인 Python 클래스이기 때문에 생성자에 값을 원하는대로 줄 수 있습니다. 이 예시에서는 각 Instance를 구분하기 위해서 이름을 전달했습니다.
  • on_epoch_begin은 매 epoch이 시작될 때 호출됩니다. 여기에서는 단순히 Callback의 이름과 몇번째 epoch인지를 출력하고 있습니다.
  • on_epoch_end는 매 epoch이 끝났을 때 호출됩니다. on_epoch_begin과 하는 일이 유사한데, 추가로 loss를 추적하면서 현재 epoch이 이전 epoch보다 loss가 좋아졌는지 출력합니다.
  • on_train_batch_beginon_train_batch_end는 각 batch의 시작과 끝에 호출되고, 하는 일은 Callback의 이름과 현재 batch가 몇번째 인지 출력합니다.
  • 마지막으로 "1st callback"이라는 이름으로 Instance를 하나 만듭니다.

2.2. Callback 등록하기

간단한 모델을 하나 만들고 위에서 만든 Callback을 등록해보겠습니다. Tensorflow에서 제공하는 Iris Dataset을 사용하겠습니다. tensorflow_datasets 에 대한 설명은 다음 기회에 :)

import tensorflow_datasets as tfds

dataset = tfds.load('iris', as_supervised=True)['train']
dataset = dataset.batch(16)

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(64, input_shape=(4,)))
model.add(tf.keras.layers.Dense(16))
model.add(tf.keras.layers.Dense(3, activation='softmax'))

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(dataset, epochs=10, callbacks=[first_callback, MyCallback('2nd callback')])     

등록하려고 하는 Callback들은 model.fitcallbacks 인자로 전달합니다. callbacks은 list이기 때문에 여러개의 Callback을 전달할 수 있습니다. 이 예시에서는 앞에서 만든 '1st callback'과 호출 시에 바로 만든 '2nd callback'을 전달합니다. 이벤트가 발생하면 이 list에 들어있는 순서대로 Callback이 호출됩니다.

Callback은 model.fit 뿐 아니라 model.evaluate, model.predict 등에도 쓸 수 있습니다. 이와 관련된 이벤트들(on_predict_batch_begin, on_predict_batch_end, on_test_batch_begin, on_test_batch_end)도 있고요.

위 파일을 실행해보시면 기존 Tensorflow의 출력 앞뒤로 MyCallback의 결과들이 출력되는 것을 보실 수 있습니다.

Callback을 이용하면 재미있는 일들을 할 수 있습니다. 매 epoch이 끝날 때마다 결과를 이메일이나 메신저로 보낼 수도 있고요. learning rate을 원격으로 제어할 수도 있습니다.

3. Keras가 제공하는 Callback 들

3.1. EarlyStopping

몇 epoch이나 훈련을 시킬지는 중요한 hyper paramer 중 하나입니다. 너무 과하면 overfitting의 위험도 있고 시간도 필요이상으로 걸리고요. 한가지 해결 방법은 metric을 지켜보다가 더이상 성능이 나아지지 않는다 싶으면 훈련을 먼저 끝내는 것입니다. 이런 일이 매우 빈번하기 때문에 Tensorflow는 tf.keras.callbacks.EarlyStopping을 제공하고 있습니다.

callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)

위 예시에서는 loss가 3번 동안 나아지지 않으면 훈련을 종료합니다. val_lossval_accuracy를 monitor하고 있다면 loss대신 val_lossval_accuracy를 사용할 수 있습니다.

3.2. LambdaCallback

Callback을 새롭게 만들어 쓰고는 싶은데 전체 클래스를 새로 만들기에는 왠지 부담스러울 때가 있습니다. 이 때 Callback을 함수 형태로 전달할 수 있게 해주는 것이 tf.keras.callbacks.LambdaCallback입니다.

callback = tf.keras.callbacks.LambdaCallback(on_epoch_begin=lambda epoch, logs: print('We are starting epoch {}'.format(epoch +1 )))
model.fit(dataset, epochs=10, callbacks=[callback]) 

위와 같이 하면 간단하게 Callback을 만들 수 있습니다. LambdaCallback의 인자는 https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/LambdaCallback#arguments_1에서 확인하실 수 있습니다.

3.3. LearningRateScheduler

Leraning rate 역시 중요한 Hyper parameter인데요. Learning rate을 고정값으로 할 수도 있지만, 시간이 지남에 따라 변하게 하거나 동적으로 결정할 수도 있습니다. 이 때 사용할 수 있는 Callback이 tf.keras.callbacks.LearningRateScheduler입니다. 아래 예제는 Tensorflow 홈페이지에서 가져왔습니다.

def scheduler(epoch):
    if epoch < 10:
        return 0.001
    else:
       return 0.001 * tf.math.exp(0.1 * (10 - epoch))
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)

model.fit(dataset, epochs=100, callbacks=[callback])  

위와 같이 하면 10번째 epoch까지는 0.001을 사용하지만 (epoch은 0부터 시작합니다), 그 뒤부터는 epoch에 따라 결정되는 값을 사용합니다.

참고로 LearningRateSchedule의 실제 구현은 아래와 같이 비교적 간단합니다. (주석 등은 제거했습니다)

class LearningRateScheduler(Callback):
  def __init__(self, schedule, verbose=0):
    super(LearningRateScheduler, self).__init__()
    self.schedule = schedule
    self.verbose = verbose

  def on_epoch_begin(self, epoch, logs=None):
    if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
    try:  # new API
      lr = float(K.get_value(self.model.optimizer.lr))
      lr = self.schedule(epoch, lr)
    except TypeError:  # Support for old API for backward compatibility
      lr = self.schedule(epoch)
    if not isinstance(lr, (ops.Tensor, float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function '
                       'should be float.')
    if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
      raise ValueError('The dtype of Tensor should be float')
    K.set_value(self.model.optimizer.lr, K.get_value(lr))
    if self.verbose > 0:
      print('\nEpoch %05d: LearningRateScheduler reducing learning '
            'rate to %s.' % (epoch + 1, lr))

코드의 대부분은 Error 확인을 위한 부분이고, 실제로 Learning rate을 update하는 부분은 아래 정도입니다.

lr = self.schedule(epoch)
K.set_value(self.model.optimizer.lr, K.get_value(lr))

3.4. ModelCheckpoint

훈련 중간 중간 현재 Parameter의 값들을 저장하고 싶을 때도 있습니다. 혹시나 무엇인가가 잘못됐을 때 중간부터 시작하기 위해서일 수도 있고, valdiation 결과가 최적일 때의 값을 사용하고 싶을 수도 있고요. 이런 목적으로 Tensorflow는 tf.keras.callbacks.ModelCheckpoint를 제공합니다.

callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='./model/checkpoint.{epoch:02d}.hdf5',
    save_weights_only=True,    
    save_freq='epoch')

위와 같이 Callback을 만들면 매 epoch마다 결과가 ./model/checkpoint.01.hdf5와 같이 저장됩니다. 다른 인자를 이용하면 지금까지 결과 중 가장 성능이 좋은 결과만을 저장한다거나, 결과를 저장하는 주기를 조절할 수도 있습니다.

4. 마무리

Callback은 Tensorflow를 처음 배우기 시작할 때는 필요성이 크게 와닿지 않지만 실제 모델링을 하다보면 필수적인 요소입니다. Tensorflow가 이미 다양한 Callback을 제공하기도 하지만, 필요하다면 각자 입맛에 맞는 Callback을 만들어서 사용하는 것도 좋은 방법입니다.

model.fit을 쓰지 않고 Custom Training을 하신다면 Callback 프레임워크를 사용하지 않고, 직접 훈련 루프에 기능을 넣어도 됩니다.