본문 바로가기

논문 리뷰/Language Model

Mask-Predict: Parallel Decoding of Conditional Masked Language Models

[Github]

[arXiv](2019/09/04 version v2)

 

 

 

Abstract

Masked token을 병렬로 디코딩하여 텍스트 생성

 

 

Conditional Masked Language Models

X, Yobs가 주어지면 Ymask에 속한 토큰들의 개별 확률을 예측해야 한다.

 

Architecture

Causal mask를 제거한 양방향 transformer.

 

Training Objective

Ymask에 대한 cross-entropy.

 

Predicting Target Sequence Length

전체 시퀀스를 병렬로 예측하기 때문에 AR 모델과 같이 동적으로 시퀀스의 끝을 결정할 수 없다.

BERT의 cls token과 비슷하게 length token을 입력하고 목표 시퀀스의 길이 N을 예측하도록 훈련되었다.

 

 

 

Decoding with Mask-Predict

N이 결정되면 시퀀스 (y1, . . . , yN), 각 토큰 확률 (p1, . . . , pN)을 정의한다.

 

첫 번째 반복에서 모든 토큰을 마스킹하고 예측하며, 확률이 낮은 토큰을 다시 마스킹한다. (노란색)

 

이후 미리 정해진 T 시간 동안 토큰을 예측하고, 토큰 확률을 조정하고, 확률이 낮은 토큰을 다시 마스킹하고를 반복한다.

 

Deciding Target Sequence Length

Length token에 대해 l개의 길이 후보를 선택하고 각 길이에 대해 병렬 디코딩 후 평균 로그 확률이 가장 높은 시퀀스를 선택한다.

 

 

 

Analysis

Model Distillation

기계 번역에 대한 이전 연구들과 마찬가지로 AR transformer로 생성된 번역에 대해 모델을 훈련한다.

 

Why Are Multiple Iterations Necessary?

토큰들을 병렬로 디코딩하기 위해 개별 토큰 예측이 서로 독립이라는 가정을 하는데, 이로 인해 t = 0과 같이 같은 단어가 반복되는 현상이 나타난다. 

 

Multiple iteration은 각 토큰의 다중 모드 분포를 선명한 단일 모드 분포로 축소함으로써 문제를 완화할 수 있다.

 

T가 커질 수록 반복 토큰의 비율이 줄어듦을 볼 수 있다.