본문 바로가기

논문 리뷰/Language Model

RL with KL penalties is better viewed as Bayesian inference

LM fine-tuning을 강화 학습으로 보는 관점과 베이즈 추론으로 보는 관점 비교

 

[arXiv](Current version v2)

 

Introduction

Reinforcement learning(RL) from human feedback은 현재 Language Model(LM)의 정렬에 매우 많이 쓰이는 방식이다.

또한 그중에서도 흔하게 쓰이는 방식인 KL-regularised RL에 대한 분석을 제공하고 이를 베이즈 추론으로 보는 대안적인 관점을 제시한다.

 

결론적으로 RL은 LM fine-tuning 같은 문제에 대한 적절한 프레임워크가 아니라고 한다.

 

 

 

Fine-tuning language models using standard RL and distribution collapse

 

X를 토큰의 시퀀스의 집합이라고 하면 LM π는 X에 대한 확률분포로 볼 수 있다.

π(x)는 시퀀스 x ∈ X의 확률이고 보상 함수 r은 x에 보상을 할당한다. 

r(x)는 π를 정렬(e.g. fine-tuning) 하기 위한 인간의 선호도를 나타낼 수 있다.

 

πθ는 피라미터 θ를 가진 LM이고, 이를 보상 함수 r을 통해 fine-tuning 하기 위한 강화 학습의 목표는 LM 분포 하에서 예상되는 보상이다.

직관적으로 이를 최대화한다는 것은 좋은 시퀀스에 좋은 보상을, 나쁜 시퀀스에 나쁜 보상을 주는 것을 의미한다. (좋은 보상을 받을 확률을 최대로, 나쁜 보상을 받을 확률을 최소로)

 

RL 목표의 문제점은 LM을 정책으로 취급한다는 것이다. 이로 인해 샘플 분포를 포착하는 것이 아니라 최적의 단일 시퀀스 x*을 찾는 목표로 퇴보한다. (LM은 보통 자동회귀적으로 다음 문장을 예측하거나 질문에 대한 답을 얻기 위해서 사용된다. 어떤 질문에 대한 완벽한 대답을 x*, 모든 질문에 대한 x*의 매핑, 즉 LM 자체는 상황에 대한 최적의 매핑 모음=정책인 π*에 대응된다.)

 

보상 최대화로 인한 이러한 분포 붕괴는 LM의 다양성 감소로 나타나며, 실제로 r이 인간 선호도를 완벽하게 포착하고, 정말로 x*이 최적의 시퀀스인 경우에도, 우리는 LM이 x*만 생성하는 것을 원하지 않는다. (과연 그런가? 질문에 대한 챗봇의 응답이 다양할 필요가 있는가?)

 

 

 

Fine-tuning language models via KL-regularised RL

분포를 보존하기 위해 주로 쓰이는 해결책은 원래의 사전 훈련된 모델과의 KL-divergence를 통해 페널티를 주는 것이다.

 

그러나 이러한 문제를 RL로 취급하는 것이 과연 필요한가? 본문에서는 이 문제를 베이즈 추론으로 보는 관점을 제시한다.

 

 

 

KL-regularised RL as variational inference

보상 함수 r에 의해 사전 훈련된 π를 fine-tuning 하는 것은 새로운 증거에 알맞게 사전 분포를 업데이트하는 베이즈 추론 문제로 볼 수 있다. 

 

보상 함수는 낮은 보상 시퀀스보다 높은 보상 시퀀스를 더 많이 만드는 X에 대한 분포로 나타낼 수 있고, 이는 보상 r을 지수화하고 다시 정규화하는 간단한 방법으로 표현 가능하며, 사후 확률은 다음과 같다.

 

이는 이전에 설명한 '최적의 정책'과 일치하고

 

KL 정규화된 RL 목표는 LM 분포와 목표 분포 사이의 KL-divergence를 최소화하는 형태로 표현할 수 있다.

 

이러한 관점이 의미 있는 이유는 RL에서 잃어버렸던 LM의 분포적인 성격을 다시 되찾아왔다는 것에 있다.

 

 

 

Separation of modelling and inference

베이즈 관점의 더 근본적인 이점은 모델링과 추론의 분리이다.

베이즈 관점에서 모델링은 보상을 증거로 사후 분포를 구하는 것, 추론은 사후 분포에서 샘플링하는 것이다.

 

하지만 강화 학습의 관점에서는, 정책은 이미 있으며, 가장 많은 보상을 얻을 수 있는 선택지를 고르는 것이 전부다.

 

예를 들어, 이 동영상을 예로 들어보면

 

이 상황에서 왼쪽 키를 누를 것인가? 오른쪽 키를 누를 것인가? 에 대한 선택지가 있다고 하자.

 

강화학습에서는 현재 상태(state)를 고려하여 정책(policy)에서 어느 쪽으로 갈지 결정한다.

 

베이즈 관점에서 보면 해당 car game에 대한 정책은 보상 함수로 미리 모델링해 놓은 분포에서 샘플링한 선택지 중 확률이 높은(많은 보상) 선택지이다. 

 

그러니까, RL은 보상을 참고하여 정책을 수립하고, 실전에서 해당 정책을 사용한다.

e.g. 길이 오른쪽으로 꺾이네? 정책에서 그럴 때는 오른쪽으로 가라고 적혀 있네? 오른쪽으로 가야지. 

 

베이즈 관점은

e.g. 길이 오른쪽으로 꺾일 때 오른쪽으로 가면 안 부딪힐 확률 90%, 왼쪽으로 가면 15%? 그러면 오른쪽으로 가야지.

 

RL은 실전에서 정책에 없는 선택지를 고려하지 않는다.

사실 이것 또한 분포적인 성격에서 오는 차이이다.

 

 

(뇌피셜 좀 첨가했는데 설명 좀 괜찮지 않나? ㅋㅋ 아님말고.)