import math from typing import List, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet50, ResNet from .clip import load, tokenize from .simple_tokenizer import SimpleTokenizer as _Tokenizer from data.imagnet_prompts import imagenet_classes from data.fewshot_datasets import fewshot_datasets from data.cls_to_names import * # from data.medclip_datasets_clsnames import * import os os.environ["TOKENIZERS_PARALLELISM"] = "false" _tokenizer = _Tokenizer() DOWNLOAD_ROOT='~/.cache/clip' # class ClipImageEncoder(nn.Module): # def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000): # super(ClipImageEncoder, self).__init__() # clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT) # self.encoder = clip.visual # del clip.transformer # torch.cuda.empty_cache() # self.cls_head = nn.Linear(embed_dim, n_class) # @property # def dtype(self): # return self.encoder.conv1.weight.dtype # def forward(self, image): # x = self.encoder(image.type(self.dtype)) # output = self.cls_head(x) # return output class TextEncoder(nn.Module): def __init__(self, medclip_text_model): super().__init__() self.medclip_text_model = medclip_text_model def forward(self, prompts_embeddings, tokenized_prompts): output = self.medclip_text_model.model(inputs_embeds=prompts_embeddings, attention_mask=tokenized_prompts['attention_mask']) # take the average of last four layers # last_hidden_states = torch.stack(output['hidden_states'][-self.last_n_layer:]) # n_layer, batch, seqlen, emb_dim # embed = last_hidden_states.permute(1,0,2,3) # embed = embed.mean(1).mean(1) # pooling # get 1+2+last layer last_hidden_states = torch.stack([output['hidden_states'][1], output['hidden_states'][2], output['hidden_states'][-1]]) # n_layer, batch, seqlen, emb_dim embed = last_hidden_states.permute(1,0,2,3).mean(2).mean(1) # pooling # let's take only the last hidden layer # embed = output['pooler_output'] embed = self.medclip_text_model.projection_head(embed) return embed class PromptLearner(nn.Module): def __init__(self, medclip_model, classnames, device, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): super().__init__() n_cls = len(classnames) self.learned_cls = learned_cls dtype = medclip_model.dtype self.dtype = dtype ctx_dim = 768 # hardcoded for now!!! medclip_model.ln_final.weight.shape[0] self.ctx_dim = ctx_dim self.batch_size = batch_size self.device = device self.medclip_model = medclip_model # self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, medclip_model) if ctx_init: # raise NotImplementedError("This part is not yet implemented.") # use given words to initialize context vectors print("Initializing the contect with given words: [{}]".format(ctx_init)) # breakpoint() ctx_init = ctx_init.replace("_", " ") if '[CLS]' in ctx_init: ctx_list = ctx_init.split(" ") split_idx = ctx_list.index("[CLS]") ctx_init = ctx_init.replace("[CLS] ", "") ctx_position = "middle" else: split_idx = None self.split_idx = split_idx n_ctx = len(ctx_init.split(" ")) # prompt = tokenize(ctx_init).to(self.device) prompt = ctx_init tokenized_prompts = medclip_model.text_model.tokenizer(prompt, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77] with torch.no_grad(): embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768] # embedding = medclip_model.token_embedding(prompt).type(dtype) ctx_vectors = embedding[0, 1 : 1 + n_ctx, :] prompt_prefix = ctx_init else: print("Random initialization: initializing a generic context") ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) self.prompt_prefix = prompt_prefix print(f'Initial context: "{prompt_prefix}"') print(f"Number of context words (tokens): {n_ctx}") # batch-wise prompt tuning for test-time adaptation if self.batch_size is not None: ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D) self.ctx_init_state = ctx_vectors.detach().clone() self.ctx = nn.Parameter(ctx_vectors) # to be optimized if not self.learned_cls: classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted prompts = [prompt_prefix + " " + name + "." for name in classnames] else: print("Random initialization: initializing a learnable class token") cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word nn.init.normal_(cls_vectors, std=0.02) cls_token = "X" name_lens = [1 for _ in classnames] prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames] self.cls_init_state = cls_vectors.detach().clone() self.cls = nn.Parameter(cls_vectors) # to be optimized tokenized_prompts = medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77] with torch.no_grad(): embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768] # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS if self.learned_cls: self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS else: self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS self.ctx_init = ctx_init self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens self.class_token_position = ctx_position self.n_cls = n_cls self.n_ctx = n_ctx self.classnames = classnames def reset(self): ctx_vectors = self.ctx_init_state self.ctx.copy_(ctx_vectors) # to be optimized if self.learned_cls: cls_vectors = self.cls_init_state self.cls.copy_(cls_vectors) def reset_classnames(self, classnames, arch): self.n_cls = len(classnames) if not self.learned_cls: classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(self.medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted prompts = [self.prompt_prefix + " " + name + "." for name in classnames] else: cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word nn.init.normal_(cls_vectors, std=0.02) cls_token = "X" name_lens = [1 for _ in classnames] prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames] self.cls_init_state = cls_vectors.detach().clone() tokenized_prompts = self.medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device) prompts_tokens = tokenized_prompts['input_ids'] with torch.no_grad(): embedding = self.medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(self.dtype) # [n_cls, 77, 768] self.token_prefix = embedding[:, :1, :] self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS self.name_lens = name_lens self.tokenized_prompts = tokenized_prompts self.classnames = classnames def forward(self, init=None): # the init will be used when computing CLIP directional loss if init is not None: ctx = init else: ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) elif not ctx.size()[0] == self.n_cls: ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix if self.batch_size is not None: # This way only works for single-gpu setting (could pass batch size as an argument for forward()) prefix = prefix.repeat(self.batch_size, 1, 1, 1) suffix = suffix.repeat(self.batch_size, 1, 1, 1) if self.learned_cls: assert self.class_token_position == "end" if self.class_token_position == "end": if self.learned_cls: cls = self.cls prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) cls, # (n_cls, 1, dim) suffix, # (n_cls, *, dim) ], dim=-2, ) else: prompts = torch.cat( [ prefix, # (n_cls, 1, dim) ctx, # (n_cls, n_ctx, dim) suffix, # (n_cls, *, dim) ], dim=-2, ) elif self.class_token_position == "middle": # TODO: to work with a batch of prompts if self.split_idx is not None: half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init` else: half_n_ctx = self.n_ctx // 2 prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :] ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) ctx_i_half1, # (1, n_ctx//2, dim) class_i, # (1, name_len, dim) ctx_i_half2, # (1, n_ctx//2, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) elif self.class_token_position == "front": prompts = [] for i in range(self.n_cls): name_len = self.name_lens[i] prefix_i = prefix[i : i + 1, :, :] class_i = suffix[i : i + 1, :name_len, :] suffix_i = suffix[i : i + 1, name_len:, :] ctx_i = ctx[i : i + 1, :, :] prompt = torch.cat( [ prefix_i, # (1, 1, dim) class_i, # (1, name_len, dim) ctx_i, # (1, n_ctx, dim) suffix_i, # (1, *, dim) ], dim=1, ) prompts.append(prompt) prompts = torch.cat(prompts, dim=0) else: raise ValueError return prompts from MedCLIP.medclip import MedCLIPModel, MedCLIPVisionModel, MedCLIPVisionModelViT from MedCLIP.medclip import MedCLIPProcessor def load_medclip_to_cpu(): model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT) model.from_pretrained() # breakpoint() # model.from_pretrained("/l/users/asif.hanif/pre-trained-models/vlps/medclip/pretrained/medclip-vit/") model.from_pretrained("./MedCLIP/pretrained/medclip-vit/") # for vit model.dtype = model.vision_model.model.embeddings.patch_embeddings.projection.weight.dtype # for Resnet # model.dtype = model.vision_model.model.conv1.weight.dtype model.eval() return model class ClipTestTimeTuning(nn.Module): def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14", n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False): super(ClipTestTimeTuning, self).__init__() self.device = device self.medclip_model = load_medclip_to_cpu() self.dtype = self.medclip_model.dtype self.medclip_model = self.medclip_model.to(self.device) self.image_encoder = self.medclip_model.vision_model self.text_encoder = TextEncoder(self.medclip_model.text_model) self.logit_scale = self.medclip_model.logit_scale.data # prompt tuning self.prompt_learner = PromptLearner(self.medclip_model, classnames, self.device, batch_size, n_ctx, ctx_init, ctx_position, learned_cls) self.criterion = criterion self.l2_norm_cal = False # @property # def dtype(self): # return self.image_encoder.conv1.weight.dtype # restore the initial state of the prompt_learner (tunable prompt) def reset(self): self.prompt_learner.reset() def reset_classnames(self, classnames, arch): self.prompt_learner.reset_classnames(classnames, arch) def get_text_features(self): text_features = [] prompts = self.prompt_learner() tokenized_prompts = self.prompt_learner.tokenized_prompts t_features = self.text_encoder(prompts, tokenized_prompts) text_features.append(t_features / t_features.norm(dim=-1, keepdim=True)) text_features = torch.stack(text_features, dim=0) return torch.mean(text_features, dim=0) def inference(self, image): with torch.no_grad(): image_features = self.image_encoder(image.type(self.dtype)) text_features = self.get_text_features() image_features = image_features / image_features.norm(dim=-1, keepdim=True) #[c-tpt] -------------------------------------------- if self.l2_norm_cal: prompt_mean = text_features.mean(0) feature_distance = text_features - prompt_mean l2_norm = torch.linalg.norm(feature_distance, dim=-1) l2_norm_mean = l2_norm.mean() #for saving to csv file self.l2_norm_mean = l2_norm_mean.item() #for training self.l2_norm_mean_training = l2_norm_mean #----------------------------------------------------- logit_scale = self.logit_scale.exp() logits = logit_scale * image_features @ text_features.t() return logits def forward(self, input): # breakpoint() if isinstance(input, Tuple): view_0, view_1, view_2 = input return self.contrast_prompt_tuning(view_0, view_1, view_2) elif len(input.size()) == 2: return self.directional_prompt_tuning(input) else: return self.inference(input) def get_coop(clip_arch, test_set, device, n_ctx, ctx_init=None, learned_cls=False): classnames = eval("{}_classes".format(test_set.lower())) model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch, n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls) return model