본문 바로가기

논문 리뷰/Diffusion Model

One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale (UniDiffuser)

각 multi-modal 데이터를 공동 훈련하여 modality 확장성이 뛰어난 모델

 

Github

arXiv

 

 

Abstract

하나의 모델에서 multi-modal 데이터셋과 관련된 모든 분포를 맞추는 통합 확산 프레임워크 UniDiffuser 제안.

Unified view에서 영감을 얻은 UniDiffuser는 원래 확산 모델에 대한 최소한의 수정으로 모든 분포를 동시에 학습한다.

그렇다고 한다 1
그렇다고 한다 2

 

 

Introduction

Multi-modal 생성 작업은 확률적 모델링의 관점에서 해당 분포를 맞추는 것으로 볼 수 있다.

(e.g. text-to-image 생성은 조건부 분포 p(ImageㅣText)를 학습하는 것)

 

본 논문에서는 추가 학습 또는 오버헤드 없이 하나의 모델에서 모든 관련 분포를 명시적으로 맞추는 확산 기반 프레임워크 UniDiffuser를 소개한다.

 

핵심 통찰은 모든 modality의 데이터를 교란(i.e. timesteps)시키고, 서로 다른 modality의 개별 timestep을 입력하고, 모든 modality의 noise를 예측하는 것이다.

 

UniDiffuser는 추가 오버헤드 없이 적절한 시간 단계를 설정하여 image, text, T2I, I2T, image와 text 쌍 생성 등을 수행할 수 있다.

 

 

 

Method

본 논문에서는 두 가지 데이터에만 초점을 두지만 더 많은 modality로 쉽게 확장할 수 있다.

 

UniDiffuser: One Diffusion Fits All Distributions

두 가지 modality의 데이터 x, y가 있다고 할 때 q(x), q(y), q(xㅣy), q(yㅣx), q(x,y) 등 모든 관련 분포를 포착할 수 있는 모델을 설계하는 것이 목표.

 

확산 모델의 학습은 노이즈에 대한 조건부 기댓값을 추정하는 것과 같다.

주변 분포 q(x0)를 모델링 하는 것은 xt에서 기댓값 E[εxㅣxt]를, 조건부 분포는 E[εxㅣxt,y0]를 추정하는 것으로 볼 수 있다.

 

중요한 것은 기댓값을 일반형 E[εxyㅣxtx,yty]으로 통합할 수 있다는 것이다.

(ty = 0일 경우 E[εxㅣxt,y0], T일 경우 E[εxㅣxt])

 

네트워크 손실

 

자세한 학습 알고리즘은 부록 B.

두 개의 timestep으로 인해 약간 더 높은 분산을 갖지만 수렴에 어려움은 없음.

더 다양한 modality에 대한 확장성을 위해 transformer 기반 네트워크 사용.

 

Classifier-Free Guidance for Free

Classifier-free guidance

 

UniDiffuser의 형태로 수정:

 

x, y를 공동 샘플링 하는 경우:

 

 

 

UniDiffuser on Images and Texts

Encoding Images and Texts into Latent Space

Image encoder-decoder

이미지 인코딩에는 재구성을 위한 Stable Diffusion의 auto-encoder, 의미론을 위한 CLIP을 사용.

 

Text encoder-decoder

인코더로 CLIP, 디코더로 GPT-2 사용.

 

Transformer as Joint Noise Prediction Network

확산 모델을 위해 최근 제안된 transformer인 U-ViT 사용.

U-ViT는 모든 입력을 토큰으로 처리하고 잔차 연결을 사용함.

또한 U-ViT의 pre-layer 정규화가 혼합 정밀도에서 overflow를 유발하기 때문에 post-layer 정규화로 변경하고 잔차 연결 이후에 layer 정규화 추가.

 

 

 

Experiments

다른 multi-modal 확산 모델인 Versatile Diffusion과 비교

 

Text-to-image 모델들과 비교