razaimam45 commited on
Commit
a96891a
·
verified ·
1 Parent(s): 110f995

Upload 108 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. BetaMixture.py +187 -0
  3. LICENSE +21 -0
  4. README.md +137 -3
  5. baselines.py +51 -0
  6. clip/__init__.py +2 -0
  7. clip/__pycache__/__init__.cpython-310.pyc +0 -0
  8. clip/__pycache__/__init__.cpython-312.pyc +0 -0
  9. clip/__pycache__/__init__.cpython-38.pyc +0 -0
  10. clip/__pycache__/__init__.cpython-39.pyc +0 -0
  11. clip/__pycache__/clip.cpython-310.pyc +0 -0
  12. clip/__pycache__/clip.cpython-312.pyc +0 -0
  13. clip/__pycache__/clip.cpython-38.pyc +0 -0
  14. clip/__pycache__/clip.cpython-39.pyc +0 -0
  15. clip/__pycache__/cocoop.cpython-310.pyc +0 -0
  16. clip/__pycache__/cocoop.cpython-312.pyc +0 -0
  17. clip/__pycache__/cocoop.cpython-39.pyc +0 -0
  18. clip/__pycache__/custom_clip.cpython-310.pyc +0 -0
  19. clip/__pycache__/custom_clip.cpython-312.pyc +0 -0
  20. clip/__pycache__/custom_clip.cpython-39.pyc +0 -0
  21. clip/__pycache__/custom_medclip.cpython-310.pyc +0 -0
  22. clip/__pycache__/custom_medclip.cpython-312.pyc +0 -0
  23. clip/__pycache__/custom_medclip.cpython-39.pyc +0 -0
  24. clip/__pycache__/model.cpython-310.pyc +0 -0
  25. clip/__pycache__/model.cpython-312.pyc +0 -0
  26. clip/__pycache__/model.cpython-38.pyc +0 -0
  27. clip/__pycache__/model.cpython-39.pyc +0 -0
  28. clip/__pycache__/simple_tokenizer.cpython-310.pyc +0 -0
  29. clip/__pycache__/simple_tokenizer.cpython-312.pyc +0 -0
  30. clip/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
  31. clip/__pycache__/simple_tokenizer.cpython-39.pyc +0 -0
  32. clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  33. clip/clip.py +232 -0
  34. clip/cocoop.py +234 -0
  35. clip/custom_clip.py +388 -0
  36. clip/custom_medclip.py +389 -0
  37. clip/model.py +438 -0
  38. clip/simple_tokenizer.py +132 -0
  39. data/__init__.py +0 -0
  40. data/__pycache__/__init__.cpython-310.pyc +0 -0
  41. data/__pycache__/__init__.cpython-311.pyc +0 -0
  42. data/__pycache__/__init__.cpython-312.pyc +0 -0
  43. data/__pycache__/__init__.cpython-39.pyc +0 -0
  44. data/__pycache__/augmix_ops.cpython-310.pyc +0 -0
  45. data/__pycache__/augmix_ops.cpython-311.pyc +0 -0
  46. data/__pycache__/augmix_ops.cpython-312.pyc +0 -0
  47. data/__pycache__/augmix_ops.cpython-39.pyc +0 -0
  48. data/__pycache__/cls_to_names.cpython-310.pyc +0 -0
  49. data/__pycache__/cls_to_names.cpython-312.pyc +0 -0
  50. data/__pycache__/cls_to_names.cpython-39.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ figures/eval.png filter=lfs diff=lfs merge=lfs -text
37
+ figures/method.png filter=lfs diff=lfs merge=lfs -text
38
+ figures/mi_v_ent.png filter=lfs diff=lfs merge=lfs -text
39
+ figures/results.png filter=lfs diff=lfs merge=lfs -text
BetaMixture.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.special import betaln, logsumexp
3
+ from sklearn.cluster import KMeans
4
+
5
+ class BetaMixtureModel:
6
+ """
7
+ Beta Mixture Model (Multivariate version).
8
+ Each dimension is modeled independently by a Beta distribution.
9
+ """
10
+
11
+ def __init__(self, n_mixtures=3, random_seed=1):
12
+ self.n_mixtures = n_mixtures
13
+ self.random_seed = random_seed
14
+ self.convergence = False
15
+
16
+ def _init_clusters(self, data_matrix, init_round):
17
+ """
18
+ Initialize the mixture responsibilities (assignments) via k-means or uniformly random
19
+ """
20
+ if self.method == "kmeans":
21
+ km = KMeans(
22
+ n_clusters=self.n_mixtures,
23
+ n_init=1,
24
+ random_state=self.random_seed + init_round
25
+ ).fit(data_matrix)
26
+ resp_matrix = np.zeros((self.n_observations, self.n_mixtures))
27
+ resp_matrix[np.arange(self.n_observations), km.labels_] = 1
28
+ else:
29
+ np.random.seed(self.random_seed + init_round)
30
+ resp_matrix = np.random.rand(self.n_observations, self.n_mixtures)
31
+ resp_matrix /= resp_matrix.sum(axis=1, keepdims=True)
32
+
33
+ # Numerical stability
34
+ resp_matrix += 10 * np.finfo(resp_matrix.dtype).eps
35
+
36
+ # Initialize beta parameters (alpha/beta for each dimension)
37
+ self.beta_params_ = np.zeros((self.n_mixtures, self.n_components * 2))
38
+ self._M_step(data_matrix, np.log(resp_matrix))
39
+
40
+
41
+ def _calc_log_weights(self):
42
+ """
43
+ Return log of current mixture weights.
44
+ """
45
+ return np.log(self.mix_weights_)
46
+
47
+ def _calc_mixture_log_probs(self, data_matrix, mixture_idx):
48
+ """
49
+ Compute log-prob for a single mixture (used if parallelized).
50
+ """
51
+ alpha_vec = self.beta_params_[mixture_idx, :self.n_components]
52
+ beta_vec = self.beta_params_[mixture_idx, self.n_components:]
53
+ beta_func_log = betaln(alpha_vec, beta_vec)
54
+ return (
55
+ (alpha_vec - 1) * np.log(data_matrix)
56
+ + (beta_vec - 1) * np.log(1 - data_matrix)
57
+ - beta_func_log
58
+ ).sum(axis=1)
59
+
60
+ def _calc_log_probs_all_mixtures(self, data_matrix):
61
+ """
62
+ Return log-prob for each observation under each mixture (unnormalized).
63
+ """
64
+ log_prob = np.empty((self.n_observations, self.n_mixtures))
65
+ for mix in range(self.n_mixtures):
66
+ alpha_vec = self.beta_params_[mix, :self.n_components]
67
+ beta_vec = self.beta_params_[mix, self.n_components:]
68
+ bfn = betaln(alpha_vec, beta_vec)
69
+ log_prob[:, mix] = (
70
+ (alpha_vec - 1) * np.log(data_matrix)
71
+ + (beta_vec - 1) * np.log(1 - data_matrix)
72
+ - bfn
73
+ ).sum(axis=1)
74
+ return log_prob
75
+
76
+ def _calc_weighted_log_probs(self, data_matrix):
77
+ """
78
+ Return the sum of log-probabilities and log-weights.
79
+ """
80
+ return self._calc_log_probs_all_mixtures(data_matrix) + self._calc_log_weights()
81
+
82
+ def _calc_log_resp_and_norm(self, data_matrix):
83
+ """
84
+ Return (log_prob_norm, log_resp) for the E-step.
85
+ """
86
+ weighted_lp = self._calc_weighted_log_probs(data_matrix)
87
+ lp_norm = logsumexp(weighted_lp, axis=1)
88
+ with np.errstate(under="ignore"):
89
+ log_resp = weighted_lp - lp_norm[:, None]
90
+ return lp_norm, log_resp
91
+
92
+ def _E_step(self, data_matrix):
93
+ """
94
+ E-step: compute average log_prob_norm and log_resp.
95
+ """
96
+ lp_norm, log_resp = self._calc_log_resp_and_norm(data_matrix)
97
+ return np.mean(lp_norm), log_resp
98
+
99
+ def _compute_responsibilities(self, log_resp):
100
+ """
101
+ Exponentiate log_resp and sum across observations.
102
+ """
103
+ resp_matrix = np.exp(log_resp)
104
+ cluster_counts = resp_matrix.sum(axis=0) + 10 * np.finfo(resp_matrix.dtype).eps
105
+ return resp_matrix, cluster_counts
106
+
107
+ def _update_mixture_weights(self, cluster_counts):
108
+ """
109
+ Update mixture weights from mixture counts.
110
+ """
111
+ self.mix_weights_ = cluster_counts / cluster_counts.sum()
112
+
113
+ def _M_step(self, data_matrix, log_resp):
114
+ """
115
+ M-step: update weights and Beta distribution parameters via moment matching.
116
+ """
117
+ resp_matrix, cluster_counts = self._compute_responsibilities(log_resp)
118
+ self._update_mixture_weights(cluster_counts)
119
+
120
+ w_sums = resp_matrix.T @ data_matrix
121
+ w_sums_sq = resp_matrix.T @ (data_matrix ** 2)
122
+
123
+ for m_idx in range(self.n_mixtures):
124
+ sum_vals = w_sums[m_idx]
125
+ sum_sq_vals = w_sums_sq[m_idx]
126
+ mean_val = sum_vals / cluster_counts[m_idx]
127
+ var_val = sum_sq_vals / cluster_counts[m_idx] - mean_val ** 2
128
+
129
+ # Clip variance
130
+ variance_cap = mean_val * (1 - mean_val) / 4
131
+ var_val = np.minimum(var_val, variance_cap)
132
+ var_val += 10 * np.finfo(var_val.dtype).eps
133
+
134
+ # Compute factor
135
+ scaling_factor = (mean_val * (1 - mean_val)) / (var_val + 1e-10) - 1
136
+ self.beta_params_[m_idx, :self.n_components] = scaling_factor * mean_val
137
+ self.beta_params_[m_idx, self.n_components:] = scaling_factor * (1 - mean_val)
138
+
139
+ def fit(self, data_matrix, num_init=3, method="kmeans", max_iter=1000, tol=1e-4):
140
+ """
141
+ Fit BetaMixtureModel to the data using EM, possibly with multiple initializations.
142
+ """
143
+ self.n_observations, self.n_components = data_matrix.shape
144
+ self.convergence = False
145
+ self.method = method
146
+ best_lower_bound = -np.inf
147
+ optimal_params = None
148
+
149
+ for init_round in range(num_init):
150
+ # print(f"{init_round + 1}-th BMM initialization")
151
+ self._init_clusters(data_matrix, init_round)
152
+ ll_bound = -np.inf
153
+
154
+ for _ in range(max_iter):
155
+ prev_bound = ll_bound
156
+ lp_norm, log_resp = self._E_step(data_matrix)
157
+ self._M_step(data_matrix, log_resp)
158
+ ll_bound = lp_norm
159
+ delta_bound = ll_bound - prev_bound
160
+
161
+ if abs(delta_bound) < tol:
162
+ self.convergence = True
163
+ break
164
+
165
+ if ll_bound > best_lower_bound:
166
+ best_lower_bound = ll_bound
167
+ # Update final weights
168
+ _, cluster_counts = self._compute_responsibilities(log_resp)
169
+ self._update_mixture_weights(cluster_counts)
170
+ optimal_params = (self.mix_weights_.copy(), self.beta_params_.copy())
171
+
172
+ self.mix_weights_, self.beta_params_ = optimal_params
173
+ self.max_lower_bound = best_lower_bound
174
+ return self
175
+
176
+ def predict_proba(self, data_matrix):
177
+ """
178
+ Return the per-mixture membership probabilities for each sample.
179
+ """
180
+ _, log_resp = self._calc_log_resp_and_norm(data_matrix)
181
+ return np.exp(log_resp)
182
+
183
+ def predict(self, data_matrix):
184
+ """
185
+ Return the most probable mixture index for each sample.
186
+ """
187
+ return np.argmax(self.predict_proba(data_matrix), axis=1)
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 razaimam45
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,3 +1,137 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # T³: Test-Time Model Merging for Medical Vision-Language Models
2
+
3
+ ![T³ Workflow](figures/method.png)
4
+ *Figure 1: Dynamic test-time merging workflow of T³*
5
+
6
+ Official implementation of **T³: Test-Time Model Merging in Vision-Language Models for Zero-Shot Medical Imaging**, a method for adaptive fusion of pretrained and fine-tuned vision-language models at test time using Jensen-Shannon divergence.
7
+
8
+ ---
9
+
10
+ ## Key Features
11
+ - 🧠 **Mutual Information Guidance**: Uses JS divergence to measure model consensus.
12
+ - ⚡ **Backpropagation-Free**: No gradient updates required during inference.
13
+ - 🏥 **Medical Modality Agnostic**: Validated consistency on 4x medical imaging domains.
14
+ - 🚀 **Batch-Wise Efficiency**: Reduces compute cost by 32x vs sample-wise merging.
15
+ - 📈 **SOTA Performance**: Outperforms 8+ baselines in accuracy & robustness.
16
+
17
+ ---
18
+
19
+ ## Table of Contents
20
+ - [Installation](#installation)
21
+ - [Method Overview](#method-overview)
22
+ - [Folder Structure](#folder-structure)
23
+ - [Reproducing Results](#reproducing-results)
24
+ - [Pretrained Weights](#pretrained-weights)
25
+ - [Citation](#citation)
26
+
27
+ ## Installation
28
+
29
+ 1. Clone repository:
30
+ ```bash
31
+ git clone https://github.com/yourusername/T3.git
32
+ cd T3
33
+ ```
34
+
35
+ 2. Create conda environment:
36
+ ```bash
37
+ conda create -n t3 python=3.9
38
+ conda activate t3
39
+ pip install -r requirements.txt
40
+ ```
41
+
42
+ ## Method Overview
43
+
44
+ ### Adaptive Merging via Jensen-Shannon Divergence
45
+ The interpolation coefficient λ is computed dynamically for each sample using the following equation:
46
+
47
+ ```math
48
+ λ(x) = λ_{min} + (λ_{max}-λ_{min})σ(γ⋅JS(p_{pt}(x)‖p_{ft}(x)))
49
+ ```
50
+
51
+ Where:
52
+ - `JS` = Jensen-Shannon divergence between pretrained and fine-tuned model predictions.
53
+ - `σ` = Sigmoid function for smooth scaling.
54
+ - `γ` = Scaling factor (default=0.5).
55
+
56
+ ### Visual Explanation of the Method
57
+ Below justifies the method and its effectiveness:
58
+
59
+ ### Dynamic Weighting Based on Model Agreement
60
+
61
+ We propose using Jensen–Shannon (JS) divergence to measure mutual information between pretrained (`p_pt`) and fine-tuned (`p_ft`) model predictions, offering a more robust gauge of joint confidence than entropy-based methods like DaWin's entropy ratio:
62
+
63
+ ```math
64
+ R(x) = \frac{\mathcal{H}(p_{ft}(x))}{\mathcal{H}(p_{pt}(x)) + \mathcal{H}(p_{ft}(x))}
65
+ ```
66
+
67
+ JS divergence explicitly captures agreement vs. disagreement by comparing full predictive distributions:
68
+
69
+ ```math
70
+ I(x) = \frac{1}{2} \Bigl(\mathrm{KL}(p_{pt}(x) \Vert \bar{p}(x)) + \mathrm{KL}(p_{ft}(x) \Vert \bar{p}(x))\Bigr)
71
+ ```
72
+ where
73
+ ```math
74
+ \bar{p}(x) = 0.5 \cdot (p_{pt}(x) + p_{ft}(x))`.
75
+ ```
76
+
77
+ This ensures:
78
+ - \(I(x) = 0\) when models fully agree.
79
+ - \(I(x) > 0\) when confident predictions disagree.
80
+
81
+ Empirically, \(I(x)\) correlates positively with \(R(x)\), but better distinguishes disagreements, validating its use for adaptive merging.
82
+
83
+ 2. **Mutual Information vs. Entropy**
84
+ ![MI vs Entropy](figures/mi_v_ent.png)
85
+ *Figure 3: Relationship between mutual information and entropy for adaptive merging.*
86
+
87
+ 3. **Performance Across Modalities**
88
+ ![Performance Comparison](figures/results.png)
89
+ *Figure 4: T³ achieves superior performance across multiple medical imaging modalities.*
90
+
91
+ ---
92
+
93
+ ## Folder Structure
94
+
95
+ ```
96
+ T3/
97
+ ├── clip/ # CLIP model adaptations
98
+ ├── data/ # Data Utilities
99
+ ├── utils/ # Helper functions
100
+ ├── baselines.py # Comparison methods
101
+ ├── t_cube.py # Core T³ implementation
102
+ ├── BetaMixture.py # Auxiliary models
103
+ └── README.md # This document
104
+ ```
105
+
106
+ ---
107
+
108
+ ## Reproducing Results
109
+
110
+ To reproduce the results from the paper, you can run the `t_cube.py` script. This script handles the evaluation of T³ and its baselines across multiple datasets and severity levels. Additional baselines are available in `baselines.py`.
111
+
112
+ To understand the script better:
113
+ - Refer to the `compute_samplewise_tcube_weights` and `compute_samplewise_tcube_weights_MI` functions for entropy (DaWiN baseline) and Our mutual information-based merging.
114
+ - Check the `evaluate_on_test_set` function for how datasets and severities are processed.
115
+ - Explore the `evaluate_tcube` function for the merging and evaluation logic.
116
+
117
+ ---
118
+
119
+ ## Pretrained Weights
120
+
121
+ We provide pretrained weights for the following models:
122
+ 1. **Generalist CLIP**: A pretrained model for general vision-language tasks.
123
+ 2. **Expert CLIPs**: 4x Fine-tuned models for the following medical imaging domains:
124
+ - Breast Imaging
125
+ - Fundoscopy
126
+ - Cell Microscopy
127
+ - Retinal OCT
128
+
129
+ If you would like access to these weights, please contact us directly at [Raza Imam](mailto:[email protected]).
130
+
131
+ ---
132
+
133
+ ## License
134
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
135
+
136
+ ## Contact
137
+ For questions or collaborations, contact [Raza Imam](mailto:[email protected]).
baselines.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import numpy as np
4
+ from scipy.stats import pearsonr
5
+ from t_cube import evaluate_model
6
+
7
+ def evaluate_slerp(clip_pt, sd_pt, sd_ft, dataloader, args, alpha=0.5):
8
+ """
9
+ SLERP (spherical linear interpolation) between pretrained (pt) and fine-tuned (ft) weights.
10
+ alpha=0 -> pt only; alpha=1 -> ft only.
11
+ """
12
+ model = copy.deepcopy(clip_pt)
13
+ merged_sd = {}
14
+ # flatten-per-key SLERP
15
+ for k in sd_pt.keys():
16
+ w1 = sd_pt[k].flatten().float()
17
+ w2 = sd_ft[k].flatten().float()
18
+ # cosine similarity
19
+ cos_val = torch.dot(w1, w2) / (w1.norm() * w2.norm() + 1e-8)
20
+ omega = torch.acos(torch.clamp(cos_val, -1+1e-6, 1-1e-6))
21
+ sin_omega = torch.sin(omega)
22
+ if sin_omega < 1e-6:
23
+ w_interp = (1-alpha)*w1 + alpha*w2
24
+ else:
25
+ w_interp = (torch.sin((1-alpha)*omega)/sin_omega)*w1 + \
26
+ (torch.sin(alpha*omega)/sin_omega)*w2
27
+ merged_sd[k] = w_interp.view_as(sd_pt[k])
28
+ model.load_state_dict(merged_sd)
29
+ return evaluate_model(model, dataloader, args)
30
+
31
+
32
+ def evaluate_m3(clip_pt, sd_pt, sd_ft, dataloader, args):
33
+ """
34
+ M^3 (Mixup Model Merge): sample lambda ~ Uniform(0,1) and do linear interpolation.
35
+ """
36
+ model = copy.deepcopy(clip_pt)
37
+ lam = np.random.rand()
38
+ merged_sd = {k: lam * sd_ft[k] + (1 - lam) * sd_pt[k]
39
+ for k in sd_pt.keys()}
40
+ model.load_state_dict(merged_sd)
41
+ return evaluate_model(model, dataloader, args)
42
+
43
+
44
+ def evaluate_task_arithmetic(clip_pt, sd_pt, sd_ft, dataloader, args):
45
+ """
46
+ Task Arithmetic: extrapolate along the ft−pt vector, i.e. 2*ft – pt.
47
+ """
48
+ model = copy.deepcopy(clip_pt)
49
+ merged_sd = {k: 2 * sd_ft[k] - sd_pt[k] for k in sd_pt.keys()}
50
+ model.load_state_dict(merged_sd)
51
+ return evaluate_model(model, dataloader, args)
clip/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .clip import *
2
+ from .custom_clip import *
clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (190 Bytes). View file
 
clip/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (202 Bytes). View file
 
clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (189 Bytes). View file
 
clip/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (188 Bytes). View file
 
clip/__pycache__/clip.cpython-310.pyc ADDED
Binary file (8.43 kB). View file
 
clip/__pycache__/clip.cpython-312.pyc ADDED
Binary file (13.6 kB). View file
 
clip/__pycache__/clip.cpython-38.pyc ADDED
Binary file (8.34 kB). View file
 
clip/__pycache__/clip.cpython-39.pyc ADDED
Binary file (8.4 kB). View file
 
clip/__pycache__/cocoop.cpython-310.pyc ADDED
Binary file (7.4 kB). View file
 
clip/__pycache__/cocoop.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
clip/__pycache__/cocoop.cpython-39.pyc ADDED
Binary file (7.44 kB). View file
 
clip/__pycache__/custom_clip.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
clip/__pycache__/custom_clip.cpython-312.pyc ADDED
Binary file (19.5 kB). View file
 
clip/__pycache__/custom_clip.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
clip/__pycache__/custom_medclip.cpython-310.pyc ADDED
Binary file (10.2 kB). View file
 
clip/__pycache__/custom_medclip.cpython-312.pyc ADDED
Binary file (18.4 kB). View file
 
clip/__pycache__/custom_medclip.cpython-39.pyc ADDED
Binary file (10 kB). View file
 
clip/__pycache__/model.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
clip/__pycache__/model.cpython-312.pyc ADDED
Binary file (29.8 kB). View file
 
clip/__pycache__/model.cpython-38.pyc ADDED
Binary file (15 kB). View file
 
clip/__pycache__/model.cpython-39.pyc ADDED
Binary file (15 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-310.pyc ADDED
Binary file (5.7 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-312.pyc ADDED
Binary file (8.92 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-38.pyc ADDED
Binary file (5.79 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-39.pyc ADDED
Binary file (5.75 kB). View file
 
clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
clip/clip.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Any, Union, List
6
+ from pkg_resources import packaging
7
+
8
+ import torch
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from tqdm import tqdm
12
+
13
+ from .model import build_model
14
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
+
16
+ try:
17
+ from torchvision.transforms import InterpolationMode
18
+ BICUBIC = InterpolationMode.BICUBIC
19
+ except ImportError:
20
+ BICUBIC = Image.BICUBIC
21
+
22
+
23
+ if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
+
26
+
27
+ __all__ = ["available_models", "load", "tokenize"]
28
+ _tokenizer = _Tokenizer()
29
+
30
+ _MODELS = {
31
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
+ }
40
+
41
+
42
+ def _download(url: str, root: str):
43
+ os.makedirs(root, exist_ok=True)
44
+ filename = os.path.basename(url)
45
+
46
+ expected_sha256 = url.split("/")[-2]
47
+ download_target = os.path.join(root, filename)
48
+
49
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
50
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
51
+
52
+ if os.path.isfile(download_target):
53
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
54
+ return download_target
55
+ else:
56
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
59
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
60
+ while True:
61
+ buffer = source.read(8192)
62
+ if not buffer:
63
+ break
64
+
65
+ output.write(buffer)
66
+ loop.update(len(buffer))
67
+
68
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
69
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
70
+
71
+ return download_target
72
+
73
+
74
+ def _convert_image_to_rgb(image):
75
+ return image.convert("RGB")
76
+
77
+
78
+ def _transform(n_px):
79
+ return Compose([
80
+ Resize(n_px, interpolation=BICUBIC),
81
+ CenterCrop(n_px),
82
+ _convert_image_to_rgb,
83
+ ToTensor(),
84
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
85
+ ])
86
+
87
+
88
+ def available_models() -> List[str]:
89
+ """Returns the names of available CLIP models"""
90
+ return list(_MODELS.keys())
91
+
92
+
93
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
94
+ """Load a CLIP model
95
+
96
+ Parameters
97
+ ----------
98
+ name : str
99
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
100
+
101
+ device : Union[str, torch.device]
102
+ The device to put the loaded model
103
+
104
+ jit : bool
105
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
106
+
107
+ download_root: str
108
+ path to download the model files; by default, it uses "~/.cache/clip"
109
+
110
+ Returns
111
+ -------
112
+ model : torch.nn.Module
113
+ The CLIP model
114
+
115
+ preprocess : Callable[[PIL.Image], torch.Tensor]
116
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
117
+ """
118
+ if name in _MODELS:
119
+ model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
120
+ elif os.path.isfile(name):
121
+ model_path = name
122
+ else:
123
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
124
+
125
+ try:
126
+ # loading JIT archive
127
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
128
+ state_dict = None
129
+ except RuntimeError:
130
+ # loading saved state dict
131
+ if jit:
132
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
133
+ jit = False
134
+ state_dict = torch.load(model_path, map_location="cpu")
135
+
136
+ embed_dim = model.state_dict()["text_projection"].shape[1]
137
+ if not jit:
138
+ model = build_model(state_dict or model.state_dict()).to(device)
139
+ if str(device) == "cpu":
140
+ model.float()
141
+ return model, embed_dim, _transform(model.visual.input_resolution)
142
+
143
+ # patch the device names
144
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
145
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
146
+
147
+ def patch_device(module):
148
+ try:
149
+ graphs = [module.graph] if hasattr(module, "graph") else []
150
+ except RuntimeError:
151
+ graphs = []
152
+
153
+ if hasattr(module, "forward1"):
154
+ graphs.append(module.forward1.graph)
155
+
156
+ for graph in graphs:
157
+ for node in graph.findAllNodes("prim::Constant"):
158
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
159
+ node.copyAttributes(device_node)
160
+
161
+ model.apply(patch_device)
162
+ patch_device(model.encode_image)
163
+ patch_device(model.encode_text)
164
+
165
+ # patch dtype to float32 on CPU
166
+ if str(device) == "cpu":
167
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
168
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
169
+ float_node = float_input.node()
170
+
171
+ def patch_float(module):
172
+ try:
173
+ graphs = [module.graph] if hasattr(module, "graph") else []
174
+ except RuntimeError:
175
+ graphs = []
176
+
177
+ if hasattr(module, "forward1"):
178
+ graphs.append(module.forward1.graph)
179
+
180
+ for graph in graphs:
181
+ for node in graph.findAllNodes("aten::to"):
182
+ inputs = list(node.inputs())
183
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
184
+ if inputs[i].node()["value"] == 5:
185
+ inputs[i].node().copyAttributes(float_node)
186
+
187
+ model.apply(patch_float)
188
+ patch_float(model.encode_image)
189
+ patch_float(model.encode_text)
190
+
191
+ model.float()
192
+
193
+ return model, embed_dim, _transform(model.input_resolution.item())
194
+
195
+
196
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
197
+ """
198
+ Returns the tokenized representation of given input string(s)
199
+
200
+ Parameters
201
+ ----------
202
+ texts : Union[str, List[str]]
203
+ An input string or a list of input strings to tokenize
204
+
205
+ context_length : int
206
+ The context length to use; all CLIP models use 77 as the context length
207
+
208
+ truncate: bool
209
+ Whether to truncate the text in case its encoding is longer than the context length
210
+
211
+ Returns
212
+ -------
213
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
214
+ """
215
+ if isinstance(texts, str):
216
+ texts = [texts]
217
+
218
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
219
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
220
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
221
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
222
+
223
+ for i, tokens in enumerate(all_tokens):
224
+ if len(tokens) > context_length:
225
+ if truncate:
226
+ tokens = tokens[:context_length]
227
+ tokens[-1] = eot_token
228
+ else:
229
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
230
+ result[i, :len(tokens)] = torch.tensor(tokens)
231
+
232
+ return result
clip/cocoop.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from clip import load, tokenize
9
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
10
+ from .custom_clip import TextEncoder
11
+ from data.imagnet_prompts import imagenet_classes
12
+ from data.cls_to_names import *
13
+ from data.fewshot_datasets import fewshot_datasets
14
+
15
+ _tokenizer = _Tokenizer()
16
+
17
+ DOWNLOAD_ROOT='~/.cache/clip'
18
+
19
+ class CoCoOpPromptLearner(nn.Module):
20
+ def __init__(self, clip_model, classnames, n_ctx=4, ctx_init="a_photo_of_a", ctx_position='end'):
21
+ super().__init__()
22
+ n_cls = len(classnames)
23
+ dtype = clip_model.dtype
24
+ self.dtype = dtype
25
+ self.device = clip_model.visual.conv1.weight.device
26
+ ctx_dim = clip_model.ln_final.weight.shape[0]
27
+ embed_dim = clip_model.text_projection.shape[1]
28
+ self.ctx_dim = ctx_dim
29
+
30
+ if ctx_init:
31
+ # use given words to initialize context vectors
32
+ print("Initializing the contect with given words: [{}]".format(ctx_init))
33
+ ctx_init = ctx_init.replace("_", " ")
34
+ n_ctx = len(ctx_init.split(" "))
35
+ prompt = tokenize(ctx_init).to(self.device)
36
+ with torch.no_grad():
37
+ embedding = clip_model.token_embedding(prompt).type(dtype)
38
+ ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
39
+ prompt_prefix = ctx_init
40
+
41
+ else:
42
+ print("Random initialization: initializing a generic context")
43
+ ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
44
+ nn.init.normal_(ctx_vectors, std=0.02)
45
+ prompt_prefix = " ".join(["X"] * n_ctx)
46
+
47
+ print(f'Initial context: "{prompt_prefix}"')
48
+ print(f"Number of context words (tokens): {n_ctx}")
49
+ self.prompt_prefix = prompt_prefix
50
+
51
+ self.ctx = nn.Parameter(ctx_vectors) # to be optimized
52
+ self.meta_net = nn.Sequential(OrderedDict([
53
+ ("linear1", nn.Linear(embed_dim, embed_dim // 16)),
54
+ ("relu", nn.ReLU(inplace=True)),
55
+ ("linear2", nn.Linear(embed_dim // 16, ctx_dim))
56
+ ]))
57
+
58
+ classnames = [name.replace("_", " ") for name in classnames]
59
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
60
+ prompts = [prompt_prefix + " " + name + "." for name in classnames]
61
+
62
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
63
+ with torch.no_grad():
64
+ embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
65
+
66
+ # These token vectors will be saved when in save_model(),
67
+ # but they should be ignored in load_model() as we want to use
68
+ # those computed using the current class names
69
+ self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
70
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
71
+
72
+ self.ctx_init = ctx_init
73
+ self.tokenized_prompts = tokenized_prompts # torch.Tensor
74
+ self.name_lens = name_lens
75
+ self.class_token_position = ctx_position
76
+ self.n_cls = n_cls
77
+ self.n_ctx = n_ctx
78
+
79
+ def construct_prompts(self, ctx, prefix, suffix, label=None):
80
+ # dim0 is either batch_size (during training) or n_cls (during testing)
81
+ # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
82
+ # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
83
+ # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
84
+
85
+ if label is not None:
86
+ prefix = prefix[label]
87
+ suffix = suffix[label]
88
+
89
+ prompts = torch.cat(
90
+ [
91
+ prefix, # (dim0, 1, dim)
92
+ ctx, # (dim0, n_ctx, dim)
93
+ suffix, # (dim0, *, dim)
94
+ ],
95
+ dim=1,
96
+ )
97
+
98
+ return prompts
99
+
100
+ def reset_classnames(self, classnames, arch):
101
+ self.n_cls = len(classnames)
102
+ classnames = [name.replace("_", " ") for name in classnames]
103
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
104
+ prompts = [self.prompt_prefix + " " + name + "." for name in classnames]
105
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
106
+
107
+ clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT)
108
+
109
+ with torch.no_grad():
110
+ embedding = clip.token_embedding(tokenized_prompts).type(self.dtype)
111
+
112
+ self.token_prefix = embedding[:, :1, :]
113
+ self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS
114
+
115
+ self.name_lens = name_lens
116
+ self.tokenized_prompts = tokenized_prompts
117
+
118
+ def forward(self, im_features, ctx_only=False):
119
+ prefix = self.token_prefix
120
+ suffix = self.token_suffix
121
+ ctx = self.ctx # (n_ctx, ctx_dim)
122
+ bias = self.meta_net(im_features) # (batch, ctx_dim)
123
+ bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
124
+ ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
125
+ ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
126
+ if ctx_only:
127
+ return ctx_shifted # don't expand to n_cls, optimize one ctx for all classes
128
+
129
+ # Use instance-conditioned context tokens for all classes
130
+ prompts = []
131
+ for ctx_shifted_i in ctx_shifted:
132
+ ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
133
+ pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim)
134
+ prompts.append(pts_i)
135
+ prompts = torch.stack(prompts)
136
+
137
+ return prompts
138
+
139
+ class CoCoOpCLIP(nn.Module):
140
+ def __init__(self, device, classnames, criterion='cosine', arch="ViT-L/14",
141
+ n_ctx=16, ctx_init="a_photo_of_a", ctx_position='end'):
142
+ super().__init__()
143
+ clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
144
+ self.image_encoder = clip.visual
145
+ self.text_encoder = TextEncoder(clip)
146
+ self.logit_scale = clip.logit_scale.data
147
+ # prompt tuning
148
+ self.prompt_generator = CoCoOpPromptLearner(clip, classnames, n_ctx, ctx_init, ctx_position)
149
+ self.tokenized_prompts = self.prompt_generator.tokenized_prompts
150
+ self.criterion = criterion
151
+ self.dtype = clip.dtype
152
+
153
+ def inference(self, image, label=None):
154
+ tokenized_prompts = self.prompt_generator.tokenized_prompts
155
+ logit_scale = self.logit_scale.exp()
156
+
157
+ image_features = self.image_encoder(image.type(self.dtype))
158
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
159
+
160
+ prompts = self.prompt_generator(image_features)
161
+
162
+ logits = []
163
+ for pts_i, imf_i in zip(prompts, image_features):
164
+ text_features = self.text_encoder(pts_i, tokenized_prompts)
165
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
166
+ l_i = logit_scale * imf_i @ text_features.t()
167
+ logits.append(l_i)
168
+ logits = torch.stack(logits)
169
+
170
+ return logits
171
+
172
+ def gen_ctx(self, image, aug=False):
173
+ with torch.no_grad():
174
+ with torch.cuda.amp.autocast():
175
+ image_features = self.image_encoder(image.type(self.dtype))
176
+ if aug:
177
+ image_feature_avg = image_features[0].unsqueeze(0)
178
+ else:
179
+ image_feature_avg = image_features.mean(dim=0, keepdim=True)
180
+ ctx = self.prompt_generator(image_feature_avg, ctx_only=True)
181
+
182
+ return image_features, ctx.detach().clone()
183
+
184
+ def forward_ctx(self, image_features, ctx):
185
+ N = 1
186
+
187
+ prefix = self.prompt_generator.token_prefix.expand(N, -1, -1, -1) # [N, n_cls, 1, dim]
188
+ suffix = self.prompt_generator.token_suffix.expand(N, -1, -1, -1)
189
+ # expand `ctx` n_cls times
190
+ ctx = ctx.expand(self.prompt_generator.n_cls, -1, -1, -1)
191
+ ctx = ctx.permute(1, 0, 2, 3)
192
+ # ctx = ctx.reshape(N, self.prompt_generator.n_cls, -1, self.prompt_generator.ctx_dim)
193
+
194
+ prompts = torch.cat([
195
+ prefix,
196
+ ctx,
197
+ suffix
198
+ ], dim=-2)
199
+
200
+ # full_n_ctx = prompts.size()[-2]
201
+
202
+ prompts = prompts.reshape(N * self.prompt_generator.n_cls, -1, self.prompt_generator.ctx_dim)
203
+ tokenized_prompts = self.prompt_generator.tokenized_prompts
204
+ tokenized_prompts = tokenized_prompts.repeat(N, 1)
205
+ text_features = self.text_encoder(prompts, tokenized_prompts)
206
+
207
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
208
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
209
+
210
+ text_features = text_features.reshape(N, -1, image_features.size()[-1])
211
+
212
+ logit_scale = self.logit_scale.exp()
213
+
214
+ text_features = text_features.squeeze(0)
215
+ logits = logit_scale * image_features @ text_features.t()
216
+
217
+ return logits
218
+
219
+ def forward(self, input):
220
+ if isinstance(input, Tuple):
221
+ image_features, ctx = input
222
+ return self.forward_ctx(image_features, ctx)
223
+ else:
224
+ return self.inference(input)
225
+
226
+ def get_cocoop(clip_arch, test_set, device, n_ctx):
227
+ if test_set in fewshot_datasets:
228
+ classnames = eval("{}_classes".format(test_set.lower()))
229
+ else:
230
+ classnames = imagenet_classes
231
+
232
+ model = CoCoOpCLIP(device, classnames, arch=clip_arch, n_ctx=n_ctx)
233
+
234
+ return model
clip/custom_clip.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from clip import load, tokenize
10
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
11
+ from data.imagnet_prompts import imagenet_classes
12
+ from data.fewshot_datasets import fewshot_datasets
13
+ from data.cls_to_names import *
14
+ from utils.ModelStock import stock_model
15
+
16
+ _tokenizer = _Tokenizer()
17
+
18
+ DOWNLOAD_ROOT='~/.cache/clip'
19
+
20
+ class ClipImageEncoder(nn.Module):
21
+ def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000):
22
+ super(ClipImageEncoder, self).__init__()
23
+ clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
24
+ self.encoder = clip.visual
25
+ del clip.transformer
26
+ torch.cuda.empty_cache()
27
+
28
+ self.cls_head = nn.Linear(embed_dim, n_class)
29
+
30
+ @property
31
+ def dtype(self):
32
+ return self.encoder.conv1.weight.dtype
33
+
34
+ def forward(self, image):
35
+ x = self.encoder(image.type(self.dtype))
36
+ output = self.cls_head(x)
37
+ return output
38
+
39
+
40
+ class TextEncoder(nn.Module):
41
+ def __init__(self, clip_model):
42
+ super().__init__()
43
+ self.transformer = clip_model.transformer
44
+ self.positional_embedding = clip_model.positional_embedding
45
+ self.ln_final = clip_model.ln_final
46
+ self.text_projection = clip_model.text_projection
47
+ self.dtype = clip_model.dtype
48
+
49
+ def forward(self, prompts, tokenized_prompts):
50
+ x = prompts + self.positional_embedding.type(self.dtype)
51
+ x = x.permute(1, 0, 2) # NLD -> LND
52
+ x = self.transformer(x)
53
+ x = x.permute(1, 0, 2) # LND -> NLD
54
+ x = self.ln_final(x).type(self.dtype)
55
+
56
+ # x.shape = [batch_size, n_ctx, transformer.width]
57
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
58
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
59
+
60
+ return x
61
+
62
+
63
+ class PromptLearner(nn.Module):
64
+ def __init__(self, clip_model, classnames, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False):
65
+ super().__init__()
66
+ n_cls = len(classnames)
67
+ self.learned_cls = learned_cls
68
+ dtype = clip_model.dtype
69
+ self.dtype = dtype
70
+ self.device = clip_model.visual.conv1.weight.device
71
+ ctx_dim = clip_model.ln_final.weight.shape[0]
72
+ self.ctx_dim = ctx_dim
73
+ self.batch_size = batch_size
74
+
75
+ # self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, clip_model)
76
+
77
+ if ctx_init:
78
+ # use given words to initialize context vectors
79
+ print("Initializing the contect with given words: [{}]".format(ctx_init))
80
+ ctx_init = ctx_init.replace("_", " ")
81
+ if '[CLS]' in ctx_init:
82
+ ctx_list = ctx_init.split(" ")
83
+ split_idx = ctx_list.index("[CLS]")
84
+ ctx_init = ctx_init.replace("[CLS] ", "")
85
+ ctx_position = "middle"
86
+ else:
87
+ split_idx = None
88
+ self.split_idx = split_idx
89
+ n_ctx = len(ctx_init.split(" "))
90
+ prompt = tokenize(ctx_init).to(self.device)
91
+ with torch.no_grad():
92
+ embedding = clip_model.token_embedding(prompt).type(dtype)
93
+ ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
94
+ prompt_prefix = ctx_init
95
+ else:
96
+ print("Random initialization: initializing a generic context")
97
+ ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
98
+ nn.init.normal_(ctx_vectors, std=0.02)
99
+ prompt_prefix = " ".join(["X"] * n_ctx)
100
+
101
+ self.prompt_prefix = prompt_prefix
102
+
103
+ print(f'Initial context: "{prompt_prefix}"')
104
+ print(f"Number of context words (tokens): {n_ctx}")
105
+
106
+ # batch-wise prompt tuning for test-time adaptation
107
+ if self.batch_size is not None:
108
+ ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D)
109
+ self.ctx_init_state = ctx_vectors.detach().clone()
110
+ self.ctx = nn.Parameter(ctx_vectors) # to be optimized
111
+
112
+ if not self.learned_cls:
113
+ classnames = [name.replace("_", " ") for name in classnames]
114
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
115
+ prompts = [prompt_prefix + " " + name + "." for name in classnames]
116
+ else:
117
+ print("Random initialization: initializing a learnable class token")
118
+ cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word
119
+ nn.init.normal_(cls_vectors, std=0.02)
120
+ cls_token = "X"
121
+ name_lens = [1 for _ in classnames]
122
+ prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames]
123
+
124
+ self.cls_init_state = cls_vectors.detach().clone()
125
+ self.cls = nn.Parameter(cls_vectors) # to be optimized
126
+
127
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
128
+ with torch.no_grad():
129
+ embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
130
+
131
+ # These token vectors will be saved when in save_model(),
132
+ # but they should be ignored in load_model() as we want to use
133
+ # those computed using the current class names
134
+ self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
135
+ if self.learned_cls:
136
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS
137
+ else:
138
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
139
+
140
+ self.ctx_init = ctx_init
141
+ self.tokenized_prompts = tokenized_prompts # torch.Tensor
142
+ self.name_lens = name_lens
143
+ self.class_token_position = ctx_position
144
+ self.n_cls = n_cls
145
+ self.n_ctx = n_ctx
146
+ self.classnames = classnames
147
+
148
+ def reset(self):
149
+ ctx_vectors = self.ctx_init_state
150
+ self.ctx.copy_(ctx_vectors) # to be optimized
151
+ if self.learned_cls:
152
+ cls_vectors = self.cls_init_state
153
+ self.cls.copy_(cls_vectors)
154
+
155
+ def reset_classnames(self, classnames, arch):
156
+ self.n_cls = len(classnames)
157
+ if not self.learned_cls:
158
+ classnames = [name.replace("_", " ") for name in classnames]
159
+ name_lens = [len(_tokenizer.encode(name)) for name in classnames]
160
+ prompts = [self.prompt_prefix + " " + name + "." for name in classnames]
161
+ else:
162
+ cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word
163
+ nn.init.normal_(cls_vectors, std=0.02)
164
+ cls_token = "X"
165
+ name_lens = [1 for _ in classnames]
166
+ prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames]
167
+ # TODO: re-init the cls parameters
168
+ # self.cls = nn.Parameter(cls_vectors) # to be optimized
169
+ self.cls_init_state = cls_vectors.detach().clone()
170
+ tokenized_prompts = torch.cat([tokenize(p) for p in prompts]).to(self.device)
171
+
172
+ clip, _, _ = load(arch, device=self.device, download_root=DOWNLOAD_ROOT)
173
+
174
+ with torch.no_grad():
175
+ embedding = clip.token_embedding(tokenized_prompts).type(self.dtype)
176
+
177
+ self.token_prefix = embedding[:, :1, :]
178
+ self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS
179
+
180
+ self.name_lens = name_lens
181
+ self.tokenized_prompts = tokenized_prompts
182
+ self.classnames = classnames
183
+
184
+ def forward(self, init=None):
185
+ # the init will be used when computing CLIP directional loss
186
+ if init is not None:
187
+ ctx = init
188
+ else:
189
+ ctx = self.ctx
190
+ if ctx.dim() == 2:
191
+ ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
192
+ elif not ctx.size()[0] == self.n_cls:
193
+ ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
194
+
195
+ prefix = self.token_prefix
196
+ suffix = self.token_suffix
197
+ if self.batch_size is not None:
198
+ # This way only works for single-gpu setting (could pass batch size as an argument for forward())
199
+ prefix = prefix.repeat(self.batch_size, 1, 1, 1)
200
+ suffix = suffix.repeat(self.batch_size, 1, 1, 1)
201
+
202
+ if self.learned_cls:
203
+ assert self.class_token_position == "end"
204
+ if self.class_token_position == "end":
205
+ if self.learned_cls:
206
+ cls = self.cls
207
+ prompts = torch.cat(
208
+ [
209
+ prefix, # (n_cls, 1, dim)
210
+ ctx, # (n_cls, n_ctx, dim)
211
+ cls, # (n_cls, 1, dim)
212
+ suffix, # (n_cls, *, dim)
213
+ ],
214
+ dim=-2,
215
+ )
216
+ else:
217
+ prompts = torch.cat(
218
+ [
219
+ prefix, # (n_cls, 1, dim)
220
+ ctx, # (n_cls, n_ctx, dim)
221
+ suffix, # (n_cls, *, dim)
222
+ ],
223
+ dim=-2,
224
+ )
225
+ elif self.class_token_position == "middle":
226
+ # TODO: to work with a batch of prompts
227
+ if self.split_idx is not None:
228
+ half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init`
229
+ else:
230
+ half_n_ctx = self.n_ctx // 2
231
+ prompts = []
232
+ for i in range(self.n_cls):
233
+ name_len = self.name_lens[i]
234
+ prefix_i = prefix[i : i + 1, :, :]
235
+ class_i = suffix[i : i + 1, :name_len, :]
236
+ suffix_i = suffix[i : i + 1, name_len:, :]
237
+ ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
238
+ ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
239
+ prompt = torch.cat(
240
+ [
241
+ prefix_i, # (1, 1, dim)
242
+ ctx_i_half1, # (1, n_ctx//2, dim)
243
+ class_i, # (1, name_len, dim)
244
+ ctx_i_half2, # (1, n_ctx//2, dim)
245
+ suffix_i, # (1, *, dim)
246
+ ],
247
+ dim=1,
248
+ )
249
+ prompts.append(prompt)
250
+ prompts = torch.cat(prompts, dim=0)
251
+
252
+ elif self.class_token_position == "front":
253
+ prompts = []
254
+ for i in range(self.n_cls):
255
+ name_len = self.name_lens[i]
256
+ prefix_i = prefix[i : i + 1, :, :]
257
+ class_i = suffix[i : i + 1, :name_len, :]
258
+ suffix_i = suffix[i : i + 1, name_len:, :]
259
+ ctx_i = ctx[i : i + 1, :, :]
260
+ prompt = torch.cat(
261
+ [
262
+ prefix_i, # (1, 1, dim)
263
+ class_i, # (1, name_len, dim)
264
+ ctx_i, # (1, n_ctx, dim)
265
+ suffix_i, # (1, *, dim)
266
+ ],
267
+ dim=1,
268
+ )
269
+ prompts.append(prompt)
270
+ prompts = torch.cat(prompts, dim=0)
271
+
272
+ else:
273
+ raise ValueError
274
+
275
+ return prompts
276
+
277
+
278
+ class ClipTestTimeTuning(nn.Module):
279
+ def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14",
280
+ n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False, pubmedclip_path=None,
281
+ merge=False, state_dict=None):
282
+ super(ClipTestTimeTuning, self).__init__()
283
+ clip, _, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
284
+ if pubmedclip_path is not None:
285
+ ft_dict = torch.load(pubmedclip_path, map_location=f'cuda:{device}')
286
+ if merge:
287
+ print("Merging the weights of clip and state dict using WiSE-FT approach")
288
+ # WiSE-FT approach
289
+ merged_dict = {}
290
+ alpha = 0.50 # You can adjust this value as needed
291
+ for key in clip.state_dict().keys():
292
+ merged_dict[key] = alpha * ft_dict[key] + (1 - alpha) * clip.state_dict()[key] # clip.load_state_dict(state_dict)
293
+ # Model Stock
294
+ # state_dict = stock_model(state_dict, clip.state_dict())
295
+ else:
296
+ merged_dict = ft_dict
297
+ clip.load_state_dict(merged_dict)
298
+ if state_dict is not None:
299
+ clip.load_state_dict(state_dict)
300
+ self.visual = clip.visual
301
+ self.text_encoder = TextEncoder(clip)
302
+ self.logit_scale = clip.logit_scale.data
303
+ # prompt tuning
304
+ self.prompt_learner = PromptLearner(clip, classnames, batch_size, n_ctx, ctx_init, ctx_position, learned_cls)
305
+ self.criterion = criterion
306
+ self.l2_norm_cal = False
307
+
308
+ @property
309
+ def dtype(self):
310
+ return self.visual.conv1.weight.dtype
311
+
312
+ # restore the initial state of the prompt_learner (tunable prompt)
313
+ def reset(self):
314
+ self.prompt_learner.reset()
315
+
316
+ def reset_classnames(self, classnames, arch):
317
+ self.prompt_learner.reset_classnames(classnames, arch)
318
+
319
+ def get_text_features(self, normalize=True):
320
+ text_features = []
321
+ prompts = self.prompt_learner()
322
+ tokenized_prompts = self.prompt_learner.tokenized_prompts
323
+ t_features = self.text_encoder(prompts, tokenized_prompts)
324
+ if normalize:
325
+ t_features = t_features / t_features.norm(dim=-1, keepdim=True)
326
+ text_features.append(t_features)
327
+ text_features = torch.stack(text_features, dim=0)
328
+
329
+ return torch.mean(text_features, dim=0)
330
+
331
+ def inference(self, image, return_logits=False, normalize=True):
332
+ with torch.no_grad():
333
+ image_features = self.visual(image.type(self.dtype))
334
+ # with torch.no_grad():
335
+ text_features = self.get_text_features(normalize=normalize)
336
+ if normalize:
337
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
338
+
339
+ #[c-tpt] --------------------------------------------
340
+ if self.l2_norm_cal:
341
+ prompt_mean = text_features.mean(0)
342
+ feature_distance = text_features - prompt_mean
343
+ l2_norm = torch.linalg.norm(feature_distance, dim=-1)
344
+ l2_norm_mean = l2_norm.mean()
345
+
346
+ #for saving to csv file
347
+ self.l2_norm_mean = l2_norm_mean.item()
348
+
349
+ #for training
350
+ self.l2_norm_mean_training = l2_norm_mean
351
+
352
+ #-----------------------------------------------------
353
+
354
+ logit_scale = self.logit_scale.exp()
355
+ logits = logit_scale * image_features @ text_features.t()
356
+
357
+ if return_logits:
358
+ return logits, image_features, text_features
359
+
360
+ return logits
361
+
362
+ def forward(self, input, return_logits=False, normalize=True):
363
+ if isinstance(input, Tuple):
364
+ view_0, view_1, view_2 = input
365
+ return self.contrast_prompt_tuning(view_0, view_1, view_2)
366
+ elif len(input.size()) == 2:
367
+ return self.directional_prompt_tuning(input)
368
+ else:
369
+ return self.inference(input, return_logits, normalize)
370
+
371
+
372
+ def get_coop(clip_arch, test_set, device, n_ctx, ctx_init, classnames, learned_cls=False, pubmedclip_path=None, merge=False, state_dict=None):
373
+ # if test_set in fewshot_datasets:
374
+ # classnames = eval("{}_classes".format(test_set.lower()))
375
+ # elif test_set == 'bongard':
376
+ # if learned_cls:
377
+ # classnames = ['X', 'X']
378
+ # else:
379
+ # classnames = ['True', 'False']
380
+ # else:
381
+ # classnames = imagenet_classes
382
+
383
+ model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch,
384
+ n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls, pubmedclip_path=pubmedclip_path, merge=merge,
385
+ state_dict=state_dict)
386
+
387
+ return model
388
+
clip/custom_medclip.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from typing import List, Tuple
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchvision.models import resnet50, ResNet
9
+
10
+ from .clip import load, tokenize
11
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
12
+ from data.imagnet_prompts import imagenet_classes
13
+ from data.fewshot_datasets import fewshot_datasets
14
+ from data.cls_to_names import *
15
+ # from data.medclip_datasets_clsnames import *
16
+ import os
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+ _tokenizer = _Tokenizer()
19
+
20
+ DOWNLOAD_ROOT='~/.cache/clip'
21
+
22
+ # class ClipImageEncoder(nn.Module):
23
+ # def __init__(self, device, arch="ViT-L/14", image_resolution=224, n_class=1000):
24
+ # super(ClipImageEncoder, self).__init__()
25
+ # clip, embed_dim, _ = load(arch, device=device, download_root=DOWNLOAD_ROOT)
26
+ # self.encoder = clip.visual
27
+ # del clip.transformer
28
+ # torch.cuda.empty_cache()
29
+
30
+ # self.cls_head = nn.Linear(embed_dim, n_class)
31
+
32
+ # @property
33
+ # def dtype(self):
34
+ # return self.encoder.conv1.weight.dtype
35
+
36
+ # def forward(self, image):
37
+ # x = self.encoder(image.type(self.dtype))
38
+ # output = self.cls_head(x)
39
+ # return output
40
+
41
+
42
+ class TextEncoder(nn.Module):
43
+ def __init__(self, medclip_text_model):
44
+ super().__init__()
45
+ self.medclip_text_model = medclip_text_model
46
+
47
+ def forward(self, prompts_embeddings, tokenized_prompts):
48
+
49
+ output = self.medclip_text_model.model(inputs_embeds=prompts_embeddings, attention_mask=tokenized_prompts['attention_mask'])
50
+
51
+ # take the average of last four layers
52
+ # last_hidden_states = torch.stack(output['hidden_states'][-self.last_n_layer:]) # n_layer, batch, seqlen, emb_dim
53
+ # embed = last_hidden_states.permute(1,0,2,3)
54
+ # embed = embed.mean(1).mean(1) # pooling
55
+
56
+ # get 1+2+last layer
57
+ last_hidden_states = torch.stack([output['hidden_states'][1], output['hidden_states'][2], output['hidden_states'][-1]]) # n_layer, batch, seqlen, emb_dim
58
+ embed = last_hidden_states.permute(1,0,2,3).mean(2).mean(1) # pooling
59
+
60
+ # let's take only the last hidden layer
61
+ # embed = output['pooler_output']
62
+
63
+ embed = self.medclip_text_model.projection_head(embed)
64
+ return embed
65
+
66
+
67
+ class PromptLearner(nn.Module):
68
+ def __init__(self, medclip_model, classnames, device, batch_size=None, n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False):
69
+ super().__init__()
70
+ n_cls = len(classnames)
71
+ self.learned_cls = learned_cls
72
+ dtype = medclip_model.dtype
73
+ self.dtype = dtype
74
+ ctx_dim = 768 # hardcoded for now!!! medclip_model.ln_final.weight.shape[0]
75
+ self.ctx_dim = ctx_dim
76
+ self.batch_size = batch_size
77
+ self.device = device
78
+ self.medclip_model = medclip_model
79
+
80
+ # self.ctx, prompt_prefix = self.reset_prompt(ctx_dim, ctx_init, medclip_model)
81
+
82
+ if ctx_init:
83
+ # raise NotImplementedError("This part is not yet implemented.")
84
+ # use given words to initialize context vectors
85
+ print("Initializing the contect with given words: [{}]".format(ctx_init))
86
+ # breakpoint()
87
+ ctx_init = ctx_init.replace("_", " ")
88
+ if '[CLS]' in ctx_init:
89
+ ctx_list = ctx_init.split(" ")
90
+ split_idx = ctx_list.index("[CLS]")
91
+ ctx_init = ctx_init.replace("[CLS] ", "")
92
+ ctx_position = "middle"
93
+ else:
94
+ split_idx = None
95
+ self.split_idx = split_idx
96
+ n_ctx = len(ctx_init.split(" "))
97
+
98
+ # prompt = tokenize(ctx_init).to(self.device)
99
+ prompt = ctx_init
100
+ tokenized_prompts = medclip_model.text_model.tokenizer(prompt, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
101
+ prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77]
102
+ with torch.no_grad():
103
+ embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768]
104
+ # embedding = medclip_model.token_embedding(prompt).type(dtype)
105
+ ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
106
+ prompt_prefix = ctx_init
107
+ else:
108
+ print("Random initialization: initializing a generic context")
109
+ ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
110
+ nn.init.normal_(ctx_vectors, std=0.02)
111
+ prompt_prefix = " ".join(["X"] * n_ctx)
112
+
113
+ self.prompt_prefix = prompt_prefix
114
+
115
+ print(f'Initial context: "{prompt_prefix}"')
116
+ print(f"Number of context words (tokens): {n_ctx}")
117
+
118
+ # batch-wise prompt tuning for test-time adaptation
119
+ if self.batch_size is not None:
120
+ ctx_vectors = ctx_vectors.repeat(batch_size, 1, 1) #(N, L, D)
121
+ self.ctx_init_state = ctx_vectors.detach().clone()
122
+ self.ctx = nn.Parameter(ctx_vectors) # to be optimized
123
+
124
+ if not self.learned_cls:
125
+ classnames = [name.replace("_", " ") for name in classnames]
126
+ name_lens = [len(medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted
127
+ prompts = [prompt_prefix + " " + name + "." for name in classnames]
128
+ else:
129
+ print("Random initialization: initializing a learnable class token")
130
+ cls_vectors = torch.empty(n_cls, 1, ctx_dim, dtype=dtype) # assume each learnable cls_token is only 1 word
131
+ nn.init.normal_(cls_vectors, std=0.02)
132
+ cls_token = "X"
133
+ name_lens = [1 for _ in classnames]
134
+ prompts = [prompt_prefix + " " + cls_token + "." for _ in classnames]
135
+
136
+ self.cls_init_state = cls_vectors.detach().clone()
137
+ self.cls = nn.Parameter(cls_vectors) # to be optimized
138
+
139
+ tokenized_prompts = medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
140
+ prompts_tokens = tokenized_prompts['input_ids'] # [n_cls, 77]
141
+ with torch.no_grad():
142
+ embedding = medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(dtype) # [n_cls, 77, 768]
143
+
144
+ # These token vectors will be saved when in save_model(),
145
+ # but they should be ignored in load_model() as we want to use
146
+ # those computed using the current class names
147
+ self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
148
+ if self.learned_cls:
149
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx + 1:, :]) # ..., EOS
150
+ else:
151
+ self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
152
+
153
+ self.ctx_init = ctx_init
154
+ self.tokenized_prompts = tokenized_prompts # torch.Tensor
155
+ self.name_lens = name_lens
156
+ self.class_token_position = ctx_position
157
+ self.n_cls = n_cls
158
+ self.n_ctx = n_ctx
159
+ self.classnames = classnames
160
+
161
+ def reset(self):
162
+ ctx_vectors = self.ctx_init_state
163
+ self.ctx.copy_(ctx_vectors) # to be optimized
164
+ if self.learned_cls:
165
+ cls_vectors = self.cls_init_state
166
+ self.cls.copy_(cls_vectors)
167
+
168
+ def reset_classnames(self, classnames, arch):
169
+ self.n_cls = len(classnames)
170
+ if not self.learned_cls:
171
+ classnames = [name.replace("_", " ") for name in classnames]
172
+ name_lens = [len(self.medclip_model.text_model.tokenizer.encode(name))-2 for name in classnames] # [CLS] and [SEP] are not counted
173
+ prompts = [self.prompt_prefix + " " + name + "." for name in classnames]
174
+ else:
175
+ cls_vectors = torch.empty(self.n_cls, 1, self.ctx_dim, dtype=self.dtype) # assume each learnable cls_token is only 1 word
176
+ nn.init.normal_(cls_vectors, std=0.02)
177
+ cls_token = "X"
178
+ name_lens = [1 for _ in classnames]
179
+ prompts = [self.prompt_prefix + " " + cls_token + "." for _ in classnames]
180
+
181
+ self.cls_init_state = cls_vectors.detach().clone()
182
+
183
+ tokenized_prompts = self.medclip_model.text_model.tokenizer(prompts, padding='max_length', max_length=25, truncation=True, return_tensors='pt').to(self.device)
184
+ prompts_tokens = tokenized_prompts['input_ids']
185
+
186
+ with torch.no_grad():
187
+ embedding = self.medclip_model.text_model.model.embeddings.word_embeddings(prompts_tokens).type(self.dtype) # [n_cls, 77, 768]
188
+
189
+ self.token_prefix = embedding[:, :1, :]
190
+ self.token_suffix = embedding[:, 1 + self.n_ctx :, :] # CLS, EOS
191
+
192
+ self.name_lens = name_lens
193
+ self.tokenized_prompts = tokenized_prompts
194
+ self.classnames = classnames
195
+
196
+ def forward(self, init=None):
197
+ # the init will be used when computing CLIP directional loss
198
+ if init is not None:
199
+ ctx = init
200
+ else:
201
+ ctx = self.ctx
202
+ if ctx.dim() == 2:
203
+ ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
204
+ elif not ctx.size()[0] == self.n_cls:
205
+ ctx = ctx.unsqueeze(1).expand(-1, self.n_cls, -1, -1)
206
+
207
+ prefix = self.token_prefix
208
+ suffix = self.token_suffix
209
+ if self.batch_size is not None:
210
+ # This way only works for single-gpu setting (could pass batch size as an argument for forward())
211
+ prefix = prefix.repeat(self.batch_size, 1, 1, 1)
212
+ suffix = suffix.repeat(self.batch_size, 1, 1, 1)
213
+
214
+ if self.learned_cls:
215
+ assert self.class_token_position == "end"
216
+ if self.class_token_position == "end":
217
+ if self.learned_cls:
218
+ cls = self.cls
219
+ prompts = torch.cat(
220
+ [
221
+ prefix, # (n_cls, 1, dim)
222
+ ctx, # (n_cls, n_ctx, dim)
223
+ cls, # (n_cls, 1, dim)
224
+ suffix, # (n_cls, *, dim)
225
+ ],
226
+ dim=-2,
227
+ )
228
+ else:
229
+ prompts = torch.cat(
230
+ [
231
+ prefix, # (n_cls, 1, dim)
232
+ ctx, # (n_cls, n_ctx, dim)
233
+ suffix, # (n_cls, *, dim)
234
+ ],
235
+ dim=-2,
236
+ )
237
+ elif self.class_token_position == "middle":
238
+ # TODO: to work with a batch of prompts
239
+ if self.split_idx is not None:
240
+ half_n_ctx = self.split_idx # split the ctx at the position of [CLS] in `ctx_init`
241
+ else:
242
+ half_n_ctx = self.n_ctx // 2
243
+ prompts = []
244
+ for i in range(self.n_cls):
245
+ name_len = self.name_lens[i]
246
+ prefix_i = prefix[i : i + 1, :, :]
247
+ class_i = suffix[i : i + 1, :name_len, :]
248
+ suffix_i = suffix[i : i + 1, name_len:, :]
249
+ ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
250
+ ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
251
+ prompt = torch.cat(
252
+ [
253
+ prefix_i, # (1, 1, dim)
254
+ ctx_i_half1, # (1, n_ctx//2, dim)
255
+ class_i, # (1, name_len, dim)
256
+ ctx_i_half2, # (1, n_ctx//2, dim)
257
+ suffix_i, # (1, *, dim)
258
+ ],
259
+ dim=1,
260
+ )
261
+ prompts.append(prompt)
262
+ prompts = torch.cat(prompts, dim=0)
263
+
264
+ elif self.class_token_position == "front":
265
+ prompts = []
266
+ for i in range(self.n_cls):
267
+ name_len = self.name_lens[i]
268
+ prefix_i = prefix[i : i + 1, :, :]
269
+ class_i = suffix[i : i + 1, :name_len, :]
270
+ suffix_i = suffix[i : i + 1, name_len:, :]
271
+ ctx_i = ctx[i : i + 1, :, :]
272
+ prompt = torch.cat(
273
+ [
274
+ prefix_i, # (1, 1, dim)
275
+ class_i, # (1, name_len, dim)
276
+ ctx_i, # (1, n_ctx, dim)
277
+ suffix_i, # (1, *, dim)
278
+ ],
279
+ dim=1,
280
+ )
281
+ prompts.append(prompt)
282
+ prompts = torch.cat(prompts, dim=0)
283
+
284
+ else:
285
+ raise ValueError
286
+
287
+ return prompts
288
+
289
+ from MedCLIP.medclip import MedCLIPModel, MedCLIPVisionModel, MedCLIPVisionModelViT
290
+ from MedCLIP.medclip import MedCLIPProcessor
291
+
292
+ def load_medclip_to_cpu():
293
+ model = MedCLIPModel(vision_cls=MedCLIPVisionModelViT)
294
+ model.from_pretrained()
295
+ # breakpoint()
296
+ # model.from_pretrained("/l/users/asif.hanif/pre-trained-models/vlps/medclip/pretrained/medclip-vit/")
297
+ model.from_pretrained("./MedCLIP/pretrained/medclip-vit/")
298
+ # for vit
299
+ model.dtype = model.vision_model.model.embeddings.patch_embeddings.projection.weight.dtype
300
+ # for Resnet
301
+ # model.dtype = model.vision_model.model.conv1.weight.dtype
302
+
303
+
304
+ model.eval()
305
+ return model
306
+
307
+ class ClipTestTimeTuning(nn.Module):
308
+ def __init__(self, device, classnames, batch_size, criterion='cosine', arch="ViT-L/14",
309
+ n_ctx=16, ctx_init=None, ctx_position='end', learned_cls=False):
310
+ super(ClipTestTimeTuning, self).__init__()
311
+ self.device = device
312
+ self.medclip_model = load_medclip_to_cpu()
313
+ self.dtype = self.medclip_model.dtype
314
+ self.medclip_model = self.medclip_model.to(self.device)
315
+ self.image_encoder = self.medclip_model.vision_model
316
+ self.text_encoder = TextEncoder(self.medclip_model.text_model)
317
+ self.logit_scale = self.medclip_model.logit_scale.data
318
+ # prompt tuning
319
+ self.prompt_learner = PromptLearner(self.medclip_model, classnames, self.device, batch_size, n_ctx, ctx_init, ctx_position, learned_cls)
320
+ self.criterion = criterion
321
+ self.l2_norm_cal = False
322
+
323
+ # @property
324
+ # def dtype(self):
325
+ # return self.image_encoder.conv1.weight.dtype
326
+
327
+ # restore the initial state of the prompt_learner (tunable prompt)
328
+ def reset(self):
329
+ self.prompt_learner.reset()
330
+
331
+ def reset_classnames(self, classnames, arch):
332
+ self.prompt_learner.reset_classnames(classnames, arch)
333
+
334
+ def get_text_features(self):
335
+ text_features = []
336
+ prompts = self.prompt_learner()
337
+ tokenized_prompts = self.prompt_learner.tokenized_prompts
338
+ t_features = self.text_encoder(prompts, tokenized_prompts)
339
+ text_features.append(t_features / t_features.norm(dim=-1, keepdim=True))
340
+ text_features = torch.stack(text_features, dim=0)
341
+
342
+ return torch.mean(text_features, dim=0)
343
+
344
+ def inference(self, image):
345
+ with torch.no_grad():
346
+ image_features = self.image_encoder(image.type(self.dtype))
347
+
348
+ text_features = self.get_text_features()
349
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
350
+
351
+ #[c-tpt] --------------------------------------------
352
+ if self.l2_norm_cal:
353
+ prompt_mean = text_features.mean(0)
354
+ feature_distance = text_features - prompt_mean
355
+ l2_norm = torch.linalg.norm(feature_distance, dim=-1)
356
+ l2_norm_mean = l2_norm.mean()
357
+
358
+ #for saving to csv file
359
+ self.l2_norm_mean = l2_norm_mean.item()
360
+
361
+ #for training
362
+ self.l2_norm_mean_training = l2_norm_mean
363
+
364
+ #-----------------------------------------------------
365
+
366
+ logit_scale = self.logit_scale.exp()
367
+ logits = logit_scale * image_features @ text_features.t()
368
+
369
+ return logits
370
+
371
+ def forward(self, input):
372
+ # breakpoint()
373
+ if isinstance(input, Tuple):
374
+ view_0, view_1, view_2 = input
375
+ return self.contrast_prompt_tuning(view_0, view_1, view_2)
376
+ elif len(input.size()) == 2:
377
+ return self.directional_prompt_tuning(input)
378
+ else:
379
+ return self.inference(input)
380
+
381
+
382
+ def get_coop(clip_arch, test_set, device, n_ctx, ctx_init=None, learned_cls=False):
383
+ classnames = eval("{}_classes".format(test_set.lower()))
384
+
385
+ model = ClipTestTimeTuning(device, classnames, None, arch=clip_arch,
386
+ n_ctx=n_ctx, ctx_init=ctx_init, learned_cls=learned_cls)
387
+
388
+ return model
389
+
clip/model.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+ self.relu1 = nn.ReLU(inplace=True)
20
+
21
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
+ self.bn2 = nn.BatchNorm2d(planes)
23
+ self.relu2 = nn.ReLU(inplace=True)
24
+
25
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
+
27
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
+ self.relu3 = nn.ReLU(inplace=True)
30
+
31
+ self.downsample = None
32
+ self.stride = stride
33
+
34
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
+ self.downsample = nn.Sequential(OrderedDict([
37
+ ("-1", nn.AvgPool2d(stride)),
38
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
+ ("1", nn.BatchNorm2d(planes * self.expansion))
40
+ ]))
41
+
42
+ def forward(self, x: torch.Tensor):
43
+ identity = x
44
+
45
+ out = self.relu1(self.bn1(self.conv1(x)))
46
+ out = self.relu2(self.bn2(self.conv2(out)))
47
+ out = self.avgpool(out)
48
+ out = self.bn3(self.conv3(out))
49
+
50
+ if self.downsample is not None:
51
+ identity = self.downsample(x)
52
+
53
+ out += identity
54
+ out = self.relu3(out)
55
+ return out
56
+
57
+
58
+ class AttentionPool2d(nn.Module):
59
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
+ super().__init__()
61
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
64
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
65
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
+ self.num_heads = num_heads
67
+
68
+ def forward(self, x):
69
+ x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
+ x, _ = F.multi_head_attention_forward(
73
+ query=x[:1], key=x, value=x,
74
+ embed_dim_to_check=x.shape[-1],
75
+ num_heads=self.num_heads,
76
+ q_proj_weight=self.q_proj.weight,
77
+ k_proj_weight=self.k_proj.weight,
78
+ v_proj_weight=self.v_proj.weight,
79
+ in_proj_weight=None,
80
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
+ bias_k=None,
82
+ bias_v=None,
83
+ add_zero_attn=False,
84
+ dropout_p=0,
85
+ out_proj_weight=self.c_proj.weight,
86
+ out_proj_bias=self.c_proj.bias,
87
+ use_separate_proj_weight=True,
88
+ training=self.training,
89
+ need_weights=False
90
+ )
91
+ return x.squeeze(0)
92
+
93
+
94
+ class ModifiedResNet(nn.Module):
95
+ """
96
+ A ResNet class that is similar to torchvision's but contains the following changes:
97
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
+ - The final pooling layer is a QKV attention instead of an average pool
100
+ """
101
+
102
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
+ super().__init__()
104
+ self.output_dim = output_dim
105
+ self.input_resolution = input_resolution
106
+
107
+ # the 3-layer stem
108
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
+ self.bn1 = nn.BatchNorm2d(width // 2)
110
+ self.relu1 = nn.ReLU(inplace=True)
111
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
+ self.bn2 = nn.BatchNorm2d(width // 2)
113
+ self.relu2 = nn.ReLU(inplace=True)
114
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
+ self.bn3 = nn.BatchNorm2d(width)
116
+ self.relu3 = nn.ReLU(inplace=True)
117
+ self.avgpool = nn.AvgPool2d(2)
118
+
119
+ # residual layers
120
+ self._inplanes = width # this is a *mutable* variable used during construction
121
+ self.layer1 = self._make_layer(width, layers[0])
122
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
+
126
+ embed_dim = width * 32 # the ResNet feature dimension
127
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
+
129
+ def _make_layer(self, planes, blocks, stride=1):
130
+ layers = [Bottleneck(self._inplanes, planes, stride)]
131
+
132
+ self._inplanes = planes * Bottleneck.expansion
133
+ for _ in range(1, blocks):
134
+ layers.append(Bottleneck(self._inplanes, planes))
135
+
136
+ return nn.Sequential(*layers)
137
+
138
+ def forward(self, x):
139
+ def stem(x):
140
+ x = self.relu1(self.bn1(self.conv1(x)))
141
+ x = self.relu2(self.bn2(self.conv2(x)))
142
+ x = self.relu3(self.bn3(self.conv3(x)))
143
+ x = self.avgpool(x)
144
+ return x
145
+
146
+ x = x.type(self.conv1.weight.dtype)
147
+ x = stem(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ x = self.attnpool(x)
153
+
154
+ return x
155
+
156
+
157
+ class LayerNorm(nn.LayerNorm):
158
+ """Subclass torch's LayerNorm to handle fp16."""
159
+
160
+ def forward(self, x: torch.Tensor):
161
+ orig_type = x.dtype
162
+ ret = super().forward(x.type(torch.float32))
163
+ return ret.type(orig_type)
164
+
165
+
166
+ class QuickGELU(nn.Module):
167
+ def forward(self, x: torch.Tensor):
168
+ return x * torch.sigmoid(1.702 * x)
169
+
170
+
171
+ class ResidualAttentionBlock(nn.Module):
172
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173
+ super().__init__()
174
+
175
+ self.attn = nn.MultiheadAttention(d_model, n_head)
176
+ self.ln_1 = LayerNorm(d_model)
177
+ self.mlp = nn.Sequential(OrderedDict([
178
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
179
+ ("gelu", QuickGELU()),
180
+ ("c_proj", nn.Linear(d_model * 4, d_model))
181
+ ]))
182
+ self.ln_2 = LayerNorm(d_model)
183
+ self.attn_mask = attn_mask
184
+
185
+ def attention(self, x: torch.Tensor):
186
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188
+
189
+ def forward(self, x: torch.Tensor):
190
+ x = x + self.attention(self.ln_1(x))
191
+ x = x + self.mlp(self.ln_2(x))
192
+ return x
193
+
194
+
195
+ class Transformer(nn.Module):
196
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197
+ super().__init__()
198
+ self.width = width
199
+ self.layers = layers
200
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201
+
202
+ def forward(self, x: torch.Tensor):
203
+ return self.resblocks(x)
204
+
205
+
206
+ class VisionTransformer(nn.Module):
207
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208
+ super().__init__()
209
+ self.input_resolution = input_resolution
210
+ self.output_dim = output_dim
211
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212
+
213
+ scale = width ** -0.5
214
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
215
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216
+ self.ln_pre = LayerNorm(width)
217
+
218
+ self.transformer = Transformer(width, layers, heads)
219
+
220
+ self.ln_post = LayerNorm(width)
221
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222
+
223
+ def forward(self, x: torch.Tensor):
224
+ x = self.conv1(x) # shape = [*, width, grid, grid]
225
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228
+ x = x + self.positional_embedding.to(x.dtype)
229
+ x = self.ln_pre(x)
230
+
231
+ x = x.permute(1, 0, 2) # NLD -> LND
232
+ x = self.transformer(x)
233
+ x = x.permute(1, 0, 2) # LND -> NLD
234
+
235
+ x = self.ln_post(x[:, 0, :])
236
+
237
+ if self.proj is not None:
238
+ x = x @ self.proj
239
+
240
+ return x
241
+
242
+
243
+ class CLIP(nn.Module):
244
+ def __init__(self,
245
+ embed_dim: int,
246
+ # vision
247
+ image_resolution: int,
248
+ vision_layers: Union[Tuple[int, int, int, int], int],
249
+ vision_width: int,
250
+ vision_patch_size: int,
251
+ # text
252
+ context_length: int,
253
+ vocab_size: int,
254
+ transformer_width: int,
255
+ transformer_heads: int,
256
+ transformer_layers: int
257
+ ):
258
+ super().__init__()
259
+
260
+ self.context_length = context_length
261
+
262
+ if isinstance(vision_layers, (tuple, list)):
263
+ vision_heads = vision_width * 32 // 64
264
+ self.visual = ModifiedResNet(
265
+ layers=vision_layers,
266
+ output_dim=embed_dim,
267
+ heads=vision_heads,
268
+ input_resolution=image_resolution,
269
+ width=vision_width
270
+ )
271
+ else:
272
+ vision_heads = vision_width // 64
273
+ self.visual = VisionTransformer(
274
+ input_resolution=image_resolution,
275
+ patch_size=vision_patch_size,
276
+ width=vision_width,
277
+ layers=vision_layers,
278
+ heads=vision_heads,
279
+ output_dim=embed_dim
280
+ )
281
+
282
+ self.transformer = Transformer(
283
+ width=transformer_width,
284
+ layers=transformer_layers,
285
+ heads=transformer_heads,
286
+ attn_mask=self.build_attention_mask()
287
+ )
288
+
289
+ self.vocab_size = vocab_size
290
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292
+ self.ln_final = LayerNorm(transformer_width)
293
+
294
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296
+
297
+ self.initialize_parameters()
298
+
299
+ def initialize_parameters(self):
300
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
301
+ nn.init.normal_(self.positional_embedding, std=0.01)
302
+
303
+ if isinstance(self.visual, ModifiedResNet):
304
+ if self.visual.attnpool is not None:
305
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
306
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310
+
311
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312
+ for name, param in resnet_block.named_parameters():
313
+ if name.endswith("bn3.weight"):
314
+ nn.init.zeros_(param)
315
+
316
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317
+ attn_std = self.transformer.width ** -0.5
318
+ fc_std = (2 * self.transformer.width) ** -0.5
319
+ for block in self.transformer.resblocks:
320
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324
+
325
+ if self.text_projection is not None:
326
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327
+
328
+ def build_attention_mask(self):
329
+ # lazily create causal attention mask, with full attention between the vision tokens
330
+ # pytorch uses additive attention mask; fill with -inf
331
+ mask = torch.empty(self.context_length, self.context_length)
332
+ mask.fill_(float("-inf"))
333
+ mask.triu_(1) # zero out the lower diagonal
334
+ return mask
335
+
336
+ @property
337
+ def dtype(self):
338
+ return self.visual.conv1.weight.dtype
339
+
340
+ def encode_image(self, image):
341
+ return self.visual(image.type(self.dtype))
342
+
343
+ def encode_text(self, text):
344
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345
+
346
+ x = x + self.positional_embedding.type(self.dtype)
347
+ x = x.permute(1, 0, 2) # NLD -> LND
348
+ x = self.transformer(x)
349
+ x = x.permute(1, 0, 2) # LND -> NLD
350
+ x = self.ln_final(x).type(self.dtype)
351
+
352
+ # x.shape = [batch_size, n_ctx, transformer.width]
353
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
354
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355
+
356
+ return x
357
+
358
+ def forward(self, image, text):
359
+ image_features = self.encode_image(image)
360
+ text_features = self.encode_text(text)
361
+
362
+ # normalized features
363
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
364
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
365
+
366
+ # cosine similarity as logits
367
+ logit_scale = self.logit_scale.exp()
368
+ logits_per_image = logit_scale * image_features @ text_features.t()
369
+ logits_per_text = logits_per_image.t()
370
+
371
+ # shape = [global_batch_size, global_batch_size]
372
+ return logits_per_image, logits_per_text
373
+
374
+
375
+ def convert_weights(model: nn.Module):
376
+ """Convert applicable model parameters to fp16"""
377
+
378
+ def _convert_weights_to_fp16(l):
379
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380
+ l.weight.data = l.weight.data.half()
381
+ if l.bias is not None:
382
+ l.bias.data = l.bias.data.half()
383
+
384
+ if isinstance(l, nn.MultiheadAttention):
385
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386
+ tensor = getattr(l, attr)
387
+ if tensor is not None:
388
+ tensor.data = tensor.data.half()
389
+
390
+ for name in ["text_projection", "proj"]:
391
+ if hasattr(l, name):
392
+ attr = getattr(l, name)
393
+ if attr is not None:
394
+ attr.data = attr.data.half()
395
+
396
+ model.apply(_convert_weights_to_fp16)
397
+
398
+
399
+ def build_model(state_dict: dict):
400
+ vit = "visual.proj" in state_dict
401
+
402
+ if vit:
403
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
404
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407
+ image_resolution = vision_patch_size * grid_size
408
+ else:
409
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410
+ vision_layers = tuple(counts)
411
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413
+ vision_patch_size = None
414
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415
+ image_resolution = output_width * 32
416
+
417
+ embed_dim = state_dict["text_projection"].shape[1]
418
+ context_length = state_dict["positional_embedding"].shape[0]
419
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
420
+ transformer_width = state_dict["ln_final.weight"].shape[0]
421
+ transformer_heads = transformer_width // 64
422
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423
+
424
+ model = CLIP(
425
+ embed_dim,
426
+ image_resolution, vision_layers, vision_width, vision_patch_size,
427
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428
+ )
429
+
430
+ for key in ["input_resolution", "context_length", "vocab_size"]:
431
+ if key in state_dict:
432
+ del state_dict[key]
433
+
434
+ # convert_weights(model)
435
+ model.load_state_dict(state_dict)
436
+ del state_dict
437
+ torch.cuda.empty_cache()
438
+ return model.eval()
clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
data/__init__.py ADDED
File without changes
data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (146 Bytes). View file
 
data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (162 Bytes). View file
 
data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (150 Bytes). View file
 
data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (144 Bytes). View file
 
data/__pycache__/augmix_ops.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
data/__pycache__/augmix_ops.cpython-311.pyc ADDED
Binary file (6.76 kB). View file
 
data/__pycache__/augmix_ops.cpython-312.pyc ADDED
Binary file (6.24 kB). View file
 
data/__pycache__/augmix_ops.cpython-39.pyc ADDED
Binary file (3.88 kB). View file
 
data/__pycache__/cls_to_names.cpython-310.pyc ADDED
Binary file (23.9 kB). View file
 
data/__pycache__/cls_to_names.cpython-312.pyc ADDED
Binary file (26.6 kB). View file
 
data/__pycache__/cls_to_names.cpython-39.pyc ADDED
Binary file (19.9 kB). View file