Image-Text-to-Text
Transformers
Safetensors
PyTorch
English
clip-fusion
feature-extraction
multimodal
hate-speech-detection
content-moderation
clip
vision-language
image-text
classification
multi-label-classification
social-media
meme-classification
text-image
late-fusion
gated-attention
custom_code
Eval Results (legacy)
| """ | |
| Multi-Modal Fusion Classifier with Gated Fusion Mechanism. | |
| This module implements a late-fusion classifier that combines text and image | |
| features from pre-trained vision-language models (CLIP, SigLIP). | |
| To use this model: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained( | |
| "Amirhossein75/clip-vit-base-mmhs150k-fusion", | |
| trust_remote_code=True | |
| ) | |
| """ | |
| from typing import Optional, Dict, Any, Union, List | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import CLIPModel, CLIPProcessor, AutoModel, PreTrainedModel, PretrainedConfig | |
| class ClipFusionConfig(PretrainedConfig): | |
| """Configuration for MultiModalFusionClassifier.""" | |
| model_type = "clip-fusion" | |
| def __init__( | |
| self, | |
| encoder_name: str = "openai/clip-vit-base-patch32", | |
| num_labels: int = 5, | |
| fusion_dim: int = 512, | |
| backend: str = "clip", | |
| freeze_text: bool = False, | |
| freeze_image: bool = False, | |
| loss_type: str = "bce", | |
| focal_gamma: float = 1.5, | |
| class_names: List[str] = None, | |
| thresholds: List[float] = None, | |
| **kwargs | |
| ): | |
| super().__init__(**kwargs) | |
| self.encoder_name = encoder_name | |
| self.num_labels = num_labels | |
| self.fusion_dim = fusion_dim | |
| self.backend = backend | |
| self.freeze_text = freeze_text | |
| self.freeze_image = freeze_image | |
| self.loss_type = loss_type | |
| self.focal_gamma = focal_gamma | |
| self.class_names = class_names or ["racist", "sexist", "homophobe", "religion", "otherhate"] | |
| self.thresholds = thresholds or [0.35, 0.7, 0.75, 0.3, 0.6] | |
| class FocalWithLogitsLoss(nn.Module): | |
| """ | |
| Focal Loss with logits for handling class imbalance. | |
| Focal loss down-weights well-classified examples and focuses on hard examples. | |
| Args: | |
| alpha: Weighting factor for positive class. | |
| gamma: Focusing parameter. Higher gamma increases focus on hard examples. | |
| reduction: Reduction method ('mean', 'sum', 'none'). | |
| """ | |
| def __init__( | |
| self, | |
| alpha: Optional[torch.Tensor] = None, | |
| gamma: float = 1.5, | |
| reduction: str = "mean" | |
| ): | |
| super().__init__() | |
| self.register_buffer("alpha", alpha if alpha is not None else None) | |
| self.gamma = gamma | |
| self.reduction = reduction | |
| def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | |
| prob = torch.sigmoid(logits) | |
| ce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") | |
| p_t = prob * targets + (1 - prob) * (1 - targets) | |
| loss = ce * ((1 - p_t) ** self.gamma) | |
| if self.alpha is not None: | |
| loss = loss * (self.alpha * targets + (1 - self.alpha) * (1 - targets)) | |
| if self.reduction == "mean": | |
| return loss.mean() | |
| elif self.reduction == "sum": | |
| return loss.sum() | |
| return loss | |
| class MultiModalFusionClassifier(PreTrainedModel): | |
| """ | |
| Multi-modal classifier with gated fusion and interaction features. | |
| Supports two backends: | |
| - "clip": Uses CLIPModel.from_pretrained() | |
| - "siglip": Uses AutoModel.from_pretrained() for SigLIP/SigLIP2 checkpoints | |
| Architecture: | |
| 1. Encode text and image using pre-trained encoder | |
| 2. Project embeddings to fusion dimension | |
| 3. Apply gated fusion mechanism with modality presence flags | |
| 4. Combine interaction features (difference, product) | |
| 5. Classify using MLP head | |
| """ | |
| config_class = ClipFusionConfig | |
| def __init__(self, config: ClipFusionConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.backend = config.backend.lower() | |
| # Load backbone encoder | |
| if self.backend == "clip": | |
| self.backbone = CLIPModel.from_pretrained(config.encoder_name) | |
| d = self.backbone.config.projection_dim | |
| if config.freeze_text: | |
| for p in self.backbone.text_model.parameters(): | |
| p.requires_grad = False | |
| if config.freeze_image: | |
| for p in self.backbone.vision_model.parameters(): | |
| p.requires_grad = False | |
| else: | |
| # SigLIP/SigLIP2 or any CLIP-like dual encoder via AutoModel | |
| self.backbone = AutoModel.from_pretrained(config.encoder_name) | |
| cfg = self.backbone.config | |
| d = getattr(cfg, "projection_dim", None) | |
| if d is None and hasattr(cfg, "text_config"): | |
| d = getattr(cfg.text_config, "projection_size", None) or \ | |
| getattr(cfg.text_config, "hidden_size", None) | |
| if d is None: | |
| d = getattr(getattr(cfg, "vision_config", None), "hidden_size", None) or \ | |
| getattr(getattr(cfg, "text_config", None), "hidden_size", None) | |
| assert d is not None, "Could not infer projection dim for the chosen encoder." | |
| if config.freeze_text and hasattr(self.backbone, "text_model"): | |
| for p in self.backbone.text_model.parameters(): | |
| p.requires_grad = False | |
| if config.freeze_image and hasattr(self.backbone, "vision_model"): | |
| for p in self.backbone.vision_model.parameters(): | |
| p.requires_grad = False | |
| # Projection layers | |
| self.proj_t = nn.Linear(d, config.fusion_dim) | |
| self.proj_i = nn.Linear(d, config.fusion_dim) | |
| # Gated fusion layers | |
| self.g_t = nn.Linear(config.fusion_dim, config.fusion_dim) | |
| self.g_i = nn.Linear(config.fusion_dim, config.fusion_dim) | |
| self.gate = nn.Linear(config.fusion_dim * 2 + 2, config.fusion_dim) # +2 for presence flags | |
| # Interaction-enhanced classifier: [fused, t, v, |t-v|, t*v] | |
| cls_in = config.fusion_dim * 5 | |
| self.cls = nn.Sequential( | |
| nn.LayerNorm(cls_in), | |
| nn.Linear(cls_in, config.fusion_dim), | |
| nn.GELU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(config.fusion_dim, config.num_labels), | |
| ) | |
| self.ln_fused = nn.LayerNorm(config.fusion_dim) | |
| # Loss configuration | |
| self.loss_type = config.loss_type | |
| self.register_buffer("pos_weight", None) | |
| if config.loss_type == "focal": | |
| self.criterion = FocalWithLogitsLoss(gamma=config.focal_gamma) | |
| else: | |
| self.criterion = None | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| pixel_values: torch.Tensor, | |
| text_present: Optional[torch.Tensor] = None, | |
| image_present: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| return_dict: bool = True, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Tokenized text input [B, seq_len]. | |
| attention_mask: Attention mask for text [B, seq_len]. | |
| pixel_values: Preprocessed image tensor [B, C, H, W]. | |
| text_present: Binary flag indicating text presence [B]. Defaults to all 1s. | |
| image_present: Binary flag indicating image presence [B]. Defaults to all 1s. | |
| labels: Ground truth labels [B, num_labels]. | |
| return_dict: Whether to return a dictionary. | |
| Returns: | |
| Dictionary with 'loss' (if labels provided) and 'logits'. | |
| """ | |
| batch_size = input_ids.shape[0] | |
| device = input_ids.device | |
| # Default presence flags to all present | |
| if text_present is None: | |
| text_present = torch.ones(batch_size, device=device) | |
| if image_present is None: | |
| image_present = torch.ones(batch_size, device=device) | |
| # Extract features | |
| t_kwargs = {"input_ids": input_ids} | |
| if attention_mask is not None: | |
| t_kwargs["attention_mask"] = attention_mask | |
| tfeat = self.backbone.get_text_features(**t_kwargs) | |
| vfeat = self.backbone.get_image_features(pixel_values=pixel_values) | |
| # Normalize and mask by presence | |
| tfeat = F.normalize(tfeat, dim=-1) * text_present.unsqueeze(1) | |
| vfeat = F.normalize(vfeat, dim=-1) * image_present.unsqueeze(1) | |
| # Project to fusion dimension | |
| tfeat_p = self.proj_t(tfeat) | |
| vfeat_p = self.proj_i(vfeat) | |
| # Gated fusion | |
| zt = torch.tanh(self.g_t(tfeat_p)) | |
| zi = torch.tanh(self.g_i(vfeat_p)) | |
| presence = torch.stack([text_present, image_present], dim=1) | |
| g = torch.sigmoid(self.gate(torch.cat([tfeat_p, vfeat_p, presence], dim=1))) | |
| # Conditional fusion based on modality presence | |
| fused = torch.where( | |
| (image_present < 0.5).unsqueeze(1), zt, | |
| torch.where((text_present < 0.5).unsqueeze(1), zi, g * zt + (1.0 - g) * zi) | |
| ) | |
| fused = self.ln_fused(fused) | |
| # Interaction features | |
| feat = torch.cat([ | |
| fused, | |
| tfeat_p, | |
| vfeat_p, | |
| torch.abs(tfeat_p - vfeat_p), | |
| tfeat_p * vfeat_p | |
| ], dim=1) | |
| logits = self.cls(feat) | |
| # Compute loss if labels provided | |
| loss = None | |
| if labels is not None: | |
| if self.loss_type == "focal": | |
| loss = self.criterion(logits, labels) | |
| else: | |
| loss = F.binary_cross_entropy_with_logits( | |
| logits, labels, | |
| pos_weight=self.pos_weight if self.pos_weight is not None else None | |
| ) | |
| return {"loss": loss, "logits": logits} | |
| def predict( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| pixel_values: torch.Tensor, | |
| text_present: Optional[torch.Tensor] = None, | |
| image_present: Optional[torch.Tensor] = None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Make predictions with optimized thresholds. | |
| Returns: | |
| Dictionary with 'predictions' (bool dict), 'probabilities' (float dict), and 'logits'. | |
| """ | |
| self.eval() | |
| with torch.no_grad(): | |
| outputs = self.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| pixel_values=pixel_values, | |
| text_present=text_present, | |
| image_present=image_present, | |
| ) | |
| logits = outputs["logits"] | |
| probabilities = torch.sigmoid(logits) | |
| # Apply per-class thresholds | |
| thresholds = torch.tensor(self.config.thresholds, device=logits.device) | |
| predictions = (probabilities > thresholds).int() | |
| # Format results | |
| batch_results = [] | |
| for i in range(logits.shape[0]): | |
| result = { | |
| "predictions": {name: bool(predictions[i, j]) for j, name in enumerate(self.config.class_names)}, | |
| "probabilities": {name: float(probabilities[i, j]) for j, name in enumerate(self.config.class_names)}, | |
| } | |
| batch_results.append(result) | |
| return batch_results[0] if len(batch_results) == 1 else batch_results | |