protein-quest 0.3.2__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/uniprot.py CHANGED
@@ -201,7 +201,7 @@ def _build_sparql_generic_query(select_clause: str, where_clause: str, limit: in
201
201
  """)
202
202
 
203
203
 
204
- def _build_sparql_generic_by_uniprot_accesions_query(
204
+ def _build_sparql_generic_by_uniprot_accessions_query(
205
205
  uniprot_accs: Iterable[str], select_clause: str, where_clause: str, limit: int = 10_000, groupby_clause=""
206
206
  ) -> str:
207
207
  values = " ".join(f'("{ac}")' for ac in uniprot_accs)
@@ -269,7 +269,7 @@ def _build_sparql_query_pdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
269
269
  """)
270
270
 
271
271
  groupby_clause = "?protein ?pdb_db ?pdb_method ?pdb_resolution"
272
- return _build_sparql_generic_by_uniprot_accesions_query(
272
+ return _build_sparql_generic_by_uniprot_accessions_query(
273
273
  uniprot_accs, select_clause, where_clause, limit, groupby_clause
274
274
  )
275
275
 
@@ -284,7 +284,7 @@ def _build_sparql_query_af(uniprot_accs: Iterable[str], limit=10_000) -> str:
284
284
  ?protein rdfs:seeAlso ?af_db .
285
285
  ?af_db up:database <http://purl.uniprot.org/database/AlphaFoldDB> .
286
286
  """)
287
- return _build_sparql_generic_by_uniprot_accesions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
287
+ return _build_sparql_generic_by_uniprot_accessions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
288
288
 
289
289
 
290
290
  def _build_sparql_query_emdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
@@ -297,7 +297,7 @@ def _build_sparql_query_emdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
297
297
  ?protein rdfs:seeAlso ?emdb_db .
298
298
  ?emdb_db up:database <http://purl.uniprot.org/database/EMDB> .
299
299
  """)
300
- return _build_sparql_generic_by_uniprot_accesions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
300
+ return _build_sparql_generic_by_uniprot_accessions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
301
301
 
302
302
 
303
303
  def _execute_sparql_search(
@@ -509,3 +509,156 @@ def search4emdb(uniprot_accs: Iterable[str], limit: int = 10_000, timeout: int =
509
509
  )
510
510
  limit_check("Search for EMDB entries on uniprot", limit, len(raw_results))
511
511
  return _flatten_results_emdb(raw_results)
512
+
513
+
514
+ def _build_complex_sparql_query(uniprot_accs: Iterable[str], limit: int) -> str:
515
+ """Builds a SPARQL query to retrieve ComplexPortal information for given UniProt accessions.
516
+
517
+ Example:
518
+
519
+ ```sparql
520
+ PREFIX up: <http://purl.uniprot.org/core/>
521
+ PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
522
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
523
+
524
+ SELECT
525
+ ?protein
526
+ ?cp_db
527
+ ?cp_comment
528
+ (GROUP_CONCAT(DISTINCT ?member; separator=",") AS ?complex_members)
529
+ (COUNT(DISTINCT ?member) AS ?member_count)
530
+ WHERE {
531
+ # Input UniProt accessions
532
+ VALUES (?ac) { ("P05067") ("P60709") ("Q05471")}
533
+ BIND (IRI(CONCAT("http://purl.uniprot.org/uniprot/", ?ac)) AS ?protein)
534
+
535
+ # ComplexPortal cross-reference for each input protein
536
+ ?protein a up:Protein ;
537
+ rdfs:seeAlso ?cp_db .
538
+ ?cp_db up:database <http://purl.uniprot.org/database/ComplexPortal> .
539
+ OPTIONAL { ?cp_db rdfs:comment ?cp_comment . }
540
+
541
+ # All member proteins of the same ComplexPortal complex
542
+ ?member a up:Protein ;
543
+ rdfs:seeAlso ?cp_db .
544
+ }
545
+ GROUP BY ?protein ?cp_db ?cp_comment
546
+ ORDER BY ?protein ?cp_db
547
+ LIMIT 500
548
+ ```
549
+
550
+ """
551
+ select_clause = dedent("""\
552
+ ?protein ?cp_db ?cp_comment
553
+ (GROUP_CONCAT(DISTINCT ?member; separator=",") AS ?complex_members)
554
+ """)
555
+ where_clause = dedent("""
556
+ # --- Complex Info ---
557
+ ?protein a up:Protein ;
558
+ rdfs:seeAlso ?cp_db .
559
+ ?cp_db up:database <http://purl.uniprot.org/database/ComplexPortal> .
560
+ OPTIONAL { ?cp_db rdfs:comment ?cp_comment . }
561
+ # All member proteins of the same ComplexPortal complex
562
+ ?member a up:Protein ;
563
+ rdfs:seeAlso ?cp_db .
564
+ """)
565
+ group_by = dedent("""
566
+ ?protein ?cp_db ?cp_comment
567
+ """)
568
+ return _build_sparql_generic_by_uniprot_accessions_query(
569
+ uniprot_accs, select_clause, where_clause, limit, groupby_clause=group_by
570
+ )
571
+
572
+
573
+ @dataclass(frozen=True)
574
+ class ComplexPortalEntry:
575
+ """A ComplexPortal entry.
576
+
577
+ Parameters:
578
+ query_protein: The UniProt accession used to find entry.
579
+ complex_id: The ComplexPortal identifier (for example "CPX-1234").
580
+ complex_url: The URL to the ComplexPortal entry.
581
+ complex_title: The title of the complex.
582
+ members: UniProt accessions which are members of the complex.
583
+ """
584
+
585
+ query_protein: str
586
+ complex_id: str
587
+ complex_url: str
588
+ complex_title: str
589
+ members: set[str]
590
+
591
+
592
+ def _flatten_results_complex(raw_results) -> list[ComplexPortalEntry]:
593
+ results = []
594
+ for raw_result in raw_results:
595
+ query_protein = raw_result["protein"]["value"].split("/")[-1]
596
+ complex_id = raw_result["cp_db"]["value"].split("/")[-1]
597
+ complex_url = f"https://www.ebi.ac.uk/complexportal/complex/{complex_id}"
598
+ complex_title = raw_result.get("cp_comment", {}).get("value", "")
599
+ members = {m.split("/")[-1] for m in raw_result["complex_members"]["value"].split(",")}
600
+ results.append(
601
+ ComplexPortalEntry(
602
+ query_protein=query_protein,
603
+ complex_id=complex_id,
604
+ complex_url=complex_url,
605
+ complex_title=complex_title,
606
+ members=members,
607
+ )
608
+ )
609
+ return results
610
+
611
+
612
+ def search4macromolecular_complexes(
613
+ uniprot_accs: Iterable[str], limit: int = 10_000, timeout: int = 1_800
614
+ ) -> list[ComplexPortalEntry]:
615
+ """Search for macromolecular complexes by UniProtKB accessions.
616
+
617
+ Queries for references to/from https://www.ebi.ac.uk/complexportal/ database in the Uniprot SPARQL endpoint.
618
+
619
+ Args:
620
+ uniprot_accs: UniProt accessions.
621
+ limit: Maximum number of results to return.
622
+ timeout: Timeout for the SPARQL query in seconds.
623
+
624
+ Returns:
625
+ List of ComplexPortalEntry objects.
626
+ """
627
+ sparql_query = _build_complex_sparql_query(uniprot_accs, limit)
628
+ logger.info("Executing SPARQL query for macromolecular complexes: %s", sparql_query)
629
+ raw_results = _execute_sparql_search(
630
+ sparql_query=sparql_query,
631
+ timeout=timeout,
632
+ )
633
+ limit_check("Search for complexes", limit, len(raw_results))
634
+ return _flatten_results_complex(raw_results)
635
+
636
+
637
+ def search4interaction_partners(
638
+ uniprot_acc: str, excludes: set[str] | None = None, limit: int = 10_000, timeout: int = 1_800
639
+ ) -> dict[str, set[str]]:
640
+ """Search for interaction partners of a given UniProt accession using ComplexPortal database references.
641
+
642
+ Args:
643
+ uniprot_acc: UniProt accession to search interaction partners for.
644
+ excludes: Set of UniProt accessions to exclude from the results.
645
+ For example already known interaction partners.
646
+ If None then no complex members are excluded.
647
+ limit: Maximum number of results to return.
648
+ timeout: Timeout for the SPARQL query in seconds.
649
+
650
+ Returns:
651
+ Dictionary with UniProt accessions of interaction partners as keys and sets of ComplexPortal entry IDs
652
+ in which the interaction occurs as values.
653
+ """
654
+ ucomplexes = search4macromolecular_complexes([uniprot_acc], limit=limit, timeout=timeout)
655
+ hits: dict[str, set[str]] = {}
656
+ if excludes is None:
657
+ excludes = set()
658
+ for ucomplex in ucomplexes:
659
+ for member in ucomplex.members:
660
+ if member != uniprot_acc and member not in excludes:
661
+ if member not in hits:
662
+ hits[member] = set()
663
+ hits[member].add(ucomplex.complex_id)
664
+ return hits
protein_quest/utils.py CHANGED
@@ -1,22 +1,260 @@
1
1
  """Module for functions that are used in multiple places."""
2
2
 
3
+ import argparse
3
4
  import asyncio
5
+ import hashlib
4
6
  import logging
5
7
  import shutil
6
- from collections.abc import Coroutine, Iterable
8
+ from collections.abc import Coroutine, Iterable, Sequence
7
9
  from contextlib import asynccontextmanager
10
+ from functools import lru_cache
8
11
  from pathlib import Path
9
12
  from textwrap import dedent
10
- from typing import Any, Literal, get_args
13
+ from typing import Any, Literal, Protocol, get_args, runtime_checkable
11
14
 
12
15
  import aiofiles
16
+ import aiofiles.os
13
17
  import aiohttp
18
+ import rich
19
+ from aiohttp.streams import AsyncStreamIterator
14
20
  from aiohttp_retry import ExponentialRetry, RetryClient
21
+ from platformdirs import user_cache_dir
22
+ from rich_argparse import ArgumentDefaultsRichHelpFormatter
15
23
  from tqdm.asyncio import tqdm
16
24
  from yarl import URL
17
25
 
18
26
  logger = logging.getLogger(__name__)
19
27
 
28
+ CopyMethod = Literal["copy", "symlink", "hardlink"]
29
+ """Methods for copying files."""
30
+ copy_methods = set(get_args(CopyMethod))
31
+ """Set of valid copy methods."""
32
+
33
+
34
+ @lru_cache
35
+ def _cache_sub_dir(root_cache_dir: Path, filename: str, hash_length: int = 4) -> Path:
36
+ """Get the cache sub-directory for a given path.
37
+
38
+ To not have too many files in a single directory,
39
+ we create sub-directories based on the hash of the filename.
40
+
41
+ Args:
42
+ root_cache_dir: The root directory for the cache.
43
+ filename: The filename to be cached.
44
+ hash_length: The length of the hash to use for the sub-directory.
45
+
46
+ Returns:
47
+ The parent path to the cached file.
48
+ """
49
+ full_hash = hashlib.blake2b(filename.encode("utf-8")).hexdigest()
50
+ cache_sub_dir = full_hash[:hash_length]
51
+ cache_sub_dir_path = root_cache_dir / cache_sub_dir
52
+ cache_sub_dir_path.mkdir(parents=True, exist_ok=True)
53
+ return cache_sub_dir_path
54
+
55
+
56
+ @runtime_checkable
57
+ class Cacher(Protocol):
58
+ """Protocol for a cacher."""
59
+
60
+ def __contains__(self, item: str | Path) -> bool:
61
+ """Check if a file is in the cache.
62
+
63
+ Args:
64
+ item: The filename or Path to check.
65
+
66
+ Returns:
67
+ True if the file is in the cache, False otherwise.
68
+ """
69
+ ...
70
+
71
+ async def copy_from_cache(self, target: Path) -> Path | None:
72
+ """Copy a file from the cache to a target location if it exists in the cache.
73
+
74
+ Assumes:
75
+
76
+ - target does not exist.
77
+ - the parent directory of target exists.
78
+
79
+ Args:
80
+ target: The path to copy the file to.
81
+
82
+ Returns:
83
+ The path to the cached file if it was copied, None otherwise.
84
+ """
85
+ ...
86
+
87
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
88
+ """Write content to a file and cache it.
89
+
90
+ Args:
91
+ target: The path to write the content to.
92
+ content: An async iterator that yields bytes to write to the file.
93
+
94
+ Returns:
95
+ The path to the cached file.
96
+
97
+ Raises:
98
+ FileExistsError: If the target file already exists.
99
+ """
100
+ ...
101
+
102
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
103
+ """Write bytes to a file and cache it.
104
+
105
+ Args:
106
+ target: The path to write the content to.
107
+ content: The bytes to write to the file.
108
+
109
+ Returns:
110
+ The path to the cached file.
111
+
112
+ Raises:
113
+ FileExistsError: If the target file already exists.
114
+ """
115
+ ...
116
+
117
+
118
+ class PassthroughCacher(Cacher):
119
+ """A cacher that caches nothing.
120
+
121
+ On writes it just writes to the target path.
122
+ """
123
+
124
+ def __contains__(self, item: str | Path) -> bool:
125
+ # We don't have anything cached ever
126
+ return False
127
+
128
+ async def copy_from_cache(self, target: Path) -> Path | None: # noqa: ARG002
129
+ # We don't have anything cached ever
130
+ return None
131
+
132
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
133
+ if target.exists():
134
+ raise FileExistsError(target)
135
+ target.write_bytes(b"".join([chunk async for chunk in content]))
136
+ return target
137
+
138
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
139
+ if target.exists():
140
+ raise FileExistsError(target)
141
+ target.write_bytes(content)
142
+ return target
143
+
144
+
145
+ def user_cache_root_dir() -> Path:
146
+ """Get the users root directory for caching files.
147
+
148
+ Returns:
149
+ The path to the user's cache directory for protein-quest.
150
+ """
151
+ return Path(user_cache_dir("protein-quest"))
152
+
153
+
154
+ class DirectoryCacher(Cacher):
155
+ """Class to cache files in a directory.
156
+
157
+ Caching logic is based on the file name only.
158
+ If file name of paths are the same then the files are considered the same.
159
+
160
+ Attributes:
161
+ cache_dir: The directory to use for caching.
162
+ copy_method: The method to use for copying files.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ cache_dir: Path | None = None,
168
+ copy_method: CopyMethod = "hardlink",
169
+ ) -> None:
170
+ """Initialize the cacher.
171
+
172
+ If file name of paths are the same then the files are considered the same.
173
+
174
+ Args:
175
+ cache_dir: The directory to use for caching.
176
+ If None, a default cache directory (~/.cache/protein-quest) is used.
177
+ copy_method: The method to use for copying.
178
+ """
179
+ if cache_dir is None:
180
+ cache_dir = user_cache_root_dir()
181
+ self.cache_dir: Path = cache_dir
182
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
183
+ if copy_method == "copy":
184
+ logger.warning(
185
+ "Using copy as copy_method to cache files is not recommended. "
186
+ "This will use more disk space and be slower than symlink or hardlink."
187
+ )
188
+ if copy_method not in copy_methods:
189
+ msg = f"Unknown copy method: {copy_method}. Must be one of {copy_methods}."
190
+ raise ValueError(msg)
191
+ self.copy_method: CopyMethod = copy_method
192
+
193
+ def __contains__(self, item: str | Path) -> bool:
194
+ cached_file = self._as_cached_path(item)
195
+ return cached_file.exists()
196
+
197
+ def _as_cached_path(self, item: str | Path) -> Path:
198
+ file_name = item.name if isinstance(item, Path) else item
199
+ cache_sub_dir = _cache_sub_dir(self.cache_dir, file_name)
200
+ return cache_sub_dir / file_name
201
+
202
+ async def copy_from_cache(self, target: Path) -> Path | None:
203
+ cached_file = self._as_cached_path(target.name)
204
+ exists = await aiofiles.os.path.exists(str(cached_file))
205
+ if exists:
206
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
207
+ return cached_file
208
+ return None
209
+
210
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
211
+ cached_file = self._as_cached_path(target.name)
212
+ # Write file to cache dir
213
+ async with aiofiles.open(cached_file, "xb") as f:
214
+ async for chunk in content:
215
+ await f.write(chunk)
216
+ # Copy to target location
217
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
218
+ return cached_file
219
+
220
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
221
+ cached_file = self._as_cached_path(target.name)
222
+ # Write file to cache dir
223
+ async with aiofiles.open(cached_file, "xb") as f:
224
+ await f.write(content)
225
+ # Copy to target location
226
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
227
+ return cached_file
228
+
229
+ def populate_cache(self, source_dir: Path) -> dict[Path, Path]:
230
+ """Populate the cache from an existing directory.
231
+
232
+ This will copy all files from the source directory to the cache directory.
233
+ If a file with the same name already exists in the cache, it will be skipped.
234
+
235
+ Args:
236
+ source_dir: The directory to populate the cache from.
237
+
238
+ Returns:
239
+ A dictionary mapping source file paths to their cached paths.
240
+
241
+ Raises:
242
+ NotADirectoryError: If the source_dir is not a directory.
243
+ """
244
+ if not source_dir.is_dir():
245
+ raise NotADirectoryError(source_dir)
246
+ cached = {}
247
+ for file_path in source_dir.iterdir():
248
+ if not file_path.is_file():
249
+ continue
250
+ cached_path = self._as_cached_path(file_path.name)
251
+ if cached_path.exists():
252
+ logger.debug(f"File {file_path.name} already in cache. Skipping.")
253
+ continue
254
+ copyfile(file_path, cached_path, copy_method=self.copy_method)
255
+ cached[file_path] = cached_path
256
+ return cached
257
+
20
258
 
21
259
  async def retrieve_files(
22
260
  urls: Iterable[tuple[URL | str, str]],
@@ -25,6 +263,8 @@ async def retrieve_files(
25
263
  retries: int = 3,
26
264
  total_timeout: int = 300,
27
265
  desc: str = "Downloading files",
266
+ cacher: Cacher | None = None,
267
+ chunk_size: int = 524288, # 512 KiB
28
268
  ) -> list[Path]:
29
269
  """Retrieve files from a list of URLs and save them to a directory.
30
270
 
@@ -35,6 +275,8 @@ async def retrieve_files(
35
275
  retries: The number of times to retry a failed download.
36
276
  total_timeout: The total timeout for a download in seconds.
37
277
  desc: Description for the progress bar.
278
+ cacher: An optional cacher to use for caching files.
279
+ chunk_size: The size of each chunk to read from the response.
38
280
 
39
281
  Returns:
40
282
  A list of paths to the downloaded files.
@@ -42,7 +284,17 @@ async def retrieve_files(
42
284
  save_dir.mkdir(parents=True, exist_ok=True)
43
285
  semaphore = asyncio.Semaphore(max_parallel_downloads)
44
286
  async with friendly_session(retries, total_timeout) as session:
45
- tasks = [_retrieve_file(session, url, save_dir / filename, semaphore) for url, filename in urls]
287
+ tasks = [
288
+ _retrieve_file(
289
+ session=session,
290
+ url=url,
291
+ save_path=save_dir / filename,
292
+ semaphore=semaphore,
293
+ cacher=cacher,
294
+ chunk_size=chunk_size,
295
+ )
296
+ for url, filename in urls
297
+ ]
46
298
  files: list[Path] = await tqdm.gather(*tasks, desc=desc)
47
299
  return files
48
300
 
@@ -52,8 +304,8 @@ async def _retrieve_file(
52
304
  url: URL | str,
53
305
  save_path: Path,
54
306
  semaphore: asyncio.Semaphore,
55
- ovewrite: bool = False,
56
- chunk_size: int = 131072, # 128 KiB
307
+ cacher: Cacher | None = None,
308
+ chunk_size: int = 524288, # 512 KiB
57
309
  ) -> Path:
58
310
  """Retrieve a single file from a URL and save it to a specified path.
59
311
 
@@ -62,26 +314,28 @@ async def _retrieve_file(
62
314
  url: The URL to download the file from.
63
315
  save_path: The path where the file should be saved.
64
316
  semaphore: A semaphore to limit the number of concurrent downloads.
65
- ovewrite: Whether to overwrite the file if it already exists.
317
+ cacher: An optional cacher to use for caching files.
66
318
  chunk_size: The size of each chunk to read from the response.
67
319
 
68
320
  Returns:
69
321
  The path to the saved file.
70
322
  """
71
323
  if save_path.exists():
72
- if ovewrite:
73
- save_path.unlink()
74
- else:
75
- logger.debug(f"File {save_path} already exists. Skipping download from {url}.")
76
- return save_path
324
+ logger.debug(f"File {save_path} already exists. Skipping download from {url}.")
325
+ return save_path
326
+
327
+ if cacher is None:
328
+ cacher = PassthroughCacher()
329
+ if cached_file := await cacher.copy_from_cache(save_path):
330
+ logger.debug(f"File {save_path} was copied from cache {cached_file}. Skipping download from {url}.")
331
+ return save_path
332
+
77
333
  async with (
78
334
  semaphore,
79
- aiofiles.open(save_path, "xb") as f,
80
335
  session.get(url) as resp,
81
336
  ):
82
337
  resp.raise_for_status()
83
- async for chunk in resp.content.iter_chunked(chunk_size):
84
- await f.write(chunk)
338
+ await cacher.write_iter(save_path, resp.content.iter_chunked(chunk_size))
85
339
  return save_path
86
340
 
87
341
 
@@ -141,27 +395,117 @@ def run_async[R](coroutine: Coroutine[Any, Any, R]) -> R:
141
395
  raise NestedAsyncIOLoopError from e
142
396
 
143
397
 
144
- CopyMethod = Literal["copy", "symlink"]
145
- copy_methods = set(get_args(CopyMethod))
146
-
147
-
148
398
  def copyfile(source: Path, target: Path, copy_method: CopyMethod = "copy"):
149
- """Make target path be same file as source by either copying or symlinking.
399
+ """Make target path be same file as source by either copying or symlinking or hardlinking.
400
+
401
+ Note that the hardlink copy method only works within the same filesystem and is harder to track.
402
+ If you want to track cached files easily then use 'symlink'.
403
+ On Windows you need developer mode or admin privileges to create symlinks.
150
404
 
151
405
  Args:
152
- source: The source file to copy or symlink.
406
+ source: The source file to copy or link.
153
407
  target: The target file to create.
154
408
  copy_method: The method to use for copying.
155
409
 
156
410
  Raises:
157
411
  FileNotFoundError: If the source file or parent of target does not exist.
158
- ValueError: If the method is not "copy" or "symlink".
412
+ FileExistsError: If the target file already exists.
413
+ ValueError: If an unknown copy method is provided.
159
414
  """
160
415
  if copy_method == "copy":
161
416
  shutil.copyfile(source, target)
162
417
  elif copy_method == "symlink":
163
- rel_source = source.relative_to(target.parent, walk_up=True)
418
+ rel_source = source.absolute().relative_to(target.parent.absolute(), walk_up=True)
164
419
  target.symlink_to(rel_source)
420
+ elif copy_method == "hardlink":
421
+ target.hardlink_to(source)
165
422
  else:
166
- msg = f"Unknown method: {copy_method}"
423
+ msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
167
424
  raise ValueError(msg)
425
+
426
+
427
+ async def async_copyfile(
428
+ source: Path,
429
+ target: Path,
430
+ copy_method: CopyMethod = "copy",
431
+ ):
432
+ """Asynchronously make target path be same file as source by either copying or symlinking or hardlinking.
433
+
434
+ Note that the hardlink copy method only works within the same filesystem and is harder to track.
435
+ If you want to track cached files easily then use 'symlink'.
436
+ On Windows you need developer mode or admin privileges to create symlinks.
437
+
438
+ Args:
439
+ source: The source file to copy.
440
+ target: The target file to create.
441
+ copy_method: The method to use for copying.
442
+
443
+ Raises:
444
+ FileNotFoundError: If the source file or parent of target does not exist.
445
+ FileExistsError: If the target file already exists.
446
+ ValueError: If an unknown copy method is provided.
447
+ """
448
+ if copy_method == "copy":
449
+ # Could use loop of chunks with aiofiles,
450
+ # but shutil is ~1.9x faster on my machine
451
+ # due to fastcopy and sendfile optimizations in shutil.
452
+ await asyncio.to_thread(shutil.copyfile, source, target)
453
+ elif copy_method == "symlink":
454
+ rel_source = source.relative_to(target.parent, walk_up=True)
455
+ await aiofiles.os.symlink(str(rel_source), str(target))
456
+ elif copy_method == "hardlink":
457
+ await aiofiles.os.link(str(source), str(target))
458
+ else:
459
+ msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
460
+ raise ValueError(msg)
461
+
462
+
463
+ def populate_cache_command(raw_args: Sequence[str] | None = None):
464
+ """Command line interface to populate the cache from an existing directory.
465
+
466
+ Can be called from the command line as:
467
+
468
+ ```bash
469
+ python3 -m protein_quest.utils populate-cache /path/to/source/dir
470
+ ```
471
+
472
+ Args:
473
+ raw_args: The raw command line arguments to parse. If None, uses sys.argv.
474
+ """
475
+ root_parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsRichHelpFormatter)
476
+ subparsers = root_parser.add_subparsers(dest="command")
477
+
478
+ desc = "Populate the cache directory with files from the source directory."
479
+ populate_cache_parser = subparsers.add_parser(
480
+ "populate-cache",
481
+ help=desc,
482
+ description=desc,
483
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
484
+ )
485
+ populate_cache_parser.add_argument("source_dir", type=Path)
486
+ populate_cache_parser.add_argument(
487
+ "--cache-dir",
488
+ type=Path,
489
+ default=user_cache_root_dir(),
490
+ help="Directory to use for caching. If not provided, a default cache directory is used.",
491
+ )
492
+ populate_cache_parser.add_argument(
493
+ "--copy-method",
494
+ type=str,
495
+ default="hardlink",
496
+ choices=copy_methods,
497
+ help="Method to use for copying files to cache.",
498
+ )
499
+
500
+ args = root_parser.parse_args(raw_args)
501
+ if args.command == "populate-cache":
502
+ source_dir = args.source_dir
503
+ cacher = DirectoryCacher(cache_dir=args.cache_dir, copy_method=args.copy_method)
504
+ cached_files = cacher.populate_cache(source_dir)
505
+ rich.print(f"Cached {len(cached_files)} files from {source_dir} to {cacher.cache_dir}")
506
+ for src, cached in cached_files.items():
507
+ rich.print(f"- {src.relative_to(source_dir)} -> {cached.relative_to(cacher.cache_dir)}")
508
+
509
+
510
+ if __name__ == "__main__":
511
+ populate_cache_command()