본문 바로가기

논문 리뷰/Diffusion Model

Improved Vector Quantized Diffusion Models

이산 확산 모델의 샘플링 전략 개선

 

Arxiv

Github

 

Abstract

VQ-Diffusion에서 때때로 낮은 품질의 샘플이나 약한 상관관계의 이미지를 생성했는데, 주요한 원인 샘플링 전략 때문이라는 것을 발견하고 두 가지 중요한 기술을 제안한다.

  •  이산 확산 모델에 대한 classifier-free guidance를 탐구하고 보다 일반적이고 효과적인 구현을 제안
  • VQ-Diffusion의 joint distribution 문제를 완화하기 위한 추론 전략 제안

 

 

 

Introduction

VQ-Diffusion의 주요 장점 중 하나는 각 이산 토큰에 대한 확률을 추정할 수 있으므로 상대적으로 적은 추론 단계로 고품질 이미지를 생성한다는 것이다. 이를 바탕으로 VQ-Diffusion을 개선하기 위한 몇 가지 기술을 소개한다.

 

Discrete classifier-free guidance

조건부 이미지 생성의 경우 보통 확산 모델은 사전 확률 p(xㅣy)를 최대화하며 사후 확률 p(yㅣx)의 제약 조건을 만족할 것이라고 가정한다.

 

그러나, 연구진은 대부분의 경우 사후 확률을 무시할 수 있다는 것을 발견했다. 이것을 사후 문제로 명명하며, 이 문제를 해결하기 위해 사전과 사후를 동시에 고려할 것을 제안한다. Classifier-free guidance의 개선과 사후 제약으로 크게 개선된 이미지를 생성할 수 있다.


High-quality inference strategy

각 노이즈 제거 단계에서 각 토큰들은 독립적으로 동시에 샘플링되기 때문에 종속성이 무시될 수 있으며, 이것을 공동 분포 문제(joint distribution issue)라고 명명한다.

 

이 문제를 완화하기 위해 두 가지 핵심 설계를 기반으로 하는 추론 전략을 도입한다.

  • 토큰이 많을수록 공동 분포로 인해 더 많은 어려움을 겪는다. 따라서 샘플링되는 토큰 수를 줄인다.
  • 신뢰도가 높은 토큰이 더 정확한 경향이 있다는 것을 발견하여 신뢰도가 높은 토큰에 purity prior를 도입한다.

 

 

 

Background: VQ-Diffusion

 

 

 

Method

Discrete Classifier-free Guidance

조건부 이미지 생성 작업의 경우 손상된 입력과 조건 정보를 통해 이미지를 복구하는데, 손상된 입력이 조건 정보보다 훨씬 더 많은 정보를 갖고 있기 때문에 조건이 무시될 수 있다. 실제로도 VQ-Diffusion에서 입력 텍스트와의 상관관계가 적은 경우를 발견했다.

 

직접적인 해결책은 사전 확률과 사후 확률을 동시에 최적화하는 것. 다음과 같이 목표를 도출할 수 있다.


무조건(unconditional) 이미지 로짓 p(x)를 예측하기 위해 classifier-free guidance에서는 조건으로 null을 입력하지만 이 논문에서는 학습 가능한 벡터를 사용한다. 이쪽이 더 성능이 좋았다고.

 

다음 단계 예측 :


연속 도메인에서의 classifier-free guidance와 사후 확률 제약의 차이점 :

  • VQ-Diffusion은 재매개변수화를 이용하여 노이즈가 없는 p(xㅣy)를 예측하기 때문에 노이즈가 없는 상태에서 위의 argmax 등식을 적용할 수 있으므로 빠른 추론, 고품질 추론과 같은 다른 기술과 호환된다.
  • 연속적인 환경에서 확산 모델은 확률분포를 직접 예측하지 않지만, 이산 확산 모델에서는 직접 추정한다.
  • Null 대신 학습 가능한 벡터를 사용했을 때 성능이 더욱 향상되었다.

 

 

High-quality Inference Strategy

Fewer tokens sampling

논문에서는 꽤 헷갈리게 설명되어 있는데, 쉽게 설명하자면

예를 들어 t-1 단계에서 t 단계로의 순방향 과정에서 20개의 마스크가 추가되었다고 하자. 하지만 역방향에서 20개를 한번에 샘플링하면 공동 분포 문제가 생긴다. 그래서 이 20개를 5개씩 4단계로 나누어서 샘플링하자는 것이다.

 

우리는 이미 순방향 과정에서 마스킹될 확률 γt를 알고 있으니 샘플링 전에 단계를 얼만큼 나눌지 정할 수 있다.

 

마스크를 1개씩 복구하면 AR(자동회귀) 방식과 추론 속도가 같게 된다. 하지만 확산 모델은 토큰들을 왼쪽 위부터 순서대로 추론함으로써 생기는 편향이 생기지 않기 때문에 다르다.


Purity prior sampling

모든 토큰은 위치 독립적으로, 모든 위치에서 [MASK]가 될 확률은 같다.

하지만 위치에 따라 신뢰도가 다르며, purity가 높은 위치는 신뢰도도 높았다고 한다.

 

Purity와 정확도의 상관 관계 :

 

따라서 랜덤 샘플링 대신 purity에 의존하여 importance sampling을 수행한다.

 

뭔가 어렴풋이는 알겠는데 정확하게 설명은 못하겠는... 좀 그런 느낌이다.

내 해석대로라면 역확산 과정에서 일관성이 높은 자리의 마스크를 우선으로 복구한다는 뜻인 것 같다.

 

또한 purity가 높은 위치의 첨도를 높이기 위해 softmax로 확률을 조정한다.