protein-quest 0.3.2__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of protein-quest might be problematic. Click here for more details.
- protein_quest/__version__.py +1 -1
- protein_quest/alphafold/fetch.py +34 -9
- protein_quest/cli.py +207 -26
- protein_quest/converter.py +1 -0
- protein_quest/emdb.py +6 -3
- protein_quest/mcp_server.py +34 -3
- protein_quest/pdbe/fetch.py +6 -3
- protein_quest/ss.py +20 -0
- protein_quest/uniprot.py +157 -4
- protein_quest/utils.py +367 -23
- {protein_quest-0.3.2.dist-info → protein_quest-0.5.0.dist-info}/METADATA +41 -3
- protein_quest-0.5.0.dist-info/RECORD +26 -0
- protein_quest-0.3.2.dist-info/RECORD +0 -26
- {protein_quest-0.3.2.dist-info → protein_quest-0.5.0.dist-info}/WHEEL +0 -0
- {protein_quest-0.3.2.dist-info → protein_quest-0.5.0.dist-info}/entry_points.txt +0 -0
- {protein_quest-0.3.2.dist-info → protein_quest-0.5.0.dist-info}/licenses/LICENSE +0 -0
protein_quest/uniprot.py
CHANGED
|
@@ -201,7 +201,7 @@ def _build_sparql_generic_query(select_clause: str, where_clause: str, limit: in
|
|
|
201
201
|
""")
|
|
202
202
|
|
|
203
203
|
|
|
204
|
-
def
|
|
204
|
+
def _build_sparql_generic_by_uniprot_accessions_query(
|
|
205
205
|
uniprot_accs: Iterable[str], select_clause: str, where_clause: str, limit: int = 10_000, groupby_clause=""
|
|
206
206
|
) -> str:
|
|
207
207
|
values = " ".join(f'("{ac}")' for ac in uniprot_accs)
|
|
@@ -269,7 +269,7 @@ def _build_sparql_query_pdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
|
|
|
269
269
|
""")
|
|
270
270
|
|
|
271
271
|
groupby_clause = "?protein ?pdb_db ?pdb_method ?pdb_resolution"
|
|
272
|
-
return
|
|
272
|
+
return _build_sparql_generic_by_uniprot_accessions_query(
|
|
273
273
|
uniprot_accs, select_clause, where_clause, limit, groupby_clause
|
|
274
274
|
)
|
|
275
275
|
|
|
@@ -284,7 +284,7 @@ def _build_sparql_query_af(uniprot_accs: Iterable[str], limit=10_000) -> str:
|
|
|
284
284
|
?protein rdfs:seeAlso ?af_db .
|
|
285
285
|
?af_db up:database <http://purl.uniprot.org/database/AlphaFoldDB> .
|
|
286
286
|
""")
|
|
287
|
-
return
|
|
287
|
+
return _build_sparql_generic_by_uniprot_accessions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
|
|
288
288
|
|
|
289
289
|
|
|
290
290
|
def _build_sparql_query_emdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
|
|
@@ -297,7 +297,7 @@ def _build_sparql_query_emdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
|
|
|
297
297
|
?protein rdfs:seeAlso ?emdb_db .
|
|
298
298
|
?emdb_db up:database <http://purl.uniprot.org/database/EMDB> .
|
|
299
299
|
""")
|
|
300
|
-
return
|
|
300
|
+
return _build_sparql_generic_by_uniprot_accessions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
|
|
301
301
|
|
|
302
302
|
|
|
303
303
|
def _execute_sparql_search(
|
|
@@ -509,3 +509,156 @@ def search4emdb(uniprot_accs: Iterable[str], limit: int = 10_000, timeout: int =
|
|
|
509
509
|
)
|
|
510
510
|
limit_check("Search for EMDB entries on uniprot", limit, len(raw_results))
|
|
511
511
|
return _flatten_results_emdb(raw_results)
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def _build_complex_sparql_query(uniprot_accs: Iterable[str], limit: int) -> str:
|
|
515
|
+
"""Builds a SPARQL query to retrieve ComplexPortal information for given UniProt accessions.
|
|
516
|
+
|
|
517
|
+
Example:
|
|
518
|
+
|
|
519
|
+
```sparql
|
|
520
|
+
PREFIX up: <http://purl.uniprot.org/core/>
|
|
521
|
+
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
|
|
522
|
+
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
|
523
|
+
|
|
524
|
+
SELECT
|
|
525
|
+
?protein
|
|
526
|
+
?cp_db
|
|
527
|
+
?cp_comment
|
|
528
|
+
(GROUP_CONCAT(DISTINCT ?member; separator=",") AS ?complex_members)
|
|
529
|
+
(COUNT(DISTINCT ?member) AS ?member_count)
|
|
530
|
+
WHERE {
|
|
531
|
+
# Input UniProt accessions
|
|
532
|
+
VALUES (?ac) { ("P05067") ("P60709") ("Q05471")}
|
|
533
|
+
BIND (IRI(CONCAT("http://purl.uniprot.org/uniprot/", ?ac)) AS ?protein)
|
|
534
|
+
|
|
535
|
+
# ComplexPortal cross-reference for each input protein
|
|
536
|
+
?protein a up:Protein ;
|
|
537
|
+
rdfs:seeAlso ?cp_db .
|
|
538
|
+
?cp_db up:database <http://purl.uniprot.org/database/ComplexPortal> .
|
|
539
|
+
OPTIONAL { ?cp_db rdfs:comment ?cp_comment . }
|
|
540
|
+
|
|
541
|
+
# All member proteins of the same ComplexPortal complex
|
|
542
|
+
?member a up:Protein ;
|
|
543
|
+
rdfs:seeAlso ?cp_db .
|
|
544
|
+
}
|
|
545
|
+
GROUP BY ?protein ?cp_db ?cp_comment
|
|
546
|
+
ORDER BY ?protein ?cp_db
|
|
547
|
+
LIMIT 500
|
|
548
|
+
```
|
|
549
|
+
|
|
550
|
+
"""
|
|
551
|
+
select_clause = dedent("""\
|
|
552
|
+
?protein ?cp_db ?cp_comment
|
|
553
|
+
(GROUP_CONCAT(DISTINCT ?member; separator=",") AS ?complex_members)
|
|
554
|
+
""")
|
|
555
|
+
where_clause = dedent("""
|
|
556
|
+
# --- Complex Info ---
|
|
557
|
+
?protein a up:Protein ;
|
|
558
|
+
rdfs:seeAlso ?cp_db .
|
|
559
|
+
?cp_db up:database <http://purl.uniprot.org/database/ComplexPortal> .
|
|
560
|
+
OPTIONAL { ?cp_db rdfs:comment ?cp_comment . }
|
|
561
|
+
# All member proteins of the same ComplexPortal complex
|
|
562
|
+
?member a up:Protein ;
|
|
563
|
+
rdfs:seeAlso ?cp_db .
|
|
564
|
+
""")
|
|
565
|
+
group_by = dedent("""
|
|
566
|
+
?protein ?cp_db ?cp_comment
|
|
567
|
+
""")
|
|
568
|
+
return _build_sparql_generic_by_uniprot_accessions_query(
|
|
569
|
+
uniprot_accs, select_clause, where_clause, limit, groupby_clause=group_by
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
@dataclass(frozen=True)
|
|
574
|
+
class ComplexPortalEntry:
|
|
575
|
+
"""A ComplexPortal entry.
|
|
576
|
+
|
|
577
|
+
Parameters:
|
|
578
|
+
query_protein: The UniProt accession used to find entry.
|
|
579
|
+
complex_id: The ComplexPortal identifier (for example "CPX-1234").
|
|
580
|
+
complex_url: The URL to the ComplexPortal entry.
|
|
581
|
+
complex_title: The title of the complex.
|
|
582
|
+
members: UniProt accessions which are members of the complex.
|
|
583
|
+
"""
|
|
584
|
+
|
|
585
|
+
query_protein: str
|
|
586
|
+
complex_id: str
|
|
587
|
+
complex_url: str
|
|
588
|
+
complex_title: str
|
|
589
|
+
members: set[str]
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
def _flatten_results_complex(raw_results) -> list[ComplexPortalEntry]:
|
|
593
|
+
results = []
|
|
594
|
+
for raw_result in raw_results:
|
|
595
|
+
query_protein = raw_result["protein"]["value"].split("/")[-1]
|
|
596
|
+
complex_id = raw_result["cp_db"]["value"].split("/")[-1]
|
|
597
|
+
complex_url = f"https://www.ebi.ac.uk/complexportal/complex/{complex_id}"
|
|
598
|
+
complex_title = raw_result.get("cp_comment", {}).get("value", "")
|
|
599
|
+
members = {m.split("/")[-1] for m in raw_result["complex_members"]["value"].split(",")}
|
|
600
|
+
results.append(
|
|
601
|
+
ComplexPortalEntry(
|
|
602
|
+
query_protein=query_protein,
|
|
603
|
+
complex_id=complex_id,
|
|
604
|
+
complex_url=complex_url,
|
|
605
|
+
complex_title=complex_title,
|
|
606
|
+
members=members,
|
|
607
|
+
)
|
|
608
|
+
)
|
|
609
|
+
return results
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def search4macromolecular_complexes(
|
|
613
|
+
uniprot_accs: Iterable[str], limit: int = 10_000, timeout: int = 1_800
|
|
614
|
+
) -> list[ComplexPortalEntry]:
|
|
615
|
+
"""Search for macromolecular complexes by UniProtKB accessions.
|
|
616
|
+
|
|
617
|
+
Queries for references to/from https://www.ebi.ac.uk/complexportal/ database in the Uniprot SPARQL endpoint.
|
|
618
|
+
|
|
619
|
+
Args:
|
|
620
|
+
uniprot_accs: UniProt accessions.
|
|
621
|
+
limit: Maximum number of results to return.
|
|
622
|
+
timeout: Timeout for the SPARQL query in seconds.
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
List of ComplexPortalEntry objects.
|
|
626
|
+
"""
|
|
627
|
+
sparql_query = _build_complex_sparql_query(uniprot_accs, limit)
|
|
628
|
+
logger.info("Executing SPARQL query for macromolecular complexes: %s", sparql_query)
|
|
629
|
+
raw_results = _execute_sparql_search(
|
|
630
|
+
sparql_query=sparql_query,
|
|
631
|
+
timeout=timeout,
|
|
632
|
+
)
|
|
633
|
+
limit_check("Search for complexes", limit, len(raw_results))
|
|
634
|
+
return _flatten_results_complex(raw_results)
|
|
635
|
+
|
|
636
|
+
|
|
637
|
+
def search4interaction_partners(
|
|
638
|
+
uniprot_acc: str, excludes: set[str] | None = None, limit: int = 10_000, timeout: int = 1_800
|
|
639
|
+
) -> dict[str, set[str]]:
|
|
640
|
+
"""Search for interaction partners of a given UniProt accession using ComplexPortal database references.
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
uniprot_acc: UniProt accession to search interaction partners for.
|
|
644
|
+
excludes: Set of UniProt accessions to exclude from the results.
|
|
645
|
+
For example already known interaction partners.
|
|
646
|
+
If None then no complex members are excluded.
|
|
647
|
+
limit: Maximum number of results to return.
|
|
648
|
+
timeout: Timeout for the SPARQL query in seconds.
|
|
649
|
+
|
|
650
|
+
Returns:
|
|
651
|
+
Dictionary with UniProt accessions of interaction partners as keys and sets of ComplexPortal entry IDs
|
|
652
|
+
in which the interaction occurs as values.
|
|
653
|
+
"""
|
|
654
|
+
ucomplexes = search4macromolecular_complexes([uniprot_acc], limit=limit, timeout=timeout)
|
|
655
|
+
hits: dict[str, set[str]] = {}
|
|
656
|
+
if excludes is None:
|
|
657
|
+
excludes = set()
|
|
658
|
+
for ucomplex in ucomplexes:
|
|
659
|
+
for member in ucomplex.members:
|
|
660
|
+
if member != uniprot_acc and member not in excludes:
|
|
661
|
+
if member not in hits:
|
|
662
|
+
hits[member] = set()
|
|
663
|
+
hits[member].add(ucomplex.complex_id)
|
|
664
|
+
return hits
|
protein_quest/utils.py
CHANGED
|
@@ -1,22 +1,260 @@
|
|
|
1
1
|
"""Module for functions that are used in multiple places."""
|
|
2
2
|
|
|
3
|
+
import argparse
|
|
3
4
|
import asyncio
|
|
5
|
+
import hashlib
|
|
4
6
|
import logging
|
|
5
7
|
import shutil
|
|
6
|
-
from collections.abc import Coroutine, Iterable
|
|
8
|
+
from collections.abc import Coroutine, Iterable, Sequence
|
|
7
9
|
from contextlib import asynccontextmanager
|
|
10
|
+
from functools import lru_cache
|
|
8
11
|
from pathlib import Path
|
|
9
12
|
from textwrap import dedent
|
|
10
|
-
from typing import Any, Literal, get_args
|
|
13
|
+
from typing import Any, Literal, Protocol, get_args, runtime_checkable
|
|
11
14
|
|
|
12
15
|
import aiofiles
|
|
16
|
+
import aiofiles.os
|
|
13
17
|
import aiohttp
|
|
18
|
+
import rich
|
|
19
|
+
from aiohttp.streams import AsyncStreamIterator
|
|
14
20
|
from aiohttp_retry import ExponentialRetry, RetryClient
|
|
21
|
+
from platformdirs import user_cache_dir
|
|
22
|
+
from rich_argparse import ArgumentDefaultsRichHelpFormatter
|
|
15
23
|
from tqdm.asyncio import tqdm
|
|
16
24
|
from yarl import URL
|
|
17
25
|
|
|
18
26
|
logger = logging.getLogger(__name__)
|
|
19
27
|
|
|
28
|
+
CopyMethod = Literal["copy", "symlink", "hardlink"]
|
|
29
|
+
"""Methods for copying files."""
|
|
30
|
+
copy_methods = set(get_args(CopyMethod))
|
|
31
|
+
"""Set of valid copy methods."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@lru_cache
|
|
35
|
+
def _cache_sub_dir(root_cache_dir: Path, filename: str, hash_length: int = 4) -> Path:
|
|
36
|
+
"""Get the cache sub-directory for a given path.
|
|
37
|
+
|
|
38
|
+
To not have too many files in a single directory,
|
|
39
|
+
we create sub-directories based on the hash of the filename.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
root_cache_dir: The root directory for the cache.
|
|
43
|
+
filename: The filename to be cached.
|
|
44
|
+
hash_length: The length of the hash to use for the sub-directory.
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The parent path to the cached file.
|
|
48
|
+
"""
|
|
49
|
+
full_hash = hashlib.blake2b(filename.encode("utf-8")).hexdigest()
|
|
50
|
+
cache_sub_dir = full_hash[:hash_length]
|
|
51
|
+
cache_sub_dir_path = root_cache_dir / cache_sub_dir
|
|
52
|
+
cache_sub_dir_path.mkdir(parents=True, exist_ok=True)
|
|
53
|
+
return cache_sub_dir_path
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@runtime_checkable
|
|
57
|
+
class Cacher(Protocol):
|
|
58
|
+
"""Protocol for a cacher."""
|
|
59
|
+
|
|
60
|
+
def __contains__(self, item: str | Path) -> bool:
|
|
61
|
+
"""Check if a file is in the cache.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
item: The filename or Path to check.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
True if the file is in the cache, False otherwise.
|
|
68
|
+
"""
|
|
69
|
+
...
|
|
70
|
+
|
|
71
|
+
async def copy_from_cache(self, target: Path) -> Path | None:
|
|
72
|
+
"""Copy a file from the cache to a target location if it exists in the cache.
|
|
73
|
+
|
|
74
|
+
Assumes:
|
|
75
|
+
|
|
76
|
+
- target does not exist.
|
|
77
|
+
- the parent directory of target exists.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
target: The path to copy the file to.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
The path to the cached file if it was copied, None otherwise.
|
|
84
|
+
"""
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
|
|
88
|
+
"""Write content to a file and cache it.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
target: The path to write the content to.
|
|
92
|
+
content: An async iterator that yields bytes to write to the file.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
The path to the cached file.
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
FileExistsError: If the target file already exists.
|
|
99
|
+
"""
|
|
100
|
+
...
|
|
101
|
+
|
|
102
|
+
async def write_bytes(self, target: Path, content: bytes) -> Path:
|
|
103
|
+
"""Write bytes to a file and cache it.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
target: The path to write the content to.
|
|
107
|
+
content: The bytes to write to the file.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
The path to the cached file.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
FileExistsError: If the target file already exists.
|
|
114
|
+
"""
|
|
115
|
+
...
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class PassthroughCacher(Cacher):
|
|
119
|
+
"""A cacher that caches nothing.
|
|
120
|
+
|
|
121
|
+
On writes it just writes to the target path.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
def __contains__(self, item: str | Path) -> bool:
|
|
125
|
+
# We don't have anything cached ever
|
|
126
|
+
return False
|
|
127
|
+
|
|
128
|
+
async def copy_from_cache(self, target: Path) -> Path | None: # noqa: ARG002
|
|
129
|
+
# We don't have anything cached ever
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
|
|
133
|
+
if target.exists():
|
|
134
|
+
raise FileExistsError(target)
|
|
135
|
+
target.write_bytes(b"".join([chunk async for chunk in content]))
|
|
136
|
+
return target
|
|
137
|
+
|
|
138
|
+
async def write_bytes(self, target: Path, content: bytes) -> Path:
|
|
139
|
+
if target.exists():
|
|
140
|
+
raise FileExistsError(target)
|
|
141
|
+
target.write_bytes(content)
|
|
142
|
+
return target
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def user_cache_root_dir() -> Path:
|
|
146
|
+
"""Get the users root directory for caching files.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The path to the user's cache directory for protein-quest.
|
|
150
|
+
"""
|
|
151
|
+
return Path(user_cache_dir("protein-quest"))
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class DirectoryCacher(Cacher):
|
|
155
|
+
"""Class to cache files in a directory.
|
|
156
|
+
|
|
157
|
+
Caching logic is based on the file name only.
|
|
158
|
+
If file name of paths are the same then the files are considered the same.
|
|
159
|
+
|
|
160
|
+
Attributes:
|
|
161
|
+
cache_dir: The directory to use for caching.
|
|
162
|
+
copy_method: The method to use for copying files.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(
|
|
166
|
+
self,
|
|
167
|
+
cache_dir: Path | None = None,
|
|
168
|
+
copy_method: CopyMethod = "hardlink",
|
|
169
|
+
) -> None:
|
|
170
|
+
"""Initialize the cacher.
|
|
171
|
+
|
|
172
|
+
If file name of paths are the same then the files are considered the same.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
cache_dir: The directory to use for caching.
|
|
176
|
+
If None, a default cache directory (~/.cache/protein-quest) is used.
|
|
177
|
+
copy_method: The method to use for copying.
|
|
178
|
+
"""
|
|
179
|
+
if cache_dir is None:
|
|
180
|
+
cache_dir = user_cache_root_dir()
|
|
181
|
+
self.cache_dir: Path = cache_dir
|
|
182
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
if copy_method == "copy":
|
|
184
|
+
logger.warning(
|
|
185
|
+
"Using copy as copy_method to cache files is not recommended. "
|
|
186
|
+
"This will use more disk space and be slower than symlink or hardlink."
|
|
187
|
+
)
|
|
188
|
+
if copy_method not in copy_methods:
|
|
189
|
+
msg = f"Unknown copy method: {copy_method}. Must be one of {copy_methods}."
|
|
190
|
+
raise ValueError(msg)
|
|
191
|
+
self.copy_method: CopyMethod = copy_method
|
|
192
|
+
|
|
193
|
+
def __contains__(self, item: str | Path) -> bool:
|
|
194
|
+
cached_file = self._as_cached_path(item)
|
|
195
|
+
return cached_file.exists()
|
|
196
|
+
|
|
197
|
+
def _as_cached_path(self, item: str | Path) -> Path:
|
|
198
|
+
file_name = item.name if isinstance(item, Path) else item
|
|
199
|
+
cache_sub_dir = _cache_sub_dir(self.cache_dir, file_name)
|
|
200
|
+
return cache_sub_dir / file_name
|
|
201
|
+
|
|
202
|
+
async def copy_from_cache(self, target: Path) -> Path | None:
|
|
203
|
+
cached_file = self._as_cached_path(target.name)
|
|
204
|
+
exists = await aiofiles.os.path.exists(str(cached_file))
|
|
205
|
+
if exists:
|
|
206
|
+
await async_copyfile(cached_file, target, copy_method=self.copy_method)
|
|
207
|
+
return cached_file
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
|
|
211
|
+
cached_file = self._as_cached_path(target.name)
|
|
212
|
+
# Write file to cache dir
|
|
213
|
+
async with aiofiles.open(cached_file, "xb") as f:
|
|
214
|
+
async for chunk in content:
|
|
215
|
+
await f.write(chunk)
|
|
216
|
+
# Copy to target location
|
|
217
|
+
await async_copyfile(cached_file, target, copy_method=self.copy_method)
|
|
218
|
+
return cached_file
|
|
219
|
+
|
|
220
|
+
async def write_bytes(self, target: Path, content: bytes) -> Path:
|
|
221
|
+
cached_file = self._as_cached_path(target.name)
|
|
222
|
+
# Write file to cache dir
|
|
223
|
+
async with aiofiles.open(cached_file, "xb") as f:
|
|
224
|
+
await f.write(content)
|
|
225
|
+
# Copy to target location
|
|
226
|
+
await async_copyfile(cached_file, target, copy_method=self.copy_method)
|
|
227
|
+
return cached_file
|
|
228
|
+
|
|
229
|
+
def populate_cache(self, source_dir: Path) -> dict[Path, Path]:
|
|
230
|
+
"""Populate the cache from an existing directory.
|
|
231
|
+
|
|
232
|
+
This will copy all files from the source directory to the cache directory.
|
|
233
|
+
If a file with the same name already exists in the cache, it will be skipped.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
source_dir: The directory to populate the cache from.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
A dictionary mapping source file paths to their cached paths.
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
NotADirectoryError: If the source_dir is not a directory.
|
|
243
|
+
"""
|
|
244
|
+
if not source_dir.is_dir():
|
|
245
|
+
raise NotADirectoryError(source_dir)
|
|
246
|
+
cached = {}
|
|
247
|
+
for file_path in source_dir.iterdir():
|
|
248
|
+
if not file_path.is_file():
|
|
249
|
+
continue
|
|
250
|
+
cached_path = self._as_cached_path(file_path.name)
|
|
251
|
+
if cached_path.exists():
|
|
252
|
+
logger.debug(f"File {file_path.name} already in cache. Skipping.")
|
|
253
|
+
continue
|
|
254
|
+
copyfile(file_path, cached_path, copy_method=self.copy_method)
|
|
255
|
+
cached[file_path] = cached_path
|
|
256
|
+
return cached
|
|
257
|
+
|
|
20
258
|
|
|
21
259
|
async def retrieve_files(
|
|
22
260
|
urls: Iterable[tuple[URL | str, str]],
|
|
@@ -25,6 +263,8 @@ async def retrieve_files(
|
|
|
25
263
|
retries: int = 3,
|
|
26
264
|
total_timeout: int = 300,
|
|
27
265
|
desc: str = "Downloading files",
|
|
266
|
+
cacher: Cacher | None = None,
|
|
267
|
+
chunk_size: int = 524288, # 512 KiB
|
|
28
268
|
) -> list[Path]:
|
|
29
269
|
"""Retrieve files from a list of URLs and save them to a directory.
|
|
30
270
|
|
|
@@ -35,6 +275,8 @@ async def retrieve_files(
|
|
|
35
275
|
retries: The number of times to retry a failed download.
|
|
36
276
|
total_timeout: The total timeout for a download in seconds.
|
|
37
277
|
desc: Description for the progress bar.
|
|
278
|
+
cacher: An optional cacher to use for caching files.
|
|
279
|
+
chunk_size: The size of each chunk to read from the response.
|
|
38
280
|
|
|
39
281
|
Returns:
|
|
40
282
|
A list of paths to the downloaded files.
|
|
@@ -42,7 +284,17 @@ async def retrieve_files(
|
|
|
42
284
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
43
285
|
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
|
44
286
|
async with friendly_session(retries, total_timeout) as session:
|
|
45
|
-
tasks = [
|
|
287
|
+
tasks = [
|
|
288
|
+
_retrieve_file(
|
|
289
|
+
session=session,
|
|
290
|
+
url=url,
|
|
291
|
+
save_path=save_dir / filename,
|
|
292
|
+
semaphore=semaphore,
|
|
293
|
+
cacher=cacher,
|
|
294
|
+
chunk_size=chunk_size,
|
|
295
|
+
)
|
|
296
|
+
for url, filename in urls
|
|
297
|
+
]
|
|
46
298
|
files: list[Path] = await tqdm.gather(*tasks, desc=desc)
|
|
47
299
|
return files
|
|
48
300
|
|
|
@@ -52,8 +304,8 @@ async def _retrieve_file(
|
|
|
52
304
|
url: URL | str,
|
|
53
305
|
save_path: Path,
|
|
54
306
|
semaphore: asyncio.Semaphore,
|
|
55
|
-
|
|
56
|
-
chunk_size: int =
|
|
307
|
+
cacher: Cacher | None = None,
|
|
308
|
+
chunk_size: int = 524288, # 512 KiB
|
|
57
309
|
) -> Path:
|
|
58
310
|
"""Retrieve a single file from a URL and save it to a specified path.
|
|
59
311
|
|
|
@@ -62,26 +314,28 @@ async def _retrieve_file(
|
|
|
62
314
|
url: The URL to download the file from.
|
|
63
315
|
save_path: The path where the file should be saved.
|
|
64
316
|
semaphore: A semaphore to limit the number of concurrent downloads.
|
|
65
|
-
|
|
317
|
+
cacher: An optional cacher to use for caching files.
|
|
66
318
|
chunk_size: The size of each chunk to read from the response.
|
|
67
319
|
|
|
68
320
|
Returns:
|
|
69
321
|
The path to the saved file.
|
|
70
322
|
"""
|
|
71
323
|
if save_path.exists():
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
324
|
+
logger.debug(f"File {save_path} already exists. Skipping download from {url}.")
|
|
325
|
+
return save_path
|
|
326
|
+
|
|
327
|
+
if cacher is None:
|
|
328
|
+
cacher = PassthroughCacher()
|
|
329
|
+
if cached_file := await cacher.copy_from_cache(save_path):
|
|
330
|
+
logger.debug(f"File {save_path} was copied from cache {cached_file}. Skipping download from {url}.")
|
|
331
|
+
return save_path
|
|
332
|
+
|
|
77
333
|
async with (
|
|
78
334
|
semaphore,
|
|
79
|
-
aiofiles.open(save_path, "xb") as f,
|
|
80
335
|
session.get(url) as resp,
|
|
81
336
|
):
|
|
82
337
|
resp.raise_for_status()
|
|
83
|
-
|
|
84
|
-
await f.write(chunk)
|
|
338
|
+
await cacher.write_iter(save_path, resp.content.iter_chunked(chunk_size))
|
|
85
339
|
return save_path
|
|
86
340
|
|
|
87
341
|
|
|
@@ -141,27 +395,117 @@ def run_async[R](coroutine: Coroutine[Any, Any, R]) -> R:
|
|
|
141
395
|
raise NestedAsyncIOLoopError from e
|
|
142
396
|
|
|
143
397
|
|
|
144
|
-
CopyMethod = Literal["copy", "symlink"]
|
|
145
|
-
copy_methods = set(get_args(CopyMethod))
|
|
146
|
-
|
|
147
|
-
|
|
148
398
|
def copyfile(source: Path, target: Path, copy_method: CopyMethod = "copy"):
|
|
149
|
-
"""Make target path be same file as source by either copying or symlinking.
|
|
399
|
+
"""Make target path be same file as source by either copying or symlinking or hardlinking.
|
|
400
|
+
|
|
401
|
+
Note that the hardlink copy method only works within the same filesystem and is harder to track.
|
|
402
|
+
If you want to track cached files easily then use 'symlink'.
|
|
403
|
+
On Windows you need developer mode or admin privileges to create symlinks.
|
|
150
404
|
|
|
151
405
|
Args:
|
|
152
|
-
source: The source file to copy or
|
|
406
|
+
source: The source file to copy or link.
|
|
153
407
|
target: The target file to create.
|
|
154
408
|
copy_method: The method to use for copying.
|
|
155
409
|
|
|
156
410
|
Raises:
|
|
157
411
|
FileNotFoundError: If the source file or parent of target does not exist.
|
|
158
|
-
|
|
412
|
+
FileExistsError: If the target file already exists.
|
|
413
|
+
ValueError: If an unknown copy method is provided.
|
|
159
414
|
"""
|
|
160
415
|
if copy_method == "copy":
|
|
161
416
|
shutil.copyfile(source, target)
|
|
162
417
|
elif copy_method == "symlink":
|
|
163
|
-
rel_source = source.relative_to(target.parent, walk_up=True)
|
|
418
|
+
rel_source = source.absolute().relative_to(target.parent.absolute(), walk_up=True)
|
|
164
419
|
target.symlink_to(rel_source)
|
|
420
|
+
elif copy_method == "hardlink":
|
|
421
|
+
target.hardlink_to(source)
|
|
165
422
|
else:
|
|
166
|
-
msg = f"Unknown method: {copy_method}"
|
|
423
|
+
msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
|
|
167
424
|
raise ValueError(msg)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
async def async_copyfile(
|
|
428
|
+
source: Path,
|
|
429
|
+
target: Path,
|
|
430
|
+
copy_method: CopyMethod = "copy",
|
|
431
|
+
):
|
|
432
|
+
"""Asynchronously make target path be same file as source by either copying or symlinking or hardlinking.
|
|
433
|
+
|
|
434
|
+
Note that the hardlink copy method only works within the same filesystem and is harder to track.
|
|
435
|
+
If you want to track cached files easily then use 'symlink'.
|
|
436
|
+
On Windows you need developer mode or admin privileges to create symlinks.
|
|
437
|
+
|
|
438
|
+
Args:
|
|
439
|
+
source: The source file to copy.
|
|
440
|
+
target: The target file to create.
|
|
441
|
+
copy_method: The method to use for copying.
|
|
442
|
+
|
|
443
|
+
Raises:
|
|
444
|
+
FileNotFoundError: If the source file or parent of target does not exist.
|
|
445
|
+
FileExistsError: If the target file already exists.
|
|
446
|
+
ValueError: If an unknown copy method is provided.
|
|
447
|
+
"""
|
|
448
|
+
if copy_method == "copy":
|
|
449
|
+
# Could use loop of chunks with aiofiles,
|
|
450
|
+
# but shutil is ~1.9x faster on my machine
|
|
451
|
+
# due to fastcopy and sendfile optimizations in shutil.
|
|
452
|
+
await asyncio.to_thread(shutil.copyfile, source, target)
|
|
453
|
+
elif copy_method == "symlink":
|
|
454
|
+
rel_source = source.relative_to(target.parent, walk_up=True)
|
|
455
|
+
await aiofiles.os.symlink(str(rel_source), str(target))
|
|
456
|
+
elif copy_method == "hardlink":
|
|
457
|
+
await aiofiles.os.link(str(source), str(target))
|
|
458
|
+
else:
|
|
459
|
+
msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
|
|
460
|
+
raise ValueError(msg)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def populate_cache_command(raw_args: Sequence[str] | None = None):
|
|
464
|
+
"""Command line interface to populate the cache from an existing directory.
|
|
465
|
+
|
|
466
|
+
Can be called from the command line as:
|
|
467
|
+
|
|
468
|
+
```bash
|
|
469
|
+
python3 -m protein_quest.utils populate-cache /path/to/source/dir
|
|
470
|
+
```
|
|
471
|
+
|
|
472
|
+
Args:
|
|
473
|
+
raw_args: The raw command line arguments to parse. If None, uses sys.argv.
|
|
474
|
+
"""
|
|
475
|
+
root_parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsRichHelpFormatter)
|
|
476
|
+
subparsers = root_parser.add_subparsers(dest="command")
|
|
477
|
+
|
|
478
|
+
desc = "Populate the cache directory with files from the source directory."
|
|
479
|
+
populate_cache_parser = subparsers.add_parser(
|
|
480
|
+
"populate-cache",
|
|
481
|
+
help=desc,
|
|
482
|
+
description=desc,
|
|
483
|
+
formatter_class=ArgumentDefaultsRichHelpFormatter,
|
|
484
|
+
)
|
|
485
|
+
populate_cache_parser.add_argument("source_dir", type=Path)
|
|
486
|
+
populate_cache_parser.add_argument(
|
|
487
|
+
"--cache-dir",
|
|
488
|
+
type=Path,
|
|
489
|
+
default=user_cache_root_dir(),
|
|
490
|
+
help="Directory to use for caching. If not provided, a default cache directory is used.",
|
|
491
|
+
)
|
|
492
|
+
populate_cache_parser.add_argument(
|
|
493
|
+
"--copy-method",
|
|
494
|
+
type=str,
|
|
495
|
+
default="hardlink",
|
|
496
|
+
choices=copy_methods,
|
|
497
|
+
help="Method to use for copying files to cache.",
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
args = root_parser.parse_args(raw_args)
|
|
501
|
+
if args.command == "populate-cache":
|
|
502
|
+
source_dir = args.source_dir
|
|
503
|
+
cacher = DirectoryCacher(cache_dir=args.cache_dir, copy_method=args.copy_method)
|
|
504
|
+
cached_files = cacher.populate_cache(source_dir)
|
|
505
|
+
rich.print(f"Cached {len(cached_files)} files from {source_dir} to {cacher.cache_dir}")
|
|
506
|
+
for src, cached in cached_files.items():
|
|
507
|
+
rich.print(f"- {src.relative_to(source_dir)} -> {cached.relative_to(cacher.cache_dir)}")
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
if __name__ == "__main__":
|
|
511
|
+
populate_cache_command()
|