weirdo 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- weirdo/__init__.py +104 -0
- weirdo/amino_acid.py +33 -0
- weirdo/amino_acid_alphabet.py +158 -0
- weirdo/amino_acid_properties.py +358 -0
- weirdo/api.py +372 -0
- weirdo/blosum.py +74 -0
- weirdo/chou_fasman.py +73 -0
- weirdo/cli.py +597 -0
- weirdo/common.py +22 -0
- weirdo/data_manager.py +475 -0
- weirdo/distances.py +16 -0
- weirdo/matrices/BLOSUM30 +25 -0
- weirdo/matrices/BLOSUM50 +21 -0
- weirdo/matrices/BLOSUM62 +27 -0
- weirdo/matrices/__init__.py +0 -0
- weirdo/matrices/amino_acid_properties.txt +829 -0
- weirdo/matrices/helix_vs_coil.txt +28 -0
- weirdo/matrices/helix_vs_strand.txt +27 -0
- weirdo/matrices/pmbec.mat +21 -0
- weirdo/matrices/strand_vs_coil.txt +27 -0
- weirdo/model_manager.py +346 -0
- weirdo/peptide_vectorizer.py +78 -0
- weirdo/pmbec.py +85 -0
- weirdo/reduced_alphabet.py +61 -0
- weirdo/residue_contact_energies.py +74 -0
- weirdo/scorers/__init__.py +95 -0
- weirdo/scorers/base.py +223 -0
- weirdo/scorers/config.py +299 -0
- weirdo/scorers/mlp.py +1126 -0
- weirdo/scorers/reference.py +265 -0
- weirdo/scorers/registry.py +282 -0
- weirdo/scorers/similarity.py +386 -0
- weirdo/scorers/swissprot.py +510 -0
- weirdo/scorers/trainable.py +219 -0
- weirdo/static_data.py +17 -0
- weirdo-2.1.0.dist-info/METADATA +294 -0
- weirdo-2.1.0.dist-info/RECORD +41 -0
- weirdo-2.1.0.dist-info/WHEEL +5 -0
- weirdo-2.1.0.dist-info/entry_points.txt +2 -0
- weirdo-2.1.0.dist-info/licenses/LICENSE +201 -0
- weirdo-2.1.0.dist-info/top_level.txt +1 -0
weirdo/cli.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
1
|
+
"""Command-line interface for WEIRDO.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
weirdo data status # Show data status
|
|
5
|
+
weirdo data download # Download reference data
|
|
6
|
+
weirdo data clear # Clear all data
|
|
7
|
+
weirdo score --model NAME PEPTIDE... # Score peptides
|
|
8
|
+
weirdo models list # List trained models
|
|
9
|
+
weirdo models train # Train a new model
|
|
10
|
+
weirdo models info NAME # Show model info
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import sys
|
|
14
|
+
import argparse
|
|
15
|
+
|
|
16
|
+
from .reduced_alphabet import alphabets
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_parser():
|
|
20
|
+
"""Create argument parser."""
|
|
21
|
+
parser = argparse.ArgumentParser(
|
|
22
|
+
prog='weirdo',
|
|
23
|
+
description='WEIRDO: Widely Estimated Immunological Recognition and Detection of Outliers',
|
|
24
|
+
)
|
|
25
|
+
subparsers = parser.add_subparsers(dest='command', help='Available commands')
|
|
26
|
+
|
|
27
|
+
# -------------------------------------------------------------------------
|
|
28
|
+
# Data management commands
|
|
29
|
+
# -------------------------------------------------------------------------
|
|
30
|
+
data_parser = subparsers.add_parser('data', help='Manage reference data')
|
|
31
|
+
data_subparsers = data_parser.add_subparsers(dest='data_command', help='Data commands')
|
|
32
|
+
|
|
33
|
+
# data list (aliased as 'ls' and 'status')
|
|
34
|
+
list_parser = data_subparsers.add_parser('list', help='List datasets with status')
|
|
35
|
+
data_subparsers.add_parser('ls', help='Alias for list')
|
|
36
|
+
data_subparsers.add_parser('status', help='Alias for list')
|
|
37
|
+
|
|
38
|
+
# data download
|
|
39
|
+
download_parser = data_subparsers.add_parser('download', help='Download reference data')
|
|
40
|
+
download_parser.add_argument(
|
|
41
|
+
'dataset',
|
|
42
|
+
nargs='?',
|
|
43
|
+
default='swissprot-8mers',
|
|
44
|
+
help='Dataset to download (default: swissprot-8mers)',
|
|
45
|
+
)
|
|
46
|
+
download_parser.add_argument(
|
|
47
|
+
'--all',
|
|
48
|
+
action='store_true',
|
|
49
|
+
help='Download all available datasets',
|
|
50
|
+
)
|
|
51
|
+
download_parser.add_argument(
|
|
52
|
+
'-f', '--force',
|
|
53
|
+
action='store_true',
|
|
54
|
+
help='Force re-download even if already present',
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# data clear
|
|
58
|
+
clear_parser = data_subparsers.add_parser('clear', help='Clear downloaded data')
|
|
59
|
+
clear_parser.add_argument(
|
|
60
|
+
'--downloads',
|
|
61
|
+
action='store_true',
|
|
62
|
+
help='Clear only downloaded data files',
|
|
63
|
+
)
|
|
64
|
+
clear_parser.add_argument(
|
|
65
|
+
'--all',
|
|
66
|
+
action='store_true',
|
|
67
|
+
help='Clear all downloaded data',
|
|
68
|
+
)
|
|
69
|
+
clear_parser.add_argument(
|
|
70
|
+
'-y', '--yes',
|
|
71
|
+
action='store_true',
|
|
72
|
+
help='Skip confirmation prompt',
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# data path
|
|
76
|
+
path_parser = data_subparsers.add_parser('path', help='Show path to data directory')
|
|
77
|
+
|
|
78
|
+
# -------------------------------------------------------------------------
|
|
79
|
+
# Score command
|
|
80
|
+
# -------------------------------------------------------------------------
|
|
81
|
+
score_parser = subparsers.add_parser('score', help='Score peptides for foreignness')
|
|
82
|
+
score_parser.add_argument(
|
|
83
|
+
'peptides',
|
|
84
|
+
nargs='+',
|
|
85
|
+
help='Peptide sequences to score',
|
|
86
|
+
)
|
|
87
|
+
score_parser.add_argument(
|
|
88
|
+
'-m', '--model',
|
|
89
|
+
help='Trained model name to use for scoring',
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# -------------------------------------------------------------------------
|
|
93
|
+
# Translate command (legacy)
|
|
94
|
+
# -------------------------------------------------------------------------
|
|
95
|
+
translate_parser = subparsers.add_parser(
|
|
96
|
+
'translate',
|
|
97
|
+
help='Translate amino acid sequences to reduced alphabets',
|
|
98
|
+
)
|
|
99
|
+
translate_inputs = translate_parser.add_mutually_exclusive_group(required=True)
|
|
100
|
+
translate_inputs.add_argument('--input-fasta')
|
|
101
|
+
translate_inputs.add_argument('--input-sequence')
|
|
102
|
+
translate_parser.add_argument(
|
|
103
|
+
'-a', '--alphabet',
|
|
104
|
+
dest='alphabet',
|
|
105
|
+
help='Reduced alphabet name',
|
|
106
|
+
choices=tuple(alphabets.keys()),
|
|
107
|
+
required=True,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# -------------------------------------------------------------------------
|
|
111
|
+
# Setup command
|
|
112
|
+
# -------------------------------------------------------------------------
|
|
113
|
+
setup_parser = subparsers.add_parser('setup', help='Initial setup - download reference data')
|
|
114
|
+
|
|
115
|
+
# -------------------------------------------------------------------------
|
|
116
|
+
# Model management commands
|
|
117
|
+
# -------------------------------------------------------------------------
|
|
118
|
+
models_parser = subparsers.add_parser('models', help='Manage trained ML models')
|
|
119
|
+
models_subparsers = models_parser.add_subparsers(dest='models_command', help='Model commands')
|
|
120
|
+
|
|
121
|
+
# models list
|
|
122
|
+
models_list_parser = models_subparsers.add_parser('list', help='List trained models')
|
|
123
|
+
models_subparsers.add_parser('ls', help='Alias for list')
|
|
124
|
+
|
|
125
|
+
# models info
|
|
126
|
+
models_info_parser = models_subparsers.add_parser('info', help='Show model details')
|
|
127
|
+
models_info_parser.add_argument('name', help='Model name')
|
|
128
|
+
|
|
129
|
+
# models delete
|
|
130
|
+
models_delete_parser = models_subparsers.add_parser('delete', help='Delete a trained model')
|
|
131
|
+
models_delete_parser.add_argument('name', help='Model name to delete')
|
|
132
|
+
models_delete_parser.add_argument('-y', '--yes', action='store_true', help='Skip confirmation')
|
|
133
|
+
|
|
134
|
+
# models train
|
|
135
|
+
models_train_parser = models_subparsers.add_parser('train', help='Train a new model')
|
|
136
|
+
models_train_parser.add_argument(
|
|
137
|
+
'--type',
|
|
138
|
+
default='mlp',
|
|
139
|
+
choices=['mlp'],
|
|
140
|
+
help='Model type to train (default: mlp)',
|
|
141
|
+
)
|
|
142
|
+
models_train_parser.add_argument(
|
|
143
|
+
'--data',
|
|
144
|
+
required=True,
|
|
145
|
+
help='Training data CSV (columns: peptide, label or peptide + category columns)',
|
|
146
|
+
)
|
|
147
|
+
models_train_parser.add_argument(
|
|
148
|
+
'--val-data',
|
|
149
|
+
help='Validation data CSV (optional)',
|
|
150
|
+
)
|
|
151
|
+
models_train_parser.add_argument(
|
|
152
|
+
'--name',
|
|
153
|
+
help='Name for saved model (default: auto-generated)',
|
|
154
|
+
)
|
|
155
|
+
models_train_parser.add_argument(
|
|
156
|
+
'--epochs',
|
|
157
|
+
type=int,
|
|
158
|
+
default=100,
|
|
159
|
+
help='Training epochs (default: 100)',
|
|
160
|
+
)
|
|
161
|
+
models_train_parser.add_argument(
|
|
162
|
+
'--lr',
|
|
163
|
+
type=float,
|
|
164
|
+
default=1e-3,
|
|
165
|
+
help='Learning rate (default: 1e-3)',
|
|
166
|
+
)
|
|
167
|
+
models_train_parser.add_argument(
|
|
168
|
+
'--k',
|
|
169
|
+
type=int,
|
|
170
|
+
default=8,
|
|
171
|
+
help='K-mer size (default: 8)',
|
|
172
|
+
)
|
|
173
|
+
models_train_parser.add_argument(
|
|
174
|
+
'--hidden-layers',
|
|
175
|
+
default='256,128,64',
|
|
176
|
+
help='Hidden layer sizes, comma-separated (default: 256,128,64)',
|
|
177
|
+
)
|
|
178
|
+
models_train_parser.add_argument(
|
|
179
|
+
'--overwrite',
|
|
180
|
+
action='store_true',
|
|
181
|
+
help='Overwrite existing model with same name',
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# models path
|
|
185
|
+
models_path_parser = models_subparsers.add_parser('path', help='Show models directory')
|
|
186
|
+
|
|
187
|
+
# models scorers
|
|
188
|
+
models_scorers_parser = models_subparsers.add_parser('scorers', help='List available scorer types')
|
|
189
|
+
|
|
190
|
+
return parser
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def cmd_data_list(args):
|
|
194
|
+
"""Handle: weirdo data list/ls/status"""
|
|
195
|
+
from .data_manager import get_data_manager
|
|
196
|
+
dm = get_data_manager()
|
|
197
|
+
dm.print_status()
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def cmd_data_download(args):
|
|
201
|
+
"""Handle: weirdo data download"""
|
|
202
|
+
from .data_manager import get_data_manager, DATASETS
|
|
203
|
+
dm = get_data_manager()
|
|
204
|
+
|
|
205
|
+
if args.all:
|
|
206
|
+
dm.download_all(force=args.force)
|
|
207
|
+
else:
|
|
208
|
+
if args.dataset not in DATASETS:
|
|
209
|
+
print(f"Unknown dataset: {args.dataset}")
|
|
210
|
+
print(f"Available: {list(DATASETS.keys())}")
|
|
211
|
+
return 1
|
|
212
|
+
dm.download(args.dataset, force=args.force)
|
|
213
|
+
return 0
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def cmd_data_clear(args):
|
|
217
|
+
"""Handle: weirdo data clear"""
|
|
218
|
+
from .data_manager import get_data_manager
|
|
219
|
+
dm = get_data_manager()
|
|
220
|
+
|
|
221
|
+
# Determine what to clear
|
|
222
|
+
clear_downloads = args.downloads or args.all or (not args.downloads)
|
|
223
|
+
|
|
224
|
+
# Confirm
|
|
225
|
+
if not args.yes:
|
|
226
|
+
status = dm.status()
|
|
227
|
+
size_mb = status['total_size_mb']
|
|
228
|
+
print(f"This will delete downloads ({size_mb:.1f} MB)")
|
|
229
|
+
response = input("Continue? [y/N] ")
|
|
230
|
+
if response.lower() not in ('y', 'yes'):
|
|
231
|
+
print("Aborted.")
|
|
232
|
+
return 1
|
|
233
|
+
|
|
234
|
+
# Clear
|
|
235
|
+
if clear_downloads:
|
|
236
|
+
count = dm.delete_all_downloads()
|
|
237
|
+
print(f"Deleted {count} downloads")
|
|
238
|
+
|
|
239
|
+
return 0
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def cmd_data_path(args):
|
|
243
|
+
"""Handle: weirdo data path"""
|
|
244
|
+
from .data_manager import get_data_manager
|
|
245
|
+
dm = get_data_manager()
|
|
246
|
+
print(dm.data_dir)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def cmd_score(args):
|
|
250
|
+
"""Handle: weirdo score"""
|
|
251
|
+
if not args.model:
|
|
252
|
+
print("A trained model is required to score peptides.")
|
|
253
|
+
print("Train one with: weirdo models train --data train.csv --name my-model")
|
|
254
|
+
return 1
|
|
255
|
+
|
|
256
|
+
from .model_manager import load_model
|
|
257
|
+
|
|
258
|
+
print(f"Scoring {len(args.peptides)} peptide(s) with model '{args.model}'...")
|
|
259
|
+
print()
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
scorer = load_model(args.model)
|
|
263
|
+
if getattr(scorer, 'target_categories', None):
|
|
264
|
+
df = scorer.predict_dataframe(args.peptides)
|
|
265
|
+
print(df.to_string(index=False))
|
|
266
|
+
print()
|
|
267
|
+
print("Foreignness is derived from max(pathogens) vs max(self).")
|
|
268
|
+
else:
|
|
269
|
+
scores = scorer.score(args.peptides)
|
|
270
|
+
print(f"{'Peptide':<40} {'Score':>10}")
|
|
271
|
+
print("-" * 52)
|
|
272
|
+
for pep, score in zip(args.peptides, scores):
|
|
273
|
+
display_pep = pep if len(pep) <= 37 else pep[:34] + '...'
|
|
274
|
+
print(f"{display_pep:<40} {score:>10.4f}")
|
|
275
|
+
print()
|
|
276
|
+
print("Higher scores = more foreign")
|
|
277
|
+
except FileNotFoundError as e:
|
|
278
|
+
print(f"Error: {e}")
|
|
279
|
+
return 1
|
|
280
|
+
|
|
281
|
+
return 0
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def cmd_translate(args):
|
|
285
|
+
"""Handle: weirdo translate"""
|
|
286
|
+
alphabet = alphabets[args.alphabet]
|
|
287
|
+
|
|
288
|
+
if args.input_sequence:
|
|
289
|
+
result = "".join([alphabet.get(aa, aa) for aa in args.input_sequence])
|
|
290
|
+
print(f"{args.input_sequence} -> {result}")
|
|
291
|
+
elif args.input_fasta:
|
|
292
|
+
print("FASTA translation not yet implemented")
|
|
293
|
+
return 1
|
|
294
|
+
|
|
295
|
+
return 0
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def cmd_setup(args):
|
|
299
|
+
"""Handle: weirdo setup"""
|
|
300
|
+
from .data_manager import get_data_manager
|
|
301
|
+
dm = get_data_manager()
|
|
302
|
+
|
|
303
|
+
print("WEIRDO Setup")
|
|
304
|
+
print("=" * 60)
|
|
305
|
+
print()
|
|
306
|
+
|
|
307
|
+
# Download data
|
|
308
|
+
print("Step 1: Downloading reference data...")
|
|
309
|
+
dm.download('swissprot-8mers')
|
|
310
|
+
print()
|
|
311
|
+
|
|
312
|
+
print("Setup complete!")
|
|
313
|
+
print()
|
|
314
|
+
dm.print_status()
|
|
315
|
+
|
|
316
|
+
return 0
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# -------------------------------------------------------------------------
|
|
320
|
+
# Model command handlers
|
|
321
|
+
# -------------------------------------------------------------------------
|
|
322
|
+
|
|
323
|
+
def cmd_models_list(args):
|
|
324
|
+
"""Handle: weirdo models list"""
|
|
325
|
+
from .model_manager import get_model_manager
|
|
326
|
+
mm = get_model_manager()
|
|
327
|
+
mm.print_models()
|
|
328
|
+
return 0
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def cmd_models_info(args):
|
|
332
|
+
"""Handle: weirdo models info NAME"""
|
|
333
|
+
from .model_manager import get_model_manager
|
|
334
|
+
mm = get_model_manager()
|
|
335
|
+
|
|
336
|
+
info = mm.get_model_info(args.name)
|
|
337
|
+
if info is None:
|
|
338
|
+
print(f"Model not found: {args.name}")
|
|
339
|
+
return 1
|
|
340
|
+
|
|
341
|
+
print(f"Model: {info.name}")
|
|
342
|
+
print("=" * 60)
|
|
343
|
+
print(f" Type: {info.scorer_type}")
|
|
344
|
+
print(f" Path: {info.path}")
|
|
345
|
+
if info.created:
|
|
346
|
+
print(f" Created: {info.created[:19]}")
|
|
347
|
+
print()
|
|
348
|
+
|
|
349
|
+
print("Parameters:")
|
|
350
|
+
for key, value in info.params.items():
|
|
351
|
+
print(f" {key}: {value}")
|
|
352
|
+
print()
|
|
353
|
+
|
|
354
|
+
if info.metadata:
|
|
355
|
+
print("Training info:")
|
|
356
|
+
if 'n_train' in info.metadata:
|
|
357
|
+
print(f" Training samples: {info.metadata['n_train']}")
|
|
358
|
+
if 'n_epochs' in info.metadata:
|
|
359
|
+
print(f" Epochs trained: {info.metadata['n_epochs']}")
|
|
360
|
+
elif 'n_iter' in info.metadata:
|
|
361
|
+
print(f" Epochs trained: {info.metadata['n_iter']}")
|
|
362
|
+
if 'final_train_loss' in info.metadata:
|
|
363
|
+
print(f" Final train loss: {info.metadata['final_train_loss']:.4f}")
|
|
364
|
+
elif 'loss' in info.metadata:
|
|
365
|
+
print(f" Final train loss: {info.metadata['loss']:.4f}")
|
|
366
|
+
if 'best_val_loss' in info.metadata:
|
|
367
|
+
print(f" Best val loss: {info.metadata['best_val_loss']:.4f}")
|
|
368
|
+
elif 'best_loss' in info.metadata:
|
|
369
|
+
print(f" Best val loss: {info.metadata['best_loss']:.4f}")
|
|
370
|
+
|
|
371
|
+
return 0
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def cmd_models_delete(args):
|
|
375
|
+
"""Handle: weirdo models delete NAME"""
|
|
376
|
+
from .model_manager import get_model_manager
|
|
377
|
+
mm = get_model_manager()
|
|
378
|
+
|
|
379
|
+
info = mm.get_model_info(args.name)
|
|
380
|
+
if info is None:
|
|
381
|
+
print(f"Model not found: {args.name}")
|
|
382
|
+
return 1
|
|
383
|
+
|
|
384
|
+
if not args.yes:
|
|
385
|
+
response = input(f"Delete model '{args.name}'? [y/N] ")
|
|
386
|
+
if response.lower() not in ('y', 'yes'):
|
|
387
|
+
print("Aborted.")
|
|
388
|
+
return 1
|
|
389
|
+
|
|
390
|
+
if mm.delete(args.name):
|
|
391
|
+
print(f"Deleted model: {args.name}")
|
|
392
|
+
else:
|
|
393
|
+
print(f"Failed to delete model: {args.name}")
|
|
394
|
+
return 1
|
|
395
|
+
|
|
396
|
+
return 0
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def cmd_models_train(args):
|
|
400
|
+
"""Handle: weirdo models train"""
|
|
401
|
+
import csv
|
|
402
|
+
from datetime import datetime
|
|
403
|
+
|
|
404
|
+
from .model_manager import get_model_manager
|
|
405
|
+
mm = get_model_manager()
|
|
406
|
+
|
|
407
|
+
# Generate default name if not provided
|
|
408
|
+
if args.name:
|
|
409
|
+
model_name = args.name
|
|
410
|
+
else:
|
|
411
|
+
timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
|
|
412
|
+
model_name = f"mlp-{timestamp}"
|
|
413
|
+
|
|
414
|
+
# Check if model already exists
|
|
415
|
+
if mm.get_model_info(model_name) and not args.overwrite:
|
|
416
|
+
print(f"Model already exists: {model_name}")
|
|
417
|
+
print("Use --overwrite to replace.")
|
|
418
|
+
return 1
|
|
419
|
+
|
|
420
|
+
# Load training data
|
|
421
|
+
print(f"Loading training data from {args.data}...")
|
|
422
|
+
peptides = []
|
|
423
|
+
labels = []
|
|
424
|
+
target_categories = None
|
|
425
|
+
with open(args.data, 'r') as f:
|
|
426
|
+
reader = csv.DictReader(f)
|
|
427
|
+
if not reader.fieldnames or 'peptide' not in reader.fieldnames:
|
|
428
|
+
print("Training CSV must include a 'peptide' column.")
|
|
429
|
+
return 1
|
|
430
|
+
label_columns = [c for c in reader.fieldnames if c != 'peptide']
|
|
431
|
+
if not label_columns:
|
|
432
|
+
print("Training CSV must include at least one label column.")
|
|
433
|
+
return 1
|
|
434
|
+
|
|
435
|
+
for row in reader:
|
|
436
|
+
peptides.append(row['peptide'])
|
|
437
|
+
if label_columns == ['label']:
|
|
438
|
+
labels.append(float(row['label']))
|
|
439
|
+
else:
|
|
440
|
+
labels.append([float(row[c]) for c in label_columns])
|
|
441
|
+
if label_columns != ['label']:
|
|
442
|
+
target_categories = label_columns
|
|
443
|
+
|
|
444
|
+
print(f" Loaded {len(peptides)} samples")
|
|
445
|
+
if target_categories:
|
|
446
|
+
print(f" Target categories: {', '.join(target_categories)}")
|
|
447
|
+
|
|
448
|
+
# Load validation data if provided
|
|
449
|
+
val_peptides = None
|
|
450
|
+
val_labels = None
|
|
451
|
+
if args.val_data:
|
|
452
|
+
print(f"Loading validation data from {args.val_data}...")
|
|
453
|
+
val_peptides = []
|
|
454
|
+
val_labels = []
|
|
455
|
+
with open(args.val_data, 'r') as f:
|
|
456
|
+
reader = csv.DictReader(f)
|
|
457
|
+
if not reader.fieldnames or 'peptide' not in reader.fieldnames:
|
|
458
|
+
print("Validation CSV must include a 'peptide' column.")
|
|
459
|
+
return 1
|
|
460
|
+
val_label_columns = [c for c in reader.fieldnames if c != 'peptide']
|
|
461
|
+
if target_categories:
|
|
462
|
+
if val_label_columns != target_categories:
|
|
463
|
+
print("Validation label columns must match training labels.")
|
|
464
|
+
print(f"Expected: {target_categories}")
|
|
465
|
+
print(f"Found: {val_label_columns}")
|
|
466
|
+
return 1
|
|
467
|
+
else:
|
|
468
|
+
if val_label_columns != ['label']:
|
|
469
|
+
print("Validation CSV must include a single 'label' column.")
|
|
470
|
+
return 1
|
|
471
|
+
for row in reader:
|
|
472
|
+
val_peptides.append(row['peptide'])
|
|
473
|
+
if target_categories:
|
|
474
|
+
val_labels.append([float(row[c]) for c in val_label_columns])
|
|
475
|
+
else:
|
|
476
|
+
val_labels.append(float(row['label']))
|
|
477
|
+
print(f" Loaded {len(val_peptides)} validation samples")
|
|
478
|
+
|
|
479
|
+
# Create model
|
|
480
|
+
from .scorers.mlp import MLPScorer
|
|
481
|
+
hidden_layers = tuple(int(x) for x in args.hidden_layers.split(','))
|
|
482
|
+
|
|
483
|
+
print()
|
|
484
|
+
print(f"Training {args.type} model...")
|
|
485
|
+
print(f" K-mer size: {args.k}")
|
|
486
|
+
print(f" Hidden layers: {hidden_layers}")
|
|
487
|
+
print(f" Epochs: {args.epochs}")
|
|
488
|
+
print(f" Learning rate: {args.lr}")
|
|
489
|
+
print()
|
|
490
|
+
|
|
491
|
+
scorer = MLPScorer(
|
|
492
|
+
k=args.k,
|
|
493
|
+
hidden_layer_sizes=hidden_layers,
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# Train
|
|
497
|
+
scorer.train(
|
|
498
|
+
peptides=peptides,
|
|
499
|
+
labels=labels,
|
|
500
|
+
val_peptides=val_peptides,
|
|
501
|
+
val_labels=val_labels,
|
|
502
|
+
epochs=args.epochs,
|
|
503
|
+
learning_rate=args.lr,
|
|
504
|
+
verbose=True,
|
|
505
|
+
target_categories=target_categories,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
# Save
|
|
509
|
+
print()
|
|
510
|
+
print(f"Saving model as '{model_name}'...")
|
|
511
|
+
path = mm.save(scorer, model_name, overwrite=args.overwrite)
|
|
512
|
+
print(f" Saved to: {path}")
|
|
513
|
+
|
|
514
|
+
return 0
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def cmd_models_path(args):
|
|
518
|
+
"""Handle: weirdo models path"""
|
|
519
|
+
from .model_manager import get_model_manager
|
|
520
|
+
mm = get_model_manager()
|
|
521
|
+
print(mm.model_dir)
|
|
522
|
+
return 0
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def cmd_models_scorers(args):
|
|
526
|
+
"""Handle: weirdo models scorers"""
|
|
527
|
+
from .scorers import list_scorers
|
|
528
|
+
|
|
529
|
+
print("Available scorer types:")
|
|
530
|
+
print()
|
|
531
|
+
|
|
532
|
+
scorers = list_scorers()
|
|
533
|
+
for name in scorers:
|
|
534
|
+
print(f" {name:<20} (ML-based, requires training)")
|
|
535
|
+
|
|
536
|
+
print()
|
|
537
|
+
print("ML-based scorers use train() with labeled peptide data.")
|
|
538
|
+
|
|
539
|
+
return 0
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
def run(args_list=None):
|
|
543
|
+
"""Main entry point for CLI."""
|
|
544
|
+
if args_list is None:
|
|
545
|
+
args_list = sys.argv[1:]
|
|
546
|
+
|
|
547
|
+
parser = create_parser()
|
|
548
|
+
args = parser.parse_args(args_list)
|
|
549
|
+
|
|
550
|
+
if args.command is None:
|
|
551
|
+
parser.print_help()
|
|
552
|
+
return 0
|
|
553
|
+
|
|
554
|
+
# Data commands
|
|
555
|
+
if args.command == 'data':
|
|
556
|
+
if args.data_command is None:
|
|
557
|
+
# Default to list
|
|
558
|
+
return cmd_data_list(args)
|
|
559
|
+
elif args.data_command in ('list', 'ls', 'status'):
|
|
560
|
+
return cmd_data_list(args)
|
|
561
|
+
elif args.data_command == 'download':
|
|
562
|
+
return cmd_data_download(args)
|
|
563
|
+
elif args.data_command == 'clear':
|
|
564
|
+
return cmd_data_clear(args)
|
|
565
|
+
elif args.data_command == 'path':
|
|
566
|
+
return cmd_data_path(args)
|
|
567
|
+
|
|
568
|
+
elif args.command == 'score':
|
|
569
|
+
return cmd_score(args)
|
|
570
|
+
|
|
571
|
+
elif args.command == 'translate':
|
|
572
|
+
return cmd_translate(args)
|
|
573
|
+
|
|
574
|
+
elif args.command == 'setup':
|
|
575
|
+
return cmd_setup(args)
|
|
576
|
+
|
|
577
|
+
elif args.command == 'models':
|
|
578
|
+
if args.models_command is None:
|
|
579
|
+
return cmd_models_list(args)
|
|
580
|
+
elif args.models_command in ('list', 'ls'):
|
|
581
|
+
return cmd_models_list(args)
|
|
582
|
+
elif args.models_command == 'info':
|
|
583
|
+
return cmd_models_info(args)
|
|
584
|
+
elif args.models_command == 'delete':
|
|
585
|
+
return cmd_models_delete(args)
|
|
586
|
+
elif args.models_command == 'train':
|
|
587
|
+
return cmd_models_train(args)
|
|
588
|
+
elif args.models_command == 'path':
|
|
589
|
+
return cmd_models_path(args)
|
|
590
|
+
elif args.models_command == 'scorers':
|
|
591
|
+
return cmd_models_scorers(args)
|
|
592
|
+
|
|
593
|
+
return 0
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
if __name__ == '__main__':
|
|
597
|
+
sys.exit(run() or 0)
|
weirdo/common.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
2
|
+
# you may not use this file except in compliance with the License.
|
|
3
|
+
# You may obtain a copy of the License at
|
|
4
|
+
#
|
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
6
|
+
#
|
|
7
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
8
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
9
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
10
|
+
# See the License for the specific language governing permissions and
|
|
11
|
+
# limitations under the License.
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
def transform_peptide(peptide, property_dict):
|
|
16
|
+
return np.array([property_dict[amino_acid] for amino_acid in peptide])
|
|
17
|
+
|
|
18
|
+
def transform_peptides(peptides, property_dict):
|
|
19
|
+
return np.array([
|
|
20
|
+
[property_dict[aa] for aa in peptide]
|
|
21
|
+
for peptide in peptides])
|
|
22
|
+
|