들어가기 전에 main.py 코드가 완전 현대적인? 함수형 프로그래밍이 아니라 함수 내부에서 전역 변수를 쓰는 경우가 좀 많음...
두 경량 확산 모델 중 cc12m_1.py가 CLIP 임베딩을 받는 조건부 모델이고 wikiart_256.py는 무조건 모델이다.
Free diffusion에 사용되는 건 cc12m 모델. 둘 다 수동 코딩이라 모델 구조가 한눈에 보인다.
특징으로는 경량 모델이라 그런지 조건을 cross attention이 아닌 Modulation block으로 받는다.
class Modulation2d(nn.Module):
def __init__(self, state, feats_in, c_out):
self.state = state
self.layer = nn.Linear(feats_in, c_out * 2, bias=False)
def forward(self, input):
scales, shifts = self.layer(self.state['cond']).chunk(2, dim=-1)
return torch.addcmul(shifts[..., None, None], input, scales[..., None, None] + 1)
0 ~ 1을 T1 step으로 하고 max_timestep으로 끊어서 T step만큼 reverse sampling, sampling 함
def run():
t = torch.linspace(0, 1, args.steps + 1, device=device)
steps = utils.get_spliced_ddpm_cosine_schedule(t)
steps = steps[steps <= args.max_timestep]
if args.method == 'ddim':
x = sampling.reverse_sample(model, init, steps, {'clip_embed': zero_embed})
out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {})
if args.method == 'prk':
x = sampling.prk_sample(model, init, steps, {'clip_embed': zero_embed}, is_reverse=True)
out = sampling.prk_sample(cfg_model_fn, x, steps.flip(0)[:-1], {})
if args.method == 'plms':
if args.method == 'pie':
if args.method == 'plms2':
if args.method == 'iplms':
Reverse sampling시 cc12m, zero-embed, 이미지와 모든 timestep을 입력으로 받음
def reverse_sample(model, x, steps, extra_args, callback=None):
"""Finds a starting latent that would produce the given image with DDIM
(eta=0) sampling."""
ts = x.new_ones([x.shape[0]])
# Create the noise schedule
alphas, sigmas = utils.t_to_alpha_sigma(steps)
# The sampling loop
for i in trange(len(steps) - 1, disable=None):
# Get the model output (v, the predicted velocity)
with torch.cuda.amp.autocast():
v = model(x, ts * steps[i], **extra_args).float()
# Predict the noise and the denoised image
pred = x * alphas[i] - v * sigmas[i]
eps = x * sigmas[i] + v * alphas[i]
# Recombine the predicted noise and predicted denoised image in the
# correct proportions for the next step
x = pred * alphas[i + 1] + eps * sigmas[i + 1]
return x
아무튼 reverse sampling에서 뽑은 x를 다시 sampling에 넣는데
out = sampling.sample(cfg_model_fn, x, steps.flip(0)[:-1], 0, {})
def cfg_model_fn(x, t):
n = x.shape[0]
n_conds = len(target_embeds)
x_in = x.repeat([n_conds, 1, 1, 1])
t_in = t.repeat([n_conds])
clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
v = vs.mul(weights[:, None, None, None, None]).sum(0)
v1 = cond_model_fn(x, t)
v = args.free_scale*v+args.wikiart_scale*v1
return v
def cond_model_fn(x, t, **extra_args):
with torch.enable_grad():
x = x.detach().requires_grad_()
v = model1(x, t, **extra_args)
alphas, sigmas = utils.t_to_alpha_sigma(t)
pred = x * alphas[:, None, None, None] - v * sigmas[:, None, None, None]
cond_grad = cond_fn(x, t, pred, **extra_args).detach()
v = v.detach() - cond_grad * (sigmas[:, None, None, None] / alphas[:, None, None, None])
return v
여기서 각종 loss 값을 줄이는 최적의 x를 찾는 게 목적이다.
Total loss :
def cond_fn(x, t, pred):
clip_embed = F.normalize(target_embeds1.mul(weights1[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = clip_embed.repeat([args.n, 1])
if min(pred.shape[2:4]) < 256:
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
loss_NCE =calculate_NCE_loss(init.detach().to(device), pred.detach().to(device))
clip_loss = losses.mean(0).sum()
content_image = load_image2(args.init, 256,256)
content_image = content_image.to(device)
content_features = get_features(img_normalize(content_image), VGG)
# target =
target_features = get_features(img_normalize(pred), VGG)
content_loss = 0
content_loss += torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
content_loss += torch.mean((target_features['conv5_2'] - content_features['conv5_2']) ** 2)
tv_losses = tv_loss(pred)
init_losses = lpips_model(pred, init)
aes_loss = (aesthetic_model_16(F.normalize(image_embeds, dim=-1))).mean()
total_loss = clip_loss * args.clip_guidance_scale+ loss_NCE * args.nce_scale \
+ content_loss *args.lambda_c + tv_losses.sum() * args.tv_scale + init_losses.sum() * args.init_scale + aes_loss * args.aes_scale
grad = -torch.autograd.grad(total_loss, x)[0]
return grad
CLIP loss
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
clip_embed = F.normalize(target_embeds1.mul(weights1[:, None]).sum(0, keepdim=True), dim=-1)
clip_embed = clip_embed.repeat([args.n, 1])
if min(pred.shape[2:4]) < 256:
pred = F.interpolate(pred, scale_factor=2, mode='bilinear', align_corners=False)
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
clip_loss = losses.mean(0).sum()
NCE loss
netAE = net.ADAIN_Encoder(vgg, args.gpu_ids).to(device)
netF = networks.define_F(args.input_nc, 'mlp_sample', args.normG, not args.no_dropout, args.init_type, args.init_gain, args.no_antialias, args.gpu_ids).to(device)
class PatchNCELoss(nn.Module):
def __init__(self):
# self.opt = opt
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
self.similarity_function = self._get_similarity_function()
self.cos = torch.nn.CosineSimilarity(dim=-1)
criterionNCE = []
for nce_layer in args.content_nce_layers:
def calculate_NCE_loss(src, tgt):
content_nce_layers = [int(i) for i in args.content_nce_layers.split(',')]
n_layers = len(content_nce_layers)
feat_q, feat_k = netAE(tgt, src, encoded_only = True)
#feat_q = self.netG_B(tgt, self.style_A, self.nce_layers, encode_only=True)
#feat_k = self.netG_A(src, self.style_B, self.nce_layers, encode_only=True)
feat_k_pool, sample_ids = netF(feat_k, args.num_patches, None)
feat_q_pool, _ = netF(feat_q, args.num_patches, sample_ids)
total_nce_loss = 0.0
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, criterionNCE, args.content_nce_layers):
loss = crit(f_q, f_k)
total_nce_loss += loss.mean()
# print('total_nce_loss_A',total_nce_loss)
return total_nce_loss / n_layers
Content loss
content_image = load_image2(args.init, 256,256)
content_image = content_image.to(device)
content_features = get_features(img_normalize(content_image), VGG)
# target =
target_features = get_features(img_normalize(pred), VGG)
content_loss = 0
content_loss += torch.mean((target_features['conv4_2'] - content_features['conv4_2']) ** 2)
content_loss += torch.mean((target_features['conv5_2'] - content_features['conv5_2']) ** 2)
Tv loss
def tv_loss(input):
"""L2 total variation loss, as in Mahendran et al."""
input = F.pad(input, (0, 1, 0, 1), 'replicate')
x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
return (x_diff**2 + y_diff**2).mean([1, 2, 3])
Init loss
import lpips
lpips_model = lpips.LPIPS(net='vgg').to(device)
init_losses = lpips_model(pred, init)
Aes loss
aesthetic_model_16 = torch.nn.Linear(512,1).cuda()
aes_loss = (aesthetic_model_16(F.normalize(image_embeds, dim=-1))).mean()
