protein-quest 0.6.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
protein_quest/cli.py CHANGED
@@ -6,14 +6,16 @@ import csv
6
6
  import logging
7
7
  import os
8
8
  import sys
9
- from collections.abc import Callable, Generator, Iterable
9
+ from collections.abc import Callable, Generator, Iterable, Sequence
10
+ from contextlib import suppress
10
11
  from importlib.util import find_spec
11
- from io import TextIOWrapper
12
+ from io import BytesIO, TextIOWrapper
12
13
  from pathlib import Path
13
14
  from textwrap import dedent
14
15
 
16
+ import shtab
15
17
  from cattrs import structure
16
- from rich import print as rprint
18
+ from rich.console import Console
17
19
  from rich.logging import RichHandler
18
20
  from rich.markdown import Markdown
19
21
  from rich.panel import Panel
@@ -24,7 +26,7 @@ from protein_quest.__version__ import __version__
24
26
  from protein_quest.alphafold.confidence import ConfidenceFilterQuery, filter_files_on_confidence
25
27
  from protein_quest.alphafold.fetch import DownloadableFormat, downloadable_formats
26
28
  from protein_quest.alphafold.fetch import fetch_many as af_fetch
27
- from protein_quest.converter import converter
29
+ from protein_quest.converter import PositiveInt, converter
28
30
  from protein_quest.emdb import fetch as emdb_fetch
29
31
  from protein_quest.filters import filter_files_on_chain, filter_files_on_residues
30
32
  from protein_quest.go import Aspect, allowed_aspects, search_gene_ontology_term, write_go_terms_to_csv
@@ -32,15 +34,20 @@ from protein_quest.io import (
32
34
  convert_to_cif_files,
33
35
  glob_structure_files,
34
36
  locate_structure_file,
37
+ read_structure,
35
38
  valid_structure_file_extensions,
36
39
  )
37
40
  from protein_quest.pdbe import fetch as pdbe_fetch
38
41
  from protein_quest.ss import SecondaryStructureFilterQuery, filter_files_on_secondary_structure
42
+ from protein_quest.structure import structure2uniprot_accessions
39
43
  from protein_quest.taxonomy import SearchField, _write_taxonomy_csv, search_fields, search_taxon
40
44
  from protein_quest.uniprot import (
41
45
  ComplexPortalEntry,
42
- PdbResult,
46
+ PdbResults,
43
47
  Query,
48
+ UniprotDetails,
49
+ filter_pdb_results_on_chain_length,
50
+ map_uniprot_accessions2uniprot_details,
44
51
  search4af,
45
52
  search4emdb,
46
53
  search4interaction_partners,
@@ -58,6 +65,8 @@ from protein_quest.utils import (
58
65
  user_cache_root_dir,
59
66
  )
60
67
 
68
+ console = Console(stderr=True)
69
+ rprint = console.print
61
70
  logger = logging.getLogger(__name__)
62
71
 
63
72
 
@@ -73,7 +82,7 @@ def _add_search_uniprot_parser(subparsers: argparse._SubParsersAction):
73
82
  "output",
74
83
  type=argparse.FileType("w", encoding="UTF-8"),
75
84
  help="Output text file for UniProt accessions (one per line). Use `-` for stdout.",
76
- )
85
+ ).complete = shtab.FILE
77
86
  parser.add_argument("--taxon-id", type=str, help="NCBI Taxon ID, e.g. 9606 for Homo Sapiens")
78
87
  parser.add_argument(
79
88
  "--reviewed",
@@ -98,6 +107,8 @@ def _add_search_uniprot_parser(subparsers: argparse._SubParsersAction):
98
107
  action="append",
99
108
  help="GO term(s) for molecular function (e.g. GO:0003677). Can be given multiple times.",
100
109
  )
110
+ parser.add_argument("--min-sequence-length", type=int, help="Minimum length of the canonical sequence.")
111
+ parser.add_argument("--max-sequence-length", type=int, help="Maximum length of the canonical sequence.")
101
112
  parser.add_argument("--limit", type=int, default=10_000, help="Maximum number of uniprot accessions to return")
102
113
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
103
114
 
@@ -111,23 +122,44 @@ def _add_search_pdbe_parser(subparsers: argparse._SubParsersAction):
111
122
  formatter_class=ArgumentDefaultsRichHelpFormatter,
112
123
  )
113
124
  parser.add_argument(
114
- "uniprot_accs",
125
+ "uniprot_accessions",
115
126
  type=argparse.FileType("r", encoding="UTF-8"),
116
127
  help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
117
- )
128
+ ).complete = shtab.FILE
118
129
  parser.add_argument(
119
130
  "output_csv",
120
131
  type=argparse.FileType("w", encoding="UTF-8"),
121
132
  help=dedent("""\
122
- Output CSV with `uniprot_acc`, `pdb_id`, `method`, `resolution`, `uniprot_chains`, `chain` columns.
133
+ Output CSV with following columns:
134
+ `uniprot_accession`, `pdb_id`, `method`, `resolution`, `uniprot_chains`, `chain`, `chain_length`.
123
135
  Where `uniprot_chains` is the raw UniProt chain string, for example `A=1-100`.
124
- and where `chain` is the first chain from `uniprot_chains`, for example `A`.
136
+ and where `chain` is the first chain from `uniprot_chains`, for example `A`
137
+ and `chain_length` is the length of the chain, for example `100`.
125
138
  Use `-` for stdout.
126
139
  """),
127
- )
140
+ ).complete = shtab.FILE
128
141
  parser.add_argument(
129
142
  "--limit", type=int, default=10_000, help="Maximum number of PDB uniprot accessions combinations to return"
130
143
  )
144
+ parser.add_argument(
145
+ "--min-residues",
146
+ type=int,
147
+ help="Minimum number of residues required in the chain mapped to the UniProt accession.",
148
+ )
149
+ parser.add_argument(
150
+ "--max-residues",
151
+ type=int,
152
+ help="Maximum number of residues allowed in chain mapped to the UniProt accession.",
153
+ )
154
+ parser.add_argument(
155
+ "--keep-invalid",
156
+ action="store_true",
157
+ help=dedent("""\
158
+ Keep PDB results when chain length could not be determined.
159
+ If not given, such results are dropped.
160
+ Only applies if min/max residues arguments are set.
161
+ """),
162
+ )
131
163
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
132
164
 
133
165
 
@@ -140,15 +172,17 @@ def _add_search_alphafold_parser(subparsers: argparse._SubParsersAction):
140
172
  formatter_class=ArgumentDefaultsRichHelpFormatter,
141
173
  )
142
174
  parser.add_argument(
143
- "uniprot_accs",
175
+ "uniprot_accessions",
144
176
  type=argparse.FileType("r", encoding="UTF-8"),
145
177
  help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
146
- )
178
+ ).complete = shtab.FILE
147
179
  parser.add_argument(
148
180
  "output_csv",
149
181
  type=argparse.FileType("w", encoding="UTF-8"),
150
182
  help="Output CSV with AlphaFold IDs per UniProt accession. Use `-` for stdout.",
151
- )
183
+ ).complete = shtab.FILE
184
+ parser.add_argument("--min-sequence-length", type=int, help="Minimum length of the canonical sequence.")
185
+ parser.add_argument("--max-sequence-length", type=int, help="Maximum length of the canonical sequence.")
152
186
  parser.add_argument(
153
187
  "--limit", type=int, default=10_000, help="Maximum number of Alphafold entry identifiers to return"
154
188
  )
@@ -170,12 +204,12 @@ def _add_search_emdb_parser(subparsers: argparse._SubParsersAction):
170
204
  "uniprot_accs",
171
205
  type=argparse.FileType("r", encoding="UTF-8"),
172
206
  help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
173
- )
207
+ ).complete = shtab.FILE
174
208
  parser.add_argument(
175
209
  "output_csv",
176
210
  type=argparse.FileType("w", encoding="UTF-8"),
177
211
  help="Output CSV with EMDB IDs per UniProt accession. Use `-` for stdout.",
178
- )
212
+ ).complete = shtab.FILE
179
213
  parser.add_argument("--limit", type=int, default=10_000, help="Maximum number of EMDB entry identifiers to return")
180
214
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
181
215
 
@@ -198,7 +232,7 @@ def _add_search_go_parser(subparsers: argparse._SubParsersAction):
198
232
  "output_csv",
199
233
  type=argparse.FileType("w", encoding="UTF-8"),
200
234
  help="Output CSV with GO term results. Use `-` for stdout.",
201
- )
235
+ ).complete = shtab.FILE
202
236
  parser.add_argument("--limit", type=int, default=100, help="Maximum number of GO term results to return")
203
237
 
204
238
 
@@ -220,7 +254,7 @@ def _add_search_taxonomy_parser(subparser: argparse._SubParsersAction):
220
254
  "output_csv",
221
255
  type=argparse.FileType("w", encoding="UTF-8"),
222
256
  help="Output CSV with taxonomy results. Use `-` for stdout.",
223
- )
257
+ ).complete = shtab.FILE
224
258
  parser.add_argument(
225
259
  "--field",
226
260
  type=str,
@@ -247,7 +281,7 @@ def _add_search_interaction_partners_parser(subparsers: argparse._SubParsersActi
247
281
  formatter_class=ArgumentDefaultsRichHelpFormatter,
248
282
  )
249
283
  parser.add_argument(
250
- "uniprot_acc",
284
+ "uniprot_accession",
251
285
  type=str,
252
286
  help="UniProt accession (for example P12345).",
253
287
  )
@@ -261,7 +295,7 @@ def _add_search_interaction_partners_parser(subparsers: argparse._SubParsersActi
261
295
  "output_csv",
262
296
  type=argparse.FileType("w", encoding="UTF-8"),
263
297
  help="Output CSV with interaction partners per UniProt accession. Use `-` for stdout.",
264
- )
298
+ ).complete = shtab.FILE
265
299
  parser.add_argument(
266
300
  "--limit", type=int, default=10_000, help="Maximum number of interaction partner uniprot accessions to return"
267
301
  )
@@ -289,19 +323,57 @@ def _add_search_complexes_parser(subparsers: argparse._SubParsersAction):
289
323
  formatter_class=ArgumentDefaultsRichHelpFormatter,
290
324
  )
291
325
  parser.add_argument(
292
- "uniprot_accs",
326
+ "uniprot_accessions",
293
327
  type=argparse.FileType("r", encoding="UTF-8"),
294
328
  help="Text file with UniProt accessions (one per line) as query for searching complexes. Use `-` for stdin.",
295
- )
329
+ ).complete = shtab.FILE
296
330
  parser.add_argument(
297
331
  "output_csv",
298
332
  type=argparse.FileType("w", encoding="UTF-8"),
299
333
  help="Output CSV file with complex results. Use `-` for stdout.",
300
- )
334
+ ).complete = shtab.FILE
301
335
  parser.add_argument("--limit", type=int, default=100, help="Maximum number of complex results to return")
302
336
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
303
337
 
304
338
 
339
+ def _add_search_uniprot_details_parser(subparsers: argparse._SubParsersAction):
340
+ """Add search uniprot details subcommand parser."""
341
+ description = dedent("""\
342
+ Retrieve UniProt details for given UniProt accessions
343
+ from the Uniprot SPARQL endpoint.
344
+
345
+ The output CSV file has the following columns:
346
+
347
+ - uniprot_accession: UniProt accession.
348
+ - uniprot_id: UniProt ID (mnemonic).
349
+ - sequence_length: Length of the canonical sequence.
350
+ - reviewed: Whether the entry is reviewed (Swiss-Prot) or unreviewed (TrEMBL).
351
+ - protein_name: Recommended protein name.
352
+ - taxon_id: NCBI Taxonomy ID of the organism.
353
+ - taxon_name: Scientific name of the organism.
354
+
355
+ The order of the output CSV can be different from the input order.
356
+ """)
357
+ parser = subparsers.add_parser(
358
+ "uniprot-details",
359
+ help="Retrieve UniProt details for given UniProt accessions",
360
+ description=Markdown(description, style="argparse.text"), # type: ignore using rich formatter makes this OK
361
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
362
+ )
363
+ parser.add_argument(
364
+ "uniprot_accessions",
365
+ type=argparse.FileType("r", encoding="UTF-8"),
366
+ help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
367
+ ).complete = shtab.FILE
368
+ parser.add_argument(
369
+ "output_csv",
370
+ type=argparse.FileType("w", encoding="UTF-8"),
371
+ help="Output CSV with UniProt details. Use `-` for stdout.",
372
+ ).complete = shtab.FILE
373
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
374
+ parser.add_argument("--batch-size", type=int, default=1_000, help="Number of accessions to query per batch")
375
+
376
+
305
377
  def _add_copy_method_arguments(parser):
306
378
  parser.add_argument(
307
379
  "--copy-method",
@@ -325,12 +397,13 @@ def _add_cacher_arguments(parser: argparse.ArgumentParser):
325
397
  action="store_true",
326
398
  help="Disable caching of files to central location.",
327
399
  )
328
- parser.add_argument(
400
+ cache_dir_action = parser.add_argument(
329
401
  "--cache-dir",
330
402
  type=Path,
331
403
  default=user_cache_root_dir(),
332
404
  help="Directory to use as cache for files.",
333
405
  )
406
+ cache_dir_action.complete = shtab.DIRECTORY # type: ignore[missing-attribute]
334
407
  _add_copy_method_arguments(parser)
335
408
 
336
409
 
@@ -349,8 +422,10 @@ def _add_retrieve_pdbe_parser(subparsers: argparse._SubParsersAction):
349
422
  "pdbe_csv",
350
423
  type=argparse.FileType("r", encoding="UTF-8"),
351
424
  help="CSV file with `pdb_id` column. Other columns are ignored. Use `-` for stdin.",
352
- )
353
- parser.add_argument("output_dir", type=Path, help="Directory to store downloaded PDBe mmCIF files")
425
+ ).complete = shtab.FILE
426
+ parser.add_argument(
427
+ "output_dir", type=Path, help="Directory to store downloaded PDBe mmCIF files"
428
+ ).complete = shtab.DIRECTORY
354
429
  parser.add_argument(
355
430
  "--max-parallel-downloads",
356
431
  type=int,
@@ -372,21 +447,36 @@ def _add_retrieve_alphafold_parser(subparsers: argparse._SubParsersAction):
372
447
  "alphafold_csv",
373
448
  type=argparse.FileType("r", encoding="UTF-8"),
374
449
  help="CSV file with `af_id` column. Other columns are ignored. Use `-` for stdin.",
375
- )
376
- parser.add_argument("output_dir", type=Path, help="Directory to store downloaded AlphaFold files")
450
+ ).complete = shtab.FILE
377
451
  parser.add_argument(
378
- "--what-formats",
452
+ "output_dir", type=Path, help="Directory to store downloaded AlphaFold files"
453
+ ).complete = shtab.DIRECTORY
454
+ parser.add_argument(
455
+ "--format",
379
456
  type=str,
380
457
  action="append",
381
458
  choices=sorted(downloadable_formats),
382
459
  help=dedent("""AlphaFold formats to retrieve. Can be specified multiple times.
383
- Default is 'summary' and 'cif'."""),
460
+ Default is 'cif'."""),
461
+ )
462
+ parser.add_argument(
463
+ "--db-version",
464
+ type=str,
465
+ help="AlphaFold database version to use. If not given, the latest version is used. For example '6'.",
384
466
  )
385
467
  parser.add_argument(
386
468
  "--gzip-files",
387
469
  action="store_true",
388
470
  help="Whether to gzip the downloaded files. Excludes summary files, they are always uncompressed.",
389
471
  )
472
+ parser.add_argument(
473
+ "--all-isoforms",
474
+ action="store_true",
475
+ help=(
476
+ "Whether to return all isoforms of each uniprot entry. "
477
+ "If not given then only the Alphafold entry for the canonical sequence is returned."
478
+ ),
479
+ )
390
480
  parser.add_argument(
391
481
  "--max-parallel-downloads",
392
482
  type=int,
@@ -411,8 +501,10 @@ def _add_retrieve_emdb_parser(subparsers: argparse._SubParsersAction):
411
501
  "emdb_csv",
412
502
  type=argparse.FileType("r", encoding="UTF-8"),
413
503
  help="CSV file with `emdb_id` column. Other columns are ignored. Use `-` for stdin.",
414
- )
415
- parser.add_argument("output_dir", type=Path, help="Directory to store downloaded EMDB volume files")
504
+ ).complete = shtab.FILE
505
+ parser.add_argument(
506
+ "output_dir", type=Path, help="Directory to store downloaded EMDB volume files"
507
+ ).complete = shtab.DIRECTORY
416
508
  _add_cacher_arguments(parser)
417
509
 
418
510
 
@@ -426,8 +518,12 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
426
518
  Passed files are written with residues below threshold removed."""),
427
519
  formatter_class=ArgumentDefaultsRichHelpFormatter,
428
520
  )
429
- parser.add_argument("input_dir", type=Path, help="Directory with AlphaFold mmcif/PDB files")
430
- parser.add_argument("output_dir", type=Path, help="Directory to write filtered mmcif/PDB files")
521
+ parser.add_argument(
522
+ "input_dir", type=Path, help="Directory with AlphaFold mmcif/PDB files"
523
+ ).complete = shtab.DIRECTORY
524
+ parser.add_argument(
525
+ "output_dir", type=Path, help="Directory to write filtered mmcif/PDB files"
526
+ ).complete = shtab.DIRECTORY
431
527
  parser.add_argument("--confidence-threshold", type=float, default=70, help="pLDDT confidence threshold (0-100)")
432
528
  parser.add_argument(
433
529
  "--min-residues", type=int, default=0, help="Minimum number of high-confidence residues a structure should have"
@@ -445,7 +541,7 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
445
541
  Write filter statistics to file.
446
542
  In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
447
543
  Use `-` for stdout."""),
448
- )
544
+ ).complete = shtab.FILE
449
545
  _add_copy_method_arguments(parser)
450
546
 
451
547
 
@@ -465,7 +561,7 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
465
561
  "chains",
466
562
  type=argparse.FileType("r", encoding="UTF-8"),
467
563
  help="CSV file with `pdb_id` and `chain` columns. Other columns are ignored.",
468
- )
564
+ ).complete = shtab.FILE
469
565
  parser.add_argument(
470
566
  "input_dir",
471
567
  type=Path,
@@ -473,13 +569,13 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
473
569
  Directory with PDB/mmCIF files.
474
570
  Expected filenames are `{pdb_id}.cif.gz`, `{pdb_id}.cif`, `{pdb_id}.pdb.gz` or `{pdb_id}.pdb`.
475
571
  """),
476
- )
572
+ ).complete = shtab.DIRECTORY
477
573
  parser.add_argument(
478
574
  "output_dir",
479
575
  type=Path,
480
576
  help=dedent("""\
481
577
  Directory to write the single-chain PDB/mmCIF files. Output files are in same format as input files."""),
482
- )
578
+ ).complete = shtab.DIRECTORY
483
579
  parser.add_argument(
484
580
  "--scheduler-address",
485
581
  help=dedent("""Address of the Dask scheduler to connect to.
@@ -499,14 +595,16 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
499
595
  """),
500
596
  formatter_class=ArgumentDefaultsRichHelpFormatter,
501
597
  )
502
- parser.add_argument("input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')")
598
+ parser.add_argument(
599
+ "input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')"
600
+ ).complete = shtab.DIRECTORY
503
601
  parser.add_argument(
504
602
  "output_dir",
505
603
  type=Path,
506
604
  help=dedent("""\
507
605
  Directory to write filtered PDB/mmCIF files. Files are copied without modification.
508
606
  """),
509
- )
607
+ ).complete = shtab.DIRECTORY
510
608
  parser.add_argument("--min-residues", type=int, default=0, help="Min residues in chain A")
511
609
  parser.add_argument("--max-residues", type=int, default=10_000_000, help="Max residues in chain A")
512
610
  parser.add_argument(
@@ -516,7 +614,7 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
516
614
  Write filter statistics to file.
517
615
  In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
518
616
  Use `-` for stdout."""),
519
- )
617
+ ).complete = shtab.FILE
520
618
  _add_copy_method_arguments(parser)
521
619
 
522
620
 
@@ -528,14 +626,16 @@ def _add_filter_ss_parser(subparsers: argparse._SubParsersAction):
528
626
  description="Filter PDB/mmCIF files by secondary structure",
529
627
  formatter_class=ArgumentDefaultsRichHelpFormatter,
530
628
  )
531
- parser.add_argument("input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')")
629
+ parser.add_argument(
630
+ "input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')"
631
+ ).complete = shtab.DIRECTORY
532
632
  parser.add_argument(
533
633
  "output_dir",
534
634
  type=Path,
535
635
  help=dedent("""\
536
636
  Directory to write filtered PDB/mmCIF files. Files are copied without modification.
537
637
  """),
538
- )
638
+ ).complete = shtab.DIRECTORY
539
639
  parser.add_argument("--abs-min-helix-residues", type=int, help="Min residues in helices")
540
640
  parser.add_argument("--abs-max-helix-residues", type=int, help="Max residues in helices")
541
641
  parser.add_argument("--abs-min-sheet-residues", type=int, help="Min residues in sheets")
@@ -553,7 +653,7 @@ def _add_filter_ss_parser(subparsers: argparse._SubParsersAction):
553
653
  <helix_ratio>,<sheet_ratio>,<passed>,<output_file>`.
554
654
  Use `-` for stdout.
555
655
  """),
556
- )
656
+ ).complete = shtab.FILE
557
657
  _add_copy_method_arguments(parser)
558
658
 
559
659
 
@@ -575,6 +675,7 @@ def _add_search_subcommands(subparsers: argparse._SubParsersAction):
575
675
  _add_search_taxonomy_parser(subsubparsers)
576
676
  _add_search_interaction_partners_parser(subsubparsers)
577
677
  _add_search_complexes_parser(subsubparsers)
678
+ _add_search_uniprot_details_parser(subsubparsers)
578
679
 
579
680
 
580
681
  def _add_retrieve_subcommands(subparsers: argparse._SubParsersAction):
@@ -603,23 +704,52 @@ def _add_filter_subcommands(subparsers: argparse._SubParsersAction):
603
704
  _add_filter_ss_parser(subsubparsers)
604
705
 
605
706
 
606
- def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
607
- """Add convert command."""
707
+ def _add_convert_uniprot_parser(subparsers: argparse._SubParsersAction):
708
+ """Add convert uniprot subcommand parser."""
608
709
  parser = subparsers.add_parser(
609
- "convert", help="Convert structure files between formats", formatter_class=ArgumentDefaultsRichHelpFormatter
710
+ "uniprot",
711
+ help="Convert structure files to list of UniProt accessions.",
712
+ description="Convert structure files to list of UniProt accessions. "
713
+ "Uniprot accessions are read from database reference of each structure.",
714
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
610
715
  )
611
716
  parser.add_argument(
612
717
  "input_dir",
613
718
  type=Path,
614
719
  help=f"Directory with structure files. Supported extensions are {valid_structure_file_extensions}",
720
+ ).complete = shtab.DIRECTORY
721
+ parser.add_argument(
722
+ "output",
723
+ type=argparse.FileType("wt", encoding="UTF-8"),
724
+ help="Output text file with UniProt accessions (one per line). Use '-' for stdout.",
725
+ ).complete = shtab.FILE
726
+ parser.add_argument(
727
+ "--grouped",
728
+ action="store_true",
729
+ help="Whether to group accessions by structure file. "
730
+ "If set output changes to `<structure_file1>,<acc1>\\n<structure_file1>,<acc2>` format.",
731
+ )
732
+
733
+
734
+ def _add_convert_structures_parser(subparsers: argparse._SubParsersAction):
735
+ """Add convert structures subcommand parser."""
736
+ parser = subparsers.add_parser(
737
+ "structures",
738
+ help="Convert structure files between formats",
739
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
615
740
  )
741
+ parser.add_argument(
742
+ "input_dir",
743
+ type=Path,
744
+ help=f"Directory with structure files. Supported extensions are {valid_structure_file_extensions}",
745
+ ).complete = shtab.DIRECTORY
616
746
  parser.add_argument(
617
747
  "--output-dir",
618
748
  type=Path,
619
749
  help=dedent("""\
620
750
  Directory to write converted structure files. If not given, files are written to `input_dir`.
621
751
  """),
622
- )
752
+ ).complete = shtab.DIRECTORY
623
753
  parser.add_argument(
624
754
  "--format",
625
755
  type=str,
@@ -630,6 +760,19 @@ def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
630
760
  _add_copy_method_arguments(parser)
631
761
 
632
762
 
763
+ def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
764
+ """Add convert command and its subcommands."""
765
+ parser = subparsers.add_parser(
766
+ "convert",
767
+ help="Convert files between formats",
768
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
769
+ )
770
+ subsubparsers = parser.add_subparsers(dest="convert_cmd", required=True)
771
+
772
+ _add_convert_structures_parser(subsubparsers)
773
+ _add_convert_uniprot_parser(subsubparsers)
774
+
775
+
633
776
  def _add_mcp_command(subparsers: argparse._SubParsersAction):
634
777
  """Add MCP command."""
635
778
 
@@ -655,6 +798,7 @@ def make_parser() -> argparse.ArgumentParser:
655
798
  )
656
799
  parser.add_argument("--log-level", default="WARNING", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"])
657
800
  parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
801
+ shtab.add_argument_to(parser, ["--print-completion"])
658
802
 
659
803
  subparsers = parser.add_subparsers(dest="command", required=True)
660
804
 
@@ -667,12 +811,22 @@ def make_parser() -> argparse.ArgumentParser:
667
811
  return parser
668
812
 
669
813
 
814
+ def _name_of(file: TextIOWrapper | BytesIO) -> str:
815
+ try:
816
+ return file.name
817
+ except AttributeError:
818
+ # In pytest BytesIO is used stdout which has no 'name' attribute
819
+ return "<stdout>"
820
+
821
+
670
822
  def _handle_search_uniprot(args):
671
823
  taxon_id = args.taxon_id
672
824
  reviewed = args.reviewed
673
825
  subcellular_location_uniprot = args.subcellular_location_uniprot
674
826
  subcellular_location_go = args.subcellular_location_go
675
827
  molecular_function_go = args.molecular_function_go
828
+ min_sequence_length = args.min_sequence_length
829
+ max_sequence_length = args.max_sequence_length
676
830
  limit = args.limit
677
831
  timeout = args.timeout
678
832
  output_file = args.output
@@ -684,54 +838,79 @@ def _handle_search_uniprot(args):
684
838
  "subcellular_location_uniprot": subcellular_location_uniprot,
685
839
  "subcellular_location_go": subcellular_location_go,
686
840
  "molecular_function_go": molecular_function_go,
841
+ "min_sequence_length": min_sequence_length,
842
+ "max_sequence_length": max_sequence_length,
687
843
  },
688
844
  Query,
689
845
  )
690
846
  rprint("Searching for UniProt accessions")
691
847
  accs = search4uniprot(query=query, limit=limit, timeout=timeout)
692
- rprint(f"Found {len(accs)} UniProt accessions, written to {output_file.name}")
848
+ rprint(f"Found {len(accs)} UniProt accessions, written to {_name_of(output_file)}")
693
849
  _write_lines(output_file, sorted(accs))
694
850
 
695
851
 
696
852
  def _handle_search_pdbe(args):
697
- uniprot_accs = args.uniprot_accs
853
+ uniprot_accessions = args.uniprot_accessions
698
854
  limit = args.limit
699
855
  timeout = args.timeout
700
856
  output_csv = args.output_csv
857
+ min_residues = converter.structure(args.min_residues, PositiveInt | None) # pyright: ignore[reportArgumentType]
858
+ max_residues = converter.structure(args.max_residues, PositiveInt | None) # pyright: ignore[reportArgumentType]
859
+ keep_invalid = args.keep_invalid
701
860
 
702
- accs = set(_read_lines(uniprot_accs))
861
+ accs = set(_read_lines(uniprot_accessions))
703
862
  rprint(f"Finding PDB entries for {len(accs)} uniprot accessions")
704
863
  results = search4pdb(accs, limit=limit, timeout=timeout)
705
- total_pdbs = sum([len(v) for v in results.values()])
706
- rprint(f"Found {total_pdbs} PDB entries for {len(results)} uniprot accessions")
707
- rprint(f"Written to {output_csv.name}")
864
+
865
+ raw_nr_results = len(results)
866
+ raw_total_pdbs = sum([len(v) for v in results.values()])
867
+ if min_residues or max_residues:
868
+ results = filter_pdb_results_on_chain_length(results, min_residues, max_residues, keep_invalid=keep_invalid)
869
+ total_pdbs = sum([len(v) for v in results.values()])
870
+ rprint(f"Before filtering found {raw_total_pdbs} PDB entries for {raw_nr_results} uniprot accessions.")
871
+ rprint(
872
+ f"After filtering on chain length ({min_residues}, {max_residues}) "
873
+ f"remained {total_pdbs} PDB entries for {len(results)} uniprot accessions."
874
+ )
875
+ else:
876
+ rprint(f"Found {raw_total_pdbs} PDB entries for {raw_nr_results} uniprot accessions")
877
+
708
878
  _write_pdbe_csv(output_csv, results)
879
+ rprint(f"Written to {_name_of(output_csv)}")
709
880
 
710
881
 
711
882
  def _handle_search_alphafold(args):
712
- uniprot_accs = args.uniprot_accs
883
+ uniprot_accessions = args.uniprot_accessions
884
+ min_sequence_length = converter.structure(args.min_sequence_length, PositiveInt | None) # pyright: ignore[reportArgumentType]
885
+ max_sequence_length = converter.structure(args.max_sequence_length, PositiveInt | None) # pyright: ignore[reportArgumentType]
713
886
  limit = args.limit
714
887
  timeout = args.timeout
715
888
  output_csv = args.output_csv
716
889
 
717
- accs = _read_lines(uniprot_accs)
890
+ accs = _read_lines(uniprot_accessions)
718
891
  rprint(f"Finding AlphaFold entries for {len(accs)} uniprot accessions")
719
- results = search4af(accs, limit=limit, timeout=timeout)
720
- rprint(f"Found {len(results)} AlphaFold entries, written to {output_csv.name}")
892
+ results = search4af(
893
+ accs,
894
+ min_sequence_length=min_sequence_length,
895
+ max_sequence_length=max_sequence_length,
896
+ limit=limit,
897
+ timeout=timeout,
898
+ )
899
+ rprint(f"Found {len(results)} AlphaFold entries, written to {_name_of(output_csv)}")
721
900
  _write_dict_of_sets2csv(output_csv, results, "af_id")
722
901
 
723
902
 
724
903
  def _handle_search_emdb(args):
725
- uniprot_accs = args.uniprot_accs
904
+ uniprot_accessions = args.uniprot_accessions
726
905
  limit = args.limit
727
906
  timeout = args.timeout
728
907
  output_csv = args.output_csv
729
908
 
730
- accs = _read_lines(uniprot_accs)
909
+ accs = _read_lines(uniprot_accessions)
731
910
  rprint(f"Finding EMDB entries for {len(accs)} uniprot accessions")
732
911
  results = search4emdb(accs, limit=limit, timeout=timeout)
733
912
  total_emdbs = sum([len(v) for v in results.values()])
734
- rprint(f"Found {total_emdbs} EMDB entries, written to {output_csv.name}")
913
+ rprint(f"Found {total_emdbs} EMDB entries, written to {_name_of(output_csv)}")
735
914
  _write_dict_of_sets2csv(output_csv, results, "emdb_id")
736
915
 
737
916
 
@@ -746,7 +925,7 @@ def _handle_search_go(args):
746
925
  else:
747
926
  rprint(f"Searching for GO terms matching '{term}'")
748
927
  results = asyncio.run(search_gene_ontology_term(term, aspect=aspect, limit=limit))
749
- rprint(f"Found {len(results)} GO terms, written to {output_csv.name}")
928
+ rprint(f"Found {len(results)} GO terms, written to {_name_of(output_csv)}")
750
929
  write_go_terms_to_csv(results, output_csv)
751
930
 
752
931
 
@@ -761,36 +940,49 @@ def _handle_search_taxonomy(args):
761
940
  else:
762
941
  rprint(f"Searching for taxon information matching '{query}'")
763
942
  results = asyncio.run(search_taxon(query=query, field=field, limit=limit))
764
- rprint(f"Found {len(results)} taxons, written to {output_csv.name}")
943
+ rprint(f"Found {len(results)} taxons, written to {_name_of(output_csv)}")
765
944
  _write_taxonomy_csv(results, output_csv)
766
945
 
767
946
 
768
947
  def _handle_search_interaction_partners(args: argparse.Namespace):
769
- uniprot_acc: str = args.uniprot_acc
948
+ uniprot_accession: str = args.uniprot_accession
770
949
  excludes: set[str] = set(args.exclude) if args.exclude else set()
771
950
  limit: int = args.limit
772
951
  timeout: int = args.timeout
773
952
  output_csv: TextIOWrapper = args.output_csv
774
953
 
775
- rprint(f"Searching for interaction partners of '{uniprot_acc}'")
776
- results = search4interaction_partners(uniprot_acc, excludes=excludes, limit=limit, timeout=timeout)
777
- rprint(f"Found {len(results)} interaction partners, written to {output_csv.name}")
954
+ rprint(f"Searching for interaction partners of '{uniprot_accession}'")
955
+ results = search4interaction_partners(uniprot_accession, excludes=excludes, limit=limit, timeout=timeout)
956
+ rprint(f"Found {len(results)} interaction partners, written to {_name_of(output_csv)}")
778
957
  _write_lines(output_csv, results.keys())
779
958
 
780
959
 
781
960
  def _handle_search_complexes(args: argparse.Namespace):
782
- uniprot_accs = args.uniprot_accs
961
+ uniprot_accessions = args.uniprot_accessions
783
962
  limit = args.limit
784
963
  timeout = args.timeout
785
964
  output_csv = args.output_csv
786
965
 
787
- accs = _read_lines(uniprot_accs)
966
+ accs = _read_lines(uniprot_accessions)
788
967
  rprint(f"Finding complexes for {len(accs)} uniprot accessions")
789
968
  results = search4macromolecular_complexes(accs, limit=limit, timeout=timeout)
790
- rprint(f"Found {len(results)} complexes, written to {output_csv.name}")
969
+ rprint(f"Found {len(results)} complexes, written to {_name_of(output_csv)}")
791
970
  _write_complexes_csv(results, output_csv)
792
971
 
793
972
 
973
+ def _handle_search_uniprot_details(args: argparse.Namespace):
974
+ uniprot_accessions = args.uniprot_accessions
975
+ timeout = args.timeout
976
+ batch_size = args.batch_size
977
+ output_csv: TextIOWrapper = args.output_csv
978
+
979
+ accs = _read_lines(uniprot_accessions)
980
+ rprint(f"Retrieving UniProt entry details for {len(accs)} uniprot accessions")
981
+ results = list(map_uniprot_accessions2uniprot_details(accs, timeout=timeout, batch_size=batch_size))
982
+ _write_uniprot_details_csv(output_csv, results)
983
+ rprint(f"Retrieved details for {len(results)} UniProt entries, written to {_name_of(output_csv)}")
984
+
985
+
794
986
  def _initialize_cacher(args: argparse.Namespace) -> Cacher:
795
987
  if args.no_cache:
796
988
  return PassthroughCacher()
@@ -816,27 +1008,30 @@ def _handle_retrieve_pdbe(args: argparse.Namespace):
816
1008
 
817
1009
  def _handle_retrieve_alphafold(args):
818
1010
  download_dir = args.output_dir
819
- what_formats = args.what_formats
1011
+ raw_formats = args.format
820
1012
  alphafold_csv = args.alphafold_csv
821
1013
  max_parallel_downloads = args.max_parallel_downloads
822
1014
  cacher = _initialize_cacher(args)
823
1015
  gzip_files = args.gzip_files
1016
+ all_isoforms = args.all_isoforms
1017
+ db_version = args.db_version
824
1018
 
825
- if what_formats is None:
826
- what_formats = {"summary", "cif"}
1019
+ if raw_formats is None:
1020
+ raw_formats = {"cif"}
827
1021
 
828
- # TODO besides `uniprot_acc,af_id\n` csv also allow headless single column format
829
- #
1022
+ # TODO besides `uniprot_accession,af_id\n` csv also allow headless single column format
830
1023
  af_ids = _read_column_from_csv(alphafold_csv, "af_id")
831
- validated_what: set[DownloadableFormat] = structure(what_formats, set[DownloadableFormat])
832
- rprint(f"Retrieving {len(af_ids)} AlphaFold entries with formats {validated_what}")
1024
+ formats: set[DownloadableFormat] = structure(raw_formats, set[DownloadableFormat])
1025
+ rprint(f"Retrieving {len(af_ids)} AlphaFold entries with formats {formats}")
833
1026
  afs = af_fetch(
834
1027
  af_ids,
835
1028
  download_dir,
836
- what=validated_what,
1029
+ formats=formats,
1030
+ db_version=db_version,
837
1031
  max_parallel_downloads=max_parallel_downloads,
838
1032
  cacher=cacher,
839
1033
  gzip_files=gzip_files,
1034
+ all_isoforms=all_isoforms,
840
1035
  )
841
1036
  total_nr_files = sum(af.nr_of_files() for af in afs)
842
1037
  rprint(f"Retrieved {total_nr_files} AlphaFold files and {len(afs)} summaries, written to {download_dir}")
@@ -891,11 +1086,11 @@ def _handle_filter_confidence(args: argparse.Namespace):
891
1086
  if r.filtered_file:
892
1087
  passed_count += 1
893
1088
  if stats_file:
894
- writer.writerow([r.input_file, r.count, r.filtered_file is not None, r.filtered_file])
1089
+ writer.writerow([r.input_file, r.count, r.filtered_file is not None, r.filtered_file]) # pyright: ignore[reportPossiblyUnboundVariable]
895
1090
 
896
1091
  rprint(f"Filtered {passed_count} mmcif/PDB files by confidence, written to {output_dir} directory")
897
1092
  if stats_file:
898
- rprint(f"Statistics written to {stats_file.name}")
1093
+ rprint(f"Statistics written to {_name_of(stats_file)}")
899
1094
 
900
1095
 
901
1096
  def _handle_filter_chain(args):
@@ -961,13 +1156,13 @@ def _handle_filter_residue(args):
961
1156
  input_files, output_dir, min_residues=min_residues, max_residues=max_residues, copy_method=copy_method
962
1157
  ):
963
1158
  if stats_file:
964
- writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file])
1159
+ writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file]) # pyright: ignore[reportPossiblyUnboundVariable]
965
1160
  if r.passed:
966
1161
  nr_passed += 1
967
1162
 
968
1163
  rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
969
1164
  if stats_file:
970
- rprint(f"Statistics written to {stats_file.name}")
1165
+ rprint(f"Statistics written to {_name_of(stats_file)}")
971
1166
 
972
1167
 
973
1168
  def _handle_filter_ss(args):
@@ -1015,7 +1210,7 @@ def _handle_filter_ss(args):
1015
1210
  copyfile(input_file, output_file, copy_method)
1016
1211
  nr_passed += 1
1017
1212
  if stats_file:
1018
- writer.writerow(
1213
+ writer.writerow( # pyright: ignore[reportPossiblyUnboundVariable]
1019
1214
  [
1020
1215
  input_file,
1021
1216
  result.stats.nr_residues,
@@ -1029,7 +1224,7 @@ def _handle_filter_ss(args):
1029
1224
  )
1030
1225
  rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
1031
1226
  if stats_file:
1032
- rprint(f"Statistics written to {stats_file.name}")
1227
+ rprint(f"Statistics written to {_name_of(stats_file)}")
1033
1228
 
1034
1229
 
1035
1230
  def _handle_mcp(args):
@@ -1045,9 +1240,30 @@ def _handle_mcp(args):
1045
1240
  mcp.run(transport=args.transport, host=args.host, port=args.port)
1046
1241
 
1047
1242
 
1048
- def _handle_convert(args):
1243
+ def _handle_convert_uniprot(args):
1244
+ input_dir = structure(args.input_dir, Path)
1245
+ output_file: TextIOWrapper = args.output
1246
+ grouped: bool = args.grouped
1247
+ input_files = sorted(glob_structure_files(input_dir))
1248
+ if grouped:
1249
+ for input_file in tqdm(input_files, unit="file"):
1250
+ s = read_structure(input_file)
1251
+ uniprot_accessions = structure2uniprot_accessions(s)
1252
+ _write_lines(
1253
+ output_file, [f"{input_file},{uniprot_accession}" for uniprot_accession in sorted(uniprot_accessions)]
1254
+ )
1255
+ else:
1256
+ uniprot_accessions: set[str] = set()
1257
+ for input_file in tqdm(input_files, unit="file"):
1258
+ s = read_structure(input_file)
1259
+ uniprot_accessions.update(structure2uniprot_accessions(s))
1260
+ _write_lines(output_file, sorted(uniprot_accessions))
1261
+
1262
+
1263
+ def _handle_convert_structures(args):
1049
1264
  input_dir = structure(args.input_dir, Path)
1050
1265
  output_dir = input_dir if args.output_dir is None else structure(args.output_dir, Path)
1266
+ output_dir.mkdir(parents=True, exist_ok=True)
1051
1267
  copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
1052
1268
 
1053
1269
  input_files = sorted(glob_structure_files(input_dir))
@@ -1070,7 +1286,8 @@ def _read_lines(file: TextIOWrapper) -> list[str]:
1070
1286
 
1071
1287
 
1072
1288
  def _make_sure_parent_exists(file: TextIOWrapper):
1073
- if file.name != "<stdout>":
1289
+ # Can not create dir for stdout
1290
+ with suppress(AttributeError):
1074
1291
  Path(file.name).parent.mkdir(parents=True, exist_ok=True)
1075
1292
 
1076
1293
 
@@ -1079,34 +1296,35 @@ def _write_lines(file: TextIOWrapper, lines: Iterable[str]):
1079
1296
  file.writelines(line + os.linesep for line in lines)
1080
1297
 
1081
1298
 
1082
- def _write_pdbe_csv(path: TextIOWrapper, data: dict[str, set[PdbResult]]):
1299
+ def _write_pdbe_csv(path: TextIOWrapper, data: PdbResults):
1083
1300
  _make_sure_parent_exists(path)
1084
- fieldnames = ["uniprot_acc", "pdb_id", "method", "resolution", "uniprot_chains", "chain"]
1301
+ fieldnames = ["uniprot_accession", "pdb_id", "method", "resolution", "uniprot_chains", "chain", "chain_length"]
1085
1302
  writer = csv.DictWriter(path, fieldnames=fieldnames)
1086
1303
  writer.writeheader()
1087
- for uniprot_acc, entries in sorted(data.items()):
1304
+ for uniprot_accession, entries in sorted(data.items()):
1088
1305
  for e in sorted(entries, key=lambda x: (x.id, x.method)):
1089
1306
  writer.writerow(
1090
1307
  {
1091
- "uniprot_acc": uniprot_acc,
1308
+ "uniprot_accession": uniprot_accession,
1092
1309
  "pdb_id": e.id,
1093
1310
  "method": e.method,
1094
1311
  "resolution": e.resolution or "",
1095
1312
  "uniprot_chains": e.uniprot_chains,
1096
1313
  "chain": e.chain,
1314
+ "chain_length": e.chain_length,
1097
1315
  }
1098
1316
  )
1099
1317
 
1100
1318
 
1101
1319
  def _write_dict_of_sets2csv(file: TextIOWrapper, data: dict[str, set[str]], ref_id_field: str):
1102
1320
  _make_sure_parent_exists(file)
1103
- fieldnames = ["uniprot_acc", ref_id_field]
1321
+ fieldnames = ["uniprot_accession", ref_id_field]
1104
1322
 
1105
1323
  writer = csv.DictWriter(file, fieldnames=fieldnames)
1106
1324
  writer.writeheader()
1107
- for uniprot_acc, ref_ids in sorted(data.items()):
1325
+ for uniprot_accession, ref_ids in sorted(data.items()):
1108
1326
  for ref_id in sorted(ref_ids):
1109
- writer.writerow({"uniprot_acc": uniprot_acc, ref_id_field: ref_id})
1327
+ writer.writerow({"uniprot_accession": uniprot_accession, ref_id_field: ref_id})
1110
1328
 
1111
1329
 
1112
1330
  def _iter_csv_rows(file: TextIOWrapper) -> Generator[dict[str, str]]:
@@ -1148,6 +1366,21 @@ def _write_complexes_csv(complexes: list[ComplexPortalEntry], output_csv: TextIO
1148
1366
  )
1149
1367
 
1150
1368
 
1369
+ def _write_uniprot_details_csv(
1370
+ output_csv: TextIOWrapper,
1371
+ uniprot_details_list: Iterable[UniprotDetails],
1372
+ ) -> None:
1373
+ if not uniprot_details_list:
1374
+ msg = "No UniProt entries found for given accessions"
1375
+ raise ValueError(msg)
1376
+ # As all props of UniprotDetails are scalar, we can directly unstructure to dicts
1377
+ rows = converter.unstructure(uniprot_details_list)
1378
+ fieldnames = rows[0].keys()
1379
+ writer = csv.DictWriter(output_csv, fieldnames=fieldnames)
1380
+ writer.writeheader()
1381
+ writer.writerows(rows)
1382
+
1383
+
1151
1384
  HANDLERS: dict[tuple[str, str | None], Callable] = {
1152
1385
  ("search", "uniprot"): _handle_search_uniprot,
1153
1386
  ("search", "pdbe"): _handle_search_pdbe,
@@ -1157,6 +1390,7 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
1157
1390
  ("search", "taxonomy"): _handle_search_taxonomy,
1158
1391
  ("search", "interaction-partners"): _handle_search_interaction_partners,
1159
1392
  ("search", "complexes"): _handle_search_complexes,
1393
+ ("search", "uniprot-details"): _handle_search_uniprot_details,
1160
1394
  ("retrieve", "pdbe"): _handle_retrieve_pdbe,
1161
1395
  ("retrieve", "alphafold"): _handle_retrieve_alphafold,
1162
1396
  ("retrieve", "emdb"): _handle_retrieve_emdb,
@@ -1165,15 +1399,20 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
1165
1399
  ("filter", "residue"): _handle_filter_residue,
1166
1400
  ("filter", "secondary-structure"): _handle_filter_ss,
1167
1401
  ("mcp", None): _handle_mcp,
1168
- ("convert", None): _handle_convert,
1402
+ ("convert", "structures"): _handle_convert_structures,
1403
+ ("convert", "uniprot"): _handle_convert_uniprot,
1169
1404
  }
1170
1405
 
1171
1406
 
1172
- def main():
1173
- """Main entry point for the CLI."""
1407
+ def main(argv: Sequence[str] | None = None):
1408
+ """Main entry point for the CLI.
1409
+
1410
+ Args:
1411
+ argv: List of command line arguments. If None, uses sys.argv.
1412
+ """
1174
1413
  parser = make_parser()
1175
- args = parser.parse_args()
1176
- logging.basicConfig(level=args.log_level, handlers=[RichHandler(show_level=False)])
1414
+ args = parser.parse_args(argv)
1415
+ logging.basicConfig(level=args.log_level, handlers=[RichHandler(show_level=False, console=console)])
1177
1416
 
1178
1417
  # Dispatch table to reduce complexity
1179
1418
  cmd = args.command