weirdo 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weirdo/__init__.py +104 -0
- weirdo/amino_acid.py +33 -0
- weirdo/amino_acid_alphabet.py +158 -0
- weirdo/amino_acid_properties.py +358 -0
- weirdo/api.py +372 -0
- weirdo/blosum.py +74 -0
- weirdo/chou_fasman.py +73 -0
- weirdo/cli.py +597 -0
- weirdo/common.py +22 -0
- weirdo/data_manager.py +475 -0
- weirdo/distances.py +16 -0
- weirdo/matrices/BLOSUM30 +25 -0
- weirdo/matrices/BLOSUM50 +21 -0
- weirdo/matrices/BLOSUM62 +27 -0
- weirdo/matrices/__init__.py +0 -0
- weirdo/matrices/amino_acid_properties.txt +829 -0
- weirdo/matrices/helix_vs_coil.txt +28 -0
- weirdo/matrices/helix_vs_strand.txt +27 -0
- weirdo/matrices/pmbec.mat +21 -0
- weirdo/matrices/strand_vs_coil.txt +27 -0
- weirdo/model_manager.py +346 -0
- weirdo/peptide_vectorizer.py +78 -0
- weirdo/pmbec.py +85 -0
- weirdo/reduced_alphabet.py +61 -0
- weirdo/residue_contact_energies.py +74 -0
- weirdo/scorers/__init__.py +95 -0
- weirdo/scorers/base.py +223 -0
- weirdo/scorers/config.py +299 -0
- weirdo/scorers/mlp.py +1126 -0
- weirdo/scorers/reference.py +265 -0
- weirdo/scorers/registry.py +282 -0
- weirdo/scorers/similarity.py +386 -0
- weirdo/scorers/swissprot.py +510 -0
- weirdo/scorers/trainable.py +219 -0
- weirdo/static_data.py +17 -0
- weirdo-2.1.0.dist-info/METADATA +294 -0
- weirdo-2.1.0.dist-info/RECORD +41 -0
- weirdo-2.1.0.dist-info/WHEEL +5 -0
- weirdo-2.1.0.dist-info/entry_points.txt +2 -0
- weirdo-2.1.0.dist-info/licenses/LICENSE +201 -0
- weirdo-2.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
"""SwissProt reference dataset implementation.
|
|
2
|
+
|
|
3
|
+
Loads pre-computed k-mer data from SwissProt protein database.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import csv
|
|
7
|
+
import os
|
|
8
|
+
from typing import Dict, Iterator, List, Optional, Set, Tuple
|
|
9
|
+
|
|
10
|
+
from .reference import StreamingReference
|
|
11
|
+
from .registry import register_reference
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_default_data_path(auto_download: bool = False) -> str:
|
|
15
|
+
"""Get default data path, using data manager if available.
|
|
16
|
+
|
|
17
|
+
First checks for local repo data, then falls back to managed data.
|
|
18
|
+
"""
|
|
19
|
+
# Check local repo path first (for development)
|
|
20
|
+
local_path = os.path.join(
|
|
21
|
+
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
|
22
|
+
'data',
|
|
23
|
+
'swissprot-8mers.csv'
|
|
24
|
+
)
|
|
25
|
+
if os.path.exists(local_path):
|
|
26
|
+
return local_path
|
|
27
|
+
|
|
28
|
+
# Use data manager for installed package
|
|
29
|
+
try:
|
|
30
|
+
from ..data_manager import get_data_manager
|
|
31
|
+
dm = get_data_manager(auto_download=auto_download, verbose=auto_download)
|
|
32
|
+
return str(dm.get_data_path('swissprot-8mers', auto_download=auto_download))
|
|
33
|
+
except (ImportError, FileNotFoundError):
|
|
34
|
+
# Fall back to local path (will fail with helpful error if not found)
|
|
35
|
+
return local_path
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Legacy constant for backwards compatibility
|
|
39
|
+
DEFAULT_DATA_PATH = _get_default_data_path()
|
|
40
|
+
|
|
41
|
+
# All available categories in the CSV
|
|
42
|
+
ALL_CATEGORIES = [
|
|
43
|
+
'archaea', 'bacteria', 'fungi', 'human', 'invertebrates',
|
|
44
|
+
'mammals', 'plants', 'rodents', 'vertebrates', 'viruses'
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@register_reference('swissprot', description='SwissProt k-mer reference data')
|
|
49
|
+
class SwissProtReference(StreamingReference):
|
|
50
|
+
"""Reference dataset from SwissProt protein database.
|
|
51
|
+
|
|
52
|
+
Loads pre-computed k-mer presence from a CSV file containing
|
|
53
|
+
organism category membership for each k-mer.
|
|
54
|
+
|
|
55
|
+
Parameters
|
|
56
|
+
----------
|
|
57
|
+
categories : list of str, optional
|
|
58
|
+
Filter to specific organism categories.
|
|
59
|
+
Available: archaea, bacteria, fungi, human, invertebrates,
|
|
60
|
+
mammals, plants, rodents, vertebrates, viruses.
|
|
61
|
+
If None, consider k-mer present if in ANY category.
|
|
62
|
+
k : int, default=8
|
|
63
|
+
K-mer size (must match data file).
|
|
64
|
+
data_path : str, optional
|
|
65
|
+
Path to k-mer CSV file. Defaults to bundled data.
|
|
66
|
+
lazy : bool, default=False
|
|
67
|
+
If True, don't load data into memory; stream from disk.
|
|
68
|
+
use_set : bool, default=False
|
|
69
|
+
If True, only track k-mer presence (not category counts).
|
|
70
|
+
Reduces memory but loses category breakdown.
|
|
71
|
+
auto_download : bool, default=False
|
|
72
|
+
If True, automatically download reference data if not present.
|
|
73
|
+
|
|
74
|
+
Example
|
|
75
|
+
-------
|
|
76
|
+
>>> ref = SwissProtReference(categories=['human', 'mammals'])
|
|
77
|
+
>>> ref.load()
|
|
78
|
+
>>> ref.contains('MTMDKSEL')
|
|
79
|
+
True
|
|
80
|
+
>>> ref.get_frequency('MTMDKSEL')
|
|
81
|
+
1.0 # Present in filtered categories
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
categories: Optional[List[str]] = None,
|
|
87
|
+
k: int = 8,
|
|
88
|
+
data_path: Optional[str] = None,
|
|
89
|
+
lazy: bool = False,
|
|
90
|
+
use_set: bool = False,
|
|
91
|
+
auto_download: bool = False,
|
|
92
|
+
**kwargs
|
|
93
|
+
):
|
|
94
|
+
super().__init__(categories=categories, k=k, lazy=lazy, use_set=use_set, **kwargs)
|
|
95
|
+
self._auto_download = auto_download
|
|
96
|
+
if data_path:
|
|
97
|
+
self._data_path = data_path
|
|
98
|
+
else:
|
|
99
|
+
self._data_path = _get_default_data_path(auto_download=auto_download)
|
|
100
|
+
self._kmers: Dict[str, Dict[str, bool]] = {} # kmer -> {category: bool}
|
|
101
|
+
self._kmer_set: Set[str] = set() # For use_set mode
|
|
102
|
+
self._total_kmers = 0
|
|
103
|
+
|
|
104
|
+
# Validate categories
|
|
105
|
+
if categories is not None:
|
|
106
|
+
invalid = set(categories) - set(ALL_CATEGORIES)
|
|
107
|
+
if invalid:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Invalid categories: {invalid}. "
|
|
110
|
+
f"Available: {ALL_CATEGORIES}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def data_path(self) -> str:
|
|
115
|
+
"""Get path to data file."""
|
|
116
|
+
return self._data_path
|
|
117
|
+
|
|
118
|
+
def load(self) -> 'SwissProtReference':
|
|
119
|
+
"""Load k-mer data from CSV file.
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
self : SwissProtReference
|
|
124
|
+
Returns self for method chaining.
|
|
125
|
+
"""
|
|
126
|
+
if self._lazy:
|
|
127
|
+
# In lazy mode, just verify file exists
|
|
128
|
+
if not os.path.exists(self._data_path):
|
|
129
|
+
raise FileNotFoundError(f"Data file not found: {self._data_path}")
|
|
130
|
+
self._is_loaded = True
|
|
131
|
+
return self
|
|
132
|
+
|
|
133
|
+
self._kmers.clear()
|
|
134
|
+
self._kmer_set.clear()
|
|
135
|
+
self._total_kmers = 0
|
|
136
|
+
|
|
137
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
138
|
+
reader = csv.DictReader(f)
|
|
139
|
+
for row in reader:
|
|
140
|
+
kmer = row['seq']
|
|
141
|
+
|
|
142
|
+
# Check if k-mer is present in any of the filtered categories
|
|
143
|
+
if self._categories is not None:
|
|
144
|
+
present = any(
|
|
145
|
+
row.get(cat, 'False') == 'True'
|
|
146
|
+
for cat in self._categories
|
|
147
|
+
)
|
|
148
|
+
if not present:
|
|
149
|
+
continue
|
|
150
|
+
else:
|
|
151
|
+
# No filter - include if present in any category
|
|
152
|
+
present = any(
|
|
153
|
+
row.get(cat, 'False') == 'True'
|
|
154
|
+
for cat in ALL_CATEGORIES
|
|
155
|
+
)
|
|
156
|
+
if not present:
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
if self._use_set:
|
|
160
|
+
self._kmer_set.add(kmer)
|
|
161
|
+
else:
|
|
162
|
+
self._kmers[kmer] = {
|
|
163
|
+
cat: row.get(cat, 'False') == 'True'
|
|
164
|
+
for cat in ALL_CATEGORIES
|
|
165
|
+
}
|
|
166
|
+
self._total_kmers += 1
|
|
167
|
+
|
|
168
|
+
self._is_loaded = True
|
|
169
|
+
return self
|
|
170
|
+
|
|
171
|
+
def contains(self, kmer: str) -> bool:
|
|
172
|
+
"""Check if k-mer exists in reference.
|
|
173
|
+
|
|
174
|
+
Parameters
|
|
175
|
+
----------
|
|
176
|
+
kmer : str
|
|
177
|
+
K-mer sequence to look up.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
exists : bool
|
|
182
|
+
True if k-mer is in reference (after category filtering).
|
|
183
|
+
"""
|
|
184
|
+
if self._lazy:
|
|
185
|
+
# Stream through file to check
|
|
186
|
+
return self._lazy_contains(kmer)
|
|
187
|
+
|
|
188
|
+
if self._use_set:
|
|
189
|
+
return kmer in self._kmer_set
|
|
190
|
+
return kmer in self._kmers
|
|
191
|
+
|
|
192
|
+
def _lazy_contains(self, kmer: str) -> bool:
|
|
193
|
+
"""Check containment by streaming file (lazy mode)."""
|
|
194
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
195
|
+
reader = csv.DictReader(f)
|
|
196
|
+
for row in reader:
|
|
197
|
+
if row['seq'] == kmer:
|
|
198
|
+
if self._categories is not None:
|
|
199
|
+
return any(
|
|
200
|
+
row.get(cat, 'False') == 'True'
|
|
201
|
+
for cat in self._categories
|
|
202
|
+
)
|
|
203
|
+
return True
|
|
204
|
+
return False
|
|
205
|
+
|
|
206
|
+
def get_frequency(self, kmer: str, default: float = 0.0) -> float:
|
|
207
|
+
"""Get frequency of k-mer in reference.
|
|
208
|
+
|
|
209
|
+
For SwissProt data, returns 1.0 if present, 0.0 if not.
|
|
210
|
+
True frequency computation requires count data.
|
|
211
|
+
|
|
212
|
+
Parameters
|
|
213
|
+
----------
|
|
214
|
+
kmer : str
|
|
215
|
+
K-mer sequence to look up.
|
|
216
|
+
default : float, default=0.0
|
|
217
|
+
Value to return if k-mer not found.
|
|
218
|
+
|
|
219
|
+
Returns
|
|
220
|
+
-------
|
|
221
|
+
frequency : float
|
|
222
|
+
1.0 if present, default if not found.
|
|
223
|
+
"""
|
|
224
|
+
if self.contains(kmer):
|
|
225
|
+
return 1.0
|
|
226
|
+
return default
|
|
227
|
+
|
|
228
|
+
def get_categories(self) -> List[str]:
|
|
229
|
+
"""Get list of available organism categories.
|
|
230
|
+
|
|
231
|
+
Returns
|
|
232
|
+
-------
|
|
233
|
+
categories : list of str
|
|
234
|
+
Available category names.
|
|
235
|
+
"""
|
|
236
|
+
if self._categories is not None:
|
|
237
|
+
return list(self._categories)
|
|
238
|
+
return ALL_CATEGORIES.copy()
|
|
239
|
+
|
|
240
|
+
def get_kmer_categories(self, kmer: str) -> Dict[str, bool]:
|
|
241
|
+
"""Get category presence for a specific k-mer.
|
|
242
|
+
|
|
243
|
+
Parameters
|
|
244
|
+
----------
|
|
245
|
+
kmer : str
|
|
246
|
+
K-mer sequence to look up.
|
|
247
|
+
|
|
248
|
+
Returns
|
|
249
|
+
-------
|
|
250
|
+
categories : dict
|
|
251
|
+
Mapping of category name to presence (True/False).
|
|
252
|
+
Empty dict if k-mer not found.
|
|
253
|
+
"""
|
|
254
|
+
self._check_is_loaded()
|
|
255
|
+
|
|
256
|
+
if self._use_set:
|
|
257
|
+
raise RuntimeError(
|
|
258
|
+
"Category information not available in use_set mode. "
|
|
259
|
+
"Set use_set=False to access category data."
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
if self._lazy:
|
|
263
|
+
return self._lazy_get_categories(kmer)
|
|
264
|
+
|
|
265
|
+
return self._kmers.get(kmer, {})
|
|
266
|
+
|
|
267
|
+
def _lazy_get_categories(self, kmer: str) -> Dict[str, bool]:
|
|
268
|
+
"""Get categories by streaming file (lazy mode)."""
|
|
269
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
270
|
+
reader = csv.DictReader(f)
|
|
271
|
+
for row in reader:
|
|
272
|
+
if row['seq'] == kmer:
|
|
273
|
+
return {
|
|
274
|
+
cat: row.get(cat, 'False') == 'True'
|
|
275
|
+
for cat in ALL_CATEGORIES
|
|
276
|
+
}
|
|
277
|
+
return {}
|
|
278
|
+
|
|
279
|
+
def iter_kmers(self) -> Iterator[str]:
|
|
280
|
+
"""Iterate over all k-mers in reference.
|
|
281
|
+
|
|
282
|
+
Yields
|
|
283
|
+
------
|
|
284
|
+
kmer : str
|
|
285
|
+
Each k-mer sequence in the reference.
|
|
286
|
+
"""
|
|
287
|
+
self._check_is_loaded()
|
|
288
|
+
|
|
289
|
+
if self._lazy:
|
|
290
|
+
yield from self._lazy_iter_kmers()
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
if self._use_set:
|
|
294
|
+
yield from self._kmer_set
|
|
295
|
+
else:
|
|
296
|
+
yield from self._kmers.keys()
|
|
297
|
+
|
|
298
|
+
def _lazy_iter_kmers(self) -> Iterator[str]:
|
|
299
|
+
"""Iterate k-mers by streaming file (lazy mode)."""
|
|
300
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
301
|
+
reader = csv.DictReader(f)
|
|
302
|
+
for row in reader:
|
|
303
|
+
kmer = row['seq']
|
|
304
|
+
if self._categories is not None:
|
|
305
|
+
present = any(
|
|
306
|
+
row.get(cat, 'False') == 'True'
|
|
307
|
+
for cat in self._categories
|
|
308
|
+
)
|
|
309
|
+
if present:
|
|
310
|
+
yield kmer
|
|
311
|
+
else:
|
|
312
|
+
yield kmer
|
|
313
|
+
|
|
314
|
+
def iter_kmers_with_counts(self) -> Iterator[Tuple[str, int]]:
|
|
315
|
+
"""Iterate over k-mers with their category counts.
|
|
316
|
+
|
|
317
|
+
Yields
|
|
318
|
+
------
|
|
319
|
+
kmer : str
|
|
320
|
+
K-mer sequence.
|
|
321
|
+
count : int
|
|
322
|
+
Number of categories the k-mer appears in.
|
|
323
|
+
"""
|
|
324
|
+
self._check_is_loaded()
|
|
325
|
+
|
|
326
|
+
if self._lazy:
|
|
327
|
+
yield from self._lazy_iter_with_counts()
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
if self._use_set:
|
|
331
|
+
for kmer in self._kmer_set:
|
|
332
|
+
yield kmer, 1
|
|
333
|
+
else:
|
|
334
|
+
for kmer, cats in self._kmers.items():
|
|
335
|
+
count = sum(1 for v in cats.values() if v)
|
|
336
|
+
yield kmer, count
|
|
337
|
+
|
|
338
|
+
def _lazy_iter_with_counts(self) -> Iterator[Tuple[str, int]]:
|
|
339
|
+
"""Iterate with counts by streaming file (lazy mode)."""
|
|
340
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
341
|
+
reader = csv.DictReader(f)
|
|
342
|
+
for row in reader:
|
|
343
|
+
kmer = row['seq']
|
|
344
|
+
try:
|
|
345
|
+
count = int(row.get('label_count', 0))
|
|
346
|
+
except (ValueError, TypeError):
|
|
347
|
+
count = 0
|
|
348
|
+
|
|
349
|
+
if self._categories is not None:
|
|
350
|
+
present = any(
|
|
351
|
+
row.get(cat, 'False') == 'True'
|
|
352
|
+
for cat in self._categories
|
|
353
|
+
)
|
|
354
|
+
if present:
|
|
355
|
+
yield kmer, count
|
|
356
|
+
else:
|
|
357
|
+
yield kmer, count
|
|
358
|
+
|
|
359
|
+
def iter_kmers_with_categories(self) -> Iterator[Tuple[str, Dict[str, bool]]]:
|
|
360
|
+
"""Iterate over k-mers with category presence.
|
|
361
|
+
|
|
362
|
+
Yields
|
|
363
|
+
------
|
|
364
|
+
kmer : str
|
|
365
|
+
K-mer sequence.
|
|
366
|
+
categories : dict
|
|
367
|
+
Mapping of category name to presence (True/False).
|
|
368
|
+
"""
|
|
369
|
+
self._check_is_loaded()
|
|
370
|
+
|
|
371
|
+
if self._use_set:
|
|
372
|
+
raise RuntimeError(
|
|
373
|
+
"Category information not available in use_set mode."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if self._lazy:
|
|
377
|
+
yield from self._lazy_iter_with_categories()
|
|
378
|
+
return
|
|
379
|
+
|
|
380
|
+
for kmer, cats in self._kmers.items():
|
|
381
|
+
yield kmer, cats.copy()
|
|
382
|
+
|
|
383
|
+
def _lazy_iter_with_categories(self) -> Iterator[Tuple[str, Dict[str, bool]]]:
|
|
384
|
+
"""Iterate with categories by streaming file (lazy mode)."""
|
|
385
|
+
with open(self._data_path, 'r', newline='') as f:
|
|
386
|
+
reader = csv.DictReader(f)
|
|
387
|
+
for row in reader:
|
|
388
|
+
kmer = row['seq']
|
|
389
|
+
cats = {
|
|
390
|
+
cat: row.get(cat, 'False') == 'True'
|
|
391
|
+
for cat in ALL_CATEGORIES
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
if self._categories is not None:
|
|
395
|
+
present = any(cats.get(cat, False) for cat in self._categories)
|
|
396
|
+
if present:
|
|
397
|
+
yield kmer, cats
|
|
398
|
+
else:
|
|
399
|
+
yield kmer, cats
|
|
400
|
+
|
|
401
|
+
def __len__(self) -> int:
|
|
402
|
+
"""Return number of k-mers in reference."""
|
|
403
|
+
self._check_is_loaded()
|
|
404
|
+
|
|
405
|
+
if self._lazy:
|
|
406
|
+
# Count by streaming
|
|
407
|
+
return sum(1 for _ in self.iter_kmers())
|
|
408
|
+
|
|
409
|
+
if self._use_set:
|
|
410
|
+
return len(self._kmer_set)
|
|
411
|
+
return len(self._kmers)
|
|
412
|
+
|
|
413
|
+
def get_training_data(
|
|
414
|
+
self,
|
|
415
|
+
target_categories: Optional[List[str]] = None,
|
|
416
|
+
max_samples: Optional[int] = None,
|
|
417
|
+
multi_label: bool = False,
|
|
418
|
+
shuffle: bool = False,
|
|
419
|
+
seed: Optional[int] = None,
|
|
420
|
+
) -> Tuple[List[str], 'np.ndarray']:
|
|
421
|
+
"""Generate training data for MLP scorer.
|
|
422
|
+
|
|
423
|
+
Parameters
|
|
424
|
+
----------
|
|
425
|
+
target_categories : list of str, optional
|
|
426
|
+
Categories to use as targets. Default: ['human', 'viruses', 'bacteria'].
|
|
427
|
+
max_samples : int, optional
|
|
428
|
+
Maximum number of samples to return (for memory efficiency).
|
|
429
|
+
multi_label : bool, default=False
|
|
430
|
+
If True, return multi-label array (one column per category).
|
|
431
|
+
If False, return single foreignness score (1 if not in first target category).
|
|
432
|
+
shuffle : bool, default=False
|
|
433
|
+
If True and max_samples is set, randomly sample k-mers using
|
|
434
|
+
reservoir sampling instead of taking the first N.
|
|
435
|
+
seed : int, optional
|
|
436
|
+
Random seed for reproducible sampling.
|
|
437
|
+
|
|
438
|
+
Returns
|
|
439
|
+
-------
|
|
440
|
+
peptides : list of str
|
|
441
|
+
K-mer sequences.
|
|
442
|
+
labels : np.ndarray
|
|
443
|
+
If multi_label=True: shape (n_samples, n_categories) with binary labels.
|
|
444
|
+
If multi_label=False: shape (n_samples,) with foreignness scores.
|
|
445
|
+
|
|
446
|
+
Example
|
|
447
|
+
-------
|
|
448
|
+
>>> ref = SwissProtReference().load()
|
|
449
|
+
>>> peptides, labels = ref.get_training_data(
|
|
450
|
+
... target_categories=['human', 'viruses', 'bacteria'],
|
|
451
|
+
... multi_label=True,
|
|
452
|
+
... max_samples=100000
|
|
453
|
+
... )
|
|
454
|
+
>>> print(labels.shape) # (100000, 3)
|
|
455
|
+
"""
|
|
456
|
+
import numpy as np
|
|
457
|
+
|
|
458
|
+
if target_categories is None:
|
|
459
|
+
target_categories = ['human', 'viruses', 'bacteria']
|
|
460
|
+
|
|
461
|
+
# Validate categories
|
|
462
|
+
for cat in target_categories:
|
|
463
|
+
if cat not in ALL_CATEGORIES:
|
|
464
|
+
raise ValueError(
|
|
465
|
+
f"Unknown category: {cat}. "
|
|
466
|
+
f"Available: {ALL_CATEGORIES}"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
self._check_is_loaded()
|
|
470
|
+
|
|
471
|
+
if self._use_set:
|
|
472
|
+
raise RuntimeError(
|
|
473
|
+
"Training data requires category info. Set use_set=False."
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
if shuffle or seed is not None:
|
|
477
|
+
import random
|
|
478
|
+
if seed is not None:
|
|
479
|
+
random.seed(seed)
|
|
480
|
+
|
|
481
|
+
peptides: List[str] = []
|
|
482
|
+
labels_list: List[List[float]] = []
|
|
483
|
+
|
|
484
|
+
def make_row(cats: Dict[str, bool]) -> List[float]:
|
|
485
|
+
if multi_label:
|
|
486
|
+
return [1.0 if cats.get(cat, False) else 0.0 for cat in target_categories]
|
|
487
|
+
in_self = cats.get(target_categories[0], False)
|
|
488
|
+
return [0.0 if in_self else 1.0]
|
|
489
|
+
|
|
490
|
+
for idx, (kmer, cats) in enumerate(self.iter_kmers_with_categories()):
|
|
491
|
+
if max_samples and shuffle:
|
|
492
|
+
if len(peptides) < max_samples:
|
|
493
|
+
peptides.append(kmer)
|
|
494
|
+
labels_list.append(make_row(cats))
|
|
495
|
+
else:
|
|
496
|
+
j = random.randint(0, idx)
|
|
497
|
+
if j < max_samples:
|
|
498
|
+
peptides[j] = kmer
|
|
499
|
+
labels_list[j] = make_row(cats)
|
|
500
|
+
else:
|
|
501
|
+
peptides.append(kmer)
|
|
502
|
+
labels_list.append(make_row(cats))
|
|
503
|
+
if max_samples and len(peptides) >= max_samples:
|
|
504
|
+
break
|
|
505
|
+
|
|
506
|
+
labels = np.array(labels_list, dtype=np.float32)
|
|
507
|
+
if not multi_label:
|
|
508
|
+
labels = labels.ravel()
|
|
509
|
+
|
|
510
|
+
return peptides, labels
|