본문 바로가기

코드 리뷰/Diffusion

DiffStyler 코드 리뷰

DiffStyler 논문 리뷰, DiffStyler 써보기


Github/DiffStyler

 

GitHub - haha-lisa/Diffstyler: DiffStyler: Controllable Dual Diffusion for Text-Driven Image Stylization

DiffStyler: Controllable Dual Diffusion for Text-Driven Image Stylization - GitHub - haha-lisa/Diffstyler: DiffStyler: Controllable Dual Diffusion for Text-Driven Image Stylization

github.com


들어가기 전에 main.py 코드가 완전 현대적인? 함수형 프로그래밍이 아니라 함수 내부에서 전역 변수를 쓰는 경우가 좀 많음...


두 경량 확산 모델 중 cc12m_1.py가 CLIP 임베딩을 받는 조건부 모델이고 wikiart_256.py는 무조건 모델이다.

Free diffusion에 사용되는 건 cc12m 모델. 둘 다 수동 코딩이라 모델 구조가 한눈에 보인다.

 

특징으로는 경량 모델이라 그런지 조건을 cross attention이 아닌 Modulation block으로 받는다.

class Modulation2d(nn.Module):
    def __init__(self, state, feats_in, c_out):
        super().__init__()
        self.state = state
        self.layer = nn.Linear(feats_in, c_out * 2, bias=False)

    def forward(self, input):
        scales, shifts = self.layer(self.state['cond']).chunk(2, dim=-1)
        return torch.addcmul(shifts[..., None, None], input, scales[..., None, None] + 1)

main.py

0 ~ 1을 T1 step으로 하고 max_timestep으로 끊어서 T step만큼 reverse sampling, sampling 함

def run():
    t = torch.linspace(0, 1, args.steps + 1, device=device)
    steps = utils.get_spliced_ddpm_cosine_schedule(t)
    steps = steps[steps <= args.max_timestep]
    
    if args.method == 'ddim':
        x = sampling.reverse_sample(model, init, steps, {'clip_embed': zero_embed})
        out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {})
    if args.method == 'prk':
        x = sampling.prk_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
        out = sampling.prk_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
    if args.method == 'plms':
        ...
    if args.method == 'pie':
        ...
    if args.method == 'plms2':
        ...
    if args.method == 'iplms':
        ...
        
    utils.to_pil_image(out[0]).save(args.output)

 

Reverse sampling시 cc12m, zero-embed, 이미지와 모든 timestep을 입력으로 받음

@torch.no_grad()
def reverse_sample(model, x, steps, extra_args, callback=None):
    """Finds a starting latent that would produce the given image with DDIM
    (eta=0) sampling."""
    ts = x.new_ones([x.shape[0]])

    # Create the noise schedule
    alphas, sigmas = utils.t_to_alpha_sigma(steps)

    # The sampling loop
    for i in trange(len(steps) - 1, disable=None):

        # Get the model output (v, the predicted velocity)
        with torch.cuda.amp.autocast():
            v = model(x, ts * steps[i], **extra_args).float()

        # Predict the noise and the denoised image
        pred = x * alphas[i] - v * sigmas[i]
        eps = x * sigmas[i] + v * alphas[i]

        # Recombine the predicted noise and predicted denoised image in the
        # correct proportions for the next step
        x = pred * alphas[i + 1] + eps * sigmas[i + 1]

    return x

아무튼 reverse sampling에서 뽑은 x를 다시 sampling에 넣는데

 out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {})

 

def cfg_model_fn(x, t):
    n = x.shape[0]
    n_conds = len(target_embeds)
    x_in = x.repeat([n_conds, 1, 1, 1])
    t_in = t.repeat([n_conds])
    clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
    vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
    v = vs.mul(weights[:, None, None, None, None]).sum(0)

    v1 = cond_model_fn(x, t)
    v = args.free_scale*v+args.wikiart_scale*v1
    return v
def cond_model_fn(x, t, **extra_args):
    with torch.enable_grad():
        x = x.detach().requires_grad_()
        v = model1(x, t, **extra_args)
        alphas, sigmas = utils.t_to_alpha_sigma(t)
        pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
        cond_grad = cond_fn(x, t, pred, **extra_args).detach()
        v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
    return v

여기서 각종 loss 값을 줄이는 최적의 x를 찾는 게 목적이다.

 

Total loss :

def cond_fn(x, t, pred):
    clip_embed = F.normalize(target_embeds1.mul(weights1[:, None]).sum(0, keepdim=True), dim=-1)
    clip_embed = clip_embed.repeat([args.n, 1])
    if min(pred.shape[2:4]) < 256:
        pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
    clip_in = normalize(make_cutouts((pred + 1) / 2))
    image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
    losses = spherical_dist_loss(image_embeds, clip_embed[None])
    loss_NCE =calculate_NCE_loss(init.detach().to(device), pred.detach().to(device))
    clip_loss = losses.mean(0).sum()

    content_image = load_image2(args.init, 256,256)
    content_image = content_image.to(device)
    content_features = get_features(img_normalize(content_image), VGG)
    # target = 
    target_features = get_features(img_normalize(pred), VGG)
    content_loss = 0
    content_loss += torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
    content_loss += torch.mean((target_features['conv5_2'] - content_features['conv5_2']) ** 2)

    tv_losses = tv_loss(pred)
    init_losses = lpips_model(pred, init)
    aes_loss = (aesthetic_model_16(F.normalize(image_embeds, dim=-1))).mean() 


    total_loss = clip_loss * args.clip_guidance_scale+ loss_NCE * args.nce_scale \
                 + content_loss *args.lambda_c + tv_losses.sum() * args.tv_scale + init_losses.sum() * args.init_scale + aes_loss * args.aes_scale 
    grad = -torch.autograd.grad(total_loss, x)[0]
    return grad

 

CLIP loss

def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
clip_embed = F.normalize(target_embeds1.mul(weights1[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = clip_embed.repeat([args.n, 1])
if min(pred.shape[2:4]) < 256:
    pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
clip_loss = losses.mean(0).sum()

NCE loss

netAE = net.ADAIN_Encoder(vgg, args.gpu_ids).to(device)
netF = networks.define_F(args.input_nc, 'mlp_sample', args.normG, not args.no_dropout, args.init_type, args.init_gain, args.no_antialias, args.gpu_ids).to(device)

class PatchNCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        # self.opt = opt
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
        self.similarity_function = self._get_similarity_function()
        self.cos = torch.nn.CosineSimilarity(dim=-1)

criterionNCE = []
for nce_layer in args.content_nce_layers:
    criterionNCE.append(PatchNCELoss().to(device))
def calculate_NCE_loss(src, tgt):
    content_nce_layers = [int(i) for i in args.content_nce_layers.split(',')]    

    n_layers = len(content_nce_layers)
    feat_q, feat_k = netAE(tgt, src, encoded_only = True)
    #feat_q = self.netG_B(tgt, self.style_A, self.nce_layers, encode_only=True)
    #feat_k = self.netG_A(src, self.style_B, self.nce_layers, encode_only=True)
    feat_k_pool, sample_ids = netF(feat_k, args.num_patches, None)
    feat_q_pool, _ = netF(feat_q, args.num_patches, sample_ids)
    total_nce_loss = 0.0
    for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, criterionNCE, args.content_nce_layers):
        loss = crit(f_q, f_k)
        total_nce_loss += loss.mean()
    # print('total_nce_loss_A',total_nce_loss)
    return total_nce_loss / n_layers

Content loss

content_image = load_image2(args.init, 256,256)
content_image = content_image.to(device)
content_features = get_features(img_normalize(content_image), VGG)
# target = 
target_features = get_features(img_normalize(pred), VGG)
content_loss = 0
content_loss += torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
content_loss += torch.mean((target_features['conv5_2'] - content_features['conv5_2']) ** 2)

Tv loss

def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

Init loss

import lpips

lpips_model = lpips.LPIPS(net='vgg').to(device)
init_losses = lpips_model(pred, init)

Aes loss

aesthetic_model_16 = torch.nn.Linear(512,1).cuda()
aesthetic_model_16.load_state_dict(torch.load("./checkpoints/ava_vit_b_16_linear.pth"))

aes_loss = (aesthetic_model_16(F.normalize(image_embeds, dim=-1))).mean()

'코드 리뷰 > Diffusion' 카테고리의 다른 글

CompVis/stable-diffusion/main.py  (0) 2023.01.19
Paint by Example 코드 리뷰  (1) 2023.01.15
DAAM 코드 리뷰  (0) 2023.01.12
Latent Diffusion  (0) 2022.12.28
Classifier-Guidance Diffusion  (1) 2022.12.07
Improved DDPM  (0) 2022.09.28