MS²-SMILES AlignNet 模型文档

📋 目录

  1. 概述
  2. 数据处理流程
  3. 模型架构
  4. 损失函数设计
  5. 训练流程
  6. 配置参数

概述

MS²-SMILES AlignNet 是一个基于对比学习的跨模态对齐模型,旨在学习质谱(MS²)数据与分子结构(SMILES)之间的语义对齐表示。该模型采用双分支架构,分别编码 MS 谱图和分子结构,并通过对比学习使得匹配的 MS-分子对在嵌入空间中靠近。

核心思想

┌─────────────────────────────────────────────────────────────────────────────┐
│                        MS²-SMILES AlignNet 架构                              │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│   SMILES ──→ [分子编码分支] ──→ 256维向量 ─┐                                 │
│                                            ├──→ 对比学习 ──→ 对齐损失        │
│   MS谱图 ──→ [MS编码分支]  ──→ 256维向量 ──┘                                 │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

数据处理流程

2.1 数据加载入口

文件: train.py (第174-175行)

train_loader = build_loaders(train_set, "train", cfg, 10)
valid_loader = build_loaders(valid_set, "valid", cfg, 10)

数据支持两种输入格式:

  • JSON/MGF 文件列表: 使用 PathDataset
  • 预处理字典列表: 使用 Dataset

2.2 原始数据获取

文件: dataset.py (第103-128行)

PathDataset
    ├── proc_data(): 解析 JSON/MGF 文件
    │   └── 提取: ms (质谱峰列表), smiles (分子结构)
    │
    └── __getitem__(): 调用 calc_feats() 计算特征

数据格式:

  • ms: List of (m/z, intensity) 元组
  • smiles: 分子 SMILES 字符串

2.3 特征预处理

文件: dataset.py (第7-41行) - calc_feats() 函数

┌─────────────────────────────────────────────────────────────────────────────┐
│                           特征计算流程                                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  输入: smi (SMILES), ms (质谱), nls (中性丢失)                               │
│                                                                              │
│  ┌────────────────────────┐  ┌────────────────────────────┐                 │
│  │   MS 特征处理          │  │   分子特征处理              │                 │
│  ├────────────────────────┤  ├────────────────────────────┤                 │
│  │ 1. ms_binner()         │  │ 1. mol_fp_encoder()        │                 │
│  │    → ms_bins           │  │    → mol_fps (指纹)        │                 │
│  │    (39600维稀疏向量)    │  │    (2048维)                │                 │
│  │                        │  │                            │                 │
│  │ 2. ms_feature_processor│  │ 2. mol_graph_featurizer()  │                 │
│  │    → ms_bins1 (100,29) │  │    → V (节点特征)          │                 │
│  │    → ms_bins2 (100,)   │  │    → A (邻接矩阵)          │                 │
│  │    (峰特征 + m/z值)     │  │    → mol_size             │                 │
│  └────────────────────────┘  └────────────────────────────┘                 │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

2.3.1 MS 特征处理详解

方法一: ms_binner() - 稀疏分桶表示

# 参数配置
min_mz = 20, max_mz = 2000, bin_size = 0.05
# 计算: (2000-20)/0.05 = 39600 bins

# 流程:
1. 过滤 m/z 范围外的峰
2. 归一化强度 (0-1)
3. 分桶 (bin_idx = (m/z - min_mz) / bin_size)
4. 可选: 添加中性丢失特征 (额外 4000 bins)

方法二: ms_feature_processor() - 序列化峰表示

输入: ms, precursor_mz, metadata_vec(25维)
输出: 
  - ms_bins1: (100, 29) 峰特征矩阵
  - ms_bins2: (100,) m/z 值序列

# 每个峰的 29 维特征:
[强度(1), 诊断离子标记(1), 中性丢失标记(1), 前体权重(1), 元数据(25)]

2.3.2 分子特征处理详解

Morgan 指纹编码:

mol_fp_encoder(smiles, tp='morgan', nbits=2048)
# → 2048 维二进制向量

分子图特征化:

mol_graph_featurizer(smiles)
# → V: 节点特征矩阵 (N_atoms, 74)
# → A: 邻接张量 (N_atoms, 6, N_atoms)
#      6 通道 = 4 键类型 + 环内 + 共轭
# → mol_size: 原子数量

原子节点特征 (74维):

特征类型 维度 说明
原子类型 12 One-hot: H,C,N,O,S,F,Si,P,Cl,Br,I,B
显式价态 7 One-hot: 0-6
隐式价态 7 One-hot: 0-6
氢原子数 5 One-hot: 0-4
自由基电子 5 One-hot: 0-4
总度数 7 One-hot: 0-6
形式电荷 5 One-hot: -2 to 2
杂化类型 5 SP,SP2,SP3,SP3D,SP3D2
芳香性 2 True/False
环内原子 2 True/False
手性 4 4种类型
CIP编码 2 R/S
范德华半径 1 连续值
手性可能性 1 True/False
原子序数 1 连续值
原子质量 1 连续值 × 0.01
度数 1 连续值
特征不变量 6 6位二进制编码

2.4 数据批处理聚合

文件: train.py (第36-83行) - my_collate() 函数

┌─────────────────────────────────────────────────────────────────────────────┐
│                         Batch 聚合流程                                       │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  输入: [sample_1, sample_2, ..., sample_B]  (B = batch_size)                │
│                                                                              │
│  1. 过滤 None 样本                                                           │
│                                                                              │
│  2. 堆叠各特征:                                                              │
│     ├── ms_bins:  stack → (B, 39600+4000)                                   │
│     ├── ms_bins1: stack → (B, 100, 29)                                      │
│     ├── ms_bins2: stack → (B, 100)                                          │
│     ├── mol_fps:  stack → (B, 2048)                                         │
│     └── mol_fmvec: stack → (B, 10)  [可选: 分子式向量]                       │
│                                                                              │
│  3. 图数据 Padding:                                                          │
│     ├── 计算 max_n = max(所有分子的原子数)                                   │
│     ├── V: pad → (B, max_n, 74)                                             │
│     ├── A: pad → (B, max_n, 6, max_n)                                       │
│     └── mol_size: concat → (B,)                                             │
│                                                                              │
└─────────────────────────────────────────────────────────────────────────────┘

模型架构

文件: modules.py (第298-457行) - FragSimiModelNew

3.1 整体架构图

┌─────────────────────────────────────────────────────────────────────────────────────────┐
│                              FragSimiModelNew 双分支架构                                 │
├─────────────────────────────────────────────────────────────────────────────────────────┤
│                                                                                          │
│  ┌──────────────────────────────────┐    ┌──────────────────────────────────────────┐   │
│  │     分子编码分支 (LEFT)          │    │        MS 编码分支 (RIGHT)                │   │
│  ├──────────────────────────────────┤    ├──────────────────────────────────────────┤   │
│  │                                  │    │                                          │   │
│  │  SMILES                          │    │  MS 谱图                                 │   │
│  │     │                            │    │     │                                    │   │
│  │     ▼                            │    │     ▼                                    │   │
│  │  ┌─────────────┐                 │    │  ┌─────────────────────────────────────┐│   │
│  │  │ RDKit 解析   │                 │    │  │ 分支A: ms_bins (稀疏分桶)           ││   │
│  │  └──────┬──────┘                 │    │  │   │                                 ││   │
│  │         │                        │    │  │   ▼ ms_input_proj (Linear)          ││   │
│  │    ┌────┴────┐                   │    │  │   │                                 ││   │
│  │    ▼         ▼                   │    │  │   ▼ 6层 Transformer                 ││   │
│  │ ┌──────┐ ┌───────┐               │    │  │   │                                 ││   │
│  │ │ 分子图 │ │ Morgan│               │    │  │   ▼ Mean Pooling                   ││   │
│  │ │构建   │ │ 指纹  │               │    │  │   │                                 ││   │
│  │ └──┬───┘ └───┬───┘               │    │  │   ▼ ms_final_proj                   ││   │
│  │    │         │                   │    │  │   └───→ 256维                       ││   │
│  │    ▼         │                   │    │  └─────────────────────────────────────┘│   │
│  │ ┌──────────┐ │                   │    │                     │                    │   │
│  │ │MolGNN    │ │                   │    │                     ▼                    │   │
│  │ │Encoder   │ │                   │    │  ┌─────────────────────────────────────┐│   │
│  │ │(3层GCN   │ │                   │    │  │ 分支B: ms_bins1/2 (峰序列)          ││   │
│  │ │+Attention│ │                   │    │  │   │                                 ││   │
│  │ │+Readout) │ │                   │    │  │   ▼ SinusoidalMzEmbedding           ││   │
│  │ └────┬─────┘ │                   │    │  │   + feature_proj (29→256)           ││   │
│  │      │       │                   │    │  │   │                                 ││   │
│  │      ▼       ▼                   │    │  │   ▼ 6层 Transformer                 ││   │
│  │ ┌──────────────────┐             │    │  │   │                                 ││   │
│  │ │  MolFusionHead   │             │    │  │   ▼ Mean Pooling                   ││   │
│  │ │ (GNN→128, FP→256)│             │    │  │   │                                 ││   │
│  │ │ → Concat → 256   │             │    │  │   ▼ final_proj                      ││   │
│  │ └────────┬─────────┘             │    │  │   └───→ 256维                       ││   │
│  │          │                       │    │  └─────────────────────────────────────┘│   │
│  │          ▼                       │    │                     │                    │   │
│  │     256维向量                    │    │                     ▼                    │   │
│  │                                  │    │  ┌─────────────────────────────────────┐│   │
│  │                                  │    │  │    Concat (256+256=512)             ││   │
│  │                                  │    │  │           │                         ││   │
│  │                                  │    │  │           ▼                         ││   │
│  │                                  │    │  │    all_final_proj → 256维           ││   │
│  │                                  │    │  └─────────────────────────────────────┘│   │
│  └──────────────────────────────────┘    └──────────────────────────────────────────┘   │
│                  │                                           │                          │
│                  ▼                                           ▼                          │
│          mol_embeddings (256)                        ms_embeddings (256)                │
│                  │                                           │                          │
│                  └───────────────────┬───────────────────────┘                          │
│                                      ▼                                                  │
│                          ┌───────────────────────┐                                      │
│                          │    L2 Normalize       │                                      │
│                          └───────────┬───────────┘                                      │
│                                      ▼                                                  │
│                          ┌───────────────────────┐                                      │
│                          │   HybridAlignLoss     │                                      │
│                          │ (InfoNCE + Tanimoto)  │                                      │
│                          └───────────────────────┘                                      │
│                                                                                          │
└─────────────────────────────────────────────────────────────────────────────────────────┘

3.2 分子编码分支

3.2.1 MolGNNEncoder

class MolGNNEncoder(nn.Module):
    # 输入: V (B, N, 74), A (B, N, 6, N), mol_size (B,)
    # 输出: (B, mol_embedding_dim)
    
    结构:
    ├── block_layers: 3 × GConvBlockNoGF
    │   └── GraphCNNLayer: 图卷积 + BatchNorm + ReLU
    │       [74256256256]
    │
    ├── attention_layer: MultiHeadGlobalAttention (4 heads)
    │   └── 对节点进行加权聚合,去除 padding 影响
    │       [256256×4 = 1024]
    │
    └── readout_layers: 2 × Linear + GELU
        └── [102420482048]

3.2.2 MolFusionHead

class MolFusionHead(nn.Module):
    # 输入: gnn_feat (B, 2048), fps (B, 2048)
    # 输出: (B, 256)
    
    结构:
    ├── gnn_proj: Linear(2048128)
    │
    ├── fp_proj: Linear(2048256) + LayerNorm + GELU
    │
    └── fusion_layer: 
        ├── Concat: [128, 256] → 384
        └── Linear(384256) + LayerNorm + Dropout + GELU

3.3 MS 编码分支

3.3.1 分支A: 稀疏分桶编码

# 输入: ms_bins (B, 43600)
# 输出: 256维

流程:
1. ms_input_proj: Linear(43600256)
2. unsqueeze → (B, 1, 256) [假序列]
3. ms_transformer: 6层 Transformer (width=256, heads=8)
4. mean(dim=1) → (B, 256)
5. ms_final_proj: Linear + LayerNorm + Dropout → 256

3.3.2 分支B: 峰序列编码

# 输入: ms_bins1 (B, 100, 29), ms_bins2 (B, 100)
# 输出: 256维

流程:
1. mz_embedder(ms_bins2) → (B, 100, 256)  # 正弦位置编码
2. feature_proj(ms_bins1) → (B, 100, 256)  # 峰特征投影
3. 相加融合 → (B, 100, 256)
4. transformer: 6层 Transformer
5. mean(dim=1) → (B, 256)
6. final_proj: Linear + LayerNorm + Dropout → 256

3.3.3 正弦 m/z 嵌入

class SinusoidalMzEmbedding(nn.Module):
    # 将 m/z 值编码为 256 维向量
    # 类似 Transformer 的位置编码
    
    PE[pos, 2i]   = sin(pos / 10000^(2i/256))
    PE[pos, 2i+1] = cos(pos / 10000^(2i/256))

3.3.4 双分支融合

# 两个 MS 分支的输出融合
out1 = torch.cat([ms_embeddings_A, ms_embeddings_B], dim=-1)  # (B, 512)
ms_embeddings = all_final_proj(out1)  # Linear(512→256) + LayerNorm + Dropout

3.4 Transformer 结构

文件: cliplayers.py - 采用 CLIP 风格的 Transformer

class ResidualAttentionBlock(nn.Module):
    # 单个 Transformer Block
    
    结构:
    ├── ln_1 → MultiheadAttention → 残差连接
    │
    └── ln_2 → MLP (Linear→QuickGELU→Linear) → 残差连接
    
    # QuickGELU: x * sigmoid(1.702 * x)
    # 比标准 GELU 更快

损失函数设计

文件: modules.py (第196-233行) - HybridAlignLoss

4.1 联合损失公式

Ltotal=βLInfoNCE+αLTanimotoMSE\mathcal{L}_{total} = \beta \cdot \mathcal{L}_{InfoNCE} + \alpha \cdot \mathcal{L}_{Tanimoto-MSE}

默认参数: α = 0.5, β = 1.0

4.2 InfoNCE 对比损失

# 目标: 让匹配的 MS-分子对相似度高,不匹配的对相似度低

logits = (ms_emb @ mol_emb.T) / temperature  # 相似度矩阵
labels = torch.arange(batch_size)            # 对角线为正样本

loss_i2t = CrossEntropy(logits, labels)      # MS → Mol
loss_t2i = CrossEntropy(logits.T, labels)    # Mol → MS

L_InfoNCE = (loss_i2t + loss_t2i) / 2

温度参数: τ = 0.07

4.3 Tanimoto MSE 损失

# 目标: 让嵌入空间的相似度矩阵逼近真实的分子结构相似度

# 1. 预测的相似度矩阵 (余弦相似度)
pred_sim = ms_emb @ mol_emb.T  # 因为已 L2 normalize

# 2. 真实的结构相似度 (Tanimoto 系数)
#    Tanimoto(A,B) = (A·B) / (|A|² + |B|² - A·B)
target_sim = batch_tanimoto_sim(mol_fps, mol_fps)

# 3. MSE 损失
L_MSE = MSELoss(pred_sim, target_sim)

4.4 损失设计的意义

损失组件 作用 特点
InfoNCE 跨模态对齐 区分正负样本对,对比学习核心
Tanimoto MSE 结构约束 保持分子结构相似性关系
┌─────────────────────────────────────────────────────────────────┐
│                    联合损失的几何意义                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   InfoNCE: "匹配的 MS-Mol 对要靠近"                             │
│                                                                  │
│       MS₁ ←──── 拉近 ────→ Mol₁                                 │
│       MS₂ ←──── 拉近 ────→ Mol₂                                 │
│        ↑                    ↑                                    │
│     推远                  推远                                   │
│        ↓                    ↓                                    │
│       MS₃ ←──── 拉近 ────→ Mol₃                                 │
│                                                                  │
│   Tanimoto MSE: "结构相似的分子在嵌入空间也要相似"               │
│                                                                  │
│     如果 Tanimoto(Mol₁, Mol₂) = 0.8                             │
│     则希望 CosSim(Emb₁, Emb₂) ≈ 0.8                             │
│                                                                  │
└─────────────────────────────────────────────────────────────────┘

训练流程

文件: train.py

5.1 训练入口

# main() 函数

1. 数据划分: make_train_valid(data, valid_ratio=0.1)
2. 构建 DataLoader: build_loaders()
3. 初始化模型: FragSimiModelNew(cfg).to(device)
4. 优化器: AdamW(lr=1e-3, weight_decay=1e-3)
5. 学习率调度: ReduceLROnPlateau(patience=2, factor=0.5)
6. 训练循环: epochs=50

5.2 训练周期

def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    for batch in train_loader:
        # 1. 数据移到 GPU
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # 2. 前向传播
        loss = model(batch)  # 返回总损失
        
        # 3. 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

5.3 模型保存策略

# 保存最佳模型 (基于验证集损失)
if valid_loss < best_loss:
    checkpoint = {
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'config': dict(CFG)
    }
    torch.save(checkpoint, best_model_fn)

# 只保留最新的 3 个最佳模型
keep_best_models_num = 3

配置参数

文件: config.py

6.1 核心参数

参数 默认值 说明
batch_size 64 批次大小
lr 1e-3 学习率
epochs 50 训练轮数
dropout 0.1 Dropout 比例
projection_dim 256 最终嵌入维度

6.2 MS 相关参数

参数 默认值 说明
min_mz 20 最小 m/z
max_mz 2000 最大 m/z
bin_size 0.05 分桶大小
add_nl True 是否添加中性丢失特征
binary_intn False 是否二值化强度

6.3 分子编码参数

参数 默认值 说明
mol_encoder 'gnn+fp' 编码方式 (fp/gnn/gnn+fp)
mol_embedding_dim 2048 Morgan 指纹维度
fptype 'morgan' 指纹类型
molgnn_n_filters_list [256,256,256] GNN 各层维度
molgnn_nhead 4 注意力头数
molgnn_readout_layers 2 Readout 层数

6.4 Transformer 参数

参数 默认值 说明
tsfm_layers 6 Transformer 层数
tsfm_heads 8 注意力头数
tsfm_in_ms True MS 分支使用 Transformer
tsfm_in_mol False 分子分支使用 Transformer

附录: 文件结构

train-001/
├── train.py          # 训练入口
├── dataset.py        # 数据集定义
├── modules.py        # 模型定义 (FragSimiModelNew)
├── config.py         # 配置参数
├── utils.py          # 工具函数 (特征处理等)
├── cliplayers.py     # CLIP 风格的 Transformer
└── GNN/
    ├── layers.py     # 图卷积层
    └── featurizer.py # 分子图特征化

参考

本模型设计参考了以下工作:

  • CLIP (Contrastive Language-Image Pre-training)
  • Graph Attention Networks (GAT)
  • Morgan Fingerprints (Extended Connectivity Fingerprints)
  • Tanimoto Coefficient for molecular similarity

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support