clip-vit-base-mmhs150k-fusion / modeling_clip_fusion.py
Amirhossein75's picture
Upload modeling_clip_fusion.py with huggingface_hub
6960843 verified
"""
Multi-Modal Fusion Classifier with Gated Fusion Mechanism.
This module implements a late-fusion classifier that combines text and image
features from pre-trained vision-language models (CLIP, SigLIP).
To use this model:
from transformers import AutoModel
model = AutoModel.from_pretrained(
"Amirhossein75/clip-vit-base-mmhs150k-fusion",
trust_remote_code=True
)
"""
from typing import Optional, Dict, Any, Union, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor, AutoModel, PreTrainedModel, PretrainedConfig
class ClipFusionConfig(PretrainedConfig):
"""Configuration for MultiModalFusionClassifier."""
model_type = "clip-fusion"
def __init__(
self,
encoder_name: str = "openai/clip-vit-base-patch32",
num_labels: int = 5,
fusion_dim: int = 512,
backend: str = "clip",
freeze_text: bool = False,
freeze_image: bool = False,
loss_type: str = "bce",
focal_gamma: float = 1.5,
class_names: List[str] = None,
thresholds: List[float] = None,
**kwargs
):
super().__init__(**kwargs)
self.encoder_name = encoder_name
self.num_labels = num_labels
self.fusion_dim = fusion_dim
self.backend = backend
self.freeze_text = freeze_text
self.freeze_image = freeze_image
self.loss_type = loss_type
self.focal_gamma = focal_gamma
self.class_names = class_names or ["racist", "sexist", "homophobe", "religion", "otherhate"]
self.thresholds = thresholds or [0.35, 0.7, 0.75, 0.3, 0.6]
class FocalWithLogitsLoss(nn.Module):
"""
Focal Loss with logits for handling class imbalance.
Focal loss down-weights well-classified examples and focuses on hard examples.
Args:
alpha: Weighting factor for positive class.
gamma: Focusing parameter. Higher gamma increases focus on hard examples.
reduction: Reduction method ('mean', 'sum', 'none').
"""
def __init__(
self,
alpha: Optional[torch.Tensor] = None,
gamma: float = 1.5,
reduction: str = "mean"
):
super().__init__()
self.register_buffer("alpha", alpha if alpha is not None else None)
self.gamma = gamma
self.reduction = reduction
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
prob = torch.sigmoid(logits)
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce * ((1 - p_t) ** self.gamma)
if self.alpha is not None:
loss = loss * (self.alpha * targets + (1 - self.alpha) * (1 - targets))
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
return loss
class MultiModalFusionClassifier(PreTrainedModel):
"""
Multi-modal classifier with gated fusion and interaction features.
Supports two backends:
- "clip": Uses CLIPModel.from_pretrained()
- "siglip": Uses AutoModel.from_pretrained() for SigLIP/SigLIP2 checkpoints
Architecture:
1. Encode text and image using pre-trained encoder
2. Project embeddings to fusion dimension
3. Apply gated fusion mechanism with modality presence flags
4. Combine interaction features (difference, product)
5. Classify using MLP head
"""
config_class = ClipFusionConfig
def __init__(self, config: ClipFusionConfig):
super().__init__(config)
self.config = config
self.backend = config.backend.lower()
# Load backbone encoder
if self.backend == "clip":
self.backbone = CLIPModel.from_pretrained(config.encoder_name)
d = self.backbone.config.projection_dim
if config.freeze_text:
for p in self.backbone.text_model.parameters():
p.requires_grad = False
if config.freeze_image:
for p in self.backbone.vision_model.parameters():
p.requires_grad = False
else:
# SigLIP/SigLIP2 or any CLIP-like dual encoder via AutoModel
self.backbone = AutoModel.from_pretrained(config.encoder_name)
cfg = self.backbone.config
d = getattr(cfg, "projection_dim", None)
if d is None and hasattr(cfg, "text_config"):
d = getattr(cfg.text_config, "projection_size", None) or \
getattr(cfg.text_config, "hidden_size", None)
if d is None:
d = getattr(getattr(cfg, "vision_config", None), "hidden_size", None) or \
getattr(getattr(cfg, "text_config", None), "hidden_size", None)
assert d is not None, "Could not infer projection dim for the chosen encoder."
if config.freeze_text and hasattr(self.backbone, "text_model"):
for p in self.backbone.text_model.parameters():
p.requires_grad = False
if config.freeze_image and hasattr(self.backbone, "vision_model"):
for p in self.backbone.vision_model.parameters():
p.requires_grad = False
# Projection layers
self.proj_t = nn.Linear(d, config.fusion_dim)
self.proj_i = nn.Linear(d, config.fusion_dim)
# Gated fusion layers
self.g_t = nn.Linear(config.fusion_dim, config.fusion_dim)
self.g_i = nn.Linear(config.fusion_dim, config.fusion_dim)
self.gate = nn.Linear(config.fusion_dim * 2 + 2, config.fusion_dim) # +2 for presence flags
# Interaction-enhanced classifier: [fused, t, v, |t-v|, t*v]
cls_in = config.fusion_dim * 5
self.cls = nn.Sequential(
nn.LayerNorm(cls_in),
nn.Linear(cls_in, config.fusion_dim),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(config.fusion_dim, config.num_labels),
)
self.ln_fused = nn.LayerNorm(config.fusion_dim)
# Loss configuration
self.loss_type = config.loss_type
self.register_buffer("pos_weight", None)
if config.loss_type == "focal":
self.criterion = FocalWithLogitsLoss(gamma=config.focal_gamma)
else:
self.criterion = None
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: torch.Tensor,
text_present: Optional[torch.Tensor] = None,
image_present: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Dict[str, Any]:
"""
Forward pass.
Args:
input_ids: Tokenized text input [B, seq_len].
attention_mask: Attention mask for text [B, seq_len].
pixel_values: Preprocessed image tensor [B, C, H, W].
text_present: Binary flag indicating text presence [B]. Defaults to all 1s.
image_present: Binary flag indicating image presence [B]. Defaults to all 1s.
labels: Ground truth labels [B, num_labels].
return_dict: Whether to return a dictionary.
Returns:
Dictionary with 'loss' (if labels provided) and 'logits'.
"""
batch_size = input_ids.shape[0]
device = input_ids.device
# Default presence flags to all present
if text_present is None:
text_present = torch.ones(batch_size, device=device)
if image_present is None:
image_present = torch.ones(batch_size, device=device)
# Extract features
t_kwargs = {"input_ids": input_ids}
if attention_mask is not None:
t_kwargs["attention_mask"] = attention_mask
tfeat = self.backbone.get_text_features(**t_kwargs)
vfeat = self.backbone.get_image_features(pixel_values=pixel_values)
# Normalize and mask by presence
tfeat = F.normalize(tfeat, dim=-1) * text_present.unsqueeze(1)
vfeat = F.normalize(vfeat, dim=-1) * image_present.unsqueeze(1)
# Project to fusion dimension
tfeat_p = self.proj_t(tfeat)
vfeat_p = self.proj_i(vfeat)
# Gated fusion
zt = torch.tanh(self.g_t(tfeat_p))
zi = torch.tanh(self.g_i(vfeat_p))
presence = torch.stack([text_present, image_present], dim=1)
g = torch.sigmoid(self.gate(torch.cat([tfeat_p, vfeat_p, presence], dim=1)))
# Conditional fusion based on modality presence
fused = torch.where(
(image_present < 0.5).unsqueeze(1), zt,
torch.where((text_present < 0.5).unsqueeze(1), zi, g * zt + (1.0 - g) * zi)
)
fused = self.ln_fused(fused)
# Interaction features
feat = torch.cat([
fused,
tfeat_p,
vfeat_p,
torch.abs(tfeat_p - vfeat_p),
tfeat_p * vfeat_p
], dim=1)
logits = self.cls(feat)
# Compute loss if labels provided
loss = None
if labels is not None:
if self.loss_type == "focal":
loss = self.criterion(logits, labels)
else:
loss = F.binary_cross_entropy_with_logits(
logits, labels,
pos_weight=self.pos_weight if self.pos_weight is not None else None
)
return {"loss": loss, "logits": logits}
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: torch.Tensor,
text_present: Optional[torch.Tensor] = None,
image_present: Optional[torch.Tensor] = None,
) -> Dict[str, Any]:
"""
Make predictions with optimized thresholds.
Returns:
Dictionary with 'predictions' (bool dict), 'probabilities' (float dict), and 'logits'.
"""
self.eval()
with torch.no_grad():
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
text_present=text_present,
image_present=image_present,
)
logits = outputs["logits"]
probabilities = torch.sigmoid(logits)
# Apply per-class thresholds
thresholds = torch.tensor(self.config.thresholds, device=logits.device)
predictions = (probabilities > thresholds).int()
# Format results
batch_results = []
for i in range(logits.shape[0]):
result = {
"predictions": {name: bool(predictions[i, j]) for j, name in enumerate(self.config.class_names)},
"probabilities": {name: float(probabilities[i, j]) for j, name in enumerate(self.config.class_names)},
}
batch_results.append(result)
return batch_results[0] if len(batch_results) == 1 else batch_results