From 80a45e141abb14ba30c7b2b510a134714e0e1263 Mon Sep 17 00:00:00 2001 From: Alexandre Chapin Date: Tue, 3 Jun 2025 16:30:36 +0200 Subject: [PATCH] Add a batch implementation of legrad_clip + add a option for feature extract This pull request introduced a batch version of the implementation (can take batch of images as input) and provide a way to get image features with an optional argument. --- legrad/wrapper.py | 58 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 56 insertions(+), 2 deletions(-) diff --git a/legrad/wrapper.py b/legrad/wrapper.py index 4b1a4da..75b3ae7 100644 --- a/legrad/wrapper.py +++ b/legrad/wrapper.py @@ -141,11 +141,11 @@ def compute_legrad(self, text_embedding, image=None, apply_correction=True): elif 'coca' in self.model_type: return self.compute_legrad_coca(text_embedding, image) - def compute_legrad_clip(self, text_embedding, image=None): + def compute_legrad_clip(self, text_embedding, image=None, return_img_feats=False): num_prompts = text_embedding.shape[0] if image is not None: image = image.repeat(num_prompts, 1, 1, 1) - _ = self.encode_image(image) + img_feats = self.encode_image(image) blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values()) @@ -183,8 +183,62 @@ def compute_legrad_clip(self, text_embedding, image=None): # Min-Max Norm accum_expl_map = min_max(accum_expl_map) + if return_img_feats: + return accum_expl_map, img_feats return accum_expl_map + def compute_legrad_clip_batch(self, text_embeddings, images, return_img_feats=False): + """ + text_embeddings: [B, N, D] + images: [B, C, H, W] + """ + B, N, D = text_embeddings.shape + + # Expand images to match the number of prompts per image + images = images.unsqueeze(1).repeat(1, N, 1, 1, 1) # [B, N, C, H, W] + images = images.view(B * N, *images.shape[2:]) # [B*N, C, H, W] + text_embeddings = text_embeddings.view(B * N, D) # [B*N, D] + + img_feats = self.encode_image(images) + + blocks_list = list(dict(self.visual.transformer.resblocks.named_children()).values()) + + image_features_list = [] + + for layer in range(self.starting_depth, len(self.visual.transformer.resblocks)): + intermediate_feat = self.visual.transformer.resblocks[layer].feat_post_mlp # [num_patch, B*N, dim] + intermediate_feat = self.visual.ln_post(intermediate_feat.mean(dim=0)) @ self.visual.proj + intermediate_feat = F.normalize(intermediate_feat, dim=-1) + image_features_list.append(intermediate_feat) + + num_tokens = blocks_list[-1].feat_post_mlp.shape[0] - 1 + w = h = int(math.sqrt(num_tokens)) + + accum_expl_map = 0 + for layer, (blk, img_feat) in enumerate(zip(blocks_list[self.starting_depth:], image_features_list)): + self.visual.zero_grad() + sim = text_embeddings @ img_feat.transpose(-1, -2) # [B*N, B*N] + sim_diag = sim.diagonal(dim1=0, dim2=1) # [B*N] + one_hot = sim_diag.sum() + + attn_map = blocks_list[self.starting_depth + layer].attn.attention_maps # [B*N * num_heads, N, N] + grad = torch.autograd.grad(one_hot, [attn_map], retain_graph=True, create_graph=True)[0] + grad = rearrange(grad, '(bn h) n m -> bn h n m', bn=B * N) # [B*N, H, N, N] + grad = torch.clamp(grad, min=0.) + + image_relevance = grad.mean(dim=1).mean(dim=1)[:, 1:] # [B*N, N_patches] + expl_map = rearrange(image_relevance, 'bn (w h) -> 1 bn w h', w=w, h=h) + expl_map = F.interpolate(expl_map, scale_factor=self.patch_size, mode='bilinear') # [1, B*N, H, W] + accum_expl_map += expl_map + + accum_expl_map = min_max(accum_expl_map) + + # Reshape back to [B, N, H, W] + accum_expl_map = accum_expl_map.view(1, B, N, *accum_expl_map.shape[2:]) # [1, B, N, H, W] + if return_img_feats: + return accum_expl_map.squeeze(0), img_feats + return accum_expl_map.squeeze(0) # [B, N, H, W] + def compute_legrad_coca(self, text_embedding, image=None): if image is not None: _ = self.encode_image(image)