From 80a45e141abb14ba30c7b2b510a134714e0e1263 Mon Sep 17 00:00:00 2001 From: Alexandre Chapin Date: Tue, 3 Jun 2025 16:30:36 +0200 Subject: [PATCH 1/2] Add a batch implementation of legrad_clip + add a option for feature extract This pull request introduced a batch version of the implementation (can take batch of images as input) and provide a way to get image features with an optional argument. --- legrad/wrapper.py | 58 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/legrad/wrapper.py b/legrad/wrapper.py index 4b1a4da..75b3ae7 100644 --- a/legrad/wrapper.py +++ b/legrad/wrapper.py @@ -141,11 +141,11 @@ def compute_legrad(self, text_embedding, image=None, apply_correction=True): elif 'coca' in self.model_type: return self.compute_legrad_coca(text_embedding, image) - def compute_legrad_clip(self, text_embedding, image=None): + def compute_legrad_clip(self, text_embedding, image=None, return_img_feats=False): num_prompts = text_embedding.shape[0] if image is not None: image = image.repeat(num_prompts, 1, 1, 1) - _ = self.encode_image(image) + img_feats = self.encode_image(image) blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values()) @@ -183,8 +183,62 @@ def compute_legrad_clip(self, text_embedding, image=None): # Min-Max Norm accum_expl_map = min_max(accum_expl_map) + if return_img_feats: + return accum_expl_map, img_feats return accum_expl_map + def compute_legrad_clip_batch(self, text_embeddings, images, return_img_feats=False): + """ + text_embeddings: [B, N, D] + images: [B, C, H, W] + """ + B, N, D = text_embeddings.shape + + # Expand images to match the number of prompts per image + images = images.unsqueeze(1).repeat(1, N, 1, 1, 1) # [B, N, C, H, W] + images = images.view(B * N, *images.shape[2:]) # [B*N, C, H, W] + text_embeddings = text_embeddings.view(B * N, D) # [B*N, D] + + img_feats = self.encode_image(images) + + blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values()) + + image_features_list = [] + + for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): + intermediate_feat = self.visual.transformer.resblocks[layer].feat_post_mlp # [num_patch, B*N, dim] + intermediate_feat = self.visual.ln_post(intermediate_feat.mean(dim=0)) @ self.visual.proj + intermediate_feat = F.normalize(intermediate_feat, dim=-1) + image_features_list.append(intermediate_feat) + + num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 + w = h = int(math.sqrt(num_tokens)) + + accum_expl_map = 0 + for layer, (blk, img_feat) in enumerate(zip(blocks_list[self.starting_depth:], image_features_list)): + self.visual.zero_grad() + sim = text_embeddings @ img_feat.transpose(-1, -2) # [B*N, B*N] + sim_diag = sim.diagonal(dim1=0, dim2=1) # [B*N] + one_hot = sim_diag.sum() + + attn_map = blocks_list[self.starting_depth + layer].attn.attention_maps # [B*N * num_heads, N, N] + grad = torch.autograd.grad(one_hot, [attn_map], retain_graph=True, create_graph=True)[0] + grad = rearrange(grad, '(bn h) n m -> bn h n m', bn=B * N) # [B*N, H, N, N] + grad = torch.clamp(grad, min=0.) + + image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] # [B*N, N_patches] + expl_map = rearrange(image_relevance, 'bn (w h) -> 1 bn w h', w=w, h=h) + expl_map = F.interpolate(expl_map, scale_factor=self.patch_size, mode='bilinear') # [1, B*N, H, W] + accum_expl_map += expl_map + + accum_expl_map = min_max(accum_expl_map) + + # Reshape back to [B, N, H, W] + accum_expl_map = accum_expl_map.view(1, B, N, *accum_expl_map.shape[2:]) # [1, B, N, H, W] + if return_img_feats: + return accum_expl_map.squeeze(0), img_feats + return accum_expl_map.squeeze(0) # [B, N, H, W] + def compute_legrad_coca(self, text_embedding, image=None): if image is not None: _ = self.encode_image(image) From 366059a8b6d1d922828aa6e594f5f3c827674549 Mon Sep 17 00:00:00 2001 From: alexcbb Date: Fri, 6 Jun 2025 15:47:46 +0200 Subject: [PATCH 2/2] Include OpenAi CLIP integration --- legrad/utils.py | 15 ++++++++++++--- legrad/wrapper.py | 33 ++++++++++++++++++++++----------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/legrad/utils.py b/legrad/utils.py index dc6c747..25b72f4 100644 --- a/legrad/utils.py +++ b/legrad/utils.py @@ -49,7 +49,6 @@ def hooked_attention_forward(self, x, x_k, x_v, attn_mask: Optional[torch.Tensor x = self.out_proj(x) return x - # ------------ Hooked Residual Transformer Block ------------ # from https://github.com/mlfoundations/open_clip/blob/73fa7f03a33da53653f61841eb6d69aef161e521/src/open_clip/transformer.py#L231 def hooked_resblock_forward( @@ -61,7 +60,7 @@ def hooked_resblock_forward( ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None - + x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) # Hook for intermediate features post Attn self.feat_post_attn = x @@ -70,6 +69,17 @@ def hooked_resblock_forward( self.feat_post_mlp = x return x +def clip_hooked_resblock_forward( + self, + x: torch.Tensor +): + x = x + self.ls_1(self.attention(self.ln_1(x))) + # Hook for intermediate features post Attn + self.feat_post_attn = x + x = x + self.ls_2(self.mlp(self.ln_2(x))) + # Hook for intermediate features post MLP + self.feat_post_mlp = x + return x # ------------ Hooked PyTorch's Multi-Head AttentionResidual ------------ # modified from PyTorch Library @@ -384,7 +394,6 @@ def hooked_torch_func_multi_head_attention_forward(query: Tensor, else: return attn_output, None - # ------------ Hooked TimmModel's Residual Transformer Block ------------ def hooked_resblock_timm_forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) diff --git a/legrad/wrapper.py b/legrad/wrapper.py index 75b3ae7..2a8d6ad 100644 --- a/legrad/wrapper.py +++ b/legrad/wrapper.py @@ -6,11 +6,12 @@ from torchvision.transforms import Compose, Resize, InterpolationMode import open_clip from open_clip.transformer import VisionTransformer +from clip.model import VisionTransformer as ClipVisionTransformer from open_clip.timm_model import TimmModel from einops import rearrange from .utils import hooked_resblock_forward, \ - hooked_attention_forward, \ + clip_hooked_resblock_forward, \ hooked_resblock_timm_forward, \ hooked_attentional_pooler_timm_forward, \ vit_dynamic_size_forward, \ @@ -36,7 +37,7 @@ def __init__(self, model, layer_index=-2): def _activate_hooks(self, layer_index): # ------------ identify model's type ------------ print('Activating necessary hooks and gradients ....') - if isinstance(self.visual, VisionTransformer): + if isinstance(self.visual, VisionTransformer) or isinstance(self.visual, ClipVisionTransformer): # --- Activate dynamic image size --- self.visual.forward = types.MethodType(vit_dynamic_size_forward, self.visual) # Get patch size @@ -47,6 +48,7 @@ def _activate_hooks(self, layer_index): if self.visual.attn_pool is None: self.model_type = 'clip' + print('Using CLIP model') self._activate_self_attention_hooks() else: self.model_type = 'coca' @@ -67,6 +69,7 @@ def _activate_hooks(self, layer_index): self.visual.trunk.blocks) + layer_index self._activate_timm_attn_pool_hooks(layer_index=layer_index) else: + print(f"type(self.visual) = {type(self.visual)}") raise ValueError( "Model currently not supported, see legrad.list_pretrained() for a list of available models") print('Hooks and gradients activated!') @@ -81,15 +84,24 @@ def _activate_self_attention_hooks(self): depth = int(name.split('visual.transformer.resblocks.')[-1].split('.')[0]) if depth >= self.starting_depth: param.requires_grad = True + elif name.startswith('model.transformer.resblocks'): + depth = int(name.split('model.transformer.resblocks.')[-1].split('.')[0]) + if depth >= self.starting_depth: + param.requires_grad = True # --- Activate the hooks for the specific layers --- for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): self.visual.transformer.resblocks[layer].attn.forward = types.MethodType(hooked_torch_multi_head_attention_forward, self.visual.transformer.resblocks[ layer].attn) - self.visual.transformer.resblocks[layer].forward = types.MethodType(hooked_resblock_forward, - self.visual.transformer.resblocks[ - layer]) + if isinstance(self.visual, ClipVisionTransformer): + self.visual.transformer.resblocks[layer].forward = types.MethodType(clip_hooked_resblock_forward, + self.visual.transformer.resblocks[ + layer]) + else: + self.visual.transformer.resblocks[layer].forward = types.MethodType(hooked_resblock_forward, + self.visual.transformer.resblocks[ + layer]) def _activate_att_pool_hooks(self, layer_index): # ---------- Apply Hooks + Activate/Deactivate gradients ---------- @@ -141,11 +153,11 @@ def compute_legrad(self, text_embedding, image=None, apply_correction=True): elif 'coca' in self.model_type: return self.compute_legrad_coca(text_embedding, image) - def compute_legrad_clip(self, text_embedding, image=None, return_img_feats=False): + def compute_legrad_clip(self, text_embedding, image=None): num_prompts = text_embedding.shape[0] if image is not None: image = image.repeat(num_prompts, 1, 1, 1) - img_feats = self.encode_image(image) + _ = self(image) blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values()) @@ -183,10 +195,8 @@ def compute_legrad_clip(self, text_embedding, image=None, return_img_feats=False # Min-Max Norm accum_expl_map = min_max(accum_expl_map) - if return_img_feats: - return accum_expl_map, img_feats return accum_expl_map - + def compute_legrad_clip_batch(self, text_embeddings, images, return_img_feats=False): """ text_embeddings: [B, N, D] @@ -222,7 +232,7 @@ def compute_legrad_clip_batch(self, text_embeddings, images, return_img_feats=Fa one_hot = sim_diag.sum() attn_map = blocks_list[self.starting_depth + layer].attn.attention_maps # [B*N * num_heads, N, N] - grad = torch.autograd.grad(one_hot, [attn_map], retain_graph=True, create_graph=True)[0] + grad = torch.autograd.grad(one_hot, [attn_map], retain_graph=True, create_graph=True, allow_unused=True)[0] grad = rearrange(grad, '(bn h) n m -> bn h n m', bn=B * N) # [B*N, H, N, N] grad = torch.clamp(grad, min=0.) @@ -239,6 +249,7 @@ def compute_legrad_clip_batch(self, text_embeddings, images, return_img_feats=Fa return accum_expl_map.squeeze(0), img_feats return accum_expl_map.squeeze(0) # [B, N, H, W] + def compute_legrad_coca(self, text_embedding, image=None): if image is not None: _ = self.encode_image(image)