DiffStyler 코드 리뷰, DiffStyler 써보기
Text guided stylization + dual architecture
Abstract
Text-guided stylization 확산 모델 DiffStyler
Dual diffusion architecture를 사용하여 content와 style 사이의 균형을 제어
Content 이미지 기반의 학습 가능한 노이즈로 content의 구조를 보존
(모델을 학습시키는 게 아니라 sampling process의 입력인 학습 가능한 노이즈를 최적화하는 형태임)
Introduction
예제 이미지를 이용한 stylization은 content와 style을 분리하는 과정을 거쳐야 하지만 텍스트는 해당 스타일 자체에 대한 의미 정보만을 가지고 있기 때문에 stylization에 더 적합하다.
DiffStyler의 세 가지 개선점
- Content 이미지의 free diffusion process에서 무작위 노이즈를 학습 가능한 노이즈로 대체하여 content 구조 보존
- Content와 style의 조절을 위해 dual diffusion architecture 채택
- 수치적 방법의 측면에서 확산 과정의 시뮬레이션을 최적화하여 샘플링 속도를 높이면서 품질을 보장
(Free diffusion process는 노이즈 예측 네트워크를 이용해 0 → T로의 reverse sampling을 진행하는 것을 말한다.)
Method
DiffStlyer는 3가지 단계로 구성된다.
- 이미지 x0을 입력하고 free diffusion process T1(>T) 단계 진행
- Free diffusion의 T 단계에서의 결과 x̂T는 샘플링을 위해 dual diffusion model의 입력 x'T로 사용됨
- 샘플링 동안 최적화 수행
Dual Diffusion Models
Pseudo numerical methods for diffusion
(Plms 같은 샘플링 방법을 사용했다는 것이고 이후 논문 내용과는 관련 없는 부분.)
확산 과정과 수치적 방법 사이의 이론적 관계를 제공하기 위해 확산 모델의 관련 미분 방정식을 도출할 수 있다.
확산 모델 ~ SDE ~ ODE까지. (수학 심화적인 부분이 궁금하면 읽고 오기)
어쨌든 역확산 과정은 ODE를 푸는 문제로 볼 수 있고 여기에 다양한 고전적인 방법을 사용할 수 있다.
사소한 해결책의 경우 다음의 공식을 이용해 ODE를 푸는 Forward Euler Method가 있다.
(f = 식(1))
다음의 공식을 이용해 더 근사할 수 있는 Runge-Kutta method. 좀 더 현대적인 비슷한 방법으로 DPM-solver가 있음.
또 다른 공식을 이용해 ODE를 푸는 Linear Multi-Step Method 등이 있다.
Learnable noise
일반적인 확산 과정의 노이즈가 중첩된 이미지는 입력 이미지의 content 구조를 보존할 수 없다.
연구진들은 T 단계의 free diffusion 결과 xT를 입력 노이즈로 사용함으로써 guidance 없이도 어느 정도의 content를 보존할 수 있다는 것을 발견했다.
Free diffusion 단계 T1이 T보다 더 클수록 content가 더 잘 보존됨.
Basic architecture of DiffStyler
확산 모델이 natural 이미지에서만 훈련하면 실제와 가깝고, 예술 이미지에서만 훈련하면 매우 추상적인 결과가 나온다.
따라서 Conceptual 12M과 WikiArt에서 각각 훈련한 경량 네트워크 두 개의 선형 조합을 사용하여 샘플링 수행.
Network Optimization
Instruction loss
Stylization 이미지 xt와 텍스트 프롬프트 T의 CLIP 임베딩간의 cosine distance
Content perceptual loss
사전 훈련된 VGG19의 각 계층에서 추출된 feature map을 이용하여 content loss 계산
또한 contrast learning을 이용하여 특정 위치(VGG 계층)에서 입출력 패치를 일치시킨다. 입력 내의 다른 패치를 negative로 활용할 수 있다.
구체적으로, 동일 공간 위치의 두 패치 x'0, x0에 간단한 linear projection을 하고(v=F(x'0), v+=F(x0)) 나머지 N개의 패치들을 negative patch로 하여 (N+1) 분류 문제로 정한다.
Cross-entropy로 계산하면 :
s()는 두 패치 signal의 유사성에 대한 dot product이고, 이하 vls로 표기.
VGG의 특정 계층들과 모든 공간 패치에 대해 patchwise contrastive content loss 정의 :
Aesthetic loss
미적 손실을 사용하여 스타일이 인간 선호와 더 일치하도록 한다.
Simulacra Aesthetic Captions에서 훈련된 CLIP aesthetic regressions R에 대한 inference code를 맞춘다.
다 합치면 : (tv loss는 total variation loss로 지저분한 느낌 없앰)
Experiments