protein-quest 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/cli.py CHANGED
@@ -6,14 +6,15 @@ import csv
6
6
  import logging
7
7
  import os
8
8
  import sys
9
- from collections.abc import Callable, Generator, Iterable
9
+ from collections.abc import Callable, Generator, Iterable, Sequence
10
+ from contextlib import suppress
10
11
  from importlib.util import find_spec
11
- from io import TextIOWrapper
12
+ from io import BytesIO, TextIOWrapper
12
13
  from pathlib import Path
13
14
  from textwrap import dedent
14
15
 
15
16
  from cattrs import structure
16
- from rich import print as rprint
17
+ from rich.console import Console
17
18
  from rich.logging import RichHandler
18
19
  from rich.markdown import Markdown
19
20
  from rich.panel import Panel
@@ -24,7 +25,7 @@ from protein_quest.__version__ import __version__
24
25
  from protein_quest.alphafold.confidence import ConfidenceFilterQuery, filter_files_on_confidence
25
26
  from protein_quest.alphafold.fetch import DownloadableFormat, downloadable_formats
26
27
  from protein_quest.alphafold.fetch import fetch_many as af_fetch
27
- from protein_quest.converter import converter
28
+ from protein_quest.converter import PositiveInt, converter
28
29
  from protein_quest.emdb import fetch as emdb_fetch
29
30
  from protein_quest.filters import filter_files_on_chain, filter_files_on_residues
30
31
  from protein_quest.go import Aspect, allowed_aspects, search_gene_ontology_term, write_go_terms_to_csv
@@ -32,15 +33,20 @@ from protein_quest.io import (
32
33
  convert_to_cif_files,
33
34
  glob_structure_files,
34
35
  locate_structure_file,
36
+ read_structure,
35
37
  valid_structure_file_extensions,
36
38
  )
37
39
  from protein_quest.pdbe import fetch as pdbe_fetch
38
40
  from protein_quest.ss import SecondaryStructureFilterQuery, filter_files_on_secondary_structure
41
+ from protein_quest.structure import structure2uniprot_accessions
39
42
  from protein_quest.taxonomy import SearchField, _write_taxonomy_csv, search_fields, search_taxon
40
43
  from protein_quest.uniprot import (
41
44
  ComplexPortalEntry,
42
- PdbResult,
45
+ PdbResults,
43
46
  Query,
47
+ UniprotDetails,
48
+ filter_pdb_results_on_chain_length,
49
+ map_uniprot_accessions2uniprot_details,
44
50
  search4af,
45
51
  search4emdb,
46
52
  search4interaction_partners,
@@ -58,6 +64,8 @@ from protein_quest.utils import (
58
64
  user_cache_root_dir,
59
65
  )
60
66
 
67
+ console = Console(stderr=True)
68
+ rprint = console.print
61
69
  logger = logging.getLogger(__name__)
62
70
 
63
71
 
@@ -98,6 +106,8 @@ def _add_search_uniprot_parser(subparsers: argparse._SubParsersAction):
98
106
  action="append",
99
107
  help="GO term(s) for molecular function (e.g. GO:0003677). Can be given multiple times.",
100
108
  )
109
+ parser.add_argument("--min-sequence-length", type=int, help="Minimum length of the canonical sequence.")
110
+ parser.add_argument("--max-sequence-length", type=int, help="Maximum length of the canonical sequence.")
101
111
  parser.add_argument("--limit", type=int, default=10_000, help="Maximum number of uniprot accessions to return")
102
112
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
103
113
 
@@ -111,7 +121,7 @@ def _add_search_pdbe_parser(subparsers: argparse._SubParsersAction):
111
121
  formatter_class=ArgumentDefaultsRichHelpFormatter,
112
122
  )
113
123
  parser.add_argument(
114
- "uniprot_accs",
124
+ "uniprot_accessions",
115
125
  type=argparse.FileType("r", encoding="UTF-8"),
116
126
  help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
117
127
  )
@@ -119,15 +129,27 @@ def _add_search_pdbe_parser(subparsers: argparse._SubParsersAction):
119
129
  "output_csv",
120
130
  type=argparse.FileType("w", encoding="UTF-8"),
121
131
  help=dedent("""\
122
- Output CSV with `uniprot_acc`, `pdb_id`, `method`, `resolution`, `uniprot_chains`, `chain` columns.
132
+ Output CSV with following columns:
133
+ `uniprot_accession`, `pdb_id`, `method`, `resolution`, `uniprot_chains`, `chain`, `chain_length`.
123
134
  Where `uniprot_chains` is the raw UniProt chain string, for example `A=1-100`.
124
- and where `chain` is the first chain from `uniprot_chains`, for example `A`.
135
+ and where `chain` is the first chain from `uniprot_chains`, for example `A`
136
+ and `chain_length` is the length of the chain, for example `100`.
125
137
  Use `-` for stdout.
126
138
  """),
127
139
  )
128
140
  parser.add_argument(
129
141
  "--limit", type=int, default=10_000, help="Maximum number of PDB uniprot accessions combinations to return"
130
142
  )
143
+ parser.add_argument(
144
+ "--min-residues",
145
+ type=int,
146
+ help="Minimum number of residues required in the chain mapped to the UniProt accession.",
147
+ )
148
+ parser.add_argument(
149
+ "--max-residues",
150
+ type=int,
151
+ help="Maximum number of residues allowed in chain mapped to the UniProt accession.",
152
+ )
131
153
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
132
154
 
133
155
 
@@ -140,7 +162,7 @@ def _add_search_alphafold_parser(subparsers: argparse._SubParsersAction):
140
162
  formatter_class=ArgumentDefaultsRichHelpFormatter,
141
163
  )
142
164
  parser.add_argument(
143
- "uniprot_accs",
165
+ "uniprot_accessions",
144
166
  type=argparse.FileType("r", encoding="UTF-8"),
145
167
  help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
146
168
  )
@@ -149,6 +171,8 @@ def _add_search_alphafold_parser(subparsers: argparse._SubParsersAction):
149
171
  type=argparse.FileType("w", encoding="UTF-8"),
150
172
  help="Output CSV with AlphaFold IDs per UniProt accession. Use `-` for stdout.",
151
173
  )
174
+ parser.add_argument("--min-sequence-length", type=int, help="Minimum length of the canonical sequence.")
175
+ parser.add_argument("--max-sequence-length", type=int, help="Maximum length of the canonical sequence.")
152
176
  parser.add_argument(
153
177
  "--limit", type=int, default=10_000, help="Maximum number of Alphafold entry identifiers to return"
154
178
  )
@@ -247,7 +271,7 @@ def _add_search_interaction_partners_parser(subparsers: argparse._SubParsersActi
247
271
  formatter_class=ArgumentDefaultsRichHelpFormatter,
248
272
  )
249
273
  parser.add_argument(
250
- "uniprot_acc",
274
+ "uniprot_accession",
251
275
  type=str,
252
276
  help="UniProt accession (for example P12345).",
253
277
  )
@@ -289,7 +313,7 @@ def _add_search_complexes_parser(subparsers: argparse._SubParsersAction):
289
313
  formatter_class=ArgumentDefaultsRichHelpFormatter,
290
314
  )
291
315
  parser.add_argument(
292
- "uniprot_accs",
316
+ "uniprot_accessions",
293
317
  type=argparse.FileType("r", encoding="UTF-8"),
294
318
  help="Text file with UniProt accessions (one per line) as query for searching complexes. Use `-` for stdin.",
295
319
  )
@@ -302,6 +326,44 @@ def _add_search_complexes_parser(subparsers: argparse._SubParsersAction):
302
326
  parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
303
327
 
304
328
 
329
+ def _add_search_uniprot_details_parser(subparsers: argparse._SubParsersAction):
330
+ """Add search uniprot details subcommand parser."""
331
+ description = dedent("""\
332
+ Retrieve UniProt details for given UniProt accessions
333
+ from the Uniprot SPARQL endpoint.
334
+
335
+ The output CSV file has the following columns:
336
+
337
+ - uniprot_accession: UniProt accession.
338
+ - uniprot_id: UniProt ID (mnemonic).
339
+ - sequence_length: Length of the canonical sequence.
340
+ - reviewed: Whether the entry is reviewed (Swiss-Prot) or unreviewed (TrEMBL).
341
+ - protein_name: Recommended protein name.
342
+ - taxon_id: NCBI Taxonomy ID of the organism.
343
+ - taxon_name: Scientific name of the organism.
344
+
345
+ The order of the output CSV can be different from the input order.
346
+ """)
347
+ parser = subparsers.add_parser(
348
+ "uniprot-details",
349
+ help="Retrieve UniProt details for given UniProt accessions",
350
+ description=Markdown(description, style="argparse.text"), # type: ignore using rich formatter makes this OK
351
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
352
+ )
353
+ parser.add_argument(
354
+ "uniprot_accessions",
355
+ type=argparse.FileType("r", encoding="UTF-8"),
356
+ help="Text file with UniProt accessions (one per line). Use `-` for stdin.",
357
+ )
358
+ parser.add_argument(
359
+ "output_csv",
360
+ type=argparse.FileType("w", encoding="UTF-8"),
361
+ help="Output CSV with UniProt details. Use `-` for stdout.",
362
+ )
363
+ parser.add_argument("--timeout", type=int, default=1_800, help="Maximum seconds to wait for query to complete")
364
+ parser.add_argument("--batch-size", type=int, default=1_000, help="Number of accessions to query per batch")
365
+
366
+
305
367
  def _add_copy_method_arguments(parser):
306
368
  parser.add_argument(
307
369
  "--copy-method",
@@ -387,6 +449,14 @@ def _add_retrieve_alphafold_parser(subparsers: argparse._SubParsersAction):
387
449
  action="store_true",
388
450
  help="Whether to gzip the downloaded files. Excludes summary files, they are always uncompressed.",
389
451
  )
452
+ parser.add_argument(
453
+ "--all-isoforms",
454
+ action="store_true",
455
+ help=(
456
+ "Whether to return all isoforms of each uniprot entry. "
457
+ "If not given then only the Alphafold entry for the canonical sequence is returned."
458
+ ),
459
+ )
390
460
  parser.add_argument(
391
461
  "--max-parallel-downloads",
392
462
  type=int,
@@ -575,6 +645,7 @@ def _add_search_subcommands(subparsers: argparse._SubParsersAction):
575
645
  _add_search_taxonomy_parser(subsubparsers)
576
646
  _add_search_interaction_partners_parser(subsubparsers)
577
647
  _add_search_complexes_parser(subsubparsers)
648
+ _add_search_uniprot_details_parser(subsubparsers)
578
649
 
579
650
 
580
651
  def _add_retrieve_subcommands(subparsers: argparse._SubParsersAction):
@@ -603,10 +674,39 @@ def _add_filter_subcommands(subparsers: argparse._SubParsersAction):
603
674
  _add_filter_ss_parser(subsubparsers)
604
675
 
605
676
 
606
- def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
607
- """Add convert command."""
677
+ def _add_convert_uniprot_parser(subparsers: argparse._SubParsersAction):
678
+ """Add convert uniprot subcommand parser."""
679
+ parser = subparsers.add_parser(
680
+ "uniprot",
681
+ help="Convert structure files to list of UniProt accessions.",
682
+ description="Convert structure files to list of UniProt accessions. "
683
+ "Uniprot accessions are read from database reference of each structure.",
684
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
685
+ )
686
+ parser.add_argument(
687
+ "input_dir",
688
+ type=Path,
689
+ help=f"Directory with structure files. Supported extensions are {valid_structure_file_extensions}",
690
+ )
691
+ parser.add_argument(
692
+ "output",
693
+ type=argparse.FileType("wt", encoding="UTF-8"),
694
+ help="Output text file with UniProt accessions (one per line). Use '-' for stdout.",
695
+ )
696
+ parser.add_argument(
697
+ "--grouped",
698
+ action="store_true",
699
+ help="Whether to group accessions by structure file. "
700
+ "If set output changes to `<structure_file1>,<acc1>\\n<structure_file1>,<acc2>` format.",
701
+ )
702
+
703
+
704
+ def _add_convert_structures_parser(subparsers: argparse._SubParsersAction):
705
+ """Add convert structures subcommand parser."""
608
706
  parser = subparsers.add_parser(
609
- "convert", help="Convert structure files between formats", formatter_class=ArgumentDefaultsRichHelpFormatter
707
+ "structures",
708
+ help="Convert structure files between formats",
709
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
610
710
  )
611
711
  parser.add_argument(
612
712
  "input_dir",
@@ -630,6 +730,19 @@ def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
630
730
  _add_copy_method_arguments(parser)
631
731
 
632
732
 
733
+ def _add_convert_subcommands(subparsers: argparse._SubParsersAction):
734
+ """Add convert command and its subcommands."""
735
+ parser = subparsers.add_parser(
736
+ "convert",
737
+ help="Convert files between formats",
738
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
739
+ )
740
+ subsubparsers = parser.add_subparsers(dest="convert_cmd", required=True)
741
+
742
+ _add_convert_structures_parser(subsubparsers)
743
+ _add_convert_uniprot_parser(subsubparsers)
744
+
745
+
633
746
  def _add_mcp_command(subparsers: argparse._SubParsersAction):
634
747
  """Add MCP command."""
635
748
 
@@ -667,12 +780,22 @@ def make_parser() -> argparse.ArgumentParser:
667
780
  return parser
668
781
 
669
782
 
783
+ def _name_of(file: TextIOWrapper | BytesIO) -> str:
784
+ try:
785
+ return file.name
786
+ except AttributeError:
787
+ # In pytest BytesIO is used stdout which has no 'name' attribute
788
+ return "<stdout>"
789
+
790
+
670
791
  def _handle_search_uniprot(args):
671
792
  taxon_id = args.taxon_id
672
793
  reviewed = args.reviewed
673
794
  subcellular_location_uniprot = args.subcellular_location_uniprot
674
795
  subcellular_location_go = args.subcellular_location_go
675
796
  molecular_function_go = args.molecular_function_go
797
+ min_sequence_length = args.min_sequence_length
798
+ max_sequence_length = args.max_sequence_length
676
799
  limit = args.limit
677
800
  timeout = args.timeout
678
801
  output_file = args.output
@@ -684,54 +807,78 @@ def _handle_search_uniprot(args):
684
807
  "subcellular_location_uniprot": subcellular_location_uniprot,
685
808
  "subcellular_location_go": subcellular_location_go,
686
809
  "molecular_function_go": molecular_function_go,
810
+ "min_sequence_length": min_sequence_length,
811
+ "max_sequence_length": max_sequence_length,
687
812
  },
688
813
  Query,
689
814
  )
690
815
  rprint("Searching for UniProt accessions")
691
816
  accs = search4uniprot(query=query, limit=limit, timeout=timeout)
692
- rprint(f"Found {len(accs)} UniProt accessions, written to {output_file.name}")
817
+ rprint(f"Found {len(accs)} UniProt accessions, written to {_name_of(output_file)}")
693
818
  _write_lines(output_file, sorted(accs))
694
819
 
695
820
 
696
821
  def _handle_search_pdbe(args):
697
- uniprot_accs = args.uniprot_accs
822
+ uniprot_accessions = args.uniprot_accessions
698
823
  limit = args.limit
699
824
  timeout = args.timeout
700
825
  output_csv = args.output_csv
826
+ min_residues = converter.structure(args.min_residues, PositiveInt | None) # pyright: ignore[reportArgumentType]
827
+ max_residues = converter.structure(args.max_residues, PositiveInt | None) # pyright: ignore[reportArgumentType]
701
828
 
702
- accs = set(_read_lines(uniprot_accs))
829
+ accs = set(_read_lines(uniprot_accessions))
703
830
  rprint(f"Finding PDB entries for {len(accs)} uniprot accessions")
704
831
  results = search4pdb(accs, limit=limit, timeout=timeout)
705
- total_pdbs = sum([len(v) for v in results.values()])
706
- rprint(f"Found {total_pdbs} PDB entries for {len(results)} uniprot accessions")
707
- rprint(f"Written to {output_csv.name}")
832
+
833
+ raw_nr_results = len(results)
834
+ raw_total_pdbs = sum([len(v) for v in results.values()])
835
+ if min_residues or max_residues:
836
+ results = filter_pdb_results_on_chain_length(results, min_residues, max_residues)
837
+ total_pdbs = sum([len(v) for v in results.values()])
838
+ rprint(f"Before filtering found {raw_total_pdbs} PDB entries for {raw_nr_results} uniprot accessions.")
839
+ rprint(
840
+ f"After filtering on chain length ({min_residues}, {max_residues}) "
841
+ f"remained {total_pdbs} PDB entries for {len(results)} uniprot accessions."
842
+ )
843
+ else:
844
+ rprint(f"Found {raw_total_pdbs} PDB entries for {raw_nr_results} uniprot accessions")
845
+
708
846
  _write_pdbe_csv(output_csv, results)
847
+ rprint(f"Written to {_name_of(output_csv)}")
709
848
 
710
849
 
711
850
  def _handle_search_alphafold(args):
712
- uniprot_accs = args.uniprot_accs
851
+ uniprot_accessions = args.uniprot_accessions
852
+ min_sequence_length = converter.structure(args.min_sequence_length, PositiveInt | None) # pyright: ignore[reportArgumentType]
853
+ max_sequence_length = converter.structure(args.max_sequence_length, PositiveInt | None) # pyright: ignore[reportArgumentType]
713
854
  limit = args.limit
714
855
  timeout = args.timeout
715
856
  output_csv = args.output_csv
716
857
 
717
- accs = _read_lines(uniprot_accs)
858
+ accs = _read_lines(uniprot_accessions)
718
859
  rprint(f"Finding AlphaFold entries for {len(accs)} uniprot accessions")
719
- results = search4af(accs, limit=limit, timeout=timeout)
720
- rprint(f"Found {len(results)} AlphaFold entries, written to {output_csv.name}")
860
+ results = search4af(
861
+ accs,
862
+ min_sequence_length=min_sequence_length,
863
+ max_sequence_length=max_sequence_length,
864
+ limit=limit,
865
+ timeout=timeout,
866
+ )
867
+ rprint(f"Found {len(results)} AlphaFold entries, written to {_name_of(output_csv)}")
721
868
  _write_dict_of_sets2csv(output_csv, results, "af_id")
722
869
 
723
870
 
724
871
  def _handle_search_emdb(args):
725
- uniprot_accs = args.uniprot_accs
872
+ uniprot_accessions = args.uniprot_accessions
726
873
  limit = args.limit
727
874
  timeout = args.timeout
728
875
  output_csv = args.output_csv
729
876
 
730
- accs = _read_lines(uniprot_accs)
877
+ accs = _read_lines(uniprot_accessions)
731
878
  rprint(f"Finding EMDB entries for {len(accs)} uniprot accessions")
732
879
  results = search4emdb(accs, limit=limit, timeout=timeout)
733
880
  total_emdbs = sum([len(v) for v in results.values()])
734
- rprint(f"Found {total_emdbs} EMDB entries, written to {output_csv.name}")
881
+ rprint(f"Found {total_emdbs} EMDB entries, written to {_name_of(output_csv)}")
735
882
  _write_dict_of_sets2csv(output_csv, results, "emdb_id")
736
883
 
737
884
 
@@ -746,7 +893,7 @@ def _handle_search_go(args):
746
893
  else:
747
894
  rprint(f"Searching for GO terms matching '{term}'")
748
895
  results = asyncio.run(search_gene_ontology_term(term, aspect=aspect, limit=limit))
749
- rprint(f"Found {len(results)} GO terms, written to {output_csv.name}")
896
+ rprint(f"Found {len(results)} GO terms, written to {_name_of(output_csv)}")
750
897
  write_go_terms_to_csv(results, output_csv)
751
898
 
752
899
 
@@ -761,36 +908,49 @@ def _handle_search_taxonomy(args):
761
908
  else:
762
909
  rprint(f"Searching for taxon information matching '{query}'")
763
910
  results = asyncio.run(search_taxon(query=query, field=field, limit=limit))
764
- rprint(f"Found {len(results)} taxons, written to {output_csv.name}")
911
+ rprint(f"Found {len(results)} taxons, written to {_name_of(output_csv)}")
765
912
  _write_taxonomy_csv(results, output_csv)
766
913
 
767
914
 
768
915
  def _handle_search_interaction_partners(args: argparse.Namespace):
769
- uniprot_acc: str = args.uniprot_acc
916
+ uniprot_accession: str = args.uniprot_accession
770
917
  excludes: set[str] = set(args.exclude) if args.exclude else set()
771
918
  limit: int = args.limit
772
919
  timeout: int = args.timeout
773
920
  output_csv: TextIOWrapper = args.output_csv
774
921
 
775
- rprint(f"Searching for interaction partners of '{uniprot_acc}'")
776
- results = search4interaction_partners(uniprot_acc, excludes=excludes, limit=limit, timeout=timeout)
777
- rprint(f"Found {len(results)} interaction partners, written to {output_csv.name}")
922
+ rprint(f"Searching for interaction partners of '{uniprot_accession}'")
923
+ results = search4interaction_partners(uniprot_accession, excludes=excludes, limit=limit, timeout=timeout)
924
+ rprint(f"Found {len(results)} interaction partners, written to {_name_of(output_csv)}")
778
925
  _write_lines(output_csv, results.keys())
779
926
 
780
927
 
781
928
  def _handle_search_complexes(args: argparse.Namespace):
782
- uniprot_accs = args.uniprot_accs
929
+ uniprot_accessions = args.uniprot_accessions
783
930
  limit = args.limit
784
931
  timeout = args.timeout
785
932
  output_csv = args.output_csv
786
933
 
787
- accs = _read_lines(uniprot_accs)
934
+ accs = _read_lines(uniprot_accessions)
788
935
  rprint(f"Finding complexes for {len(accs)} uniprot accessions")
789
936
  results = search4macromolecular_complexes(accs, limit=limit, timeout=timeout)
790
- rprint(f"Found {len(results)} complexes, written to {output_csv.name}")
937
+ rprint(f"Found {len(results)} complexes, written to {_name_of(output_csv)}")
791
938
  _write_complexes_csv(results, output_csv)
792
939
 
793
940
 
941
+ def _handle_search_uniprot_details(args: argparse.Namespace):
942
+ uniprot_accessions = args.uniprot_accessions
943
+ timeout = args.timeout
944
+ batch_size = args.batch_size
945
+ output_csv: TextIOWrapper = args.output_csv
946
+
947
+ accs = _read_lines(uniprot_accessions)
948
+ rprint(f"Retrieving UniProt entry details for {len(accs)} uniprot accessions")
949
+ results = list(map_uniprot_accessions2uniprot_details(accs, timeout=timeout, batch_size=batch_size))
950
+ _write_uniprot_details_csv(output_csv, results)
951
+ rprint(f"Retrieved details for {len(results)} UniProt entries, written to {_name_of(output_csv)}")
952
+
953
+
794
954
  def _initialize_cacher(args: argparse.Namespace) -> Cacher:
795
955
  if args.no_cache:
796
956
  return PassthroughCacher()
@@ -821,11 +981,12 @@ def _handle_retrieve_alphafold(args):
821
981
  max_parallel_downloads = args.max_parallel_downloads
822
982
  cacher = _initialize_cacher(args)
823
983
  gzip_files = args.gzip_files
984
+ all_isoforms = args.all_isoforms
824
985
 
825
986
  if what_formats is None:
826
987
  what_formats = {"summary", "cif"}
827
988
 
828
- # TODO besides `uniprot_acc,af_id\n` csv also allow headless single column format
989
+ # TODO besides `uniprot_accession,af_id\n` csv also allow headless single column format
829
990
  #
830
991
  af_ids = _read_column_from_csv(alphafold_csv, "af_id")
831
992
  validated_what: set[DownloadableFormat] = structure(what_formats, set[DownloadableFormat])
@@ -837,6 +998,7 @@ def _handle_retrieve_alphafold(args):
837
998
  max_parallel_downloads=max_parallel_downloads,
838
999
  cacher=cacher,
839
1000
  gzip_files=gzip_files,
1001
+ all_isoforms=all_isoforms,
840
1002
  )
841
1003
  total_nr_files = sum(af.nr_of_files() for af in afs)
842
1004
  rprint(f"Retrieved {total_nr_files} AlphaFold files and {len(afs)} summaries, written to {download_dir}")
@@ -891,11 +1053,11 @@ def _handle_filter_confidence(args: argparse.Namespace):
891
1053
  if r.filtered_file:
892
1054
  passed_count += 1
893
1055
  if stats_file:
894
- writer.writerow([r.input_file, r.count, r.filtered_file is not None, r.filtered_file])
1056
+ writer.writerow([r.input_file, r.count, r.filtered_file is not None, r.filtered_file]) # pyright: ignore[reportPossiblyUnboundVariable]
895
1057
 
896
1058
  rprint(f"Filtered {passed_count} mmcif/PDB files by confidence, written to {output_dir} directory")
897
1059
  if stats_file:
898
- rprint(f"Statistics written to {stats_file.name}")
1060
+ rprint(f"Statistics written to {_name_of(stats_file)}")
899
1061
 
900
1062
 
901
1063
  def _handle_filter_chain(args):
@@ -961,13 +1123,13 @@ def _handle_filter_residue(args):
961
1123
  input_files, output_dir, min_residues=min_residues, max_residues=max_residues, copy_method=copy_method
962
1124
  ):
963
1125
  if stats_file:
964
- writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file])
1126
+ writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file]) # pyright: ignore[reportPossiblyUnboundVariable]
965
1127
  if r.passed:
966
1128
  nr_passed += 1
967
1129
 
968
1130
  rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
969
1131
  if stats_file:
970
- rprint(f"Statistics written to {stats_file.name}")
1132
+ rprint(f"Statistics written to {_name_of(stats_file)}")
971
1133
 
972
1134
 
973
1135
  def _handle_filter_ss(args):
@@ -1015,7 +1177,7 @@ def _handle_filter_ss(args):
1015
1177
  copyfile(input_file, output_file, copy_method)
1016
1178
  nr_passed += 1
1017
1179
  if stats_file:
1018
- writer.writerow(
1180
+ writer.writerow( # pyright: ignore[reportPossiblyUnboundVariable]
1019
1181
  [
1020
1182
  input_file,
1021
1183
  result.stats.nr_residues,
@@ -1029,7 +1191,7 @@ def _handle_filter_ss(args):
1029
1191
  )
1030
1192
  rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
1031
1193
  if stats_file:
1032
- rprint(f"Statistics written to {stats_file.name}")
1194
+ rprint(f"Statistics written to {_name_of(stats_file)}")
1033
1195
 
1034
1196
 
1035
1197
  def _handle_mcp(args):
@@ -1045,9 +1207,30 @@ def _handle_mcp(args):
1045
1207
  mcp.run(transport=args.transport, host=args.host, port=args.port)
1046
1208
 
1047
1209
 
1048
- def _handle_convert(args):
1210
+ def _handle_convert_uniprot(args):
1211
+ input_dir = structure(args.input_dir, Path)
1212
+ output_file: TextIOWrapper = args.output
1213
+ grouped: bool = args.grouped
1214
+ input_files = sorted(glob_structure_files(input_dir))
1215
+ if grouped:
1216
+ for input_file in tqdm(input_files, unit="file"):
1217
+ s = read_structure(input_file)
1218
+ uniprot_accessions = structure2uniprot_accessions(s)
1219
+ _write_lines(
1220
+ output_file, [f"{input_file},{uniprot_accession}" for uniprot_accession in sorted(uniprot_accessions)]
1221
+ )
1222
+ else:
1223
+ uniprot_accessions: set[str] = set()
1224
+ for input_file in tqdm(input_files, unit="file"):
1225
+ s = read_structure(input_file)
1226
+ uniprot_accessions.update(structure2uniprot_accessions(s))
1227
+ _write_lines(output_file, sorted(uniprot_accessions))
1228
+
1229
+
1230
+ def _handle_convert_structures(args):
1049
1231
  input_dir = structure(args.input_dir, Path)
1050
1232
  output_dir = input_dir if args.output_dir is None else structure(args.output_dir, Path)
1233
+ output_dir.mkdir(parents=True, exist_ok=True)
1051
1234
  copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
1052
1235
 
1053
1236
  input_files = sorted(glob_structure_files(input_dir))
@@ -1070,7 +1253,8 @@ def _read_lines(file: TextIOWrapper) -> list[str]:
1070
1253
 
1071
1254
 
1072
1255
  def _make_sure_parent_exists(file: TextIOWrapper):
1073
- if file.name != "<stdout>":
1256
+ # Can not create dir for stdout
1257
+ with suppress(AttributeError):
1074
1258
  Path(file.name).parent.mkdir(parents=True, exist_ok=True)
1075
1259
 
1076
1260
 
@@ -1079,34 +1263,35 @@ def _write_lines(file: TextIOWrapper, lines: Iterable[str]):
1079
1263
  file.writelines(line + os.linesep for line in lines)
1080
1264
 
1081
1265
 
1082
- def _write_pdbe_csv(path: TextIOWrapper, data: dict[str, set[PdbResult]]):
1266
+ def _write_pdbe_csv(path: TextIOWrapper, data: PdbResults):
1083
1267
  _make_sure_parent_exists(path)
1084
- fieldnames = ["uniprot_acc", "pdb_id", "method", "resolution", "uniprot_chains", "chain"]
1268
+ fieldnames = ["uniprot_accession", "pdb_id", "method", "resolution", "uniprot_chains", "chain", "chain_length"]
1085
1269
  writer = csv.DictWriter(path, fieldnames=fieldnames)
1086
1270
  writer.writeheader()
1087
- for uniprot_acc, entries in sorted(data.items()):
1271
+ for uniprot_accession, entries in sorted(data.items()):
1088
1272
  for e in sorted(entries, key=lambda x: (x.id, x.method)):
1089
1273
  writer.writerow(
1090
1274
  {
1091
- "uniprot_acc": uniprot_acc,
1275
+ "uniprot_accession": uniprot_accession,
1092
1276
  "pdb_id": e.id,
1093
1277
  "method": e.method,
1094
1278
  "resolution": e.resolution or "",
1095
1279
  "uniprot_chains": e.uniprot_chains,
1096
1280
  "chain": e.chain,
1281
+ "chain_length": e.chain_length,
1097
1282
  }
1098
1283
  )
1099
1284
 
1100
1285
 
1101
1286
  def _write_dict_of_sets2csv(file: TextIOWrapper, data: dict[str, set[str]], ref_id_field: str):
1102
1287
  _make_sure_parent_exists(file)
1103
- fieldnames = ["uniprot_acc", ref_id_field]
1288
+ fieldnames = ["uniprot_accession", ref_id_field]
1104
1289
 
1105
1290
  writer = csv.DictWriter(file, fieldnames=fieldnames)
1106
1291
  writer.writeheader()
1107
- for uniprot_acc, ref_ids in sorted(data.items()):
1292
+ for uniprot_accession, ref_ids in sorted(data.items()):
1108
1293
  for ref_id in sorted(ref_ids):
1109
- writer.writerow({"uniprot_acc": uniprot_acc, ref_id_field: ref_id})
1294
+ writer.writerow({"uniprot_accession": uniprot_accession, ref_id_field: ref_id})
1110
1295
 
1111
1296
 
1112
1297
  def _iter_csv_rows(file: TextIOWrapper) -> Generator[dict[str, str]]:
@@ -1148,6 +1333,21 @@ def _write_complexes_csv(complexes: list[ComplexPortalEntry], output_csv: TextIO
1148
1333
  )
1149
1334
 
1150
1335
 
1336
+ def _write_uniprot_details_csv(
1337
+ output_csv: TextIOWrapper,
1338
+ uniprot_details_list: Iterable[UniprotDetails],
1339
+ ) -> None:
1340
+ if not uniprot_details_list:
1341
+ msg = "No UniProt entries found for given accessions"
1342
+ raise ValueError(msg)
1343
+ # As all props of UniprotDetails are scalar, we can directly unstructure to dicts
1344
+ rows = converter.unstructure(uniprot_details_list)
1345
+ fieldnames = rows[0].keys()
1346
+ writer = csv.DictWriter(output_csv, fieldnames=fieldnames)
1347
+ writer.writeheader()
1348
+ writer.writerows(rows)
1349
+
1350
+
1151
1351
  HANDLERS: dict[tuple[str, str | None], Callable] = {
1152
1352
  ("search", "uniprot"): _handle_search_uniprot,
1153
1353
  ("search", "pdbe"): _handle_search_pdbe,
@@ -1157,6 +1357,7 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
1157
1357
  ("search", "taxonomy"): _handle_search_taxonomy,
1158
1358
  ("search", "interaction-partners"): _handle_search_interaction_partners,
1159
1359
  ("search", "complexes"): _handle_search_complexes,
1360
+ ("search", "uniprot-details"): _handle_search_uniprot_details,
1160
1361
  ("retrieve", "pdbe"): _handle_retrieve_pdbe,
1161
1362
  ("retrieve", "alphafold"): _handle_retrieve_alphafold,
1162
1363
  ("retrieve", "emdb"): _handle_retrieve_emdb,
@@ -1165,15 +1366,20 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
1165
1366
  ("filter", "residue"): _handle_filter_residue,
1166
1367
  ("filter", "secondary-structure"): _handle_filter_ss,
1167
1368
  ("mcp", None): _handle_mcp,
1168
- ("convert", None): _handle_convert,
1369
+ ("convert", "structures"): _handle_convert_structures,
1370
+ ("convert", "uniprot"): _handle_convert_uniprot,
1169
1371
  }
1170
1372
 
1171
1373
 
1172
- def main():
1173
- """Main entry point for the CLI."""
1374
+ def main(argv: Sequence[str] | None = None):
1375
+ """Main entry point for the CLI.
1376
+
1377
+ Args:
1378
+ argv: List of command line arguments. If None, uses sys.argv.
1379
+ """
1174
1380
  parser = make_parser()
1175
- args = parser.parse_args()
1176
- logging.basicConfig(level=args.log_level, handlers=[RichHandler(show_level=False)])
1381
+ args = parser.parse_args(argv)
1382
+ logging.basicConfig(level=args.log_level, handlers=[RichHandler(show_level=False, console=console)])
1177
1383
 
1178
1384
  # Dispatch table to reduce complexity
1179
1385
  cmd = args.command