protein-quest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of protein-quest might be problematic. Click here for more details.
- protein_quest/__version__.py +2 -1
- protein_quest/alphafold/confidence.py +44 -17
- protein_quest/alphafold/entry_summary.py +11 -9
- protein_quest/alphafold/fetch.py +37 -63
- protein_quest/cli.py +187 -30
- protein_quest/converter.py +45 -0
- protein_quest/filters.py +78 -35
- protein_quest/go.py +1 -4
- protein_quest/mcp_server.py +8 -5
- protein_quest/parallel.py +37 -1
- protein_quest/pdbe/fetch.py +15 -1
- protein_quest/pdbe/io.py +142 -46
- protein_quest/ss.py +264 -0
- protein_quest/taxonomy.py +13 -3
- protein_quest/utils.py +65 -3
- {protein_quest-0.3.0.dist-info → protein_quest-0.3.2.dist-info}/METADATA +21 -11
- protein_quest-0.3.2.dist-info/RECORD +26 -0
- protein_quest-0.3.0.dist-info/RECORD +0 -24
- {protein_quest-0.3.0.dist-info → protein_quest-0.3.2.dist-info}/WHEEL +0 -0
- {protein_quest-0.3.0.dist-info → protein_quest-0.3.2.dist-info}/entry_points.txt +0 -0
- {protein_quest-0.3.0.dist-info → protein_quest-0.3.2.dist-info}/licenses/LICENSE +0 -0
protein_quest/cli.py
CHANGED
|
@@ -5,7 +5,8 @@ import asyncio
|
|
|
5
5
|
import csv
|
|
6
6
|
import logging
|
|
7
7
|
import os
|
|
8
|
-
|
|
8
|
+
import sys
|
|
9
|
+
from collections.abc import Callable, Generator, Iterable
|
|
9
10
|
from importlib.util import find_spec
|
|
10
11
|
from io import TextIOWrapper
|
|
11
12
|
from pathlib import Path
|
|
@@ -14,6 +15,7 @@ from textwrap import dedent
|
|
|
14
15
|
from cattrs import structure
|
|
15
16
|
from rich import print as rprint
|
|
16
17
|
from rich.logging import RichHandler
|
|
18
|
+
from rich.panel import Panel
|
|
17
19
|
from rich_argparse import ArgumentDefaultsRichHelpFormatter
|
|
18
20
|
from tqdm.rich import tqdm
|
|
19
21
|
|
|
@@ -21,13 +23,16 @@ from protein_quest.__version__ import __version__
|
|
|
21
23
|
from protein_quest.alphafold.confidence import ConfidenceFilterQuery, filter_files_on_confidence
|
|
22
24
|
from protein_quest.alphafold.fetch import DownloadableFormat, downloadable_formats
|
|
23
25
|
from protein_quest.alphafold.fetch import fetch_many as af_fetch
|
|
26
|
+
from protein_quest.converter import converter
|
|
24
27
|
from protein_quest.emdb import fetch as emdb_fetch
|
|
25
28
|
from protein_quest.filters import filter_files_on_chain, filter_files_on_residues
|
|
26
29
|
from protein_quest.go import Aspect, allowed_aspects, search_gene_ontology_term, write_go_terms_to_csv
|
|
27
30
|
from protein_quest.pdbe import fetch as pdbe_fetch
|
|
28
|
-
from protein_quest.pdbe.io import glob_structure_files
|
|
31
|
+
from protein_quest.pdbe.io import glob_structure_files, locate_structure_file
|
|
32
|
+
from protein_quest.ss import SecondaryStructureFilterQuery, filter_files_on_secondary_structure
|
|
29
33
|
from protein_quest.taxonomy import SearchField, _write_taxonomy_csv, search_fields, search_taxon
|
|
30
34
|
from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
|
|
35
|
+
from protein_quest.utils import CopyMethod, copy_methods, copyfile
|
|
31
36
|
|
|
32
37
|
logger = logging.getLogger(__name__)
|
|
33
38
|
|
|
@@ -246,12 +251,12 @@ def _add_retrieve_alphafold_parser(subparsers: argparse._SubParsersAction):
|
|
|
246
251
|
)
|
|
247
252
|
parser.add_argument("output_dir", type=Path, help="Directory to store downloaded AlphaFold files")
|
|
248
253
|
parser.add_argument(
|
|
249
|
-
"--what-
|
|
254
|
+
"--what-formats",
|
|
250
255
|
type=str,
|
|
251
256
|
action="append",
|
|
252
257
|
choices=sorted(downloadable_formats),
|
|
253
258
|
help=dedent("""AlphaFold formats to retrieve. Can be specified multiple times.
|
|
254
|
-
Default is '
|
|
259
|
+
Default is 'summary' and 'cif'."""),
|
|
255
260
|
)
|
|
256
261
|
parser.add_argument(
|
|
257
262
|
"--max-parallel-downloads",
|
|
@@ -280,6 +285,22 @@ def _add_retrieve_emdb_parser(subparsers: argparse._SubParsersAction):
|
|
|
280
285
|
parser.add_argument("output_dir", type=Path, help="Directory to store downloaded EMDB volume files")
|
|
281
286
|
|
|
282
287
|
|
|
288
|
+
def _add_copy_method_argument(parser: argparse.ArgumentParser):
|
|
289
|
+
"""Add copy method argument to parser."""
|
|
290
|
+
default_copy_method = "symlink"
|
|
291
|
+
if os.name == "nt":
|
|
292
|
+
# On Windows you need developer mode or admin privileges to create symlinks
|
|
293
|
+
# so we default to copying files instead of symlinking
|
|
294
|
+
default_copy_method = "copy"
|
|
295
|
+
parser.add_argument(
|
|
296
|
+
"--copy-method",
|
|
297
|
+
type=str,
|
|
298
|
+
choices=copy_methods,
|
|
299
|
+
default=default_copy_method,
|
|
300
|
+
help="How to copy files when no changes are needed to output file.",
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
|
|
283
304
|
def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
|
|
284
305
|
"""Add filter confidence subcommand parser."""
|
|
285
306
|
parser = subparsers.add_parser(
|
|
@@ -310,6 +331,7 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
|
|
|
310
331
|
In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
|
|
311
332
|
Use `-` for stdout."""),
|
|
312
333
|
)
|
|
334
|
+
_add_copy_method_argument(parser)
|
|
313
335
|
|
|
314
336
|
|
|
315
337
|
def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
|
|
@@ -345,8 +367,11 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
|
|
|
345
367
|
)
|
|
346
368
|
parser.add_argument(
|
|
347
369
|
"--scheduler-address",
|
|
348
|
-
help="Address of the Dask scheduler to connect to.
|
|
370
|
+
help=dedent("""Address of the Dask scheduler to connect to.
|
|
371
|
+
If not provided, will create a local cluster.
|
|
372
|
+
If set to `sequential` will run tasks sequentially."""),
|
|
349
373
|
)
|
|
374
|
+
_add_copy_method_argument(parser)
|
|
350
375
|
|
|
351
376
|
|
|
352
377
|
def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
|
|
@@ -369,6 +394,7 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
|
|
|
369
394
|
)
|
|
370
395
|
parser.add_argument("--min-residues", type=int, default=0, help="Min residues in chain A")
|
|
371
396
|
parser.add_argument("--max-residues", type=int, default=10_000_000, help="Max residues in chain A")
|
|
397
|
+
_add_copy_method_argument(parser)
|
|
372
398
|
parser.add_argument(
|
|
373
399
|
"--write-stats",
|
|
374
400
|
type=argparse.FileType("w", encoding="UTF-8"),
|
|
@@ -379,6 +405,43 @@ def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
|
|
|
379
405
|
)
|
|
380
406
|
|
|
381
407
|
|
|
408
|
+
def _add_filter_ss_parser(subparsers: argparse._SubParsersAction):
|
|
409
|
+
"""Add filter secondary structure subcommand parser."""
|
|
410
|
+
parser = subparsers.add_parser(
|
|
411
|
+
"secondary-structure",
|
|
412
|
+
help="Filter PDB/mmCIF files by secondary structure",
|
|
413
|
+
description="Filter PDB/mmCIF files by secondary structure",
|
|
414
|
+
formatter_class=ArgumentDefaultsRichHelpFormatter,
|
|
415
|
+
)
|
|
416
|
+
parser.add_argument("input_dir", type=Path, help="Directory with PDB/mmCIF files (e.g., from 'filter chain')")
|
|
417
|
+
parser.add_argument(
|
|
418
|
+
"output_dir",
|
|
419
|
+
type=Path,
|
|
420
|
+
help=dedent("""\
|
|
421
|
+
Directory to write filtered PDB/mmCIF files. Files are copied without modification.
|
|
422
|
+
"""),
|
|
423
|
+
)
|
|
424
|
+
parser.add_argument("--abs-min-helix-residues", type=int, help="Min residues in helices")
|
|
425
|
+
parser.add_argument("--abs-max-helix-residues", type=int, help="Max residues in helices")
|
|
426
|
+
parser.add_argument("--abs-min-sheet-residues", type=int, help="Min residues in sheets")
|
|
427
|
+
parser.add_argument("--abs-max-sheet-residues", type=int, help="Max residues in sheets")
|
|
428
|
+
parser.add_argument("--ratio-min-helix-residues", type=float, help="Min residues in helices (relative)")
|
|
429
|
+
parser.add_argument("--ratio-max-helix-residues", type=float, help="Max residues in helices (relative)")
|
|
430
|
+
parser.add_argument("--ratio-min-sheet-residues", type=float, help="Min residues in sheets (relative)")
|
|
431
|
+
parser.add_argument("--ratio-max-sheet-residues", type=float, help="Max residues in sheets (relative)")
|
|
432
|
+
_add_copy_method_argument(parser)
|
|
433
|
+
parser.add_argument(
|
|
434
|
+
"--write-stats",
|
|
435
|
+
type=argparse.FileType("w", encoding="UTF-8"),
|
|
436
|
+
help=dedent("""
|
|
437
|
+
Write filter statistics to file. In CSV format with columns:
|
|
438
|
+
`<input_file>,<nr_residues>,<nr_helix_residues>,<nr_sheet_residues>,
|
|
439
|
+
<helix_ratio>,<sheet_ratio>,<passed>,<output_file>`.
|
|
440
|
+
Use `-` for stdout.
|
|
441
|
+
"""),
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
|
|
382
445
|
def _add_search_subcommands(subparsers: argparse._SubParsersAction):
|
|
383
446
|
"""Add search command and its subcommands."""
|
|
384
447
|
parser = subparsers.add_parser(
|
|
@@ -420,6 +483,7 @@ def _add_filter_subcommands(subparsers: argparse._SubParsersAction):
|
|
|
420
483
|
_add_filter_confidence_parser(subsubparsers)
|
|
421
484
|
_add_filter_chain_parser(subsubparsers)
|
|
422
485
|
_add_filter_residue_parser(subsubparsers)
|
|
486
|
+
_add_filter_ss_parser(subsubparsers)
|
|
423
487
|
|
|
424
488
|
|
|
425
489
|
def _add_mcp_command(subparsers: argparse._SubParsersAction):
|
|
@@ -585,17 +649,17 @@ def _handle_retrieve_pdbe(args):
|
|
|
585
649
|
|
|
586
650
|
def _handle_retrieve_alphafold(args):
|
|
587
651
|
download_dir = args.output_dir
|
|
588
|
-
|
|
652
|
+
what_formats = args.what_formats
|
|
589
653
|
alphafold_csv = args.alphafold_csv
|
|
590
654
|
max_parallel_downloads = args.max_parallel_downloads
|
|
591
655
|
|
|
592
|
-
if
|
|
593
|
-
|
|
656
|
+
if what_formats is None:
|
|
657
|
+
what_formats = {"summary", "cif"}
|
|
594
658
|
|
|
595
659
|
# TODO besides `uniprot_acc,af_id\n` csv also allow headless single column format
|
|
596
660
|
#
|
|
597
|
-
af_ids =
|
|
598
|
-
validated_what: set[DownloadableFormat] = structure(
|
|
661
|
+
af_ids = _read_column_from_csv(alphafold_csv, "af_id")
|
|
662
|
+
validated_what: set[DownloadableFormat] = structure(what_formats, set[DownloadableFormat])
|
|
599
663
|
rprint(f"Retrieving {len(af_ids)} AlphaFold entries with formats {validated_what}")
|
|
600
664
|
afs = af_fetch(af_ids, download_dir, what=validated_what, max_parallel_downloads=max_parallel_downloads)
|
|
601
665
|
total_nr_files = sum(af.nr_of_files() for af in afs)
|
|
@@ -618,21 +682,22 @@ def _handle_filter_confidence(args: argparse.Namespace):
|
|
|
618
682
|
# to get rid of duplication
|
|
619
683
|
input_dir = structure(args.input_dir, Path)
|
|
620
684
|
output_dir = structure(args.output_dir, Path)
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
min_residues =
|
|
624
|
-
max_residues =
|
|
685
|
+
|
|
686
|
+
confidence_threshold = args.confidence_threshold
|
|
687
|
+
min_residues = args.min_residues
|
|
688
|
+
max_residues = args.max_residues
|
|
625
689
|
stats_file: TextIOWrapper | None = args.write_stats
|
|
690
|
+
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
|
|
626
691
|
|
|
627
692
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
628
693
|
input_files = sorted(glob_structure_files(input_dir))
|
|
629
694
|
nr_input_files = len(input_files)
|
|
630
695
|
rprint(f"Starting confidence filtering of {nr_input_files} mmcif/PDB files in {input_dir} directory.")
|
|
631
|
-
query = structure(
|
|
696
|
+
query = converter.structure(
|
|
632
697
|
{
|
|
633
698
|
"confidence": confidence_threshold,
|
|
634
|
-
"
|
|
635
|
-
"
|
|
699
|
+
"min_residues": min_residues,
|
|
700
|
+
"max_residues": max_residues,
|
|
636
701
|
},
|
|
637
702
|
ConfidenceFilterQuery,
|
|
638
703
|
)
|
|
@@ -641,7 +706,11 @@ def _handle_filter_confidence(args: argparse.Namespace):
|
|
|
641
706
|
writer.writerow(["input_file", "residue_count", "passed", "output_file"])
|
|
642
707
|
|
|
643
708
|
passed_count = 0
|
|
644
|
-
for r in tqdm(
|
|
709
|
+
for r in tqdm(
|
|
710
|
+
filter_files_on_confidence(input_files, query, output_dir, copy_method=copy_method),
|
|
711
|
+
total=len(input_files),
|
|
712
|
+
unit="file",
|
|
713
|
+
):
|
|
645
714
|
if r.filtered_file:
|
|
646
715
|
passed_count += 1
|
|
647
716
|
if stats_file:
|
|
@@ -654,25 +723,53 @@ def _handle_filter_confidence(args: argparse.Namespace):
|
|
|
654
723
|
|
|
655
724
|
def _handle_filter_chain(args):
|
|
656
725
|
input_dir = args.input_dir
|
|
657
|
-
output_dir = args.output_dir
|
|
726
|
+
output_dir = structure(args.output_dir, Path)
|
|
658
727
|
pdb_id2chain_mapping_file = args.chains
|
|
659
|
-
scheduler_address = args.scheduler_address
|
|
728
|
+
scheduler_address = structure(args.scheduler_address, str | None) # pyright: ignore[reportArgumentType]
|
|
729
|
+
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
|
|
660
730
|
|
|
731
|
+
# make sure files in input dir with entries in mapping file are the same
|
|
732
|
+
# complain when files from mapping file are missing on disk
|
|
661
733
|
rows = list(_iter_csv_rows(pdb_id2chain_mapping_file))
|
|
662
|
-
|
|
734
|
+
file2chain: set[tuple[Path, str]] = set()
|
|
735
|
+
errors: list[FileNotFoundError] = []
|
|
663
736
|
|
|
664
|
-
|
|
737
|
+
for row in rows:
|
|
738
|
+
pdb_id = row["pdb_id"]
|
|
739
|
+
chain = row["chain"]
|
|
740
|
+
try:
|
|
741
|
+
f = locate_structure_file(input_dir, pdb_id)
|
|
742
|
+
file2chain.add((f, chain))
|
|
743
|
+
except FileNotFoundError as e:
|
|
744
|
+
errors.append(e)
|
|
665
745
|
|
|
666
|
-
|
|
746
|
+
if errors:
|
|
747
|
+
msg = f"Some structure files could not be found ({len(errors)} missing), skipping them"
|
|
748
|
+
rprint(Panel(os.linesep.join(map(str, errors)), title=msg, style="red"))
|
|
749
|
+
|
|
750
|
+
if not file2chain:
|
|
751
|
+
rprint("[red]No valid structure files found. Exiting.")
|
|
752
|
+
sys.exit(1)
|
|
753
|
+
|
|
754
|
+
results = filter_files_on_chain(
|
|
755
|
+
file2chain, output_dir, scheduler_address=scheduler_address, copy_method=copy_method
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
nr_written = len([r for r in results if r.passed])
|
|
667
759
|
|
|
668
760
|
rprint(f"Wrote {nr_written} single-chain PDB/mmCIF files to {output_dir}.")
|
|
669
761
|
|
|
762
|
+
for result in results:
|
|
763
|
+
if result.discard_reason:
|
|
764
|
+
rprint(f"[red]Discarding {result.input_file} ({result.discard_reason})[/red]")
|
|
765
|
+
|
|
670
766
|
|
|
671
767
|
def _handle_filter_residue(args):
|
|
672
768
|
input_dir = structure(args.input_dir, Path)
|
|
673
769
|
output_dir = structure(args.output_dir, Path)
|
|
674
770
|
min_residues = structure(args.min_residues, int)
|
|
675
771
|
max_residues = structure(args.max_residues, int)
|
|
772
|
+
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
|
|
676
773
|
stats_file: TextIOWrapper | None = args.write_stats
|
|
677
774
|
|
|
678
775
|
if stats_file:
|
|
@@ -683,7 +780,9 @@ def _handle_filter_residue(args):
|
|
|
683
780
|
input_files = sorted(glob_structure_files(input_dir))
|
|
684
781
|
nr_total = len(input_files)
|
|
685
782
|
rprint(f"Filtering {nr_total} files in {input_dir} directory by number of residues in chain A.")
|
|
686
|
-
for r in filter_files_on_residues(
|
|
783
|
+
for r in filter_files_on_residues(
|
|
784
|
+
input_files, output_dir, min_residues=min_residues, max_residues=max_residues, copy_method=copy_method
|
|
785
|
+
):
|
|
687
786
|
if stats_file:
|
|
688
787
|
writer.writerow([r.input_file, r.residue_count, r.passed, r.output_file])
|
|
689
788
|
if r.passed:
|
|
@@ -694,6 +793,68 @@ def _handle_filter_residue(args):
|
|
|
694
793
|
rprint(f"Statistics written to {stats_file.name}")
|
|
695
794
|
|
|
696
795
|
|
|
796
|
+
def _handle_filter_ss(args):
|
|
797
|
+
input_dir = structure(args.input_dir, Path)
|
|
798
|
+
output_dir = structure(args.output_dir, Path)
|
|
799
|
+
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
|
|
800
|
+
stats_file: TextIOWrapper | None = args.write_stats
|
|
801
|
+
|
|
802
|
+
raw_query = {
|
|
803
|
+
"abs_min_helix_residues": args.abs_min_helix_residues,
|
|
804
|
+
"abs_max_helix_residues": args.abs_max_helix_residues,
|
|
805
|
+
"abs_min_sheet_residues": args.abs_min_sheet_residues,
|
|
806
|
+
"abs_max_sheet_residues": args.abs_max_sheet_residues,
|
|
807
|
+
"ratio_min_helix_residues": args.ratio_min_helix_residues,
|
|
808
|
+
"ratio_max_helix_residues": args.ratio_max_helix_residues,
|
|
809
|
+
"ratio_min_sheet_residues": args.ratio_min_sheet_residues,
|
|
810
|
+
"ratio_max_sheet_residues": args.ratio_max_sheet_residues,
|
|
811
|
+
}
|
|
812
|
+
query = converter.structure(raw_query, SecondaryStructureFilterQuery)
|
|
813
|
+
input_files = sorted(glob_structure_files(input_dir))
|
|
814
|
+
nr_total = len(input_files)
|
|
815
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
816
|
+
|
|
817
|
+
if stats_file:
|
|
818
|
+
writer = csv.writer(stats_file)
|
|
819
|
+
writer.writerow(
|
|
820
|
+
[
|
|
821
|
+
"input_file",
|
|
822
|
+
"nr_residues",
|
|
823
|
+
"nr_helix_residues",
|
|
824
|
+
"nr_sheet_residues",
|
|
825
|
+
"helix_ratio",
|
|
826
|
+
"sheet_ratio",
|
|
827
|
+
"passed",
|
|
828
|
+
"output_file",
|
|
829
|
+
]
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
rprint(f"Filtering {nr_total} files in {input_dir} directory by secondary structure.")
|
|
833
|
+
nr_passed = 0
|
|
834
|
+
for input_file, result in filter_files_on_secondary_structure(input_files, query=query):
|
|
835
|
+
output_file: Path | None = None
|
|
836
|
+
if result.passed:
|
|
837
|
+
output_file = output_dir / input_file.name
|
|
838
|
+
copyfile(input_file, output_file, copy_method)
|
|
839
|
+
nr_passed += 1
|
|
840
|
+
if stats_file:
|
|
841
|
+
writer.writerow(
|
|
842
|
+
[
|
|
843
|
+
input_file,
|
|
844
|
+
result.stats.nr_residues,
|
|
845
|
+
result.stats.nr_helix_residues,
|
|
846
|
+
result.stats.nr_sheet_residues,
|
|
847
|
+
round(result.stats.helix_ratio, 3),
|
|
848
|
+
round(result.stats.sheet_ratio, 3),
|
|
849
|
+
result.passed,
|
|
850
|
+
output_file,
|
|
851
|
+
]
|
|
852
|
+
)
|
|
853
|
+
rprint(f"Wrote {nr_passed} files to {output_dir} directory.")
|
|
854
|
+
if stats_file:
|
|
855
|
+
rprint(f"Statistics written to {stats_file.name}")
|
|
856
|
+
|
|
857
|
+
|
|
697
858
|
def _handle_mcp(args):
|
|
698
859
|
if find_spec("fastmcp") is None:
|
|
699
860
|
msg = "Unable to start MCP server, please install `protein-quest[mcp]`."
|
|
@@ -720,6 +881,7 @@ HANDLERS: dict[tuple[str, str | None], Callable] = {
|
|
|
720
881
|
("filter", "confidence"): _handle_filter_confidence,
|
|
721
882
|
("filter", "chain"): _handle_filter_chain,
|
|
722
883
|
("filter", "residue"): _handle_filter_residue,
|
|
884
|
+
("filter", "secondary-structure"): _handle_filter_ss,
|
|
723
885
|
("mcp", None): _handle_mcp,
|
|
724
886
|
}
|
|
725
887
|
|
|
@@ -768,12 +930,7 @@ def _write_dict_of_sets2csv(file: TextIOWrapper, data: dict[str, set[str]], ref_
|
|
|
768
930
|
writer.writerow({"uniprot_acc": uniprot_acc, ref_id_field: ref_id})
|
|
769
931
|
|
|
770
932
|
|
|
771
|
-
def
|
|
772
|
-
reader = csv.DictReader(file)
|
|
773
|
-
yield from reader
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
def _iter_csv_rows(file: TextIOWrapper):
|
|
933
|
+
def _iter_csv_rows(file: TextIOWrapper) -> Generator[dict[str, str]]:
|
|
777
934
|
reader = csv.DictReader(file)
|
|
778
935
|
yield from reader
|
|
779
936
|
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Convert json or dict to Python objects."""
|
|
2
|
+
|
|
3
|
+
from cattrs.preconf.orjson import make_converter
|
|
4
|
+
from yarl import URL
|
|
5
|
+
|
|
6
|
+
type Percentage = float
|
|
7
|
+
"""Type alias for percentage values (0.0-100.0)."""
|
|
8
|
+
type Ratio = float
|
|
9
|
+
"""Type alias for ratio values (0.0-1.0)."""
|
|
10
|
+
type PositiveInt = int
|
|
11
|
+
"""Type alias for positive integer values (>= 0)."""
|
|
12
|
+
|
|
13
|
+
converter = make_converter()
|
|
14
|
+
"""cattrs converter to read JSON document or dict to Python objects."""
|
|
15
|
+
converter.register_structure_hook(URL, lambda v, _: URL(v))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@converter.register_structure_hook
|
|
19
|
+
def percentage_hook(val, _) -> Percentage:
|
|
20
|
+
value = float(val)
|
|
21
|
+
"""Cattrs hook to validate percentage values."""
|
|
22
|
+
if not 0.0 <= value <= 100.0:
|
|
23
|
+
msg = f"Value {value} is not a valid percentage (0.0-100.0)"
|
|
24
|
+
raise ValueError(msg)
|
|
25
|
+
return value
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@converter.register_structure_hook
|
|
29
|
+
def ratio_hook(val, _) -> Ratio:
|
|
30
|
+
"""Cattrs hook to validate ratio values."""
|
|
31
|
+
value = float(val)
|
|
32
|
+
if not 0.0 <= value <= 1.0:
|
|
33
|
+
msg = f"Value {value} is not a valid ratio (0.0-1.0)"
|
|
34
|
+
raise ValueError(msg)
|
|
35
|
+
return value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@converter.register_structure_hook
|
|
39
|
+
def positive_int_hook(val, _) -> PositiveInt:
|
|
40
|
+
"""Cattrs hook to validate positive integer values."""
|
|
41
|
+
value = int(val)
|
|
42
|
+
if value < 0:
|
|
43
|
+
msg = f"Value {value} is not a valid positive integer (>= 0)"
|
|
44
|
+
raise ValueError(msg)
|
|
45
|
+
return value
|
protein_quest/filters.py
CHANGED
|
@@ -1,70 +1,107 @@
|
|
|
1
1
|
"""Module for filtering structure files and their contents."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from collections.abc import Generator
|
|
4
|
+
from collections.abc import Collection, Generator
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from
|
|
8
|
-
from typing import cast
|
|
7
|
+
from typing import Literal
|
|
9
8
|
|
|
10
|
-
from dask.distributed import Client
|
|
9
|
+
from dask.distributed import Client
|
|
11
10
|
from distributed.deploy.cluster import Cluster
|
|
12
11
|
from tqdm.auto import tqdm
|
|
13
12
|
|
|
14
|
-
from protein_quest.parallel import configure_dask_scheduler
|
|
13
|
+
from protein_quest.parallel import configure_dask_scheduler, dask_map_with_progress
|
|
15
14
|
from protein_quest.pdbe.io import (
|
|
16
|
-
locate_structure_file,
|
|
17
15
|
nr_residues_in_chain,
|
|
18
16
|
write_single_chain_pdb_file,
|
|
19
17
|
)
|
|
18
|
+
from protein_quest.utils import CopyMethod, copyfile
|
|
20
19
|
|
|
21
20
|
logger = logging.getLogger(__name__)
|
|
22
21
|
|
|
23
22
|
|
|
23
|
+
@dataclass
|
|
24
|
+
class ChainFilterStatistics:
|
|
25
|
+
input_file: Path
|
|
26
|
+
chain_id: str
|
|
27
|
+
passed: bool = False
|
|
28
|
+
output_file: Path | None = None
|
|
29
|
+
discard_reason: Exception | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def filter_file_on_chain(
|
|
33
|
+
file_and_chain: tuple[Path, str],
|
|
34
|
+
output_dir: Path,
|
|
35
|
+
out_chain: str = "A",
|
|
36
|
+
copy_method: CopyMethod = "copy",
|
|
37
|
+
) -> ChainFilterStatistics:
|
|
38
|
+
input_file, chain_id = file_and_chain
|
|
39
|
+
logger.debug("Filtering %s on chain %s", input_file, chain_id)
|
|
40
|
+
try:
|
|
41
|
+
output_file = write_single_chain_pdb_file(
|
|
42
|
+
input_file, chain_id, output_dir, out_chain=out_chain, copy_method=copy_method
|
|
43
|
+
)
|
|
44
|
+
return ChainFilterStatistics(
|
|
45
|
+
input_file=input_file,
|
|
46
|
+
chain_id=chain_id,
|
|
47
|
+
output_file=output_file,
|
|
48
|
+
passed=True,
|
|
49
|
+
)
|
|
50
|
+
except Exception as e: # noqa: BLE001 - error is handled downstream
|
|
51
|
+
return ChainFilterStatistics(input_file=input_file, chain_id=chain_id, discard_reason=e)
|
|
52
|
+
|
|
53
|
+
|
|
24
54
|
def filter_files_on_chain(
|
|
25
|
-
|
|
26
|
-
id2chains: dict[str, str],
|
|
55
|
+
file2chains: Collection[tuple[Path, str]],
|
|
27
56
|
output_dir: Path,
|
|
28
|
-
scheduler_address: str | Cluster | None = None,
|
|
29
57
|
out_chain: str = "A",
|
|
30
|
-
|
|
58
|
+
scheduler_address: str | Cluster | Literal["sequential"] | None = None,
|
|
59
|
+
copy_method: CopyMethod = "copy",
|
|
60
|
+
) -> list[ChainFilterStatistics]:
|
|
31
61
|
"""Filter mmcif/PDB files by chain.
|
|
32
62
|
|
|
33
63
|
Args:
|
|
34
|
-
|
|
35
|
-
|
|
64
|
+
file2chains: Which chain to keep for each PDB file.
|
|
65
|
+
First item is the PDB file path, second item is the chain ID.
|
|
36
66
|
output_dir: The directory where the filtered files will be written.
|
|
37
|
-
scheduler_address: The address of the Dask scheduler.
|
|
38
67
|
out_chain: Under what name to write the kept chain.
|
|
68
|
+
scheduler_address: The address of the Dask scheduler.
|
|
69
|
+
If not provided, will create a local cluster.
|
|
70
|
+
If set to `sequential` will run tasks sequentially.
|
|
71
|
+
copy_method: How to copy when a direct copy is possible.
|
|
39
72
|
|
|
40
73
|
Returns:
|
|
41
|
-
|
|
42
|
-
Last tuple item is None if something went wrong like chain not present.
|
|
74
|
+
Result of the filtering process.
|
|
43
75
|
"""
|
|
44
76
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
if scheduler_address == "sequential":
|
|
78
|
+
|
|
79
|
+
def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics:
|
|
80
|
+
return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, copy_method=copy_method)
|
|
81
|
+
|
|
82
|
+
return list(map(task, file2chains))
|
|
83
|
+
|
|
84
|
+
# TODO make logger.debug in filter_file_on_chain show to user when --log
|
|
85
|
+
# GPT-5 generated a fairly difficult setup with a WorkerPlugin, need to find a simpler approach
|
|
45
86
|
scheduler_address = configure_dask_scheduler(
|
|
46
87
|
scheduler_address,
|
|
47
88
|
name="filter-chain",
|
|
48
89
|
)
|
|
49
90
|
|
|
50
|
-
def task(id2chain: tuple[str, str]) -> tuple[str, str, Path | None]:
|
|
51
|
-
pdb_id, chain = id2chain
|
|
52
|
-
input_file = locate_structure_file(input_dir, pdb_id)
|
|
53
|
-
return pdb_id, chain, write_single_chain_pdb_file(input_file, chain, output_dir, out_chain=out_chain)
|
|
54
|
-
|
|
55
91
|
with Client(scheduler_address) as client:
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
92
|
+
client.forward_logging()
|
|
93
|
+
return dask_map_with_progress(
|
|
94
|
+
client,
|
|
95
|
+
filter_file_on_chain,
|
|
96
|
+
file2chains,
|
|
97
|
+
output_dir=output_dir,
|
|
98
|
+
out_chain=out_chain,
|
|
99
|
+
copy_method=copy_method,
|
|
100
|
+
)
|
|
64
101
|
|
|
65
102
|
|
|
66
103
|
@dataclass
|
|
67
|
-
class
|
|
104
|
+
class ResidueFilterStatistics:
|
|
68
105
|
"""Statistics for filtering files based on residue count in a specific chain.
|
|
69
106
|
|
|
70
107
|
Parameters:
|
|
@@ -81,8 +118,13 @@ class FilterStat:
|
|
|
81
118
|
|
|
82
119
|
|
|
83
120
|
def filter_files_on_residues(
|
|
84
|
-
input_files: list[Path],
|
|
85
|
-
|
|
121
|
+
input_files: list[Path],
|
|
122
|
+
output_dir: Path,
|
|
123
|
+
min_residues: int,
|
|
124
|
+
max_residues: int,
|
|
125
|
+
chain: str = "A",
|
|
126
|
+
copy_method: CopyMethod = "copy",
|
|
127
|
+
) -> Generator[ResidueFilterStatistics]:
|
|
86
128
|
"""Filter PDB/mmCIF files by number of residues in given chain.
|
|
87
129
|
|
|
88
130
|
Args:
|
|
@@ -91,9 +133,10 @@ def filter_files_on_residues(
|
|
|
91
133
|
min_residues: The minimum number of residues in chain.
|
|
92
134
|
max_residues: The maximum number of residues in chain.
|
|
93
135
|
chain: The chain to count residues of.
|
|
136
|
+
copy_method: How to copy passed files to output directory:
|
|
94
137
|
|
|
95
138
|
Yields:
|
|
96
|
-
|
|
139
|
+
Objects containing information about the filtering process for each input file.
|
|
97
140
|
"""
|
|
98
141
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
99
142
|
for input_file in tqdm(input_files, unit="file"):
|
|
@@ -101,7 +144,7 @@ def filter_files_on_residues(
|
|
|
101
144
|
passed = min_residues <= residue_count <= max_residues
|
|
102
145
|
if passed:
|
|
103
146
|
output_file = output_dir / input_file.name
|
|
104
|
-
copyfile(input_file, output_file)
|
|
105
|
-
yield
|
|
147
|
+
copyfile(input_file, output_file, copy_method)
|
|
148
|
+
yield ResidueFilterStatistics(input_file, residue_count, True, output_file)
|
|
106
149
|
else:
|
|
107
|
-
yield
|
|
150
|
+
yield ResidueFilterStatistics(input_file, residue_count, False, None)
|
protein_quest/go.py
CHANGED
|
@@ -8,8 +8,8 @@ from io import TextIOWrapper
|
|
|
8
8
|
from typing import Literal, get_args
|
|
9
9
|
|
|
10
10
|
from cattrs.gen import make_dict_structure_fn, override
|
|
11
|
-
from cattrs.preconf.orjson import make_converter
|
|
12
11
|
|
|
12
|
+
from protein_quest.converter import converter
|
|
13
13
|
from protein_quest.utils import friendly_session
|
|
14
14
|
|
|
15
15
|
logger = logging.getLogger(__name__)
|
|
@@ -52,9 +52,6 @@ class SearchResponse:
|
|
|
52
52
|
page_info: PageInfo
|
|
53
53
|
|
|
54
54
|
|
|
55
|
-
converter = make_converter()
|
|
56
|
-
|
|
57
|
-
|
|
58
55
|
def flatten_definition(definition, _context) -> str:
|
|
59
56
|
return definition["text"]
|
|
60
57
|
|
protein_quest/mcp_server.py
CHANGED
|
@@ -24,12 +24,11 @@ npx @modelcontextprotocol/inspector
|
|
|
24
24
|
# Choose STDIO
|
|
25
25
|
# command: uv run protein-quest mcp
|
|
26
26
|
# id: protein-quest
|
|
27
|
-
# Prompt: What are the PDBe structures for `A8MT69` uniprot accession?
|
|
28
27
|
```
|
|
29
28
|
|
|
30
29
|
Examples:
|
|
31
30
|
|
|
32
|
-
|
|
31
|
+
- What are the PDBe structures for `A8MT69` uniprot accession?
|
|
33
32
|
|
|
34
33
|
"""
|
|
35
34
|
|
|
@@ -47,6 +46,7 @@ from protein_quest.emdb import fetch as emdb_fetch
|
|
|
47
46
|
from protein_quest.go import search_gene_ontology_term
|
|
48
47
|
from protein_quest.pdbe.fetch import fetch as pdbe_fetch
|
|
49
48
|
from protein_quest.pdbe.io import glob_structure_files, nr_residues_in_chain, write_single_chain_pdb_file
|
|
49
|
+
from protein_quest.ss import filter_file_on_secondary_structure
|
|
50
50
|
from protein_quest.taxonomy import search_taxon
|
|
51
51
|
from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
|
|
52
52
|
|
|
@@ -90,7 +90,7 @@ def extract_single_chain_from_structure(
|
|
|
90
90
|
chain2keep: str,
|
|
91
91
|
output_dir: Path,
|
|
92
92
|
out_chain: str = "A",
|
|
93
|
-
) -> Path
|
|
93
|
+
) -> Path:
|
|
94
94
|
"""
|
|
95
95
|
Extract a single chain from a mmCIF/pdb file and write to a new file.
|
|
96
96
|
|
|
@@ -101,7 +101,7 @@ def extract_single_chain_from_structure(
|
|
|
101
101
|
out_chain: The chain identifier for the output file.
|
|
102
102
|
|
|
103
103
|
Returns:
|
|
104
|
-
Path to the output mmCIF/pdb file
|
|
104
|
+
Path to the output mmCIF/pdb file
|
|
105
105
|
"""
|
|
106
106
|
return write_single_chain_pdb_file(input_file, chain2keep, output_dir, out_chain)
|
|
107
107
|
|
|
@@ -150,7 +150,7 @@ def fetch_alphafold_structures(uniprot_accs: set[str], save_dir: Path) -> list[A
|
|
|
150
150
|
Returns:
|
|
151
151
|
A list of AlphaFold entries.
|
|
152
152
|
"""
|
|
153
|
-
what: set[DownloadableFormat] = {"cif"}
|
|
153
|
+
what: set[DownloadableFormat] = {"summary", "cif"}
|
|
154
154
|
return alphafold_fetch(uniprot_accs, save_dir, what)
|
|
155
155
|
|
|
156
156
|
|
|
@@ -166,6 +166,9 @@ def alphafold_confidence_filter(file: Path, query: ConfidenceFilterQuery, filter
|
|
|
166
166
|
return filter_file_on_residues(file, query, filtered_dir)
|
|
167
167
|
|
|
168
168
|
|
|
169
|
+
mcp.tool(filter_file_on_secondary_structure)
|
|
170
|
+
|
|
171
|
+
|
|
169
172
|
@mcp.prompt
|
|
170
173
|
def candidate_structures(
|
|
171
174
|
species: str = "Human",
|