protein-quest 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
protein_quest/utils.py ADDED
@@ -0,0 +1,547 @@
1
+ """Module for functions that are used in multiple places."""
2
+
3
+ import argparse
4
+ import asyncio
5
+ import hashlib
6
+ import logging
7
+ import shutil
8
+ from collections.abc import Coroutine, Iterable, Sequence
9
+ from contextlib import asynccontextmanager
10
+ from functools import lru_cache
11
+ from pathlib import Path
12
+ from textwrap import dedent
13
+ from typing import Any, Literal, Protocol, get_args, runtime_checkable
14
+
15
+ import aiofiles
16
+ import aiofiles.os
17
+ import aiohttp
18
+ import rich
19
+ from aiohttp.streams import AsyncStreamIterator
20
+ from aiohttp_retry import ExponentialRetry, RetryClient
21
+ from platformdirs import user_cache_dir
22
+ from rich_argparse import ArgumentDefaultsRichHelpFormatter
23
+ from tqdm.asyncio import tqdm
24
+ from yarl import URL
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ CopyMethod = Literal["copy", "symlink", "hardlink"]
29
+ """Methods for copying files."""
30
+ copy_methods = set(get_args(CopyMethod))
31
+ """Set of valid copy methods."""
32
+
33
+
34
+ @lru_cache
35
+ def _cache_sub_dir(root_cache_dir: Path, filename: str, hash_length: int = 4) -> Path:
36
+ """Get the cache sub-directory for a given path.
37
+
38
+ To not have too many files in a single directory,
39
+ we create sub-directories based on the hash of the filename.
40
+
41
+ Args:
42
+ root_cache_dir: The root directory for the cache.
43
+ filename: The filename to be cached.
44
+ hash_length: The length of the hash to use for the sub-directory.
45
+
46
+ Returns:
47
+ The parent path to the cached file.
48
+ """
49
+ full_hash = hashlib.blake2b(filename.encode("utf-8")).hexdigest()
50
+ cache_sub_dir = full_hash[:hash_length]
51
+ cache_sub_dir_path = root_cache_dir / cache_sub_dir
52
+ cache_sub_dir_path.mkdir(parents=True, exist_ok=True)
53
+ return cache_sub_dir_path
54
+
55
+
56
+ @runtime_checkable
57
+ class Cacher(Protocol):
58
+ """Protocol for a cacher."""
59
+
60
+ def __contains__(self, item: str | Path) -> bool:
61
+ """Check if a file is in the cache.
62
+
63
+ Args:
64
+ item: The filename or Path to check.
65
+
66
+ Returns:
67
+ True if the file is in the cache, False otherwise.
68
+ """
69
+ ...
70
+
71
+ async def copy_from_cache(self, target: Path) -> Path | None:
72
+ """Copy a file from the cache to a target location if it exists in the cache.
73
+
74
+ Assumes:
75
+
76
+ - target does not exist.
77
+ - the parent directory of target exists.
78
+
79
+ Args:
80
+ target: The path to copy the file to.
81
+
82
+ Returns:
83
+ The path to the cached file if it was copied, None otherwise.
84
+ """
85
+ ...
86
+
87
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
88
+ """Write content to a file and cache it.
89
+
90
+ Args:
91
+ target: The path to write the content to.
92
+ content: An async iterator that yields bytes to write to the file.
93
+
94
+ Returns:
95
+ The path to the cached file.
96
+
97
+ Raises:
98
+ FileExistsError: If the target file already exists.
99
+ """
100
+ ...
101
+
102
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
103
+ """Write bytes to a file and cache it.
104
+
105
+ Args:
106
+ target: The path to write the content to.
107
+ content: The bytes to write to the file.
108
+
109
+ Returns:
110
+ The path to the cached file.
111
+
112
+ Raises:
113
+ FileExistsError: If the target file already exists.
114
+ """
115
+ ...
116
+
117
+
118
+ class PassthroughCacher(Cacher):
119
+ """A cacher that caches nothing.
120
+
121
+ On writes it just writes to the target path.
122
+ """
123
+
124
+ def __contains__(self, item: str | Path) -> bool:
125
+ # We don't have anything cached ever
126
+ return False
127
+
128
+ async def copy_from_cache(self, target: Path) -> Path | None: # noqa: ARG002
129
+ # We don't have anything cached ever
130
+ return None
131
+
132
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
133
+ if target.exists():
134
+ raise FileExistsError(target)
135
+ target.write_bytes(b"".join([chunk async for chunk in content]))
136
+ return target
137
+
138
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
139
+ if target.exists():
140
+ raise FileExistsError(target)
141
+ target.write_bytes(content)
142
+ return target
143
+
144
+
145
+ def user_cache_root_dir() -> Path:
146
+ """Get the users root directory for caching files.
147
+
148
+ Returns:
149
+ The path to the user's cache directory for protein-quest.
150
+ """
151
+ return Path(user_cache_dir("protein-quest"))
152
+
153
+
154
+ class DirectoryCacher(Cacher):
155
+ """Class to cache files in a directory.
156
+
157
+ Caching logic is based on the file name only.
158
+ If file name of paths are the same then the files are considered the same.
159
+
160
+ Attributes:
161
+ cache_dir: The directory to use for caching.
162
+ copy_method: The method to use for copying files.
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ cache_dir: Path | None = None,
168
+ copy_method: CopyMethod = "hardlink",
169
+ ) -> None:
170
+ """Initialize the cacher.
171
+
172
+ If file name of paths are the same then the files are considered the same.
173
+
174
+ Args:
175
+ cache_dir: The directory to use for caching.
176
+ If None, a default cache directory (~/.cache/protein-quest) is used.
177
+ copy_method: The method to use for copying.
178
+ """
179
+ if cache_dir is None:
180
+ cache_dir = user_cache_root_dir()
181
+ self.cache_dir: Path = cache_dir
182
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
183
+ if copy_method == "copy":
184
+ logger.warning(
185
+ "Using copy as copy_method to cache files is not recommended. "
186
+ "This will use more disk space and be slower than symlink or hardlink."
187
+ )
188
+ if copy_method not in copy_methods:
189
+ msg = f"Unknown copy method: {copy_method}. Must be one of {copy_methods}."
190
+ raise ValueError(msg)
191
+ self.copy_method: CopyMethod = copy_method
192
+
193
+ def __contains__(self, item: str | Path) -> bool:
194
+ cached_file = self._as_cached_path(item)
195
+ return cached_file.exists()
196
+
197
+ def _as_cached_path(self, item: str | Path) -> Path:
198
+ file_name = item.name if isinstance(item, Path) else item
199
+ cache_sub_dir = _cache_sub_dir(self.cache_dir, file_name)
200
+ return cache_sub_dir / file_name
201
+
202
+ async def copy_from_cache(self, target: Path) -> Path | None:
203
+ cached_file = self._as_cached_path(target.name)
204
+ exists = await aiofiles.os.path.exists(str(cached_file))
205
+ if exists:
206
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
207
+ return cached_file
208
+ return None
209
+
210
+ async def write_iter(self, target: Path, content: AsyncStreamIterator[bytes]) -> Path:
211
+ cached_file = self._as_cached_path(target.name)
212
+ # Write file to cache dir
213
+ async with aiofiles.open(cached_file, "xb") as f:
214
+ async for chunk in content:
215
+ await f.write(chunk)
216
+ # Copy to target location
217
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
218
+ return cached_file
219
+
220
+ async def write_bytes(self, target: Path, content: bytes) -> Path:
221
+ cached_file = self._as_cached_path(target.name)
222
+ # Write file to cache dir
223
+ async with aiofiles.open(cached_file, "xb") as f:
224
+ await f.write(content)
225
+ # Copy to target location
226
+ await async_copyfile(cached_file, target, copy_method=self.copy_method)
227
+ return cached_file
228
+
229
+ def populate_cache(self, source_dir: Path) -> dict[Path, Path]:
230
+ """Populate the cache from an existing directory.
231
+
232
+ This will copy all files from the source directory to the cache directory.
233
+ If a file with the same name already exists in the cache, it will be skipped.
234
+
235
+ Args:
236
+ source_dir: The directory to populate the cache from.
237
+
238
+ Returns:
239
+ A dictionary mapping source file paths to their cached paths.
240
+
241
+ Raises:
242
+ NotADirectoryError: If the source_dir is not a directory.
243
+ """
244
+ if not source_dir.is_dir():
245
+ raise NotADirectoryError(source_dir)
246
+ cached = {}
247
+ for file_path in source_dir.iterdir():
248
+ if not file_path.is_file():
249
+ continue
250
+ cached_path = self._as_cached_path(file_path.name)
251
+ if cached_path.exists():
252
+ logger.debug(f"File {file_path.name} already in cache. Skipping.")
253
+ continue
254
+ copyfile(file_path, cached_path, copy_method=self.copy_method)
255
+ cached[file_path] = cached_path
256
+ return cached
257
+
258
+
259
+ async def retrieve_files(
260
+ urls: Iterable[tuple[URL | str, str]],
261
+ save_dir: Path,
262
+ max_parallel_downloads: int = 5,
263
+ retries: int = 3,
264
+ total_timeout: int = 300,
265
+ desc: str = "Downloading files",
266
+ cacher: Cacher | None = None,
267
+ chunk_size: int = 524288, # 512 KiB
268
+ gzip_files: bool = False,
269
+ raise_for_not_found: bool = True,
270
+ ) -> list[Path]:
271
+ """Retrieve files from a list of URLs and save them to a directory.
272
+
273
+ Args:
274
+ urls: A list of tuples, where each tuple contains a URL and a filename.
275
+ save_dir: The directory to save the downloaded files to.
276
+ max_parallel_downloads: The maximum number of files to download in parallel.
277
+ retries: The number of times to retry a failed download.
278
+ total_timeout: The total timeout for a download in seconds.
279
+ desc: Description for the progress bar.
280
+ cacher: An optional cacher to use for caching files.
281
+ chunk_size: The size of each chunk to read from the response.
282
+ gzip_files: Whether to gzip the downloaded files.
283
+ This requires the server can send gzip encoded content.
284
+ raise_for_not_found: Whether to raise an error for HTTP 404 errors.
285
+ If false then function does not returns Path for which url gave HTTP 404 error and logs as debug message.
286
+
287
+ Returns:
288
+ A list of paths to the downloaded files.
289
+ """
290
+ save_dir.mkdir(parents=True, exist_ok=True)
291
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
292
+ async with friendly_session(retries, total_timeout) as session:
293
+ tasks = [
294
+ _retrieve_file(
295
+ session=session,
296
+ url=url,
297
+ save_path=save_dir / filename,
298
+ semaphore=semaphore,
299
+ cacher=cacher,
300
+ chunk_size=chunk_size,
301
+ gzip_files=gzip_files,
302
+ raise_for_not_found=raise_for_not_found,
303
+ )
304
+ for url, filename in urls
305
+ ]
306
+ raw_files: list[Path | None] = await tqdm.gather(*tasks, desc=desc)
307
+ return [f for f in raw_files if f is not None]
308
+
309
+
310
+ class InvalidContentEncodingError(aiohttp.ClientResponseError):
311
+ """Content encoding is invalid."""
312
+
313
+
314
+ async def _retrieve_file(
315
+ session: RetryClient,
316
+ url: URL | str,
317
+ save_path: Path,
318
+ semaphore: asyncio.Semaphore,
319
+ cacher: Cacher | None = None,
320
+ chunk_size: int = 524288, # 512 KiB
321
+ gzip_files: bool = False,
322
+ raise_for_not_found=True,
323
+ ) -> Path | None:
324
+ """Retrieve a single file from a URL and save it to a specified path.
325
+
326
+ Args:
327
+ session: The aiohttp session to use for the request.
328
+ url: The URL to download the file from.
329
+ save_path: The path where the file should be saved.
330
+ semaphore: A semaphore to limit the number of concurrent downloads.
331
+ cacher: An optional cacher to use for caching files.
332
+ chunk_size: The size of each chunk to read from the response.
333
+ gzip_files: Whether to gzip the downloaded file.
334
+ This requires the server can send gzip encoded content.
335
+ raise_for_not_found: Whether to raise an error for HTTP 404 errors.
336
+ If false then function returns None on HTTP 404 errors and logs as debug message.
337
+
338
+ Returns:
339
+ The path to the saved file.
340
+ """
341
+ if save_path.exists():
342
+ logger.debug(f"File {save_path} already exists. Skipping download from {url}.")
343
+ return save_path
344
+
345
+ if cacher is None:
346
+ cacher = PassthroughCacher()
347
+ if cached_file := await cacher.copy_from_cache(save_path):
348
+ logger.debug(f"File {save_path} was copied from cache {cached_file}. Skipping download from {url}.")
349
+ return save_path
350
+
351
+ # Alphafold server and many other web servers can return gzipped responses,
352
+ # when we want to save as *.gz, we use raw stream
353
+ # otherwise aiohttp will decompress it automatically for us.
354
+ auto_decompress = not gzip_files
355
+ headers = {"Accept-Encoding": "gzip"}
356
+ async with (
357
+ semaphore,
358
+ session.get(url, headers=headers, auto_decompress=auto_decompress) as resp,
359
+ ):
360
+ if not raise_for_not_found and resp.status == 404:
361
+ logger.debug(f"File not found at {url}, skipping download.")
362
+ return None
363
+ resp.raise_for_status()
364
+ if gzip_files and resp.headers.get("Content-Encoding") != "gzip":
365
+ msg = f"Server did not send gzip encoded content for {url}, can not save as gzipped file."
366
+ raise InvalidContentEncodingError(
367
+ request_info=resp.request_info,
368
+ history=resp.history,
369
+ status=415,
370
+ message=msg,
371
+ headers=resp.headers,
372
+ )
373
+ iterator = resp.content.iter_chunked(chunk_size)
374
+ await cacher.write_iter(save_path, iterator)
375
+ return save_path
376
+
377
+
378
+ @asynccontextmanager
379
+ async def friendly_session(retries: int = 3, total_timeout: int = 300):
380
+ """Create an aiohttp session with retry capabilities.
381
+
382
+ Examples:
383
+ Use as async context:
384
+
385
+ >>> async with friendly_session(retries=5, total_timeout=60) as session:
386
+ >>> r = await session.get("https://example.com/api/data")
387
+ >>> print(r)
388
+ <ClientResponse(https://example.com/api/data) [404 Not Found]>
389
+ <CIMultiDictProxy('Accept-Ranges': 'bytes', ...
390
+
391
+ Args:
392
+ retries: The number of retry attempts for failed requests.
393
+ total_timeout: The total timeout for a request in seconds.
394
+ """
395
+ retry_options = ExponentialRetry(attempts=retries)
396
+ timeout = aiohttp.ClientTimeout(total=total_timeout) # pyrefly: ignore false positive
397
+ async with aiohttp.ClientSession(timeout=timeout) as session:
398
+ client = RetryClient(client_session=session, retry_options=retry_options)
399
+ yield client
400
+
401
+
402
+ class NestedAsyncIOLoopError(RuntimeError):
403
+ """Custom error for nested async I/O loops."""
404
+
405
+ def __init__(self) -> None:
406
+ msg = dedent("""\
407
+ Can not run async method from an environment where the asyncio event loop is already running.
408
+ Like a Jupyter notebook.
409
+
410
+ Please use the async function directly or
411
+ call `import nest_asyncio; nest_asyncio.apply()` and try again.
412
+ """)
413
+ super().__init__(msg)
414
+
415
+
416
+ def run_async[R](coroutine: Coroutine[Any, Any, R]) -> R:
417
+ """Run an async coroutine with nicer error.
418
+
419
+ Args:
420
+ coroutine: The async coroutine to run.
421
+
422
+ Returns:
423
+ The result of the coroutine.
424
+
425
+ Raises:
426
+ NestedAsyncIOLoopError: If called from a nested async I/O loop like in a Jupyter notebook.
427
+ """
428
+ try:
429
+ return asyncio.run(coroutine)
430
+ except RuntimeError as e:
431
+ raise NestedAsyncIOLoopError from e
432
+
433
+
434
+ def copyfile(source: Path, target: Path, copy_method: CopyMethod = "copy"):
435
+ """Make target path be same file as source by either copying or symlinking or hardlinking.
436
+
437
+ Note that the hardlink copy method only works within the same filesystem and is harder to track.
438
+ If you want to track cached files easily then use 'symlink'.
439
+ On Windows you need developer mode or admin privileges to create symlinks.
440
+
441
+ Args:
442
+ source: The source file to copy or link.
443
+ target: The target file to create.
444
+ copy_method: The method to use for copying.
445
+
446
+ Raises:
447
+ FileNotFoundError: If the source file or parent of target does not exist.
448
+ FileExistsError: If the target file already exists.
449
+ ValueError: If an unknown copy method is provided.
450
+ """
451
+ if copy_method == "copy":
452
+ shutil.copyfile(source, target)
453
+ elif copy_method == "symlink":
454
+ rel_source = source.absolute().relative_to(target.parent.absolute(), walk_up=True)
455
+ target.symlink_to(rel_source)
456
+ elif copy_method == "hardlink":
457
+ target.hardlink_to(source)
458
+ else:
459
+ msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
460
+ raise ValueError(msg)
461
+
462
+
463
+ async def async_copyfile(
464
+ source: Path,
465
+ target: Path,
466
+ copy_method: CopyMethod = "copy",
467
+ ):
468
+ """Asynchronously make target path be same file as source by either copying or symlinking or hardlinking.
469
+
470
+ Note that the hardlink copy method only works within the same filesystem and is harder to track.
471
+ If you want to track cached files easily then use 'symlink'.
472
+ On Windows you need developer mode or admin privileges to create symlinks.
473
+
474
+ Args:
475
+ source: The source file to copy.
476
+ target: The target file to create.
477
+ copy_method: The method to use for copying.
478
+
479
+ Raises:
480
+ FileNotFoundError: If the source file or parent of target does not exist.
481
+ FileExistsError: If the target file already exists.
482
+ ValueError: If an unknown copy method is provided.
483
+ """
484
+ if copy_method == "copy":
485
+ # Could use loop of chunks with aiofiles,
486
+ # but shutil is ~1.9x faster on my machine
487
+ # due to fastcopy and sendfile optimizations in shutil.
488
+ await asyncio.to_thread(shutil.copyfile, source, target)
489
+ elif copy_method == "symlink":
490
+ rel_source = source.relative_to(target.parent, walk_up=True)
491
+ await aiofiles.os.symlink(str(rel_source), str(target))
492
+ elif copy_method == "hardlink":
493
+ await aiofiles.os.link(str(source), str(target))
494
+ else:
495
+ msg = f"Unknown method: {copy_method}. Valid methods are: {copy_methods}"
496
+ raise ValueError(msg)
497
+
498
+
499
+ def populate_cache_command(raw_args: Sequence[str] | None = None):
500
+ """Command line interface to populate the cache from an existing directory.
501
+
502
+ Can be called from the command line as:
503
+
504
+ ```bash
505
+ python3 -m protein_quest.utils populate-cache /path/to/source/dir
506
+ ```
507
+
508
+ Args:
509
+ raw_args: The raw command line arguments to parse. If None, uses sys.argv.
510
+ """
511
+ root_parser = argparse.ArgumentParser(formatter_class=ArgumentDefaultsRichHelpFormatter)
512
+ subparsers = root_parser.add_subparsers(dest="command")
513
+
514
+ desc = "Populate the cache directory with files from the source directory."
515
+ populate_cache_parser = subparsers.add_parser(
516
+ "populate-cache",
517
+ help=desc,
518
+ description=desc,
519
+ formatter_class=ArgumentDefaultsRichHelpFormatter,
520
+ )
521
+ populate_cache_parser.add_argument("source_dir", type=Path)
522
+ populate_cache_parser.add_argument(
523
+ "--cache-dir",
524
+ type=Path,
525
+ default=user_cache_root_dir(),
526
+ help="Directory to use for caching. If not provided, a default cache directory is used.",
527
+ )
528
+ populate_cache_parser.add_argument(
529
+ "--copy-method",
530
+ type=str,
531
+ default="hardlink",
532
+ choices=copy_methods,
533
+ help="Method to use for copying files to cache.",
534
+ )
535
+
536
+ args = root_parser.parse_args(raw_args)
537
+ if args.command == "populate-cache":
538
+ source_dir = args.source_dir
539
+ cacher = DirectoryCacher(cache_dir=args.cache_dir, copy_method=args.copy_method)
540
+ cached_files = cacher.populate_cache(source_dir)
541
+ rich.print(f"Cached {len(cached_files)} files from {source_dir} to {cacher.cache_dir}")
542
+ for src, cached in cached_files.items():
543
+ rich.print(f"- {src.relative_to(source_dir)} -> {cached.relative_to(cacher.cache_dir)}")
544
+
545
+
546
+ if __name__ == "__main__":
547
+ populate_cache_command()