Pytorch로 훈련 이어서하기 (checkpoint)

모델을 훈련시키는데 오랜 시간이 걸리다보면 여러가지 이유로 처음에 생각했던 Epoch만큼 훈련을 시키지 못하고 중간에 끝나는 경우가 있습니다. 누군가 실수로 Ctrl+C를 눌러버린다거나, 모종의 이유로 훈련 프로세스가 죽어버린다거나, 아니면 GPU를 다른 곳에 쓰기 위해서 눈물을 머금고 중간에 멈추는 경우도 있습니다. 이럴 때 유용한 방법이 중간 중간 모델을 저장해두고 나중에 그 시점부터 훈련을 이어서 하는 것입니다. 오늘은 중간중간 모델을 저장하는 방법과 나중에 이어서 훈련을 시작하는 방법을 알아보겠습니다.

import torch
import torch.nn as nn
import torch.optim as optim

import argparse

class MyModel(nn.Module):
    def __init__(self, input_dim):
        super(MyModel, self).__init__()

        self.input_dim = input_dim

        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return self.linear(x)

def generate_data(input_dim, num_samples, num_batches):
    x = torch.rand((num_samples, input_dim)).reshape(-1, num_batches, input_dim)
    y = torch.randint(2, (num_samples,)).reshape(-1, num_batches).float()

    return x, y

먼저 오늘 사용할 간단한 Model을 하나 정의했습니다. 그리고 가상의 훈련 데이터를 만들어 줄 generate_data()를 만듭니다. y0 또는 1이라고 가정했습니다. generate_data()를 호출하면, (num_samples / num_batches, num_batches, input_dim) 차원을 가지는 x(num_sample / num_batches, num_batches) 차원을 가지는 y가 만들어집니다.

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output-prefix')
    parser.add_argument('--resume-from')
    parser.add_argument('--input-dim', type=int)

    args = parser.parse_args()

    x, y = generate_data(args.input_dim, 3200, 32)

다음으로 몇가지 명령행 인자를 정의합니다.

  • --output-prefix: 저장할 중간 모델과 최종 모델의 Prefix입니다. 예를 들어 이 값으로 model이라고 줬다면 중간 모델은 매 epoch마다 model.0001.pth와 같이 저장되고, 최종 모델은 model.final.pth와 같이 저장되게 하려고 합니다.
  • --resume-from: 여기에 중간 모델의 경로를 적어주면 이 모델이 저장된 값들을 불러들이고, 거기에서부터 훈련을 시작합니다.
  • --input-dim: 모델의 파라미터입니다.

그리고 generate_data()를 호출해서 훈련 데이터를 준비합니다.

    if args.resume_from:
        # 저장했던 중간 모델 정보를 읽습니다.
        model_data = torch.load(args.resume_from)

        model = MyModel(model_data['input_dim'])
        # 저장했던 모델의 값들로 치환합니다.
        model.load_state_dict(model_data['model_state_dict'])

        optimizer = optim.Adam(model.parameters())
        # optimizer도 중간에 저장했던 값들로 치환합니다.

        optimizer.load_state_dict(model_data['optimizer_state_dict'])

        # 지금 시작할 epoch은 기존 epoch + 1 즉 다음 epoch입니다.
        start_epoch = model_data['epoch'] + 1
    else:
        model = MyModel(args.input_dim)

        optimizer = optim.Adam(model.parameters())

        start_epoch = 1

args.resume_from의 값이 있다는 것은 모델을 처음부터 훈련시키는 것이 아니고 이미 저장된 중간 모델부터 실행을 하고 싶다는 의미입니다. 뒤에 중간 모델을 저장하는 코드를 보면 의미가 더 명확해질 겁니다. 여기에서 가장 중요한 것은 model의 state_dict 뿐 아니라 optimizer의 state_dict도 같이 복원을 해줘야 한다는 점입니다.

args.resume_from이 없으면 처음부터 훈련을 시킵니다.

    loss_fn = nn.BCEWithLogitsLoss()

    EPOCH = 1000

    for epoch in range(start_epoch, EPOCH):
        total_loss = 0

        for x_batch, y_batch in zip(x, y):
            logits = model(x_batch)

            optimizer.zero_grad()

            loss = loss_fn(logits.reshape(-1), y_batch)
            loss.backward()

            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch: {epoch:>4d}\tLoss: {total_loss / len(x):.5f}')

        # 나중에 읽을 수 있도록 필요한 정보를 저장합니다.
        torch.save({
            'epoch': epoch,
            'optimizer_state_dict': optimizer.state_dict(),
            'model_state_dict': model.state_dict(),
            'input_dim': args.input_dim
        }, f'{args.output_prefix}.{epoch:04d}.pth')

일반적인 pytorch 훈련 부분입니다. 매 epoch이 끝날 때 마다 중간 결과를 저장합니다. 앞에 설명했듯이 model의 state_dict 뿐 아니라 optimizer의 state_dict도 저장합니다.

    torch.save({
        'model_state_dict': model.state_dict(),
        'input_dim': args.input_dim
    }, args.output_prefix + '.final.pth')

마지막으로 모든 훈련이 끝나면 최종 모델을 저장합니다. 이때는 optimizer의 state_dict는 필요하지 않기 때문에 저장하지 않습니다.

최초에 훈련을 시작할 때는 아래와 같이 실행합니다.

train.py --input-dim 4 --output-prefix model

그럼 매 epoch마다 model.0001.pth와 같이 중간 정보들이 저장됩니다.

만약 18번째 epoch부터 이어서 다시 훈련을 하고 싶다면 아래와 같이 실행합니다.

train.py --input-dim 4 --output-prefix model --resume-from model.0018.pth

사실 여기에서 --input-dim은 불필요합니다. 재실행시에는 이 값을 저장된 중간 정보에서 가져오기 때문입니다.

이렇게 중간 중간 정보를 저장할 때는 한가지 주의할 점이 있습니다. model과 optimizer의 state_dict를 저장하다보면 중간 파일의 크기가 꽤 커집니다. 파일 하나당 수기가 될 수도 있습니다. 매 epoch 마다 저장한다면 아주 많은 저장 공간을 차지할 수도 있습니다. 이 때문에 훈련 특성에 따라서 매 epoch 마다 저장을 하는 대신에 마지막 epoch만 저장을 할 수도 있습니다.

이제 훈련 중간에 프로세스가 죽는 공포에서 조금이나마 벗어나 보아요~!