본문 바로가기

코드 리뷰/Diffusion

DAAM 코드 리뷰

DAAM 논문 리뷰, DAAM 써보기


DAAM Github의 사용 설명서 :

위 코드들이 어떻게 작동하는지 알아보자.


 

DAAM package index에서 DAAM library의 해당 명령어가 github의 어떤 코드를 가리키는지 알 수 있다.

 

먼저 set_seed :

각 library의 시드를 고정하고 시드 고정된 torch.Generator 객체 반환

def set_seed(seed: int) -> torch.Generator:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    gen = torch.Generator(device=auto_device())
    gen.manual_seed(seed)

    return gen

 

trace

 

pipe를 입력으로 DiffusionHeatMapHooker 객체 만듬

with trace(pipe) as tc:

 

DiffusionHeatMaperHooker 객체는

RawHeatMapCollection에 heatmap 저장,

Cross attention은 U-Net의 외부 모듈이기 때문에 UNetCrossAttentionLocator로 cross attention 모듈을 찾고,

UNetCrossAttentionHooker, PipelineHooker 등록

class DiffusionHeatMapHooker(AggregateHooker):
    def __init__(
            self,
            pipeline:
            StableDiffusionPipeline,
            low_memory: bool = False,
            load_heads: bool = False,
            save_heads: bool = False,
            data_dir: str = None
    ):
        self.all_heat_maps = RawHeatMapCollection()
        h = (pipeline.unet.config.sample_size * pipeline.vae_scale_factor)
        self.latent_hw = 4096 if h == 512 else 9216  # 64x64 or 96x96 depending on if it's 2.0-v or 2.0
        locate_middle = load_heads or save_heads
        self.locator = UNetCrossAttentionLocator(restrict={0} if low_memory else None, locate_middle_block=locate_middle)
        self.last_prompt: str = ''
        self.last_image: Image = None
        self.time_idx = 0
        self._gen_idx = 0

        modules = [
            UNetCrossAttentionHooker(
                x,
                self,
                layer_idx=idx,
                latent_hw=self.latent_hw,
                load_heads=load_heads,
                save_heads=save_heads,
                data_dir=data_dir
            ) for idx, x in enumerate(self.locator.locate(pipeline.unet))
        ]

        modules.append(PipelineHooker(pipeline, self))

        super().__init__(modules)
        self.pipe = pipeline

 

부모 클래스들

class ObjectHooker(Generic[ModuleType]):
    def __init__(self, module: ModuleType):
        self.module: ModuleType = module
        self.hooked = False
        self.old_state = dict()

    def __enter__(self):
        self.hook()
        return self

    def hook(self):
        if self.hooked:
            raise RuntimeError('Already hooked module')

        self.old_state = dict()
        self.hooked = True
        self._hook_impl()

        return self

    def monkey_patch(self, fn_name, fn):
        self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
        setattr(self.module, fn_name, functools.partial(fn, self.module))

    def monkey_super(self, fn_name, *args, **kwargs):
        return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs)
        

class AggregateHooker(ObjectHooker[ModuleListType]):
    def _hook_impl(self):
        for h in self.module:
            h.hook()

    def _unhook_impl(self):
        for h in self.module:
            h.unhook()

    def register_hook(self, hook: ObjectHooker):
        self.module.append(hook)

 

최하위 Hooker인 UNetCrossAttentionHooker, PipelineHooker에서 hook → _hook_impl → monkey_patch 순으로 호출됨

UNetCrossAttentionHooker에서도 같은 방법으로 attention value를 수집하고 heatmap 계산

class PipelineHooker(ObjectHooker[StableDiffusionPipeline]):
    def __init__(self, pipeline: StableDiffusionPipeline, parent_trace: 'trace'):
        super().__init__(pipeline)
        self.heat_maps = parent_trace.all_heat_maps
        self.parent_trace = parent_trace

    def _hooked_run_safety_checker(hk_self, self: StableDiffusionPipeline, image, *args, **kwargs):
        image, has_nsfw = hk_self.monkey_super('run_safety_checker', image, *args, **kwargs)
        pil_image = self.numpy_to_pil(image)
        hk_self.parent_trace.last_image = pil_image[0]

        return image, has_nsfw

    def _hooked_encode_prompt(hk_self, _: StableDiffusionPipeline, prompt: Union[str, List[str]], *args, **kwargs):
        if not isinstance(prompt, str) and len(prompt) > 1:
            raise ValueError('Only single prompt generation is supported for heat map computation.')
        elif not isinstance(prompt, str):
            last_prompt = prompt[0]
        else:
            last_prompt = prompt

        hk_self.heat_maps.clear()
        hk_self.parent_trace.last_prompt = last_prompt
        ret = hk_self.monkey_super('_encode_prompt', prompt, *args, **kwargs)

        return ret

    def _hook_impl(self):
        self.monkey_patch('run_safety_checker', self._hooked_run_safety_checker)
        self.monkey_patch('_encode_prompt', self._hooked_encode_prompt)

 

한번 forward 거치면 heatmap이 수집되어 있고 compute_global_heat_map으로 GlobalHeatMap 객체 반환

with trace(pipe) as tc:
    out = pipe(prompt, num_inference_steps=30, generator=gen)
    heat_map = tc.compute_global_heat_map()
def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, layer_idx=None, normalize=False):
        # type: (str, List[float], int, int, bool) -> GlobalHeatMap
        
        heat_maps = self.all_heat_maps

        if prompt is None:
            prompt = self.last_prompt

        if factors is None:
            factors = {0, 1, 2, 4, 8, 16, 32, 64}
        else:
            factors = set(factors)

        all_merges = []
        x = int(np.sqrt(self.latent_hw))

        with auto_autocast(dtype=torch.float32):
            for (factor, layer, head), heat_map in heat_maps:
                if factor in factors and (head_idx is None or head_idx == head) and (layer_idx is None or layer_idx == layer):
                    heat_map = heat_map.unsqueeze(1)
                    # The clamping fixes undershoot.
                    all_merges.append(F.interpolate(heat_map, size=(x, x), mode='bicubic').clamp_(min=0))

            try:
                maps = torch.stack(all_merges, dim=0)
            except RuntimeError:
                if head_idx is not None or layer_idx is not None:
                    raise RuntimeError('No heat maps found for the given parameters.')
                else:
                    raise RuntimeError('No heat maps found. Did you forget to call `with trace(...)` during generation?')

            maps = maps.mean(0)[:, 0]
            maps = maps[:len(self.pipe.tokenizer.tokenize(prompt)) + 2]  # 1 for SOS and 1 for padding

            if normalize:
                maps = maps / (maps[1:-1].sum(0, keepdim=True) + 1e-6)  # drop out [SOS] and [PAD] for proper probabilities

        return GlobalHeatMap(self.pipe.tokenizer, prompt, maps)

 

각 heatmap과 프롬프트의 단어 매칭, heatmap 시각화

heat_map = heat_map.compute_word_heat_map('dog')
heat_map.plot_overlay(out.images[0])
plt.show()
def compute_word_heat_map(self, word: str, word_idx: int = None, offset_idx: int = 0) -> WordHeatMap:
    merge_idxs, word_idx = compute_token_merge_indices(self.tokenizer, self.prompt, word, word_idx, offset_idx)
    return WordHeatMap(self.heat_maps[merge_idxs].mean(0), word, word_idx)


 

 

손쉽게 저장하고 불러오기 가능

GenerationExperiment

def to_experiment(self, path, seed=None, id='.', subtype='.', **compute_kwargs):
    # type: (Union[Path, str], int, str, str, Dict[str, Any]) -> GenerationExperiment
    """Exports the last generation call to a serializable generation experiment."""

    return GenerationExperiment(
        self.last_image,
        self.compute_global_heat_map(**compute_kwargs).heat_maps,
        self.last_prompt,
        seed=seed,
        id=id,
        subtype=subtype,
        path=path,
        tokenizer=self.pipe.tokenizer,
    )

'코드 리뷰 > Diffusion' 카테고리의 다른 글

CompVis/stable-diffusion/main.py  (0) 2023.01.19
DiffStyler 코드 리뷰  (0) 2023.01.16
Paint by Example 코드 리뷰  (1) 2023.01.15
Latent Diffusion  (0) 2022.12.28
Classifier-Guidance Diffusion  (1) 2022.12.07
Improved DDPM  (0) 2022.09.28