spec2function 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Spec2Function/MS2BioTextDataset.py +3159 -0
- Spec2Function/__init__.py +17 -0
- Spec2Function/assets.py +74 -0
- Spec2Function/biotext_processor.py +380 -0
- Spec2Function/config.py +118 -0
- Spec2Function/data_augmentation.py +354 -0
- Spec2Function/gpt_inference.py +739 -0
- Spec2Function/llm_client.py +114 -0
- Spec2Function/model/MS2BioText.py +522 -0
- Spec2Function/model/MSBERT.py +261 -0
- Spec2Function/model/__init__.py +56 -0
- Spec2Function/model/config.py +249 -0
- Spec2Function/model/utils.py +167 -0
- Spec2Function/model_manager.py +1102 -0
- Spec2Function/pubmed.py +251 -0
- Spec2Function/read_raw_data.py +154 -0
- Spec2Function/utils.py +216 -0
- Spec2Function/workflow.py +233 -0
- spec2function-0.1.1.dist-info/METADATA +91 -0
- spec2function-0.1.1.dist-info/RECORD +23 -0
- spec2function-0.1.1.dist-info/WHEEL +5 -0
- spec2function-0.1.1.dist-info/licenses/LICENSE +21 -0
- spec2function-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3159 @@
|
|
|
1
|
+
# 所以MS2BioText的输入应该是包括MS2与BioText。然而一个BioText对应一个分子,一个分子对应多个MS2。所以在dataset构建的时候,需要储存:
|
|
2
|
+
# MS2的列表
|
|
3
|
+
# MS2对应的分子的列表
|
|
4
|
+
# 分子对应的BioText的列表
|
|
5
|
+
# 最后item的时候返回input_ids(m/z),intensity,BioText
|
|
6
|
+
# (这里添加一系列类内方法,在init的时候输入参数可以选择处理BioText的类内方法)
|
|
7
|
+
# 但是问题是但是在之后的实验里,还要记录对于每个MS2的其他信息,这个怎么储存。
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# 先完成dataset然后创建实例
|
|
11
|
+
# 读取HMDB数据集
|
|
12
|
+
# HMDB.h5是MS2数据读取为list?,HMDB.parquet是meta数据,读取为?然后Biotext是一个文件夹下包含的一系列txt文件,文件名为HMDB的id,读取为?
|
|
13
|
+
# ,读取test data跑通两个模型试试
|
|
14
|
+
|
|
15
|
+
import pickle
|
|
16
|
+
import h5py
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import os
|
|
19
|
+
import torch
|
|
20
|
+
from torch.utils.data import Dataset
|
|
21
|
+
import numpy as np
|
|
22
|
+
import random
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from sklearn.model_selection import train_test_split
|
|
25
|
+
import json
|
|
26
|
+
from torch.utils.data import Sampler
|
|
27
|
+
import random, itertools
|
|
28
|
+
import math
|
|
29
|
+
from collections.abc import Iterator
|
|
30
|
+
from typing import Optional, TypeVar, Dict, List, Tuple
|
|
31
|
+
import random
|
|
32
|
+
import argparse
|
|
33
|
+
import torch
|
|
34
|
+
import torch.distributed as dist
|
|
35
|
+
from torch.utils.data.dataset import Dataset
|
|
36
|
+
from torch.utils.data.sampler import Sampler
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
_T_co = TypeVar("_T_co", covariant=True)
|
|
40
|
+
|
|
41
|
+
class MS2MoleculeDistributedSampler(Sampler[_T_co]):
|
|
42
|
+
"""
|
|
43
|
+
专为MS2对比学习设计的分布式采样器
|
|
44
|
+
|
|
45
|
+
新增功能:为每个batch中的样本分配不冲突的text
|
|
46
|
+
- 确保batch内每个molecule选择的text不出现在其他molecule的候选列表中
|
|
47
|
+
- 保证batch内只有对角线是正样本
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
dataset: Dataset,
|
|
53
|
+
batch_size: int,
|
|
54
|
+
num_replicas: Optional[int] = None,
|
|
55
|
+
rank: Optional[int] = None,
|
|
56
|
+
shuffle: bool = True,
|
|
57
|
+
seed: int = 0,
|
|
58
|
+
drop_last: bool = False,
|
|
59
|
+
) -> None:
|
|
60
|
+
if num_replicas is None:
|
|
61
|
+
if not dist.is_available():
|
|
62
|
+
raise RuntimeError("Requires distributed package to be available")
|
|
63
|
+
num_replicas = dist.get_world_size()
|
|
64
|
+
if rank is None:
|
|
65
|
+
if not dist.is_available():
|
|
66
|
+
raise RuntimeError("Requires distributed package to be available")
|
|
67
|
+
rank = dist.get_rank()
|
|
68
|
+
if rank >= num_replicas or rank < 0:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
self.dataset = dataset
|
|
74
|
+
self.batch_size = batch_size
|
|
75
|
+
self.num_replicas = num_replicas
|
|
76
|
+
self.rank = rank
|
|
77
|
+
self.epoch = 0
|
|
78
|
+
self.drop_last = drop_last
|
|
79
|
+
self.shuffle = shuffle
|
|
80
|
+
self.seed = seed
|
|
81
|
+
|
|
82
|
+
# 构建molecule -> texts的映射
|
|
83
|
+
self.molecule_to_texts = self._build_molecule_text_mapping()
|
|
84
|
+
|
|
85
|
+
# 按molecule分组并按MS2数量排序
|
|
86
|
+
self.molecule_groups = self._group_by_molecule()
|
|
87
|
+
self.sorted_molecules = self._sort_molecules_by_ms2_count()
|
|
88
|
+
|
|
89
|
+
# 生成batch分配方案
|
|
90
|
+
self.batch_indices = self._create_batch_allocation()
|
|
91
|
+
|
|
92
|
+
# 确保能被GPU数整除
|
|
93
|
+
self._adjust_for_distributed()
|
|
94
|
+
|
|
95
|
+
# 计算每个进程的样本数
|
|
96
|
+
total_batches = len(self.batch_indices)
|
|
97
|
+
batches_per_replica = total_batches // self.num_replicas
|
|
98
|
+
self.num_samples = batches_per_replica * self.batch_size
|
|
99
|
+
|
|
100
|
+
def _build_molecule_text_mapping(self) -> Dict[str, List[str]]:
|
|
101
|
+
"""构建每个molecule的候选text列表"""
|
|
102
|
+
molecule_to_texts = {}
|
|
103
|
+
|
|
104
|
+
for mol_id, entry in self.dataset.biotext_data.items():
|
|
105
|
+
texts = []
|
|
106
|
+
if isinstance(entry, list):
|
|
107
|
+
texts = [record['text'] for record in entry]
|
|
108
|
+
elif isinstance(entry, dict):
|
|
109
|
+
original = entry.get("original", "")
|
|
110
|
+
paraphrases = entry.get("paraphrases", [])
|
|
111
|
+
texts = [original] + paraphrases
|
|
112
|
+
elif isinstance(entry, str):
|
|
113
|
+
texts = [entry]
|
|
114
|
+
|
|
115
|
+
molecule_to_texts[mol_id] = texts
|
|
116
|
+
|
|
117
|
+
print(f"Built text mapping for {len(molecule_to_texts)} molecules")
|
|
118
|
+
return molecule_to_texts
|
|
119
|
+
|
|
120
|
+
def _group_by_molecule(self) -> Dict[str, List[int]]:
|
|
121
|
+
"""按molecule ID对MS2数据进行分组"""
|
|
122
|
+
molecule_groups = {}
|
|
123
|
+
for idx in range(len(self.dataset)):
|
|
124
|
+
ms2_id = self.dataset.ms2_ids[idx]
|
|
125
|
+
molecule_id = self.dataset.preprocessed_ms2_tensors[ms2_id]['molecule_id']
|
|
126
|
+
|
|
127
|
+
if molecule_id not in molecule_groups:
|
|
128
|
+
molecule_groups[molecule_id] = []
|
|
129
|
+
molecule_groups[molecule_id].append(idx)
|
|
130
|
+
|
|
131
|
+
return molecule_groups
|
|
132
|
+
|
|
133
|
+
def _sort_molecules_by_ms2_count(self) -> List[Tuple[str, List[int]]]:
|
|
134
|
+
"""按每个molecule拥有的MS2数量从多到少排序"""
|
|
135
|
+
molecule_items = [(mol_id, indices) for mol_id, indices in self.molecule_groups.items()]
|
|
136
|
+
sorted_items = sorted(molecule_items, key=lambda x: len(x[1]), reverse=True)
|
|
137
|
+
|
|
138
|
+
print(f"Molecule MS2 count distribution:")
|
|
139
|
+
print(f"Max MS2 per molecule: {len(sorted_items[0][1])}")
|
|
140
|
+
print(f"Min MS2 per molecule: {len(sorted_items[-1][1])}")
|
|
141
|
+
print(f"Total molecules: {len(sorted_items)}")
|
|
142
|
+
print(f"Total MS2 spectra: {sum(len(indices) for _, indices in sorted_items)}")
|
|
143
|
+
|
|
144
|
+
return sorted_items
|
|
145
|
+
|
|
146
|
+
def _assign_texts_for_batch(self, batch_molecule_ids: List[str]) -> Dict[str, int]:
|
|
147
|
+
"""
|
|
148
|
+
为batch中的每个molecule分配一个text index
|
|
149
|
+
确保选中的text不在其他molecule的候选列表中
|
|
150
|
+
|
|
151
|
+
返回: {molecule_id: text_index}
|
|
152
|
+
"""
|
|
153
|
+
mol_to_text_idx = {}
|
|
154
|
+
occupied_texts = set() # 已被占用的text
|
|
155
|
+
|
|
156
|
+
# 按照候选text数量从少到多排序,优先处理选择空间小的molecule
|
|
157
|
+
sorted_mols = sorted(
|
|
158
|
+
batch_molecule_ids,
|
|
159
|
+
key=lambda mol_id: len(self.molecule_to_texts.get(mol_id, []))
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
for mol_id in sorted_mols:
|
|
163
|
+
candidate_texts = self.molecule_to_texts.get(mol_id, [])
|
|
164
|
+
|
|
165
|
+
if not candidate_texts:
|
|
166
|
+
print(f"⚠️ Warning: Molecule {mol_id} has no candidate texts")
|
|
167
|
+
mol_to_text_idx[mol_id] = 0
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
# 找到所有未被占用且不在其他molecule候选中的text
|
|
171
|
+
available_indices = []
|
|
172
|
+
for i, text in enumerate(candidate_texts):
|
|
173
|
+
if text not in occupied_texts:
|
|
174
|
+
# 检查这个text是否在其他molecule的候选中
|
|
175
|
+
text_in_others = False
|
|
176
|
+
for other_mol_id in batch_molecule_ids:
|
|
177
|
+
if other_mol_id != mol_id:
|
|
178
|
+
other_texts = self.molecule_to_texts.get(other_mol_id, [])
|
|
179
|
+
if text in other_texts:
|
|
180
|
+
text_in_others = True
|
|
181
|
+
break
|
|
182
|
+
|
|
183
|
+
if not text_in_others:
|
|
184
|
+
available_indices.append(i)
|
|
185
|
+
|
|
186
|
+
# 如果有可用的text,随机选一个
|
|
187
|
+
if available_indices:
|
|
188
|
+
chosen_idx = random.choice(available_indices)
|
|
189
|
+
mol_to_text_idx[mol_id] = chosen_idx
|
|
190
|
+
occupied_texts.add(candidate_texts[chosen_idx])
|
|
191
|
+
else:
|
|
192
|
+
# Fallback: 随机选一个未被占用的(可能在其他molecule的候选中)
|
|
193
|
+
fallback_indices = [i for i, text in enumerate(candidate_texts)
|
|
194
|
+
if text not in occupied_texts]
|
|
195
|
+
if fallback_indices:
|
|
196
|
+
chosen_idx = random.choice(fallback_indices)
|
|
197
|
+
mol_to_text_idx[mol_id] = chosen_idx
|
|
198
|
+
occupied_texts.add(candidate_texts[chosen_idx])
|
|
199
|
+
print(f"⚠️ Fallback: Molecule {mol_id} text may conflict with others")
|
|
200
|
+
else:
|
|
201
|
+
# 极端情况:所有text都被占用了
|
|
202
|
+
chosen_idx = random.randint(0, len(candidate_texts) - 1)
|
|
203
|
+
mol_to_text_idx[mol_id] = chosen_idx
|
|
204
|
+
print(f"⚠️ Extreme fallback: All texts occupied for {mol_id}")
|
|
205
|
+
|
|
206
|
+
return mol_to_text_idx
|
|
207
|
+
|
|
208
|
+
def _create_batch_allocation(self) -> List[List[int]]:
|
|
209
|
+
"""创建batch分配方案"""
|
|
210
|
+
molecule_ms2_usage = {}
|
|
211
|
+
for mol_id, indices in self.sorted_molecules:
|
|
212
|
+
molecule_ms2_usage[mol_id] = {
|
|
213
|
+
'indices': indices.copy(),
|
|
214
|
+
'used_count': 0
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
batch_indices = []
|
|
218
|
+
total_ms2_count = sum(len(indices) for _, indices in self.sorted_molecules)
|
|
219
|
+
used_ms2_count = 0
|
|
220
|
+
|
|
221
|
+
print(f"Starting batch allocation for {total_ms2_count} MS2 spectra...")
|
|
222
|
+
|
|
223
|
+
while used_ms2_count < total_ms2_count:
|
|
224
|
+
current_batch = []
|
|
225
|
+
used_molecules_in_batch = set()
|
|
226
|
+
|
|
227
|
+
for mol_id, mol_data in molecule_ms2_usage.items():
|
|
228
|
+
if len(current_batch) >= self.batch_size:
|
|
229
|
+
break
|
|
230
|
+
|
|
231
|
+
if mol_id in used_molecules_in_batch:
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
if mol_data['used_count'] < len(mol_data['indices']):
|
|
235
|
+
ms2_idx = mol_data['indices'][mol_data['used_count']]
|
|
236
|
+
current_batch.append(ms2_idx)
|
|
237
|
+
used_molecules_in_batch.add(mol_id)
|
|
238
|
+
mol_data['used_count'] += 1
|
|
239
|
+
used_ms2_count += 1
|
|
240
|
+
|
|
241
|
+
if len(current_batch) < self.batch_size and len(current_batch) > 0:
|
|
242
|
+
available_molecules = [mol_id for mol_id in molecule_ms2_usage.keys()
|
|
243
|
+
if mol_id not in used_molecules_in_batch]
|
|
244
|
+
|
|
245
|
+
while len(current_batch) < self.batch_size and available_molecules:
|
|
246
|
+
available_molecules.sort(key=lambda x: len(molecule_ms2_usage[x]['indices']),
|
|
247
|
+
reverse=True)
|
|
248
|
+
|
|
249
|
+
mol_id = available_molecules[0]
|
|
250
|
+
mol_data = molecule_ms2_usage[mol_id]
|
|
251
|
+
|
|
252
|
+
ms2_idx = random.choice(mol_data['indices'])
|
|
253
|
+
current_batch.append(ms2_idx)
|
|
254
|
+
used_molecules_in_batch.add(mol_id)
|
|
255
|
+
available_molecules.remove(mol_id)
|
|
256
|
+
|
|
257
|
+
if len(current_batch) == 0:
|
|
258
|
+
break
|
|
259
|
+
|
|
260
|
+
if len(current_batch) < self.batch_size:
|
|
261
|
+
if self.drop_last:
|
|
262
|
+
print(f"Dropping incomplete batch with {len(current_batch)} samples")
|
|
263
|
+
break
|
|
264
|
+
else:
|
|
265
|
+
while len(current_batch) < self.batch_size:
|
|
266
|
+
available_molecules = [mol_id for mol_id in molecule_ms2_usage.keys()
|
|
267
|
+
if mol_id not in used_molecules_in_batch]
|
|
268
|
+
|
|
269
|
+
if not available_molecules:
|
|
270
|
+
print(f"Cannot fill batch further: only {len(self.sorted_molecules)} unique molecules available")
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
mol_id = random.choice(available_molecules)
|
|
274
|
+
mol_data = molecule_ms2_usage[mol_id]
|
|
275
|
+
|
|
276
|
+
ms2_idx = random.choice(mol_data['indices'])
|
|
277
|
+
current_batch.append(ms2_idx)
|
|
278
|
+
used_molecules_in_batch.add(mol_id)
|
|
279
|
+
|
|
280
|
+
batch_indices.append(current_batch)
|
|
281
|
+
|
|
282
|
+
print(f"Created {len(batch_indices)} batches")
|
|
283
|
+
print(f"Used {used_ms2_count} MS2 spectra out of {total_ms2_count}")
|
|
284
|
+
|
|
285
|
+
return batch_indices
|
|
286
|
+
|
|
287
|
+
def _adjust_for_distributed(self):
|
|
288
|
+
"""调整batch数量以确保能被GPU数整除"""
|
|
289
|
+
total_batches = len(self.batch_indices)
|
|
290
|
+
remainder = total_batches % self.num_replicas
|
|
291
|
+
|
|
292
|
+
if remainder != 0:
|
|
293
|
+
if self.drop_last:
|
|
294
|
+
batches_to_remove = remainder
|
|
295
|
+
self.batch_indices = self.batch_indices[:-batches_to_remove]
|
|
296
|
+
print(f"Dropped {batches_to_remove} batches to ensure divisibility by {self.num_replicas} GPUs")
|
|
297
|
+
else:
|
|
298
|
+
batches_to_add = self.num_replicas - remainder
|
|
299
|
+
for i in range(batches_to_add):
|
|
300
|
+
batch_to_copy = self.batch_indices[i % len(self.batch_indices)]
|
|
301
|
+
self.batch_indices.append(batch_to_copy.copy())
|
|
302
|
+
print(f"Added {batches_to_add} batches to ensure divisibility by {self.num_replicas} GPUs")
|
|
303
|
+
|
|
304
|
+
final_batches = len(self.batch_indices)
|
|
305
|
+
print(f"Final batch count: {final_batches} (divisible by {self.num_replicas} GPUs)")
|
|
306
|
+
print(f"Each GPU will process {final_batches // self.num_replicas} batches")
|
|
307
|
+
|
|
308
|
+
def __iter__(self) -> Iterator[_T_co]:
|
|
309
|
+
# 获取当前进程应该处理的batch
|
|
310
|
+
total_batches = len(self.batch_indices)
|
|
311
|
+
batches_per_replica = total_batches // self.num_replicas
|
|
312
|
+
|
|
313
|
+
start_batch = self.rank * batches_per_replica
|
|
314
|
+
end_batch = start_batch + batches_per_replica
|
|
315
|
+
|
|
316
|
+
my_batches = self.batch_indices[start_batch:end_batch]
|
|
317
|
+
|
|
318
|
+
if self.shuffle:
|
|
319
|
+
g = torch.Generator()
|
|
320
|
+
g.manual_seed(self.seed + self.epoch)
|
|
321
|
+
batch_order = torch.randperm(len(my_batches), generator=g).tolist()
|
|
322
|
+
my_batches = [my_batches[i] for i in batch_order]
|
|
323
|
+
|
|
324
|
+
for batch in my_batches:
|
|
325
|
+
random.Random(self.seed + self.epoch).shuffle(batch)
|
|
326
|
+
|
|
327
|
+
# *** 关键:为每个batch分配text ***
|
|
328
|
+
self.dataset.text_assignment.clear()
|
|
329
|
+
|
|
330
|
+
for batch_indices in my_batches:
|
|
331
|
+
# 获取batch中所有molecule ID
|
|
332
|
+
batch_molecule_ids = []
|
|
333
|
+
ms2_id_to_mol_id = {}
|
|
334
|
+
|
|
335
|
+
for idx in batch_indices:
|
|
336
|
+
ms2_id = self.dataset.ms2_ids[idx]
|
|
337
|
+
mol_id = self.dataset.preprocessed_ms2_tensors[ms2_id]['molecule_id']
|
|
338
|
+
batch_molecule_ids.append(mol_id)
|
|
339
|
+
ms2_id_to_mol_id[ms2_id] = mol_id
|
|
340
|
+
|
|
341
|
+
# 为这个batch分配text indices
|
|
342
|
+
mol_to_text_idx = self._assign_texts_for_batch(batch_molecule_ids)
|
|
343
|
+
|
|
344
|
+
# 将分配结果写入dataset.text_assignment
|
|
345
|
+
for idx in batch_indices:
|
|
346
|
+
ms2_id = self.dataset.ms2_ids[idx]
|
|
347
|
+
mol_id = ms2_id_to_mol_id[ms2_id]
|
|
348
|
+
self.dataset.text_assignment[ms2_id] = mol_to_text_idx.get(mol_id, 0)
|
|
349
|
+
|
|
350
|
+
# 展平所有batch的indices
|
|
351
|
+
all_indices = []
|
|
352
|
+
for batch in my_batches:
|
|
353
|
+
all_indices.extend(batch)
|
|
354
|
+
|
|
355
|
+
return iter(all_indices)
|
|
356
|
+
|
|
357
|
+
def __len__(self) -> int:
|
|
358
|
+
return self.num_samples
|
|
359
|
+
|
|
360
|
+
def set_epoch(self, epoch: int) -> None:
|
|
361
|
+
"""设置当前epoch,用于确保每个epoch的shuffle结果不同"""
|
|
362
|
+
self.epoch = epoch
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
# function for dataset augmentation
|
|
366
|
+
import torch
|
|
367
|
+
import numpy as np
|
|
368
|
+
import random
|
|
369
|
+
import torch
|
|
370
|
+
import numpy as np
|
|
371
|
+
import random
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def sample_truncated_normal(mean, std, low, high):
|
|
375
|
+
"""
|
|
376
|
+
从截断正态分布中采样
|
|
377
|
+
|
|
378
|
+
Args:
|
|
379
|
+
mean: 均值
|
|
380
|
+
std: 标准差
|
|
381
|
+
low: 下界
|
|
382
|
+
high: 上界
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
采样值
|
|
386
|
+
"""
|
|
387
|
+
max_attempts = 1000
|
|
388
|
+
for _ in range(max_attempts):
|
|
389
|
+
sample = np.random.normal(mean, std)
|
|
390
|
+
if low <= sample <= high:
|
|
391
|
+
return sample
|
|
392
|
+
# 如果1000次都没采样到,返回截断后的值
|
|
393
|
+
return np.clip(np.random.normal(mean, std), low, high)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def augment_tokenized_ms2_optimized(mz_tokens, intensity, word2idx, args):
|
|
397
|
+
"""
|
|
398
|
+
基于External数据特征优化的动态augmentation(支持随机noise ratio)
|
|
399
|
+
"""
|
|
400
|
+
# ===== 0. 读取参数并决定是否augment =====
|
|
401
|
+
augment_prob = getattr(args, 'augment_prob', 0.5)
|
|
402
|
+
if random.random() > augment_prob:
|
|
403
|
+
return mz_tokens, intensity
|
|
404
|
+
|
|
405
|
+
# 确保intensity是1D
|
|
406
|
+
if intensity.dim() == 2:
|
|
407
|
+
intensity = intensity.squeeze(0)
|
|
408
|
+
|
|
409
|
+
device = mz_tokens.device
|
|
410
|
+
|
|
411
|
+
# ===== 1. 构建token_id到m/z的映射 =====
|
|
412
|
+
idx2word = {v: k for k, v in word2idx.items()}
|
|
413
|
+
|
|
414
|
+
def token_to_mz(token_id):
|
|
415
|
+
"""将token id转换为实际m/z值"""
|
|
416
|
+
word = idx2word.get(token_id.item(), None)
|
|
417
|
+
if word and word not in ['[PAD]', '[MASK]']:
|
|
418
|
+
try:
|
|
419
|
+
return float(word)
|
|
420
|
+
except ValueError:
|
|
421
|
+
return None
|
|
422
|
+
return None
|
|
423
|
+
|
|
424
|
+
# 转换为numpy进行计算
|
|
425
|
+
mz_tokens_np = mz_tokens.cpu().numpy()
|
|
426
|
+
intensity_np = intensity.cpu().numpy()
|
|
427
|
+
|
|
428
|
+
# 获取实际的m/z值(跳过特殊token)
|
|
429
|
+
original_mz = []
|
|
430
|
+
original_intensity = []
|
|
431
|
+
for i, token_id in enumerate(mz_tokens_np):
|
|
432
|
+
mz_val = token_to_mz(torch.tensor(token_id))
|
|
433
|
+
if mz_val is not None:
|
|
434
|
+
original_mz.append(mz_val)
|
|
435
|
+
original_intensity.append(intensity_np[i])
|
|
436
|
+
|
|
437
|
+
if len(original_mz) == 0:
|
|
438
|
+
return mz_tokens, intensity
|
|
439
|
+
|
|
440
|
+
original_mz = np.array(original_mz)
|
|
441
|
+
original_intensity = np.array(original_intensity)
|
|
442
|
+
max_intensity = original_intensity.max()
|
|
443
|
+
|
|
444
|
+
# ===== 2. 识别signal peaks(强度>5%的peaks) =====
|
|
445
|
+
signal_threshold = 0.05
|
|
446
|
+
signal_mask = original_intensity >= signal_threshold * max_intensity
|
|
447
|
+
signal_mz = original_mz[signal_mask]
|
|
448
|
+
n_signal = len(signal_mz)
|
|
449
|
+
|
|
450
|
+
if n_signal == 0:
|
|
451
|
+
return mz_tokens, intensity
|
|
452
|
+
|
|
453
|
+
# ===== 3. 🔥 随机化参数 =====
|
|
454
|
+
align_to_external = getattr(args, 'align_to_external', False)
|
|
455
|
+
randomize_noise_ratio = getattr(args, 'randomize_noise_ratio', True) # 🔥 新增开关
|
|
456
|
+
noise_sampling_strategy = getattr(args, 'noise_sampling_strategy', 'uniform') # 🔥 选择策略
|
|
457
|
+
|
|
458
|
+
if align_to_external:
|
|
459
|
+
# 🔥 随机化noise ratio
|
|
460
|
+
if randomize_noise_ratio:
|
|
461
|
+
if noise_sampling_strategy == 'uniform':
|
|
462
|
+
# 策略1:均匀分布
|
|
463
|
+
noise_ratio_range = getattr(args, 'noise_ratio_range', [0.60, 0.90])
|
|
464
|
+
TARGET_NOISE_RATIO = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
|
|
465
|
+
|
|
466
|
+
elif noise_sampling_strategy == 'normal':
|
|
467
|
+
# 策略2:正态分布
|
|
468
|
+
target_noise_ratio = getattr(args, 'target_noise_ratio', 0.80)
|
|
469
|
+
noise_ratio_std = getattr(args, 'noise_ratio_std', 0.10)
|
|
470
|
+
TARGET_NOISE_RATIO = np.random.normal(target_noise_ratio, noise_ratio_std)
|
|
471
|
+
TARGET_NOISE_RATIO = np.clip(TARGET_NOISE_RATIO, 0.40, 0.95)
|
|
472
|
+
|
|
473
|
+
elif noise_sampling_strategy == 'bimodal':
|
|
474
|
+
# 策略3:双模态分布
|
|
475
|
+
bimodal_dirty_prob = getattr(args, 'bimodal_dirty_prob', 0.7)
|
|
476
|
+
if np.random.random() < bimodal_dirty_prob:
|
|
477
|
+
# 脏数据模式
|
|
478
|
+
TARGET_NOISE_RATIO = np.random.uniform(0.70, 0.90)
|
|
479
|
+
else:
|
|
480
|
+
# 干净数据模式
|
|
481
|
+
TARGET_NOISE_RATIO = np.random.uniform(0.30, 0.60)
|
|
482
|
+
else:
|
|
483
|
+
# 默认使用固定值
|
|
484
|
+
TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.80)
|
|
485
|
+
else:
|
|
486
|
+
# 不随机化,使用固定值
|
|
487
|
+
TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.80)
|
|
488
|
+
|
|
489
|
+
# 🔥 可选:随机化proximal ratio
|
|
490
|
+
randomize_proximal_ratio = getattr(args, 'randomize_proximal_ratio', False)
|
|
491
|
+
if randomize_proximal_ratio:
|
|
492
|
+
proximal_ratio_range = getattr(args, 'proximal_ratio_range', [0.15, 0.22])
|
|
493
|
+
PROXIMAL_RATIO_OF_NOISE = np.random.uniform(proximal_ratio_range[0], proximal_ratio_range[1])
|
|
494
|
+
else:
|
|
495
|
+
PROXIMAL_RATIO_OF_NOISE = getattr(args, 'proximal_ratio_of_noise', 0.18)
|
|
496
|
+
|
|
497
|
+
# Intensity参数
|
|
498
|
+
PROXIMAL_MEAN = getattr(args, 'proximal_intensity_mean', 0.0115)
|
|
499
|
+
PROXIMAL_STD = getattr(args, 'proximal_intensity_std', 0.0138)
|
|
500
|
+
ISOLATED_MEAN = getattr(args, 'isolated_intensity_mean', 0.0079)
|
|
501
|
+
ISOLATED_STD = getattr(args, 'isolated_intensity_std', 0.0114)
|
|
502
|
+
|
|
503
|
+
# 区域权重
|
|
504
|
+
use_regional_weighting = getattr(args, 'use_regional_weighting', True)
|
|
505
|
+
if use_regional_weighting:
|
|
506
|
+
REGION_WEIGHTS = [
|
|
507
|
+
(0, 100, 0.24),
|
|
508
|
+
(100, 200, 0.53),
|
|
509
|
+
(200, 300, 0.18),
|
|
510
|
+
(300, 500, 0.05)
|
|
511
|
+
]
|
|
512
|
+
else:
|
|
513
|
+
REGION_WEIGHTS = None
|
|
514
|
+
else:
|
|
515
|
+
# 不对齐External时的参数
|
|
516
|
+
if randomize_noise_ratio:
|
|
517
|
+
noise_ratio_range = getattr(args, 'noise_ratio_range', [0.30, 0.70])
|
|
518
|
+
TARGET_NOISE_RATIO = np.random.uniform(noise_ratio_range[0], noise_ratio_range[1])
|
|
519
|
+
else:
|
|
520
|
+
TARGET_NOISE_RATIO = getattr(args, 'target_noise_ratio', 0.50)
|
|
521
|
+
|
|
522
|
+
PROXIMAL_RATIO_OF_NOISE = 0.25
|
|
523
|
+
PROXIMAL_MEAN = 0.0141
|
|
524
|
+
PROXIMAL_STD = 0.0209
|
|
525
|
+
ISOLATED_MEAN = 0.0091
|
|
526
|
+
ISOLATED_STD = 0.0144
|
|
527
|
+
REGION_WEIGHTS = None
|
|
528
|
+
|
|
529
|
+
# 空间分布参数
|
|
530
|
+
proximal_distance_range = getattr(args, 'proximal_distance_range', [-1.5, 1.5])
|
|
531
|
+
isolated_min_distance = getattr(args, 'isolated_min_distance', 5.0)
|
|
532
|
+
|
|
533
|
+
# ===== 4. 计算需要添加的noise总数 =====
|
|
534
|
+
n_noise_total = int(n_signal * TARGET_NOISE_RATIO / (1 - TARGET_NOISE_RATIO))
|
|
535
|
+
n_proximal = int(n_noise_total * PROXIMAL_RATIO_OF_NOISE)
|
|
536
|
+
n_isolated = n_noise_total - n_proximal
|
|
537
|
+
|
|
538
|
+
# ... 后面的代码保持不变 ...
|
|
539
|
+
|
|
540
|
+
mz_min, mz_max = original_mz.min(), original_mz.max()
|
|
541
|
+
|
|
542
|
+
# ===== 5. 生成Proximal Noise =====
|
|
543
|
+
proximal_mz = []
|
|
544
|
+
proximal_intensity = []
|
|
545
|
+
|
|
546
|
+
for _ in range(n_proximal):
|
|
547
|
+
# 选择一个signal peak作为base
|
|
548
|
+
base_mz = np.random.choice(signal_mz)
|
|
549
|
+
|
|
550
|
+
# Proximal距离范围
|
|
551
|
+
offset = np.random.uniform(proximal_distance_range[0], proximal_distance_range[1])
|
|
552
|
+
noise_mz = base_mz + offset
|
|
553
|
+
noise_mz = np.clip(noise_mz, mz_min, mz_max)
|
|
554
|
+
|
|
555
|
+
# 采样intensity(截断正态分布)
|
|
556
|
+
noise_intensity = sample_truncated_normal(
|
|
557
|
+
PROXIMAL_MEAN, PROXIMAL_STD, 0.001, 0.10
|
|
558
|
+
) * max_intensity
|
|
559
|
+
|
|
560
|
+
proximal_mz.append(noise_mz)
|
|
561
|
+
proximal_intensity.append(noise_intensity)
|
|
562
|
+
|
|
563
|
+
# ===== 6. 生成Isolated Noise(区域加权) =====
|
|
564
|
+
isolated_mz = []
|
|
565
|
+
isolated_intensity = []
|
|
566
|
+
|
|
567
|
+
if REGION_WEIGHTS:
|
|
568
|
+
# 使用区域加权策略
|
|
569
|
+
for low, high, weight in REGION_WEIGHTS:
|
|
570
|
+
# 只在光谱m/z范围内生成
|
|
571
|
+
region_low = max(low, mz_min)
|
|
572
|
+
region_high = min(high, mz_max)
|
|
573
|
+
|
|
574
|
+
if region_low >= region_high:
|
|
575
|
+
continue
|
|
576
|
+
|
|
577
|
+
n_in_region = int(n_isolated * weight)
|
|
578
|
+
attempts = 0
|
|
579
|
+
max_attempts = n_in_region * 10
|
|
580
|
+
generated = 0
|
|
581
|
+
|
|
582
|
+
while generated < n_in_region and attempts < max_attempts:
|
|
583
|
+
candidate_mz = np.random.uniform(region_low, region_high)
|
|
584
|
+
|
|
585
|
+
# 检查是否远离所有signal peaks
|
|
586
|
+
min_dist = np.min(np.abs(signal_mz - candidate_mz))
|
|
587
|
+
|
|
588
|
+
if min_dist >= isolated_min_distance:
|
|
589
|
+
# 100-200 Da区域intensity稍高
|
|
590
|
+
if 100 <= candidate_mz <= 200:
|
|
591
|
+
mean_adj = ISOLATED_MEAN * 1.1
|
|
592
|
+
else:
|
|
593
|
+
mean_adj = ISOLATED_MEAN
|
|
594
|
+
|
|
595
|
+
noise_intensity = sample_truncated_normal(
|
|
596
|
+
mean_adj, ISOLATED_STD, 0.0001, 0.05
|
|
597
|
+
) * max_intensity
|
|
598
|
+
|
|
599
|
+
isolated_mz.append(candidate_mz)
|
|
600
|
+
isolated_intensity.append(noise_intensity)
|
|
601
|
+
generated += 1
|
|
602
|
+
|
|
603
|
+
attempts += 1
|
|
604
|
+
else:
|
|
605
|
+
# 不使用区域加权(原始随机策略)
|
|
606
|
+
attempts = 0
|
|
607
|
+
max_attempts = n_isolated * 10
|
|
608
|
+
generated = 0
|
|
609
|
+
|
|
610
|
+
while generated < n_isolated and attempts < max_attempts:
|
|
611
|
+
candidate_mz = np.random.uniform(mz_min, mz_max)
|
|
612
|
+
min_dist = np.min(np.abs(signal_mz - candidate_mz))
|
|
613
|
+
|
|
614
|
+
if min_dist >= isolated_min_distance:
|
|
615
|
+
noise_intensity = sample_truncated_normal(
|
|
616
|
+
ISOLATED_MEAN, ISOLATED_STD, 0.0001, 0.05
|
|
617
|
+
) * max_intensity
|
|
618
|
+
|
|
619
|
+
isolated_mz.append(candidate_mz)
|
|
620
|
+
isolated_intensity.append(noise_intensity)
|
|
621
|
+
generated += 1
|
|
622
|
+
|
|
623
|
+
attempts += 1
|
|
624
|
+
|
|
625
|
+
# ===== 7. 合并所有peaks =====
|
|
626
|
+
all_mz = np.concatenate([original_mz, proximal_mz, isolated_mz])
|
|
627
|
+
all_intensity = np.concatenate([original_intensity, proximal_intensity, isolated_intensity])
|
|
628
|
+
|
|
629
|
+
# ===== 8. 转换回token ids =====
|
|
630
|
+
def mz_to_token(mz_val):
|
|
631
|
+
"""将m/z值转换为token id(四舍五入到0.01精度)"""
|
|
632
|
+
mz_rounded = round(mz_val, 2)
|
|
633
|
+
mz_str = f"{mz_rounded:.2f}"
|
|
634
|
+
return word2idx.get(mz_str, word2idx.get('[MASK]', 1))
|
|
635
|
+
|
|
636
|
+
all_tokens = [mz_to_token(mz) for mz in all_mz]
|
|
637
|
+
|
|
638
|
+
# 按m/z排序
|
|
639
|
+
sorted_idx = np.argsort(all_mz)
|
|
640
|
+
all_tokens = np.array(all_tokens)[sorted_idx]
|
|
641
|
+
all_intensity = all_intensity[sorted_idx]
|
|
642
|
+
|
|
643
|
+
# ===== 9. 可选:过滤和截断 =====
|
|
644
|
+
filter_threshold = getattr(args, 'filter_threshold', None)
|
|
645
|
+
if filter_threshold and filter_threshold > 0:
|
|
646
|
+
max_int = all_intensity.max()
|
|
647
|
+
if max_int > 0:
|
|
648
|
+
threshold = max_int * filter_threshold
|
|
649
|
+
mask = all_intensity >= threshold
|
|
650
|
+
all_tokens = all_tokens[mask]
|
|
651
|
+
all_intensity = all_intensity[mask]
|
|
652
|
+
|
|
653
|
+
# 截断到maxlen
|
|
654
|
+
maxlen = getattr(args, 'maxlen', 100)
|
|
655
|
+
if len(all_tokens) > maxlen:
|
|
656
|
+
# 保留最强的peaks
|
|
657
|
+
top_indices = np.argsort(all_intensity)[-maxlen:]
|
|
658
|
+
top_indices = np.sort(top_indices) # 恢复m/z顺序
|
|
659
|
+
all_tokens = all_tokens[top_indices]
|
|
660
|
+
all_intensity = all_intensity[top_indices]
|
|
661
|
+
|
|
662
|
+
# ===== 10. 转换回tensor =====
|
|
663
|
+
augmented_mz = torch.tensor(all_tokens, dtype=mz_tokens.dtype, device=device)
|
|
664
|
+
augmented_intensity = torch.tensor(all_intensity, dtype=intensity.dtype, device=device).unsqueeze(0)
|
|
665
|
+
|
|
666
|
+
return augmented_mz, augmented_intensity
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
class MS2BioTextDataset(Dataset):
|
|
672
|
+
def __init__(self, ms2_data, meta_data, biotext_data, tokenizer, max_length=512,
|
|
673
|
+
use_paraphrase=False, use_mlm=False, use_ms2_prediction=False,
|
|
674
|
+
prediction_label_columns=None, word2idx=None, args=None, split='test'):
|
|
675
|
+
"""
|
|
676
|
+
hard_neg_path: path to hard_negatives_v2.json
|
|
677
|
+
num_hard_neg_per_sample: number of hard negatives to use per sample
|
|
678
|
+
"""
|
|
679
|
+
self.ms2_data = ms2_data
|
|
680
|
+
self.meta_data = meta_data
|
|
681
|
+
self.biotext_data = biotext_data
|
|
682
|
+
self.preprocessed_ms2_tensors = {}
|
|
683
|
+
|
|
684
|
+
for ms2_id, ms2_entry in self.ms2_data.items():
|
|
685
|
+
self.preprocessed_ms2_tensors[ms2_id] = {
|
|
686
|
+
'mz': torch.tensor(ms2_entry['mz'], dtype=torch.float32),
|
|
687
|
+
'intensity': torch.tensor(ms2_entry['intensity'], dtype=torch.float32),
|
|
688
|
+
'molecule_id': ms2_entry['molecule_id']
|
|
689
|
+
}
|
|
690
|
+
self.text_assignment = {}
|
|
691
|
+
self.ms2_ids = list(ms2_data.keys())
|
|
692
|
+
self.word2idx = word2idx
|
|
693
|
+
self.args = args or argparse.Namespace() # 确保非None
|
|
694
|
+
self.split = split
|
|
695
|
+
self.tokenizer = tokenizer
|
|
696
|
+
self.max_length = max_length
|
|
697
|
+
self.use_mlm = use_mlm
|
|
698
|
+
self.use_ms2_prediction = use_ms2_prediction
|
|
699
|
+
self.prediction_label_columns = prediction_label_columns
|
|
700
|
+
self.use_paraphrase = use_paraphrase
|
|
701
|
+
|
|
702
|
+
# === NEW: Load hard negatives ===
|
|
703
|
+
self.hard_negatives = {}
|
|
704
|
+
self.num_hard_neg = getattr(self.args, "num_hard_neg_per_sample", 0)
|
|
705
|
+
hard_neg_path = getattr(self.args, "hard_neg_path", None)
|
|
706
|
+
|
|
707
|
+
if hard_neg_path and os.path.exists(hard_neg_path) and split == 'train':
|
|
708
|
+
import json
|
|
709
|
+
with open(hard_neg_path, 'r', encoding='utf-8') as f:
|
|
710
|
+
self.hard_negatives = json.load(f)
|
|
711
|
+
print(f"Loaded hard negatives for {len(self.hard_negatives)} molecules")
|
|
712
|
+
print(f"Using {self.num_hard_neg} hard negatives per sample")
|
|
713
|
+
else:
|
|
714
|
+
if split == 'train':
|
|
715
|
+
print(f"No valid hard_neg_path found ({hard_neg_path}), skipping hard negatives.")
|
|
716
|
+
|
|
717
|
+
# === MS2 prediction checks ===
|
|
718
|
+
if self.use_ms2_prediction:
|
|
719
|
+
if not self.prediction_label_columns or not isinstance(self.prediction_label_columns, list):
|
|
720
|
+
raise ValueError("When MS2 prediction task is enabled, a list of column names 'prediction_label_columns' must be provided.")
|
|
721
|
+
|
|
722
|
+
for col in self.prediction_label_columns:
|
|
723
|
+
if col not in self.meta_data.columns:
|
|
724
|
+
raise ValueError(f"Column '{col}' is not found in meta_data.")
|
|
725
|
+
|
|
726
|
+
self.num_ms2_classes = len(self.prediction_label_columns)
|
|
727
|
+
print(f"Found {self.num_ms2_classes} label columns for multilabel prediction task: {self.prediction_label_columns}")
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def __len__(self):
|
|
731
|
+
return len(self.ms2_ids)
|
|
732
|
+
|
|
733
|
+
def _create_mlm_inputs(self, input_ids):
|
|
734
|
+
"""为MLM任务创建掩码输入和标签"""
|
|
735
|
+
labels = input_ids.clone()
|
|
736
|
+
probability_matrix = torch.full(labels.shape, 0.15) # 15%的概率进行mask
|
|
737
|
+
|
|
738
|
+
# 避免mask特殊tokens (e.g., [CLS], [SEP], [PAD])
|
|
739
|
+
special_tokens_mask = self.tokenizer.get_special_tokens_mask(labels.tolist(), already_has_special_tokens=True)
|
|
740
|
+
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
|
|
741
|
+
|
|
742
|
+
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
|
|
743
|
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
|
744
|
+
|
|
745
|
+
# 将未被mask的token的label设置为-100,这样在计算loss时会被忽略
|
|
746
|
+
labels[~masked_indices] = -100
|
|
747
|
+
|
|
748
|
+
# 80% 的概率用 [MASK] token 替换
|
|
749
|
+
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
|
|
750
|
+
input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
|
|
751
|
+
|
|
752
|
+
# 10% 的概率用随机token替换
|
|
753
|
+
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
|
754
|
+
random_words = torch.randint(len(self.tokenizer.vocab), labels.shape, dtype=torch.long)
|
|
755
|
+
input_ids[indices_random] = random_words[indices_random]
|
|
756
|
+
|
|
757
|
+
return input_ids, labels
|
|
758
|
+
|
|
759
|
+
def __getitem__(self, idx):
|
|
760
|
+
ms2_id = self.ms2_ids[idx]
|
|
761
|
+
tensor_data = self.preprocessed_ms2_tensors[ms2_id]
|
|
762
|
+
mz = tensor_data['mz']
|
|
763
|
+
intensity = tensor_data['intensity']
|
|
764
|
+
|
|
765
|
+
# Dynamic augmentation
|
|
766
|
+
if self.split == 'train' and hasattr(self, 'word2idx'):
|
|
767
|
+
mz, intensity = augment_tokenized_ms2_optimized(
|
|
768
|
+
mz, intensity, self.word2idx, self.args
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
batch = {
|
|
772
|
+
'mz': tensor_data['mz'],
|
|
773
|
+
'intensity': tensor_data['intensity'].unsqueeze(0)
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
# === 统一的BioText处理逻辑 ===
|
|
777
|
+
molecule_id = tensor_data['molecule_id']
|
|
778
|
+
biotext = ""
|
|
779
|
+
paraphrase_text = None
|
|
780
|
+
all_candidate_texts = []
|
|
781
|
+
|
|
782
|
+
if molecule_id in self.biotext_data:
|
|
783
|
+
entry = self.biotext_data[molecule_id]
|
|
784
|
+
|
|
785
|
+
if isinstance(entry, list):
|
|
786
|
+
all_candidate_texts = [record['text'] for record in entry]
|
|
787
|
+
|
|
788
|
+
# *** 关键修改:使用sampler分配的text index ***
|
|
789
|
+
if ms2_id in self.text_assignment:
|
|
790
|
+
text_idx = self.text_assignment[ms2_id]
|
|
791
|
+
biotext = all_candidate_texts[text_idx]
|
|
792
|
+
else:
|
|
793
|
+
# fallback: 随机选择(shouldn't happen in training)
|
|
794
|
+
biotext = random.choice(entry)['text']
|
|
795
|
+
|
|
796
|
+
# Paraphrase: 如果需要,从剩余的候选中选一个不同的
|
|
797
|
+
if self.use_paraphrase and len(entry) >= 2:
|
|
798
|
+
remaining_indices = [i for i in range(len(all_candidate_texts))
|
|
799
|
+
if i != text_idx]
|
|
800
|
+
if remaining_indices:
|
|
801
|
+
para_idx = random.choice(remaining_indices)
|
|
802
|
+
paraphrase_text = all_candidate_texts[para_idx]
|
|
803
|
+
|
|
804
|
+
elif isinstance(entry, dict):
|
|
805
|
+
original = entry.get("original", "")
|
|
806
|
+
paraphrases = entry.get("paraphrases", [])
|
|
807
|
+
all_candidates = [original] + paraphrases
|
|
808
|
+
all_candidate_texts = all_candidates
|
|
809
|
+
|
|
810
|
+
# *** 同样的逻辑 ***
|
|
811
|
+
if ms2_id in self.text_assignment:
|
|
812
|
+
text_idx = self.text_assignment[ms2_id]
|
|
813
|
+
biotext = all_candidates[text_idx]
|
|
814
|
+
else:
|
|
815
|
+
biotext = original
|
|
816
|
+
|
|
817
|
+
if self.use_paraphrase and len(all_candidates) >= 2:
|
|
818
|
+
remaining = [i for i in range(len(all_candidates)) if i != text_idx]
|
|
819
|
+
if remaining:
|
|
820
|
+
para_idx = random.choice(remaining)
|
|
821
|
+
paraphrase_text = all_candidates[para_idx]
|
|
822
|
+
|
|
823
|
+
elif isinstance(entry, str):
|
|
824
|
+
biotext = entry
|
|
825
|
+
all_candidate_texts = [entry]
|
|
826
|
+
else:
|
|
827
|
+
print(f"⚠️ Warning: Molecule ID '{molecule_id}' missing BioText")
|
|
828
|
+
|
|
829
|
+
# === 后续tokenization等保持不变 ===
|
|
830
|
+
encoded_text = self.tokenizer(
|
|
831
|
+
biotext,
|
|
832
|
+
padding="max_length",
|
|
833
|
+
truncation=True,
|
|
834
|
+
max_length=self.max_length,
|
|
835
|
+
return_tensors="pt"
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
input_ids = encoded_text['input_ids'].squeeze(0)
|
|
839
|
+
attention_mask = encoded_text['attention_mask'].squeeze(0)
|
|
840
|
+
|
|
841
|
+
batch['text_input_ids'] = input_ids
|
|
842
|
+
batch['text_attention_mask'] = attention_mask
|
|
843
|
+
batch['all_candidate_texts'] = all_candidate_texts
|
|
844
|
+
|
|
845
|
+
# MLM task
|
|
846
|
+
if self.use_mlm:
|
|
847
|
+
masked_input_ids, mlm_labels = self._create_mlm_inputs(input_ids.clone())
|
|
848
|
+
batch['masked_text_input_ids'] = masked_input_ids
|
|
849
|
+
batch['mlm_labels'] = mlm_labels
|
|
850
|
+
|
|
851
|
+
# Paraphrase
|
|
852
|
+
if paraphrase_text is not None:
|
|
853
|
+
encoded_para = self.tokenizer(
|
|
854
|
+
paraphrase_text,
|
|
855
|
+
padding="max_length",
|
|
856
|
+
truncation=True,
|
|
857
|
+
max_length=self.max_length,
|
|
858
|
+
return_tensors="pt"
|
|
859
|
+
)
|
|
860
|
+
batch['paraphrase_input_ids'] = encoded_para['input_ids'].squeeze(0)
|
|
861
|
+
batch['paraphrase_attention_mask'] = encoded_para['attention_mask'].squeeze(0)
|
|
862
|
+
batch['has_paraphrase'] = True
|
|
863
|
+
else:
|
|
864
|
+
batch['has_paraphrase'] = False
|
|
865
|
+
|
|
866
|
+
batch['ms2_id'] = ms2_id
|
|
867
|
+
batch['molecule_id'] = molecule_id
|
|
868
|
+
batch['original_text'] = biotext
|
|
869
|
+
|
|
870
|
+
return batch
|
|
871
|
+
|
|
872
|
+
@staticmethod
|
|
873
|
+
def custom_collate_fn(batch_list):
|
|
874
|
+
"""
|
|
875
|
+
自定义collate_fn,用于处理字典形式的批次数据。
|
|
876
|
+
支持可选的keys,并过滤与batch正样本冲突的hard negatives。
|
|
877
|
+
"""
|
|
878
|
+
if not batch_list:
|
|
879
|
+
return {}
|
|
880
|
+
|
|
881
|
+
# 收集batch中所有正样本的molecule_id
|
|
882
|
+
batch_molecule_ids = set()
|
|
883
|
+
for sample in batch_list:
|
|
884
|
+
if 'molecule_id' in sample:
|
|
885
|
+
batch_molecule_ids.add(sample['molecule_id'])
|
|
886
|
+
|
|
887
|
+
# 收集所有可能的keys
|
|
888
|
+
all_keys = set()
|
|
889
|
+
for d in batch_list:
|
|
890
|
+
all_keys.update(d.keys())
|
|
891
|
+
|
|
892
|
+
# 按key分组收集数据
|
|
893
|
+
collated_batch = {}
|
|
894
|
+
for key in all_keys:
|
|
895
|
+
values = [d[key] for d in batch_list if key in d]
|
|
896
|
+
collated_batch[key] = values
|
|
897
|
+
|
|
898
|
+
# === 过滤冲突的hard negatives ===
|
|
899
|
+
if 'hard_neg_input_ids' in collated_batch and len(batch_molecule_ids) > 0:
|
|
900
|
+
filtered_hard_neg_ids = []
|
|
901
|
+
filtered_hard_neg_masks = []
|
|
902
|
+
filtered_has_hard_neg = []
|
|
903
|
+
|
|
904
|
+
for i, sample in enumerate(batch_list):
|
|
905
|
+
if sample.get('has_hard_neg', False):
|
|
906
|
+
hard_neg_ids = sample['hard_neg_input_ids'] # [num_hard_neg, seq_len]
|
|
907
|
+
hard_neg_mask = sample['hard_neg_attention_mask']
|
|
908
|
+
hard_neg_mol_ids = sample.get('hard_neg_molecule_ids', [])
|
|
909
|
+
|
|
910
|
+
# 找出不在batch中的hard negatives的索引
|
|
911
|
+
valid_indices = []
|
|
912
|
+
for j, neg_mol_id in enumerate(hard_neg_mol_ids):
|
|
913
|
+
if neg_mol_id not in batch_molecule_ids:
|
|
914
|
+
valid_indices.append(j)
|
|
915
|
+
|
|
916
|
+
if valid_indices:
|
|
917
|
+
# 只保留不冲突的hard negatives
|
|
918
|
+
filtered_hard_neg_ids.append(hard_neg_ids[valid_indices])
|
|
919
|
+
filtered_hard_neg_masks.append(hard_neg_mask[valid_indices])
|
|
920
|
+
filtered_has_hard_neg.append(True)
|
|
921
|
+
else:
|
|
922
|
+
# 所有hard negatives都冲突,设为空tensor
|
|
923
|
+
max_length = sample['text_input_ids'].shape[0]
|
|
924
|
+
filtered_hard_neg_ids.append(
|
|
925
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
926
|
+
)
|
|
927
|
+
filtered_hard_neg_masks.append(
|
|
928
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
929
|
+
)
|
|
930
|
+
filtered_has_hard_neg.append(False)
|
|
931
|
+
else:
|
|
932
|
+
# 原本就没有hard negatives
|
|
933
|
+
max_length = batch_list[0]['text_input_ids'].shape[0]
|
|
934
|
+
filtered_hard_neg_ids.append(
|
|
935
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
936
|
+
)
|
|
937
|
+
filtered_hard_neg_masks.append(
|
|
938
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
939
|
+
)
|
|
940
|
+
filtered_has_hard_neg.append(False)
|
|
941
|
+
|
|
942
|
+
collated_batch['hard_neg_input_ids'] = filtered_hard_neg_ids
|
|
943
|
+
collated_batch['hard_neg_attention_mask'] = filtered_hard_neg_masks
|
|
944
|
+
collated_batch['has_hard_neg'] = filtered_has_hard_neg
|
|
945
|
+
|
|
946
|
+
# 逐个处理字典中的键值
|
|
947
|
+
final_batch = {}
|
|
948
|
+
for key, values in collated_batch.items():
|
|
949
|
+
# 可选的boolean flags
|
|
950
|
+
if key in ['has_paraphrase', 'has_hard_neg']:
|
|
951
|
+
final_batch[key] = [d.get(key, False) for d in batch_list]
|
|
952
|
+
|
|
953
|
+
# 保持为list的keys(包括新增的hard_neg_molecule_ids)
|
|
954
|
+
elif key in ['hard_neg_input_ids', 'hard_neg_attention_mask',
|
|
955
|
+
'paraphrase_input_ids', 'paraphrase_attention_mask',
|
|
956
|
+
'hard_neg_molecule_ids']:
|
|
957
|
+
final_batch[key] = values
|
|
958
|
+
|
|
959
|
+
# 对张量进行堆叠
|
|
960
|
+
elif isinstance(values[0], torch.Tensor):
|
|
961
|
+
if len(values) == len(batch_list):
|
|
962
|
+
final_batch[key] = torch.stack(values)
|
|
963
|
+
else:
|
|
964
|
+
final_batch[key] = values
|
|
965
|
+
|
|
966
|
+
# 其他类型直接保留
|
|
967
|
+
else:
|
|
968
|
+
final_batch[key] = values
|
|
969
|
+
|
|
970
|
+
return final_batch
|
|
971
|
+
|
|
972
|
+
@staticmethod
|
|
973
|
+
def custom_collate_fn(batch_list):
|
|
974
|
+
"""
|
|
975
|
+
自定义collate_fn,识别batch中有text overlap的样本对
|
|
976
|
+
"""
|
|
977
|
+
if not batch_list:
|
|
978
|
+
return {}
|
|
979
|
+
|
|
980
|
+
batch_size = len(batch_list)
|
|
981
|
+
|
|
982
|
+
# 收集batch中所有正样本的molecule_id
|
|
983
|
+
batch_molecule_ids = set()
|
|
984
|
+
for sample in batch_list:
|
|
985
|
+
if 'molecule_id' in sample:
|
|
986
|
+
batch_molecule_ids.add(sample['molecule_id'])
|
|
987
|
+
|
|
988
|
+
# === NEW: 构建text overlap矩阵 ===
|
|
989
|
+
# text_overlap[i][j] = 1 表示样本i和样本j有共享的候选text
|
|
990
|
+
text_overlap = torch.zeros(batch_size, batch_size, dtype=torch.float32)
|
|
991
|
+
|
|
992
|
+
for i in range(batch_size):
|
|
993
|
+
for j in range(batch_size):
|
|
994
|
+
if i == j:
|
|
995
|
+
text_overlap[i, j] = 1.0 # 自己和自己肯定overlap
|
|
996
|
+
else:
|
|
997
|
+
# 检查候选text集合是否有交集
|
|
998
|
+
texts_i = set(batch_list[i].get('all_candidate_texts', []))
|
|
999
|
+
texts_j = set(batch_list[j].get('all_candidate_texts', []))
|
|
1000
|
+
|
|
1001
|
+
if texts_i & texts_j: # 有交集
|
|
1002
|
+
text_overlap[i, j] = 1.0
|
|
1003
|
+
|
|
1004
|
+
# 收集所有可能的keys
|
|
1005
|
+
all_keys = set()
|
|
1006
|
+
for d in batch_list:
|
|
1007
|
+
all_keys.update(d.keys())
|
|
1008
|
+
|
|
1009
|
+
# 按key分组收集数据
|
|
1010
|
+
collated_batch = {}
|
|
1011
|
+
for key in all_keys:
|
|
1012
|
+
values = [d[key] for d in batch_list if key in d]
|
|
1013
|
+
collated_batch[key] = values
|
|
1014
|
+
|
|
1015
|
+
# === 过滤冲突的hard negatives ===
|
|
1016
|
+
if 'hard_neg_input_ids' in collated_batch and len(batch_molecule_ids) > 0:
|
|
1017
|
+
filtered_hard_neg_ids = []
|
|
1018
|
+
filtered_hard_neg_masks = []
|
|
1019
|
+
filtered_has_hard_neg = []
|
|
1020
|
+
|
|
1021
|
+
for i, sample in enumerate(batch_list):
|
|
1022
|
+
if sample.get('has_hard_neg', False):
|
|
1023
|
+
hard_neg_ids = sample['hard_neg_input_ids']
|
|
1024
|
+
hard_neg_mask = sample['hard_neg_attention_mask']
|
|
1025
|
+
hard_neg_mol_ids = sample.get('hard_neg_molecule_ids', [])
|
|
1026
|
+
|
|
1027
|
+
valid_indices = []
|
|
1028
|
+
for j, neg_mol_id in enumerate(hard_neg_mol_ids):
|
|
1029
|
+
if neg_mol_id not in batch_molecule_ids:
|
|
1030
|
+
valid_indices.append(j)
|
|
1031
|
+
|
|
1032
|
+
if valid_indices:
|
|
1033
|
+
filtered_hard_neg_ids.append(hard_neg_ids[valid_indices])
|
|
1034
|
+
filtered_hard_neg_masks.append(hard_neg_mask[valid_indices])
|
|
1035
|
+
filtered_has_hard_neg.append(True)
|
|
1036
|
+
else:
|
|
1037
|
+
max_length = sample['text_input_ids'].shape[0]
|
|
1038
|
+
filtered_hard_neg_ids.append(
|
|
1039
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
1040
|
+
)
|
|
1041
|
+
filtered_hard_neg_masks.append(
|
|
1042
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
1043
|
+
)
|
|
1044
|
+
filtered_has_hard_neg.append(False)
|
|
1045
|
+
else:
|
|
1046
|
+
max_length = batch_list[0]['text_input_ids'].shape[0]
|
|
1047
|
+
filtered_hard_neg_ids.append(
|
|
1048
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
1049
|
+
)
|
|
1050
|
+
filtered_hard_neg_masks.append(
|
|
1051
|
+
torch.zeros((0, max_length), dtype=torch.long)
|
|
1052
|
+
)
|
|
1053
|
+
filtered_has_hard_neg.append(False)
|
|
1054
|
+
|
|
1055
|
+
collated_batch['hard_neg_input_ids'] = filtered_hard_neg_ids
|
|
1056
|
+
collated_batch['hard_neg_attention_mask'] = filtered_hard_neg_masks
|
|
1057
|
+
collated_batch['has_hard_neg'] = filtered_has_hard_neg
|
|
1058
|
+
|
|
1059
|
+
# 逐个处理字典中的键值
|
|
1060
|
+
final_batch = {}
|
|
1061
|
+
for key, values in collated_batch.items():
|
|
1062
|
+
if key in ['has_paraphrase', 'has_hard_neg']:
|
|
1063
|
+
final_batch[key] = [d.get(key, False) for d in batch_list]
|
|
1064
|
+
elif key in ['hard_neg_input_ids', 'hard_neg_attention_mask',
|
|
1065
|
+
'paraphrase_input_ids', 'paraphrase_attention_mask',
|
|
1066
|
+
'hard_neg_molecule_ids', 'all_candidate_texts']: # all_candidate_texts保持list
|
|
1067
|
+
final_batch[key] = values
|
|
1068
|
+
elif isinstance(values[0], torch.Tensor):
|
|
1069
|
+
if len(values) == len(batch_list):
|
|
1070
|
+
final_batch[key] = torch.stack(values)
|
|
1071
|
+
else:
|
|
1072
|
+
final_batch[key] = values
|
|
1073
|
+
else:
|
|
1074
|
+
final_batch[key] = values
|
|
1075
|
+
|
|
1076
|
+
# === 添加text_overlap信息 ===
|
|
1077
|
+
final_batch['text_overlap_matrix'] = text_overlap # [batch_size, batch_size]
|
|
1078
|
+
|
|
1079
|
+
return final_batch
|
|
1080
|
+
|
|
1081
|
+
|
|
1082
|
+
@staticmethod
|
|
1083
|
+
def load_hmdb_data_subsections(first_path, second_path, jsonl_path, max_text_sharing=5):
|
|
1084
|
+
"""
|
|
1085
|
+
使用subsections JSONL格式读取数据,并过滤高频共享的text
|
|
1086
|
+
|
|
1087
|
+
参数:
|
|
1088
|
+
first_path (str): MS2数据文件路径(h5、pkl等)
|
|
1089
|
+
second_path (str): Meta数据文件路径(parquet、csv等)
|
|
1090
|
+
jsonl_path (str): BioText的jsonl文件路径
|
|
1091
|
+
max_text_sharing (int): text最多可以被多少个molecule共享,超过则删除
|
|
1092
|
+
|
|
1093
|
+
返回:
|
|
1094
|
+
tuple: (ms2_data, meta_data, biotext_data)
|
|
1095
|
+
biotext_data格式: {molecule_id: [{'type': 'xxx', 'text': 'xxx'}, ...]}
|
|
1096
|
+
"""
|
|
1097
|
+
from collections import defaultdict
|
|
1098
|
+
|
|
1099
|
+
# 读取MS2数据
|
|
1100
|
+
ms2_data = {}
|
|
1101
|
+
try:
|
|
1102
|
+
_, ext1 = os.path.splitext(first_path)
|
|
1103
|
+
if ext1 == '.h5':
|
|
1104
|
+
with h5py.File(first_path, 'r') as f:
|
|
1105
|
+
spectra_group = f['spectra']
|
|
1106
|
+
for spectrum_id in spectra_group.keys():
|
|
1107
|
+
group = spectra_group[spectrum_id]
|
|
1108
|
+
parts = spectrum_id.split('_')
|
|
1109
|
+
molecule_id = parts[0]
|
|
1110
|
+
ms2_data[spectrum_id] = {
|
|
1111
|
+
'mz': group['mz'][...].tolist(),
|
|
1112
|
+
'intensity': group['intensity'][...].tolist(),
|
|
1113
|
+
'molecule_id': molecule_id
|
|
1114
|
+
}
|
|
1115
|
+
elif ext1 == '.pkl':
|
|
1116
|
+
with open(first_path, 'rb') as f:
|
|
1117
|
+
ms2_data = pickle.load(f)
|
|
1118
|
+
else:
|
|
1119
|
+
print(f"Unsupported file format for first path: {ext1}")
|
|
1120
|
+
return None, None, None
|
|
1121
|
+
except Exception as e:
|
|
1122
|
+
print(f"Error: Failed to read first file: {first_path}. Error message: {str(e)}")
|
|
1123
|
+
return None, None, None
|
|
1124
|
+
|
|
1125
|
+
# 读取Meta数据
|
|
1126
|
+
meta_data = None
|
|
1127
|
+
try:
|
|
1128
|
+
_, ext2 = os.path.splitext(second_path)
|
|
1129
|
+
if ext2 == '.parquet':
|
|
1130
|
+
meta_data = pd.read_parquet(second_path)
|
|
1131
|
+
elif ext2 == '.csv':
|
|
1132
|
+
meta_data = pd.read_csv(second_path)
|
|
1133
|
+
else:
|
|
1134
|
+
print(f"Unsupported file format for second path: {ext2}")
|
|
1135
|
+
return ms2_data, None, None
|
|
1136
|
+
except Exception as e:
|
|
1137
|
+
print(f"Error: Failed to read second file: {second_path}. Error message: {str(e)}")
|
|
1138
|
+
return ms2_data, None, None
|
|
1139
|
+
|
|
1140
|
+
# 读取 BioText JSONL 文件
|
|
1141
|
+
biotext_data = {}
|
|
1142
|
+
try:
|
|
1143
|
+
import json
|
|
1144
|
+
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
|
1145
|
+
for line in f:
|
|
1146
|
+
item = json.loads(line.strip())
|
|
1147
|
+
accession = item['accession']
|
|
1148
|
+
|
|
1149
|
+
if accession not in biotext_data:
|
|
1150
|
+
biotext_data[accession] = []
|
|
1151
|
+
|
|
1152
|
+
biotext_data[accession].append({
|
|
1153
|
+
'type': item['type'],
|
|
1154
|
+
'text': item['text']
|
|
1155
|
+
})
|
|
1156
|
+
|
|
1157
|
+
print(f"✓ Loaded BioText subsections from {os.path.basename(jsonl_path)}")
|
|
1158
|
+
print(f" Total molecules: {len(biotext_data)}")
|
|
1159
|
+
total_records_before = sum(len(v) for v in biotext_data.values())
|
|
1160
|
+
print(f" Total records (before filtering): {total_records_before}, Avg per molecule: {total_records_before / len(biotext_data):.1f}")
|
|
1161
|
+
|
|
1162
|
+
except Exception as e:
|
|
1163
|
+
print(f"Error reading BioText JSONL: {e}")
|
|
1164
|
+
return ms2_data, meta_data, None
|
|
1165
|
+
|
|
1166
|
+
# ===== 过滤高频共享的text =====
|
|
1167
|
+
print(f"\n=== Filtering texts shared by >{max_text_sharing} molecules ===")
|
|
1168
|
+
|
|
1169
|
+
# 1. 构建text -> molecules的倒排索引
|
|
1170
|
+
text_to_molecules = defaultdict(set)
|
|
1171
|
+
for mol_id, records in biotext_data.items():
|
|
1172
|
+
for record in records:
|
|
1173
|
+
text = record['text']
|
|
1174
|
+
if text: # 避免空字符串
|
|
1175
|
+
text_to_molecules[text].add(mol_id)
|
|
1176
|
+
|
|
1177
|
+
# 2. 找出需要删除的高频text
|
|
1178
|
+
texts_to_remove = set()
|
|
1179
|
+
sharing_distribution = defaultdict(int) # 统计分布
|
|
1180
|
+
|
|
1181
|
+
for text, molecules in text_to_molecules.items():
|
|
1182
|
+
sharing_count = len(molecules)
|
|
1183
|
+
sharing_distribution[sharing_count] += 1
|
|
1184
|
+
|
|
1185
|
+
if sharing_count > max_text_sharing:
|
|
1186
|
+
texts_to_remove.add(text)
|
|
1187
|
+
|
|
1188
|
+
print(f"Text sharing distribution (top 10):")
|
|
1189
|
+
for count in sorted(sharing_distribution.keys(), reverse=True)[:10]:
|
|
1190
|
+
print(f" {count} molecules share: {sharing_distribution[count]} texts")
|
|
1191
|
+
|
|
1192
|
+
print(f"\nFound {len(texts_to_remove)} texts to remove (shared by >{max_text_sharing} molecules)")
|
|
1193
|
+
|
|
1194
|
+
# 3. 从每个molecule的候选text中删除这些高频text
|
|
1195
|
+
filtered_biotext_data = {}
|
|
1196
|
+
total_removed = 0
|
|
1197
|
+
molecules_with_no_text = []
|
|
1198
|
+
|
|
1199
|
+
for mol_id, records in biotext_data.items():
|
|
1200
|
+
filtered_records = [record for record in records
|
|
1201
|
+
if record['text'] not in texts_to_remove]
|
|
1202
|
+
|
|
1203
|
+
if filtered_records:
|
|
1204
|
+
filtered_biotext_data[mol_id] = filtered_records
|
|
1205
|
+
total_removed += len(records) - len(filtered_records)
|
|
1206
|
+
else:
|
|
1207
|
+
molecules_with_no_text.append(mol_id)
|
|
1208
|
+
total_removed += len(records)
|
|
1209
|
+
|
|
1210
|
+
# 4. 统计信息
|
|
1211
|
+
print(f"\nFiltering results:")
|
|
1212
|
+
print(f" Text entries removed: {total_removed}")
|
|
1213
|
+
print(f" Molecules before: {len(biotext_data)}")
|
|
1214
|
+
print(f" Molecules after: {len(filtered_biotext_data)}")
|
|
1215
|
+
print(f" Molecules with no text left: {len(molecules_with_no_text)}")
|
|
1216
|
+
|
|
1217
|
+
if molecules_with_no_text:
|
|
1218
|
+
print(f" ⚠️ Warning: {len(molecules_with_no_text)} molecules lost all texts")
|
|
1219
|
+
if len(molecules_with_no_text) <= 5:
|
|
1220
|
+
print(f" Lost: {molecules_with_no_text}")
|
|
1221
|
+
else:
|
|
1222
|
+
print(f" First 5: {molecules_with_no_text[:5]}")
|
|
1223
|
+
|
|
1224
|
+
# 5. 验证过滤效果
|
|
1225
|
+
text_to_molecules_after = defaultdict(set)
|
|
1226
|
+
for mol_id, records in filtered_biotext_data.items():
|
|
1227
|
+
for record in records:
|
|
1228
|
+
text = record['text']
|
|
1229
|
+
if text:
|
|
1230
|
+
text_to_molecules_after[text].add(mol_id)
|
|
1231
|
+
|
|
1232
|
+
max_sharing_after = max(len(mols) for mols in text_to_molecules_after.values()) if text_to_molecules_after else 0
|
|
1233
|
+
shared_texts_after = sum(1 for mols in text_to_molecules_after.values() if len(mols) > 1)
|
|
1234
|
+
|
|
1235
|
+
print(f" Max sharing after filtering: {max_sharing_after} molecules")
|
|
1236
|
+
print(f" Texts still shared by multiple molecules: {shared_texts_after}")
|
|
1237
|
+
|
|
1238
|
+
total_records_after = sum(len(v) for v in filtered_biotext_data.values())
|
|
1239
|
+
print(f" Total records (after filtering): {total_records_after}, Avg per molecule: {total_records_after / len(filtered_biotext_data):.1f}")
|
|
1240
|
+
|
|
1241
|
+
# 使用过滤后的数据
|
|
1242
|
+
biotext_data = filtered_biotext_data
|
|
1243
|
+
|
|
1244
|
+
# 打印统计
|
|
1245
|
+
unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
|
|
1246
|
+
print(f"\nFinal data summary:")
|
|
1247
|
+
print(f" Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
|
|
1248
|
+
print(f" Molecule IDs in BioText data: {len(biotext_data)}")
|
|
1249
|
+
|
|
1250
|
+
return ms2_data, meta_data, biotext_data
|
|
1251
|
+
|
|
1252
|
+
@staticmethod
|
|
1253
|
+
def missing_biotext_handling(ms2_data, biotext_data, method="drop"):
|
|
1254
|
+
"""
|
|
1255
|
+
ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
|
|
1256
|
+
biotext_data: dict, {molecule_id: BioText}
|
|
1257
|
+
"""
|
|
1258
|
+
# If a molecule in ms2_data is missing in biotext_data, remove it from ms2_data
|
|
1259
|
+
# Handle missing biotext entries
|
|
1260
|
+
if method == "drop":
|
|
1261
|
+
ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in biotext_data}
|
|
1262
|
+
unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
|
|
1263
|
+
print(f"Post-processing statistics ('drop' method) - Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
|
|
1264
|
+
print(f"Post-processing statistics ('drop' method) - Molecule IDs in BioText data: {len(biotext_data)}")
|
|
1265
|
+
return ms2_data, biotext_data
|
|
1266
|
+
|
|
1267
|
+
if method == "fill":
|
|
1268
|
+
# Fill missing entries with empty values; will be handled during dataset initialization
|
|
1269
|
+
for info in ms2_data.values():
|
|
1270
|
+
molecule_id = info['molecule_id']
|
|
1271
|
+
if molecule_id not in biotext_data:
|
|
1272
|
+
biotext_data[molecule_id] = ""
|
|
1273
|
+
unique_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
|
|
1274
|
+
print(f"Post-processing statistics ('fill' method) - Unique molecule IDs in MS2 data: {len(unique_molecule_ids)}")
|
|
1275
|
+
print(f"Post-processing statistics ('fill' method) - Molecule IDs in BioText data: {len(biotext_data)}")
|
|
1276
|
+
return ms2_data, biotext_data
|
|
1277
|
+
|
|
1278
|
+
raise ValueError(f"Unknown method: {method}. Method must be 'drop' or 'fill'.")
|
|
1279
|
+
|
|
1280
|
+
|
|
1281
|
+
@staticmethod
|
|
1282
|
+
def add_noise_peaks(peaks, intensities, noise_ratio=0.5, noise_intensity_range=(0.001, 0.05), seed=None):
|
|
1283
|
+
"""
|
|
1284
|
+
添加随机noise peaks来模拟外部数据
|
|
1285
|
+
|
|
1286
|
+
Args:
|
|
1287
|
+
peaks: list of float, 原始m/z值
|
|
1288
|
+
intensities: list of float, 原始强度值
|
|
1289
|
+
noise_ratio: float, 添加的noise peaks数量 = 原peaks数 × noise_ratio
|
|
1290
|
+
noise_intensity_range: tuple, noise的相对强度范围(相对于max intensity)
|
|
1291
|
+
seed: int, 随机种子(可选)
|
|
1292
|
+
|
|
1293
|
+
Returns:
|
|
1294
|
+
aug_peaks: list of float, 添加noise后的m/z
|
|
1295
|
+
aug_intensities: list of float, 添加noise后的强度
|
|
1296
|
+
"""
|
|
1297
|
+
import numpy as np
|
|
1298
|
+
import random
|
|
1299
|
+
|
|
1300
|
+
if seed is not None:
|
|
1301
|
+
np.random.seed(seed)
|
|
1302
|
+
random.seed(seed)
|
|
1303
|
+
|
|
1304
|
+
if len(peaks) == 0:
|
|
1305
|
+
return peaks, intensities
|
|
1306
|
+
|
|
1307
|
+
max_int = max(intensities)
|
|
1308
|
+
if max_int == 0:
|
|
1309
|
+
return peaks, intensities
|
|
1310
|
+
|
|
1311
|
+
# 计算要添加的noise数量
|
|
1312
|
+
n_noise = int(len(peaks) * noise_ratio)
|
|
1313
|
+
if n_noise == 0:
|
|
1314
|
+
return peaks, intensities
|
|
1315
|
+
|
|
1316
|
+
# 在光谱范围内随机生成noise peaks的m/z
|
|
1317
|
+
mz_min, mz_max = min(peaks), max(peaks)
|
|
1318
|
+
noise_mz = np.random.uniform(mz_min, mz_max, n_noise).tolist()
|
|
1319
|
+
|
|
1320
|
+
# 生成低强度noise(相对于max intensity)
|
|
1321
|
+
noise_int = np.random.uniform(
|
|
1322
|
+
noise_intensity_range[0] * max_int,
|
|
1323
|
+
noise_intensity_range[1] * max_int,
|
|
1324
|
+
n_noise
|
|
1325
|
+
).tolist()
|
|
1326
|
+
|
|
1327
|
+
# 合并原始peaks和noise
|
|
1328
|
+
aug_peaks = peaks + noise_mz
|
|
1329
|
+
aug_intensities = intensities + noise_int
|
|
1330
|
+
|
|
1331
|
+
# 按m/z排序
|
|
1332
|
+
sorted_indices = sorted(range(len(aug_peaks)), key=lambda i: aug_peaks[i])
|
|
1333
|
+
aug_peaks = [aug_peaks[i] for i in sorted_indices]
|
|
1334
|
+
aug_intensities = [aug_intensities[i] for i in sorted_indices]
|
|
1335
|
+
|
|
1336
|
+
return aug_peaks, aug_intensities
|
|
1337
|
+
|
|
1338
|
+
|
|
1339
|
+
@staticmethod
|
|
1340
|
+
def filter_low_intensity_peaks(peaks, intensities, threshold=0.01):
|
|
1341
|
+
"""
|
|
1342
|
+
过滤低强度peaks
|
|
1343
|
+
|
|
1344
|
+
Args:
|
|
1345
|
+
peaks: list of float, m/z值
|
|
1346
|
+
intensities: list of float, 强度值
|
|
1347
|
+
threshold: float, 相对强度阈值(0.01 = 1%)
|
|
1348
|
+
|
|
1349
|
+
Returns:
|
|
1350
|
+
filtered_peaks: list of float
|
|
1351
|
+
filtered_intensities: list of float
|
|
1352
|
+
"""
|
|
1353
|
+
if len(peaks) == 0 or len(intensities) == 0:
|
|
1354
|
+
return peaks, intensities
|
|
1355
|
+
|
|
1356
|
+
max_int = max(intensities)
|
|
1357
|
+
if max_int == 0:
|
|
1358
|
+
return peaks, intensities
|
|
1359
|
+
|
|
1360
|
+
# 归一化并过滤
|
|
1361
|
+
norm_intensities = [i / max_int for i in intensities]
|
|
1362
|
+
filtered_peaks = []
|
|
1363
|
+
filtered_intensities = []
|
|
1364
|
+
|
|
1365
|
+
for mz, intensity, norm_int in zip(peaks, intensities, norm_intensities):
|
|
1366
|
+
if norm_int >= threshold:
|
|
1367
|
+
filtered_peaks.append(mz)
|
|
1368
|
+
filtered_intensities.append(intensity)
|
|
1369
|
+
|
|
1370
|
+
return filtered_peaks, filtered_intensities
|
|
1371
|
+
|
|
1372
|
+
|
|
1373
|
+
@staticmethod
|
|
1374
|
+
def augment_ms2_data(ms2_data, args):
|
|
1375
|
+
"""
|
|
1376
|
+
对MS2数据进行增强(必须在preprocess之前调用)
|
|
1377
|
+
|
|
1378
|
+
Args:
|
|
1379
|
+
ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
|
|
1380
|
+
注意:mz和intensity必须是原始的float值,不能是token_ids
|
|
1381
|
+
args: argparse.Namespace, 包含增强参数:
|
|
1382
|
+
- augment_noise: bool, 是否添加noise增强 (default: False)
|
|
1383
|
+
- augment_multiplier: int, 每个光谱生成几个版本 (1=不增强, 2=生成2倍数据)
|
|
1384
|
+
- noise_ratio: float, 添加的noise数量 = 原peaks数 × noise_ratio
|
|
1385
|
+
- noise_intensity_range: tuple, noise强度范围 (相对于max intensity)
|
|
1386
|
+
- filter_threshold: float or None, 过滤低强度peaks的阈值
|
|
1387
|
+
|
|
1388
|
+
Returns:
|
|
1389
|
+
augmented_ms2_data: dict, 包含原始+增强版本的数据
|
|
1390
|
+
如果augment_multiplier=1,返回原始数据
|
|
1391
|
+
如果augment_multiplier=2,返回2倍数据(原始+1个增强版本)
|
|
1392
|
+
|
|
1393
|
+
Example:
|
|
1394
|
+
>>> augmented_data = MS2BioTextDataset.augment_ms2_data(ms2_data, args)
|
|
1395
|
+
>>> processed_data, word2idx = MS2BioTextDataset.preprocess_ms2_data_positive_only(
|
|
1396
|
+
... augmented_data, meta_data
|
|
1397
|
+
... )
|
|
1398
|
+
"""
|
|
1399
|
+
import numpy as np
|
|
1400
|
+
|
|
1401
|
+
# 获取参数(兼容没有这些参数的情况)
|
|
1402
|
+
augment_noise = getattr(args, 'augment_noise', False)
|
|
1403
|
+
augment_multiplier = getattr(args, 'augment_multiplier', 1)
|
|
1404
|
+
noise_ratio = getattr(args, 'noise_ratio', 0.5)
|
|
1405
|
+
noise_intensity_range = getattr(args, 'noise_intensity_range', (0.001, 0.05))
|
|
1406
|
+
filter_threshold = getattr(args, 'filter_threshold', None)
|
|
1407
|
+
|
|
1408
|
+
# 如果不需要增强,直接返回原数据
|
|
1409
|
+
if not augment_noise or augment_multiplier <= 1:
|
|
1410
|
+
print("ℹ️ 未启用数据增强 (augment_noise=False or augment_multiplier<=1)")
|
|
1411
|
+
return ms2_data
|
|
1412
|
+
|
|
1413
|
+
print(f"\n{'='*60}")
|
|
1414
|
+
print(f"🔄 MS2数据增强")
|
|
1415
|
+
print(f"{'='*60}")
|
|
1416
|
+
print(f" 增强倍数: {augment_multiplier}x")
|
|
1417
|
+
print(f" Noise比例: {noise_ratio}")
|
|
1418
|
+
print(f" Noise强度范围: {noise_intensity_range}")
|
|
1419
|
+
if filter_threshold:
|
|
1420
|
+
print(f" 过滤阈值: {filter_threshold} (相对强度)")
|
|
1421
|
+
print(f" 原始光谱数: {len(ms2_data)}")
|
|
1422
|
+
|
|
1423
|
+
augmented_ms2_data = {}
|
|
1424
|
+
|
|
1425
|
+
for ms2_id, info in ms2_data.items():
|
|
1426
|
+
molecule_id = info.get('molecule_id')
|
|
1427
|
+
|
|
1428
|
+
# 检查数据格式
|
|
1429
|
+
if not isinstance(info['mz'], list) or not isinstance(info['intensity'], list):
|
|
1430
|
+
print(f"⚠️ 跳过 {ms2_id}: mz或intensity不是list格式")
|
|
1431
|
+
continue
|
|
1432
|
+
|
|
1433
|
+
# 版本0: 原始数据(可选过滤)
|
|
1434
|
+
peaks_original = info['mz'].copy() if isinstance(info['mz'], list) else list(info['mz'])
|
|
1435
|
+
intensities_original = info['intensity'].copy() if isinstance(info['intensity'], list) else list(info['intensity'])
|
|
1436
|
+
|
|
1437
|
+
# 可选:过滤低强度peaks
|
|
1438
|
+
if filter_threshold is not None and filter_threshold > 0:
|
|
1439
|
+
peaks_original, intensities_original = MS2BioTextDataset.filter_low_intensity_peaks(
|
|
1440
|
+
peaks_original, intensities_original, threshold=filter_threshold
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
# 保存原始版本
|
|
1444
|
+
augmented_ms2_data[ms2_id] = {
|
|
1445
|
+
'mz': peaks_original,
|
|
1446
|
+
'intensity': intensities_original,
|
|
1447
|
+
'molecule_id': molecule_id
|
|
1448
|
+
}
|
|
1449
|
+
|
|
1450
|
+
# 生成增强版本(版本1到N-1)
|
|
1451
|
+
for aug_idx in range(1, augment_multiplier):
|
|
1452
|
+
peaks_aug, intensities_aug = MS2BioTextDataset.add_noise_peaks(
|
|
1453
|
+
peaks_original.copy(),
|
|
1454
|
+
intensities_original.copy(),
|
|
1455
|
+
noise_ratio=noise_ratio,
|
|
1456
|
+
noise_intensity_range=noise_intensity_range,
|
|
1457
|
+
seed=None # 每次随机生成不同的noise
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
# 新的ID:原ID + 后缀
|
|
1461
|
+
aug_ms2_id = f"{ms2_id}_aug{aug_idx}"
|
|
1462
|
+
augmented_ms2_data[aug_ms2_id] = {
|
|
1463
|
+
'mz': peaks_aug,
|
|
1464
|
+
'intensity': intensities_aug,
|
|
1465
|
+
'molecule_id': molecule_id # 保持相同的molecule_id!
|
|
1466
|
+
}
|
|
1467
|
+
|
|
1468
|
+
print(f" ✓ 增强后光谱数: {len(augmented_ms2_data)}")
|
|
1469
|
+
print(f" 增强版本数: {len(augmented_ms2_data) - len(ms2_data)}")
|
|
1470
|
+
print(f"{'='*60}\n")
|
|
1471
|
+
|
|
1472
|
+
return augmented_ms2_data
|
|
1473
|
+
|
|
1474
|
+
@staticmethod
|
|
1475
|
+
def preprocess_ms2_data_positive_only(ms2_data, meta_data, maxlen=100, min_peaks=0):
|
|
1476
|
+
"""
|
|
1477
|
+
Preprocess ms2_data for model input.
|
|
1478
|
+
|
|
1479
|
+
Parameters:
|
|
1480
|
+
- ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
|
|
1481
|
+
- meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
|
|
1482
|
+
- maxlen: int, maximum sequence length
|
|
1483
|
+
- min_peaks: int, minimum number of peaks required (default: 0, no filtering)
|
|
1484
|
+
|
|
1485
|
+
Returns:
|
|
1486
|
+
- ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
|
|
1487
|
+
- word2idx: dict, maps string-formatted m/z values to token indices
|
|
1488
|
+
"""
|
|
1489
|
+
|
|
1490
|
+
# ===== 新增:安全转换precursor_mass的函数 =====
|
|
1491
|
+
def safe_convert_precursor(value):
|
|
1492
|
+
"""安全转换precursor_mass值,处理异常格式"""
|
|
1493
|
+
if pd.isna(value):
|
|
1494
|
+
return None
|
|
1495
|
+
|
|
1496
|
+
# 如果已经是数字
|
|
1497
|
+
if isinstance(value, (int, float)):
|
|
1498
|
+
return float(value)
|
|
1499
|
+
|
|
1500
|
+
# 如果是字符串
|
|
1501
|
+
value_str = str(value).strip()
|
|
1502
|
+
|
|
1503
|
+
# 处理空字符串
|
|
1504
|
+
if value_str == '' or value_str.lower() == 'nan':
|
|
1505
|
+
return None
|
|
1506
|
+
|
|
1507
|
+
# 处理 "209/192" 这种格式(取第一个值)
|
|
1508
|
+
if '/' in value_str:
|
|
1509
|
+
try:
|
|
1510
|
+
return float(value_str.split('/')[0])
|
|
1511
|
+
except:
|
|
1512
|
+
return None
|
|
1513
|
+
|
|
1514
|
+
# 尝试直接转换
|
|
1515
|
+
try:
|
|
1516
|
+
return float(value_str)
|
|
1517
|
+
except:
|
|
1518
|
+
return None
|
|
1519
|
+
# ============================================
|
|
1520
|
+
|
|
1521
|
+
# 1) Create word list: ["0.00", "0.01", ..., "999.99"]
|
|
1522
|
+
word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
|
|
1523
|
+
word_list = ["%.2f" % i for i in word_list]
|
|
1524
|
+
|
|
1525
|
+
# 2) Build word2idx dictionary with special tokens
|
|
1526
|
+
word2idx = {'[PAD]': 0, '[MASK]': 1}
|
|
1527
|
+
for i, w in enumerate(word_list):
|
|
1528
|
+
word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
|
|
1529
|
+
|
|
1530
|
+
# 3) Initialize output dictionary
|
|
1531
|
+
ms_data = {}
|
|
1532
|
+
|
|
1533
|
+
# ===== 新增:统计信息 =====
|
|
1534
|
+
filter_stats = {
|
|
1535
|
+
'total': 0,
|
|
1536
|
+
'empty_mz': 0,
|
|
1537
|
+
'not_positive': 0,
|
|
1538
|
+
'no_meta': 0,
|
|
1539
|
+
'no_precursor': 0,
|
|
1540
|
+
'precursor_gt_1000': 0,
|
|
1541
|
+
'no_peaks_after_filter': 0,
|
|
1542
|
+
'too_few_peaks': 0, # 新增
|
|
1543
|
+
'kept': 0
|
|
1544
|
+
}
|
|
1545
|
+
# ===========================
|
|
1546
|
+
|
|
1547
|
+
# 4) Iterate through each ms2_id
|
|
1548
|
+
for ms2_id, info in ms2_data.items():
|
|
1549
|
+
filter_stats['total'] += 1
|
|
1550
|
+
|
|
1551
|
+
mz_data = info.get('mz')
|
|
1552
|
+
if mz_data is None or len(mz_data) == 0:
|
|
1553
|
+
filter_stats['empty_mz'] += 1
|
|
1554
|
+
continue
|
|
1555
|
+
peaks = info['mz']
|
|
1556
|
+
intensities = info['intensity']
|
|
1557
|
+
molecule_id = info.get('molecule_id', None)
|
|
1558
|
+
|
|
1559
|
+
specific_row = meta_data[meta_data["file_name"] == ms2_id]
|
|
1560
|
+
if specific_row.empty:
|
|
1561
|
+
filter_stats['no_meta'] += 1
|
|
1562
|
+
continue
|
|
1563
|
+
elif specific_row["Polarity"].values[0] not in ["Positive", "positive"]:
|
|
1564
|
+
filter_stats['not_positive'] += 1
|
|
1565
|
+
continue
|
|
1566
|
+
|
|
1567
|
+
# 4.1 Find precursor mass from meta_data
|
|
1568
|
+
if 'HMDB.ID' in meta_data.columns:
|
|
1569
|
+
row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
|
|
1570
|
+
else:
|
|
1571
|
+
row = meta_data.loc[meta_data.index == molecule_id]
|
|
1572
|
+
if row.empty:
|
|
1573
|
+
filter_stats['no_meta'] += 1
|
|
1574
|
+
continue
|
|
1575
|
+
|
|
1576
|
+
# ===== 修改:使用安全转换函数 =====
|
|
1577
|
+
precursor_val = safe_convert_precursor(row['precursor_mass'].values[0])
|
|
1578
|
+
if precursor_val is None:
|
|
1579
|
+
filter_stats['no_precursor'] += 1
|
|
1580
|
+
continue
|
|
1581
|
+
# ===================================
|
|
1582
|
+
|
|
1583
|
+
if precursor_val > 1000:
|
|
1584
|
+
filter_stats['precursor_gt_1000'] += 1
|
|
1585
|
+
continue
|
|
1586
|
+
precursor_str = "%.2f" % precursor_val
|
|
1587
|
+
|
|
1588
|
+
# 4.2 Convert m/z values to string and map to indices
|
|
1589
|
+
peaks_str = []
|
|
1590
|
+
for mz_val in peaks:
|
|
1591
|
+
if mz_val <= 1000:
|
|
1592
|
+
peaks_str.append("%.2f" % mz_val)
|
|
1593
|
+
|
|
1594
|
+
# ===== 新增:检查peaks数量 =====
|
|
1595
|
+
if len(peaks_str) == 0:
|
|
1596
|
+
filter_stats['no_peaks_after_filter'] += 1
|
|
1597
|
+
continue
|
|
1598
|
+
|
|
1599
|
+
if len(peaks_str) < min_peaks:
|
|
1600
|
+
filter_stats['too_few_peaks'] += 1
|
|
1601
|
+
continue
|
|
1602
|
+
# ===============================
|
|
1603
|
+
|
|
1604
|
+
token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
|
|
1605
|
+
|
|
1606
|
+
# 4.3 Normalize intensity and prepend a fixed value (2)
|
|
1607
|
+
intensities = np.hstack((2, intensities))
|
|
1608
|
+
max_intensity = np.max(intensities)
|
|
1609
|
+
if max_intensity != 0:
|
|
1610
|
+
intensities = intensities / max_intensity
|
|
1611
|
+
|
|
1612
|
+
# 4.4 Pad or truncate to maxlen
|
|
1613
|
+
n_pad = maxlen - len(token_ids)
|
|
1614
|
+
if n_pad < 0:
|
|
1615
|
+
token_ids = token_ids[:maxlen]
|
|
1616
|
+
intensities = intensities[:maxlen]
|
|
1617
|
+
n_pad = 0
|
|
1618
|
+
token_ids += [word2idx['[PAD]']] * n_pad
|
|
1619
|
+
if len(intensities) < maxlen:
|
|
1620
|
+
intensities = np.hstack([intensities, np.zeros(maxlen - len(intensities))])
|
|
1621
|
+
else:
|
|
1622
|
+
intensities = intensities[:maxlen]
|
|
1623
|
+
|
|
1624
|
+
# 4.5 Save processed result
|
|
1625
|
+
ms_data[ms2_id] = {
|
|
1626
|
+
'mz': token_ids,
|
|
1627
|
+
'intensity': intensities.tolist(),
|
|
1628
|
+
'molecule_id': molecule_id
|
|
1629
|
+
}
|
|
1630
|
+
filter_stats['kept'] += 1
|
|
1631
|
+
|
|
1632
|
+
# ===== 新增:打印统计信息 =====
|
|
1633
|
+
print(f"\n预处理统计:")
|
|
1634
|
+
print(f" 总光谱: {filter_stats['total']}")
|
|
1635
|
+
print(f" 过滤:")
|
|
1636
|
+
print(f" 空M/Z: {filter_stats['empty_mz']}")
|
|
1637
|
+
print(f" 非Positive: {filter_stats['not_positive']}")
|
|
1638
|
+
print(f" 无meta: {filter_stats['no_meta']}")
|
|
1639
|
+
print(f" 无/异常precursor: {filter_stats['no_precursor']}")
|
|
1640
|
+
print(f" precursor>1000: {filter_stats['precursor_gt_1000']}")
|
|
1641
|
+
print(f" 所有peaks>1000: {filter_stats['no_peaks_after_filter']}")
|
|
1642
|
+
if min_peaks > 0:
|
|
1643
|
+
print(f" peaks<{min_peaks}: {filter_stats['too_few_peaks']}")
|
|
1644
|
+
print(f" ✓ 保留: {filter_stats['kept']} ({filter_stats['kept']/filter_stats['total']*100:.2f}%)")
|
|
1645
|
+
# ================================
|
|
1646
|
+
|
|
1647
|
+
# 5) Return processed data and dictionary
|
|
1648
|
+
return ms_data, word2idx
|
|
1649
|
+
|
|
1650
|
+
|
|
1651
|
+
@staticmethod
|
|
1652
|
+
def augment_ms2_data_parallel(ms2_data, args, n_workers=None):
|
|
1653
|
+
"""多进程版本的augment_ms2_data"""
|
|
1654
|
+
from multiprocessing import Pool, cpu_count
|
|
1655
|
+
|
|
1656
|
+
if n_workers is None:
|
|
1657
|
+
n_workers = min(cpu_count() - 1, 8)
|
|
1658
|
+
|
|
1659
|
+
augment_noise = getattr(args, 'augment_noise', False)
|
|
1660
|
+
augment_multiplier = getattr(args, 'augment_multiplier', 1)
|
|
1661
|
+
|
|
1662
|
+
if not augment_noise or augment_multiplier <= 1:
|
|
1663
|
+
print("ℹ️ 未启用数据增强")
|
|
1664
|
+
return ms2_data
|
|
1665
|
+
|
|
1666
|
+
print(f"\n🚀 多进程数据增强 (workers={n_workers})...")
|
|
1667
|
+
|
|
1668
|
+
# 准备参数
|
|
1669
|
+
items = list(ms2_data.items())
|
|
1670
|
+
chunk_size = max(1, len(items) // (n_workers * 4))
|
|
1671
|
+
|
|
1672
|
+
# 提取参数
|
|
1673
|
+
filter_threshold = getattr(args, 'filter_threshold', None)
|
|
1674
|
+
noise_ratio = getattr(args, 'noise_ratio', 0.5)
|
|
1675
|
+
noise_intensity_range = getattr(args, 'noise_intensity_range', (0.001, 0.05))
|
|
1676
|
+
|
|
1677
|
+
# 分批并打包参数
|
|
1678
|
+
batches = [items[i:i+chunk_size] for i in range(0, len(items), chunk_size)]
|
|
1679
|
+
batch_data = [(batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier)
|
|
1680
|
+
for batch in batches]
|
|
1681
|
+
|
|
1682
|
+
with Pool(n_workers) as pool:
|
|
1683
|
+
results = pool.map(_augment_worker, batch_data)
|
|
1684
|
+
|
|
1685
|
+
# 合并结果
|
|
1686
|
+
augmented_data = {}
|
|
1687
|
+
for r in results:
|
|
1688
|
+
augmented_data.update(r)
|
|
1689
|
+
|
|
1690
|
+
print(f" ✓ 完成: {len(augmented_data)} 光谱")
|
|
1691
|
+
return augmented_data
|
|
1692
|
+
|
|
1693
|
+
|
|
1694
|
+
@staticmethod
|
|
1695
|
+
def preprocess_ms2_data_positive_only_parallel(ms2_data, meta_data, maxlen=100, min_peaks=0, n_workers=None,
|
|
1696
|
+
precursor_mode='normalize_add', precursor_value=2.0):
|
|
1697
|
+
"""
|
|
1698
|
+
多进程版本的preprocess
|
|
1699
|
+
|
|
1700
|
+
Args:
|
|
1701
|
+
precursor_mode:
|
|
1702
|
+
- 'scale_fixed': 缩放fragments到precursor_value(如20000),precursor固定为2
|
|
1703
|
+
- 'normalize_add': 归一化fragments到1,precursor用precursor_value(如2.0),再整体归一化
|
|
1704
|
+
- 'original': 原始MSBERT方式
|
|
1705
|
+
precursor_value:
|
|
1706
|
+
- mode='scale_fixed'时: fragments缩放的目标值(默认20000)
|
|
1707
|
+
- mode='normalize_add'时: precursor的强度值(默认2.0)
|
|
1708
|
+
"""
|
|
1709
|
+
from multiprocessing import Pool, cpu_count
|
|
1710
|
+
import numpy as np
|
|
1711
|
+
import pandas as pd
|
|
1712
|
+
|
|
1713
|
+
if n_workers is None:
|
|
1714
|
+
n_workers = min(cpu_count() - 1, 8)
|
|
1715
|
+
|
|
1716
|
+
print(f"\n🚀 多进程预处理 (workers={n_workers}, mode={precursor_mode}, value={precursor_value})...")
|
|
1717
|
+
|
|
1718
|
+
# 构建word2idx
|
|
1719
|
+
word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
|
|
1720
|
+
word_list = ["%.2f" % i for i in word_list]
|
|
1721
|
+
word2idx = {'[PAD]': 0, '[MASK]': 1}
|
|
1722
|
+
for i, w in enumerate(word_list):
|
|
1723
|
+
word2idx[w] = i + 2
|
|
1724
|
+
|
|
1725
|
+
# 预处理 meta_data
|
|
1726
|
+
meta_data_processed = meta_data.copy()
|
|
1727
|
+
if "Polarity" in meta_data_processed.columns:
|
|
1728
|
+
meta_data_processed["Polarity"] = meta_data_processed["Polarity"].astype(str).str.lower().str.strip()
|
|
1729
|
+
|
|
1730
|
+
# 计算最大碎片数
|
|
1731
|
+
max_frag = max(0, min(100, maxlen - 1))
|
|
1732
|
+
|
|
1733
|
+
# 准备数据
|
|
1734
|
+
items = list(ms2_data.items())
|
|
1735
|
+
chunk_size = max(1, len(items) // (n_workers * 4))
|
|
1736
|
+
|
|
1737
|
+
# 分批并打包参数
|
|
1738
|
+
batches = [items[i:i+chunk_size] for i in range(0, len(items), chunk_size)]
|
|
1739
|
+
batch_data = [(batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks,
|
|
1740
|
+
precursor_mode, precursor_value)
|
|
1741
|
+
for batch in batches]
|
|
1742
|
+
|
|
1743
|
+
with Pool(n_workers) as pool:
|
|
1744
|
+
results = pool.map(_preprocess_worker, batch_data)
|
|
1745
|
+
|
|
1746
|
+
# 合并
|
|
1747
|
+
ms_data = {}
|
|
1748
|
+
total_kept = 0
|
|
1749
|
+
total_filtered = 0
|
|
1750
|
+
for r, stats in results:
|
|
1751
|
+
ms_data.update(r)
|
|
1752
|
+
total_kept += stats['kept']
|
|
1753
|
+
total_filtered += stats['filtered']
|
|
1754
|
+
|
|
1755
|
+
print(f" ✓ 完成: {total_kept}/{len(items)} 光谱 (过滤: {total_filtered})")
|
|
1756
|
+
return ms_data, word2idx
|
|
1757
|
+
|
|
1758
|
+
@staticmethod
|
|
1759
|
+
def preprocess_ms2_data(ms2_data, meta_data, maxlen=100):
|
|
1760
|
+
"""
|
|
1761
|
+
Preprocess ms2_data for model input.
|
|
1762
|
+
|
|
1763
|
+
Parameters:
|
|
1764
|
+
- ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
|
|
1765
|
+
- meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
|
|
1766
|
+
- maxlen: int, maximum sequence length
|
|
1767
|
+
|
|
1768
|
+
Returns:
|
|
1769
|
+
- ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
|
|
1770
|
+
- word2idx: dict, maps string-formatted m/z values to token indices
|
|
1771
|
+
"""
|
|
1772
|
+
# 1) Create word list: ["0.00", "0.01", ..., "999.99"]
|
|
1773
|
+
word_list = list(np.round(np.linspace(0, 1000, 100*1000, endpoint=False), 2))
|
|
1774
|
+
word_list = ["%.2f" % i for i in word_list]
|
|
1775
|
+
|
|
1776
|
+
# 2) Build word2idx dictionary with special tokens
|
|
1777
|
+
word2idx = {'[PAD]': 0, '[MASK]': 1}
|
|
1778
|
+
for i, w in enumerate(word_list):
|
|
1779
|
+
word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
|
|
1780
|
+
|
|
1781
|
+
# 3) Initialize output dictionary
|
|
1782
|
+
ms_data = {}
|
|
1783
|
+
|
|
1784
|
+
# 4) Iterate through each ms2_id
|
|
1785
|
+
for ms2_id, info in ms2_data.items():
|
|
1786
|
+
if not info['mz']:
|
|
1787
|
+
continue
|
|
1788
|
+
peaks = info['mz']
|
|
1789
|
+
intensities = info['intensity']
|
|
1790
|
+
molecule_id = info.get('molecule_id', None)
|
|
1791
|
+
|
|
1792
|
+
# 4.1 Find precursor mass from meta_data
|
|
1793
|
+
if 'HMDB.ID' in meta_data.columns:
|
|
1794
|
+
row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
|
|
1795
|
+
else:
|
|
1796
|
+
row = meta_data.loc[meta_data.index == molecule_id]
|
|
1797
|
+
if row.empty:
|
|
1798
|
+
continue
|
|
1799
|
+
precursor_val = float(row['precursor_mass'].values[0])
|
|
1800
|
+
if pd.isna(precursor_val):
|
|
1801
|
+
continue
|
|
1802
|
+
if precursor_val > 1000:
|
|
1803
|
+
continue
|
|
1804
|
+
precursor_str = "%.2f" % precursor_val
|
|
1805
|
+
|
|
1806
|
+
# 4.2 Convert m/z values to string and map to indices
|
|
1807
|
+
peaks_str = []
|
|
1808
|
+
for mz_val in peaks:
|
|
1809
|
+
if mz_val <= 1000:
|
|
1810
|
+
peaks_str.append("%.2f" % mz_val)
|
|
1811
|
+
else:
|
|
1812
|
+
continue
|
|
1813
|
+
token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
|
|
1814
|
+
|
|
1815
|
+
# 4.3 Normalize intensity and prepend a fixed value (2)
|
|
1816
|
+
intensities = np.hstack((2, intensities))
|
|
1817
|
+
max_intensity = np.max(intensities)
|
|
1818
|
+
if max_intensity != 0:
|
|
1819
|
+
intensities = intensities / max_intensity
|
|
1820
|
+
|
|
1821
|
+
# 4.4 Pad or truncate to maxlen
|
|
1822
|
+
n_pad = maxlen - len(token_ids)
|
|
1823
|
+
if n_pad < 0:
|
|
1824
|
+
token_ids = token_ids[:maxlen]
|
|
1825
|
+
intensities = intensities[:maxlen]
|
|
1826
|
+
n_pad = 0
|
|
1827
|
+
token_ids += [word2idx['[PAD]']] * n_pad
|
|
1828
|
+
if len(intensities) < maxlen:
|
|
1829
|
+
intensities = np.hstack([intensities, np.zeros(maxlen - len(intensities))])
|
|
1830
|
+
else:
|
|
1831
|
+
intensities = intensities[:maxlen]
|
|
1832
|
+
|
|
1833
|
+
# 4.5 Save processed result
|
|
1834
|
+
ms_data[ms2_id] = {
|
|
1835
|
+
'mz': token_ids,
|
|
1836
|
+
'intensity': intensities.tolist(),
|
|
1837
|
+
'molecule_id': molecule_id
|
|
1838
|
+
}
|
|
1839
|
+
|
|
1840
|
+
# 5) Return processed data and dictionary
|
|
1841
|
+
return ms_data, word2idx
|
|
1842
|
+
|
|
1843
|
+
@staticmethod
|
|
1844
|
+
def fill_precursor_data(meta_data, ms_data):
|
|
1845
|
+
"""
|
|
1846
|
+
Fill in missing precursor ion mass in mass spectrometry metadata.
|
|
1847
|
+
|
|
1848
|
+
Parameters:
|
|
1849
|
+
meta_data (DataFrame): DataFrame containing metadata for mass spectrometry samples.
|
|
1850
|
+
ms_data (dict): Dictionary containing peak data, with keys as spectrum IDs and values as dicts with 'mz' and 'intensity'.
|
|
1851
|
+
|
|
1852
|
+
Returns:
|
|
1853
|
+
DataFrame: Updated meta_data with filled precursor ion mass.
|
|
1854
|
+
"""
|
|
1855
|
+
# Copy the DataFrame to avoid modifying the original
|
|
1856
|
+
meta_data = meta_data.copy()
|
|
1857
|
+
|
|
1858
|
+
# Find column name containing 'precursor'
|
|
1859
|
+
precursor_cols = [col for col in meta_data.columns if 'precursor' in col.lower()]
|
|
1860
|
+
if not precursor_cols:
|
|
1861
|
+
raise ValueError("No column containing 'precursor' found in meta_data")
|
|
1862
|
+
precursor_col = precursor_cols[0]
|
|
1863
|
+
print(f"Using '{precursor_col}' as the precursor mass column")
|
|
1864
|
+
|
|
1865
|
+
init_nan = meta_data[precursor_col].isna().sum()
|
|
1866
|
+
|
|
1867
|
+
# Find column named 'mz'
|
|
1868
|
+
mz_cols = [col for col in meta_data.columns if col.lower() == 'mz']
|
|
1869
|
+
has_mz_column = len(mz_cols) > 0
|
|
1870
|
+
mz_col = mz_cols[0] if has_mz_column else None
|
|
1871
|
+
|
|
1872
|
+
# Proton mass (H+) is approximately 1.007276 Da
|
|
1873
|
+
proton_mass = 1.007276
|
|
1874
|
+
|
|
1875
|
+
# Tolerance threshold for isotopic effect (Da)
|
|
1876
|
+
isotope_threshold = 2.0
|
|
1877
|
+
|
|
1878
|
+
# Dictionary of adduct ion modes with their corresponding mass calculations
|
|
1879
|
+
adduct_modes = {
|
|
1880
|
+
'positive': {
|
|
1881
|
+
'[M+H]+': lambda m: m + proton_mass,
|
|
1882
|
+
'[M+H-H2O]+': lambda m: m + proton_mass - 18.010565,
|
|
1883
|
+
'[M+Na]+': lambda m: m + 22.989218,
|
|
1884
|
+
'[M+K]+': lambda m: m + 39.098301,
|
|
1885
|
+
'[M+NH4]+': lambda m: m + 18.033823,
|
|
1886
|
+
'[2M+H]+': lambda m: 2*m + proton_mass,
|
|
1887
|
+
'[2M+Na]+': lambda m: 2*m + 22.989218,
|
|
1888
|
+
'[2M+K]+': lambda m: 2*m + 39.098301,
|
|
1889
|
+
'[2M+NH4]+': lambda m: 2*m + 18.033823,
|
|
1890
|
+
'[2M+H-H2O]+': lambda m: 2*m + proton_mass - 18.010565
|
|
1891
|
+
},
|
|
1892
|
+
'negative': {
|
|
1893
|
+
'[M-H]-': lambda m: m - proton_mass,
|
|
1894
|
+
'[M-H2O-H]-': lambda m: m - proton_mass - 18.010565,
|
|
1895
|
+
'[M+Cl]-': lambda m: m + 34.969402,
|
|
1896
|
+
'[M+HAc-H]-': lambda m: m + 59.013851,
|
|
1897
|
+
'[2M-H]-': lambda m: 2*m - proton_mass,
|
|
1898
|
+
'[2M+Cl]-': lambda m: 2*m + 34.969402,
|
|
1899
|
+
'[2M+HAc-H]-': lambda m: 2*m + 59.013851
|
|
1900
|
+
}
|
|
1901
|
+
}
|
|
1902
|
+
meta_data[precursor_col] = pd.to_numeric(meta_data[precursor_col], errors='coerce')
|
|
1903
|
+
# Replace negative precursor values with NaN
|
|
1904
|
+
meta_data.loc[meta_data[precursor_col] < 0, precursor_col] = np.nan
|
|
1905
|
+
|
|
1906
|
+
# Fill missing precursor masses
|
|
1907
|
+
for idx, row in meta_data.iterrows():
|
|
1908
|
+
if pd.isna(row[precursor_col]):
|
|
1909
|
+
spectrum_id = idx
|
|
1910
|
+
|
|
1911
|
+
# Determine polarity
|
|
1912
|
+
polarity = str(row['Polarity']).lower()
|
|
1913
|
+
if 'positive' in polarity:
|
|
1914
|
+
polarity_type = 'positive'
|
|
1915
|
+
elif 'negative' in polarity:
|
|
1916
|
+
polarity_type = 'negative'
|
|
1917
|
+
else:
|
|
1918
|
+
polarity_type = 'positive'
|
|
1919
|
+
|
|
1920
|
+
# Try to get base mz from meta_data
|
|
1921
|
+
base_mz = None
|
|
1922
|
+
if has_mz_column and not pd.isna(row[mz_col]):
|
|
1923
|
+
base_mz = row[mz_col]
|
|
1924
|
+
else:
|
|
1925
|
+
if row["file_name"] not in ms_data:
|
|
1926
|
+
print(f"Warning: spectrum ID {spectrum_id} not found in ms_data")
|
|
1927
|
+
continue
|
|
1928
|
+
|
|
1929
|
+
spectrum = ms_data.get(row["file_name"], {})
|
|
1930
|
+
if 'mz' not in spectrum or len(spectrum['mz']) == 0:
|
|
1931
|
+
print(f"Warning: spectrum ID {spectrum_id} has no mz data")
|
|
1932
|
+
continue
|
|
1933
|
+
|
|
1934
|
+
base_mz = max(spectrum['mz'])
|
|
1935
|
+
|
|
1936
|
+
candidate_precursors = {}
|
|
1937
|
+
for mode_name, mode_func in adduct_modes[polarity_type].items():
|
|
1938
|
+
candidate_mass = mode_func(base_mz)
|
|
1939
|
+
candidate_precursors[mode_name] = candidate_mass
|
|
1940
|
+
|
|
1941
|
+
valid_candidates = {}
|
|
1942
|
+
spectrum = ms_data.get(row["file_name"], {})
|
|
1943
|
+
if 'mz' in spectrum and len(spectrum['mz']) > 0:
|
|
1944
|
+
max_fragment_mz = max(spectrum['mz'])
|
|
1945
|
+
|
|
1946
|
+
for mode_name, precursor_mass in candidate_precursors.items():
|
|
1947
|
+
if max_fragment_mz <= precursor_mass + isotope_threshold:
|
|
1948
|
+
valid_candidates[mode_name] = precursor_mass
|
|
1949
|
+
|
|
1950
|
+
if valid_candidates:
|
|
1951
|
+
default_mode = '[M+H]+' if polarity_type == 'positive' else '[M-H]-'
|
|
1952
|
+
if default_mode in valid_candidates:
|
|
1953
|
+
selected_mode = default_mode
|
|
1954
|
+
else:
|
|
1955
|
+
selected_mode = list(valid_candidates.keys())[0]
|
|
1956
|
+
precursor_mass = valid_candidates[selected_mode]
|
|
1957
|
+
meta_data.at[idx, precursor_col] = precursor_mass
|
|
1958
|
+
else:
|
|
1959
|
+
if 'mz' in spectrum and len(spectrum['mz']) > 0:
|
|
1960
|
+
max_mz = max(spectrum['mz'])
|
|
1961
|
+
if polarity_type == 'positive':
|
|
1962
|
+
adjusted_precursor = max_mz + proton_mass + 1.0
|
|
1963
|
+
else:
|
|
1964
|
+
adjusted_precursor = max_mz - proton_mass + 1.0
|
|
1965
|
+
meta_data.at[idx, precursor_col] = adjusted_precursor
|
|
1966
|
+
else:
|
|
1967
|
+
print(f"Warning: unable to determine precursor mass for spectrum ID {spectrum_id}")
|
|
1968
|
+
|
|
1969
|
+
left_nan = meta_data[precursor_col].isna().sum()
|
|
1970
|
+
print(f"Precursor mass missing: initially {init_nan}; filled {init_nan - left_nan}; remaining {left_nan}.")
|
|
1971
|
+
|
|
1972
|
+
return meta_data
|
|
1973
|
+
|
|
1974
|
+
|
|
1975
|
+
@staticmethod
|
|
1976
|
+
def preprocess_ms2_data_positive_only(ms2_data, meta_data, maxlen=100):
|
|
1977
|
+
"""
|
|
1978
|
+
Preprocess ms2_data for model input.
|
|
1979
|
+
|
|
1980
|
+
Parameters:
|
|
1981
|
+
- ms2_data: dict, {ms2_id: {'mz': list, 'intensity': list, 'molecule_id': str}}
|
|
1982
|
+
- meta_data: pd.DataFrame, must contain precursor information (column: 'precursor_mass')
|
|
1983
|
+
- maxlen: int, maximum sequence length
|
|
1984
|
+
|
|
1985
|
+
Returns:
|
|
1986
|
+
- ms_data: dict, same structure as ms2_data but with processed 'mz' and 'intensity' sequences
|
|
1987
|
+
- word2idx: dict, maps string-formatted m/z values to token indices
|
|
1988
|
+
"""
|
|
1989
|
+
import numpy as np
|
|
1990
|
+
import pandas as pd
|
|
1991
|
+
|
|
1992
|
+
# 1) Create word list: ["0.00", "0.01", ..., "999.99"]
|
|
1993
|
+
word_list = list(np.round(np.linspace(0, 1000, 100 * 1000, endpoint=False), 2))
|
|
1994
|
+
word_list = ["%.2f" % i for i in word_list]
|
|
1995
|
+
|
|
1996
|
+
# 2) Build word2idx dictionary with special tokens
|
|
1997
|
+
word2idx = {'[PAD]': 0, '[MASK]': 1}
|
|
1998
|
+
for i, w in enumerate(word_list):
|
|
1999
|
+
word2idx[w] = i + 2 # Start from 2 to avoid collision with special tokens
|
|
2000
|
+
|
|
2001
|
+
# 3) Initialize output dictionary
|
|
2002
|
+
ms_data = {}
|
|
2003
|
+
|
|
2004
|
+
# 预计算:正离子模式的判定更稳健(lower+strip)
|
|
2005
|
+
if "Polarity" in meta_data.columns:
|
|
2006
|
+
meta_data = meta_data.copy()
|
|
2007
|
+
meta_data["Polarity"] = meta_data["Polarity"].astype(str).str.lower().str.strip()
|
|
2008
|
+
|
|
2009
|
+
# 允许的最大碎片数(保证前体+碎片 <= maxlen)
|
|
2010
|
+
max_frag = max(0, min(100, maxlen - 1))
|
|
2011
|
+
|
|
2012
|
+
# 4) Iterate through each ms2_id
|
|
2013
|
+
for ms2_id, info in ms2_data.items():
|
|
2014
|
+
# 基础检查
|
|
2015
|
+
if not info.get('mz'):
|
|
2016
|
+
continue
|
|
2017
|
+
|
|
2018
|
+
peaks = np.asarray(info['mz'], dtype=float)
|
|
2019
|
+
intensities = np.asarray(info['intensity'], dtype=float)
|
|
2020
|
+
molecule_id = info.get('molecule_id', None)
|
|
2021
|
+
|
|
2022
|
+
# 4.0 文件名对应行用于极性判断
|
|
2023
|
+
specific_row = meta_data.loc[meta_data["file_name"] == ms2_id] if "file_name" in meta_data.columns else pd.DataFrame()
|
|
2024
|
+
if specific_row.empty:
|
|
2025
|
+
# 若找不到,就尽量用 molecule_id 定位一行(不强制)
|
|
2026
|
+
if molecule_id is not None:
|
|
2027
|
+
if 'HMDB.ID' in meta_data.columns:
|
|
2028
|
+
specific_row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
|
|
2029
|
+
else:
|
|
2030
|
+
specific_row = meta_data.loc[meta_data.index == molecule_id]
|
|
2031
|
+
if specific_row.empty:
|
|
2032
|
+
continue
|
|
2033
|
+
|
|
2034
|
+
# 只保留正离子
|
|
2035
|
+
pol = str(specific_row["Polarity"].values[0]).lower().strip() if "Polarity" in specific_row.columns else ""
|
|
2036
|
+
if pol != "positive":
|
|
2037
|
+
continue
|
|
2038
|
+
|
|
2039
|
+
# 4.1 Find precursor mass from meta_data
|
|
2040
|
+
if 'HMDB.ID' in meta_data.columns and (molecule_id is not None):
|
|
2041
|
+
row = meta_data.loc[meta_data['HMDB.ID'] == molecule_id]
|
|
2042
|
+
else:
|
|
2043
|
+
row = meta_data.loc[meta_data.index == molecule_id]
|
|
2044
|
+
|
|
2045
|
+
if row.empty or ('precursor_mass' not in row.columns):
|
|
2046
|
+
continue
|
|
2047
|
+
|
|
2048
|
+
try:
|
|
2049
|
+
precursor_val = float(row['precursor_mass'].values[0])
|
|
2050
|
+
except Exception:
|
|
2051
|
+
continue
|
|
2052
|
+
|
|
2053
|
+
# 前体范围 [10, 1000);并避免 1000.00 被格式化后越界
|
|
2054
|
+
if pd.isna(precursor_val) or (precursor_val < 10.0) or (precursor_val >= 1000.0):
|
|
2055
|
+
continue
|
|
2056
|
+
precursor_val = min(precursor_val, 999.99)
|
|
2057
|
+
precursor_str = "%.2f" % precursor_val
|
|
2058
|
+
|
|
2059
|
+
# 4.2 过滤峰到 [10, 1000)
|
|
2060
|
+
if peaks.shape[0] != intensities.shape[0]:
|
|
2061
|
+
# 长度不一致直接跳过(也可选择截断到对齐最短)
|
|
2062
|
+
n = min(len(peaks), len(intensities))
|
|
2063
|
+
peaks = peaks[:n]
|
|
2064
|
+
intensities = intensities[:n]
|
|
2065
|
+
|
|
2066
|
+
mask = (peaks >= 10.0) & (peaks < 1000.0) & np.isfinite(peaks) & np.isfinite(intensities)
|
|
2067
|
+
peaks = peaks[mask]
|
|
2068
|
+
intensities = intensities[mask]
|
|
2069
|
+
|
|
2070
|
+
if peaks.size == 0:
|
|
2071
|
+
continue
|
|
2072
|
+
|
|
2073
|
+
# 4.3 按强度选 Top-K 碎片(最多 100,且保证前体+碎片 <= maxlen)
|
|
2074
|
+
if peaks.size > max_frag:
|
|
2075
|
+
idx = np.argpartition(intensities, -max_frag)[-max_frag:]
|
|
2076
|
+
# 选完后按 m/z 升序排序(也可按强度降序,看你需求)
|
|
2077
|
+
order = np.argsort(peaks[idx])
|
|
2078
|
+
idx = idx[order]
|
|
2079
|
+
peaks_sel = peaks[idx]
|
|
2080
|
+
intens_sel = intensities[idx]
|
|
2081
|
+
else:
|
|
2082
|
+
# 直接按 m/z 升序
|
|
2083
|
+
order = np.argsort(peaks)
|
|
2084
|
+
peaks_sel = peaks[order]
|
|
2085
|
+
intens_sel = intensities[order]
|
|
2086
|
+
|
|
2087
|
+
# 4.4 构建 token 序列(前体在最前)
|
|
2088
|
+
peaks_str = ["%.2f" % p for p in peaks_sel]
|
|
2089
|
+
try:
|
|
2090
|
+
token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
|
|
2091
|
+
except KeyError:
|
|
2092
|
+
# 理论上不会发生(我们已限制到 [10, 999.99]),但以防万一
|
|
2093
|
+
continue
|
|
2094
|
+
|
|
2095
|
+
# 4.5 强度:在最前 prepend 2,并按你原有逻辑整体归一化
|
|
2096
|
+
intens_seq = np.hstack((2.0, intens_sel))
|
|
2097
|
+
max_intensity = float(np.max(intens_seq)) if intens_seq.size else 1.0
|
|
2098
|
+
if max_intensity != 0.0:
|
|
2099
|
+
intens_seq = intens_seq / max_intensity
|
|
2100
|
+
|
|
2101
|
+
# 4.6 Pad 或截断到 maxlen(双序列严格对齐)
|
|
2102
|
+
if len(token_ids) > maxlen:
|
|
2103
|
+
token_ids = token_ids[:maxlen]
|
|
2104
|
+
intens_seq = intens_seq[:maxlen]
|
|
2105
|
+
|
|
2106
|
+
n_pad = maxlen - len(token_ids)
|
|
2107
|
+
if n_pad > 0:
|
|
2108
|
+
token_ids += [word2idx['[PAD]']] * n_pad
|
|
2109
|
+
intens_seq = np.hstack([intens_seq, np.zeros(n_pad, dtype=float)])
|
|
2110
|
+
|
|
2111
|
+
# 4.7 Save processed result
|
|
2112
|
+
ms_data[ms2_id] = {
|
|
2113
|
+
'mz': token_ids,
|
|
2114
|
+
'intensity': intens_seq.tolist(),
|
|
2115
|
+
'molecule_id': molecule_id
|
|
2116
|
+
}
|
|
2117
|
+
|
|
2118
|
+
# 5) Return processed data and dictionary
|
|
2119
|
+
return ms_data, word2idx
|
|
2120
|
+
|
|
2121
|
+
|
|
2122
|
+
|
|
2123
|
+
@staticmethod
|
|
2124
|
+
def load_external_test_dataset(
|
|
2125
|
+
external_data_dir,
|
|
2126
|
+
biotext_dir,
|
|
2127
|
+
paraphrase_dir,
|
|
2128
|
+
tokenizer,
|
|
2129
|
+
args,
|
|
2130
|
+
dataset_configs=None,
|
|
2131
|
+
**kwargs
|
|
2132
|
+
):
|
|
2133
|
+
"""
|
|
2134
|
+
加载并处理外部测试数据集(如HILIC和RPLC)
|
|
2135
|
+
|
|
2136
|
+
参数:
|
|
2137
|
+
external_data_dir (str): 外部数据目录路径
|
|
2138
|
+
biotext_dir (str): BioText文本文件目录
|
|
2139
|
+
paraphrase_dir (str): Paraphrase文本文件目录
|
|
2140
|
+
tokenizer: 文本tokenizer
|
|
2141
|
+
args: 包含预处理参数的args对象(precursor_mode, precursor_value, n_workers等)
|
|
2142
|
+
dataset_configs (list): 数据集配置列表,每个配置包含name, ms2_file, meta_file
|
|
2143
|
+
**kwargs: 传递给MS2BioTextDataset构造函数的其他参数
|
|
2144
|
+
|
|
2145
|
+
返回:
|
|
2146
|
+
tuple: (external_test_dataset, data_statistics)
|
|
2147
|
+
- external_test_dataset: MS2BioTextDataset实例
|
|
2148
|
+
- data_statistics: 包含数据统计信息的字典
|
|
2149
|
+
"""
|
|
2150
|
+
import pickle
|
|
2151
|
+
import pandas as pd
|
|
2152
|
+
import os
|
|
2153
|
+
from pathlib import Path
|
|
2154
|
+
|
|
2155
|
+
# 默认配置(HILIC和RPLC)
|
|
2156
|
+
if dataset_configs is None:
|
|
2157
|
+
dataset_configs = [
|
|
2158
|
+
{
|
|
2159
|
+
'name': 'HILIC',
|
|
2160
|
+
'ms2_file': os.path.join(external_data_dir, 'hilic_ms_data.pkl'),
|
|
2161
|
+
'meta_file': os.path.join(external_data_dir, 'hilic_meta_data.csv'),
|
|
2162
|
+
},
|
|
2163
|
+
{
|
|
2164
|
+
'name': 'RPLC',
|
|
2165
|
+
'ms2_file': os.path.join(external_data_dir, 'rplc_ms_data.pkl'),
|
|
2166
|
+
'meta_file': os.path.join(external_data_dir, 'rplc_meta_data.csv'),
|
|
2167
|
+
}
|
|
2168
|
+
]
|
|
2169
|
+
|
|
2170
|
+
print("\n" + "="*60)
|
|
2171
|
+
print("加载外部测试数据集...")
|
|
2172
|
+
print("="*60)
|
|
2173
|
+
|
|
2174
|
+
# 1. 加载所有数据集
|
|
2175
|
+
all_ms2_data = {}
|
|
2176
|
+
all_meta_data = []
|
|
2177
|
+
|
|
2178
|
+
for config in dataset_configs:
|
|
2179
|
+
print(f"\n📁 加载 {config['name']} 数据集...")
|
|
2180
|
+
|
|
2181
|
+
# 加载MS2数据
|
|
2182
|
+
with open(config['ms2_file'], 'rb') as f:
|
|
2183
|
+
ms2_data = pickle.load(f)
|
|
2184
|
+
|
|
2185
|
+
# 加载Meta数据
|
|
2186
|
+
meta_data = pd.read_csv(config['meta_file'])
|
|
2187
|
+
|
|
2188
|
+
print(f" ✓ {config['name']}: {len(ms2_data)} 光谱, {len(meta_data)} meta")
|
|
2189
|
+
|
|
2190
|
+
# 合并MS2
|
|
2191
|
+
all_ms2_data.update(ms2_data)
|
|
2192
|
+
all_meta_data.append(meta_data)
|
|
2193
|
+
|
|
2194
|
+
# 2. 合并Meta数据(确保列对齐)
|
|
2195
|
+
if len(all_meta_data) > 1:
|
|
2196
|
+
all_cols = set()
|
|
2197
|
+
for df in all_meta_data:
|
|
2198
|
+
all_cols.update(df.columns)
|
|
2199
|
+
all_cols = sorted(all_cols)
|
|
2200
|
+
|
|
2201
|
+
aligned_meta_data = []
|
|
2202
|
+
for df in all_meta_data:
|
|
2203
|
+
df = df.reindex(columns=all_cols, fill_value=None)
|
|
2204
|
+
aligned_meta_data.append(df)
|
|
2205
|
+
|
|
2206
|
+
external_meta_data = pd.concat(aligned_meta_data, ignore_index=True)
|
|
2207
|
+
else:
|
|
2208
|
+
external_meta_data = all_meta_data[0]
|
|
2209
|
+
|
|
2210
|
+
external_ms2_data = all_ms2_data
|
|
2211
|
+
|
|
2212
|
+
print(f"\n✓ 合并后外部数据集: {len(external_ms2_data)} 光谱, {len(external_meta_data)} meta")
|
|
2213
|
+
|
|
2214
|
+
# 3. 设置HMDB.ID为索引
|
|
2215
|
+
if 'HMDB.ID' in external_meta_data.columns:
|
|
2216
|
+
external_meta_data = external_meta_data.set_index('HMDB.ID')
|
|
2217
|
+
print(f" 已设置HMDB.ID为索引")
|
|
2218
|
+
|
|
2219
|
+
# 4. 确保MS2数据格式正确(添加molecule_id字段)
|
|
2220
|
+
print("\n🔧 修正MS2数据格式...")
|
|
2221
|
+
for spectrum_id, spectrum_data in external_ms2_data.items():
|
|
2222
|
+
if 'molecule_id' not in spectrum_data:
|
|
2223
|
+
molecule_id = spectrum_id.split('_')[0]
|
|
2224
|
+
spectrum_data['molecule_id'] = molecule_id
|
|
2225
|
+
|
|
2226
|
+
# 5. 获取所有unique的HMDB IDs
|
|
2227
|
+
unique_hmdb_ids = set()
|
|
2228
|
+
for spec_id in external_ms2_data.keys():
|
|
2229
|
+
hmdb_id = spec_id.split('_')[0]
|
|
2230
|
+
unique_hmdb_ids.add(hmdb_id)
|
|
2231
|
+
|
|
2232
|
+
print(f" 外部数据集包含 {len(unique_hmdb_ids)} 个unique HMDB IDs")
|
|
2233
|
+
|
|
2234
|
+
# 6. 加载对应的BioText数据
|
|
2235
|
+
print("\n📚 加载BioText数据...")
|
|
2236
|
+
external_biotext_data = {}
|
|
2237
|
+
missing_biotext = []
|
|
2238
|
+
|
|
2239
|
+
biotext_dir = Path(biotext_dir)
|
|
2240
|
+
paraphrase_dir = Path(paraphrase_dir) if paraphrase_dir else None
|
|
2241
|
+
|
|
2242
|
+
for hmdb_id in unique_hmdb_ids:
|
|
2243
|
+
# 处理异常的HMDB ID(如包含{}的)
|
|
2244
|
+
if '{}' in hmdb_id:
|
|
2245
|
+
missing_biotext.append(hmdb_id)
|
|
2246
|
+
continue
|
|
2247
|
+
|
|
2248
|
+
biotext_file = biotext_dir / f"{hmdb_id}.txt"
|
|
2249
|
+
if biotext_file.exists():
|
|
2250
|
+
with open(biotext_file, 'r', encoding='utf-8') as f:
|
|
2251
|
+
original_text = f.read().strip()
|
|
2252
|
+
|
|
2253
|
+
# 加载paraphrase(如果有)
|
|
2254
|
+
paraphrases = []
|
|
2255
|
+
if paraphrase_dir:
|
|
2256
|
+
paraphrase_file = paraphrase_dir / f"{hmdb_id}_paraphrase.txt"
|
|
2257
|
+
if paraphrase_file.exists():
|
|
2258
|
+
with open(paraphrase_file, 'r', encoding='utf-8') as pf:
|
|
2259
|
+
content = pf.read()
|
|
2260
|
+
versions = content.split("=== version")
|
|
2261
|
+
for version in versions[1:]:
|
|
2262
|
+
_, text = version.split("===", 1)
|
|
2263
|
+
text = text.strip()
|
|
2264
|
+
if text:
|
|
2265
|
+
paraphrases.append(text)
|
|
2266
|
+
|
|
2267
|
+
external_biotext_data[hmdb_id] = {
|
|
2268
|
+
'original': original_text,
|
|
2269
|
+
'paraphrases': paraphrases
|
|
2270
|
+
}
|
|
2271
|
+
else:
|
|
2272
|
+
missing_biotext.append(hmdb_id)
|
|
2273
|
+
|
|
2274
|
+
print(f" ✓ 成功加载 {len(external_biotext_data)} 个BioText")
|
|
2275
|
+
print(f" ✗ 缺失BioText: {len(missing_biotext)} 个")
|
|
2276
|
+
|
|
2277
|
+
# 7. 处理缺失的biotext(使用drop方法)
|
|
2278
|
+
initial_ms2_count = len(external_ms2_data)
|
|
2279
|
+
external_ms2_data, _ = MS2BioTextDataset.missing_biotext_handling(
|
|
2280
|
+
external_ms2_data,
|
|
2281
|
+
external_biotext_data,
|
|
2282
|
+
method="drop"
|
|
2283
|
+
)
|
|
2284
|
+
print(f" 删除 {initial_ms2_count - len(external_ms2_data)} 条缺失biotext的光谱")
|
|
2285
|
+
|
|
2286
|
+
# 8. 更新meta_data,只保留有MS2数据的条目
|
|
2287
|
+
remaining_hmdb_ids = set()
|
|
2288
|
+
for spectrum_id in external_ms2_data.keys():
|
|
2289
|
+
hmdb_id = spectrum_id.split('_')[0]
|
|
2290
|
+
remaining_hmdb_ids.add(hmdb_id)
|
|
2291
|
+
|
|
2292
|
+
external_meta_data = external_meta_data[external_meta_data.index.isin(remaining_hmdb_ids)]
|
|
2293
|
+
|
|
2294
|
+
# 9. 填充precursor数据
|
|
2295
|
+
print("\n⚙️ 填充precursor数据...")
|
|
2296
|
+
external_meta_data = MS2BioTextDataset.fill_precursor_data(
|
|
2297
|
+
external_meta_data,
|
|
2298
|
+
external_ms2_data
|
|
2299
|
+
)
|
|
2300
|
+
|
|
2301
|
+
# 10. 预处理MS2数据(测试集不做数据增强)
|
|
2302
|
+
print("\n🚀 预处理MS2数据(不进行数据增强)...")
|
|
2303
|
+
external_processed_ms2, external_word2idx = MS2BioTextDataset.preprocess_ms2_data_positive_only_parallel(
|
|
2304
|
+
external_ms2_data,
|
|
2305
|
+
external_meta_data,
|
|
2306
|
+
n_workers=getattr(args, 'n_workers', 4),
|
|
2307
|
+
precursor_mode=getattr(args, 'precursor_mode', 'auto'),
|
|
2308
|
+
precursor_value=getattr(args, 'precursor_value', 2.0)
|
|
2309
|
+
)
|
|
2310
|
+
|
|
2311
|
+
# 11. 统计信息
|
|
2312
|
+
data_statistics = {
|
|
2313
|
+
'original_ms2_count': initial_ms2_count,
|
|
2314
|
+
'processed_ms2_count': len(external_processed_ms2),
|
|
2315
|
+
'meta_count': len(external_meta_data),
|
|
2316
|
+
'biotext_count': len(external_biotext_data),
|
|
2317
|
+
'unique_molecules': len(remaining_hmdb_ids),
|
|
2318
|
+
'vocab_size': len(external_word2idx),
|
|
2319
|
+
'datasets': [config['name'] for config in dataset_configs]
|
|
2320
|
+
}
|
|
2321
|
+
|
|
2322
|
+
print("\n" + "="*60)
|
|
2323
|
+
print("📊 外部测试数据集最终统计")
|
|
2324
|
+
print("="*60)
|
|
2325
|
+
print(f" 原始MS2光谱数: {data_statistics['original_ms2_count']}")
|
|
2326
|
+
print(f" 处理后MS2光谱数: {data_statistics['processed_ms2_count']}")
|
|
2327
|
+
print(f" Meta记录数: {data_statistics['meta_count']}")
|
|
2328
|
+
print(f" BioText数: {data_statistics['biotext_count']}")
|
|
2329
|
+
print(f" Unique分子数: {data_statistics['unique_molecules']}")
|
|
2330
|
+
print(f" 词汇表大小: {data_statistics['vocab_size']}")
|
|
2331
|
+
print(f" 数据集来源: {', '.join(data_statistics['datasets'])}")
|
|
2332
|
+
|
|
2333
|
+
# 12. 创建Dataset实例
|
|
2334
|
+
print("\n🎯 创建外部测试Dataset实例...")
|
|
2335
|
+
external_test_dataset = MS2BioTextDataset(
|
|
2336
|
+
ms2_data=external_processed_ms2,
|
|
2337
|
+
meta_data=external_meta_data,
|
|
2338
|
+
biotext_data=external_biotext_data,
|
|
2339
|
+
tokenizer=tokenizer,
|
|
2340
|
+
use_paraphrase=False,
|
|
2341
|
+
**kwargs
|
|
2342
|
+
)
|
|
2343
|
+
|
|
2344
|
+
print(f"✅ 外部测试Dataset创建成功!大小: {len(external_test_dataset)}")
|
|
2345
|
+
|
|
2346
|
+
return external_test_dataset, data_statistics
|
|
2347
|
+
|
|
2348
|
+
|
|
2349
|
+
# @staticmethod
|
|
2350
|
+
# def create_train_test_datasets_from_file(
|
|
2351
|
+
# data_dir,
|
|
2352
|
+
# ms2_data,
|
|
2353
|
+
# meta_data,
|
|
2354
|
+
# biotext_data,
|
|
2355
|
+
# tokenizer,
|
|
2356
|
+
# test_size=0.2,
|
|
2357
|
+
# random_state=42,
|
|
2358
|
+
# use_paraphrase = False,
|
|
2359
|
+
# **kwargs
|
|
2360
|
+
# ):
|
|
2361
|
+
# """
|
|
2362
|
+
# Split data into training and test sets based on molecule IDs.
|
|
2363
|
+
# This version checks whether a local file with pre-defined splits exists.
|
|
2364
|
+
# If it exists, it loads the split; otherwise, it creates the split and saves it to file.
|
|
2365
|
+
# (Enhanced to handle empty or corrupted JSON files gracefully)
|
|
2366
|
+
|
|
2367
|
+
# Parameters:
|
|
2368
|
+
# data_dir (str or Path): Path to the directory containing the split file (molecule_split.json).
|
|
2369
|
+
# ms2_data (dict): Full MS2 data dictionary.
|
|
2370
|
+
# meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
|
|
2371
|
+
# biotext_data (dict): Full BioText data dictionary.
|
|
2372
|
+
# tokenizer: Tokenizer used for initializing the Dataset.
|
|
2373
|
+
# test_size (float): Proportion of test set (used only when creating the split).
|
|
2374
|
+
# random_state (int): Random seed (used only when creating the split).
|
|
2375
|
+
# **kwargs: Other arguments passed to the MS2BioTextDataset constructor.
|
|
2376
|
+
|
|
2377
|
+
# Returns:
|
|
2378
|
+
# tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
|
|
2379
|
+
# """
|
|
2380
|
+
# print("Preparing training and test datasets (with file persistence)...")
|
|
2381
|
+
|
|
2382
|
+
# # --- 1. Define split file path ---
|
|
2383
|
+
# data_dir = Path(data_dir)
|
|
2384
|
+
# split_file_path = data_dir / 'molecule_split.json'
|
|
2385
|
+
|
|
2386
|
+
# # --- 2. Load existing split file if valid; otherwise create new split ---
|
|
2387
|
+
# train_mol_ids, test_mol_ids = None, None
|
|
2388
|
+
|
|
2389
|
+
# if split_file_path.exists() and split_file_path.stat().st_size > 0:
|
|
2390
|
+
# print(f"✅ Found existing split file: {split_file_path}")
|
|
2391
|
+
# print(" Loading molecule IDs from file...")
|
|
2392
|
+
# try:
|
|
2393
|
+
# with open(split_file_path, 'r', encoding='utf-8') as f:
|
|
2394
|
+
# split_ids = json.load(f)
|
|
2395
|
+
# train_mol_ids = split_ids['train_ids']
|
|
2396
|
+
# test_mol_ids = split_ids['test_ids']
|
|
2397
|
+
# except json.JSONDecodeError:
|
|
2398
|
+
# print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed (may be empty or corrupted). Will recreate.")
|
|
2399
|
+
# except KeyError:
|
|
2400
|
+
# print(f" ⚠️ Warning: File '{split_file_path}' has invalid format (missing 'train_ids' or 'test_ids'). Will recreate.")
|
|
2401
|
+
|
|
2402
|
+
# if train_mol_ids is None or test_mol_ids is None:
|
|
2403
|
+
# print(f"⚠️ No valid split file found or could not be loaded. Creating a new split...")
|
|
2404
|
+
# valid_molecule_ids = set(item['molecule_id'] for item in ms2_data.values())
|
|
2405
|
+
# all_molecule_ids = [mol_id for mol_id in meta_data.index.unique() if mol_id in valid_molecule_ids]
|
|
2406
|
+
# print(f" Found {len(all_molecule_ids)} unique molecules for splitting.")
|
|
2407
|
+
|
|
2408
|
+
# train_mol_ids, test_mol_ids = train_test_split(
|
|
2409
|
+
# all_molecule_ids,
|
|
2410
|
+
# test_size=test_size,
|
|
2411
|
+
# random_state=random_state
|
|
2412
|
+
# )
|
|
2413
|
+
|
|
2414
|
+
# print(f" Saving new split to: {split_file_path}")
|
|
2415
|
+
# split_data_to_save = {'train_ids': train_mol_ids, 'test_ids': test_mol_ids}
|
|
2416
|
+
# split_file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2417
|
+
# with open(split_file_path, 'w', encoding='utf-8') as f:
|
|
2418
|
+
# json.dump(split_data_to_save, f, indent=4)
|
|
2419
|
+
# print(" Split file saved successfully.")
|
|
2420
|
+
|
|
2421
|
+
|
|
2422
|
+
# # --- 3. Filter data sources by ID lists ---
|
|
2423
|
+
# train_mol_ids_set = set(train_mol_ids)
|
|
2424
|
+
# test_mol_ids_set = set(test_mol_ids)
|
|
2425
|
+
|
|
2426
|
+
# train_meta_data = meta_data[meta_data.index.isin(train_mol_ids_set)]
|
|
2427
|
+
# test_meta_data = meta_data[meta_data.index.isin(test_mol_ids_set)]
|
|
2428
|
+
|
|
2429
|
+
# train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_mol_ids_set}
|
|
2430
|
+
# test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_mol_ids_set}
|
|
2431
|
+
# train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in train_mol_ids_set}
|
|
2432
|
+
# test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if info['molecule_id'] in test_mol_ids_set}
|
|
2433
|
+
|
|
2434
|
+
# print(f"Data filtering completed:")
|
|
2435
|
+
|
|
2436
|
+
# # --- 4. Create Dataset instances ---
|
|
2437
|
+
# print("Creating training Dataset instance...")
|
|
2438
|
+
# train_dataset = MS2BioTextDataset(
|
|
2439
|
+
# ms2_data=train_ms2_data,
|
|
2440
|
+
# meta_data=train_meta_data,
|
|
2441
|
+
# biotext_data=train_biotext_data,
|
|
2442
|
+
# tokenizer=tokenizer,
|
|
2443
|
+
# use_paraphrase = use_paraphrase,
|
|
2444
|
+
# **kwargs
|
|
2445
|
+
# )
|
|
2446
|
+
|
|
2447
|
+
# print("Creating test Dataset instance...")
|
|
2448
|
+
# test_dataset = MS2BioTextDataset(
|
|
2449
|
+
# ms2_data=test_ms2_data,
|
|
2450
|
+
# meta_data=test_meta_data,
|
|
2451
|
+
# biotext_data=test_biotext_data,
|
|
2452
|
+
# tokenizer=tokenizer,
|
|
2453
|
+
# use_paraphrase=False,
|
|
2454
|
+
# **kwargs
|
|
2455
|
+
# )
|
|
2456
|
+
|
|
2457
|
+
# return train_dataset, test_dataset
|
|
2458
|
+
|
|
2459
|
+
|
|
2460
|
+
|
|
2461
|
+
|
|
2462
|
+
# @staticmethod
|
|
2463
|
+
# def create_train_test_datasets_from_file(
|
|
2464
|
+
# data_dir,
|
|
2465
|
+
# ms2_data,
|
|
2466
|
+
# meta_data,
|
|
2467
|
+
# biotext_data,
|
|
2468
|
+
# tokenizer,
|
|
2469
|
+
# test_size=0.2,
|
|
2470
|
+
# use_paraphrase = False,
|
|
2471
|
+
# **kwargs
|
|
2472
|
+
# ):
|
|
2473
|
+
# """
|
|
2474
|
+
# Split data into training and test sets based on MS2 spectra within each molecule.
|
|
2475
|
+
# For each molecule with multiple MS2 spectra, one MS2 is reserved for testing,
|
|
2476
|
+
# and the rest are used for training. Molecules with only one MS2 spectrum are
|
|
2477
|
+
# only included in the training set.
|
|
2478
|
+
|
|
2479
|
+
# Parameters:
|
|
2480
|
+
# data_dir (str or Path): Path to the directory containing the split file (ms2_split.json).
|
|
2481
|
+
# ms2_data (dict): Full MS2 data dictionary {ms2_id: {molecule_id: ..., ...}}.
|
|
2482
|
+
# meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
|
|
2483
|
+
# biotext_data (dict): Full BioText data dictionary {molecule_id: text}.
|
|
2484
|
+
# tokenizer: Tokenizer used for initializing the Dataset.
|
|
2485
|
+
# test_size (float): Deprecated in this version (kept for compatibility).
|
|
2486
|
+
# random_state (int): Random seed for reproducible splits.
|
|
2487
|
+
# **kwargs: Other arguments passed to the MS2BioTextDataset constructor.
|
|
2488
|
+
|
|
2489
|
+
# Returns:
|
|
2490
|
+
# tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
|
|
2491
|
+
# """
|
|
2492
|
+
# print("Preparing training and test datasets with per-molecule MS2 splitting...")
|
|
2493
|
+
|
|
2494
|
+
# # --- 1. Define split file path ---
|
|
2495
|
+
# data_dir = Path(data_dir)
|
|
2496
|
+
# split_file_path = data_dir / 'ms2_split.json' # 改名以区分新的划分方式
|
|
2497
|
+
|
|
2498
|
+
# # --- 2. Load existing split file if valid; otherwise create new split ---
|
|
2499
|
+
# train_ms2_ids, test_ms2_ids = None, None
|
|
2500
|
+
|
|
2501
|
+
# if split_file_path.exists() and split_file_path.stat().st_size > 0:
|
|
2502
|
+
# print(f"✅ Found existing split file: {split_file_path}")
|
|
2503
|
+
# print(" Loading MS2 IDs from file...")
|
|
2504
|
+
# try:
|
|
2505
|
+
# with open(split_file_path, 'r', encoding='utf-8') as f:
|
|
2506
|
+
# split_ids = json.load(f)
|
|
2507
|
+
# train_ms2_ids = split_ids['train_ms2_ids']
|
|
2508
|
+
# test_ms2_ids = split_ids['test_ms2_ids']
|
|
2509
|
+
# except json.JSONDecodeError:
|
|
2510
|
+
# print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed. Will recreate.")
|
|
2511
|
+
# except KeyError:
|
|
2512
|
+
# print(f" ⚠️ Warning: File '{split_file_path}' has invalid format. Will recreate.")
|
|
2513
|
+
|
|
2514
|
+
# if train_ms2_ids is None or test_ms2_ids is None:
|
|
2515
|
+
# print(f"⚠️ No valid split file found. Creating a new MS2-level split...")
|
|
2516
|
+
|
|
2517
|
+
# # 按分子ID分组MS2谱图
|
|
2518
|
+
# molecule_to_ms2 = {}
|
|
2519
|
+
# for ms2_id, ms2_info in ms2_data.items():
|
|
2520
|
+
# mol_id = ms2_info['molecule_id']
|
|
2521
|
+
# if mol_id not in molecule_to_ms2:
|
|
2522
|
+
# molecule_to_ms2[mol_id] = []
|
|
2523
|
+
# molecule_to_ms2[mol_id].append(ms2_id)
|
|
2524
|
+
|
|
2525
|
+
# # 统计信息
|
|
2526
|
+
# single_ms2_molecules = []
|
|
2527
|
+
# multi_ms2_molecules = []
|
|
2528
|
+
# for mol_id, ms2_list in molecule_to_ms2.items():
|
|
2529
|
+
# if len(ms2_list) == 1:
|
|
2530
|
+
# single_ms2_molecules.append(mol_id)
|
|
2531
|
+
# else:
|
|
2532
|
+
# multi_ms2_molecules.append(mol_id)
|
|
2533
|
+
|
|
2534
|
+
# print(f" Found {len(single_ms2_molecules)} molecules with single MS2 spectrum")
|
|
2535
|
+
# print(f" Found {len(multi_ms2_molecules)} molecules with multiple MS2 spectra")
|
|
2536
|
+
|
|
2537
|
+
|
|
2538
|
+
# train_ms2_ids = []
|
|
2539
|
+
# test_ms2_ids = []
|
|
2540
|
+
|
|
2541
|
+
# # 处理只有一个MS2的分子:全部放入训练集
|
|
2542
|
+
# for mol_id in single_ms2_molecules:
|
|
2543
|
+
# train_ms2_ids.extend(molecule_to_ms2[mol_id])
|
|
2544
|
+
|
|
2545
|
+
# # 处理有多个MS2的分子:随机选择一个作为测试集,其余作为训练集
|
|
2546
|
+
# for mol_id in multi_ms2_molecules:
|
|
2547
|
+
# ms2_list = molecule_to_ms2[mol_id]
|
|
2548
|
+
# # 随机选择一个MS2作为测试集
|
|
2549
|
+
# test_ms2_id = np.random.choice(ms2_list)
|
|
2550
|
+
# test_ms2_ids.append(test_ms2_id)
|
|
2551
|
+
# # 其余的作为训练集
|
|
2552
|
+
# train_ms2_ids.extend([ms2_id for ms2_id in ms2_list if ms2_id != test_ms2_id])
|
|
2553
|
+
|
|
2554
|
+
# print(f" Split results: {len(train_ms2_ids)} training MS2, {len(test_ms2_ids)} test MS2")
|
|
2555
|
+
|
|
2556
|
+
# # 保存划分结果
|
|
2557
|
+
# print(f" Saving new split to: {split_file_path}")
|
|
2558
|
+
# split_data_to_save = {
|
|
2559
|
+
# 'train_ms2_ids': train_ms2_ids,
|
|
2560
|
+
# 'test_ms2_ids': test_ms2_ids
|
|
2561
|
+
# }
|
|
2562
|
+
# split_file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2563
|
+
# with open(split_file_path, 'w', encoding='utf-8') as f:
|
|
2564
|
+
# json.dump(split_data_to_save, f, indent=4)
|
|
2565
|
+
# print(" Split file saved successfully.")
|
|
2566
|
+
|
|
2567
|
+
# # --- 3. Filter data sources by MS2 ID lists ---
|
|
2568
|
+
# train_ms2_ids_set = set(train_ms2_ids)
|
|
2569
|
+
# test_ms2_ids_set = set(test_ms2_ids)
|
|
2570
|
+
|
|
2571
|
+
# # 过滤MS2数据
|
|
2572
|
+
# train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in train_ms2_ids_set}
|
|
2573
|
+
# test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in test_ms2_ids_set}
|
|
2574
|
+
|
|
2575
|
+
# # 获取涉及的分子ID
|
|
2576
|
+
# train_molecule_ids = set(info['molecule_id'] for info in train_ms2_data.values())
|
|
2577
|
+
# test_molecule_ids = set(info['molecule_id'] for info in test_ms2_data.values())
|
|
2578
|
+
|
|
2579
|
+
# # 过滤元数据和biotext数据
|
|
2580
|
+
# # 注意:训练集和测试集可能包含相同的分子ID(因为同一分子的不同MS2可能分布在训练集和测试集中)
|
|
2581
|
+
# train_meta_data = meta_data[meta_data.index.isin(train_molecule_ids)]
|
|
2582
|
+
# test_meta_data = meta_data[meta_data.index.isin(test_molecule_ids)]
|
|
2583
|
+
|
|
2584
|
+
# train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_molecule_ids}
|
|
2585
|
+
# test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_molecule_ids}
|
|
2586
|
+
|
|
2587
|
+
# print(f"Data filtering completed:")
|
|
2588
|
+
# print(f" Training: {len(train_ms2_data)} MS2 spectra from {len(train_molecule_ids)} molecules")
|
|
2589
|
+
# print(f" Test: {len(test_ms2_data)} MS2 spectra from {len(test_molecule_ids)} molecules")
|
|
2590
|
+
|
|
2591
|
+
# # --- 4. Create Dataset instances ---
|
|
2592
|
+
# print("Creating training Dataset instance...")
|
|
2593
|
+
# train_dataset = MS2BioTextDataset(
|
|
2594
|
+
# ms2_data=train_ms2_data,
|
|
2595
|
+
# meta_data=train_meta_data,
|
|
2596
|
+
# biotext_data=train_biotext_data,
|
|
2597
|
+
# tokenizer=tokenizer,
|
|
2598
|
+
# use_paraphrase=use_paraphrase,
|
|
2599
|
+
# **kwargs
|
|
2600
|
+
# )
|
|
2601
|
+
|
|
2602
|
+
# print("Creating test Dataset instance...")
|
|
2603
|
+
# test_dataset = MS2BioTextDataset(
|
|
2604
|
+
# ms2_data=test_ms2_data,
|
|
2605
|
+
# meta_data=test_meta_data,
|
|
2606
|
+
# biotext_data=test_biotext_data,
|
|
2607
|
+
# tokenizer=tokenizer,
|
|
2608
|
+
# use_paraphrase=False,
|
|
2609
|
+
# **kwargs
|
|
2610
|
+
# )
|
|
2611
|
+
|
|
2612
|
+
# return train_dataset, test_dataset
|
|
2613
|
+
|
|
2614
|
+
@staticmethod
|
|
2615
|
+
def filter_shared_texts(biotext_data, max_sharing_molecules=5):
|
|
2616
|
+
"""
|
|
2617
|
+
删除被过多molecule共享的text
|
|
2618
|
+
|
|
2619
|
+
Args:
|
|
2620
|
+
biotext_data: {molecule_id: [{'type': ..., 'text': ...}, ...]}
|
|
2621
|
+
max_sharing_molecules: text最多可以被多少个molecule共享
|
|
2622
|
+
|
|
2623
|
+
Returns:
|
|
2624
|
+
filtered_biotext_data: 清洗后的数据
|
|
2625
|
+
stats: 统计信息
|
|
2626
|
+
"""
|
|
2627
|
+
from collections import defaultdict
|
|
2628
|
+
|
|
2629
|
+
print(f"\n=== Filtering shared texts (max_sharing={max_sharing_molecules}) ===")
|
|
2630
|
+
|
|
2631
|
+
# 1. 构建text -> molecules的倒排索引
|
|
2632
|
+
text_to_molecules = defaultdict(set)
|
|
2633
|
+
|
|
2634
|
+
for mol_id, entry in biotext_data.items():
|
|
2635
|
+
texts = []
|
|
2636
|
+
if isinstance(entry, list):
|
|
2637
|
+
texts = [record['text'] for record in entry]
|
|
2638
|
+
elif isinstance(entry, dict):
|
|
2639
|
+
texts = [entry.get("original", "")] + entry.get("paraphrases", [])
|
|
2640
|
+
elif isinstance(entry, str):
|
|
2641
|
+
texts = [entry]
|
|
2642
|
+
|
|
2643
|
+
for text in texts:
|
|
2644
|
+
if text: # 避免空字符串
|
|
2645
|
+
text_to_molecules[text].add(mol_id)
|
|
2646
|
+
|
|
2647
|
+
# 2. 找出需要删除的高频text
|
|
2648
|
+
texts_to_remove = set()
|
|
2649
|
+
for text, molecules in text_to_molecules.items():
|
|
2650
|
+
if len(molecules) > max_sharing_molecules:
|
|
2651
|
+
texts_to_remove.add(text)
|
|
2652
|
+
|
|
2653
|
+
print(f"Found {len(texts_to_remove)} texts shared by >{max_sharing_molecules} molecules")
|
|
2654
|
+
|
|
2655
|
+
# 3. 从每个molecule的候选text中删除这些高频text
|
|
2656
|
+
filtered_biotext_data = {}
|
|
2657
|
+
total_removed = 0
|
|
2658
|
+
molecules_with_no_text = []
|
|
2659
|
+
|
|
2660
|
+
for mol_id, entry in biotext_data.items():
|
|
2661
|
+
if isinstance(entry, list):
|
|
2662
|
+
# 新格式:列表
|
|
2663
|
+
filtered_entry = [record for record in entry
|
|
2664
|
+
if record['text'] not in texts_to_remove]
|
|
2665
|
+
|
|
2666
|
+
if filtered_entry:
|
|
2667
|
+
filtered_biotext_data[mol_id] = filtered_entry
|
|
2668
|
+
else:
|
|
2669
|
+
molecules_with_no_text.append(mol_id)
|
|
2670
|
+
|
|
2671
|
+
total_removed += len(entry) - len(filtered_entry)
|
|
2672
|
+
|
|
2673
|
+
elif isinstance(entry, dict):
|
|
2674
|
+
# 旧格式:字典
|
|
2675
|
+
original = entry.get("original", "")
|
|
2676
|
+
paraphrases = entry.get("paraphrases", [])
|
|
2677
|
+
|
|
2678
|
+
filtered_paraphrases = [p for p in paraphrases if p not in texts_to_remove]
|
|
2679
|
+
|
|
2680
|
+
# 如果original也被删除了,用第一个paraphrase作为original
|
|
2681
|
+
if original in texts_to_remove:
|
|
2682
|
+
if filtered_paraphrases:
|
|
2683
|
+
original = filtered_paraphrases[0]
|
|
2684
|
+
filtered_paraphrases = filtered_paraphrases[1:]
|
|
2685
|
+
else:
|
|
2686
|
+
molecules_with_no_text.append(mol_id)
|
|
2687
|
+
continue
|
|
2688
|
+
|
|
2689
|
+
filtered_biotext_data[mol_id] = {
|
|
2690
|
+
'original': original,
|
|
2691
|
+
'paraphrases': filtered_paraphrases
|
|
2692
|
+
}
|
|
2693
|
+
|
|
2694
|
+
original_count = 1 if entry.get("original", "") not in texts_to_remove else 0
|
|
2695
|
+
total_removed += (len(paraphrases) - len(filtered_paraphrases) +
|
|
2696
|
+
(1 - original_count))
|
|
2697
|
+
|
|
2698
|
+
elif isinstance(entry, str):
|
|
2699
|
+
# 字符串格式
|
|
2700
|
+
if entry not in texts_to_remove:
|
|
2701
|
+
filtered_biotext_data[mol_id] = entry
|
|
2702
|
+
else:
|
|
2703
|
+
molecules_with_no_text.append(mol_id)
|
|
2704
|
+
|
|
2705
|
+
# 4. 统计信息
|
|
2706
|
+
print(f"Statistics:")
|
|
2707
|
+
print(f" Total text entries removed: {total_removed}")
|
|
2708
|
+
print(f" Molecules before filtering: {len(biotext_data)}")
|
|
2709
|
+
print(f" Molecules after filtering: {len(filtered_biotext_data)}")
|
|
2710
|
+
print(f" Molecules with no text left: {len(molecules_with_no_text)}")
|
|
2711
|
+
|
|
2712
|
+
if molecules_with_no_text:
|
|
2713
|
+
print(f" Warning: {len(molecules_with_no_text)} molecules lost all texts!")
|
|
2714
|
+
print(f" First 5: {molecules_with_no_text[:5]}")
|
|
2715
|
+
|
|
2716
|
+
# 5. 验证过滤效果
|
|
2717
|
+
text_to_molecules_after = defaultdict(set)
|
|
2718
|
+
for mol_id, entry in filtered_biotext_data.items():
|
|
2719
|
+
texts = []
|
|
2720
|
+
if isinstance(entry, list):
|
|
2721
|
+
texts = [record['text'] for record in entry]
|
|
2722
|
+
elif isinstance(entry, dict):
|
|
2723
|
+
texts = [entry.get("original", "")] + entry.get("paraphrases", [])
|
|
2724
|
+
elif isinstance(entry, str):
|
|
2725
|
+
texts = [entry]
|
|
2726
|
+
|
|
2727
|
+
for text in texts:
|
|
2728
|
+
if text:
|
|
2729
|
+
text_to_molecules_after[text].add(mol_id)
|
|
2730
|
+
|
|
2731
|
+
max_sharing_after = max(len(mols) for mols in text_to_molecules_after.values()) if text_to_molecules_after else 0
|
|
2732
|
+
print(f" Max molecules sharing one text after filtering: {max_sharing_after}")
|
|
2733
|
+
|
|
2734
|
+
return filtered_biotext_data, {
|
|
2735
|
+
'removed_texts': len(texts_to_remove),
|
|
2736
|
+
'removed_entries': total_removed,
|
|
2737
|
+
'molecules_no_text': len(molecules_with_no_text),
|
|
2738
|
+
'max_sharing_after': max_sharing_after
|
|
2739
|
+
}
|
|
2740
|
+
|
|
2741
|
+
|
|
2742
|
+
|
|
2743
|
+
@staticmethod
|
|
2744
|
+
def create_train_test_datasets_from_file(
|
|
2745
|
+
data_dir,
|
|
2746
|
+
ms2_data,
|
|
2747
|
+
meta_data,
|
|
2748
|
+
biotext_data,
|
|
2749
|
+
tokenizer,
|
|
2750
|
+
word2idx,
|
|
2751
|
+
args,
|
|
2752
|
+
test_size=0.2,
|
|
2753
|
+
use_paraphrase = False,
|
|
2754
|
+
**kwargs
|
|
2755
|
+
):
|
|
2756
|
+
"""
|
|
2757
|
+
Split data into training and test sets.
|
|
2758
|
+
If a split file exists, use its test_ms2_ids as the test set,
|
|
2759
|
+
and use ALL OTHER MS2 spectra (including any new data) as the training set.
|
|
2760
|
+
If no split file exists, create a new split following the original logic.
|
|
2761
|
+
|
|
2762
|
+
Parameters:
|
|
2763
|
+
data_dir (str or Path): Path to the directory containing the split file (ms2_split.json).
|
|
2764
|
+
ms2_data (dict): Full MS2 data dictionary {ms2_id: {molecule_id: ..., ...}}.
|
|
2765
|
+
meta_data (pd.DataFrame): Metadata dataframe indexed by molecule ID.
|
|
2766
|
+
biotext_data (dict): Full BioText data dictionary {molecule_id: text}.
|
|
2767
|
+
tokenizer: Tokenizer used for initializing the Dataset.
|
|
2768
|
+
test_size (float): Deprecated in this version (kept for compatibility).
|
|
2769
|
+
**kwargs: Other arguments passed to the MS2BioTextDataset constructor.
|
|
2770
|
+
|
|
2771
|
+
Returns:
|
|
2772
|
+
tuple: (train_dataset, test_dataset), two MS2BioTextDataset instances.
|
|
2773
|
+
"""
|
|
2774
|
+
print("Preparing training and test datasets with per-molecule MS2 splitting...")
|
|
2775
|
+
|
|
2776
|
+
# --- 1. Define split file path ---
|
|
2777
|
+
data_dir = Path(data_dir)
|
|
2778
|
+
split_file_path = data_dir / 'ms2_split.json'
|
|
2779
|
+
|
|
2780
|
+
# --- 2. Load existing split file if valid; otherwise create new split ---
|
|
2781
|
+
test_ms2_ids = None
|
|
2782
|
+
|
|
2783
|
+
if split_file_path.exists() and split_file_path.stat().st_size > 0:
|
|
2784
|
+
print(f"✅ Found existing split file: {split_file_path}")
|
|
2785
|
+
print(" Loading test MS2 IDs from file...")
|
|
2786
|
+
try:
|
|
2787
|
+
with open(split_file_path, 'r', encoding='utf-8') as f:
|
|
2788
|
+
split_ids = json.load(f)
|
|
2789
|
+
test_ms2_ids = split_ids['test_ms2_ids']
|
|
2790
|
+
print(f" Loaded {len(test_ms2_ids)} test MS2 IDs from existing split.")
|
|
2791
|
+
except json.JSONDecodeError:
|
|
2792
|
+
print(f" ⚠️ Warning: File '{split_file_path}' exists but could not be parsed. Will recreate.")
|
|
2793
|
+
except KeyError:
|
|
2794
|
+
print(f" ⚠️ Warning: File '{split_file_path}' has invalid format. Will recreate.")
|
|
2795
|
+
|
|
2796
|
+
if test_ms2_ids is None:
|
|
2797
|
+
print(f"⚠️ No valid split file found. Creating a new MS2-level split...")
|
|
2798
|
+
|
|
2799
|
+
# 按分子ID分组MS2谱图
|
|
2800
|
+
molecule_to_ms2 = {}
|
|
2801
|
+
for ms2_id, ms2_info in ms2_data.items():
|
|
2802
|
+
mol_id = ms2_info['molecule_id']
|
|
2803
|
+
if mol_id not in molecule_to_ms2:
|
|
2804
|
+
molecule_to_ms2[mol_id] = []
|
|
2805
|
+
molecule_to_ms2[mol_id].append(ms2_id)
|
|
2806
|
+
|
|
2807
|
+
# 统计信息
|
|
2808
|
+
single_ms2_molecules = []
|
|
2809
|
+
multi_ms2_molecules = []
|
|
2810
|
+
for mol_id, ms2_list in molecule_to_ms2.items():
|
|
2811
|
+
if len(ms2_list) == 1:
|
|
2812
|
+
single_ms2_molecules.append(mol_id)
|
|
2813
|
+
else:
|
|
2814
|
+
multi_ms2_molecules.append(mol_id)
|
|
2815
|
+
|
|
2816
|
+
print(f" Found {len(single_ms2_molecules)} molecules with single MS2 spectrum")
|
|
2817
|
+
print(f" Found {len(multi_ms2_molecules)} molecules with multiple MS2 spectra")
|
|
2818
|
+
|
|
2819
|
+
test_ms2_ids = []
|
|
2820
|
+
|
|
2821
|
+
# 处理有多个MS2的分子:随机选择一个作为测试集
|
|
2822
|
+
for mol_id in multi_ms2_molecules:
|
|
2823
|
+
ms2_list = molecule_to_ms2[mol_id]
|
|
2824
|
+
# 随机选择一个MS2作为测试集
|
|
2825
|
+
test_ms2_id = np.random.choice(ms2_list)
|
|
2826
|
+
test_ms2_ids.append(test_ms2_id)
|
|
2827
|
+
|
|
2828
|
+
print(f" Created new test set with {len(test_ms2_ids)} MS2 spectra")
|
|
2829
|
+
|
|
2830
|
+
# 保存划分结果(只保存test_ms2_ids,train会动态计算)
|
|
2831
|
+
print(f" Saving new split to: {split_file_path}")
|
|
2832
|
+
split_data_to_save = {
|
|
2833
|
+
'test_ms2_ids': test_ms2_ids,
|
|
2834
|
+
'note': 'Training set uses all MS2 IDs not in test_ms2_ids'
|
|
2835
|
+
}
|
|
2836
|
+
split_file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
2837
|
+
with open(split_file_path, 'w', encoding='utf-8') as f:
|
|
2838
|
+
json.dump(split_data_to_save, f, indent=4)
|
|
2839
|
+
print(" Split file saved successfully.")
|
|
2840
|
+
|
|
2841
|
+
# --- 3. Create train set: ALL MS2 IDs except those in test set ---
|
|
2842
|
+
test_ms2_ids_set = set(test_ms2_ids)
|
|
2843
|
+
all_ms2_ids = set(ms2_data.keys())
|
|
2844
|
+
train_ms2_ids_set = all_ms2_ids - test_ms2_ids_set
|
|
2845
|
+
|
|
2846
|
+
print(f"\n📊 Dataset Statistics:")
|
|
2847
|
+
print(f" Total MS2 spectra: {len(all_ms2_ids)}")
|
|
2848
|
+
print(f" Test MS2 spectra: {len(test_ms2_ids_set)}")
|
|
2849
|
+
print(f" Training MS2 spectra: {len(train_ms2_ids_set)}")
|
|
2850
|
+
|
|
2851
|
+
# 过滤MS2数据
|
|
2852
|
+
train_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in train_ms2_ids_set}
|
|
2853
|
+
test_ms2_data = {ms2_id: info for ms2_id, info in ms2_data.items() if ms2_id in test_ms2_ids_set}
|
|
2854
|
+
|
|
2855
|
+
# 获取涉及的分子ID
|
|
2856
|
+
train_molecule_ids = set(info['molecule_id'] for info in train_ms2_data.values())
|
|
2857
|
+
test_molecule_ids = set(info['molecule_id'] for info in test_ms2_data.values())
|
|
2858
|
+
|
|
2859
|
+
# 过滤元数据和biotext数据
|
|
2860
|
+
train_meta_data = meta_data[meta_data.index.isin(train_molecule_ids)]
|
|
2861
|
+
test_meta_data = meta_data[meta_data.index.isin(test_molecule_ids)]
|
|
2862
|
+
|
|
2863
|
+
train_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in train_molecule_ids}
|
|
2864
|
+
test_biotext_data = {mol_id: text for mol_id, text in biotext_data.items() if mol_id in test_molecule_ids}
|
|
2865
|
+
|
|
2866
|
+
print(f"\nData filtering completed:")
|
|
2867
|
+
print(f" Training: {len(train_ms2_data)} MS2 spectra from {len(train_molecule_ids)} molecules")
|
|
2868
|
+
print(f" Test: {len(test_ms2_data)} MS2 spectra from {len(test_molecule_ids)} molecules")
|
|
2869
|
+
|
|
2870
|
+
# --- 4. Create Dataset instances ---
|
|
2871
|
+
print("\nCreating training Dataset instance...")
|
|
2872
|
+
train_dataset = MS2BioTextDataset(
|
|
2873
|
+
ms2_data=train_ms2_data,
|
|
2874
|
+
meta_data=train_meta_data,
|
|
2875
|
+
biotext_data=train_biotext_data,
|
|
2876
|
+
tokenizer=tokenizer,
|
|
2877
|
+
use_paraphrase=use_paraphrase,
|
|
2878
|
+
word2idx=word2idx,
|
|
2879
|
+
args=args,
|
|
2880
|
+
split='train',
|
|
2881
|
+
**kwargs
|
|
2882
|
+
)
|
|
2883
|
+
|
|
2884
|
+
print("Creating test Dataset instance...")
|
|
2885
|
+
test_dataset = MS2BioTextDataset(
|
|
2886
|
+
ms2_data=test_ms2_data,
|
|
2887
|
+
meta_data=test_meta_data,
|
|
2888
|
+
biotext_data=test_biotext_data,
|
|
2889
|
+
tokenizer=tokenizer,
|
|
2890
|
+
use_paraphrase=False,
|
|
2891
|
+
word2idx=word2idx,
|
|
2892
|
+
args=args,
|
|
2893
|
+
split='test',
|
|
2894
|
+
**kwargs
|
|
2895
|
+
)
|
|
2896
|
+
|
|
2897
|
+
return train_dataset, test_dataset
|
|
2898
|
+
|
|
2899
|
+
|
|
2900
|
+
|
|
2901
|
+
|
|
2902
|
+
# 在文件顶部,import 语句之后,类定义之前
|
|
2903
|
+
|
|
2904
|
+
def _augment_worker(batch_data):
|
|
2905
|
+
"""
|
|
2906
|
+
全局 worker 函数,用于数据增强
|
|
2907
|
+
batch_data: (batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier)
|
|
2908
|
+
"""
|
|
2909
|
+
import numpy as np
|
|
2910
|
+
# ⚠️ 关键修改:不要导入 MS2BioTextDataset,直接在这里定义需要的函数
|
|
2911
|
+
|
|
2912
|
+
batch, filter_threshold, noise_ratio, noise_intensity_range, augment_multiplier = batch_data
|
|
2913
|
+
|
|
2914
|
+
# 将 filter_low_intensity_peaks 和 add_noise_peaks 的逻辑直接复制到这里
|
|
2915
|
+
def filter_low_intensity_peaks(peaks, intensities, threshold):
|
|
2916
|
+
"""过滤低强度峰"""
|
|
2917
|
+
if not peaks or not intensities:
|
|
2918
|
+
return peaks, intensities
|
|
2919
|
+
|
|
2920
|
+
max_intensity = max(intensities)
|
|
2921
|
+
if max_intensity == 0:
|
|
2922
|
+
return peaks, intensities
|
|
2923
|
+
|
|
2924
|
+
filtered_peaks = []
|
|
2925
|
+
filtered_intensities = []
|
|
2926
|
+
for mz, intensity in zip(peaks, intensities):
|
|
2927
|
+
if intensity / max_intensity >= threshold:
|
|
2928
|
+
filtered_peaks.append(mz)
|
|
2929
|
+
filtered_intensities.append(intensity)
|
|
2930
|
+
|
|
2931
|
+
return filtered_peaks, filtered_intensities
|
|
2932
|
+
|
|
2933
|
+
def add_noise_peaks(peaks, intensities, noise_ratio, noise_intensity_range):
|
|
2934
|
+
"""添加噪声峰"""
|
|
2935
|
+
import random
|
|
2936
|
+
|
|
2937
|
+
if not peaks:
|
|
2938
|
+
return peaks, intensities
|
|
2939
|
+
|
|
2940
|
+
n_noise = int(len(peaks) * noise_ratio)
|
|
2941
|
+
if n_noise == 0:
|
|
2942
|
+
return peaks, intensities
|
|
2943
|
+
|
|
2944
|
+
# 获取m/z范围
|
|
2945
|
+
min_mz = min(peaks)
|
|
2946
|
+
max_mz = max(peaks)
|
|
2947
|
+
max_intensity = max(intensities) if intensities else 1.0
|
|
2948
|
+
|
|
2949
|
+
# 生成噪声峰
|
|
2950
|
+
for _ in range(n_noise):
|
|
2951
|
+
# 随机m/z(避免与现有峰重复)
|
|
2952
|
+
noise_mz = random.uniform(min_mz, max_mz)
|
|
2953
|
+
# 随机低强度
|
|
2954
|
+
noise_intensity = random.uniform(
|
|
2955
|
+
noise_intensity_range[0] * max_intensity,
|
|
2956
|
+
noise_intensity_range[1] * max_intensity
|
|
2957
|
+
)
|
|
2958
|
+
|
|
2959
|
+
peaks.append(noise_mz)
|
|
2960
|
+
intensities.append(noise_intensity)
|
|
2961
|
+
|
|
2962
|
+
# 按m/z排序
|
|
2963
|
+
sorted_pairs = sorted(zip(peaks, intensities), key=lambda x: x[0])
|
|
2964
|
+
peaks = [p for p, _ in sorted_pairs]
|
|
2965
|
+
intensities = [i for _, i in sorted_pairs]
|
|
2966
|
+
|
|
2967
|
+
return peaks, intensities
|
|
2968
|
+
|
|
2969
|
+
# 处理逻辑
|
|
2970
|
+
result = {}
|
|
2971
|
+
for ms2_id, info in batch:
|
|
2972
|
+
molecule_id = info.get('molecule_id')
|
|
2973
|
+
peaks_original = info['mz'] if isinstance(info['mz'], list) else list(info['mz'])
|
|
2974
|
+
intensities_original = info['intensity'] if isinstance(info['intensity'], list) else list(info['intensity'])
|
|
2975
|
+
|
|
2976
|
+
# 过滤
|
|
2977
|
+
if filter_threshold:
|
|
2978
|
+
peaks_original, intensities_original = filter_low_intensity_peaks(
|
|
2979
|
+
peaks_original, intensities_original, filter_threshold
|
|
2980
|
+
)
|
|
2981
|
+
|
|
2982
|
+
# 原始版本
|
|
2983
|
+
result[ms2_id] = {
|
|
2984
|
+
'mz': peaks_original,
|
|
2985
|
+
'intensity': intensities_original,
|
|
2986
|
+
'molecule_id': molecule_id
|
|
2987
|
+
}
|
|
2988
|
+
|
|
2989
|
+
# 增强版本
|
|
2990
|
+
for aug_idx in range(1, augment_multiplier):
|
|
2991
|
+
peaks_aug, intensities_aug = add_noise_peaks(
|
|
2992
|
+
peaks_original.copy(), intensities_original.copy(),
|
|
2993
|
+
noise_ratio,
|
|
2994
|
+
noise_intensity_range
|
|
2995
|
+
)
|
|
2996
|
+
result[f"{ms2_id}_aug{aug_idx}"] = {
|
|
2997
|
+
'mz': peaks_aug,
|
|
2998
|
+
'intensity': intensities_aug,
|
|
2999
|
+
'molecule_id': molecule_id
|
|
3000
|
+
}
|
|
3001
|
+
return result
|
|
3002
|
+
|
|
3003
|
+
def _preprocess_worker(batch_data):
|
|
3004
|
+
"""
|
|
3005
|
+
全局 worker 函数,用于多进程处理
|
|
3006
|
+
batch_data: (batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks, precursor_mode, precursor_value)
|
|
3007
|
+
"""
|
|
3008
|
+
import numpy as np
|
|
3009
|
+
import pandas as pd
|
|
3010
|
+
|
|
3011
|
+
batch, word2idx, meta_data_processed, maxlen, max_frag, min_peaks, precursor_mode, precursor_value = batch_data
|
|
3012
|
+
|
|
3013
|
+
result = {}
|
|
3014
|
+
stats = {'kept': 0, 'filtered': 0}
|
|
3015
|
+
|
|
3016
|
+
for ms2_id, info in batch:
|
|
3017
|
+
# 基础检查
|
|
3018
|
+
if not info.get('mz'):
|
|
3019
|
+
stats['filtered'] += 1
|
|
3020
|
+
continue
|
|
3021
|
+
|
|
3022
|
+
# 转为 numpy 数组
|
|
3023
|
+
peaks = np.asarray(info['mz'], dtype=float)
|
|
3024
|
+
intensities = np.asarray(info['intensity'], dtype=float)
|
|
3025
|
+
molecule_id = info.get('molecule_id', None)
|
|
3026
|
+
|
|
3027
|
+
# 长度对齐检查
|
|
3028
|
+
if peaks.shape[0] != intensities.shape[0]:
|
|
3029
|
+
n = min(len(peaks), len(intensities))
|
|
3030
|
+
peaks = peaks[:n]
|
|
3031
|
+
intensities = intensities[:n]
|
|
3032
|
+
|
|
3033
|
+
# 文件名对应行用于极性判断
|
|
3034
|
+
specific_row = meta_data_processed.loc[meta_data_processed["file_name"] == ms2_id] if "file_name" in meta_data_processed.columns else pd.DataFrame()
|
|
3035
|
+
if specific_row.empty:
|
|
3036
|
+
if molecule_id is not None:
|
|
3037
|
+
if 'HMDB.ID' in meta_data_processed.columns:
|
|
3038
|
+
specific_row = meta_data_processed.loc[meta_data_processed['HMDB.ID'] == molecule_id]
|
|
3039
|
+
else:
|
|
3040
|
+
specific_row = meta_data_processed.loc[meta_data_processed.index == molecule_id]
|
|
3041
|
+
|
|
3042
|
+
if specific_row.empty:
|
|
3043
|
+
stats['filtered'] += 1
|
|
3044
|
+
continue
|
|
3045
|
+
|
|
3046
|
+
# 只保留正离子
|
|
3047
|
+
pol = str(specific_row["Polarity"].values[0]).lower().strip() if "Polarity" in specific_row.columns else ""
|
|
3048
|
+
if pol != "positive":
|
|
3049
|
+
stats['filtered'] += 1
|
|
3050
|
+
continue
|
|
3051
|
+
|
|
3052
|
+
# 获取precursor
|
|
3053
|
+
if 'HMDB.ID' in meta_data_processed.columns and (molecule_id is not None):
|
|
3054
|
+
row = meta_data_processed.loc[meta_data_processed['HMDB.ID'] == molecule_id]
|
|
3055
|
+
else:
|
|
3056
|
+
row = meta_data_processed.loc[meta_data_processed.index == molecule_id]
|
|
3057
|
+
|
|
3058
|
+
if row.empty or ('precursor_mass' not in row.columns):
|
|
3059
|
+
stats['filtered'] += 1
|
|
3060
|
+
continue
|
|
3061
|
+
|
|
3062
|
+
try:
|
|
3063
|
+
precursor_val = float(row['precursor_mass'].values[0])
|
|
3064
|
+
except Exception:
|
|
3065
|
+
stats['filtered'] += 1
|
|
3066
|
+
continue
|
|
3067
|
+
|
|
3068
|
+
# 前体范围 [10, 1000)
|
|
3069
|
+
if pd.isna(precursor_val) or (precursor_val < 10.0) or (precursor_val >= 1000.0):
|
|
3070
|
+
stats['filtered'] += 1
|
|
3071
|
+
continue
|
|
3072
|
+
|
|
3073
|
+
precursor_val = min(precursor_val, 999.99)
|
|
3074
|
+
precursor_str = "%.2f" % precursor_val
|
|
3075
|
+
|
|
3076
|
+
# 过滤峰到 [10, 1000)
|
|
3077
|
+
mask = (peaks >= 10.0) & (peaks < 1000.0) & np.isfinite(peaks) & np.isfinite(intensities)
|
|
3078
|
+
peaks = peaks[mask]
|
|
3079
|
+
intensities = intensities[mask]
|
|
3080
|
+
|
|
3081
|
+
if peaks.size == 0:
|
|
3082
|
+
stats['filtered'] += 1
|
|
3083
|
+
continue
|
|
3084
|
+
|
|
3085
|
+
# 按强度选 Top-K 碎片
|
|
3086
|
+
if peaks.size > max_frag:
|
|
3087
|
+
idx = np.argpartition(intensities, -max_frag)[-max_frag:]
|
|
3088
|
+
order = np.argsort(peaks[idx])
|
|
3089
|
+
idx = idx[order]
|
|
3090
|
+
peaks_sel = peaks[idx]
|
|
3091
|
+
intens_sel = intensities[idx]
|
|
3092
|
+
else:
|
|
3093
|
+
order = np.argsort(peaks)
|
|
3094
|
+
peaks_sel = peaks[order]
|
|
3095
|
+
intens_sel = intensities[order]
|
|
3096
|
+
|
|
3097
|
+
# 检查 min_peaks
|
|
3098
|
+
if peaks_sel.size < min_peaks:
|
|
3099
|
+
stats['filtered'] += 1
|
|
3100
|
+
continue
|
|
3101
|
+
|
|
3102
|
+
# 构建 token 序列
|
|
3103
|
+
peaks_str = ["%.2f" % p for p in peaks_sel]
|
|
3104
|
+
try:
|
|
3105
|
+
token_ids = [word2idx[precursor_str]] + [word2idx[p] for p in peaks_str]
|
|
3106
|
+
except KeyError:
|
|
3107
|
+
stats['filtered'] += 1
|
|
3108
|
+
continue
|
|
3109
|
+
|
|
3110
|
+
# ⭐ 根据 precursor_mode 选择处理方式
|
|
3111
|
+
if precursor_mode == 'scale_fixed':
|
|
3112
|
+
# 方案一:缩放fragments到固定值precursor_value(如20000),然后precursor添加2
|
|
3113
|
+
if np.max(intens_sel) > 0:
|
|
3114
|
+
intens_sel = intens_sel / np.max(intens_sel) * precursor_value
|
|
3115
|
+
intens_seq = np.hstack((2.0, intens_sel))
|
|
3116
|
+
# 整体归一化
|
|
3117
|
+
max_intensity = float(np.max(intens_seq))
|
|
3118
|
+
if max_intensity > 0:
|
|
3119
|
+
intens_seq = intens_seq / max_intensity
|
|
3120
|
+
|
|
3121
|
+
elif precursor_mode == 'normalize_add':
|
|
3122
|
+
# 方案二:归一化fragments到1,添加precursor_value,再整体归一化
|
|
3123
|
+
if np.max(intens_sel) > 0:
|
|
3124
|
+
intens_sel = intens_sel / np.max(intens_sel)
|
|
3125
|
+
intens_seq = np.hstack((precursor_value, intens_sel))
|
|
3126
|
+
# 整体归一化
|
|
3127
|
+
max_intensity = float(np.max(intens_seq))
|
|
3128
|
+
if max_intensity > 0:
|
|
3129
|
+
intens_seq = intens_seq / max_intensity
|
|
3130
|
+
|
|
3131
|
+
else:
|
|
3132
|
+
# 默认:原始MSBERT方式
|
|
3133
|
+
intens_seq = np.hstack((2.0, intens_sel))
|
|
3134
|
+
max_intensity = float(np.max(intens_seq))
|
|
3135
|
+
if max_intensity > 0:
|
|
3136
|
+
intens_seq = intens_seq / max_intensity
|
|
3137
|
+
|
|
3138
|
+
# Pad 或截断到 maxlen
|
|
3139
|
+
if len(token_ids) > maxlen:
|
|
3140
|
+
token_ids = token_ids[:maxlen]
|
|
3141
|
+
intens_seq = intens_seq[:maxlen]
|
|
3142
|
+
|
|
3143
|
+
n_pad = maxlen - len(token_ids)
|
|
3144
|
+
if n_pad > 0:
|
|
3145
|
+
token_ids += [word2idx['[PAD]']] * n_pad
|
|
3146
|
+
intens_seq = np.hstack([intens_seq, np.zeros(n_pad, dtype=float)])
|
|
3147
|
+
|
|
3148
|
+
result[ms2_id] = {
|
|
3149
|
+
'mz': token_ids,
|
|
3150
|
+
'intensity': intens_seq.tolist(),
|
|
3151
|
+
'molecule_id': molecule_id
|
|
3152
|
+
}
|
|
3153
|
+
stats['kept'] += 1
|
|
3154
|
+
|
|
3155
|
+
return result, stats
|
|
3156
|
+
|
|
3157
|
+
|
|
3158
|
+
|
|
3159
|
+
|