본문 바로가기

코드 리뷰/Diffusion

Classifier-Guidance Diffusion

논문 리뷰

 

Diffusion Models Beat GANs on Image Synthesis 논문 리뷰

Diffusion 모델 성능 개선, 분류기 가이드 도입 Github GitHub - openai/guided-diffusion Contribute to openai/guided-diffusion development by creating an account on GitHub. github.com Arxiv Diffusion Mo..

ostin.tistory.com

코드

 

GitHub - openai/guided-diffusion

Contribute to openai/guided-diffusion development by creating an account on GitHub.

github.com

 


Classifier

https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/unet.py#L683

 

GitHub - openai/guided-diffusion

Contribute to openai/guided-diffusion development by creating an account on GitHub.

github.com

분류자로는 EncoderUNetModel 클래스를 사용.

이미지를 k 길이의 텐서로 변환.


가이드 과정

p_sample 메소드를 살펴보면

def p_sample(
    self,
    model,
    x,
    t,
    clip_denoised=True,
    denoised_fn=None,
    cond_fn=None,
    model_kwargs=None,
):
    out = self.p_mean_variance(
        model,
        x,
        t,
        clip_denoised=clip_denoised,
        denoised_fn=denoised_fn,
        model_kwargs=model_kwargs,
    )
    noise = th.randn_like(x)
    nonzero_mask = (
        (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
    )  # no noise when t == 0
    if cond_fn is not None:
        out["mean"] = self.condition_mean(
            cond_fn, out, x, t, model_kwargs=model_kwargs
        )
    sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
    return {"sample": sample, "pred_xstart": out["pred_xstart"]}

p_mean_variance 메소드로 다음 단계의 통계값을 구한 뒤 condition_mean 메소드로 평균을 이동시킨다.

 

def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
    """
    Compute the mean for the previous step, given a function cond_fn that
    computes the gradient of a conditional log probability with respect to
    x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
    condition on y.
    This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
    """
    gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
    new_mean = (
        p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
    )
    return new_mean
def cond_fn(x, t, y=None):
    assert y is not None
    with th.enable_grad():
        x_in = x.detach().requires_grad_(True)
        logits = classifier(x_in, t)
        log_probs = F.log_softmax(logits, dim=-1)
        selected = log_probs[range(len(logits)), y.view(-1)]
        return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

'코드 리뷰 > Diffusion' 카테고리의 다른 글

DiffStyler 코드 리뷰  (0) 2023.01.16
Paint by Example 코드 리뷰  (1) 2023.01.15
DAAM 코드 리뷰  (0) 2023.01.12
Latent Diffusion  (0) 2022.12.28
Improved DDPM  (0) 2022.09.28
Denoising Diffusion Pytorch  (1) 2022.09.25