MS²-SMILES AlignNet 模型文档
📋 目录
概述
MS²-SMILES AlignNet 是一个基于对比学习的跨模态对齐模型,旨在学习质谱(MS²)数据与分子结构(SMILES)之间的语义对齐表示。该模型采用双分支架构,分别编码 MS 谱图和分子结构,并通过对比学习使得匹配的 MS-分子对在嵌入空间中靠近。
核心思想
┌─────────────────────────────────────────────────────────────────────────────┐
│ MS²-SMILES AlignNet 架构 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ SMILES ──→ [分子编码分支] ──→ 256维向量 ─┐ │
│ ├──→ 对比学习 ──→ 对齐损失 │
│ MS谱图 ──→ [MS编码分支] ──→ 256维向量 ──┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
数据处理流程
2.1 数据加载入口
文件: train.py (第174-175行)
train_loader = build_loaders(train_set, "train", cfg, 10)
valid_loader = build_loaders(valid_set, "valid", cfg, 10)
数据支持两种输入格式:
- JSON/MGF 文件列表: 使用
PathDataset类 - 预处理字典列表: 使用
Dataset类
2.2 原始数据获取
文件: dataset.py (第103-128行)
PathDataset
├── proc_data(): 解析 JSON/MGF 文件
│ └── 提取: ms (质谱峰列表), smiles (分子结构)
│
└── __getitem__(): 调用 calc_feats() 计算特征
数据格式:
ms: List of (m/z, intensity) 元组smiles: 分子 SMILES 字符串
2.3 特征预处理
文件: dataset.py (第7-41行) - calc_feats() 函数
┌─────────────────────────────────────────────────────────────────────────────┐
│ 特征计算流程 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: smi (SMILES), ms (质谱), nls (中性丢失) │
│ │
│ ┌────────────────────────┐ ┌────────────────────────────┐ │
│ │ MS 特征处理 │ │ 分子特征处理 │ │
│ ├────────────────────────┤ ├────────────────────────────┤ │
│ │ 1. ms_binner() │ │ 1. mol_fp_encoder() │ │
│ │ → ms_bins │ │ → mol_fps (指纹) │ │
│ │ (39600维稀疏向量) │ │ (2048维) │ │
│ │ │ │ │ │
│ │ 2. ms_feature_processor│ │ 2. mol_graph_featurizer() │ │
│ │ → ms_bins1 (100,29) │ │ → V (节点特征) │ │
│ │ → ms_bins2 (100,) │ │ → A (邻接矩阵) │ │
│ │ (峰特征 + m/z值) │ │ → mol_size │ │
│ └────────────────────────┘ └────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
2.3.1 MS 特征处理详解
方法一: ms_binner() - 稀疏分桶表示
# 参数配置
min_mz = 20, max_mz = 2000, bin_size = 0.05
# 计算: (2000-20)/0.05 = 39600 bins
# 流程:
1. 过滤 m/z 范围外的峰
2. 归一化强度 (0-1)
3. 分桶 (bin_idx = (m/z - min_mz) / bin_size)
4. 可选: 添加中性丢失特征 (额外 4000 bins)
方法二: ms_feature_processor() - 序列化峰表示
输入: ms, precursor_mz, metadata_vec(25维)
输出:
- ms_bins1: (100, 29) 峰特征矩阵
- ms_bins2: (100,) m/z 值序列
# 每个峰的 29 维特征:
[强度(1), 诊断离子标记(1), 中性丢失标记(1), 前体权重(1), 元数据(25)]
2.3.2 分子特征处理详解
Morgan 指纹编码:
mol_fp_encoder(smiles, tp='morgan', nbits=2048)
# → 2048 维二进制向量
分子图特征化:
mol_graph_featurizer(smiles)
# → V: 节点特征矩阵 (N_atoms, 74)
# → A: 邻接张量 (N_atoms, 6, N_atoms)
# 6 通道 = 4 键类型 + 环内 + 共轭
# → mol_size: 原子数量
原子节点特征 (74维):
| 特征类型 | 维度 | 说明 |
|---|---|---|
| 原子类型 | 12 | One-hot: H,C,N,O,S,F,Si,P,Cl,Br,I,B |
| 显式价态 | 7 | One-hot: 0-6 |
| 隐式价态 | 7 | One-hot: 0-6 |
| 氢原子数 | 5 | One-hot: 0-4 |
| 自由基电子 | 5 | One-hot: 0-4 |
| 总度数 | 7 | One-hot: 0-6 |
| 形式电荷 | 5 | One-hot: -2 to 2 |
| 杂化类型 | 5 | SP,SP2,SP3,SP3D,SP3D2 |
| 芳香性 | 2 | True/False |
| 环内原子 | 2 | True/False |
| 手性 | 4 | 4种类型 |
| CIP编码 | 2 | R/S |
| 范德华半径 | 1 | 连续值 |
| 手性可能性 | 1 | True/False |
| 原子序数 | 1 | 连续值 |
| 原子质量 | 1 | 连续值 × 0.01 |
| 度数 | 1 | 连续值 |
| 特征不变量 | 6 | 6位二进制编码 |
2.4 数据批处理聚合
文件: train.py (第36-83行) - my_collate() 函数
┌─────────────────────────────────────────────────────────────────────────────┐
│ Batch 聚合流程 │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: [sample_1, sample_2, ..., sample_B] (B = batch_size) │
│ │
│ 1. 过滤 None 样本 │
│ │
│ 2. 堆叠各特征: │
│ ├── ms_bins: stack → (B, 39600+4000) │
│ ├── ms_bins1: stack → (B, 100, 29) │
│ ├── ms_bins2: stack → (B, 100) │
│ ├── mol_fps: stack → (B, 2048) │
│ └── mol_fmvec: stack → (B, 10) [可选: 分子式向量] │
│ │
│ 3. 图数据 Padding: │
│ ├── 计算 max_n = max(所有分子的原子数) │
│ ├── V: pad → (B, max_n, 74) │
│ ├── A: pad → (B, max_n, 6, max_n) │
│ └── mol_size: concat → (B,) │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
模型架构
文件: modules.py (第298-457行) - FragSimiModelNew 类
3.1 整体架构图
┌─────────────────────────────────────────────────────────────────────────────────────────┐
│ FragSimiModelNew 双分支架构 │
├─────────────────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────────────────────┐ ┌──────────────────────────────────────────┐ │
│ │ 分子编码分支 (LEFT) │ │ MS 编码分支 (RIGHT) │ │
│ ├──────────────────────────────────┤ ├──────────────────────────────────────────┤ │
│ │ │ │ │ │
│ │ SMILES │ │ MS 谱图 │ │
│ │ │ │ │ │ │ │
│ │ ▼ │ │ ▼ │ │
│ │ ┌─────────────┐ │ │ ┌─────────────────────────────────────┐│ │
│ │ │ RDKit 解析 │ │ │ │ 分支A: ms_bins (稀疏分桶) ││ │
│ │ └──────┬──────┘ │ │ │ │ ││ │
│ │ │ │ │ │ ▼ ms_input_proj (Linear) ││ │
│ │ ┌────┴────┐ │ │ │ │ ││ │
│ │ ▼ ▼ │ │ │ ▼ 6层 Transformer ││ │
│ │ ┌──────┐ ┌───────┐ │ │ │ │ ││ │
│ │ │ 分子图 │ │ Morgan│ │ │ │ ▼ Mean Pooling ││ │
│ │ │构建 │ │ 指纹 │ │ │ │ │ ││ │
│ │ └──┬───┘ └───┬───┘ │ │ │ ▼ ms_final_proj ││ │
│ │ │ │ │ │ │ └───→ 256维 ││ │
│ │ ▼ │ │ │ └─────────────────────────────────────┘│ │
│ │ ┌──────────┐ │ │ │ │ │ │
│ │ │MolGNN │ │ │ │ ▼ │ │
│ │ │Encoder │ │ │ │ ┌─────────────────────────────────────┐│ │
│ │ │(3层GCN │ │ │ │ │ 分支B: ms_bins1/2 (峰序列) ││ │
│ │ │+Attention│ │ │ │ │ │ ││ │
│ │ │+Readout) │ │ │ │ │ ▼ SinusoidalMzEmbedding ││ │
│ │ └────┬─────┘ │ │ │ │ + feature_proj (29→256) ││ │
│ │ │ │ │ │ │ │ ││ │
│ │ ▼ ▼ │ │ │ ▼ 6层 Transformer ││ │
│ │ ┌──────────────────┐ │ │ │ │ ││ │
│ │ │ MolFusionHead │ │ │ │ ▼ Mean Pooling ││ │
│ │ │ (GNN→128, FP→256)│ │ │ │ │ ││ │
│ │ │ → Concat → 256 │ │ │ │ ▼ final_proj ││ │
│ │ └────────┬─────────┘ │ │ │ └───→ 256维 ││ │
│ │ │ │ │ └─────────────────────────────────────┘│ │
│ │ ▼ │ │ │ │ │
│ │ 256维向量 │ │ ▼ │ │
│ │ │ │ ┌─────────────────────────────────────┐│ │
│ │ │ │ │ Concat (256+256=512) ││ │
│ │ │ │ │ │ ││ │
│ │ │ │ │ ▼ ││ │
│ │ │ │ │ all_final_proj → 256维 ││ │
│ │ │ │ └─────────────────────────────────────┘│ │
│ └──────────────────────────────────┘ └──────────────────────────────────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ mol_embeddings (256) ms_embeddings (256) │
│ │ │ │
│ └───────────────────┬───────────────────────┘ │
│ ▼ │
│ ┌───────────────────────┐ │
│ │ L2 Normalize │ │
│ └───────────┬───────────┘ │
│ ▼ │
│ ┌───────────────────────┐ │
│ │ HybridAlignLoss │ │
│ │ (InfoNCE + Tanimoto) │ │
│ └───────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────────────────┘
3.2 分子编码分支
3.2.1 MolGNNEncoder
class MolGNNEncoder(nn.Module):
# 输入: V (B, N, 74), A (B, N, 6, N), mol_size (B,)
# 输出: (B, mol_embedding_dim)
结构:
├── block_layers: 3 × GConvBlockNoGF
│ └── GraphCNNLayer: 图卷积 + BatchNorm + ReLU
│ [74 → 256 → 256 → 256]
│
├── attention_layer: MultiHeadGlobalAttention (4 heads)
│ └── 对节点进行加权聚合,去除 padding 影响
│ [256 → 256×4 = 1024]
│
└── readout_layers: 2 × Linear + GELU
└── [1024 → 2048 → 2048]
3.2.2 MolFusionHead
class MolFusionHead(nn.Module):
# 输入: gnn_feat (B, 2048), fps (B, 2048)
# 输出: (B, 256)
结构:
├── gnn_proj: Linear(2048 → 128)
│
├── fp_proj: Linear(2048 → 256) + LayerNorm + GELU
│
└── fusion_layer:
├── Concat: [128, 256] → 384
└── Linear(384 → 256) + LayerNorm + Dropout + GELU
3.3 MS 编码分支
3.3.1 分支A: 稀疏分桶编码
# 输入: ms_bins (B, 43600)
# 输出: 256维
流程:
1. ms_input_proj: Linear(43600 → 256)
2. unsqueeze → (B, 1, 256) [假序列]
3. ms_transformer: 6层 Transformer (width=256, heads=8)
4. mean(dim=1) → (B, 256)
5. ms_final_proj: Linear + LayerNorm + Dropout → 256维
3.3.2 分支B: 峰序列编码
# 输入: ms_bins1 (B, 100, 29), ms_bins2 (B, 100)
# 输出: 256维
流程:
1. mz_embedder(ms_bins2) → (B, 100, 256) # 正弦位置编码
2. feature_proj(ms_bins1) → (B, 100, 256) # 峰特征投影
3. 相加融合 → (B, 100, 256)
4. transformer: 6层 Transformer
5. mean(dim=1) → (B, 256)
6. final_proj: Linear + LayerNorm + Dropout → 256维
3.3.3 正弦 m/z 嵌入
class SinusoidalMzEmbedding(nn.Module):
# 将 m/z 值编码为 256 维向量
# 类似 Transformer 的位置编码
PE[pos, 2i] = sin(pos / 10000^(2i/256))
PE[pos, 2i+1] = cos(pos / 10000^(2i/256))
3.3.4 双分支融合
# 两个 MS 分支的输出融合
out1 = torch.cat([ms_embeddings_A, ms_embeddings_B], dim=-1) # (B, 512)
ms_embeddings = all_final_proj(out1) # Linear(512→256) + LayerNorm + Dropout
3.4 Transformer 结构
文件: cliplayers.py - 采用 CLIP 风格的 Transformer
class ResidualAttentionBlock(nn.Module):
# 单个 Transformer Block
结构:
├── ln_1 → MultiheadAttention → 残差连接
│
└── ln_2 → MLP (Linear→QuickGELU→Linear) → 残差连接
# QuickGELU: x * sigmoid(1.702 * x)
# 比标准 GELU 更快
损失函数设计
文件: modules.py (第196-233行) - HybridAlignLoss 类
4.1 联合损失公式
默认参数: α = 0.5, β = 1.0
4.2 InfoNCE 对比损失
# 目标: 让匹配的 MS-分子对相似度高,不匹配的对相似度低
logits = (ms_emb @ mol_emb.T) / temperature # 相似度矩阵
labels = torch.arange(batch_size) # 对角线为正样本
loss_i2t = CrossEntropy(logits, labels) # MS → Mol
loss_t2i = CrossEntropy(logits.T, labels) # Mol → MS
L_InfoNCE = (loss_i2t + loss_t2i) / 2
温度参数: τ = 0.07
4.3 Tanimoto MSE 损失
# 目标: 让嵌入空间的相似度矩阵逼近真实的分子结构相似度
# 1. 预测的相似度矩阵 (余弦相似度)
pred_sim = ms_emb @ mol_emb.T # 因为已 L2 normalize
# 2. 真实的结构相似度 (Tanimoto 系数)
# Tanimoto(A,B) = (A·B) / (|A|² + |B|² - A·B)
target_sim = batch_tanimoto_sim(mol_fps, mol_fps)
# 3. MSE 损失
L_MSE = MSELoss(pred_sim, target_sim)
4.4 损失设计的意义
| 损失组件 | 作用 | 特点 |
|---|---|---|
| InfoNCE | 跨模态对齐 | 区分正负样本对,对比学习核心 |
| Tanimoto MSE | 结构约束 | 保持分子结构相似性关系 |
┌─────────────────────────────────────────────────────────────────┐
│ 联合损失的几何意义 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ InfoNCE: "匹配的 MS-Mol 对要靠近" │
│ │
│ MS₁ ←──── 拉近 ────→ Mol₁ │
│ MS₂ ←──── 拉近 ────→ Mol₂ │
│ ↑ ↑ │
│ 推远 推远 │
│ ↓ ↓ │
│ MS₃ ←──── 拉近 ────→ Mol₃ │
│ │
│ Tanimoto MSE: "结构相似的分子在嵌入空间也要相似" │
│ │
│ 如果 Tanimoto(Mol₁, Mol₂) = 0.8 │
│ 则希望 CosSim(Emb₁, Emb₂) ≈ 0.8 │
│ │
└─────────────────────────────────────────────────────────────────┘
训练流程
文件: train.py
5.1 训练入口
# main() 函数
1. 数据划分: make_train_valid(data, valid_ratio=0.1)
2. 构建 DataLoader: build_loaders()
3. 初始化模型: FragSimiModelNew(cfg).to(device)
4. 优化器: AdamW(lr=1e-3, weight_decay=1e-3)
5. 学习率调度: ReduceLROnPlateau(patience=2, factor=0.5)
6. 训练循环: epochs=50
5.2 训练周期
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
for batch in train_loader:
# 1. 数据移到 GPU
batch = {k: v.to(device) for k, v in batch.items()}
# 2. 前向传播
loss = model(batch) # 返回总损失
# 3. 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
5.3 模型保存策略
# 保存最佳模型 (基于验证集损失)
if valid_loss < best_loss:
checkpoint = {
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'config': dict(CFG)
}
torch.save(checkpoint, best_model_fn)
# 只保留最新的 3 个最佳模型
keep_best_models_num = 3
配置参数
文件: config.py
6.1 核心参数
| 参数 | 默认值 | 说明 |
|---|---|---|
batch_size |
64 | 批次大小 |
lr |
1e-3 | 学习率 |
epochs |
50 | 训练轮数 |
dropout |
0.1 | Dropout 比例 |
projection_dim |
256 | 最终嵌入维度 |
6.2 MS 相关参数
| 参数 | 默认值 | 说明 |
|---|---|---|
min_mz |
20 | 最小 m/z |
max_mz |
2000 | 最大 m/z |
bin_size |
0.05 | 分桶大小 |
add_nl |
True | 是否添加中性丢失特征 |
binary_intn |
False | 是否二值化强度 |
6.3 分子编码参数
| 参数 | 默认值 | 说明 |
|---|---|---|
mol_encoder |
'gnn+fp' | 编码方式 (fp/gnn/gnn+fp) |
mol_embedding_dim |
2048 | Morgan 指纹维度 |
fptype |
'morgan' | 指纹类型 |
molgnn_n_filters_list |
[256,256,256] | GNN 各层维度 |
molgnn_nhead |
4 | 注意力头数 |
molgnn_readout_layers |
2 | Readout 层数 |
6.4 Transformer 参数
| 参数 | 默认值 | 说明 |
|---|---|---|
tsfm_layers |
6 | Transformer 层数 |
tsfm_heads |
8 | 注意力头数 |
tsfm_in_ms |
True | MS 分支使用 Transformer |
tsfm_in_mol |
False | 分子分支使用 Transformer |
附录: 文件结构
train-001/
├── train.py # 训练入口
├── dataset.py # 数据集定义
├── modules.py # 模型定义 (FragSimiModelNew)
├── config.py # 配置参数
├── utils.py # 工具函数 (特征处理等)
├── cliplayers.py # CLIP 风格的 Transformer
└── GNN/
├── layers.py # 图卷积层
└── featurizer.py # 分子图特征化
参考
本模型设计参考了以下工作:
- CLIP (Contrastive Language-Image Pre-training)
- Graph Attention Networks (GAT)
- Morgan Fingerprints (Extended Connectivity Fingerprints)
- Tanimoto Coefficient for molecular similarity