Deep Learning

Tensorflow의 Embedding Layer vs fastText

둔진 2020. 4. 30. 07:21

  지난 번 한국어 토큰 테스트(한국어 토큰의 단위는 뭐가 좋을까?)를 해보고나니, Tensorflow(Keras)의 Embedding은 어떨지 궁금했습니다. 그래서 지난번 테스트 중 가장 성능이 좋았던 "형태소+글자" 단위 토큰으로 비교를 해보았습니다.

 

  Embedding dimension은 fastText의 기본값 100으로 그대로 하고요. OOV를 위한 토큰을 추가합니다.

import tensorflow as tf
import numpy as np
from tqdm import tqdm
from konlpy.tag import Mecab

OOV_TOKEN = '<UNK>'
EMBEDDING_DIM = 100

  Toknizer와 코퍼스 읽는 건 앞에서 그대로 가져오고요.

tagger = Mecab()

def tokenize_by_morpheme_char(s):
    return tagger.morphs(s)

def parse_corpus_line(line):
    _, sentence, label = line.strip().split('\t')
    label = int(label)
    return sentence, label

def read_corpus(path):
    sentences = []
    labels = []

    for line in open(path):
        sentence, label = parse_corpus_line(line)
        sentences.append(sentence)
        labels.append(label)

    return np.array(sentences), np.array(labels)

  문자열을 index로 변환하는 것은  tf.keras.preprocessing.text.Tokenizer 를 써도 되지만 간단히 직접 구현했습니다.

def build_index(sentences):
    token_to_index = {}
    index_to_token = {}
    
    index = 0
    for tokens in sentences:
        for token in tokens:
            if token in token_to_index:
                continue
            token_to_index[token] = index
            index_to_token[index] = token
            index += 1
    
    token_to_index[OOV_TOKEN] = index
    index_to_token[index] = OOV_TOKEN
    
    return token_to_index, index_to_token
        
def convert_to_index(tokens, token_to_index):
    result = []
    for token in tokens:
        if token in token_to_index:
            result.append(token_to_index[token])
        else:
            result.append(token_to_index[OOV_TOKEN])
            
    return np.array(result)

 

  기존 실험과 최대한 똑같이 모델을 구성해주고 돌립니다.

if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()               
    parser.add_argument('train')
    parser.add_argument('test')    
    parser.add_argument('--batch-size', type=int, default=128)   
    parser.add_argument('--epochs', type=int, default=10)    
    args = parser.parse_args()
   
    sentences, labels = read_corpus(args.train)    
    tokenized_sentences = [tokenize_by_morpheme_char(sentence) for sentence in tqdm(sentences, desc='Tokenizing')]    
    
    token_to_index, index_to_token = build_index(tokenized_sentences)
    
    train_data = [convert_to_index(s, token_to_index) for s in tokenized_sentences]
    train_data = tf.keras.preprocessing.sequence.pad_sequences(train_data, padding='post')
    train_data = tf.data.Dataset.from_tensor_slices((train_data, labels)).batch(args.batch_size)    
    
    test_sentences, test_labels = read_corpus(args.test)
    tokenized_sentences = [tokenize_by_morpheme_char(sentence) for sentence in tqdm(test_sentences, desc='Tokenizing')]    
    test_data = [convert_to_index(s, token_to_index) for s in tokenized_sentences]    
    test_data = tf.keras.preprocessing.sequence.pad_sequences(test_data, padding='post')
    test_data = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).batch(args.batch_size)
    
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Embedding(len(token_to_index), EMBEDDING_DIM))
    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True)))
    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128)))
    model.add(tf.keras.layers.Dense(64))
    model.add(tf.keras.layers.Dense(1, activation='sigmoid'))    
     
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])    
    model.fit(train_data, epochs=args.epochs)
    
    loss, accuracy = model.evaluate(test_data)
    print(loss)
    print(accuracy)

성능

  Accuracy
fastText 87.07%
Keras Embedding 84.05%

  생각보다 차이가 나는군요.

 

  성능 차이는 아무래도,

  • fastText는 네이버 훈련 데이터 외에 한국어 Wikipedia 데이터도 사용했기 때문에 훨씬 커버리지가 큽니다.
  • fastText의 알고리즘 특성상 OOV에 대해서 성능이 더 좋습니다. 네이버 코퍼스가 구어체가 많기 때문에 이 점도 영향을 주었을 것 같고요.

  예전 식으로 이만 총총이라고 쓰려다가 궁금해서 찾았봤더니 정말 "총총"이 맞는 말이군요;; https://ko.dict.naver.com/#/entry/koko/255f2166b1ba4ad4aa0ccbf4556fd5fb

  그럼 이만 총총.