protein-quest 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/cli.py ADDED
@@ -0,0 +1,782 @@
1
+ """Module for cli parsers and handlers."""
2
+
3
+ import argparse
4
+ import asyncio
5
+ import csv
6
+ import logging
7
+ import os
8
+ from collections.abc import Callable, Iterable
9
+ from importlib.util import find_spec
10
+ from io import TextIOWrapper
11
+ from pathlib import Path
12
+ from textwrap import dedent
13
+
14
+ from cattrs import structure
15
+ from rich import print as rprint
16
+ from rich.logging import RichHandler
17
+ from rich_argparse import ArgumentDefaultsRichHelpFormatter
18
+ from tqdm.rich import tqdm
19
+
20
+ from protein_quest.__version__ import __version__
21
+ from protein_quest.alphafold.confidence import ConfidenceFilterQuery, filter_files_on_confidence
22
+ from protein_quest.alphafold.fetch import DownloadableFormat, downloadable_formats
23
+ from protein_quest.alphafold.fetch import fetch_many as af_fetch
24
+ from protein_quest.emdb import fetch as emdb_fetch
25
+ from protein_quest.filters import filter_files_on_chain, filter_files_on_residues
26
+ from protein_quest.go import Aspect, allowed_aspects, search_gene_ontology_term, write_go_terms_to_csv
27
+ from protein_quest.pdbe import fetch as pdbe_fetch
28
+ from protein_quest.pdbe.io import glob_structure_files
29
+ from protein_quest.taxonomy import SearchField, _write_taxonomy_csv, search_fields, search_taxon
30
+ from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def _add_search_uniprot_parser(subparsers: argparse._SubParsersAction):
36
+ """Add search uniprot subcommand parser."""
37
+ parser = subparsers.add_parser(
38
+ "uniprot",
39
+ help="Search UniProt accessions",
40
+ description="Search for UniProt accessions based on various criteria in the Uniprot SPARQL endpoint.",
41
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
42
+ )
43
+ parser.add_argument(
44
+ "output",
45
+ type=argparse.FileType("w", encoding="UTF-8"),
46
+ help="Output text file for UniProt accessions (one per line). Use `-` for stdout.",
47
+ )
48
+ parser.add_argument("--taxon-id", type=str, help="NCBI Taxon ID, e.g. 9606 for Homo Sapiens")
49
+ parser.add_argument(
50
+ "--reviewed",
51
+ action=argparse.BooleanOptionalAction,
52
+ help="Reviewed=swissprot, no-reviewed=trembl. Default is uniprot=swissprot+trembl.",
53
+ default=None,
54
+ )
55
+ parser.add_argument(
56
+ "--subcellular-location-uniprot",
57
+ type=str,
58
+ help="Subcellular location label as used by UniProt (e.g. nucleus)",
59
+ )
60
+ parser.add_argument(
61
+ "--subcellular-location-go",
62
+ dest="subcellular_location_go",
63
+ action="append",
64
+ help="GO term(s) for subcellular location (e.g. GO:0005634). Can be given multiple times.",
65
+ )
66
+ parser.add_argument(
67
+ "--molecular-function-go",
68
+ dest="molecular_function_go",
69
+ action="append",
70
+ help="GO term(s) for molecular function (e.g. GO:0003677). Can be given multiple times.",
71
+ )
72
+ parser.add_argument("--limit", type=int, default=10_000, help="Maximum number of uniprot accessions to return")
73
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
74
+
75
+
76
+ def _add_search_pdbe_parser(subparsers: argparse._SubParsersAction):
77
+ """Add search pdbe subcommand parser."""
78
+ parser = subparsers.add_parser(
79
+ "pdbe",
80
+ help="Search PDBe structures of given UniProt accessions",
81
+ description="Search for PDB structures of given UniProt accessions in the Uniprot SPARQL endpoint.",
82
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
83
+ )
84
+ parser.add_argument(
85
+ "uniprot_accs",
86
+ type=argparse.FileType("r", encoding="UTF-8"),
87
+ help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
88
+ )
89
+ parser.add_argument(
90
+ "output_csv",
91
+ type=argparse.FileType("w", encoding="UTF-8"),
92
+ help=dedent("""\
93
+ Output CSV with `uniprot_acc`, `pdb_id`, `method`, `resolution`, `uniprot_chains`, `chain` columns.
94
+ Where `uniprot_chains` is the raw UniProt chain string, for example `A=1-100`.
95
+ and where `chain` is the first chain from `uniprot_chains`, for example `A`.
96
+ Use `-` for stdout.
97
+ """),
98
+ )
99
+ parser.add_argument(
100
+ "--limit", type=int, default=10_000, help="Maximum number of PDB uniprot accessions combinations to return"
101
+ )
102
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
103
+
104
+
105
+ def _add_search_alphafold_parser(subparsers: argparse._SubParsersAction):
106
+ """Add search alphafold subcommand parser."""
107
+ parser = subparsers.add_parser(
108
+ "alphafold",
109
+ help="Search AlphaFold structures of given UniProt accessions",
110
+ description="Search for AlphaFold structures of given UniProt accessions in the Uniprot SPARQL endpoint.",
111
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
112
+ )
113
+ parser.add_argument(
114
+ "uniprot_accs",
115
+ type=argparse.FileType("r", encoding="UTF-8"),
116
+ help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
117
+ )
118
+ parser.add_argument(
119
+ "output_csv",
120
+ type=argparse.FileType("w", encoding="UTF-8"),
121
+ help="Output CSV with AlphaFold IDs per UniProt accession. Use `-` for stdout.",
122
+ )
123
+ parser.add_argument(
124
+ "--limit", type=int, default=10_000, help="Maximum number of Alphafold entry identifiers to return"
125
+ )
126
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
127
+
128
+
129
+ def _add_search_emdb_parser(subparsers: argparse._SubParsersAction):
130
+ """Add search emdb subcommand parser."""
131
+ parser = subparsers.add_parser(
132
+ "emdb",
133
+ help="Search Electron Microscopy Data Bank (EMDB) identifiers of given UniProt accessions",
134
+ description=dedent("""\
135
+ Search for Electron Microscopy Data Bank (EMDB) identifiers of given UniProt accessions
136
+ in the Uniprot SPARQL endpoint.
137
+ """),
138
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
139
+ )
140
+ parser.add_argument(
141
+ "uniprot_accs",
142
+ type=argparse.FileType("r", encoding="UTF-8"),
143
+ help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
144
+ )
145
+ parser.add_argument(
146
+ "output_csv",
147
+ type=argparse.FileType("w", encoding="UTF-8"),
148
+ help="Output CSV with EMDB IDs per UniProt accession. Use `-` for stdout.",
149
+ )
150
+ parser.add_argument("--limit", type=int, default=10_000, help="Maximum number of EMDB entry identifiers to return")
151
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
152
+
153
+
154
+ def _add_search_go_parser(subparsers: argparse._SubParsersAction):
155
+ """Add search go subcommand parser"""
156
+ parser = subparsers.add_parser(
157
+ "go",
158
+ help="Search for Gene Ontology (GO) terms",
159
+ description="Search for Gene Ontology (GO) terms in the EBI QuickGO API.",
160
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
161
+ )
162
+ parser.add_argument(
163
+ "term",
164
+ type=str,
165
+ help="GO term to search for. For example `apoptosome`.",
166
+ )
167
+ parser.add_argument("--aspect", type=str, choices=allowed_aspects, help="Filter on aspect.")
168
+ parser.add_argument(
169
+ "output_csv",
170
+ type=argparse.FileType("w", encoding="UTF-8"),
171
+ help="Output CSV with GO term results. Use `-` for stdout.",
172
+ )
173
+ parser.add_argument("--limit", type=int, default=100, help="Maximum number of GO term results to return")
174
+
175
+
176
+ def _add_search_taxonomy_parser(subparser: argparse._SubParsersAction):
177
+ """Add search taxonomy subcommand parser."""
178
+ parser = subparser.add_parser(
179
+ "taxonomy",
180
+ help="Search for taxon information in UniProt",
181
+ description=dedent("""\
182
+ Search for taxon information in UniProt.
183
+ Uses https://www.uniprot.org/taxonomy?query=*.
184
+ """),
185
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
186
+ )
187
+ parser.add_argument(
188
+ "query", type=str, help="Search query for the taxon. Surround multiple words with quotes (' or \")."
189
+ )
190
+ parser.add_argument(
191
+ "output_csv",
192
+ type=argparse.FileType("w", encoding="UTF-8"),
193
+ help="Output CSV with taxonomy results. Use `-` for stdout.",
194
+ )
195
+ parser.add_argument(
196
+ "--field",
197
+ type=str,
198
+ choices=search_fields,
199
+ help=dedent("""\
200
+ Field to search in. If not given then searches all fields.
201
+ If "tax_id" then searches by taxon ID.
202
+ If "parent" then given a parent taxon ID returns all its children.
203
+ For example, if the parent taxon ID is 9606 (Human), it will return Neanderthal and Denisovan.
204
+ """),
205
+ )
206
+ parser.add_argument("--limit", type=int, default=100, help="Maximum number of results to return")
207
+
208
+
209
+ def _add_retrieve_pdbe_parser(subparsers: argparse._SubParsersAction):
210
+ """Add retrieve pdbe subcommand parser."""
211
+ parser = subparsers.add_parser(
212
+ "pdbe",
213
+ help="Retrieve PDBe gzipped mmCIF files for PDB IDs in CSV.",
214
+ description=dedent("""\
215
+ Retrieve mmCIF files from Protein Data Bank in Europe Knowledge Base (PDBe) website
216
+ for unique PDB IDs listed in a CSV file.
217
+ """),
218
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
219
+ )
220
+ parser.add_argument(
221
+ "pdbe_csv",
222
+ type=argparse.FileType("r", encoding="UTF-8"),
223
+ help="CSV file with `pdb_id` column. Other columns are ignored. Use `-` for stdin.",
224
+ )
225
+ parser.add_argument("output_dir", type=Path, help="Directory to store downloaded PDBe mmCIF files")
226
+ parser.add_argument(
227
+ "--max-parallel-downloads",
228
+ type=int,
229
+ default=5,
230
+ help="Maximum number of parallel downloads",
231
+ )
232
+
233
+
234
+ def _add_retrieve_alphafold_parser(subparsers: argparse._SubParsersAction):
235
+ """Add retrieve alphafold subcommand parser."""
236
+ parser = subparsers.add_parser(
237
+ "alphafold",
238
+ help="Retrieve AlphaFold files for IDs in CSV",
239
+ description="Retrieve AlphaFold files from the AlphaFold Protein Structure Database.",
240
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
241
+ )
242
+ parser.add_argument(
243
+ "alphafold_csv",
244
+ type=argparse.FileType("r", encoding="UTF-8"),
245
+ help="CSV file with `af_id` column. Other columns are ignored. Use `-` for stdin.",
246
+ )
247
+ parser.add_argument("output_dir", type=Path, help="Directory to store downloaded AlphaFold files")
248
+ parser.add_argument(
249
+ "--what-af-formats",
250
+ type=str,
251
+ action="append",
252
+ choices=sorted(downloadable_formats),
253
+ help=dedent("""AlphaFold formats to retrieve. Can be specified multiple times.
254
+ Default is 'pdb'. Summary is always downloaded as `<entryId>.json`."""),
255
+ )
256
+ parser.add_argument(
257
+ "--max-parallel-downloads",
258
+ type=int,
259
+ default=5,
260
+ help="Maximum number of parallel downloads",
261
+ )
262
+
263
+
264
+ def _add_retrieve_emdb_parser(subparsers: argparse._SubParsersAction):
265
+ """Add retrieve emdb subcommand parser."""
266
+ parser = subparsers.add_parser(
267
+ "emdb",
268
+ help="Retrieve Electron Microscopy Data Bank (EMDB) gzipped 3D volume files for EMDB IDs in CSV.",
269
+ description=dedent("""\
270
+ Retrieve volume files from Electron Microscopy Data Bank (EMDB) website
271
+ for unique EMDB IDs listed in a CSV file.
272
+ """),
273
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
274
+ )
275
+ parser.add_argument(
276
+ "emdb_csv",
277
+ type=argparse.FileType("r", encoding="UTF-8"),
278
+ help="CSV file with `emdb_id` column. Other columns are ignored. Use `-` for stdin.",
279
+ )
280
+ parser.add_argument("output_dir", type=Path, help="Directory to store downloaded EMDB volume files")
281
+
282
+
283
+ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
284
+ """Add filter confidence subcommand parser."""
285
+ parser = subparsers.add_parser(
286
+ "confidence",
287
+ help="Filter AlphaFold mmcif/PDB files by confidence",
288
+ description=dedent("""\
289
+ Filter AlphaFold mmcif/PDB files by confidence (plDDT).
290
+ Passed files are written with residues below threshold removed."""),
291
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
292
+ )
293
+ parser.add_argument("input_dir", type=Path, help="Directory with AlphaFold mmcif/PDB files")
294
+ parser.add_argument("output_dir", type=Path, help="Directory to write filtered mmcif/PDB files")
295
+ parser.add_argument("--confidence-threshold", type=float, default=70, help="pLDDT confidence threshold (0-100)")
296
+ parser.add_argument(
297
+ "--min-residues", type=int, default=0, help="Minimum number of high-confidence residues a structure should have"
298
+ )
299
+ parser.add_argument(
300
+ "--max-residues",
301
+ type=int,
302
+ default=10_000_000,
303
+ help="Maximum number of high-confidence residues a structure should have",
304
+ )
305
+ parser.add_argument(
306
+ "--write-stats",
307
+ type=argparse.FileType("w", encoding="UTF-8"),
308
+ help=dedent("""\
309
+ Write filter statistics to file.
310
+ In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
311
+ Use `-` for stdout."""),
312
+ )
313
+
314
+
315
+ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
316
+ """Add filter chain subcommand parser."""
317
+ parser = subparsers.add_parser(
318
+ "chain",
319
+ help="Filter on chain.",
320
+ description=dedent("""\
321
+ For each input PDB/mmCIF and chain combination
322
+ write a PDB/mmCIF file with just the given chain
323
+ and rename it to chain `A`.
324
+ Filtering is done in parallel using a Dask cluster."""),
325
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
326
+ )
327
+ parser.add_argument(
328
+ "chains",
329
+ type=argparse.FileType("r", encoding="UTF-8"),
330
+ help="CSV file with `pdb_id` and `chain` columns. Other columns are ignored.",
331
+ )
332
+ parser.add_argument(
333
+ "input_dir",
334
+ type=Path,
335
+ help=dedent("""\
336
+ Directory with PDB/mmCIF files.
337
+ Expected filenames are `{pdb_id}.cif.gz`, `{pdb_id}.cif`, `{pdb_id}.pdb.gz` or `{pdb_id}.pdb`.
338
+ """),
339
+ )
340
+ parser.add_argument(
341
+ "output_dir",
342
+ type=Path,
343
+ help=dedent("""\
344
+ Directory to write the single-chain PDB/mmCIF files. Output files are in same format as input files."""),
345
+ )
346
+ parser.add_argument(
347
+ "--scheduler-address",
348
+ help="Address of the Dask scheduler to connect to. If not provided, will create a local cluster.",
349
+ )
350
+
351
+
352
+ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
353
+ """Add filter residue subcommand parser."""
354
+ parser = subparsers.add_parser(
355
+ "residue",
356
+ help="Filter PDB/mmCIF files by number of residues in chain A",
357
+ description=dedent("""\
358
+ Filter PDB/mmCIF files by number of residues in chain A.
359
+ """),
360
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
361
+ )
362
+ parser.add_argument("input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')")
363
+ parser.add_argument(
364
+ "output_dir",
365
+ type=Path,
366
+ help=dedent("""\
367
+ Directory to write filtered PDB/mmCIF files. Files are copied without modification.
368
+ """),
369
+ )
370
+ parser.add_argument("--min-residues", type=int, default=0, help="Min residues in chain A")
371
+ parser.add_argument("--max-residues", type=int, default=10_000_000, help="Max residues in chain A")
372
+ parser.add_argument(
373
+ "--write-stats",
374
+ type=argparse.FileType("w", encoding="UTF-8"),
375
+ help=dedent("""\
376
+ Write filter statistics to file.
377
+ In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
378
+ Use `-` for stdout."""),
379
+ )
380
+
381
+
382
+ def _add_search_subcommands(subparsers: argparse._SubParsersAction):
383
+ """Add search command and its subcommands."""
384
+ parser = subparsers.add_parser(
385
+ "search",
386
+ help="Search data sources",
387
+ description="Search various things online.",
388
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
389
+ )
390
+ subsubparsers = parser.add_subparsers(dest="search_cmd", required=True)
391
+
392
+ _add_search_uniprot_parser(subsubparsers)
393
+ _add_search_pdbe_parser(subsubparsers)
394
+ _add_search_alphafold_parser(subsubparsers)
395
+ _add_search_emdb_parser(subsubparsers)
396
+ _add_search_go_parser(subsubparsers)
397
+ _add_search_taxonomy_parser(subsubparsers)
398
+
399
+
400
+ def _add_retrieve_subcommands(subparsers: argparse._SubParsersAction):
401
+ """Add retrieve command and its subcommands."""
402
+ parser = subparsers.add_parser(
403
+ "retrieve",
404
+ help="Retrieve structure files",
405
+ description="Retrieve structure files from online resources.",
406
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
407
+ )
408
+ subsubparsers = parser.add_subparsers(dest="retrieve_cmd", required=True)
409
+
410
+ _add_retrieve_pdbe_parser(subsubparsers)
411
+ _add_retrieve_alphafold_parser(subsubparsers)
412
+ _add_retrieve_emdb_parser(subsubparsers)
413
+
414
+
415
+ def _add_filter_subcommands(subparsers: argparse._SubParsersAction):
416
+ """Add filter command and its subcommands."""
417
+ parser = subparsers.add_parser("filter", help="Filter files", formatter_class=ArgumentDefaultsRichHelpFormatter)
418
+ subsubparsers = parser.add_subparsers(dest="filter_cmd", required=True)
419
+
420
+ _add_filter_confidence_parser(subsubparsers)
421
+ _add_filter_chain_parser(subsubparsers)
422
+ _add_filter_residue_parser(subsubparsers)
423
+
424
+
425
+ def _add_mcp_command(subparsers: argparse._SubParsersAction):
426
+ """Add MCP command."""
427
+
428
+ parser = subparsers.add_parser(
429
+ "mcp",
430
+ help="Run Model Context Protocol (MCP) server",
431
+ description=(
432
+ "Run Model Context Protocol (MCP) server. "
433
+ "Can be used by agentic LLMs like Claude Sonnet 4 as a set of tools."
434
+ ),
435
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
436
+ )
437
+ parser.add_argument(
438
+ "--transport", default="stdio", choices=["stdio", "http", "streamable-http"], help="Transport protocol to use"
439
+ )
440
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind the server to")
441
+ parser.add_argument("--port", default=8000, type=int, help="Port to bind the server to")
442
+
443
+
444
+ def make_parser() -> argparse.ArgumentParser:
445
+ parser = argparse.ArgumentParser(
446
+ description="Protein Quest CLI", prog="protein-quest", formatter_class=ArgumentDefaultsRichHelpFormatter
447
+ )
448
+ parser.add_argument("--log-level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
449
+ parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
450
+
451
+ subparsers = parser.add_subparsers(dest="command", required=True)
452
+
453
+ _add_search_subcommands(subparsers)
454
+ _add_retrieve_subcommands(subparsers)
455
+ _add_filter_subcommands(subparsers)
456
+ _add_mcp_command(subparsers)
457
+
458
+ return parser
459
+
460
+
461
+ def main():
462
+ """Main entry point for the CLI."""
463
+ parser = make_parser()
464
+ args = parser.parse_args()
465
+ logging.basicConfig(level=args.log_level, handlers=[RichHandler(show_level=False)])
466
+
467
+ # Dispatch table to reduce complexity
468
+ cmd = args.command
469
+ sub = getattr(args, f"{cmd}_cmd", None)
470
+ handler = HANDLERS.get((cmd, sub))
471
+ if handler is None:
472
+ msg = f"Unknown command: {cmd} {sub}"
473
+ raise SystemExit(msg)
474
+ handler(args)
475
+
476
+
477
+ def _handle_search_uniprot(args):
478
+ taxon_id = args.taxon_id
479
+ reviewed = args.reviewed
480
+ subcellular_location_uniprot = args.subcellular_location_uniprot
481
+ subcellular_location_go = args.subcellular_location_go
482
+ molecular_function_go = args.molecular_function_go
483
+ limit = args.limit
484
+ timeout = args.timeout
485
+ output_file = args.output
486
+
487
+ query = structure(
488
+ {
489
+ "taxon_id": taxon_id,
490
+ "reviewed": reviewed,
491
+ "subcellular_location_uniprot": subcellular_location_uniprot,
492
+ "subcellular_location_go": subcellular_location_go,
493
+ "molecular_function_go": molecular_function_go,
494
+ },
495
+ Query,
496
+ )
497
+ rprint("Searching for UniProt accessions")
498
+ accs = search4uniprot(query=query, limit=limit, timeout=timeout)
499
+ rprint(f"Found {len(accs)} UniProt accessions, written to {output_file.name}")
500
+ _write_lines(output_file, sorted(accs))
501
+
502
+
503
+ def _handle_search_pdbe(args):
504
+ uniprot_accs = args.uniprot_accs
505
+ limit = args.limit
506
+ timeout = args.timeout
507
+ output_csv = args.output_csv
508
+
509
+ accs = set(_read_lines(uniprot_accs))
510
+ rprint(f"Finding PDB entries for {len(accs)} uniprot accessions")
511
+ results = search4pdb(accs, limit=limit, timeout=timeout)
512
+ total_pdbs = sum([len(v) for v in results.values()])
513
+ rprint(f"Found {total_pdbs} PDB entries for {len(results)} uniprot accessions")
514
+ rprint(f"Written to {output_csv.name}")
515
+ _write_pdbe_csv(output_csv, results)
516
+
517
+
518
+ def _handle_search_alphafold(args):
519
+ uniprot_accs = args.uniprot_accs
520
+ limit = args.limit
521
+ timeout = args.timeout
522
+ output_csv = args.output_csv
523
+
524
+ accs = _read_lines(uniprot_accs)
525
+ rprint(f"Finding AlphaFold entries for {len(accs)} uniprot accessions")
526
+ results = search4af(accs, limit=limit, timeout=timeout)
527
+ rprint(f"Found {len(results)} AlphaFold entries, written to {output_csv.name}")
528
+ _write_dict_of_sets2csv(output_csv, results, "af_id")
529
+
530
+
531
+ def _handle_search_emdb(args):
532
+ uniprot_accs = args.uniprot_accs
533
+ limit = args.limit
534
+ timeout = args.timeout
535
+ output_csv = args.output_csv
536
+
537
+ accs = _read_lines(uniprot_accs)
538
+ rprint(f"Finding EMDB entries for {len(accs)} uniprot accessions")
539
+ results = search4emdb(accs, limit=limit, timeout=timeout)
540
+ total_emdbs = sum([len(v) for v in results.values()])
541
+ rprint(f"Found {total_emdbs} EMDB entries, written to {output_csv.name}")
542
+ _write_dict_of_sets2csv(output_csv, results, "emdb_id")
543
+
544
+
545
+ def _handle_search_go(args):
546
+ term = structure(args.term, str)
547
+ aspect: Aspect | None = args.aspect
548
+ limit = structure(args.limit, int)
549
+ output_csv: TextIOWrapper = args.output_csv
550
+
551
+ if aspect:
552
+ rprint(f"Searching for GO terms matching '{term}' with aspect '{aspect}'")
553
+ else:
554
+ rprint(f"Searching for GO terms matching '{term}'")
555
+ results = asyncio.run(search_gene_ontology_term(term, aspect=aspect, limit=limit))
556
+ rprint(f"Found {len(results)} GO terms, written to {output_csv.name}")
557
+ write_go_terms_to_csv(results, output_csv)
558
+
559
+
560
+ def _handle_search_taxonomy(args):
561
+ query: str = args.query
562
+ field: SearchField | None = args.field
563
+ limit: int = args.limit
564
+ output_csv: TextIOWrapper = args.output_csv
565
+
566
+ if field:
567
+ rprint(f"Searching for taxon information matching '{query}' in field '{field}'")
568
+ else:
569
+ rprint(f"Searching for taxon information matching '{query}'")
570
+ results = asyncio.run(search_taxon(query=query, field=field, limit=limit))
571
+ rprint(f"Found {len(results)} taxons, written to {output_csv.name}")
572
+ _write_taxonomy_csv(results, output_csv)
573
+
574
+
575
+ def _handle_retrieve_pdbe(args):
576
+ pdbe_csv = args.pdbe_csv
577
+ output_dir = args.output_dir
578
+ max_parallel_downloads = args.max_parallel_downloads
579
+
580
+ pdb_ids = _read_column_from_csv(pdbe_csv, "pdb_id")
581
+ rprint(f"Retrieving {len(pdb_ids)} PDBe entries")
582
+ result = asyncio.run(pdbe_fetch.fetch(pdb_ids, output_dir, max_parallel_downloads=max_parallel_downloads))
583
+ rprint(f"Retrieved {len(result)} PDBe entries")
584
+
585
+
586
+ def _handle_retrieve_alphafold(args):
587
+ download_dir = args.output_dir
588
+ what_af_formats = args.what_af_formats
589
+ alphafold_csv = args.alphafold_csv
590
+ max_parallel_downloads = args.max_parallel_downloads
591
+
592
+ if what_af_formats is None:
593
+ what_af_formats = {"pdb"}
594
+
595
+ # TODO besides `uniprot_acc,af_id\n` csv also allow headless single column format
596
+ #
597
+ af_ids = [r["af_id"] for r in _read_alphafold_csv(alphafold_csv)]
598
+ validated_what: set[DownloadableFormat] = structure(what_af_formats, set[DownloadableFormat])
599
+ rprint(f"Retrieving {len(af_ids)} AlphaFold entries with formats {validated_what}")
600
+ afs = af_fetch(af_ids, download_dir, what=validated_what, max_parallel_downloads=max_parallel_downloads)
601
+ total_nr_files = sum(af.nr_of_files() for af in afs)
602
+ rprint(f"Retrieved {total_nr_files} AlphaFold files and {len(afs)} summaries, written to {download_dir}")
603
+
604
+
605
+ def _handle_retrieve_emdb(args):
606
+ emdb_csv = args.emdb_csv
607
+ output_dir = args.output_dir
608
+
609
+ emdb_ids = _read_column_from_csv(emdb_csv, "emdb_id")
610
+ rprint(f"Retrieving {len(emdb_ids)} EMDB entries")
611
+ result = asyncio.run(emdb_fetch(emdb_ids, output_dir))
612
+ rprint(f"Retrieved {len(result)} EMDB entries")
613
+
614
+
615
+ def _handle_filter_confidence(args: argparse.Namespace):
616
+ # we are repeating types here and in add_argument call
617
+ # TODO replace argparse with modern alternative like cyclopts
618
+ # to get rid of duplication
619
+ input_dir = structure(args.input_dir, Path)
620
+ output_dir = structure(args.output_dir, Path)
621
+ confidence_threshold = structure(args.confidence_threshold, float)
622
+ # TODO add min/max
623
+ min_residues = structure(args.min_residues, int)
624
+ max_residues = structure(args.max_residues, int)
625
+ stats_file: TextIOWrapper | None = args.write_stats
626
+
627
+ output_dir.mkdir(parents=True, exist_ok=True)
628
+ input_files = sorted(glob_structure_files(input_dir))
629
+ nr_input_files = len(input_files)
630
+ rprint(f"Starting confidence filtering of {nr_input_files} mmcif/PDB files in {input_dir} directory.")
631
+ query = structure(
632
+ {
633
+ "confidence": confidence_threshold,
634
+ "min_threshold": min_residues,
635
+ "max_threshold": max_residues,
636
+ },
637
+ ConfidenceFilterQuery,
638
+ )
639
+ if stats_file:
640
+ writer = csv.writer(stats_file)
641
+ writer.writerow(["input_file", "residue_count", "passed", "output_file"])
642
+
643
+ passed_count = 0
644
+ for r in tqdm(filter_files_on_confidence(input_files, query, output_dir), total=len(input_files), unit="file"):
645
+ if r.filtered_file:
646
+ passed_count += 1
647
+ if stats_file:
648
+ writer.writerow([r.input_file, r.count, r.filtered_file is not None, r.filtered_file])
649
+
650
+ rprint(f"Filtered {passed_count} mmcif/PDB files by confidence, written to {output_dir} directory")
651
+ if stats_file:
652
+ rprint(f"Statistics written to {stats_file.name}")
653
+
654
+
655
+ def _handle_filter_chain(args):
656
+ input_dir = args.input_dir
657
+ output_dir = args.output_dir
658
+ pdb_id2chain_mapping_file = args.chains
659
+ scheduler_address = args.scheduler_address
660
+
661
+ rows = list(_iter_csv_rows(pdb_id2chain_mapping_file))
662
+ id2chains: dict[str, str] = {row["pdb_id"]: row["chain"] for row in rows}
663
+
664
+ new_files = filter_files_on_chain(input_dir, id2chains, output_dir, scheduler_address)
665
+
666
+ nr_written = len([r for r in new_files if r[2] is not None])
667
+
668
+ rprint(f"Wrote {nr_written} single-chain PDB/mmCIF files to {output_dir}.")
669
+
670
+
671
+ def _handle_filter_residue(args):
672
+ input_dir = structure(args.input_dir, Path)
673
+ output_dir = structure(args.output_dir, Path)
674
+ min_residues = structure(args.min_residues, int)
675
+ max_residues = structure(args.max_residues, int)
676
+ stats_file: TextIOWrapper | None = args.write_stats
677
+
678
+ if stats_file:
679
+ writer = csv.writer(stats_file)
680
+ writer.writerow(["input_file", "residue_count", "passed", "output_file"])
681
+
682
+ nr_passed = 0
683
+ input_files = sorted(glob_structure_files(input_dir))
684
+ nr_total = len(input_files)
685
+ rprint(f"Filtering {nr_total} files in {input_dir} directory by number of residues in chain A.")
686
+ for r in filter_files_on_residues(input_files, output_dir, min_residues=min_residues, max_residues=max_residues):
687
+ if stats_file:
688
+ writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file])
689
+ if r.passed:
690
+ nr_passed += 1
691
+
692
+ rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
693
+ if stats_file:
694
+ rprint(f"Statistics written to {stats_file.name}")
695
+
696
+
697
+ def _handle_mcp(args):
698
+ if find_spec("fastmcp") is None:
699
+ msg = "Unable to start MCP server, please install `protein-quest[mcp]`."
700
+ raise ImportError(msg)
701
+
702
+ from protein_quest.mcp_server import mcp # noqa: PLC0415
703
+
704
+ if args.transport == "stdio":
705
+ mcp.run(transport=args.transport)
706
+ else:
707
+ mcp.run(transport=args.transport, host=args.host, port=args.port)
708
+
709
+
710
+ HANDLERS: dict[tuple[str, str | None], Callable] = {
711
+ ("search", "uniprot"): _handle_search_uniprot,
712
+ ("search", "pdbe"): _handle_search_pdbe,
713
+ ("search", "alphafold"): _handle_search_alphafold,
714
+ ("search", "emdb"): _handle_search_emdb,
715
+ ("search", "go"): _handle_search_go,
716
+ ("search", "taxonomy"): _handle_search_taxonomy,
717
+ ("retrieve", "pdbe"): _handle_retrieve_pdbe,
718
+ ("retrieve", "alphafold"): _handle_retrieve_alphafold,
719
+ ("retrieve", "emdb"): _handle_retrieve_emdb,
720
+ ("filter", "confidence"): _handle_filter_confidence,
721
+ ("filter", "chain"): _handle_filter_chain,
722
+ ("filter", "residue"): _handle_filter_residue,
723
+ ("mcp", None): _handle_mcp,
724
+ }
725
+
726
+
727
+ def _read_lines(file: TextIOWrapper) -> list[str]:
728
+ return [line.strip() for line in file]
729
+
730
+
731
+ def _make_sure_parent_exists(file: TextIOWrapper):
732
+ if file.name != "<stdout>":
733
+ Path(file.name).parent.mkdir(parents=True, exist_ok=True)
734
+
735
+
736
+ def _write_lines(file: TextIOWrapper, lines: Iterable[str]):
737
+ _make_sure_parent_exists(file)
738
+ file.writelines(line + os.linesep for line in lines)
739
+
740
+
741
+ def _write_pdbe_csv(path: TextIOWrapper, data: dict[str, set[PdbResult]]):
742
+ _make_sure_parent_exists(path)
743
+ fieldnames = ["uniprot_acc", "pdb_id", "method", "resolution", "uniprot_chains", "chain"]
744
+ writer = csv.DictWriter(path, fieldnames=fieldnames)
745
+ writer.writeheader()
746
+ for uniprot_acc, entries in sorted(data.items()):
747
+ for e in sorted(entries, key=lambda x: (x.id, x.method)):
748
+ writer.writerow(
749
+ {
750
+ "uniprot_acc": uniprot_acc,
751
+ "pdb_id": e.id,
752
+ "method": e.method,
753
+ "resolution": e.resolution or "",
754
+ "uniprot_chains": e.uniprot_chains,
755
+ "chain": e.chain,
756
+ }
757
+ )
758
+
759
+
760
+ def _write_dict_of_sets2csv(file: TextIOWrapper, data: dict[str, set[str]], ref_id_field: str):
761
+ _make_sure_parent_exists(file)
762
+ fieldnames = ["uniprot_acc", ref_id_field]
763
+
764
+ writer = csv.DictWriter(file, fieldnames=fieldnames)
765
+ writer.writeheader()
766
+ for uniprot_acc, ref_ids in sorted(data.items()):
767
+ for ref_id in sorted(ref_ids):
768
+ writer.writerow({"uniprot_acc": uniprot_acc, ref_id_field: ref_id})
769
+
770
+
771
+ def _read_alphafold_csv(file: TextIOWrapper):
772
+ reader = csv.DictReader(file)
773
+ yield from reader
774
+
775
+
776
+ def _iter_csv_rows(file: TextIOWrapper):
777
+ reader = csv.DictReader(file)
778
+ yield from reader
779
+
780
+
781
+ def _read_column_from_csv(file: TextIOWrapper, column: str) -> set[str]:
782
+ return {row[column] for row in _iter_csv_rows(file)}