[Github]
[arXiv](2019/09/04 version v2)
Abstract
Masked token을 병렬로 디코딩하여 텍스트 생성
Conditional Masked Language Models
X, Yobs가 주어지면 Ymask에 속한 토큰들의 개별 확률을 예측해야 한다.
Architecture
Causal mask를 제거한 양방향 transformer.
Training Objective
Ymask에 대한 cross-entropy.
Predicting Target Sequence Length
전체 시퀀스를 병렬로 예측하기 때문에 AR 모델과 같이 동적으로 시퀀스의 끝을 결정할 수 없다.
BERT의 cls token과 비슷하게 length token을 입력하고 목표 시퀀스의 길이 N을 예측하도록 훈련되었다.
Decoding with Mask-Predict
N이 결정되면 시퀀스 (y1, . . . , yN), 각 토큰 확률 (p1, . . . , pN)을 정의한다.
첫 번째 반복에서 모든 토큰을 마스킹하고 예측하며, 확률이 낮은 토큰을 다시 마스킹한다. (노란색)
이후 미리 정해진 T 시간 동안 토큰을 예측하고, 토큰 확률을 조정하고, 확률이 낮은 토큰을 다시 마스킹하고를 반복한다.
Deciding Target Sequence Length
Length token에 대해 l개의 길이 후보를 선택하고 각 길이에 대해 병렬 디코딩 후 평균 로그 확률이 가장 높은 시퀀스를 선택한다.
Analysis
Model Distillation
기계 번역에 대한 이전 연구들과 마찬가지로 AR transformer로 생성된 번역에 대해 모델을 훈련한다.
Why Are Multiple Iterations Necessary?
토큰들을 병렬로 디코딩하기 위해 개별 토큰 예측이 서로 독립이라는 가정을 하는데, 이로 인해 t = 0과 같이 같은 단어가 반복되는 현상이 나타난다.
Multiple iteration은 각 토큰의 다중 모드 분포를 선명한 단일 모드 분포로 축소함으로써 문제를 완화할 수 있다.
T가 커질 수록 반복 토큰의 비율이 줄어듦을 볼 수 있다.