# ============================================================ # LOOKTHEM V7.6 FULL TRAINING + INFERENCE # Backbone + Lite Residual Classifier # ============================================================ import os import io import math from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import Dataset, DataLoader import torchvision.transforms as transforms from datasets import load_dataset # ============================================================ # CONFIG # ============================================================ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") BATCH_SIZE_TRAIN = 96 BATCH_SIZE_VAL = 32 EPOCHS = 20 LR = 1e-3 WEIGHT_DECAY = 1e-4 MODEL_SAVE_PATH = "LookThem_V76_Full_LiteResidual.pth" # ============================================================ # TRANSFORM # ============================================================ transform_train = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) ) ]) transform_val = transforms.Compose([ transforms.Lambda(lambda img: img.convert("RGB")), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) ) ]) # ============================================================ # DATASET # ============================================================ class ImageNet100ParquetDataset(Dataset): def __init__(self, hf_subset, transform=None): self.dataset = hf_subset self.transform = transform def __getitem__(self, index): row = self.dataset[index] img_data = row["image"] if isinstance(img_data, dict) and "bytes" in img_data: img = Image.open(io.BytesIO(img_data["bytes"])) elif isinstance(img_data, Image.Image): img = img_data else: img = Image.open(io.BytesIO(img_data)) label = row["label"] if self.transform: img = self.transform(img) return img, label def __len__(self): return len(self.dataset) # ============================================================ # LOAD DATASET # ============================================================ print("๐Ÿ“ก Loading ImageNet-100...") raw_train = load_dataset( "clane9/imagenet-100", split="train" ) raw_val = load_dataset( "clane9/imagenet-100", split="validation" ) train_dataset = ImageNet100ParquetDataset( raw_train, transform=transform_train ) val_dataset = ImageNet100ParquetDataset( raw_val, transform=transform_val ) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=2, pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE_VAL, shuffle=False, num_workers=2, pin_memory=True ) # ============================================================ # LOOKTHEM LAYER # ============================================================ class LookThemLayer(nn.Module): def __init__(self, num_tokens, in_features, hidden_dim): super().__init__() self.num_tokens = num_tokens self.mod1_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod1_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod1_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod1_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.mod2_w1 = nn.Parameter( torch.randn(num_tokens, in_features, hidden_dim) ) self.mod2_b1 = nn.Parameter( torch.zeros(num_tokens, hidden_dim) ) self.mod2_w2 = nn.Parameter( torch.randn(num_tokens, hidden_dim, 1) ) self.mod2_b2 = nn.Parameter( torch.zeros(num_tokens, 1) ) self.trans_w = nn.Parameter( torch.randn(num_tokens, 1, 1) ) self.trans_b = nn.Parameter( torch.zeros(num_tokens, 1) ) self._init_weights() def _init_weights(self): for w in [ self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w ]: nn.init.kaiming_uniform_(w, a=math.sqrt(5)) def forward(self, x): N = self.num_tokens h1 = ( torch.einsum( "bti,tij->btj", x, self.mod1_w1 ) + self.mod1_b1 ) out_m1 = ( torch.einsum( "btj,tjk->btk", F.gelu(h1), self.mod1_w2 ) + self.mod1_b2 ) h2 = ( torch.einsum( "bti,tij->btj", x, self.mod2_w1 ) + self.mod2_b1 ) out_m2 = ( torch.einsum( "btj,tjk->btk", F.gelu(h2), self.mod2_w2 ) + self.mod2_b2 ) out_m2_safe = out_m2 + 1e-5 compare = torch.tanh( out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) ) compare2 = torch.tanh( out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) ) bias_reshaped = self.trans_b.view( 1, 1, N, 1 ) trans_compare = ( torch.einsum( "bije,jef->bijf", compare, self.trans_w ) + bias_reshaped ) trans_compare2 = ( torch.einsum( "bije,jef->bijf", compare2, self.trans_w ) + bias_reshaped ) interaksi = ( trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1) ) / 2 mask = 1.0 - torch.eye( N, device=x.device ) interaksi_masked = ( interaksi * mask.view(1, N, N, 1) ) return interaksi_masked.sum(dim=2) / (N - 1.0) # ============================================================ # LITE RESIDUAL BLOCK # ============================================================ class LiteResidualBlock(nn.Module): def __init__(self, dim, dropout=0.05): super().__init__() self.block = nn.Sequential( nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim) ) self.norm = nn.LayerNorm(dim) def forward(self, x): residual = x x = self.block(x) x = x + residual x = self.norm(x) return x # ============================================================ # FULL MODEL # ============================================================ class LookThemV76LiteResidual(nn.Module): def __init__(self): super().__init__() # ==================================================== # STREAM A # ==================================================== self.stream_a = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d( 64, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) # ==================================================== # STREAM B # ==================================================== self.stream_b = nn.Sequential( nn.Conv2d( 3, 16, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(16), nn.GELU(), nn.Conv2d( 16, 32, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(32), nn.GELU(), nn.Conv2d( 32, 64, kernel_size=3, stride=2, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d( 64, 64, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm2d(64), nn.GELU(), nn.AdaptiveMaxPool2d((8, 8)) ) # ==================================================== # LOOKTHEM # ==================================================== self.lookthemA = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=32 ) self.lookthemB = LookThemLayer( num_tokens=64, in_features=64, hidden_dim=32 ) self.lookthem = LookThemLayer( num_tokens=64, in_features=128, hidden_dim=32 ) self.compressor = nn.Conv1d( 128, 64, kernel_size=1 ) self.imageCorrupter = nn.Dropout(0.1) # ==================================================== # CLASSIFIER # ==================================================== self.flatten = nn.Flatten() self.input_proj = nn.Sequential( nn.Linear(4096, 256), nn.GELU(), nn.Dropout(0.08) ) self.res1 = LiteResidualBlock(256, 0.05) self.res2 = LiteResidualBlock(256, 0.05) self.head = nn.Sequential( nn.Linear(256, 128), nn.GELU(), nn.Linear(128, 100) ) def extract_features(self, x): batch_size = x.size(0) # ==================================================== # STREAM A # ==================================================== feat_a = self.stream_a(x) feat_a_tokens = feat_a.view( batch_size, 64, 64 ).transpose(1, 2) feat_a_tokens = self.imageCorrupter( feat_a_tokens ) feat_a_lt = self.lookthemA( feat_a_tokens ) # ==================================================== # STREAM B # ==================================================== feat_b = self.stream_b(x) feat_b_tokens = feat_b.view( batch_size, 64, 64 ).transpose(1, 2) feat_b_tokens = self.imageCorrupter( feat_b_tokens ) feat_b_lt = self.lookthemB( feat_b_tokens ) # ==================================================== # COMBINE # ==================================================== tokens_combined = torch.cat( [feat_a_lt, feat_b_lt], dim=2 ) out_lookthem = self.lookthem( tokens_combined ) out_lookthem = out_lookthem.transpose(1, 2) compressed = self.compressor( out_lookthem ) return compressed def forward(self, x): x = self.extract_features(x) x = self.flatten(x) x = self.input_proj(x) x = self.res1(x) x = self.res2(x) x = self.head(x) return x # ============================================================ # MODEL INIT # ============================================================ model = LookThemV76LiteResidual().to(DEVICE) # ============================================================ # PARAMETER COUNT # ============================================================ total_params = sum( p.numel() for p in model.parameters() ) print(f"\n๐Ÿง  Total Parameters : {total_params:,}") size_mb = total_params * 4 / (1024 * 1024) print(f"๐Ÿ“ฆ Estimated Size : {size_mb:.2f} MB") # ============================================================ # LOSS & OPTIMIZER # ============================================================ criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW( model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY ) scheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=EPOCHS ) # ============================================================ # TRAINING # ============================================================ print("\n๐Ÿš€ Training Started...\n") for epoch in range(EPOCHS): model.train() total_loss = 0 correct = 0 total = 0 for step, (data, target) in enumerate(train_loader): data = data.to(DEVICE) target = target.to(DEVICE) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() if (step + 1) % 100 == 0: print( f"Epoch [{epoch+1:02d}/{EPOCHS}] " f"| Step [{step+1}/{len(train_loader)}] " f"| Loss: {loss.item():.4f}" ) scheduler.step() acc = 100. * correct / total current_lr = optimizer.param_groups[0]["lr"] print( f"\n๐Ÿ Epoch [{epoch+1:02d}/{EPOCHS}] " f"| Loss: {total_loss / len(train_loader):.4f} " f"| Train Acc: {acc:.2f}% " f"| LR: {current_lr:.6f}\n" ) # ============================================================ # VALIDATION # ============================================================ print("\n๐Ÿงช Validation...\n") model.eval() val_loss = 0 val_correct = 0 val_total = 0 with torch.no_grad(): for data, target in val_loader: data = data.to(DEVICE) target = target.to(DEVICE) output = model(data) loss = criterion(output, target) val_loss += loss.item() _, predicted = output.max(1) val_total += target.size(0) val_correct += predicted.eq(target).sum().item() val_acc = 100. * val_correct / val_total print( f"\n๐Ÿ† Validation Accuracy: {val_acc:.2f}%" ) # ============================================================ # SAVE MODEL # ============================================================ torch.save( model.state_dict(), MODEL_SAVE_PATH ) real_size = os.path.getsize( MODEL_SAVE_PATH ) / (1024 * 1024) print("\n๐Ÿ’พ MODEL SAVED!") print(f"๐Ÿ“ Path : {MODEL_SAVE_PATH}") print(f"๐Ÿ“ฆ Size : {real_size:.2f} MB")