본문 바로가기

논문 리뷰/Language Model

Fast Inference from Transformers via Speculative Decoding

[arXiv](2023/05/18 version v2)

 

 

Abstract

여러 개의 토큰을 병렬로 계산하여 더 빠르게 샘플링하는 Speculative Decoding 제안

 

 

 

Speculative Decoding

  1. 효율적인 모델 Mq가 토큰 시퀀스를 생성하고
  2. 목표 모델 Mp가 해당 시퀀스를 평가하여
  3. 토큰을 수용하거나 거부하고, 대안을 생성한다.

각 라인은 한 번의 decoding step이다.

 

Standardized Sampling

Argmax, top-k, nucleus, temperature 등 다양한 샘플링 설정이 있지만 본문에서는 생략하고 일반적인 경우만 가정.

 

Speculative Sampling

준비물: 각 모델, 토큰 시퀀스

 

γ개의 예측 생성

 

Mp를 병렬로 실행하여 γ개의 예측을 각각 생성

 

q(x)가 p(x) 보다 클 때, 일정 확률로 거부

 

분포를 조정한 후 조정된 분포 p'(x)에서 토큰을 샘플링하고 decoding step 종료.

 

 

과정을 보면 이게 빠른가? 싶지만 K개의 토큰을 생성할 때, K번의 큰 모델을 호출하지 않고 작은 모델을 K번 호출, 큰 모델은 한 번만 호출하여 병렬로 평가를 수행할 수 있기 때문에 빠르다.