spec2function 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3159 @@
1
+ # 所以MS2BioText的输入应该是包括MS2与BioText。然而一个BioText对应一个分子,一个分子对应多个MS2。所以在dataset构建的时候,需要储存:
2
+ # MS2的列表
3
+ # MS2对应的分子的列表
4
+ # 分子对应的BioText的列表
5
+ # 最后item的时候返回input_ids(m/z),intensity,BioText
6
+ # (这里添加一系列类内方法,在init的时候输入参数可以选择处理BioText的类内方法)
7
+ # 但是问题是但是在之后的实验里,还要记录对于每个MS2的其他信息,这个怎么储存。
8
+
9
+
10
+ # 先完成dataset然后创建实例
11
+ # 读取HMDB数据集
12
+ # HMDB.h5是MS2数据读取为list?,HMDB.parquet是meta数据,读取为?然后Biotext是一个文件夹下包含的一系列txt文件,文件名为HMDB的id,读取为?
13
+ # ,读取test data跑通两个模型试试
14
+
15
+ import pickle
16
+ import h5py
17
+ import pandas as pd
18
+ import os
19
+ import torch
20
+ from torch.utils.data import Dataset
21
+ import numpy as np
22
+ import random
23
+ from pathlib import Path
24
+ from sklearn.model_selection import train_test_split
25
+ import json
26
+ from torch.utils.data import Sampler
27
+ import random, itertools
28
+ import math
29
+ from collections.abc import Iterator
30
+ from typing import Optional, TypeVar, Dict, List, Tuple
31
+ import random
32
+ import argparse
33
+ import torch
34
+ import torch.distributed as dist
35
+ from torch.utils.data.dataset import Dataset
36
+ from torch.utils.data.sampler import Sampler
37
+
38
+
39
+ _T_co = TypeVar("_T_co", covariant=True)
40
+
41
+ class MS2MoleculeDistributedSampler(Sampler[_T_co]):
42
+ """
43
+ 专为MS2对比学习设计的分布式采样器
44
+
45
+ 新增功能:为每个batch中的样本分配不冲突的text
46
+ - 确保batch内每个molecule选择的text不出现在其他molecule的候选列表中
47
+ - 保证batch内只有对角线是正样本
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ dataset: Dataset,
53
+ batch_size: int,
54
+ num_replicas: Optional[int] = None,
55
+ rank: Optional[int] = None,
56
+ shuffle: bool = True,
57
+ seed: int = 0,
58
+ drop_last: bool = False,
59
+ ) -> None:
60
+ if num_replicas is None:
61
+ if not dist.is_available():
62
+ raise RuntimeError("Requires distributed package to be available")
63
+ num_replicas = dist.get_world_size()
64
+ if rank is None:
65
+ if not dist.is_available():
66
+ raise RuntimeError("Requires distributed package to be available")
67
+ rank = dist.get_rank()
68
+ if rank >= num_replicas or rank < 0:
69
+ raise ValueError(
70
+ f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
71
+ )
72
+
73
+ self.dataset = dataset
74
+ self.batch_size = batch_size
75
+ self.num_replicas = num_replicas
76
+ self.rank = rank
77
+ self.epoch = 0
78
+ self.drop_last = drop_last
79
+ self.shuffle = shuffle
80
+ self.seed = seed
81
+
82
+ # 构建molecule -> texts的映射
83
+ self.molecule_to_texts = self._build_molecule_text_mapping()
84
+
85
+ # 按molecule分组并按MS2数量排序
86
+ self.molecule_groups = self._group_by_molecule()
87
+ self.sorted_molecules = self._sort_molecules_by_ms2_count()
88
+
89
+ # 生成batch分配方案
90
+ self.batch_indices = self._create_batch_allocation()
91
+
92
+ # 确保能被GPU数整除
93
+ self._adjust_for_distributed()
94
+
95
+ # 计算每个进程的样本数
96
+ total_batches = len(self.batch_indices)
97
+ batches_per_replica = total_batches // self.num_replicas
98
+ self.num_samples = batches_per_replica * self.batch_size
99
+
100
+ def _build_molecule_text_mapping(self) -> Dict[str, List[str]]:
101
+ """构建每个molecule的候选text列表"""
102
+ molecule_to_texts = {}
103
+
104
+ for mol_id, entry in self.dataset.biotext_data.items():
105
+ texts = []
106
+ if isinstance(entry, list):
107
+ texts = [record['text'] for record in entry]
108
+ elif isinstance(entry, dict):
109
+ original = entry.get("original", "")
110
+ paraphrases = entry.get("paraphrases", [])
111
+ texts = [original] + paraphrases
112
+ elif isinstance(entry, str):
113
+ texts = [entry]
114
+
115
+ molecule_to_texts[mol_id] = texts
116
+
117
+ print(f"Built text mapping for {len(molecule_to_texts)} molecules")
118
+ return molecule_to_texts
119
+
120
+ def _group_by_molecule(self) -> Dict[str, List[int]]:
121
+ """按molecule ID对MS2数据进行分组"""
122
+ molecule_groups = {}
123
+ for idx in range(len(self.dataset)):
124
+ ms2_id = self.dataset.ms2_ids[idx]
125
+ molecule_id = self.dataset.preprocessed_ms2_tensors[ms2_id]['molecule_id']
126
+
127
+ if molecule_id not in molecule_groups:
128
+ molecule_groups[molecule_id] = []
129
+ molecule_groups[molecule_id].append(idx)
130
+
131
+ return molecule_groups
132
+
133
+ def _sort_molecules_by_ms2_count(self) -> List[Tuple[str, List[int]]]:
134
+ """按每个molecule拥有的MS2数量从多到少排序"""
135
+ molecule_items = [(mol_id, indices) for mol_id, indices in self.molecule_groups.items()]
136
+ sorted_items = sorted(molecule_items, key=lambda x: len(x[1]), reverse=True)
137
+
138
+ print(f"Molecule MS2 count distribution:")
139
+ print(f"Max MS2 per molecule: {len(sorted_items[0][1])}")
140
+ print(f"Min MS2 per molecule: {len(sorted_items[-1][1])}")
141
+ print(f"Total molecules: {len(sorted_items)}")
142
+ print(f"Total MS2 spectra: {sum(len(indices) for _, indices in sorted_items)}")
143
+
144
+ return sorted_items
145
+
146
+ def _assign_texts_for_batch(self, batch_molecule_ids: List[str]) -> Dict[str, int]:
147
+ """
148
+ 为batch中的每个molecule分配一个text index
149
+ 确保选中的text不在其他molecule的候选列表中
150
+
151
+ 返回: {molecule_id: text_index}
152
+ """
153
+ mol_to_text_idx = {}
154
+ occupied_texts = set() # 已被占用的text
155
+
156
+ # 按照候选text数量从少到多排序,优先处理选择空间小的molecule
157
+ sorted_mols = sorted(
158
+ batch_molecule_ids,
159
+ key=lambda mol_id: len(self.molecule_to_texts.get(mol_id, []))
160
+ )
161
+
162
+ for mol_id in sorted_mols:
163
+ candidate_texts = self.molecule_to_texts.get(mol_id, [])
164
+
165
+ if not candidate_texts:
166
+ print(f"⚠️ Warning: Molecule {mol_id} has no candidate texts")
167
+ mol_to_text_idx[mol_id] = 0
168
+ continue
169
+
170
+ # 找到所有未被占用且不在其他molecule候选中的text
171
+ available_indices = []
172
+ for i, text in enumerate(candidate_texts):
173
+ if text not in occupied_texts:
174
+ # 检查这个text是否在其他molecule的候选中
175
+ text_in_others = False
176
+ for other_mol_id in batch_molecule_ids:
177
+ if other_mol_id != mol_id:
178
+ other_texts = self.molecule_to_texts.get(other_mol_id, [])
179
+ if text in other_texts:
180
+ text_in_others = True
181
+ break
182
+
183
+ if not text_in_others:
184
+ available_indices.append(i)
185
+
186
+ # 如果有可用的text,随机选一个
187
+ if available_indices:
188
+ chosen_idx = random.choice(available_indices)
189
+ mol_to_text_idx[mol_id] = chosen_idx
190
+ occupied_texts.add(candidate_texts[chosen_idx])
191
+ else:
192
+ # Fallback: 随机选一个未被占用的(可能在其他molecule的候选中)
193
+ fallback_indices = [i for i, text in enumerate(candidate_texts)
194
+ if text not in occupied_texts]
195
+ if fallback_indices:
196
+ chosen_idx = random.choice(fallback_indices)
197
+ mol_to_text_idx[mol_id] = chosen_idx
198
+ occupied_texts.add(candidate_texts[chosen_idx])
199
+ print(f"⚠️ Fallback: Molecule {mol_id} text may conflict with others")
200
+ else:
201
+ # 极端情况:所有text都被占用了
202
+ chosen_idx = random.randint(0, len(candidate_texts) - 1)
203
+ mol_to_text_idx[mol_id] = chosen_idx
204
+ print(f"⚠️ Extreme fallback: All texts occupied for {mol_id}")
205
+
206
+ return mol_to_text_idx
207
+
208
+ def _create_batch_allocation(self) -> List[List[int]]:
209
+ """创建batch分配方案"""
210
+ molecule_ms2_usage = {}
211
+ for mol_id, indices in self.sorted_molecules:
212
+ molecule_ms2_usage[mol_id] = {
213
+ 'indices': indices.copy(),
214
+ 'used_count': 0
215
+ }
216
+
217
+ batch_indices = []
218
+ total_ms2_count = sum(len(indices) for _, indices in self.sorted_molecules)
219
+ used_ms2_count = 0
220
+
221
+ print(f"Starting batch allocation for {total_ms2_count} MS2 spectra...")
222
+
223
+ while used_ms2_count < total_ms2_count:
224
+ current_batch = []
225
+ used_molecules_in_batch = set()
226
+
227
+ for mol_id, mol_data in molecule_ms2_usage.items():
228
+ if len(current_batch) >= self.batch_size:
229
+ break
230
+
231
+ if mol_id in used_molecules_in_batch:
232
+ continue
233
+
234
+ if mol_data['used_count'] < len(mol_data['indices']):
235
+ ms2_idx = mol_data['indices'][mol_data['used_count']]
236
+ current_batch.append(ms2_idx)
237
+ used_molecules_in_batch.add(mol_id)
238
+ mol_data['used_count'] += 1
239
+ used_ms2_count += 1
240
+
241
+ if len(current_batch) < self.batch_size and len(current_batch) > 0:
242
+ available_molecules = [mol_id for mol_id in molecule_ms2_usage.keys()
243
+ if mol_id not in used_molecules_in_batch]
244
+
245
+ while len(current_batch) < self.batch_size and available_molecules:
246
+ available_molecules.sort(key=lambda x: len(molecule_ms2_usage[x]['indices']),
247
+ reverse=True)
248
+
249
+ mol_id = available_molecules[0]
250
+ mol_data = molecule_ms2_usage[mol_id]
251
+
252
+ ms2_idx = random.choice(mol_data['indices'])
253
+ current_batch.append(ms2_idx)
254
+ used_molecules_in_batch.add(mol_id)
255
+ available_molecules.remove(mol_id)
256
+
257
+ if len(current_batch) == 0:
258
+ break
259
+
260
+ if len(current_batch) < self.batch_size:
261
+ if self.drop_last:
262
+ print(f"Dropping incomplete batch with {len(current_batch)} samples")
263
+ break
264
+ else:
265
+ while len(current_batch) < self.batch_size:
266
+ available_molecules = [mol_id for mol_id in molecule_ms2_usage.keys()
267
+ if mol_id not in used_molecules_in_batch]
268
+
269
+ if not available_molecules:
270
+ print(f"Cannot fill batch further: only {len(self.sorted_molecules)} unique molecules available")
271
+ break
272
+
273
+ mol_id = random.choice(available_molecules)
274
+ mol_data = molecule_ms2_usage[mol_id]
275
+
276
+ ms2_idx = random.choice(mol_data['indices'])
277
+ current_batch.append(ms2_idx)
278
+ used_molecules_in_batch.add(mol_id)
279
+
280
+ batch_indices.append(current_batch)
281
+
282
+ print(f"Created {len(batch_indices)} batches")
283
+ print(f"Used {used_ms2_count} MS2 spectra out of {total_ms2_count}")
284
+
285
+ return batch_indices
286
+
287
+ def _adjust_for_distributed(self):
288
+ """调整batch数量以确保能被GPU数整除"""
289
+ total_batches = len(self.batch_indices)
290
+ remainder = total_batches % self.num_replicas
291
+
292
+ if remainder != 0:
293
+ if self.drop_last:
294
+ batches_to_remove = remainder
295
+ self.batch_indices = self.batch_indices[:-batches_to_remove]
296
+ print(f"Dropped {batches_to_remove} batches to ensure divisibility by {self.num_replicas} GPUs")
297
+ else:
298
+ batches_to_add = self.num_replicas - remainder
299
+ for i in range(batches_to_add):
300
+ batch_to_copy = self.batch_indices[i % len(self.batch_indices)]
301
+ self.batch_indices.append(batch_to_copy.copy())
302
+ print(f"Added {batches_to_add} batches to ensure divisibility by {self.num_replicas} GPUs")
303
+
304
+ final_batches = len(self.batch_indices)
305
+ print(f"Final batch count: {final_batches} (divisible by {self.num_replicas} GPUs)")
306
+ print(f"Each GPU will process {final_batches // self.num_replicas} batches")
307
+
308
+ def __iter__(self) -> Iterator[_T_co]:
309
+ # 获取当前进程应该处理的batch
310
+ total_batches = len(self.batch_indices)
311
+ batches_per_replica = total_batches // self.num_replicas
312
+
313
+ start_batch = self.rank * batches_per_replica
314
+ end_batch = start_batch + batches_per_replica
315
+
316
+ my_batches = self.batch_indices[start_batch:end_batch]
317
+
318
+ if self.shuffle:
319
+ g = torch.Generator()
320
+ g.manual_seed(self.seed + self.epoch)
321
+ batch_order = torch.randperm(len(my_batches), generator=g).tolist()
322
+ my_batches = [my_batches[i] for i in batch_order]
323
+
324
+ for batch in my_batches:
325
+ random.Random(self.seed + self.epoch).shuffle(batch)
326
+
327
+ # *** 关键:为每个batch分配text ***
328
+ self.dataset.text_assignment.clear()
329
+
330
+ for batch_indices in my_batches:
331
+ # 获取batch中所有molecule ID
332
+ batch_molecule_ids = []
333
+ ms2_id_to_mol_id = {}
334
+
335
+ for idx in batch_indices:
336
+ ms2_id = self.dataset.ms2_ids[idx]
337
+ mol_id = self.dataset.preprocessed_ms2_tensors[ms2_id]['molecule_id']
338
+ batch_molecule_ids.append(mol_id)
339
+ ms2_id_to_mol_id[ms2_id] = mol_id
340
+
341
+ # 为这个batch分配text indices
342
+ mol_to_text_idx = self._assign_texts_for_batch(batch_molecule_ids)
343
+
344
+ # 将分配结果写入dataset.text_assignment
345
+ for idx in batch_indices:
346
+ ms2_id = self.dataset.ms2_ids[idx]
347
+ mol_id = ms2_id_to_mol_id[ms2_id]
348
+ self.dataset.text_assignment[ms2_id] = mol_to_text_idx.get(mol_id, 0)
349
+
350
+ # 展平所有batch的indices
351
+ all_indices = []
352
+ for batch in my_batches:
353
+ all_indices.extend(batch)
354
+
355
+ return iter(all_indices)
356
+
357
+ def __len__(self) -> int:
358
+ return self.num_samples
359
+
360
+ def set_epoch(self, epoch: int) -> None:
361
+ """设置当前epoch,用于确保每个epoch的shuffle结果不同"""
362
+ self.epoch = epoch
363
+
364
+
365
+ # function for dataset augmentation
366
+ import torch
367
+ import numpy as np
368
+ import random
369
+ import torch
370
+ import numpy as np
371
+ import random
372
+
373
+
374
+ def sample_truncated_normal(mean, std, low, high):
375
+ """
376
+ 从截断正态分布中采样
377
+
378
+ Args:
379
+ mean: 均值
380
+ std: 标准差
381
+ low: 下界
382
+ high: 上界
383
+
384
+ Returns:
385
+ 采样值
386
+ """
387
+ max_attempts = 1000
388
+ for _ in range(max_attempts):
389
+ sample = np.random.normal(mean, std)
390
+ if low <= sample <= high:
391
+ return sample
392
+ # 如果1000次都没采样到,返回截断后的值
393
+ return np.clip(np.random.normal(mean, std), low, high)
394
+
395
+
396
+ def augment_tokenized_ms2_optimized(mz_tokens, intensity, word2idx, args):
397
+ """
398
+ 基于External数据特征优化的动态augmentation(支持随机noise ratio)
399
+ """
400
+ # ===== 0. 读取参数并决定是否augment =====
401
+ augment_prob = getattr(args, 'augment_prob', 0.5)
402
+ if random.random() > augment_prob:
403
+ return mz_tokens, intensity
404
+
405
+ # 确保intensity是1D
406
+ if intensity.dim() == 2:
407
+ intensity = intensity.squeeze(0)
408
+
409
+ device = mz_tokens.device
410
+
411
+ # ===== 1. 构建token_id到m/z的映射 =====
412
+ idx2word = {v: k for k, v in word2idx.items()}
413
+
414
+ def token_to_mz(token_id):
415
+ """将token id转换为实际m/z值"""
416
+ word = idx2word.get(token_id.item(), None)
417
+ if word and word not in ['[PAD]', '[MASK]']:
418
+ try:
419
+ return float(word)
420
+ except ValueError:
421
+ return None
422
+ return None
423
+
424
+ # 转换为numpy进行计算
425
+ mz_tokens_np = mz_tokens.cpu().numpy()
426
+ intensity_np = intensity.cpu().numpy()
427
+
428
+ # 获取实际的m/z值(跳过特殊token)
429
+ original_mz = []
430
+ original_intensity = []
431
+ for i, token_id in enumerate(mz_tokens_np):
432
+ mz_val = token_to_mz(torch.tensor(token_id))
433
+ if mz_val is not None:
434
+ original_mz.append(mz_val)
435
+ original_intensity.append(intensity_np[i])
436
+
437
+ if len(original_mz) == 0:
438
+ return mz_tokens, intensity
439
+
440
+ original_mz = np.array(original_mz)
441
+ original_intensity = np.array(original_intensity)
442
+ max_intensity = original_intensity.max()
443
+
444
+ # ===== 2. 识别signal peaks(强度>5%的peaks) =====
445
+ signal_threshold = 0.05
446
+ signal_mask = original_intensity >= signal_threshold * max_intensity
447
+ signal_mz = original_mz[signal_mask]
448
+ n_signal = len(signal_mz)
449
+
450
+ if n_signal == 0:
451
+ return mz_tokens, intensity
452
+
453
+ # ===== 3. 🔥 随机化参数 =====
454
+ align_to_external = getattr(args, 'align_to_external', False)
455
+ randomize_noise_ratio = getattr(args, 'randomize_noise_ratio', True) # 🔥 新增开关
456
+ noise_sampling_strategy = getattr(args, 'noise_sampling_strategy', 'uniform') # 🔥 选择策略
457
+
458
+ if align_to_external:
459
+ # 🔥 随机化noise ratio
460
+ if randomize_noise_ratio:
461
+ if noise_sampling_strategy == 'uniform':
462
+ # 策略1:均匀分布
463
+ noise_ratio_range = getattr(args, 'noise_ratio_range', [0.60, 0.90])
464
+ TARGET_NOISE_RATIO = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
465
+
466
+ elif noise_sampling_strategy == 'normal':
467
+ # 策略2:正态分布
468
+ target_noise_ratio = getattr(args, 'target_noise_ratio', 0.80)
469
+ noise_ratio_std = getattr(args, 'noise_ratio_std', 0.10)
470
+ TARGET_NOISE_RATIO = np.random.normal(target_noise_ratio, noise_ratio_std)
471
+ TARGET_NOISE_RATIO = np.clip(TARGET_NOISE_RATIO, 0.40, 0.95)
472
+
473
+ elif noise_sampling_strategy == 'bimodal':
474
+ # 策略3:双模态分布
475
+ bimodal_dirty_prob = getattr(args, 'bimodal_dirty_prob', 0.7)
476
+ if np.random.random() < bimodal_dirty_prob:
477
+ # 脏数据模式
478
+ TARGET_NOISE_RATIO = np.random.uniform(0.70, 0.90)
479
+ else:
480
+ # 干净数据模式
481
+ TARGET_NOISE_RATIO = np.random.uniform(0.30, 0.60)
482
+ else:
483
+ # 默认使用固定值
484
+ TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.80)
485
+ else:
486
+ # 不随机化,使用固定值
487
+ TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.80)
488
+
489
+ # 🔥 可选:随机化proximal ratio
490
+ randomize_proximal_ratio = getattr(args, 'randomize_proximal_ratio', False)
491
+ if randomize_proximal_ratio:
492
+ proximal_ratio_range = getattr(args, 'proximal_ratio_range', [0.15, 0.22])
493
+ PROXIMAL_RATIO_OF_NOISE = np.random.uniform(proximal_ratio_range[0], proximal_ratio_range[1])
494
+ else:
495
+ PROXIMAL_RATIO_OF_NOISE = getattr(args, 'proximal_ratio_of_noise', 0.18)
496
+
497
+ # Intensity参数
498
+ PROXIMAL_MEAN = getattr(args, 'proximal_intensity_mean', 0.0115)
499
+ PROXIMAL_STD = getattr(args, 'proximal_intensity_std', 0.0138)
500
+ ISOLATED_MEAN = getattr(args, 'isolated_intensity_mean', 0.0079)
501
+ ISOLATED_STD = getattr(args, 'isolated_intensity_std', 0.0114)
502
+
503
+ # 区域权重
504
+ use_regional_weighting = getattr(args, 'use_regional_weighting', True)
505
+ if use_regional_weighting:
506
+ REGION_WEIGHTS = [
507
+ (0, 100, 0.24),
508
+ (100, 200, 0.53),
509
+ (200, 300, 0.18),
510
+ (300, 500, 0.05)
511
+ ]
512
+ else:
513
+ REGION_WEIGHTS = None
514
+ else:
515
+ # 不对齐External时的参数
516
+ if randomize_noise_ratio:
517
+ noise_ratio_range = getattr(args, 'noise_ratio_range', [0.30, 0.70])
518
+ TARGET_NOISE_RATIO = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
519
+ else:
520
+ TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.50)
521
+
522
+ PROXIMAL_RATIO_OF_NOISE = 0.25
523
+ PROXIMAL_MEAN = 0.0141
524
+ PROXIMAL_STD = 0.0209
525
+ ISOLATED_MEAN = 0.0091
526
+ ISOLATED_STD = 0.0144
527
+ REGION_WEIGHTS = None
528
+
529
+ # 空间分布参数
530
+ proximal_distance_range = getattr(args, 'proximal_distance_range', [-1.5, 1.5])
531
+ isolated_min_distance = getattr(args, 'isolated_min_distance', 5.0)
532
+
533
+ # ===== 4. 计算需要添加的noise总数 =====
534
+ n_noise_total = int(n_signal * TARGET_NOISE_RATIO / (1 - TARGET_NOISE_RATIO))
535
+ n_proximal = int(n_noise_total * PROXIMAL_RATIO_OF_NOISE)
536
+ n_isolated = n_noise_total - n_proximal
537
+
538
+ # ... 后面的代码保持不变 ...
539
+
540
+ mz_min, mz_max = original_mz.min(), original_mz.max()
541
+
542
+ # ===== 5. 生成Proximal Noise =====
543
+ proximal_mz = []
544
+ proximal_intensity = []
545
+
546
+ for _ in range(n_proximal):
547
+ # 选择一个signal peak作为base
548
+ base_mz = np.random.choice(signal_mz)
549
+
550
+ # Proximal距离范围
551
+ offset = np.random.uniform(proximal_distance_range[0], proximal_distance_range[1])
552
+ noise_mz = base_mz + offset
553
+ noise_mz = np.clip(noise_mz, mz_min, mz_max)
554
+
555
+ # 采样intensity(截断正态分布)
556
+ noise_intensity = sample_truncated_normal(
557
+ PROXIMAL_MEAN, PROXIMAL_STD, 0.001, 0.10
558
+ ) * max_intensity
559
+
560
+ proximal_mz.append(noise_mz)
561
+ proximal_intensity.append(noise_intensity)
562
+
563
+ # ===== 6. 生成Isolated Noise(区域加权) =====
564
+ isolated_mz = []
565
+ isolated_intensity = []
566
+
567
+ if REGION_WEIGHTS:
568
+ # 使用区域加权策略
569
+ for low, high, weight in REGION_WEIGHTS:
570
+ # 只在光谱m/z范围内生成
571
+ region_low = max(low, mz_min)
572
+ region_high = min(high, mz_max)
573
+
574
+ if region_low >= region_high:
575
+ continue
576
+
577
+ n_in_region = int(n_isolated * weight)
578
+ attempts = 0
579
+ max_attempts = n_in_region * 10
580
+ generated = 0
581
+
582
+ while generated < n_in_region and attempts < max_attempts:
583
+ candidate_mz = np.random.uniform(region_low, region_high)
584
+
585
+ # 检查是否远离所有signal peaks
586
+ min_dist = np.min(np.abs(signal_mz - candidate_mz))
587
+
588
+ if min_dist >= isolated_min_distance:
589
+ # 100-200 Da区域intensity稍高
590
+ if 100 <= candidate_mz <= 200:
591
+ mean_adj = ISOLATED_MEAN * 1.1
592
+ else:
593
+ mean_adj = ISOLATED_MEAN
594
+
595
+ noise_intensity = sample_truncated_normal(
596
+ mean_adj, ISOLATED_STD, 0.0001, 0.05
597
+ ) * max_intensity
598
+
599
+ isolated_mz.append(candidate_mz)
600
+ isolated_intensity.append(noise_intensity)
601
+ generated += 1
602
+
603
+ attempts += 1
604
+ else:
605
+ # 不使用区域加权(原始随机策略)
606
+ attempts = 0
607
+ max_attempts = n_isolated * 10
608
+ generated = 0
609
+
610
+ while generated < n_isolated and attempts < max_attempts:
611
+ candidate_mz = np.random.uniform(mz_min, mz_max)
612
+ min_dist = np.min(np.abs(signal_mz - candidate_mz))
613
+
614
+ if min_dist >= isolated_min_distance:
615
+ noise_intensity = sample_truncated_normal(
616
+ ISOLATED_MEAN, ISOLATED_STD, 0.0001, 0.05
617
+ ) * max_intensity
618
+
619
+ isolated_mz.append(candidate_mz)
620
+ isolated_intensity.append(noise_intensity)
621
+ generated += 1
622
+
623
+ attempts += 1
624
+
625
+ # ===== 7. 合并所有peaks =====
626
+ all_mz = np.concatenate([original_mz, proximal_mz, isolated_mz])
627
+ all_intensity = np.concatenate([original_intensity, proximal_intensity, isolated_intensity])
628
+
629
+ # ===== 8. 转换回token ids =====
630
+ def mz_to_token(mz_val):
631
+ """将m/z值转换为token id(四舍五入到0.01精度)"""
632
+ mz_rounded = round(mz_val, 2)
633
+ mz_str = f"{mz_rounded:.2f}"
634
+ return word2idx.get(mz_str, word2idx.get('[MASK]', 1))
635
+
636
+ all_tokens = [mz_to_token(mz) for mz in all_mz]
637
+
638
+ # 按m/z排序
639
+ sorted_idx = np.argsort(all_mz)
640
+ all_tokens = np.array(all_tokens)[sorted_idx]
641
+ all_intensity = all_intensity[sorted_idx]
642
+
643
+ # ===== 9. 可选:过滤和截断 =====
644
+ filter_threshold = getattr(args, 'filter_threshold', None)
645
+ if filter_threshold and filter_threshold > 0:
646
+ max_int = all_intensity.max()
647
+ if max_int > 0:
648
+ threshold = max_int * filter_threshold
649
+ mask = all_intensity >= threshold
650
+ all_tokens = all_tokens[mask]
651
+ all_intensity = all_intensity[mask]
652
+
653
+ # 截断到maxlen
654
+ maxlen = getattr(args, 'maxlen', 100)
655
+ if len(all_tokens) > maxlen:
656
+ # 保留最强的peaks
657
+ top_indices = np.argsort(all_intensity)[-maxlen:]
658
+ top_indices = np.sort(top_indices) # 恢复m/z顺序
659
+ all_tokens = all_tokens[top_indices]
660
+ all_intensity = all_intensity[top_indices]
661
+
662
+ # ===== 10. 转换回tensor =====
663
+ augmented_mz = torch.tensor(all_tokens, dtype=mz_tokens.dtype, device=device)
664
+ augmented_intensity = torch.tensor(all_intensity, dtype=intensity.dtype, device=device).unsqueeze(0)
665
+
666
+ return augmented_mz, augmented_intensity
667
+
668
+
669
+
670
+
671
+ class MS2BioTextDataset(Dataset):
672
+ def __init__(self, ms2_data, meta_data, biotext_data, tokenizer, max_length=512,
673
+ use_paraphrase=False, use_mlm=False, use_ms2_prediction=False,
674
+ prediction_label_columns=None, word2idx=None, args=None, split='test'):
675
+ """
676
+ hard_neg_path: path to hard_negatives_v2.json
677
+ num_hard_neg_per_sample: number of hard negatives to use per sample
678
+ """
679
+ self.ms2_data = ms2_data
680
+ self.meta_data = meta_data
681
+ self.biotext_data = biotext_data
682
+ self.preprocessed_ms2_tensors = {}
683
+
684
+ for ms2_id, ms2_entry in self.ms2_data.items():
685
+ self.preprocessed_ms2_tensors[ms2_id] = {
686
+ 'mz': torch.tensor(ms2_entry['mz'], dtype=torch.float32),
687
+ 'intensity': torch.tensor(ms2_entry['intensity'], dtype=torch.float32),
688
+ 'molecule_id': ms2_entry['molecule_id']
689
+ }
690
+ self.text_assignment = {}
691
+ self.ms2_ids = list(ms2_data.keys())
692
+ self.word2idx = word2idx
693
+ self.args = args or argparse.Namespace() # 确保非None
694
+ self.split = split
695
+ self.tokenizer = tokenizer
696
+ self.max_length = max_length
697
+ self.use_mlm = use_mlm
698
+ self.use_ms2_prediction = use_ms2_prediction
699
+ self.prediction_label_columns = prediction_label_columns
700
+ self.use_paraphrase = use_paraphrase
701
+
702
+ # === NEW: Load hard negatives ===
703
+ self.hard_negatives = {}
704
+ self.num_hard_neg = getattr(self.args, "num_hard_neg_per_sample", 0)
705
+ hard_neg_path = getattr(self.args, "hard_neg_path", None)
706
+
707
+ if hard_neg_path and os.path.exists(hard_neg_path) and split == 'train':
708
+ import json
709
+ with open(hard_neg_path, 'r', encoding='utf-8') as f:
710
+ self.hard_negatives = json.load(f)
711
+ print(f"Loaded hard negatives for {len(self.hard_negatives)} molecules")
712
+ print(f"Using {self.num_hard_neg} hard negatives per sample")
713
+ else:
714
+ if split == 'train':
715
+ print(f"No valid hard_neg_path found ({hard_neg_path}), skipping hard negatives.")
716
+
717
+ # === MS2 prediction checks ===
718
+ if self.use_ms2_prediction:
719
+ if not self.prediction_label_columns or not isinstance(self.prediction_label_columns, list):
720
+ raise ValueError("When MS2 prediction task is enabled, a list of column names 'prediction_label_columns' must be provided.")
721
+
722
+ for col in self.prediction_label_columns:
723
+ if col not in self.meta_data.columns:
724
+ raise ValueError(f"Column '{col}' is not found in meta_data.")
725
+
726
+ self.num_ms2_classes = len(self.prediction_label_columns)
727
+ print(f"Found {self.num_ms2_classes} label columns for multilabel prediction task: {self.prediction_label_columns}")
728
+
729
+
730
+ def __len__(self):
731
+ return len(self.ms2_ids)
732
+
733
+ def _create_mlm_inputs(self, input_ids):
734
+ """为MLM任务创建掩码输入和标签"""
735
+ labels = input_ids.clone()
736
+ probability_matrix = torch.full(labels.shape, 0.15) # 15%的概率进行mask
737
+
738
+ # 避免mask特殊tokens (e.g., [CLS], [SEP], [PAD])
739
+ special_tokens_mask = self.tokenizer.get_special_tokens_mask(labels.tolist(), already_has_special_tokens=True)
740
+ special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
741
+
742
+ probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
743
+ masked_indices = torch.bernoulli(probability_matrix).bool()
744
+
745
+ # 将未被mask的token的label设置为-100,这样在计算loss时会被忽略
746
+ labels[~masked_indices] = -100
747
+
748
+ # 80% 的概率用 [MASK] token 替换
749
+ indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
750
+ input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
751
+
752
+ # 10% 的概率用随机token替换
753
+ indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
754
+ random_words = torch.randint(len(self.tokenizer.vocab), labels.shape, dtype=torch.long)
755
+ input_ids[indices_random] = random_words[indices_random]
756
+
757
+ return input_ids, labels
758
+
759
+ def __getitem__(self, idx):
760
+ ms2_id = self.ms2_ids[idx]
761
+ tensor_data = self.preprocessed_ms2_tensors[ms2_id]
762
+ mz = tensor_data['mz']
763
+ intensity = tensor_data['intensity']
764
+
765
+ # Dynamic augmentation
766
+ if self.split == 'train' and hasattr(self, 'word2idx'):
767
+ mz, intensity = augment_tokenized_ms2_optimized(
768
+ mz, intensity, self.word2idx, self.args
769
+ )
770
+
771
+ batch = {
772
+ 'mz': tensor_data['mz'],
773
+ 'intensity': tensor_data['intensity'].unsqueeze(0)
774
+ }
775
+
776
+ # === 统一的BioText处理逻辑 ===
777
+ molecule_id = tensor_data['molecule_id']
778
+ biotext = ""
779
+ paraphrase_text = None
780
+ all_candidate_texts = []
781
+
782
+ if molecule_id in self.biotext_data:
783
+ entry = self.biotext_data[molecule_id]
784
+
785
+ if isinstance(entry, list):
786
+ all_candidate_texts = [record['text'] for record in entry]
787
+
788
+ # *** 关键修改:使用sampler分配的text index ***
789
+ if ms2_id in self.text_assignment:
790
+ text_idx = self.text_assignment[ms2_id]
791
+ biotext = all_candidate_texts[text_idx]
792
+ else:
793
+ # fallback: 随机选择(shouldn't happen in training)
794
+ biotext = random.choice(entry)['text']
795
+
796
+ # Paraphrase: 如果需要,从剩余的候选中选一个不同的
797
+ if self.use_paraphrase and len(entry) >= 2:
798
+ remaining_indices = [i for i in range(len(all_candidate_texts))
799
+ if i != text_idx]
800
+ if remaining_indices:
801
+ para_idx = random.choice(remaining_indices)
802
+ paraphrase_text = all_candidate_texts[para_idx]
803
+
804
+ elif isinstance(entry, dict):
805
+ original = entry.get("original", "")
806
+ paraphrases = entry.get("paraphrases", [])
807
+ all_candidates = [original] + paraphrases
808
+ all_candidate_texts = all_candidates
809
+
810
+ # *** 同样的逻辑 ***
811
+ if ms2_id in self.text_assignment:
812
+ text_idx = self.text_assignment[ms2_id]
813
+ biotext = all_candidates[text_idx]
814
+ else:
815
+ biotext = original
816
+
817
+ if self.use_paraphrase and len(all_candidates) >= 2:
818
+ remaining = [i for i in range(len(all_candidates)) if i != text_idx]
819
+ if remaining:
820
+ para_idx = random.choice(remaining)
821
+ paraphrase_text = all_candidates[para_idx]
822
+
823
+ elif isinstance(entry, str):
824
+ biotext = entry
825
+ all_candidate_texts = [entry]
826
+ else:
827
+ print(f"⚠️ Warning: Molecule ID '{molecule_id}' missing BioText")
828
+
829
+ # === 后续tokenization等保持不变 ===
830
+ encoded_text = self.tokenizer(
831
+ biotext,
832
+ padding="max_length",
833
+ truncation=True,
834
+ max_length=self.max_length,
835
+ return_tensors="pt"
836
+ )
837
+
838
+ input_ids = encoded_text['input_ids'].squeeze(0)
839
+ attention_mask = encoded_text['attention_mask'].squeeze(0)
840
+
841
+ batch['text_input_ids'] = input_ids
842
+ batch['text_attention_mask'] = attention_mask
843
+ batch['all_candidate_texts'] = all_candidate_texts
844
+
845
+ # MLM task
846
+ if self.use_mlm:
847
+ masked_input_ids, mlm_labels = self._create_mlm_inputs(input_ids.clone())
848
+ batch['masked_text_input_ids'] = masked_input_ids
849
+ batch['mlm_labels'] = mlm_labels
850
+
851
+ # Paraphrase
852
+ if paraphrase_text is not None:
853
+ encoded_para = self.tokenizer(
854
+ paraphrase_text,
855
+ padding="max_length",
856
+ truncation=True,
857
+ max_length=self.max_length,
858
+ return_tensors="pt"
859
+ )
860
+ batch['paraphrase_input_ids'] = encoded_para['input_ids'].squeeze(0)
861
+ batch['paraphrase_attention_mask'] = encoded_para['attention_mask'].squeeze(0)
862
+ batch['has_paraphrase'] = True
863
+ else:
864
+ batch['has_paraphrase'] = False
865
+
866
+ batch['ms2_id'] = ms2_id
867
+ batch['molecule_id'] = molecule_id
868
+ batch['original_text'] = biotext
869
+
870
+ return batch
871
+
872
+ @staticmethod
873
+ def custom_collate_fn(batch_list):
874
+ """
875
+ 自定义collate_fn,用于处理字典形式的批次数据。
876
+ 支持可选的keys,并过滤与batch正样本冲突的hard negatives。
877
+ """
878
+ if not batch_list:
879
+ return {}
880
+
881
+ # 收集batch中所有正样本的molecule_id
882
+ batch_molecule_ids = set()
883
+ for sample in batch_list:
884
+ if 'molecule_id' in sample:
885
+ batch_molecule_ids.add(sample['molecule_id'])
886
+
887
+ # 收集所有可能的keys
888
+ all_keys = set()
889
+ for d in batch_list:
890
+ all_keys.update(d.keys())
891
+
892
+ # 按key分组收集数据
893
+ collated_batch = {}
894
+ for key in all_keys:
895
+ values = [d[key] for d in batch_list if key in d]
896
+ collated_batch[key] = values
897
+
898
+ # === 过滤冲突的hard negatives ===
899
+ if 'hard_neg_input_ids' in collated_batch and len(batch_molecule_ids) > 0:
900
+ filtered_hard_neg_ids = []
901
+ filtered_hard_neg_masks = []
902
+ filtered_has_hard_neg = []
903
+
904
+ for i, sample in enumerate(batch_list):
905
+ if sample.get('has_hard_neg', False):
906
+ hard_neg_ids = sample['hard_neg_input_ids'] # [num_hard_neg, seq_len]
907
+ hard_neg_mask = sample['hard_neg_attention_mask']
908
+ hard_neg_mol_ids = sample.get('hard_neg_molecule_ids', [])
909
+
910
+ # 找出不在batch中的hard negatives的索引
911
+ valid_indices = []
912
+ for j, neg_mol_id in enumerate(hard_neg_mol_ids):
913
+ if neg_mol_id not in batch_molecule_ids:
914
+ valid_indices.append(j)
915
+
916
+ if valid_indices:
917
+ # 只保留不冲突的hard negatives
918
+ filtered_hard_neg_ids.append(hard_neg_ids[valid_indices])
919
+ filtered_hard_neg_masks.append(hard_neg_mask[valid_indices])
920
+ filtered_has_hard_neg.append(True)
921
+ else:
922
+ # 所有hard negatives都冲突,设为空tensor
923
+ max_length = sample['text_input_ids'].shape[0]
924
+ filtered_hard_neg_ids.append(
925
+ torch.zeros((0, max_length), dtype=torch.long)
926
+ )
927
+ filtered_hard_neg_masks.append(
928
+ torch.zeros((0, max_length), dtype=torch.long)
929
+ )
930
+ filtered_has_hard_neg.append(False)
931
+ else:
932
+ # 原本就没有hard negatives
933
+ max_length = batch_list[0]['text_input_ids'].shape[0]
934
+ filtered_hard_neg_ids.append(
935
+ torch.zeros((0, max_length), dtype=torch.long)
936
+ )
937
+ filtered_hard_neg_masks.append(
938
+ torch.zeros((0, max_length), dtype=torch.long)
939
+ )
940
+ filtered_has_hard_neg.append(False)
941
+
942
+ collated_batch['hard_neg_input_ids'] = filtered_hard_neg_ids
943
+ collated_batch['hard_neg_attention_mask'] = filtered_hard_neg_masks
944
+ collated_batch['has_hard_neg'] = filtered_has_hard_neg
945
+
946
+ # 逐个处理字典中的键值
947
+ final_batch = {}
948
+ for key, values in collated_batch.items():
949
+ # 可选的boolean flags
950
+ if key in ['has_paraphrase', 'has_hard_neg']:
951
+ final_batch[key] = [d.get(key, False) for d in batch_list]
952
+
953
+ # 保持为list的keys(包括新增的hard_neg_molecule_ids)
954
+ elif key in ['hard_neg_input_ids', 'hard_neg_attention_mask',
955
+ 'paraphrase_input_ids', 'paraphrase_attention_mask',
956
+ 'hard_neg_molecule_ids']:
957
+ final_batch[key] = values
958
+
959
+ # 对张量进行堆叠
960
+ elif isinstance(values[0], torch.Tensor):
961
+ if len(values) == len(batch_list):
962
+ final_batch[key] = torch.stack(values)
963
+ else:
964
+ final_batch[key] = values
965
+
966
+ # 其他类型直接保留
967
+ else:
968
+ final_batch[key] = values
969
+
970
+ return final_batch
971
+
972
+ @staticmethod
973
+ def custom_collate_fn(batch_list):
974
+ """
975
+ 自定义collate_fn,识别batch中有text overlap的样本对
976
+ """
977
+ if not batch_list:
978
+ return {}
979
+
980
+ batch_size = len(batch_list)
981
+
982
+ # 收集batch中所有正样本的molecule_id
983
+ batch_molecule_ids = set()
984
+ for sample in batch_list:
985
+ if 'molecule_id' in sample:
986
+ batch_molecule_ids.add(sample['molecule_id'])
987
+
988
+ # === NEW: 构建text overlap矩阵 ===
989
+ # text_overlap[i][j] = 1 表示样本i和样本j有共享的候选text
990
+ text_overlap = torch.zeros(batch_size, batch_size, dtype=torch.float32)
991
+
992
+ for i in range(batch_size):
993
+ for j in range(batch_size):
994
+ if i == j:
995
+ text_overlap[i, j] = 1.0 # 自己和自己肯定overlap
996
+ else:
997
+ # 检查候选text集合是否有交集
998
+ texts_i = set(batch_list[i].get('all_candidate_texts', []))
999
+ texts_j = set(batch_list[j].get('all_candidate_texts', []))
1000
+
1001
+ if texts_i & texts_j: # 有交集
1002
+ text_overlap[i, j] = 1.0
1003
+
1004
+ # 收集所有可能的keys
1005
+ all_keys = set()
1006
+ for d in batch_list:
1007
+ all_keys.update(d.keys())
1008
+
1009
+ # 按key分组收集数据
1010
+ collated_batch = {}
1011
+ for key in all_keys:
1012
+ values = [d[key] for d in batch_list if key in d]
1013
+ collated_batch[key] = values
1014
+
1015
+ # === 过滤冲突的hard negatives ===
1016
+ if 'hard_neg_input_ids' in collated_batch and len(batch_molecule_ids) > 0:
1017
+ filtered_hard_neg_ids = []
1018
+ filtered_hard_neg_masks = []
1019
+ filtered_has_hard_neg = []
1020
+
1021
+ for i, sample in enumerate(batch_list):
1022
+ if sample.get('has_hard_neg', False):
1023
+ hard_neg_ids = sample['hard_neg_input_ids']
1024
+ hard_neg_mask = sample['hard_neg_attention_mask']
1025
+ hard_neg_mol_ids = sample.get('hard_neg_molecule_ids', [])
1026
+
1027
+ valid_indices = []
1028
+ for j, neg_mol_id in enumerate(hard_neg_mol_ids):
1029
+ if neg_mol_id not in batch_molecule_ids:
1030
+ valid_indices.append(j)
1031
+
1032
+ if valid_indices:
1033
+ filtered_hard_neg_ids.append(hard_neg_ids[valid_indices])
1034
+ filtered_hard_neg_masks.append(hard_neg_mask[valid_indices])
1035
+ filtered_has_hard_neg.append(True)
1036
+ else:
1037
+ max_length = sample['text_input_ids'].shape[0]
1038
+ filtered_hard_neg_ids.append(
1039
+ torch.zeros((0, max_length), dtype=torch.long)
1040
+ )
1041
+ filtered_hard_neg_masks.append(
1042
+ torch.zeros((0, max_length), dtype=torch.long)
1043
+ )
1044
+ filtered_has_hard_neg.append(False)
1045
+ else:
1046
+ max_length = batch_list[0]['text_input_ids'].shape[0]
1047
+ filtered_hard_neg_ids.append(
1048
+ torch.zeros((0, max_length), dtype=torch.long)
1049
+ )
1050
+ filtered_hard_neg_masks.append(
1051
+ torch.zeros((0, max_length), dtype=torch.long)
1052
+ )
1053
+ filtered_has_hard_neg.append(False)
1054
+
1055
+ collated_batch['hard_neg_input_ids'] = filtered_hard_neg_ids
1056
+ collated_batch['hard_neg_attention_mask'] = filtered_hard_neg_masks
1057
+ collated_batch['has_hard_neg'] = filtered_has_hard_neg
1058
+
1059
+ # 逐个处理字典中的键值
1060
+ final_batch = {}
1061
+ for key, values in collated_batch.items():
1062
+ if key in ['has_paraphrase', 'has_hard_neg']:
1063
+ final_batch[key] = [d.get(key, False) for d in batch_list]
1064
+ elif key in ['hard_neg_input_ids', 'hard_neg_attention_mask',
1065
+ 'paraphrase_input_ids', 'paraphrase_attention_mask',
1066
+ 'hard_neg_molecule_ids', 'all_candidate_texts']: # all_candidate_texts保持list
1067
+ final_batch[key] = values
1068
+ elif isinstance(values[0], torch.Tensor):
1069
+ if len(values) == len(batch_list):
1070
+ final_batch[key] = torch.stack(values)
1071
+ else:
1072
+ final_batch[key] = values
1073
+ else:
1074
+ final_batch[key] = values
1075
+
1076
+ # === 添加text_overlap信息 ===
1077
+ final_batch['text_overlap_matrix'] = text_overlap # [batch_size, batch_size]
1078
+
1079
+ return final_batch
1080
+
1081
+
1082
+ @staticmethod
1083
+ def load_hmdb_data_subsections(first_path, second_path, jsonl_path, max_text_sharing=5):
1084
+ """
1085
+ 使用subsections JSONL格式读取数据,并过滤高频共享的text
1086
+
1087
+ 参数:
1088
+ first_path (str): MS2数据文件路径(h5、pkl等)
1089
+ second_path (str): Meta数据文件路径(parquet、csv等)
1090
+ jsonl_path (str): BioText的jsonl文件路径
1091
+ max_text_sharing (int): text最多可以被多少个molecule共享,超过则删除
1092
+
1093
+ 返回:
1094
+ tuple: (ms2_data, meta_data, biotext_data)
1095
+ biotext_data格式: {molecule_id: [{'type': 'xxx', 'text': 'xxx'}, ...]}
1096
+ """
1097
+ from collections import defaultdict
1098
+
1099
+ # 读取MS2数据
1100
+ ms2_data = {}
1101
+ try:
1102
+ _, ext1 = os.path.splitext(first_path)
1103
+ if ext1 == '.h5':
1104
+ with h5py.File(first_path, 'r') as f:
1105
+ spectra_group = f['spectra']
1106
+ for spectrum_id in spectra_group.keys():
1107
+ group = spectra_group[spectrum_id]
1108
+ parts = spectrum_id.split('_')
1109
+ molecule_id = parts[0]
1110
+ ms2_data[spectrum_id] = {
1111
+ 'mz': group['mz'][...].tolist(),
1112
+ 'intensity': group['intensity'][...].tolist(),
1113
+ 'molecule_id': molecule_id
1114
+ }
1115
+ elif ext1 == '.pkl':
1116
+ with open(first_path, 'rb') as f:
1117
+ ms2_data = pickle.load(f)
1118
+ else:
1119
+ print(f"Unsupported file format for first path: {ext1}")
1120
+ return None, None, None
1121
+ except Exception as e:
1122
+ print(f"Error: Failed to read first file: {first_path}. Error message: {str(e)}")
1123
+ return None, None, None
1124
+
1125
+ # 读取Meta数据
1126
+ meta_data = None
1127
+ try:
1128
+ _, ext2 = os.path.splitext(second_path)
1129
+ if ext2 == '.parquet':
1130
+ meta_data = pd.read_parquet(second_path)
1131
+ elif ext2 == '.csv':
1132
+ meta_data = pd.read_csv(second_path)
1133
+ else:
1134
+ print(f"Unsupported file format for second path: {ext2}")
1135
+ return ms2_data, None, None
1136
+ except Exception as e:
1137
+ print(f"Error: Failed to read second file: {second_path}. Error message: {str(e)}")
1138
+ return ms2_data, None, None
1139
+
1140
+ # 读取 BioText JSONL 文件
1141
+ biotext_data = {}
1142
+ try:
1143
+ import json
1144
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
1145
+ for line in f:
1146
+ item = json.loads(line.strip())
1147
+ accession = item['accession']
1148
+
1149
+ if accession not in biotext_data:
1150
+ biotext_data[accession] = []
1151
+
1152
+ biotext_data[accession].append({
1153
+ 'type': item['type'],
1154
+ 'text': item['text']
1155
+ })
1156
+
1157
+ print(f"✓ Loaded BioText subsections from {os.path.basename(jsonl_path)}")
1158
+ print(f" Total molecules: {len(biotext_data)}")
1159
+ total_records_before = sum(len(v) for v in biotext_data.values())
1160
+ print(f" Total records (before filtering): {total_records_before}, Avg per molecule: {total_records_before / len(biotext_data):.1f}")
1161
+
1162
+ except Exception as e:
1163
+ print(f"Error reading BioText JSONL: {e}")
1164
+ return ms2_data, meta_data, None
1165
+
1166
+ # ===== 过滤高频共享的text =====
1167
+ print(f"\n=== Filtering texts shared by >{max_text_sharing} molecules ===")
1168
+
1169
+ # 1. 构建text -> molecules的倒排索引
1170
+ text_to_molecules = defaultdict(set)
1171
+ for mol_id, records in biotext_data.items():
1172
+ for record in records:
1173
+ text = record['text']
1174
+ if text: # 避免空字符串
1175
+ text_to_molecules[text].add(mol_id)
1176
+
1177
+ # 2. 找出需要删除的高频text
1178
+ texts_to_remove = set()
1179
+ sharing_distribution = defaultdict(int) # 统计分布
1180
+
1181
+ for text, molecules in text_to_molecules.items():
1182
+ sharing_count = len(molecules)
1183
+ sharing_distribution[sharing_count] += 1
1184
+
1185
+ if sharing_count > max_text_sharing:
1186
+ texts_to_remove.add(text)
1187
+
1188
+ print(f"Text sharing distribution (top 10):")
1189
+ for count in sorted(sharing_distribution.keys(), reverse=True)[:10]:
1190
+ print(f" {count} molecules share: {sharing_distribution[count]} texts")
1191
+
1192
+ print(f"\nFound {len(texts_to_remove)} texts to remove (shared by >{max_text_sharing} molecules)")
1193
+
1194
+ # 3. 从每个molecule的候选text中删除这些高频text
1195
+ filtered_biotext_data = {}
1196
+ total_removed = 0
1197
+ molecules_with_no_text = []
1198
+
1199
+ for mol_id, records in biotext_data.items():
1200
+ filtered_records = [record for record in records
1201
+ if record['text'] not in texts_to_remove]
1202
+
1203
+ if filtered_records:
1204
+ filtered_biotext_data[mol_id] = filtered_records
1205
+ total_removed += len(records) - len(filtered_records)
1206
+ else:
1207
+ molecules_with_no_text.append(mol_id)
1208
+ total_removed += len(records)
1209
+
1210
+ # 4. 统计信息
1211
+ print(f"\nFiltering results:")
1212
+ print(f" Text entries removed: {total_removed}")
1213
+ print(f" Molecules before: {len(biotext_data)}")
1214
+ print(f" Molecules after: {len(filtered_biotext_data)}")
1215
+ print(f" Molecules with no text left: {len(molecules_with_no_text)}")
1216
+
1217
+ if molecules_with_no_text:
1218
+ print(f" ⚠️ Warning: {len(molecules_with_no_text)} molecules lost all texts")
1219
+ if len(molecules_with_no_text) <= 5:
1220
+ print(f" Lost: {molecules_with_no_text}")
1221
+ else:
1222
+ print(f" First 5: {molecules_with_no_text[:5]}")
1223
+
1224
+ # 5. 验证过滤效果
1225
+ text_to_molecules_after = defaultdict(set)
1226
+ for mol_id, records in filtered_biotext_data.items():
1227
+ for record in records:
1228
+ text = record['text']
1229
+ if text:
1230
+ text_to_molecules_after[text].add(mol_id)
1231
+
1232
+ max_sharing_after = max(len(mols) for mols in text_to_molecules_after.values()) if text_to_molecules_after else 0
1233
+ shared_texts_after = sum(1 for mols in text_to_molecules_after.values() if len(mols) > 1)
1234
+
1235
+ print(f" Max sharing after filtering: {max_sharing_after} molecules")
1236
+ print(f" Texts still shared by multiple molecules: {shared_texts_after}")
1237
+
1238
+ total_records_after = sum(len(v) for v in filtered_biotext_data.values())
1239
+ print(f" Total records (after filtering): {total_records_after}, Avg per molecule: {total_records_after / len(filtered_biotext_data):.1f}")
1240
+
1241
+ # 使用过滤后的数据
1242
+ biotext_data = filtered_biotext_data
1243
+
1244
+ # 打印统计
1245
+ unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
1246
+ print(f"\nFinal data summary:")
1247
+ print(f" Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
1248
+ print(f" Molecule IDs in BioText data: {len(biotext_data)}")
1249
+
1250
+ return ms2_data, meta_data, biotext_data
1251
+
1252
+ @staticmethod
1253
+ def missing_biotext_handling(ms2_data, biotext_data, method="drop"):
1254
+ """
1255
+ ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
1256
+ biotext_data: dict, {molecule_id: BioText}
1257
+ """
1258
+ # If a molecule in ms2_data is missing in biotext_data, remove it from ms2_data
1259
+ # Handle missing biotext entries
1260
+ if method == "drop":
1261
+ ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in biotext_data}
1262
+ unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
1263
+ print(f"Post-processing statistics ('drop' method) - Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
1264
+ print(f"Post-processing statistics ('drop' method) - Molecule IDs in BioText data: {len(biotext_data)}")
1265
+ return ms2_data, biotext_data
1266
+
1267
+ if method == "fill":
1268
+ # Fill missing entries with empty values; will be handled during dataset initialization
1269
+ for info in ms2_data.values():
1270
+ molecule_id = info['molecule_id']
1271
+ if molecule_id not in biotext_data:
1272
+ biotext_data[molecule_id] = ""
1273
+ unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
1274
+ print(f"Post-processing statistics ('fill' method) - Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
1275
+ print(f"Post-processing statistics ('fill' method) - Molecule IDs in BioText data: {len(biotext_data)}")
1276
+ return ms2_data, biotext_data
1277
+
1278
+ raise ValueError(f"Unknown method: {method}. Method must be 'drop' or 'fill'.")
1279
+
1280
+
1281
+ @staticmethod
1282
+ def add_noise_peaks(peaks, intensities, noise_ratio=0.5, noise_intensity_range=(0.001, 0.05), seed=None):
1283
+ """
1284
+ 添加随机noise peaks来模拟外部数据
1285
+
1286
+ Args:
1287
+ peaks: list of float, 原始m/z值
1288
+ intensities: list of float, 原始强度值
1289
+ noise_ratio: float, 添加的noise peaks数量 = 原peaks数 × noise_ratio
1290
+ noise_intensity_range: tuple, noise的相对强度范围(相对于max intensity)
1291
+ seed: int, 随机种子(可选)
1292
+
1293
+ Returns:
1294
+ aug_peaks: list of float, 添加noise后的m/z
1295
+ aug_intensities: list of float, 添加noise后的强度
1296
+ """
1297
+ import numpy as np
1298
+ import random
1299
+
1300
+ if seed is not None:
1301
+ np.random.seed(seed)
1302
+ random.seed(seed)
1303
+
1304
+ if len(peaks) == 0:
1305
+ return peaks, intensities
1306
+
1307
+ max_int = max(intensities)
1308
+ if max_int == 0:
1309
+ return peaks, intensities
1310
+
1311
+ # 计算要添加的noise数量
1312
+ n_noise = int(len(peaks) * noise_ratio)
1313
+ if n_noise == 0:
1314
+ return peaks, intensities
1315
+
1316
+ # 在光谱范围内随机生成noise peaks的m/z
1317
+ mz_min, mz_max = min(peaks), max(peaks)
1318
+ noise_mz = np.random.uniform(mz_min, mz_max, n_noise).tolist()
1319
+
1320
+ # 生成低强度noise(相对于max intensity)
1321
+ noise_int = np.random.uniform(
1322
+ noise_intensity_range[0] * max_int,
1323
+ noise_intensity_range[1] * max_int,
1324
+ n_noise
1325
+ ).tolist()
1326
+
1327
+ # 合并原始peaks和noise
1328
+ aug_peaks = peaks + noise_mz
1329
+ aug_intensities = intensities + noise_int
1330
+
1331
+ # 按m/z排序
1332
+ sorted_indices = sorted(range(len(aug_peaks)), key=lambda i: aug_peaks[i])
1333
+ aug_peaks = [aug_peaks[i] for i in sorted_indices]
1334
+ aug_intensities = [aug_intensities[i] for i in sorted_indices]
1335
+
1336
+ return aug_peaks, aug_intensities
1337
+
1338
+
1339
+ @staticmethod
1340
+ def filter_low_intensity_peaks(peaks, intensities, threshold=0.01):
1341
+ """
1342
+ 过滤低强度peaks
1343
+
1344
+ Args:
1345
+ peaks: list of float, m/z值
1346
+ intensities: list of float, 强度值
1347
+ threshold: float, 相对强度阈值(0.01 = 1%)
1348
+
1349
+ Returns:
1350
+ filtered_peaks: list of float
1351
+ filtered_intensities: list of float
1352
+ """
1353
+ if len(peaks) == 0 or len(intensities) == 0:
1354
+ return peaks, intensities
1355
+
1356
+ max_int = max(intensities)
1357
+ if max_int == 0:
1358
+ return peaks, intensities
1359
+
1360
+ # 归一化并过滤
1361
+ norm_intensities = [i / max_int for i in intensities]
1362
+ filtered_peaks = []
1363
+ filtered_intensities = []
1364
+
1365
+ for mz, intensity, norm_int in zip(peaks, intensities, norm_intensities):
1366
+ if norm_int >= threshold:
1367
+ filtered_peaks.append(mz)
1368
+ filtered_intensities.append(intensity)
1369
+
1370
+ return filtered_peaks, filtered_intensities
1371
+
1372
+
1373
+ @staticmethod
1374
+ def augment_ms2_data(ms2_data, args):
1375
+ """
1376
+ 对MS2数据进行增强(必须在preprocess之前调用)
1377
+
1378
+ Args:
1379
+ ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
1380
+ 注意:mz和intensity必须是原始的float值,不能是token_ids
1381
+ args: argparse.Namespace, 包含增强参数:
1382
+ - augment_noise: bool, 是否添加noise增强 (default: False)
1383
+ - augment_multiplier: int, 每个光谱生成几个版本 (1=不增强, 2=生成2倍数据)
1384
+ - noise_ratio: float, 添加的noise数量 = 原peaks数 × noise_ratio
1385
+ - noise_intensity_range: tuple, noise强度范围 (相对于max intensity)
1386
+ - filter_threshold: float or None, 过滤低强度peaks的阈值
1387
+
1388
+ Returns:
1389
+ augmented_ms2_data: dict, 包含原始+增强版本的数据
1390
+ 如果augment_multiplier=1,返回原始数据
1391
+ 如果augment_multiplier=2,返回2倍数据(原始+1个增强版本)
1392
+
1393
+ Example:
1394
+ >>> augmented_data = MS2BioTextDataset.augment_ms2_data(ms2_data, args)
1395
+ >>> processed_data, word2idx = MS2BioTextDataset.preprocess_ms2_data_positive_only(
1396
+ ... augmented_data, meta_data
1397
+ ... )
1398
+ """
1399
+ import numpy as np
1400
+
1401
+ # 获取参数(兼容没有这些参数的情况)
1402
+ augment_noise = getattr(args, 'augment_noise', False)
1403
+ augment_multiplier = getattr(args, 'augment_multiplier', 1)
1404
+ noise_ratio = getattr(args, 'noise_ratio', 0.5)
1405
+ noise_intensity_range = getattr(args, 'noise_intensity_range', (0.001, 0.05))
1406
+ filter_threshold = getattr(args, 'filter_threshold', None)
1407
+
1408
+ # 如果不需要增强,直接返回原数据
1409
+ if not augment_noise or augment_multiplier <= 1:
1410
+ print("ℹ️ 未启用数据增强 (augment_noise=False or augment_multiplier<=1)")
1411
+ return ms2_data
1412
+
1413
+ print(f"\n{'='*60}")
1414
+ print(f"🔄 MS2数据增强")
1415
+ print(f"{'='*60}")
1416
+ print(f" 增强倍数: {augment_multiplier}x")
1417
+ print(f" Noise比例: {noise_ratio}")
1418
+ print(f" Noise强度范围: {noise_intensity_range}")
1419
+ if filter_threshold:
1420
+ print(f" 过滤阈值: {filter_threshold} (相对强度)")
1421
+ print(f" 原始光谱数: {len(ms2_data)}")
1422
+
1423
+ augmented_ms2_data = {}
1424
+
1425
+ for ms2_id, info in ms2_data.items():
1426
+ molecule_id = info.get('molecule_id')
1427
+
1428
+ # 检查数据格式
1429
+ if not isinstance(info['mz'], list) or not isinstance(info['intensity'], list):
1430
+ print(f"⚠️ 跳过 {ms2_id}: mz或intensity不是list格式")
1431
+ continue
1432
+
1433
+ # 版本0: 原始数据(可选过滤)
1434
+ peaks_original = info['mz'].copy() if isinstance(info['mz'], list) else list(info['mz'])
1435
+ intensities_original = info['intensity'].copy() if isinstance(info['intensity'], list) else list(info['intensity'])
1436
+
1437
+ # 可选:过滤低强度peaks
1438
+ if filter_threshold is not None and filter_threshold > 0:
1439
+ peaks_original, intensities_original = MS2BioTextDataset.filter_low_intensity_peaks(
1440
+ peaks_original, intensities_original, threshold=filter_threshold
1441
+ )
1442
+
1443
+ # 保存原始版本
1444
+ augmented_ms2_data[ms2_id] = {
1445
+ 'mz': peaks_original,
1446
+ 'intensity': intensities_original,
1447
+ 'molecule_id': molecule_id
1448
+ }
1449
+
1450
+ # 生成增强版本(版本1到N-1)
1451
+ for aug_idx in range(1, augment_multiplier):
1452
+ peaks_aug, intensities_aug = MS2BioTextDataset.add_noise_peaks(
1453
+ peaks_original.copy(),
1454
+ intensities_original.copy(),
1455
+ noise_ratio=noise_ratio,
1456
+ noise_intensity_range=noise_intensity_range,
1457
+ seed=None # 每次随机生成不同的noise
1458
+ )
1459
+
1460
+ # 新的ID:原ID + 后缀
1461
+ aug_ms2_id = f"{ms2_id}_aug{aug_idx}"
1462
+ augmented_ms2_data[aug_ms2_id] = {
1463
+ 'mz': peaks_aug,
1464
+ 'intensity': intensities_aug,
1465
+ 'molecule_id': molecule_id # 保持相同的molecule_id!
1466
+ }
1467
+
1468
+ print(f" ✓ 增强后光谱数: {len(augmented_ms2_data)}")
1469
+ print(f" 增强版本数: {len(augmented_ms2_data) - len(ms2_data)}")
1470
+ print(f"{'='*60}\n")
1471
+
1472
+ return augmented_ms2_data
1473
+
1474
+ @staticmethod
1475
+ def preprocess_ms2_data_positive_only(ms2_data, meta_data, maxlen=100, min_peaks=0):
1476
+ """
1477
+ Preprocess ms2_data for model input.
1478
+
1479
+ Parameters:
1480
+ - ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
1481
+ - meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
1482
+ - maxlen: int, maximum sequence length
1483
+ - min_peaks: int, minimum number of peaks required (default: 0, no filtering)
1484
+
1485
+ Returns:
1486
+ - ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
1487
+ - word2idx: dict, maps string-formatted m/z values to token indices
1488
+ """
1489
+
1490
+ # ===== 新增:安全转换precursor_mass的函数 =====
1491
+ def safe_convert_precursor(value):
1492
+ """安全转换precursor_mass值,处理异常格式"""
1493
+ if pd.isna(value):
1494
+ return None
1495
+
1496
+ # 如果已经是数字
1497
+ if isinstance(value, (int, float)):
1498
+ return float(value)
1499
+
1500
+ # 如果是字符串
1501
+ value_str = str(value).strip()
1502
+
1503
+ # 处理空字符串
1504
+ if value_str == '' or value_str.lower() == 'nan':
1505
+ return None
1506
+
1507
+ # 处理 "209/192" 这种格式(取第一个值)
1508
+ if '/' in value_str:
1509
+ try:
1510
+ return float(value_str.split('/')[0])
1511
+ except:
1512
+ return None
1513
+
1514
+ # 尝试直接转换
1515
+ try:
1516
+ return float(value_str)
1517
+ except:
1518
+ return None
1519
+ # ============================================
1520
+
1521
+ # 1) Create word list: ["0.00", "0.01", ..., "999.99"]
1522
+ word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
1523
+ word_list = ["%.2f" % i for i in word_list]
1524
+
1525
+ # 2) Build word2idx dictionary with special tokens
1526
+ word2idx = {'[PAD]': 0, '[MASK]': 1}
1527
+ for i, w in enumerate(word_list):
1528
+ word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
1529
+
1530
+ # 3) Initialize output dictionary
1531
+ ms_data = {}
1532
+
1533
+ # ===== 新增:统计信息 =====
1534
+ filter_stats = {
1535
+ 'total': 0,
1536
+ 'empty_mz': 0,
1537
+ 'not_positive': 0,
1538
+ 'no_meta': 0,
1539
+ 'no_precursor': 0,
1540
+ 'precursor_gt_1000': 0,
1541
+ 'no_peaks_after_filter': 0,
1542
+ 'too_few_peaks': 0, # 新增
1543
+ 'kept': 0
1544
+ }
1545
+ # ===========================
1546
+
1547
+ # 4) Iterate through each ms2_id
1548
+ for ms2_id, info in ms2_data.items():
1549
+ filter_stats['total'] += 1
1550
+
1551
+ mz_data = info.get('mz')
1552
+ if mz_data is None or len(mz_data) == 0:
1553
+ filter_stats['empty_mz'] += 1
1554
+ continue
1555
+ peaks = info['mz']
1556
+ intensities = info['intensity']
1557
+ molecule_id = info.get('molecule_id', None)
1558
+
1559
+ specific_row = meta_data[meta_data["file_name"] == ms2_id]
1560
+ if specific_row.empty:
1561
+ filter_stats['no_meta'] += 1
1562
+ continue
1563
+ elif specific_row["Polarity"].values[0] not in ["Positive", "positive"]:
1564
+ filter_stats['not_positive'] += 1
1565
+ continue
1566
+
1567
+ # 4.1 Find precursor mass from meta_data
1568
+ if 'HMDB.ID' in meta_data.columns:
1569
+ row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
1570
+ else:
1571
+ row = meta_data.loc[meta_data.index == molecule_id]
1572
+ if row.empty:
1573
+ filter_stats['no_meta'] += 1
1574
+ continue
1575
+
1576
+ # ===== 修改:使用安全转换函数 =====
1577
+ precursor_val = safe_convert_precursor(row['precursor_mass'].values[0])
1578
+ if precursor_val is None:
1579
+ filter_stats['no_precursor'] += 1
1580
+ continue
1581
+ # ===================================
1582
+
1583
+ if precursor_val > 1000:
1584
+ filter_stats['precursor_gt_1000'] += 1
1585
+ continue
1586
+ precursor_str = "%.2f" % precursor_val
1587
+
1588
+ # 4.2 Convert m/z values to string and map to indices
1589
+ peaks_str = []
1590
+ for mz_val in peaks:
1591
+ if mz_val <= 1000:
1592
+ peaks_str.append("%.2f" % mz_val)
1593
+
1594
+ # ===== 新增:检查peaks数量 =====
1595
+ if len(peaks_str) == 0:
1596
+ filter_stats['no_peaks_after_filter'] += 1
1597
+ continue
1598
+
1599
+ if len(peaks_str) < min_peaks:
1600
+ filter_stats['too_few_peaks'] += 1
1601
+ continue
1602
+ # ===============================
1603
+
1604
+ token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
1605
+
1606
+ # 4.3 Normalize intensity and prepend a fixed value (2)
1607
+ intensities = np.hstack((2, intensities))
1608
+ max_intensity = np.max(intensities)
1609
+ if max_intensity != 0:
1610
+ intensities = intensities / max_intensity
1611
+
1612
+ # 4.4 Pad or truncate to maxlen
1613
+ n_pad = maxlen - len(token_ids)
1614
+ if n_pad < 0:
1615
+ token_ids = token_ids[:maxlen]
1616
+ intensities = intensities[:maxlen]
1617
+ n_pad = 0
1618
+ token_ids += [word2idx['[PAD]']] * n_pad
1619
+ if len(intensities) < maxlen:
1620
+ intensities = np.hstack([intensities, np.zeros(maxlen - len(intensities))])
1621
+ else:
1622
+ intensities = intensities[:maxlen]
1623
+
1624
+ # 4.5 Save processed result
1625
+ ms_data[ms2_id] = {
1626
+ 'mz': token_ids,
1627
+ 'intensity': intensities.tolist(),
1628
+ 'molecule_id': molecule_id
1629
+ }
1630
+ filter_stats['kept'] += 1
1631
+
1632
+ # ===== 新增:打印统计信息 =====
1633
+ print(f"\n预处理统计:")
1634
+ print(f" 总光谱: {filter_stats['total']}")
1635
+ print(f" 过滤:")
1636
+ print(f" 空M/Z: {filter_stats['empty_mz']}")
1637
+ print(f" 非Positive: {filter_stats['not_positive']}")
1638
+ print(f" 无meta: {filter_stats['no_meta']}")
1639
+ print(f" 无/异常precursor: {filter_stats['no_precursor']}")
1640
+ print(f" precursor>1000: {filter_stats['precursor_gt_1000']}")
1641
+ print(f" 所有peaks>1000: {filter_stats['no_peaks_after_filter']}")
1642
+ if min_peaks > 0:
1643
+ print(f" peaks<{min_peaks}: {filter_stats['too_few_peaks']}")
1644
+ print(f" ✓ 保留: {filter_stats['kept']} ({filter_stats['kept']/filter_stats['total']*100:.2f}%)")
1645
+ # ================================
1646
+
1647
+ # 5) Return processed data and dictionary
1648
+ return ms_data, word2idx
1649
+
1650
+
1651
+ @staticmethod
1652
+ def augment_ms2_data_parallel(ms2_data, args, n_workers=None):
1653
+ """多进程版本的augment_ms2_data"""
1654
+ from multiprocessing import Pool, cpu_count
1655
+
1656
+ if n_workers is None:
1657
+ n_workers = min(cpu_count() - 1, 8)
1658
+
1659
+ augment_noise = getattr(args, 'augment_noise', False)
1660
+ augment_multiplier = getattr(args, 'augment_multiplier', 1)
1661
+
1662
+ if not augment_noise or augment_multiplier <= 1:
1663
+ print("ℹ️ 未启用数据增强")
1664
+ return ms2_data
1665
+
1666
+ print(f"\n🚀 多进程数据增强 (workers={n_workers})...")
1667
+
1668
+ # 准备参数
1669
+ items = list(ms2_data.items())
1670
+ chunk_size = max(1, len(items) // (n_workers * 4))
1671
+
1672
+ # 提取参数
1673
+ filter_threshold = getattr(args, 'filter_threshold', None)
1674
+ noise_ratio = getattr(args, 'noise_ratio', 0.5)
1675
+ noise_intensity_range = getattr(args, 'noise_intensity_range', (0.001, 0.05))
1676
+
1677
+ # 分批并打包参数
1678
+ batches = [items[i:i+chunk_size] for i in range(0, len(items), chunk_size)]
1679
+ batch_data = [(batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier)
1680
+ for batch in batches]
1681
+
1682
+ with Pool(n_workers) as pool:
1683
+ results = pool.map(_augment_worker, batch_data)
1684
+
1685
+ # 合并结果
1686
+ augmented_data = {}
1687
+ for r in results:
1688
+ augmented_data.update(r)
1689
+
1690
+ print(f" ✓ 完成: {len(augmented_data)} 光谱")
1691
+ return augmented_data
1692
+
1693
+
1694
+ @staticmethod
1695
+ def preprocess_ms2_data_positive_only_parallel(ms2_data, meta_data, maxlen=100, min_peaks=0, n_workers=None,
1696
+ precursor_mode='normalize_add', precursor_value=2.0):
1697
+ """
1698
+ 多进程版本的preprocess
1699
+
1700
+ Args:
1701
+ precursor_mode:
1702
+ - 'scale_fixed': 缩放fragments到precursor_value(如20000),precursor固定为2
1703
+ - 'normalize_add': 归一化fragments到1,precursor用precursor_value(如2.0),再整体归一化
1704
+ - 'original': 原始MSBERT方式
1705
+ precursor_value:
1706
+ - mode='scale_fixed'时: fragments缩放的目标值(默认20000)
1707
+ - mode='normalize_add'时: precursor的强度值(默认2.0)
1708
+ """
1709
+ from multiprocessing import Pool, cpu_count
1710
+ import numpy as np
1711
+ import pandas as pd
1712
+
1713
+ if n_workers is None:
1714
+ n_workers = min(cpu_count() - 1, 8)
1715
+
1716
+ print(f"\n🚀 多进程预处理 (workers={n_workers}, mode={precursor_mode}, value={precursor_value})...")
1717
+
1718
+ # 构建word2idx
1719
+ word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
1720
+ word_list = ["%.2f" % i for i in word_list]
1721
+ word2idx = {'[PAD]': 0, '[MASK]': 1}
1722
+ for i, w in enumerate(word_list):
1723
+ word2idx[w] = i + 2
1724
+
1725
+ # 预处理 meta_data
1726
+ meta_data_processed = meta_data.copy()
1727
+ if "Polarity" in meta_data_processed.columns:
1728
+ meta_data_processed["Polarity"] = meta_data_processed["Polarity"].astype(str).str.lower().str.strip()
1729
+
1730
+ # 计算最大碎片数
1731
+ max_frag = max(0, min(100, maxlen - 1))
1732
+
1733
+ # 准备数据
1734
+ items = list(ms2_data.items())
1735
+ chunk_size = max(1, len(items) // (n_workers * 4))
1736
+
1737
+ # 分批并打包参数
1738
+ batches = [items[i:i+chunk_size] for i in range(0, len(items), chunk_size)]
1739
+ batch_data = [(batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks,
1740
+ precursor_mode, precursor_value)
1741
+ for batch in batches]
1742
+
1743
+ with Pool(n_workers) as pool:
1744
+ results = pool.map(_preprocess_worker, batch_data)
1745
+
1746
+ # 合并
1747
+ ms_data = {}
1748
+ total_kept = 0
1749
+ total_filtered = 0
1750
+ for r, stats in results:
1751
+ ms_data.update(r)
1752
+ total_kept += stats['kept']
1753
+ total_filtered += stats['filtered']
1754
+
1755
+ print(f" ✓ 完成: {total_kept}/{len(items)} 光谱 (过滤: {total_filtered})")
1756
+ return ms_data, word2idx
1757
+
1758
+ @staticmethod
1759
+ def preprocess_ms2_data(ms2_data, meta_data, maxlen=100):
1760
+ """
1761
+ Preprocess ms2_data for model input.
1762
+
1763
+ Parameters:
1764
+ - ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
1765
+ - meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
1766
+ - maxlen: int, maximum sequence length
1767
+
1768
+ Returns:
1769
+ - ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
1770
+ - word2idx: dict, maps string-formatted m/z values to token indices
1771
+ """
1772
+ # 1) Create word list: ["0.00", "0.01", ..., "999.99"]
1773
+ word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
1774
+ word_list = ["%.2f" % i for i in word_list]
1775
+
1776
+ # 2) Build word2idx dictionary with special tokens
1777
+ word2idx = {'[PAD]': 0, '[MASK]': 1}
1778
+ for i, w in enumerate(word_list):
1779
+ word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
1780
+
1781
+ # 3) Initialize output dictionary
1782
+ ms_data = {}
1783
+
1784
+ # 4) Iterate through each ms2_id
1785
+ for ms2_id, info in ms2_data.items():
1786
+ if not info['mz']:
1787
+ continue
1788
+ peaks = info['mz']
1789
+ intensities = info['intensity']
1790
+ molecule_id = info.get('molecule_id', None)
1791
+
1792
+ # 4.1 Find precursor mass from meta_data
1793
+ if 'HMDB.ID' in meta_data.columns:
1794
+ row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
1795
+ else:
1796
+ row = meta_data.loc[meta_data.index == molecule_id]
1797
+ if row.empty:
1798
+ continue
1799
+ precursor_val = float(row['precursor_mass'].values[0])
1800
+ if pd.isna(precursor_val):
1801
+ continue
1802
+ if precursor_val > 1000:
1803
+ continue
1804
+ precursor_str = "%.2f" % precursor_val
1805
+
1806
+ # 4.2 Convert m/z values to string and map to indices
1807
+ peaks_str = []
1808
+ for mz_val in peaks:
1809
+ if mz_val <= 1000:
1810
+ peaks_str.append("%.2f" % mz_val)
1811
+ else:
1812
+ continue
1813
+ token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
1814
+
1815
+ # 4.3 Normalize intensity and prepend a fixed value (2)
1816
+ intensities = np.hstack((2, intensities))
1817
+ max_intensity = np.max(intensities)
1818
+ if max_intensity != 0:
1819
+ intensities = intensities / max_intensity
1820
+
1821
+ # 4.4 Pad or truncate to maxlen
1822
+ n_pad = maxlen - len(token_ids)
1823
+ if n_pad < 0:
1824
+ token_ids = token_ids[:maxlen]
1825
+ intensities = intensities[:maxlen]
1826
+ n_pad = 0
1827
+ token_ids += [word2idx['[PAD]']] * n_pad
1828
+ if len(intensities) < maxlen:
1829
+ intensities = np.hstack([intensities, np.zeros(maxlen - len(intensities))])
1830
+ else:
1831
+ intensities = intensities[:maxlen]
1832
+
1833
+ # 4.5 Save processed result
1834
+ ms_data[ms2_id] = {
1835
+ 'mz': token_ids,
1836
+ 'intensity': intensities.tolist(),
1837
+ 'molecule_id': molecule_id
1838
+ }
1839
+
1840
+ # 5) Return processed data and dictionary
1841
+ return ms_data, word2idx
1842
+
1843
+ @staticmethod
1844
+ def fill_precursor_data(meta_data, ms_data):
1845
+ """
1846
+ Fill in missing precursor ion mass in mass spectrometry metadata.
1847
+
1848
+ Parameters:
1849
+ meta_data (DataFrame): DataFrame containing metadata for mass spectrometry samples.
1850
+ ms_data (dict): Dictionary containing peak data, with keys as spectrum IDs and values as dicts with 'mz' and 'intensity'.
1851
+
1852
+ Returns:
1853
+ DataFrame: Updated meta_data with filled precursor ion mass.
1854
+ """
1855
+ # Copy the DataFrame to avoid modifying the original
1856
+ meta_data = meta_data.copy()
1857
+
1858
+ # Find column name containing 'precursor'
1859
+ precursor_cols = [col for col in meta_data.columns if 'precursor' in col.lower()]
1860
+ if not precursor_cols:
1861
+ raise ValueError("No column containing 'precursor' found in meta_data")
1862
+ precursor_col = precursor_cols[0]
1863
+ print(f"Using '{precursor_col}' as the precursor mass column")
1864
+
1865
+ init_nan = meta_data[precursor_col].isna().sum()
1866
+
1867
+ # Find column named 'mz'
1868
+ mz_cols = [col for col in meta_data.columns if col.lower() == 'mz']
1869
+ has_mz_column = len(mz_cols) > 0
1870
+ mz_col = mz_cols[0] if has_mz_column else None
1871
+
1872
+ # Proton mass (H+) is approximately 1.007276 Da
1873
+ proton_mass = 1.007276
1874
+
1875
+ # Tolerance threshold for isotopic effect (Da)
1876
+ isotope_threshold = 2.0
1877
+
1878
+ # Dictionary of adduct ion modes with their corresponding mass calculations
1879
+ adduct_modes = {
1880
+ 'positive': {
1881
+ '[M+H]+': lambda m: m + proton_mass,
1882
+ '[M+H-H2O]+': lambda m: m + proton_mass - 18.010565,
1883
+ '[M+Na]+': lambda m: m + 22.989218,
1884
+ '[M+K]+': lambda m: m + 39.098301,
1885
+ '[M+NH4]+': lambda m: m + 18.033823,
1886
+ '[2M+H]+': lambda m: 2*m + proton_mass,
1887
+ '[2M+Na]+': lambda m: 2*m + 22.989218,
1888
+ '[2M+K]+': lambda m: 2*m + 39.098301,
1889
+ '[2M+NH4]+': lambda m: 2*m + 18.033823,
1890
+ '[2M+H-H2O]+': lambda m: 2*m + proton_mass - 18.010565
1891
+ },
1892
+ 'negative': {
1893
+ '[M-H]-': lambda m: m - proton_mass,
1894
+ '[M-H2O-H]-': lambda m: m - proton_mass - 18.010565,
1895
+ '[M+Cl]-': lambda m: m + 34.969402,
1896
+ '[M+HAc-H]-': lambda m: m + 59.013851,
1897
+ '[2M-H]-': lambda m: 2*m - proton_mass,
1898
+ '[2M+Cl]-': lambda m: 2*m + 34.969402,
1899
+ '[2M+HAc-H]-': lambda m: 2*m + 59.013851
1900
+ }
1901
+ }
1902
+ meta_data[precursor_col] = pd.to_numeric(meta_data[precursor_col], errors='coerce')
1903
+ # Replace negative precursor values with NaN
1904
+ meta_data.loc[meta_data[precursor_col] < 0, precursor_col] = np.nan
1905
+
1906
+ # Fill missing precursor masses
1907
+ for idx, row in meta_data.iterrows():
1908
+ if pd.isna(row[precursor_col]):
1909
+ spectrum_id = idx
1910
+
1911
+ # Determine polarity
1912
+ polarity = str(row['Polarity']).lower()
1913
+ if 'positive' in polarity:
1914
+ polarity_type = 'positive'
1915
+ elif 'negative' in polarity:
1916
+ polarity_type = 'negative'
1917
+ else:
1918
+ polarity_type = 'positive'
1919
+
1920
+ # Try to get base mz from meta_data
1921
+ base_mz = None
1922
+ if has_mz_column and not pd.isna(row[mz_col]):
1923
+ base_mz = row[mz_col]
1924
+ else:
1925
+ if row["file_name"] not in ms_data:
1926
+ print(f"Warning: spectrum ID {spectrum_id} not found in ms_data")
1927
+ continue
1928
+
1929
+ spectrum = ms_data.get(row["file_name"], {})
1930
+ if 'mz' not in spectrum or len(spectrum['mz']) == 0:
1931
+ print(f"Warning: spectrum ID {spectrum_id} has no mz data")
1932
+ continue
1933
+
1934
+ base_mz = max(spectrum['mz'])
1935
+
1936
+ candidate_precursors = {}
1937
+ for mode_name, mode_func in adduct_modes[polarity_type].items():
1938
+ candidate_mass = mode_func(base_mz)
1939
+ candidate_precursors[mode_name] = candidate_mass
1940
+
1941
+ valid_candidates = {}
1942
+ spectrum = ms_data.get(row["file_name"], {})
1943
+ if 'mz' in spectrum and len(spectrum['mz']) > 0:
1944
+ max_fragment_mz = max(spectrum['mz'])
1945
+
1946
+ for mode_name, precursor_mass in candidate_precursors.items():
1947
+ if max_fragment_mz <= precursor_mass + isotope_threshold:
1948
+ valid_candidates[mode_name] = precursor_mass
1949
+
1950
+ if valid_candidates:
1951
+ default_mode = '[M+H]+' if polarity_type == 'positive' else '[M-H]-'
1952
+ if default_mode in valid_candidates:
1953
+ selected_mode = default_mode
1954
+ else:
1955
+ selected_mode = list(valid_candidates.keys())[0]
1956
+ precursor_mass = valid_candidates[selected_mode]
1957
+ meta_data.at[idx, precursor_col] = precursor_mass
1958
+ else:
1959
+ if 'mz' in spectrum and len(spectrum['mz']) > 0:
1960
+ max_mz = max(spectrum['mz'])
1961
+ if polarity_type == 'positive':
1962
+ adjusted_precursor = max_mz + proton_mass + 1.0
1963
+ else:
1964
+ adjusted_precursor = max_mz - proton_mass + 1.0
1965
+ meta_data.at[idx, precursor_col] = adjusted_precursor
1966
+ else:
1967
+ print(f"Warning: unable to determine precursor mass for spectrum ID {spectrum_id}")
1968
+
1969
+ left_nan = meta_data[precursor_col].isna().sum()
1970
+ print(f"Precursor mass missing: initially {init_nan}; filled {init_nan - left_nan}; remaining {left_nan}.")
1971
+
1972
+ return meta_data
1973
+
1974
+
1975
+ @staticmethod
1976
+ def preprocess_ms2_data_positive_only(ms2_data, meta_data, maxlen=100):
1977
+ """
1978
+ Preprocess ms2_data for model input.
1979
+
1980
+ Parameters:
1981
+ - ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
1982
+ - meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
1983
+ - maxlen: int, maximum sequence length
1984
+
1985
+ Returns:
1986
+ - ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
1987
+ - word2idx: dict, maps string-formatted m/z values to token indices
1988
+ """
1989
+ import numpy as np
1990
+ import pandas as pd
1991
+
1992
+ # 1) Create word list: ["0.00", "0.01", ..., "999.99"]
1993
+ word_list = list(np.round(np.linspace(0, 1000, 100 * 1000, endpoint=False), 2))
1994
+ word_list = ["%.2f" % i for i in word_list]
1995
+
1996
+ # 2) Build word2idx dictionary with special tokens
1997
+ word2idx = {'[PAD]': 0, '[MASK]': 1}
1998
+ for i, w in enumerate(word_list):
1999
+ word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
2000
+
2001
+ # 3) Initialize output dictionary
2002
+ ms_data = {}
2003
+
2004
+ # 预计算:正离子模式的判定更稳健(lower+strip)
2005
+ if "Polarity" in meta_data.columns:
2006
+ meta_data = meta_data.copy()
2007
+ meta_data["Polarity"] = meta_data["Polarity"].astype(str).str.lower().str.strip()
2008
+
2009
+ # 允许的最大碎片数(保证前体+碎片 <= maxlen)
2010
+ max_frag = max(0, min(100, maxlen - 1))
2011
+
2012
+ # 4) Iterate through each ms2_id
2013
+ for ms2_id, info in ms2_data.items():
2014
+ # 基础检查
2015
+ if not info.get('mz'):
2016
+ continue
2017
+
2018
+ peaks = np.asarray(info['mz'], dtype=float)
2019
+ intensities = np.asarray(info['intensity'], dtype=float)
2020
+ molecule_id = info.get('molecule_id', None)
2021
+
2022
+ # 4.0 文件名对应行用于极性判断
2023
+ specific_row = meta_data.loc[meta_data["file_name"] == ms2_id] if "file_name" in meta_data.columns else pd.DataFrame()
2024
+ if specific_row.empty:
2025
+ # 若找不到,就尽量用 molecule_id 定位一行(不强制)
2026
+ if molecule_id is not None:
2027
+ if 'HMDB.ID' in meta_data.columns:
2028
+ specific_row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
2029
+ else:
2030
+ specific_row = meta_data.loc[meta_data.index == molecule_id]
2031
+ if specific_row.empty:
2032
+ continue
2033
+
2034
+ # 只保留正离子
2035
+ pol = str(specific_row["Polarity"].values[0]).lower().strip() if "Polarity" in specific_row.columns else ""
2036
+ if pol != "positive":
2037
+ continue
2038
+
2039
+ # 4.1 Find precursor mass from meta_data
2040
+ if 'HMDB.ID' in meta_data.columns and (molecule_id is not None):
2041
+ row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
2042
+ else:
2043
+ row = meta_data.loc[meta_data.index == molecule_id]
2044
+
2045
+ if row.empty or ('precursor_mass' not in row.columns):
2046
+ continue
2047
+
2048
+ try:
2049
+ precursor_val = float(row['precursor_mass'].values[0])
2050
+ except Exception:
2051
+ continue
2052
+
2053
+ # 前体范围 [10, 1000);并避免 1000.00 被格式化后越界
2054
+ if pd.isna(precursor_val) or (precursor_val < 10.0) or (precursor_val >= 1000.0):
2055
+ continue
2056
+ precursor_val = min(precursor_val, 999.99)
2057
+ precursor_str = "%.2f" % precursor_val
2058
+
2059
+ # 4.2 过滤峰到 [10, 1000)
2060
+ if peaks.shape[0] != intensities.shape[0]:
2061
+ # 长度不一致直接跳过(也可选择截断到对齐最短)
2062
+ n = min(len(peaks), len(intensities))
2063
+ peaks = peaks[:n]
2064
+ intensities = intensities[:n]
2065
+
2066
+ mask = (peaks >= 10.0) & (peaks < 1000.0) & np.isfinite(peaks) & np.isfinite(intensities)
2067
+ peaks = peaks[mask]
2068
+ intensities = intensities[mask]
2069
+
2070
+ if peaks.size == 0:
2071
+ continue
2072
+
2073
+ # 4.3 按强度选 Top-K 碎片(最多 100,且保证前体+碎片 <= maxlen)
2074
+ if peaks.size > max_frag:
2075
+ idx = np.argpartition(intensities, -max_frag)[-max_frag:]
2076
+ # 选完后按 m/z 升序排序(也可按强度降序,看你需求)
2077
+ order = np.argsort(peaks[idx])
2078
+ idx = idx[order]
2079
+ peaks_sel = peaks[idx]
2080
+ intens_sel = intensities[idx]
2081
+ else:
2082
+ # 直接按 m/z 升序
2083
+ order = np.argsort(peaks)
2084
+ peaks_sel = peaks[order]
2085
+ intens_sel = intensities[order]
2086
+
2087
+ # 4.4 构建 token 序列(前体在最前)
2088
+ peaks_str = ["%.2f" % p for p in peaks_sel]
2089
+ try:
2090
+ token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
2091
+ except KeyError:
2092
+ # 理论上不会发生(我们已限制到 [10, 999.99]),但以防万一
2093
+ continue
2094
+
2095
+ # 4.5 强度:在最前 prepend 2,并按你原有逻辑整体归一化
2096
+ intens_seq = np.hstack((2.0, intens_sel))
2097
+ max_intensity = float(np.max(intens_seq)) if intens_seq.size else 1.0
2098
+ if max_intensity != 0.0:
2099
+ intens_seq = intens_seq / max_intensity
2100
+
2101
+ # 4.6 Pad 或截断到 maxlen(双序列严格对齐)
2102
+ if len(token_ids) > maxlen:
2103
+ token_ids = token_ids[:maxlen]
2104
+ intens_seq = intens_seq[:maxlen]
2105
+
2106
+ n_pad = maxlen - len(token_ids)
2107
+ if n_pad > 0:
2108
+ token_ids += [word2idx['[PAD]']] * n_pad
2109
+ intens_seq = np.hstack([intens_seq, np.zeros(n_pad, dtype=float)])
2110
+
2111
+ # 4.7 Save processed result
2112
+ ms_data[ms2_id] = {
2113
+ 'mz': token_ids,
2114
+ 'intensity': intens_seq.tolist(),
2115
+ 'molecule_id': molecule_id
2116
+ }
2117
+
2118
+ # 5) Return processed data and dictionary
2119
+ return ms_data, word2idx
2120
+
2121
+
2122
+
2123
+ @staticmethod
2124
+ def load_external_test_dataset(
2125
+ external_data_dir,
2126
+ biotext_dir,
2127
+ paraphrase_dir,
2128
+ tokenizer,
2129
+ args,
2130
+ dataset_configs=None,
2131
+ **kwargs
2132
+ ):
2133
+ """
2134
+ 加载并处理外部测试数据集(如HILIC和RPLC)
2135
+
2136
+ 参数:
2137
+ external_data_dir (str): 外部数据目录路径
2138
+ biotext_dir (str): BioText文本文件目录
2139
+ paraphrase_dir (str): Paraphrase文本文件目录
2140
+ tokenizer: 文本tokenizer
2141
+ args: 包含预处理参数的args对象(precursor_mode, precursor_value, n_workers等)
2142
+ dataset_configs (list): 数据集配置列表,每个配置包含name, ms2_file, meta_file
2143
+ **kwargs: 传递给MS2BioTextDataset构造函数的其他参数
2144
+
2145
+ 返回:
2146
+ tuple: (external_test_dataset, data_statistics)
2147
+ - external_test_dataset: MS2BioTextDataset实例
2148
+ - data_statistics: 包含数据统计信息的字典
2149
+ """
2150
+ import pickle
2151
+ import pandas as pd
2152
+ import os
2153
+ from pathlib import Path
2154
+
2155
+ # 默认配置(HILIC和RPLC)
2156
+ if dataset_configs is None:
2157
+ dataset_configs = [
2158
+ {
2159
+ 'name': 'HILIC',
2160
+ 'ms2_file': os.path.join(external_data_dir, 'hilic_ms_data.pkl'),
2161
+ 'meta_file': os.path.join(external_data_dir, 'hilic_meta_data.csv'),
2162
+ },
2163
+ {
2164
+ 'name': 'RPLC',
2165
+ 'ms2_file': os.path.join(external_data_dir, 'rplc_ms_data.pkl'),
2166
+ 'meta_file': os.path.join(external_data_dir, 'rplc_meta_data.csv'),
2167
+ }
2168
+ ]
2169
+
2170
+ print("\n" + "="*60)
2171
+ print("加载外部测试数据集...")
2172
+ print("="*60)
2173
+
2174
+ # 1. 加载所有数据集
2175
+ all_ms2_data = {}
2176
+ all_meta_data = []
2177
+
2178
+ for config in dataset_configs:
2179
+ print(f"\n📁 加载 {config['name']} 数据集...")
2180
+
2181
+ # 加载MS2数据
2182
+ with open(config['ms2_file'], 'rb') as f:
2183
+ ms2_data = pickle.load(f)
2184
+
2185
+ # 加载Meta数据
2186
+ meta_data = pd.read_csv(config['meta_file'])
2187
+
2188
+ print(f" ✓ {config['name']}: {len(ms2_data)} 光谱, {len(meta_data)} meta")
2189
+
2190
+ # 合并MS2
2191
+ all_ms2_data.update(ms2_data)
2192
+ all_meta_data.append(meta_data)
2193
+
2194
+ # 2. 合并Meta数据(确保列对齐)
2195
+ if len(all_meta_data) > 1:
2196
+ all_cols = set()
2197
+ for df in all_meta_data:
2198
+ all_cols.update(df.columns)
2199
+ all_cols = sorted(all_cols)
2200
+
2201
+ aligned_meta_data = []
2202
+ for df in all_meta_data:
2203
+ df = df.reindex(columns=all_cols, fill_value=None)
2204
+ aligned_meta_data.append(df)
2205
+
2206
+ external_meta_data = pd.concat(aligned_meta_data, ignore_index=True)
2207
+ else:
2208
+ external_meta_data = all_meta_data[0]
2209
+
2210
+ external_ms2_data = all_ms2_data
2211
+
2212
+ print(f"\n✓ 合并后外部数据集: {len(external_ms2_data)} 光谱, {len(external_meta_data)} meta")
2213
+
2214
+ # 3. 设置HMDB.ID为索引
2215
+ if 'HMDB.ID' in external_meta_data.columns:
2216
+ external_meta_data = external_meta_data.set_index('HMDB.ID')
2217
+ print(f" 已设置HMDB.ID为索引")
2218
+
2219
+ # 4. 确保MS2数据格式正确(添加molecule_id字段)
2220
+ print("\n🔧 修正MS2数据格式...")
2221
+ for spectrum_id, spectrum_data in external_ms2_data.items():
2222
+ if 'molecule_id' not in spectrum_data:
2223
+ molecule_id = spectrum_id.split('_')[0]
2224
+ spectrum_data['molecule_id'] = molecule_id
2225
+
2226
+ # 5. 获取所有unique的HMDB IDs
2227
+ unique_hmdb_ids = set()
2228
+ for spec_id in external_ms2_data.keys():
2229
+ hmdb_id = spec_id.split('_')[0]
2230
+ unique_hmdb_ids.add(hmdb_id)
2231
+
2232
+ print(f" 外部数据集包含 {len(unique_hmdb_ids)} 个unique HMDB IDs")
2233
+
2234
+ # 6. 加载对应的BioText数据
2235
+ print("\n📚 加载BioText数据...")
2236
+ external_biotext_data = {}
2237
+ missing_biotext = []
2238
+
2239
+ biotext_dir = Path(biotext_dir)
2240
+ paraphrase_dir = Path(paraphrase_dir) if paraphrase_dir else None
2241
+
2242
+ for hmdb_id in unique_hmdb_ids:
2243
+ # 处理异常的HMDB ID(如包含{}的)
2244
+ if '{}' in hmdb_id:
2245
+ missing_biotext.append(hmdb_id)
2246
+ continue
2247
+
2248
+ biotext_file = biotext_dir / f"{hmdb_id}.txt"
2249
+ if biotext_file.exists():
2250
+ with open(biotext_file, 'r', encoding='utf-8') as f:
2251
+ original_text = f.read().strip()
2252
+
2253
+ # 加载paraphrase(如果有)
2254
+ paraphrases = []
2255
+ if paraphrase_dir:
2256
+ paraphrase_file = paraphrase_dir / f"{hmdb_id}_paraphrase.txt"
2257
+ if paraphrase_file.exists():
2258
+ with open(paraphrase_file, 'r', encoding='utf-8') as pf:
2259
+ content = pf.read()
2260
+ versions = content.split("=== version")
2261
+ for version in versions[1:]:
2262
+ _, text = version.split("===", 1)
2263
+ text = text.strip()
2264
+ if text:
2265
+ paraphrases.append(text)
2266
+
2267
+ external_biotext_data[hmdb_id] = {
2268
+ 'original': original_text,
2269
+ 'paraphrases': paraphrases
2270
+ }
2271
+ else:
2272
+ missing_biotext.append(hmdb_id)
2273
+
2274
+ print(f" ✓ 成功加载 {len(external_biotext_data)} 个BioText")
2275
+ print(f" ✗ 缺失BioText: {len(missing_biotext)} 个")
2276
+
2277
+ # 7. 处理缺失的biotext(使用drop方法)
2278
+ initial_ms2_count = len(external_ms2_data)
2279
+ external_ms2_data, _ = MS2BioTextDataset.missing_biotext_handling(
2280
+ external_ms2_data,
2281
+ external_biotext_data,
2282
+ method="drop"
2283
+ )
2284
+ print(f" 删除 {initial_ms2_count - len(external_ms2_data)} 条缺失biotext的光谱")
2285
+
2286
+ # 8. 更新meta_data,只保留有MS2数据的条目
2287
+ remaining_hmdb_ids = set()
2288
+ for spectrum_id in external_ms2_data.keys():
2289
+ hmdb_id = spectrum_id.split('_')[0]
2290
+ remaining_hmdb_ids.add(hmdb_id)
2291
+
2292
+ external_meta_data = external_meta_data[external_meta_data.index.isin(remaining_hmdb_ids)]
2293
+
2294
+ # 9. 填充precursor数据
2295
+ print("\n⚙️ 填充precursor数据...")
2296
+ external_meta_data = MS2BioTextDataset.fill_precursor_data(
2297
+ external_meta_data,
2298
+ external_ms2_data
2299
+ )
2300
+
2301
+ # 10. 预处理MS2数据(测试集不做数据增强)
2302
+ print("\n🚀 预处理MS2数据(不进行数据增强)...")
2303
+ external_processed_ms2, external_word2idx = MS2BioTextDataset.preprocess_ms2_data_positive_only_parallel(
2304
+ external_ms2_data,
2305
+ external_meta_data,
2306
+ n_workers=getattr(args, 'n_workers', 4),
2307
+ precursor_mode=getattr(args, 'precursor_mode', 'auto'),
2308
+ precursor_value=getattr(args, 'precursor_value', 2.0)
2309
+ )
2310
+
2311
+ # 11. 统计信息
2312
+ data_statistics = {
2313
+ 'original_ms2_count': initial_ms2_count,
2314
+ 'processed_ms2_count': len(external_processed_ms2),
2315
+ 'meta_count': len(external_meta_data),
2316
+ 'biotext_count': len(external_biotext_data),
2317
+ 'unique_molecules': len(remaining_hmdb_ids),
2318
+ 'vocab_size': len(external_word2idx),
2319
+ 'datasets': [config['name'] for config in dataset_configs]
2320
+ }
2321
+
2322
+ print("\n" + "="*60)
2323
+ print("📊 外部测试数据集最终统计")
2324
+ print("="*60)
2325
+ print(f" 原始MS2光谱数: {data_statistics['original_ms2_count']}")
2326
+ print(f" 处理后MS2光谱数: {data_statistics['processed_ms2_count']}")
2327
+ print(f" Meta记录数: {data_statistics['meta_count']}")
2328
+ print(f" BioText数: {data_statistics['biotext_count']}")
2329
+ print(f" Unique分子数: {data_statistics['unique_molecules']}")
2330
+ print(f" 词汇表大小: {data_statistics['vocab_size']}")
2331
+ print(f" 数据集来源: {', '.join(data_statistics['datasets'])}")
2332
+
2333
+ # 12. 创建Dataset实例
2334
+ print("\n🎯 创建外部测试Dataset实例...")
2335
+ external_test_dataset = MS2BioTextDataset(
2336
+ ms2_data=external_processed_ms2,
2337
+ meta_data=external_meta_data,
2338
+ biotext_data=external_biotext_data,
2339
+ tokenizer=tokenizer,
2340
+ use_paraphrase=False,
2341
+ **kwargs
2342
+ )
2343
+
2344
+ print(f"✅ 外部测试Dataset创建成功!大小: {len(external_test_dataset)}")
2345
+
2346
+ return external_test_dataset, data_statistics
2347
+
2348
+
2349
+ # @staticmethod
2350
+ # def create_train_test_datasets_from_file(
2351
+ # data_dir,
2352
+ # ms2_data,
2353
+ # meta_data,
2354
+ # biotext_data,
2355
+ # tokenizer,
2356
+ # test_size=0.2,
2357
+ # random_state=42,
2358
+ # use_paraphrase = False,
2359
+ # **kwargs
2360
+ # ):
2361
+ # """
2362
+ # Split data into training and test sets based on molecule IDs.
2363
+ # This version checks whether a local file with pre-defined splits exists.
2364
+ # If it exists, it loads the split; otherwise, it creates the split and saves it to file.
2365
+ # (Enhanced to handle empty or corrupted JSON files gracefully)
2366
+
2367
+ # Parameters:
2368
+ # data_dir (str or Path): Path to the directory containing the split file (molecule_split.json).
2369
+ # ms2_data (dict): Full MS2 data dictionary.
2370
+ # meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
2371
+ # biotext_data (dict): Full BioText data dictionary.
2372
+ # tokenizer: Tokenizer used for initializing the Dataset.
2373
+ # test_size (float): Proportion of test set (used only when creating the split).
2374
+ # random_state (int): Random seed (used only when creating the split).
2375
+ # **kwargs: Other arguments passed to the MS2BioTextDataset constructor.
2376
+
2377
+ # Returns:
2378
+ # tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
2379
+ # """
2380
+ # print("Preparing training and test datasets (with file persistence)...")
2381
+
2382
+ # # --- 1. Define split file path ---
2383
+ # data_dir = Path(data_dir)
2384
+ # split_file_path = data_dir / 'molecule_split.json'
2385
+
2386
+ # # --- 2. Load existing split file if valid; otherwise create new split ---
2387
+ # train_mol_ids, test_mol_ids = None, None
2388
+
2389
+ # if split_file_path.exists() and split_file_path.stat().st_size > 0:
2390
+ # print(f"✅ Found existing split file: {split_file_path}")
2391
+ # print(" Loading molecule IDs from file...")
2392
+ # try:
2393
+ # with open(split_file_path, 'r', encoding='utf-8') as f:
2394
+ # split_ids = json.load(f)
2395
+ # train_mol_ids = split_ids['train_ids']
2396
+ # test_mol_ids = split_ids['test_ids']
2397
+ # except json.JSONDecodeError:
2398
+ # print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed (may be empty or corrupted). Will recreate.")
2399
+ # except KeyError:
2400
+ # print(f" ⚠️ Warning: File '{split_file_path}' has invalid format (missing 'train_ids' or 'test_ids'). Will recreate.")
2401
+
2402
+ # if train_mol_ids is None or test_mol_ids is None:
2403
+ # print(f"⚠️ No valid split file found or could not be loaded. Creating a new split...")
2404
+ # valid_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
2405
+ # all_molecule_ids = [mol_id for mol_id in meta_data.index.unique() if mol_id in valid_molecule_ids]
2406
+ # print(f" Found {len(all_molecule_ids)} unique molecules for splitting.")
2407
+
2408
+ # train_mol_ids, test_mol_ids = train_test_split(
2409
+ # all_molecule_ids,
2410
+ # test_size=test_size,
2411
+ # random_state=random_state
2412
+ # )
2413
+
2414
+ # print(f" Saving new split to: {split_file_path}")
2415
+ # split_data_to_save = {'train_ids': train_mol_ids, 'test_ids': test_mol_ids}
2416
+ # split_file_path.parent.mkdir(parents=True, exist_ok=True)
2417
+ # with open(split_file_path, 'w', encoding='utf-8') as f:
2418
+ # json.dump(split_data_to_save, f, indent=4)
2419
+ # print(" Split file saved successfully.")
2420
+
2421
+
2422
+ # # --- 3. Filter data sources by ID lists ---
2423
+ # train_mol_ids_set = set(train_mol_ids)
2424
+ # test_mol_ids_set = set(test_mol_ids)
2425
+
2426
+ # train_meta_data = meta_data[meta_data.index.isin(train_mol_ids_set)]
2427
+ # test_meta_data = meta_data[meta_data.index.isin(test_mol_ids_set)]
2428
+
2429
+ # train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_mol_ids_set}
2430
+ # test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_mol_ids_set}
2431
+ # train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in train_mol_ids_set}
2432
+ # test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in test_mol_ids_set}
2433
+
2434
+ # print(f"Data filtering completed:")
2435
+
2436
+ # # --- 4. Create Dataset instances ---
2437
+ # print("Creating training Dataset instance...")
2438
+ # train_dataset = MS2BioTextDataset(
2439
+ # ms2_data=train_ms2_data,
2440
+ # meta_data=train_meta_data,
2441
+ # biotext_data=train_biotext_data,
2442
+ # tokenizer=tokenizer,
2443
+ # use_paraphrase = use_paraphrase,
2444
+ # **kwargs
2445
+ # )
2446
+
2447
+ # print("Creating test Dataset instance...")
2448
+ # test_dataset = MS2BioTextDataset(
2449
+ # ms2_data=test_ms2_data,
2450
+ # meta_data=test_meta_data,
2451
+ # biotext_data=test_biotext_data,
2452
+ # tokenizer=tokenizer,
2453
+ # use_paraphrase=False,
2454
+ # **kwargs
2455
+ # )
2456
+
2457
+ # return train_dataset, test_dataset
2458
+
2459
+
2460
+
2461
+
2462
+ # @staticmethod
2463
+ # def create_train_test_datasets_from_file(
2464
+ # data_dir,
2465
+ # ms2_data,
2466
+ # meta_data,
2467
+ # biotext_data,
2468
+ # tokenizer,
2469
+ # test_size=0.2,
2470
+ # use_paraphrase = False,
2471
+ # **kwargs
2472
+ # ):
2473
+ # """
2474
+ # Split data into training and test sets based on MS2 spectra within each molecule.
2475
+ # For each molecule with multiple MS2 spectra, one MS2 is reserved for testing,
2476
+ # and the rest are used for training. Molecules with only one MS2 spectrum are
2477
+ # only included in the training set.
2478
+
2479
+ # Parameters:
2480
+ # data_dir (str or Path): Path to the directory containing the split file (ms2_split.json).
2481
+ # ms2_data (dict): Full MS2 data dictionary {ms2_id: {molecule_id: ..., ...}}.
2482
+ # meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
2483
+ # biotext_data (dict): Full BioText data dictionary {molecule_id: text}.
2484
+ # tokenizer: Tokenizer used for initializing the Dataset.
2485
+ # test_size (float): Deprecated in this version (kept for compatibility).
2486
+ # random_state (int): Random seed for reproducible splits.
2487
+ # **kwargs: Other arguments passed to the MS2BioTextDataset constructor.
2488
+
2489
+ # Returns:
2490
+ # tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
2491
+ # """
2492
+ # print("Preparing training and test datasets with per-molecule MS2 splitting...")
2493
+
2494
+ # # --- 1. Define split file path ---
2495
+ # data_dir = Path(data_dir)
2496
+ # split_file_path = data_dir / 'ms2_split.json' # 改名以区分新的划分方式
2497
+
2498
+ # # --- 2. Load existing split file if valid; otherwise create new split ---
2499
+ # train_ms2_ids, test_ms2_ids = None, None
2500
+
2501
+ # if split_file_path.exists() and split_file_path.stat().st_size > 0:
2502
+ # print(f"✅ Found existing split file: {split_file_path}")
2503
+ # print(" Loading MS2 IDs from file...")
2504
+ # try:
2505
+ # with open(split_file_path, 'r', encoding='utf-8') as f:
2506
+ # split_ids = json.load(f)
2507
+ # train_ms2_ids = split_ids['train_ms2_ids']
2508
+ # test_ms2_ids = split_ids['test_ms2_ids']
2509
+ # except json.JSONDecodeError:
2510
+ # print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed. Will recreate.")
2511
+ # except KeyError:
2512
+ # print(f" ⚠️ Warning: File '{split_file_path}' has invalid format. Will recreate.")
2513
+
2514
+ # if train_ms2_ids is None or test_ms2_ids is None:
2515
+ # print(f"⚠️ No valid split file found. Creating a new MS2-level split...")
2516
+
2517
+ # # 按分子ID分组MS2谱图
2518
+ # molecule_to_ms2 = {}
2519
+ # for ms2_id, ms2_info in ms2_data.items():
2520
+ # mol_id = ms2_info['molecule_id']
2521
+ # if mol_id not in molecule_to_ms2:
2522
+ # molecule_to_ms2[mol_id] = []
2523
+ # molecule_to_ms2[mol_id].append(ms2_id)
2524
+
2525
+ # # 统计信息
2526
+ # single_ms2_molecules = []
2527
+ # multi_ms2_molecules = []
2528
+ # for mol_id, ms2_list in molecule_to_ms2.items():
2529
+ # if len(ms2_list) == 1:
2530
+ # single_ms2_molecules.append(mol_id)
2531
+ # else:
2532
+ # multi_ms2_molecules.append(mol_id)
2533
+
2534
+ # print(f" Found {len(single_ms2_molecules)} molecules with single MS2 spectrum")
2535
+ # print(f" Found {len(multi_ms2_molecules)} molecules with multiple MS2 spectra")
2536
+
2537
+
2538
+ # train_ms2_ids = []
2539
+ # test_ms2_ids = []
2540
+
2541
+ # # 处理只有一个MS2的分子:全部放入训练集
2542
+ # for mol_id in single_ms2_molecules:
2543
+ # train_ms2_ids.extend(molecule_to_ms2[mol_id])
2544
+
2545
+ # # 处理有多个MS2的分子:随机选择一个作为测试集,其余作为训练集
2546
+ # for mol_id in multi_ms2_molecules:
2547
+ # ms2_list = molecule_to_ms2[mol_id]
2548
+ # # 随机选择一个MS2作为测试集
2549
+ # test_ms2_id = np.random.choice(ms2_list)
2550
+ # test_ms2_ids.append(test_ms2_id)
2551
+ # # 其余的作为训练集
2552
+ # train_ms2_ids.extend([ms2_id for ms2_id in ms2_list if ms2_id != test_ms2_id])
2553
+
2554
+ # print(f" Split results: {len(train_ms2_ids)} training MS2, {len(test_ms2_ids)} test MS2")
2555
+
2556
+ # # 保存划分结果
2557
+ # print(f" Saving new split to: {split_file_path}")
2558
+ # split_data_to_save = {
2559
+ # 'train_ms2_ids': train_ms2_ids,
2560
+ # 'test_ms2_ids': test_ms2_ids
2561
+ # }
2562
+ # split_file_path.parent.mkdir(parents=True, exist_ok=True)
2563
+ # with open(split_file_path, 'w', encoding='utf-8') as f:
2564
+ # json.dump(split_data_to_save, f, indent=4)
2565
+ # print(" Split file saved successfully.")
2566
+
2567
+ # # --- 3. Filter data sources by MS2 ID lists ---
2568
+ # train_ms2_ids_set = set(train_ms2_ids)
2569
+ # test_ms2_ids_set = set(test_ms2_ids)
2570
+
2571
+ # # 过滤MS2数据
2572
+ # train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in train_ms2_ids_set}
2573
+ # test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in test_ms2_ids_set}
2574
+
2575
+ # # 获取涉及的分子ID
2576
+ # train_molecule_ids = set(info['molecule_id'] for info in train_ms2_data.values())
2577
+ # test_molecule_ids = set(info['molecule_id'] for info in test_ms2_data.values())
2578
+
2579
+ # # 过滤元数据和biotext数据
2580
+ # # 注意:训练集和测试集可能包含相同的分子ID(因为同一分子的不同MS2可能分布在训练集和测试集中)
2581
+ # train_meta_data = meta_data[meta_data.index.isin(train_molecule_ids)]
2582
+ # test_meta_data = meta_data[meta_data.index.isin(test_molecule_ids)]
2583
+
2584
+ # train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_molecule_ids}
2585
+ # test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_molecule_ids}
2586
+
2587
+ # print(f"Data filtering completed:")
2588
+ # print(f" Training: {len(train_ms2_data)} MS2 spectra from {len(train_molecule_ids)} molecules")
2589
+ # print(f" Test: {len(test_ms2_data)} MS2 spectra from {len(test_molecule_ids)} molecules")
2590
+
2591
+ # # --- 4. Create Dataset instances ---
2592
+ # print("Creating training Dataset instance...")
2593
+ # train_dataset = MS2BioTextDataset(
2594
+ # ms2_data=train_ms2_data,
2595
+ # meta_data=train_meta_data,
2596
+ # biotext_data=train_biotext_data,
2597
+ # tokenizer=tokenizer,
2598
+ # use_paraphrase=use_paraphrase,
2599
+ # **kwargs
2600
+ # )
2601
+
2602
+ # print("Creating test Dataset instance...")
2603
+ # test_dataset = MS2BioTextDataset(
2604
+ # ms2_data=test_ms2_data,
2605
+ # meta_data=test_meta_data,
2606
+ # biotext_data=test_biotext_data,
2607
+ # tokenizer=tokenizer,
2608
+ # use_paraphrase=False,
2609
+ # **kwargs
2610
+ # )
2611
+
2612
+ # return train_dataset, test_dataset
2613
+
2614
+ @staticmethod
2615
+ def filter_shared_texts(biotext_data, max_sharing_molecules=5):
2616
+ """
2617
+ 删除被过多molecule共享的text
2618
+
2619
+ Args:
2620
+ biotext_data: {molecule_id: [{'type': ..., 'text': ...}, ...]}
2621
+ max_sharing_molecules: text最多可以被多少个molecule共享
2622
+
2623
+ Returns:
2624
+ filtered_biotext_data: 清洗后的数据
2625
+ stats: 统计信息
2626
+ """
2627
+ from collections import defaultdict
2628
+
2629
+ print(f"\n=== Filtering shared texts (max_sharing={max_sharing_molecules}) ===")
2630
+
2631
+ # 1. 构建text -> molecules的倒排索引
2632
+ text_to_molecules = defaultdict(set)
2633
+
2634
+ for mol_id, entry in biotext_data.items():
2635
+ texts = []
2636
+ if isinstance(entry, list):
2637
+ texts = [record['text'] for record in entry]
2638
+ elif isinstance(entry, dict):
2639
+ texts = [entry.get("original", "")] + entry.get("paraphrases", [])
2640
+ elif isinstance(entry, str):
2641
+ texts = [entry]
2642
+
2643
+ for text in texts:
2644
+ if text: # 避免空字符串
2645
+ text_to_molecules[text].add(mol_id)
2646
+
2647
+ # 2. 找出需要删除的高频text
2648
+ texts_to_remove = set()
2649
+ for text, molecules in text_to_molecules.items():
2650
+ if len(molecules) > max_sharing_molecules:
2651
+ texts_to_remove.add(text)
2652
+
2653
+ print(f"Found {len(texts_to_remove)} texts shared by >{max_sharing_molecules} molecules")
2654
+
2655
+ # 3. 从每个molecule的候选text中删除这些高频text
2656
+ filtered_biotext_data = {}
2657
+ total_removed = 0
2658
+ molecules_with_no_text = []
2659
+
2660
+ for mol_id, entry in biotext_data.items():
2661
+ if isinstance(entry, list):
2662
+ # 新格式:列表
2663
+ filtered_entry = [record for record in entry
2664
+ if record['text'] not in texts_to_remove]
2665
+
2666
+ if filtered_entry:
2667
+ filtered_biotext_data[mol_id] = filtered_entry
2668
+ else:
2669
+ molecules_with_no_text.append(mol_id)
2670
+
2671
+ total_removed += len(entry) - len(filtered_entry)
2672
+
2673
+ elif isinstance(entry, dict):
2674
+ # 旧格式:字典
2675
+ original = entry.get("original", "")
2676
+ paraphrases = entry.get("paraphrases", [])
2677
+
2678
+ filtered_paraphrases = [p for p in paraphrases if p not in texts_to_remove]
2679
+
2680
+ # 如果original也被删除了,用第一个paraphrase作为original
2681
+ if original in texts_to_remove:
2682
+ if filtered_paraphrases:
2683
+ original = filtered_paraphrases[0]
2684
+ filtered_paraphrases = filtered_paraphrases[1:]
2685
+ else:
2686
+ molecules_with_no_text.append(mol_id)
2687
+ continue
2688
+
2689
+ filtered_biotext_data[mol_id] = {
2690
+ 'original': original,
2691
+ 'paraphrases': filtered_paraphrases
2692
+ }
2693
+
2694
+ original_count = 1 if entry.get("original", "") not in texts_to_remove else 0
2695
+ total_removed += (len(paraphrases) - len(filtered_paraphrases) +
2696
+ (1 - original_count))
2697
+
2698
+ elif isinstance(entry, str):
2699
+ # 字符串格式
2700
+ if entry not in texts_to_remove:
2701
+ filtered_biotext_data[mol_id] = entry
2702
+ else:
2703
+ molecules_with_no_text.append(mol_id)
2704
+
2705
+ # 4. 统计信息
2706
+ print(f"Statistics:")
2707
+ print(f" Total text entries removed: {total_removed}")
2708
+ print(f" Molecules before filtering: {len(biotext_data)}")
2709
+ print(f" Molecules after filtering: {len(filtered_biotext_data)}")
2710
+ print(f" Molecules with no text left: {len(molecules_with_no_text)}")
2711
+
2712
+ if molecules_with_no_text:
2713
+ print(f" Warning: {len(molecules_with_no_text)} molecules lost all texts!")
2714
+ print(f" First 5: {molecules_with_no_text[:5]}")
2715
+
2716
+ # 5. 验证过滤效果
2717
+ text_to_molecules_after = defaultdict(set)
2718
+ for mol_id, entry in filtered_biotext_data.items():
2719
+ texts = []
2720
+ if isinstance(entry, list):
2721
+ texts = [record['text'] for record in entry]
2722
+ elif isinstance(entry, dict):
2723
+ texts = [entry.get("original", "")] + entry.get("paraphrases", [])
2724
+ elif isinstance(entry, str):
2725
+ texts = [entry]
2726
+
2727
+ for text in texts:
2728
+ if text:
2729
+ text_to_molecules_after[text].add(mol_id)
2730
+
2731
+ max_sharing_after = max(len(mols) for mols in text_to_molecules_after.values()) if text_to_molecules_after else 0
2732
+ print(f" Max molecules sharing one text after filtering: {max_sharing_after}")
2733
+
2734
+ return filtered_biotext_data, {
2735
+ 'removed_texts': len(texts_to_remove),
2736
+ 'removed_entries': total_removed,
2737
+ 'molecules_no_text': len(molecules_with_no_text),
2738
+ 'max_sharing_after': max_sharing_after
2739
+ }
2740
+
2741
+
2742
+
2743
+ @staticmethod
2744
+ def create_train_test_datasets_from_file(
2745
+ data_dir,
2746
+ ms2_data,
2747
+ meta_data,
2748
+ biotext_data,
2749
+ tokenizer,
2750
+ word2idx,
2751
+ args,
2752
+ test_size=0.2,
2753
+ use_paraphrase = False,
2754
+ **kwargs
2755
+ ):
2756
+ """
2757
+ Split data into training and test sets.
2758
+ If a split file exists, use its test_ms2_ids as the test set,
2759
+ and use ALL OTHER MS2 spectra (including any new data) as the training set.
2760
+ If no split file exists, create a new split following the original logic.
2761
+
2762
+ Parameters:
2763
+ data_dir (str or Path): Path to the directory containing the split file (ms2_split.json).
2764
+ ms2_data (dict): Full MS2 data dictionary {ms2_id: {molecule_id: ..., ...}}.
2765
+ meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
2766
+ biotext_data (dict): Full BioText data dictionary {molecule_id: text}.
2767
+ tokenizer: Tokenizer used for initializing the Dataset.
2768
+ test_size (float): Deprecated in this version (kept for compatibility).
2769
+ **kwargs: Other arguments passed to the MS2BioTextDataset constructor.
2770
+
2771
+ Returns:
2772
+ tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
2773
+ """
2774
+ print("Preparing training and test datasets with per-molecule MS2 splitting...")
2775
+
2776
+ # --- 1. Define split file path ---
2777
+ data_dir = Path(data_dir)
2778
+ split_file_path = data_dir / 'ms2_split.json'
2779
+
2780
+ # --- 2. Load existing split file if valid; otherwise create new split ---
2781
+ test_ms2_ids = None
2782
+
2783
+ if split_file_path.exists() and split_file_path.stat().st_size > 0:
2784
+ print(f"✅ Found existing split file: {split_file_path}")
2785
+ print(" Loading test MS2 IDs from file...")
2786
+ try:
2787
+ with open(split_file_path, 'r', encoding='utf-8') as f:
2788
+ split_ids = json.load(f)
2789
+ test_ms2_ids = split_ids['test_ms2_ids']
2790
+ print(f" Loaded {len(test_ms2_ids)} test MS2 IDs from existing split.")
2791
+ except json.JSONDecodeError:
2792
+ print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed. Will recreate.")
2793
+ except KeyError:
2794
+ print(f" ⚠️ Warning: File '{split_file_path}' has invalid format. Will recreate.")
2795
+
2796
+ if test_ms2_ids is None:
2797
+ print(f"⚠️ No valid split file found. Creating a new MS2-level split...")
2798
+
2799
+ # 按分子ID分组MS2谱图
2800
+ molecule_to_ms2 = {}
2801
+ for ms2_id, ms2_info in ms2_data.items():
2802
+ mol_id = ms2_info['molecule_id']
2803
+ if mol_id not in molecule_to_ms2:
2804
+ molecule_to_ms2[mol_id] = []
2805
+ molecule_to_ms2[mol_id].append(ms2_id)
2806
+
2807
+ # 统计信息
2808
+ single_ms2_molecules = []
2809
+ multi_ms2_molecules = []
2810
+ for mol_id, ms2_list in molecule_to_ms2.items():
2811
+ if len(ms2_list) == 1:
2812
+ single_ms2_molecules.append(mol_id)
2813
+ else:
2814
+ multi_ms2_molecules.append(mol_id)
2815
+
2816
+ print(f" Found {len(single_ms2_molecules)} molecules with single MS2 spectrum")
2817
+ print(f" Found {len(multi_ms2_molecules)} molecules with multiple MS2 spectra")
2818
+
2819
+ test_ms2_ids = []
2820
+
2821
+ # 处理有多个MS2的分子:随机选择一个作为测试集
2822
+ for mol_id in multi_ms2_molecules:
2823
+ ms2_list = molecule_to_ms2[mol_id]
2824
+ # 随机选择一个MS2作为测试集
2825
+ test_ms2_id = np.random.choice(ms2_list)
2826
+ test_ms2_ids.append(test_ms2_id)
2827
+
2828
+ print(f" Created new test set with {len(test_ms2_ids)} MS2 spectra")
2829
+
2830
+ # 保存划分结果(只保存test_ms2_ids,train会动态计算)
2831
+ print(f" Saving new split to: {split_file_path}")
2832
+ split_data_to_save = {
2833
+ 'test_ms2_ids': test_ms2_ids,
2834
+ 'note': 'Training set uses all MS2 IDs not in test_ms2_ids'
2835
+ }
2836
+ split_file_path.parent.mkdir(parents=True, exist_ok=True)
2837
+ with open(split_file_path, 'w', encoding='utf-8') as f:
2838
+ json.dump(split_data_to_save, f, indent=4)
2839
+ print(" Split file saved successfully.")
2840
+
2841
+ # --- 3. Create train set: ALL MS2 IDs except those in test set ---
2842
+ test_ms2_ids_set = set(test_ms2_ids)
2843
+ all_ms2_ids = set(ms2_data.keys())
2844
+ train_ms2_ids_set = all_ms2_ids - test_ms2_ids_set
2845
+
2846
+ print(f"\n📊 Dataset Statistics:")
2847
+ print(f" Total MS2 spectra: {len(all_ms2_ids)}")
2848
+ print(f" Test MS2 spectra: {len(test_ms2_ids_set)}")
2849
+ print(f" Training MS2 spectra: {len(train_ms2_ids_set)}")
2850
+
2851
+ # 过滤MS2数据
2852
+ train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in train_ms2_ids_set}
2853
+ test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in test_ms2_ids_set}
2854
+
2855
+ # 获取涉及的分子ID
2856
+ train_molecule_ids = set(info['molecule_id'] for info in train_ms2_data.values())
2857
+ test_molecule_ids = set(info['molecule_id'] for info in test_ms2_data.values())
2858
+
2859
+ # 过滤元数据和biotext数据
2860
+ train_meta_data = meta_data[meta_data.index.isin(train_molecule_ids)]
2861
+ test_meta_data = meta_data[meta_data.index.isin(test_molecule_ids)]
2862
+
2863
+ train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_molecule_ids}
2864
+ test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_molecule_ids}
2865
+
2866
+ print(f"\nData filtering completed:")
2867
+ print(f" Training: {len(train_ms2_data)} MS2 spectra from {len(train_molecule_ids)} molecules")
2868
+ print(f" Test: {len(test_ms2_data)} MS2 spectra from {len(test_molecule_ids)} molecules")
2869
+
2870
+ # --- 4. Create Dataset instances ---
2871
+ print("\nCreating training Dataset instance...")
2872
+ train_dataset = MS2BioTextDataset(
2873
+ ms2_data=train_ms2_data,
2874
+ meta_data=train_meta_data,
2875
+ biotext_data=train_biotext_data,
2876
+ tokenizer=tokenizer,
2877
+ use_paraphrase=use_paraphrase,
2878
+ word2idx=word2idx,
2879
+ args=args,
2880
+ split='train',
2881
+ **kwargs
2882
+ )
2883
+
2884
+ print("Creating test Dataset instance...")
2885
+ test_dataset = MS2BioTextDataset(
2886
+ ms2_data=test_ms2_data,
2887
+ meta_data=test_meta_data,
2888
+ biotext_data=test_biotext_data,
2889
+ tokenizer=tokenizer,
2890
+ use_paraphrase=False,
2891
+ word2idx=word2idx,
2892
+ args=args,
2893
+ split='test',
2894
+ **kwargs
2895
+ )
2896
+
2897
+ return train_dataset, test_dataset
2898
+
2899
+
2900
+
2901
+
2902
+ # 在文件顶部,import 语句之后,类定义之前
2903
+
2904
+ def _augment_worker(batch_data):
2905
+ """
2906
+ 全局 worker 函数,用于数据增强
2907
+ batch_data: (batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier)
2908
+ """
2909
+ import numpy as np
2910
+ # ⚠️ 关键修改:不要导入 MS2BioTextDataset,直接在这里定义需要的函数
2911
+
2912
+ batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier = batch_data
2913
+
2914
+ # 将 filter_low_intensity_peaks 和 add_noise_peaks 的逻辑直接复制到这里
2915
+ def filter_low_intensity_peaks(peaks, intensities, threshold):
2916
+ """过滤低强度峰"""
2917
+ if not peaks or not intensities:
2918
+ return peaks, intensities
2919
+
2920
+ max_intensity = max(intensities)
2921
+ if max_intensity == 0:
2922
+ return peaks, intensities
2923
+
2924
+ filtered_peaks = []
2925
+ filtered_intensities = []
2926
+ for mz, intensity in zip(peaks, intensities):
2927
+ if intensity / max_intensity >= threshold:
2928
+ filtered_peaks.append(mz)
2929
+ filtered_intensities.append(intensity)
2930
+
2931
+ return filtered_peaks, filtered_intensities
2932
+
2933
+ def add_noise_peaks(peaks, intensities, noise_ratio, noise_intensity_range):
2934
+ """添加噪声峰"""
2935
+ import random
2936
+
2937
+ if not peaks:
2938
+ return peaks, intensities
2939
+
2940
+ n_noise = int(len(peaks) * noise_ratio)
2941
+ if n_noise == 0:
2942
+ return peaks, intensities
2943
+
2944
+ # 获取m/z范围
2945
+ min_mz = min(peaks)
2946
+ max_mz = max(peaks)
2947
+ max_intensity = max(intensities) if intensities else 1.0
2948
+
2949
+ # 生成噪声峰
2950
+ for _ in range(n_noise):
2951
+ # 随机m/z(避免与现有峰重复)
2952
+ noise_mz = random.uniform(min_mz, max_mz)
2953
+ # 随机低强度
2954
+ noise_intensity = random.uniform(
2955
+ noise_intensity_range[0] * max_intensity,
2956
+ noise_intensity_range[1] * max_intensity
2957
+ )
2958
+
2959
+ peaks.append(noise_mz)
2960
+ intensities.append(noise_intensity)
2961
+
2962
+ # 按m/z排序
2963
+ sorted_pairs = sorted(zip(peaks, intensities), key=lambda x: x[0])
2964
+ peaks = [p for p, _ in sorted_pairs]
2965
+ intensities = [i for _, i in sorted_pairs]
2966
+
2967
+ return peaks, intensities
2968
+
2969
+ # 处理逻辑
2970
+ result = {}
2971
+ for ms2_id, info in batch:
2972
+ molecule_id = info.get('molecule_id')
2973
+ peaks_original = info['mz'] if isinstance(info['mz'], list) else list(info['mz'])
2974
+ intensities_original = info['intensity'] if isinstance(info['intensity'], list) else list(info['intensity'])
2975
+
2976
+ # 过滤
2977
+ if filter_threshold:
2978
+ peaks_original, intensities_original = filter_low_intensity_peaks(
2979
+ peaks_original, intensities_original, filter_threshold
2980
+ )
2981
+
2982
+ # 原始版本
2983
+ result[ms2_id] = {
2984
+ 'mz': peaks_original,
2985
+ 'intensity': intensities_original,
2986
+ 'molecule_id': molecule_id
2987
+ }
2988
+
2989
+ # 增强版本
2990
+ for aug_idx in range(1, augment_multiplier):
2991
+ peaks_aug, intensities_aug = add_noise_peaks(
2992
+ peaks_original.copy(), intensities_original.copy(),
2993
+ noise_ratio,
2994
+ noise_intensity_range
2995
+ )
2996
+ result[f"{ms2_id}_aug{aug_idx}"] = {
2997
+ 'mz': peaks_aug,
2998
+ 'intensity': intensities_aug,
2999
+ 'molecule_id': molecule_id
3000
+ }
3001
+ return result
3002
+
3003
+ def _preprocess_worker(batch_data):
3004
+ """
3005
+ 全局 worker 函数,用于多进程处理
3006
+ batch_data: (batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks, precursor_mode, precursor_value)
3007
+ """
3008
+ import numpy as np
3009
+ import pandas as pd
3010
+
3011
+ batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks, precursor_mode, precursor_value = batch_data
3012
+
3013
+ result = {}
3014
+ stats = {'kept': 0, 'filtered': 0}
3015
+
3016
+ for ms2_id, info in batch:
3017
+ # 基础检查
3018
+ if not info.get('mz'):
3019
+ stats['filtered'] += 1
3020
+ continue
3021
+
3022
+ # 转为 numpy 数组
3023
+ peaks = np.asarray(info['mz'], dtype=float)
3024
+ intensities = np.asarray(info['intensity'], dtype=float)
3025
+ molecule_id = info.get('molecule_id', None)
3026
+
3027
+ # 长度对齐检查
3028
+ if peaks.shape[0] != intensities.shape[0]:
3029
+ n = min(len(peaks), len(intensities))
3030
+ peaks = peaks[:n]
3031
+ intensities = intensities[:n]
3032
+
3033
+ # 文件名对应行用于极性判断
3034
+ specific_row = meta_data_processed.loc[meta_data_processed["file_name"] == ms2_id] if "file_name" in meta_data_processed.columns else pd.DataFrame()
3035
+ if specific_row.empty:
3036
+ if molecule_id is not None:
3037
+ if 'HMDB.ID' in meta_data_processed.columns:
3038
+ specific_row = meta_data_processed.loc[meta_data_processed['HMDB.ID'] == molecule_id]
3039
+ else:
3040
+ specific_row = meta_data_processed.loc[meta_data_processed.index == molecule_id]
3041
+
3042
+ if specific_row.empty:
3043
+ stats['filtered'] += 1
3044
+ continue
3045
+
3046
+ # 只保留正离子
3047
+ pol = str(specific_row["Polarity"].values[0]).lower().strip() if "Polarity" in specific_row.columns else ""
3048
+ if pol != "positive":
3049
+ stats['filtered'] += 1
3050
+ continue
3051
+
3052
+ # 获取precursor
3053
+ if 'HMDB.ID' in meta_data_processed.columns and (molecule_id is not None):
3054
+ row = meta_data_processed.loc[meta_data_processed['HMDB.ID'] == molecule_id]
3055
+ else:
3056
+ row = meta_data_processed.loc[meta_data_processed.index == molecule_id]
3057
+
3058
+ if row.empty or ('precursor_mass' not in row.columns):
3059
+ stats['filtered'] += 1
3060
+ continue
3061
+
3062
+ try:
3063
+ precursor_val = float(row['precursor_mass'].values[0])
3064
+ except Exception:
3065
+ stats['filtered'] += 1
3066
+ continue
3067
+
3068
+ # 前体范围 [10, 1000)
3069
+ if pd.isna(precursor_val) or (precursor_val < 10.0) or (precursor_val >= 1000.0):
3070
+ stats['filtered'] += 1
3071
+ continue
3072
+
3073
+ precursor_val = min(precursor_val, 999.99)
3074
+ precursor_str = "%.2f" % precursor_val
3075
+
3076
+ # 过滤峰到 [10, 1000)
3077
+ mask = (peaks >= 10.0) & (peaks < 1000.0) & np.isfinite(peaks) & np.isfinite(intensities)
3078
+ peaks = peaks[mask]
3079
+ intensities = intensities[mask]
3080
+
3081
+ if peaks.size == 0:
3082
+ stats['filtered'] += 1
3083
+ continue
3084
+
3085
+ # 按强度选 Top-K 碎片
3086
+ if peaks.size > max_frag:
3087
+ idx = np.argpartition(intensities, -max_frag)[-max_frag:]
3088
+ order = np.argsort(peaks[idx])
3089
+ idx = idx[order]
3090
+ peaks_sel = peaks[idx]
3091
+ intens_sel = intensities[idx]
3092
+ else:
3093
+ order = np.argsort(peaks)
3094
+ peaks_sel = peaks[order]
3095
+ intens_sel = intensities[order]
3096
+
3097
+ # 检查 min_peaks
3098
+ if peaks_sel.size < min_peaks:
3099
+ stats['filtered'] += 1
3100
+ continue
3101
+
3102
+ # 构建 token 序列
3103
+ peaks_str = ["%.2f" % p for p in peaks_sel]
3104
+ try:
3105
+ token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
3106
+ except KeyError:
3107
+ stats['filtered'] += 1
3108
+ continue
3109
+
3110
+ # ⭐ 根据 precursor_mode 选择处理方式
3111
+ if precursor_mode == 'scale_fixed':
3112
+ # 方案一:缩放fragments到固定值precursor_value(如20000),然后precursor添加2
3113
+ if np.max(intens_sel) > 0:
3114
+ intens_sel = intens_sel / np.max(intens_sel) * precursor_value
3115
+ intens_seq = np.hstack((2.0, intens_sel))
3116
+ # 整体归一化
3117
+ max_intensity = float(np.max(intens_seq))
3118
+ if max_intensity > 0:
3119
+ intens_seq = intens_seq / max_intensity
3120
+
3121
+ elif precursor_mode == 'normalize_add':
3122
+ # 方案二:归一化fragments到1,添加precursor_value,再整体归一化
3123
+ if np.max(intens_sel) > 0:
3124
+ intens_sel = intens_sel / np.max(intens_sel)
3125
+ intens_seq = np.hstack((precursor_value, intens_sel))
3126
+ # 整体归一化
3127
+ max_intensity = float(np.max(intens_seq))
3128
+ if max_intensity > 0:
3129
+ intens_seq = intens_seq / max_intensity
3130
+
3131
+ else:
3132
+ # 默认:原始MSBERT方式
3133
+ intens_seq = np.hstack((2.0, intens_sel))
3134
+ max_intensity = float(np.max(intens_seq))
3135
+ if max_intensity > 0:
3136
+ intens_seq = intens_seq / max_intensity
3137
+
3138
+ # Pad 或截断到 maxlen
3139
+ if len(token_ids) > maxlen:
3140
+ token_ids = token_ids[:maxlen]
3141
+ intens_seq = intens_seq[:maxlen]
3142
+
3143
+ n_pad = maxlen - len(token_ids)
3144
+ if n_pad > 0:
3145
+ token_ids += [word2idx['[PAD]']] * n_pad
3146
+ intens_seq = np.hstack([intens_seq, np.zeros(n_pad, dtype=float)])
3147
+
3148
+ result[ms2_id] = {
3149
+ 'mz': token_ids,
3150
+ 'intensity': intens_seq.tolist(),
3151
+ 'molecule_id': molecule_id
3152
+ }
3153
+ stats['kept'] += 1
3154
+
3155
+ return result, stats
3156
+
3157
+
3158
+
3159
+