protein-quest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/cli.py CHANGED
@@ -5,7 +5,8 @@ import asyncio
5
5
  import csv
6
6
  import logging
7
7
  import os
8
- from collections.abc import Callable, Iterable
8
+ import sys
9
+ from collections.abc import Callable, Generator, Iterable
9
10
  from importlib.util import find_spec
10
11
  from io import TextIOWrapper
11
12
  from pathlib import Path
@@ -14,6 +15,7 @@ from textwrap import dedent
14
15
  from cattrs import structure
15
16
  from rich import print as rprint
16
17
  from rich.logging import RichHandler
18
+ from rich.panel import Panel
17
19
  from rich_argparse import ArgumentDefaultsRichHelpFormatter
18
20
  from tqdm.rich import tqdm
19
21
 
@@ -21,13 +23,16 @@ from protein_quest.__version__ import __version__
21
23
  from protein_quest.alphafold.confidence import ConfidenceFilterQuery, filter_files_on_confidence
22
24
  from protein_quest.alphafold.fetch import DownloadableFormat, downloadable_formats
23
25
  from protein_quest.alphafold.fetch import fetch_many as af_fetch
26
+ from protein_quest.converter import converter
24
27
  from protein_quest.emdb import fetch as emdb_fetch
25
28
  from protein_quest.filters import filter_files_on_chain, filter_files_on_residues
26
29
  from protein_quest.go import Aspect, allowed_aspects, search_gene_ontology_term, write_go_terms_to_csv
27
30
  from protein_quest.pdbe import fetch as pdbe_fetch
28
- from protein_quest.pdbe.io import glob_structure_files
31
+ from protein_quest.pdbe.io import glob_structure_files, locate_structure_file
32
+ from protein_quest.ss import SecondaryStructureFilterQuery, filter_files_on_secondary_structure
29
33
  from protein_quest.taxonomy import SearchField, _write_taxonomy_csv, search_fields, search_taxon
30
34
  from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
35
+ from protein_quest.utils import CopyMethod, copy_methods, copyfile
31
36
 
32
37
  logger = logging.getLogger(__name__)
33
38
 
@@ -246,12 +251,12 @@ def _add_retrieve_alphafold_parser(subparsers: argparse._SubParsersAction):
246
251
  )
247
252
  parser.add_argument("output_dir", type=Path, help="Directory to store downloaded AlphaFold files")
248
253
  parser.add_argument(
249
- "--what-af-formats",
254
+ "--what-formats",
250
255
  type=str,
251
256
  action="append",
252
257
  choices=sorted(downloadable_formats),
253
258
  help=dedent("""AlphaFold formats to retrieve. Can be specified multiple times.
254
- Default is 'pdb'. Summary is always downloaded as `<entryId>.json`."""),
259
+ Default is 'summary' and 'cif'."""),
255
260
  )
256
261
  parser.add_argument(
257
262
  "--max-parallel-downloads",
@@ -280,6 +285,22 @@ def _add_retrieve_emdb_parser(subparsers: argparse._SubParsersAction):
280
285
  parser.add_argument("output_dir", type=Path, help="Directory to store downloaded EMDB volume files")
281
286
 
282
287
 
288
+ def _add_copy_method_argument(parser: argparse.ArgumentParser):
289
+ """Add copy method argument to parser."""
290
+ default_copy_method = "symlink"
291
+ if os.name == "nt":
292
+ # On Windows you need developer mode or admin privileges to create symlinks
293
+ # so we default to copying files instead of symlinking
294
+ default_copy_method = "copy"
295
+ parser.add_argument(
296
+ "--copy-method",
297
+ type=str,
298
+ choices=copy_methods,
299
+ default=default_copy_method,
300
+ help="How to copy files when no changes are needed to output file.",
301
+ )
302
+
303
+
283
304
  def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
284
305
  """Add filter confidence subcommand parser."""
285
306
  parser = subparsers.add_parser(
@@ -310,6 +331,7 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
310
331
  In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
311
332
  Use `-` for stdout."""),
312
333
  )
334
+ _add_copy_method_argument(parser)
313
335
 
314
336
 
315
337
  def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
@@ -345,8 +367,11 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
345
367
  )
346
368
  parser.add_argument(
347
369
  "--scheduler-address",
348
- help="Address of the Dask scheduler to connect to. If not provided, will create a local cluster.",
370
+ help=dedent("""Address of the Dask scheduler to connect to.
371
+ If not provided, will create a local cluster.
372
+ If set to `sequential` will run tasks sequentially."""),
349
373
  )
374
+ _add_copy_method_argument(parser)
350
375
 
351
376
 
352
377
  def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
@@ -369,6 +394,7 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
369
394
  )
370
395
  parser.add_argument("--min-residues", type=int, default=0, help="Min residues in chain A")
371
396
  parser.add_argument("--max-residues", type=int, default=10_000_000, help="Max residues in chain A")
397
+ _add_copy_method_argument(parser)
372
398
  parser.add_argument(
373
399
  "--write-stats",
374
400
  type=argparse.FileType("w", encoding="UTF-8"),
@@ -379,6 +405,43 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
379
405
  )
380
406
 
381
407
 
408
+ def _add_filter_ss_parser(subparsers: argparse._SubParsersAction):
409
+ """Add filter secondary structure subcommand parser."""
410
+ parser = subparsers.add_parser(
411
+ "secondary-structure",
412
+ help="Filter PDB/mmCIF files by secondary structure",
413
+ description="Filter PDB/mmCIF files by secondary structure",
414
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
415
+ )
416
+ parser.add_argument("input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')")
417
+ parser.add_argument(
418
+ "output_dir",
419
+ type=Path,
420
+ help=dedent("""\
421
+ Directory to write filtered PDB/mmCIF files. Files are copied without modification.
422
+ """),
423
+ )
424
+ parser.add_argument("--abs-min-helix-residues", type=int, help="Min residues in helices")
425
+ parser.add_argument("--abs-max-helix-residues", type=int, help="Max residues in helices")
426
+ parser.add_argument("--abs-min-sheet-residues", type=int, help="Min residues in sheets")
427
+ parser.add_argument("--abs-max-sheet-residues", type=int, help="Max residues in sheets")
428
+ parser.add_argument("--ratio-min-helix-residues", type=float, help="Min residues in helices (relative)")
429
+ parser.add_argument("--ratio-max-helix-residues", type=float, help="Max residues in helices (relative)")
430
+ parser.add_argument("--ratio-min-sheet-residues", type=float, help="Min residues in sheets (relative)")
431
+ parser.add_argument("--ratio-max-sheet-residues", type=float, help="Max residues in sheets (relative)")
432
+ _add_copy_method_argument(parser)
433
+ parser.add_argument(
434
+ "--write-stats",
435
+ type=argparse.FileType("w", encoding="UTF-8"),
436
+ help=dedent("""
437
+ Write filter statistics to file. In CSV format with columns:
438
+ `<input_file>,<nr_residues>,<nr_helix_residues>,<nr_sheet_residues>,
439
+ <helix_ratio>,<sheet_ratio>,<passed>,<output_file>`.
440
+ Use `-` for stdout.
441
+ """),
442
+ )
443
+
444
+
382
445
  def _add_search_subcommands(subparsers: argparse._SubParsersAction):
383
446
  """Add search command and its subcommands."""
384
447
  parser = subparsers.add_parser(
@@ -420,6 +483,7 @@ def _add_filter_subcommands(subparsers: argparse._SubParsersAction):
420
483
  _add_filter_confidence_parser(subsubparsers)
421
484
  _add_filter_chain_parser(subsubparsers)
422
485
  _add_filter_residue_parser(subsubparsers)
486
+ _add_filter_ss_parser(subsubparsers)
423
487
 
424
488
 
425
489
  def _add_mcp_command(subparsers: argparse._SubParsersAction):
@@ -585,17 +649,17 @@ def _handle_retrieve_pdbe(args):
585
649
 
586
650
  def _handle_retrieve_alphafold(args):
587
651
  download_dir = args.output_dir
588
- what_af_formats = args.what_af_formats
652
+ what_formats = args.what_formats
589
653
  alphafold_csv = args.alphafold_csv
590
654
  max_parallel_downloads = args.max_parallel_downloads
591
655
 
592
- if what_af_formats is None:
593
- what_af_formats = {"pdb"}
656
+ if what_formats is None:
657
+ what_formats = {"summary", "cif"}
594
658
 
595
659
  # TODO besides `uniprot_acc,af_id\n` csv also allow headless single column format
596
660
  #
597
- af_ids = [r["af_id"] for r in _read_alphafold_csv(alphafold_csv)]
598
- validated_what: set[DownloadableFormat] = structure(what_af_formats, set[DownloadableFormat])
661
+ af_ids = _read_column_from_csv(alphafold_csv, "af_id")
662
+ validated_what: set[DownloadableFormat] = structure(what_formats, set[DownloadableFormat])
599
663
  rprint(f"Retrieving {len(af_ids)} AlphaFold entries with formats {validated_what}")
600
664
  afs = af_fetch(af_ids, download_dir, what=validated_what, max_parallel_downloads=max_parallel_downloads)
601
665
  total_nr_files = sum(af.nr_of_files() for af in afs)
@@ -618,21 +682,22 @@ def _handle_filter_confidence(args: argparse.Namespace):
618
682
  # to get rid of duplication
619
683
  input_dir = structure(args.input_dir, Path)
620
684
  output_dir = structure(args.output_dir, Path)
621
- confidence_threshold = structure(args.confidence_threshold, float)
622
- # TODO add min/max
623
- min_residues = structure(args.min_residues, int)
624
- max_residues = structure(args.max_residues, int)
685
+
686
+ confidence_threshold = args.confidence_threshold
687
+ min_residues = args.min_residues
688
+ max_residues = args.max_residues
625
689
  stats_file: TextIOWrapper | None = args.write_stats
690
+ copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
626
691
 
627
692
  output_dir.mkdir(parents=True, exist_ok=True)
628
693
  input_files = sorted(glob_structure_files(input_dir))
629
694
  nr_input_files = len(input_files)
630
695
  rprint(f"Starting confidence filtering of {nr_input_files} mmcif/PDB files in {input_dir} directory.")
631
- query = structure(
696
+ query = converter.structure(
632
697
  {
633
698
  "confidence": confidence_threshold,
634
- "min_threshold": min_residues,
635
- "max_threshold": max_residues,
699
+ "min_residues": min_residues,
700
+ "max_residues": max_residues,
636
701
  },
637
702
  ConfidenceFilterQuery,
638
703
  )
@@ -641,7 +706,11 @@ def _handle_filter_confidence(args: argparse.Namespace):
641
706
  writer.writerow(["input_file", "residue_count", "passed", "output_file"])
642
707
 
643
708
  passed_count = 0
644
- for r in tqdm(filter_files_on_confidence(input_files, query, output_dir), total=len(input_files), unit="file"):
709
+ for r in tqdm(
710
+ filter_files_on_confidence(input_files, query, output_dir, copy_method=copy_method),
711
+ total=len(input_files),
712
+ unit="file",
713
+ ):
645
714
  if r.filtered_file:
646
715
  passed_count += 1
647
716
  if stats_file:
@@ -654,25 +723,53 @@ def _handle_filter_confidence(args: argparse.Namespace):
654
723
 
655
724
  def _handle_filter_chain(args):
656
725
  input_dir = args.input_dir
657
- output_dir = args.output_dir
726
+ output_dir = structure(args.output_dir, Path)
658
727
  pdb_id2chain_mapping_file = args.chains
659
- scheduler_address = args.scheduler_address
728
+ scheduler_address = structure(args.scheduler_address, str | None) # pyright: ignore[reportArgumentType]
729
+ copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
660
730
 
731
+ # make sure files in input dir with entries in mapping file are the same
732
+ # complain when files from mapping file are missing on disk
661
733
  rows = list(_iter_csv_rows(pdb_id2chain_mapping_file))
662
- id2chains: dict[str, str] = {row["pdb_id"]: row["chain"] for row in rows}
734
+ file2chain: set[tuple[Path, str]] = set()
735
+ errors: list[FileNotFoundError] = []
663
736
 
664
- new_files = filter_files_on_chain(input_dir, id2chains, output_dir, scheduler_address)
737
+ for row in rows:
738
+ pdb_id = row["pdb_id"]
739
+ chain = row["chain"]
740
+ try:
741
+ f = locate_structure_file(input_dir, pdb_id)
742
+ file2chain.add((f, chain))
743
+ except FileNotFoundError as e:
744
+ errors.append(e)
665
745
 
666
- nr_written = len([r for r in new_files if r[2] is not None])
746
+ if errors:
747
+ msg = f"Some structure files could not be found ({len(errors)} missing), skipping them"
748
+ rprint(Panel(os.linesep.join(map(str, errors)), title=msg, style="red"))
749
+
750
+ if not file2chain:
751
+ rprint("[red]No valid structure files found. Exiting.")
752
+ sys.exit(1)
753
+
754
+ results = filter_files_on_chain(
755
+ file2chain, output_dir, scheduler_address=scheduler_address, copy_method=copy_method
756
+ )
757
+
758
+ nr_written = len([r for r in results if r.passed])
667
759
 
668
760
  rprint(f"Wrote {nr_written} single-chain PDB/mmCIF files to {output_dir}.")
669
761
 
762
+ for result in results:
763
+ if result.discard_reason:
764
+ rprint(f"[red]Discarding {result.input_file} ({result.discard_reason})[/red]")
765
+
670
766
 
671
767
  def _handle_filter_residue(args):
672
768
  input_dir = structure(args.input_dir, Path)
673
769
  output_dir = structure(args.output_dir, Path)
674
770
  min_residues = structure(args.min_residues, int)
675
771
  max_residues = structure(args.max_residues, int)
772
+ copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
676
773
  stats_file: TextIOWrapper | None = args.write_stats
677
774
 
678
775
  if stats_file:
@@ -683,7 +780,9 @@ def _handle_filter_residue(args):
683
780
  input_files = sorted(glob_structure_files(input_dir))
684
781
  nr_total = len(input_files)
685
782
  rprint(f"Filtering {nr_total} files in {input_dir} directory by number of residues in chain A.")
686
- for r in filter_files_on_residues(input_files, output_dir, min_residues=min_residues, max_residues=max_residues):
783
+ for r in filter_files_on_residues(
784
+ input_files, output_dir, min_residues=min_residues, max_residues=max_residues, copy_method=copy_method
785
+ ):
687
786
  if stats_file:
688
787
  writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file])
689
788
  if r.passed:
@@ -694,6 +793,68 @@ def _handle_filter_residue(args):
694
793
  rprint(f"Statistics written to {stats_file.name}")
695
794
 
696
795
 
796
+ def _handle_filter_ss(args):
797
+ input_dir = structure(args.input_dir, Path)
798
+ output_dir = structure(args.output_dir, Path)
799
+ copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
800
+ stats_file: TextIOWrapper | None = args.write_stats
801
+
802
+ raw_query = {
803
+ "abs_min_helix_residues": args.abs_min_helix_residues,
804
+ "abs_max_helix_residues": args.abs_max_helix_residues,
805
+ "abs_min_sheet_residues": args.abs_min_sheet_residues,
806
+ "abs_max_sheet_residues": args.abs_max_sheet_residues,
807
+ "ratio_min_helix_residues": args.ratio_min_helix_residues,
808
+ "ratio_max_helix_residues": args.ratio_max_helix_residues,
809
+ "ratio_min_sheet_residues": args.ratio_min_sheet_residues,
810
+ "ratio_max_sheet_residues": args.ratio_max_sheet_residues,
811
+ }
812
+ query = converter.structure(raw_query, SecondaryStructureFilterQuery)
813
+ input_files = sorted(glob_structure_files(input_dir))
814
+ nr_total = len(input_files)
815
+ output_dir.mkdir(parents=True, exist_ok=True)
816
+
817
+ if stats_file:
818
+ writer = csv.writer(stats_file)
819
+ writer.writerow(
820
+ [
821
+ "input_file",
822
+ "nr_residues",
823
+ "nr_helix_residues",
824
+ "nr_sheet_residues",
825
+ "helix_ratio",
826
+ "sheet_ratio",
827
+ "passed",
828
+ "output_file",
829
+ ]
830
+ )
831
+
832
+ rprint(f"Filtering {nr_total} files in {input_dir} directory by secondary structure.")
833
+ nr_passed = 0
834
+ for input_file, result in filter_files_on_secondary_structure(input_files, query=query):
835
+ output_file: Path | None = None
836
+ if result.passed:
837
+ output_file = output_dir / input_file.name
838
+ copyfile(input_file, output_file, copy_method)
839
+ nr_passed += 1
840
+ if stats_file:
841
+ writer.writerow(
842
+ [
843
+ input_file,
844
+ result.stats.nr_residues,
845
+ result.stats.nr_helix_residues,
846
+ result.stats.nr_sheet_residues,
847
+ round(result.stats.helix_ratio, 3),
848
+ round(result.stats.sheet_ratio, 3),
849
+ result.passed,
850
+ output_file,
851
+ ]
852
+ )
853
+ rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
854
+ if stats_file:
855
+ rprint(f"Statistics written to {stats_file.name}")
856
+
857
+
697
858
  def _handle_mcp(args):
698
859
  if find_spec("fastmcp") is None:
699
860
  msg = "Unable to start MCP server, please install `protein-quest[mcp]`."
@@ -720,6 +881,7 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
720
881
  ("filter", "confidence"): _handle_filter_confidence,
721
882
  ("filter", "chain"): _handle_filter_chain,
722
883
  ("filter", "residue"): _handle_filter_residue,
884
+ ("filter", "secondary-structure"): _handle_filter_ss,
723
885
  ("mcp", None): _handle_mcp,
724
886
  }
725
887
 
@@ -768,12 +930,7 @@ def _write_dict_of_sets2csv(file: TextIOWrapper, data: dict[str, set[str]], ref_
768
930
  writer.writerow({"uniprot_acc": uniprot_acc, ref_id_field: ref_id})
769
931
 
770
932
 
771
- def _read_alphafold_csv(file: TextIOWrapper):
772
- reader = csv.DictReader(file)
773
- yield from reader
774
-
775
-
776
- def _iter_csv_rows(file: TextIOWrapper):
933
+ def _iter_csv_rows(file: TextIOWrapper) -> Generator[dict[str, str]]:
777
934
  reader = csv.DictReader(file)
778
935
  yield from reader
779
936
 
@@ -0,0 +1,45 @@
1
+ """Convert json or dict to Python objects."""
2
+
3
+ from cattrs.preconf.orjson import make_converter
4
+ from yarl import URL
5
+
6
+ type Percentage = float
7
+ """Type alias for percentage values (0.0-100.0)."""
8
+ type Ratio = float
9
+ """Type alias for ratio values (0.0-1.0)."""
10
+ type PositiveInt = int
11
+ """Type alias for positive integer values (>= 0)."""
12
+
13
+ converter = make_converter()
14
+ """cattrs converter to read JSON document or dict to Python objects."""
15
+ converter.register_structure_hook(URL, lambda v, _: URL(v))
16
+
17
+
18
+ @converter.register_structure_hook
19
+ def percentage_hook(val, _) -> Percentage:
20
+ value = float(val)
21
+ """Cattrs hook to validate percentage values."""
22
+ if not 0.0 <= value <= 100.0:
23
+ msg = f"Value {value} is not a valid percentage (0.0-100.0)"
24
+ raise ValueError(msg)
25
+ return value
26
+
27
+
28
+ @converter.register_structure_hook
29
+ def ratio_hook(val, _) -> Ratio:
30
+ """Cattrs hook to validate ratio values."""
31
+ value = float(val)
32
+ if not 0.0 <= value <= 1.0:
33
+ msg = f"Value {value} is not a valid ratio (0.0-1.0)"
34
+ raise ValueError(msg)
35
+ return value
36
+
37
+
38
+ @converter.register_structure_hook
39
+ def positive_int_hook(val, _) -> PositiveInt:
40
+ """Cattrs hook to validate positive integer values."""
41
+ value = int(val)
42
+ if value < 0:
43
+ msg = f"Value {value} is not a valid positive integer (>= 0)"
44
+ raise ValueError(msg)
45
+ return value
protein_quest/filters.py CHANGED
@@ -1,70 +1,107 @@
1
1
  """Module for filtering structure files and their contents."""
2
2
 
3
3
  import logging
4
- from collections.abc import Generator
4
+ from collections.abc import Collection, Generator
5
5
  from dataclasses import dataclass
6
6
  from pathlib import Path
7
- from shutil import copyfile
8
- from typing import cast
7
+ from typing import Literal
9
8
 
10
- from dask.distributed import Client, progress
9
+ from dask.distributed import Client
11
10
  from distributed.deploy.cluster import Cluster
12
11
  from tqdm.auto import tqdm
13
12
 
14
- from protein_quest.parallel import configure_dask_scheduler
13
+ from protein_quest.parallel import configure_dask_scheduler, dask_map_with_progress
15
14
  from protein_quest.pdbe.io import (
16
- locate_structure_file,
17
15
  nr_residues_in_chain,
18
16
  write_single_chain_pdb_file,
19
17
  )
18
+ from protein_quest.utils import CopyMethod, copyfile
20
19
 
21
20
  logger = logging.getLogger(__name__)
22
21
 
23
22
 
23
+ @dataclass
24
+ class ChainFilterStatistics:
25
+ input_file: Path
26
+ chain_id: str
27
+ passed: bool = False
28
+ output_file: Path | None = None
29
+ discard_reason: Exception | None = None
30
+
31
+
32
+ def filter_file_on_chain(
33
+ file_and_chain: tuple[Path, str],
34
+ output_dir: Path,
35
+ out_chain: str = "A",
36
+ copy_method: CopyMethod = "copy",
37
+ ) -> ChainFilterStatistics:
38
+ input_file, chain_id = file_and_chain
39
+ logger.debug("Filtering %s on chain %s", input_file, chain_id)
40
+ try:
41
+ output_file = write_single_chain_pdb_file(
42
+ input_file, chain_id, output_dir, out_chain=out_chain, copy_method=copy_method
43
+ )
44
+ return ChainFilterStatistics(
45
+ input_file=input_file,
46
+ chain_id=chain_id,
47
+ output_file=output_file,
48
+ passed=True,
49
+ )
50
+ except Exception as e: # noqa: BLE001 - error is handled downstream
51
+ return ChainFilterStatistics(input_file=input_file, chain_id=chain_id, discard_reason=e)
52
+
53
+
24
54
  def filter_files_on_chain(
25
- input_dir: Path,
26
- id2chains: dict[str, str],
55
+ file2chains: Collection[tuple[Path, str]],
27
56
  output_dir: Path,
28
- scheduler_address: str | Cluster | None = None,
29
57
  out_chain: str = "A",
30
- ) -> list[tuple[str, str, Path | None]]:
58
+ scheduler_address: str | Cluster | Literal["sequential"] | None = None,
59
+ copy_method: CopyMethod = "copy",
60
+ ) -> list[ChainFilterStatistics]:
31
61
  """Filter mmcif/PDB files by chain.
32
62
 
33
63
  Args:
34
- input_dir: The directory containing the input mmcif/PDB files.
35
- id2chains: Which chain to keep for each PDB ID. Key is the PDB ID, value is the chain ID.
64
+ file2chains: Which chain to keep for each PDB file.
65
+ First item is the PDB file path, second item is the chain ID.
36
66
  output_dir: The directory where the filtered files will be written.
37
- scheduler_address: The address of the Dask scheduler.
38
67
  out_chain: Under what name to write the kept chain.
68
+ scheduler_address: The address of the Dask scheduler.
69
+ If not provided, will create a local cluster.
70
+ If set to `sequential` will run tasks sequentially.
71
+ copy_method: How to copy when a direct copy is possible.
39
72
 
40
73
  Returns:
41
- A list of tuples containing the PDB ID, chain ID, and path to the filtered file.
42
- Last tuple item is None if something went wrong like chain not present.
74
+ Result of the filtering process.
43
75
  """
44
76
  output_dir.mkdir(parents=True, exist_ok=True)
77
+ if scheduler_address == "sequential":
78
+
79
+ def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics:
80
+ return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, copy_method=copy_method)
81
+
82
+ return list(map(task, file2chains))
83
+
84
+ # TODO make logger.debug in filter_file_on_chain show to user when --log
85
+ # GPT-5 generated a fairly difficult setup with a WorkerPlugin, need to find a simpler approach
45
86
  scheduler_address = configure_dask_scheduler(
46
87
  scheduler_address,
47
88
  name="filter-chain",
48
89
  )
49
90
 
50
- def task(id2chain: tuple[str, str]) -> tuple[str, str, Path | None]:
51
- pdb_id, chain = id2chain
52
- input_file = locate_structure_file(input_dir, pdb_id)
53
- return pdb_id, chain, write_single_chain_pdb_file(input_file, chain, output_dir, out_chain=out_chain)
54
-
55
91
  with Client(scheduler_address) as client:
56
- logger.info(f"Follow progress on dask dashboard at: {client.dashboard_link}")
57
-
58
- futures = client.map(task, id2chains.items())
59
-
60
- progress(futures)
61
-
62
- results = client.gather(futures)
63
- return cast("list[tuple[str,str, Path | None]]", results)
92
+ client.forward_logging()
93
+ return dask_map_with_progress(
94
+ client,
95
+ filter_file_on_chain,
96
+ file2chains,
97
+ output_dir=output_dir,
98
+ out_chain=out_chain,
99
+ copy_method=copy_method,
100
+ )
64
101
 
65
102
 
66
103
  @dataclass
67
- class FilterStat:
104
+ class ResidueFilterStatistics:
68
105
  """Statistics for filtering files based on residue count in a specific chain.
69
106
 
70
107
  Parameters:
@@ -81,8 +118,13 @@ class FilterStat:
81
118
 
82
119
 
83
120
  def filter_files_on_residues(
84
- input_files: list[Path], output_dir: Path, min_residues: int, max_residues: int, chain: str = "A"
85
- ) -> Generator[FilterStat]:
121
+ input_files: list[Path],
122
+ output_dir: Path,
123
+ min_residues: int,
124
+ max_residues: int,
125
+ chain: str = "A",
126
+ copy_method: CopyMethod = "copy",
127
+ ) -> Generator[ResidueFilterStatistics]:
86
128
  """Filter PDB/mmCIF files by number of residues in given chain.
87
129
 
88
130
  Args:
@@ -91,9 +133,10 @@ def filter_files_on_residues(
91
133
  min_residues: The minimum number of residues in chain.
92
134
  max_residues: The maximum number of residues in chain.
93
135
  chain: The chain to count residues of.
136
+ copy_method: How to copy passed files to output directory:
94
137
 
95
138
  Yields:
96
- FilterStat objects containing information about the filtering process for each input file.
139
+ Objects containing information about the filtering process for each input file.
97
140
  """
98
141
  output_dir.mkdir(parents=True, exist_ok=True)
99
142
  for input_file in tqdm(input_files, unit="file"):
@@ -101,7 +144,7 @@ def filter_files_on_residues(
101
144
  passed = min_residues <= residue_count <= max_residues
102
145
  if passed:
103
146
  output_file = output_dir / input_file.name
104
- copyfile(input_file, output_file)
105
- yield FilterStat(input_file, residue_count, True, output_file)
147
+ copyfile(input_file, output_file, copy_method)
148
+ yield ResidueFilterStatistics(input_file, residue_count, True, output_file)
106
149
  else:
107
- yield FilterStat(input_file, residue_count, False, None)
150
+ yield ResidueFilterStatistics(input_file, residue_count, False, None)
protein_quest/go.py CHANGED
@@ -8,8 +8,8 @@ from io import TextIOWrapper
8
8
  from typing import Literal, get_args
9
9
 
10
10
  from cattrs.gen import make_dict_structure_fn, override
11
- from cattrs.preconf.orjson import make_converter
12
11
 
12
+ from protein_quest.converter import converter
13
13
  from protein_quest.utils import friendly_session
14
14
 
15
15
  logger = logging.getLogger(__name__)
@@ -52,9 +52,6 @@ class SearchResponse:
52
52
  page_info: PageInfo
53
53
 
54
54
 
55
- converter = make_converter()
56
-
57
-
58
55
  def flatten_definition(definition, _context) -> str:
59
56
  return definition["text"]
60
57
 
@@ -24,12 +24,11 @@ npx @modelcontextprotocol/inspector
24
24
  # Choose STDIO
25
25
  # command: uv run protein-quest mcp
26
26
  # id: protein-quest
27
- # Prompt: What are the PDBe structures for `A8MT69` uniprot accession?
28
27
  ```
29
28
 
30
29
  Examples:
31
30
 
32
- For search pdb use `A8MT69` as input.
31
+ - What are the PDBe structures for `A8MT69` uniprot accession?
33
32
 
34
33
  """
35
34
 
@@ -47,6 +46,7 @@ from protein_quest.emdb import fetch as emdb_fetch
47
46
  from protein_quest.go import search_gene_ontology_term
48
47
  from protein_quest.pdbe.fetch import fetch as pdbe_fetch
49
48
  from protein_quest.pdbe.io import glob_structure_files, nr_residues_in_chain, write_single_chain_pdb_file
49
+ from protein_quest.ss import filter_file_on_secondary_structure
50
50
  from protein_quest.taxonomy import search_taxon
51
51
  from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
52
52
 
@@ -90,7 +90,7 @@ def extract_single_chain_from_structure(
90
90
  chain2keep: str,
91
91
  output_dir: Path,
92
92
  out_chain: str = "A",
93
- ) -> Path | None:
93
+ ) -> Path:
94
94
  """
95
95
  Extract a single chain from a mmCIF/pdb file and write to a new file.
96
96
 
@@ -101,7 +101,7 @@ def extract_single_chain_from_structure(
101
101
  out_chain: The chain identifier for the output file.
102
102
 
103
103
  Returns:
104
- Path to the output mmCIF/pdb file or None if not created.
104
+ Path to the output mmCIF/pdb file
105
105
  """
106
106
  return write_single_chain_pdb_file(input_file, chain2keep, output_dir, out_chain)
107
107
 
@@ -150,7 +150,7 @@ def fetch_alphafold_structures(uniprot_accs: set[str], save_dir: Path) -> list[A
150
150
  Returns:
151
151
  A list of AlphaFold entries.
152
152
  """
153
- what: set[DownloadableFormat] = {"cif"}
153
+ what: set[DownloadableFormat] = {"summary", "cif"}
154
154
  return alphafold_fetch(uniprot_accs, save_dir, what)
155
155
 
156
156
 
@@ -166,6 +166,9 @@ def alphafold_confidence_filter(file: Path, query: ConfidenceFilterQuery, filter
166
166
  return filter_file_on_residues(file, query, filtered_dir)
167
167
 
168
168
 
169
+ mcp.tool(filter_file_on_secondary_structure)
170
+
171
+
169
172
  @mcp.prompt
170
173
  def candidate_structures(
171
174
  species: str = "Human",