본문 바로가기

논문 리뷰/Language Model

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

메두사 같은 병렬 헤드를 통한 빠른 생성

 

[Github]

[arXiv](2024/01/19 version v1)

 

 


 

본 논문의 대략적인 맥락, 결과에 대해서는 Medusa Homepage에 잘 설명되어 있다.

 

Homepage

Tianle Cai*, Yuhong Li*, Zhengyang Geng, Hongwu Peng, Tri Dao (* Equal contribution)

sites.google.com

 


참고: Speculative Decoding

 


 

Key Components

Medusa Heads

Original head가 t번째 토큰을 예측할 때, k번째 medusa head는 t+k번째 토큰을 예측하도록 훈련된다.

 

Medusa head는 이전 연구와 똑같이 잔차 연결이 있는 feedforward network로 구현되며,

초기 상태가 원래 모델과 같도록 W1은 original head의 가중치로, W2는 0으로 초기화된다.

 

Tree Attention

다음 그림은 이전 head에서 2개의 토큰 후보를 출력했을 때, 다음 head에서 3개의 토큰 후보를 출력하기 위한 tree attention을 보여준다. k개의 후보에 대한 attention을 동시에 처리하여 속도를 높인다.

 

 

 

Training Strategies

Medusa-1: Frozen Backbone

일반적인 negative log likelihood를 사용하여 medusa head를 훈련한다.

k가 클수록 불확실성이 증가함에 따라 손실이 커지기 때문에 1보다 작은 값의 k-제곱인 λ를 사용하여 스케일링.

 

Medusa-2: Joint Training

정확도를 더욱 향상하기 위해 백본과 함께 훈련할 수 있다. 하지만 원본 모델의 품질을 보존하기 위해 특별한 전략이 필요하다.

 

Combined loss

 

Differential learning rates

백본 모델은 이미 잘 훈련되어 있으므로 낮은 학습률을 사용.

 

Heads warmup

훈련 초기 medusa head의 loss가 크며 이로 인한 큰 gradient가 백본 모델을 압도하여 백본의 피라미터가 왜곡될 수 있다.

 

LP-FT에 따라 Medusa-1 loss를 통해 medusa head를 먼저 훈련한 뒤 전체 모델을 훈련하는 2-stage 전략을 사용한다.

아니면 λ0를 작은 값에서 시작하여 점진적으로 증가시키는 방법을 사용할 수도 있다. 두 방법 모두 잘 작동한다.

더보기

논문의 현재 version-1에는 backbone을 먼저 훈련하라고 적혀 있다. 근데 아무리 생각해도 medusa head를 먼저 훈련하는 것이 맞는 것 같아서 github에 직접 물어봤다.

 

그랬더니 논문이 잘못 표기되어 있으며, 다음 version에서 수정한다고 한다!!!

뭔가 뿌듯했다...

 

https://ostin.tistory.com/399

 

논문 수정에 기여해버렸다 ㅎ

Medusa 논문을 보고 있었는데 HomepageTianle Cai*, Yuhong Li*, Zhengyang Geng, Hongwu Peng, Tri Dao (* Equal contribution)sites.google.com 이 Heads warmup 부분이 아무리 생각해도 이해가 안 되었다. First stage에서 backbone을 훈

ostin.tistory.com

 

Extensions

Typical Acceptance

Speculative Decoding에 따라 생성되는 시퀀스의 길이 결정에 수용-거부 메커니즘을 사용한다.

 

하지만 분포의 일치가 굳이 필요하지 않다는 이후 논문들에 따라 해당 논문에서 사용하는 모델 평가 방법 대신 η-sampling을 통해 거부할 토큰을 결정한다.

 

최소한 1개의 토큰 생성을 보장하기 위해 첫 번째 토큰에는 greedy decoding(무조건 수용)을 적용, 현재 decoding step의 최종 예측은 허용되는 가장 긴 prefix에 의해 결정된다.

 

Self-Distillation

목표 모델의 출력 분포와 일치하는 데이터셋이 없을 수도 있다. 따라서 간단한 Self-Distillation pipeline 제안.

 

먼저 유사한 도메인의 공개 데이터셋을 가져오고 해당 prompt에 대해 모델과 대화를 진행하여 출력 분포의 데이터를 얻을 수 있다. 해당 데이터셋에 대해 증류 손실 최적화:

 

여기서 2가지 문제가 발생한다:

  • 훈련 중에 2개의 모델을 유지해야 하므로 메모리 요구 사항이 증가
  • Medusa head는 이러한 데이터셋으로도 충분하지만, 증류 데이터로 백본을 훈련하면 성능이 저하됨

 

따라서 LoRA를 사용하여 메모리 요구 사항을 줄이고 원본 모델의 성능을 유지한다.

 

Searching for the Optimized Tree Construction

이전 섹션에 설명한 데카르트 곱을 통해 초기 tree를 형성하고 이후 Alpaca-eval 데이터셋에서 측정된 각 medusa head의 top-k 예측의 통계적 기댓값에 기반한 가지치기 알고리즘을 통해 수용 길이 기댓값을 최대화하는 sparse tree를 구축한다.