protein-quest 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

@@ -0,0 +1,511 @@
1
+ """Module for searching UniProtKB using SPARQL."""
2
+
3
+ import logging
4
+ from collections.abc import Collection, Iterable
5
+ from dataclasses import dataclass
6
+ from itertools import batched
7
+ from textwrap import dedent
8
+
9
+ from SPARQLWrapper import JSON, SPARQLWrapper
10
+ from tqdm.auto import tqdm
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class Query:
17
+ """Search query for UniProtKB.
18
+
19
+ Parameters:
20
+ taxon_id: NCBI Taxon ID to filter results by organism (e.g., "9606" for human).
21
+ reviewed: Whether to filter results by reviewed status (True for reviewed, False for unreviewed).
22
+ subcellular_location_uniprot: Subcellular location in UniProt format (e.g., "nucleus").
23
+ subcellular_location_go: Subcellular location in GO format. Can be a single GO term
24
+ (e.g., ["GO:0005634"]) or a collection of GO terms (e.g., ["GO:0005634", "GO:0005737"]).
25
+ molecular_function_go: Molecular function in GO format. Can be a single GO term
26
+ (e.g., ["GO:0003674"]) or a collection of GO terms (e.g., ["GO:0003674", "GO:0008150"]).
27
+ """
28
+
29
+ # TODO make taxon_id an int
30
+ taxon_id: str | None
31
+ reviewed: bool | None = None
32
+ subcellular_location_uniprot: str | None = None
33
+ subcellular_location_go: list[str] | None = None
34
+ molecular_function_go: list[str] | None = None
35
+
36
+
37
+ def _first_chain_from_uniprot_chains(uniprot_chains: str) -> str:
38
+ """Extracts the first chain identifier from a UniProt chains string.
39
+
40
+ The UniProt chains string is formatted (with EBNF notation) as follows:
41
+
42
+ chain_group(=range)?(,chain_group(=range)?)*
43
+
44
+ where:
45
+ chain_group := chain_id(/chain_id)*
46
+ chain_id := [A-Za-z]+
47
+ range := start-end
48
+ start, end := integer
49
+
50
+ Args:
51
+ uniprot_chains: A string representing UniProt chains, For example "B/D=1-81".
52
+ Returns:
53
+ The first chain identifier from the UniProt chain string. For example "B".
54
+ """
55
+ chains = uniprot_chains.split("=")
56
+ parts = chains[0].split("/")
57
+ chain = parts[0]
58
+ try:
59
+ # Workaround for Q9Y2Q5 │ 5YK3 │ 1/B/G=1-124, 1 does not exist but B does
60
+ int(chain)
61
+ if len(parts) > 1:
62
+ return parts[1]
63
+ except ValueError:
64
+ # A letter
65
+ pass
66
+ return chain
67
+
68
+
69
+ @dataclass(frozen=True)
70
+ class PdbResult:
71
+ """Result of a PDB search in UniProtKB.
72
+
73
+ Parameters:
74
+ id: PDB ID (e.g., "1H3O").
75
+ method: Method used for the PDB entry (e.g., "X-ray diffraction").
76
+ uniprot_chains: Chains in UniProt format (e.g., "A/B=1-42,A/B=50-99").
77
+ resolution: Resolution of the PDB entry (e.g., "2.0" for 2.0 Å). Optional.
78
+ """
79
+
80
+ id: str
81
+ method: str
82
+ uniprot_chains: str
83
+ resolution: str | None = None
84
+
85
+ @property
86
+ def chain(self) -> str:
87
+ """The first chain from the UniProt chains aka self.uniprot_chains."""
88
+ return _first_chain_from_uniprot_chains(self.uniprot_chains)
89
+
90
+
91
+ def _query2dynamic_sparql_triples(query: Query):
92
+ parts: list[str] = []
93
+ if query.taxon_id:
94
+ parts.append(f"?protein up:organism taxon:{query.taxon_id} .")
95
+
96
+ if query.reviewed:
97
+ parts.append("?protein up:reviewed true .")
98
+ elif query.reviewed is False:
99
+ parts.append("?protein up:reviewed false .")
100
+
101
+ parts.append(_append_subcellular_location_filters(query))
102
+
103
+ if query.molecular_function_go:
104
+ # Handle both single GO term (string) and multiple GO terms (list)
105
+ if isinstance(query.molecular_function_go, str):
106
+ go_terms = [query.molecular_function_go]
107
+ else:
108
+ go_terms = query.molecular_function_go
109
+
110
+ molecular_function_filter = _create_go_filter(go_terms, "Molecular function")
111
+ parts.append(molecular_function_filter)
112
+
113
+ return "\n".join(parts)
114
+
115
+
116
+ def _create_go_filter(go_terms: Collection[str], term_type: str) -> str:
117
+ """Create SPARQL filter for GO terms.
118
+
119
+ Args:
120
+ go_terms: Collection of GO terms to filter by.
121
+ term_type: Type of GO terms for error messages (e.g., "Molecular function", "Subcellular location").
122
+
123
+ Returns:
124
+ SPARQL filter string.
125
+ """
126
+ # Validate all GO terms start with "GO:"
127
+ for term in go_terms:
128
+ if not term.startswith("GO:"):
129
+ msg = f"{term_type} GO term must start with 'GO:', got: {term}"
130
+ raise ValueError(msg)
131
+
132
+ if len(go_terms) == 1:
133
+ # Single GO term - get the first (and only) term
134
+ term = next(iter(go_terms))
135
+ return dedent(f"""
136
+ ?protein up:classifiedWith|(up:classifiedWith/rdfs:subClassOf+) {term} .
137
+ """)
138
+
139
+ # Multiple GO terms - use UNION for OR logic
140
+ union_parts = [
141
+ dedent(f"""
142
+ {{ ?protein up:classifiedWith|(up:classifiedWith/rdfs:subClassOf+) {term} . }}
143
+ """).strip()
144
+ for term in go_terms
145
+ ]
146
+ return " UNION ".join(union_parts)
147
+
148
+
149
+ def _append_subcellular_location_filters(query: Query) -> str:
150
+ subcellular_location_uniprot_part = ""
151
+ subcellular_location_go_part = ""
152
+
153
+ if query.subcellular_location_uniprot:
154
+ subcellular_location_uniprot_part = dedent(f"""
155
+ ?protein up:annotation ?subcellAnnotation .
156
+ ?subcellAnnotation up:locatedIn/up:cellularComponent ?cellcmpt .
157
+ ?cellcmpt skos:prefLabel "{query.subcellular_location_uniprot}" .
158
+ """)
159
+
160
+ if query.subcellular_location_go:
161
+ # Handle both single GO term (string) and multiple GO terms (list)
162
+ if isinstance(query.subcellular_location_go, str):
163
+ go_terms = [query.subcellular_location_go]
164
+ else:
165
+ go_terms = query.subcellular_location_go
166
+
167
+ subcellular_location_go_part = _create_go_filter(go_terms, "Subcellular location")
168
+
169
+ if subcellular_location_uniprot_part and subcellular_location_go_part:
170
+ # If both are provided include results for both with logical OR
171
+ return dedent(f"""
172
+ {{
173
+ {subcellular_location_uniprot_part}
174
+ }} UNION {{
175
+ {subcellular_location_go_part}
176
+ }}
177
+ """)
178
+
179
+ return subcellular_location_uniprot_part or subcellular_location_go_part
180
+
181
+
182
+ def _build_sparql_generic_query(select_clause: str, where_clause: str, limit: int = 10_000, groupby_clause="") -> str:
183
+ """
184
+ Builds a generic SPARQL query with the given select and where clauses.
185
+ """
186
+ groupby = f" GROUP BY {groupby_clause}" if groupby_clause else ""
187
+ return dedent(f"""
188
+ PREFIX up: <http://purl.uniprot.org/core/>
189
+ PREFIX taxon: <http://purl.uniprot.org/taxonomy/>
190
+ PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
191
+ PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
192
+ PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
193
+ PREFIX GO:<http://purl.obolibrary.org/obo/GO_>
194
+
195
+ SELECT {select_clause}
196
+ WHERE {{
197
+ {where_clause}
198
+ }}
199
+ {groupby}
200
+ LIMIT {limit}
201
+ """)
202
+
203
+
204
+ def _build_sparql_generic_by_uniprot_accesions_query(
205
+ uniprot_accs: Iterable[str], select_clause: str, where_clause: str, limit: int = 10_000, groupby_clause=""
206
+ ) -> str:
207
+ values = " ".join(f'("{ac}")' for ac in uniprot_accs)
208
+ where_clause2 = dedent(f"""
209
+ # --- Protein Selection ---
210
+ VALUES (?ac) {{ {values}}}
211
+ BIND (IRI(CONCAT("http://purl.uniprot.org/uniprot/",?ac)) AS ?protein)
212
+ ?protein a up:Protein .
213
+
214
+ {where_clause}
215
+ """)
216
+ return _build_sparql_generic_query(
217
+ select_clause=select_clause,
218
+ where_clause=where_clause2,
219
+ limit=limit,
220
+ groupby_clause=groupby_clause,
221
+ )
222
+
223
+
224
+ def _build_sparql_query_uniprot(query: Query, limit=10_000) -> str:
225
+ dynamic_triples = _query2dynamic_sparql_triples(query)
226
+ # TODO add usefull columns that have 1:1 mapping to protein
227
+ # like uniprot_id with `?protein up:mnemonic ?mnemonic .`
228
+ # and sequence, take care to take first isoform
229
+ # ?protein up:sequence ?isoform .
230
+ # ?isoform rdf:value ?sequence .
231
+ select_clause = "DISTINCT ?protein"
232
+ where_clause = dedent(f"""
233
+ # --- Protein Selection ---
234
+ ?protein a up:Protein .
235
+ {dynamic_triples}
236
+ """)
237
+ return _build_sparql_generic_query(select_clause, dedent(where_clause), limit)
238
+
239
+
240
+ def _build_sparql_query_pdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
241
+ # For http://purl.uniprot.org/uniprot/O00268 + http://rdf.wwpdb.org/pdb/1H3O
242
+ # the chainSequenceMapping are
243
+ # http://purl.uniprot.org/isoforms/O00268-1#PDB_1H3O_tt872tt945
244
+ # http://purl.uniprot.org/isoforms/Q16514-1#PDB_1H3O_tt57tt128
245
+ # For http://purl.uniprot.org/uniprot/O00255 + http://rdf.wwpdb.org/pdb/3U84
246
+ # the chainSequenceMapping are
247
+ # http://purl.uniprot.org/isoforms/O00255-2#PDB_3U84_tt520tt610
248
+ # http://purl.uniprot.org/isoforms/O00255-2#PDB_3U84_tt2tt459
249
+ # To get the the chain belonging to the uniprot/pdb pair we need to
250
+ # do some string filtering.
251
+ # Also there can be multiple cnhins for the same uniprot/pdb pair, so we need to
252
+ # do a group by and concat
253
+
254
+ select_clause = dedent("""\
255
+ ?protein ?pdb_db ?pdb_method ?pdb_resolution
256
+ (GROUP_CONCAT(DISTINCT ?pdb_chain; separator=",") AS ?pdb_chains)
257
+ """)
258
+
259
+ where_clause = dedent("""
260
+ # --- PDB Info ---
261
+ ?protein rdfs:seeAlso ?pdb_db .
262
+ ?pdb_db up:database <http://purl.uniprot.org/database/PDB> .
263
+ ?pdb_db up:method ?pdb_method .
264
+ ?pdb_db up:chainSequenceMapping ?chainSequenceMapping .
265
+ BIND(STRAFTER(STR(?chainSequenceMapping), "isoforms/") AS ?isoformPart)
266
+ FILTER(STRSTARTS(?isoformPart, CONCAT(?ac, "-")))
267
+ ?chainSequenceMapping up:chain ?pdb_chain .
268
+ OPTIONAL { ?pdb_db up:resolution ?pdb_resolution . }
269
+ """)
270
+
271
+ groupby_clause = "?protein ?pdb_db ?pdb_method ?pdb_resolution"
272
+ return _build_sparql_generic_by_uniprot_accesions_query(
273
+ uniprot_accs, select_clause, where_clause, limit, groupby_clause
274
+ )
275
+
276
+
277
+ def _build_sparql_query_af(uniprot_accs: Iterable[str], limit=10_000) -> str:
278
+ select_clause = "?protein ?af_db"
279
+ where_clause = dedent("""
280
+ # --- Protein Selection ---
281
+ ?protein a up:Protein .
282
+
283
+ # --- AlphaFoldDB Info ---
284
+ ?protein rdfs:seeAlso ?af_db .
285
+ ?af_db up:database <http://purl.uniprot.org/database/AlphaFoldDB> .
286
+ """)
287
+ return _build_sparql_generic_by_uniprot_accesions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
288
+
289
+
290
+ def _build_sparql_query_emdb(uniprot_accs: Iterable[str], limit=10_000) -> str:
291
+ select_clause = "?protein ?emdb_db"
292
+ where_clause = dedent("""
293
+ # --- Protein Selection ---
294
+ ?protein a up:Protein .
295
+
296
+ # --- EMDB Info ---
297
+ ?protein rdfs:seeAlso ?emdb_db .
298
+ ?emdb_db up:database <http://purl.uniprot.org/database/EMDB> .
299
+ """)
300
+ return _build_sparql_generic_by_uniprot_accesions_query(uniprot_accs, select_clause, dedent(where_clause), limit)
301
+
302
+
303
+ def _execute_sparql_search(
304
+ sparql_query: str,
305
+ timeout: int,
306
+ ) -> list:
307
+ """
308
+ Execute a SPARQL query.
309
+ """
310
+ if timeout > 2_700:
311
+ msg = "Uniprot SPARQL timeout is limited to 2700 seconds (45 minutes)."
312
+ raise ValueError(msg)
313
+
314
+ # Execute the query
315
+ sparql = SPARQLWrapper("https://sparql.uniprot.org/sparql")
316
+ sparql.setReturnFormat(JSON)
317
+ sparql.setTimeout(timeout)
318
+
319
+ # Default is GET method which can be cached by the server so is preferred.
320
+ # Too prevent URITooLong errors, we use POST method for large queries.
321
+ too_long_for_get = 5_000
322
+ if len(sparql_query) > too_long_for_get:
323
+ sparql.setMethod("POST")
324
+
325
+ sparql.setQuery(sparql_query)
326
+ rawresults = sparql.queryAndConvert()
327
+ if not isinstance(rawresults, dict):
328
+ msg = f"Expected rawresults to be a dict, but got {type(rawresults)}"
329
+ raise TypeError(msg)
330
+
331
+ bindings = rawresults.get("results", {}).get("bindings")
332
+ if not isinstance(bindings, list):
333
+ logger.warning("SPARQL query did not return 'bindings' list as expected.")
334
+ return []
335
+
336
+ logger.debug(bindings)
337
+ return bindings
338
+
339
+
340
+ def _flatten_results_pdb(rawresults: Iterable) -> dict[str, set[PdbResult]]:
341
+ pdb_entries: dict[str, set[PdbResult]] = {}
342
+ for result in rawresults:
343
+ protein = result["protein"]["value"].split("/")[-1]
344
+ if "pdb_db" not in result: # Should not happen with build_sparql_query_pdb
345
+ continue
346
+ pdb_id = result["pdb_db"]["value"].split("/")[-1]
347
+ method = result["pdb_method"]["value"].split("/")[-1]
348
+ uniprot_chains = result["pdb_chains"]["value"]
349
+ pdb = PdbResult(id=pdb_id, method=method, uniprot_chains=uniprot_chains)
350
+ if "pdb_resolution" in result:
351
+ pdb = PdbResult(
352
+ id=pdb_id,
353
+ method=method,
354
+ uniprot_chains=uniprot_chains,
355
+ resolution=result["pdb_resolution"]["value"],
356
+ )
357
+ if protein not in pdb_entries:
358
+ pdb_entries[protein] = set()
359
+ pdb_entries[protein].add(pdb)
360
+
361
+ return pdb_entries
362
+
363
+
364
+ def _flatten_results_af(rawresults: Iterable) -> dict[str, set[str]]:
365
+ alphafold_entries: dict[str, set[str]] = {}
366
+ for result in rawresults:
367
+ protein = result["protein"]["value"].split("/")[-1]
368
+ if "af_db" in result:
369
+ af_id = result["af_db"]["value"].split("/")[-1]
370
+ if protein not in alphafold_entries:
371
+ alphafold_entries[protein] = set()
372
+ alphafold_entries[protein].add(af_id)
373
+ return alphafold_entries
374
+
375
+
376
+ def _flatten_results_emdb(rawresults: Iterable) -> dict[str, set[str]]:
377
+ emdb_entries: dict[str, set[str]] = {}
378
+ for result in rawresults:
379
+ protein = result["protein"]["value"].split("/")[-1]
380
+ if "emdb_db" in result:
381
+ emdb_id = result["emdb_db"]["value"].split("/")[-1]
382
+ if protein not in emdb_entries:
383
+ emdb_entries[protein] = set()
384
+ emdb_entries[protein].add(emdb_id)
385
+ return emdb_entries
386
+
387
+
388
+ def limit_check(what: str, limit: int, len_raw_results: int):
389
+ if len_raw_results >= limit:
390
+ logger.warning(
391
+ "%s returned %d results. "
392
+ "There may be more results available, "
393
+ "but they are not returned due to the limit of %d. "
394
+ "Consider increasing the limit to get more results.",
395
+ what,
396
+ len_raw_results,
397
+ limit,
398
+ )
399
+
400
+
401
+ def search4uniprot(query: Query, limit: int = 10_000, timeout: int = 1_800) -> set[str]:
402
+ """
403
+ Search for UniProtKB entries based on the given query.
404
+
405
+ Args:
406
+ query: Query object containing search parameters.
407
+ limit: Maximum number of results to return.
408
+ timeout: Timeout for the SPARQL query in seconds.
409
+
410
+ Returns:
411
+ Set of uniprot accessions.
412
+ """
413
+ sparql_query = _build_sparql_query_uniprot(query, limit)
414
+ logger.info("Executing SPARQL query for UniProt: %s", sparql_query)
415
+
416
+ # Type assertion is needed because _execute_sparql_search returns a Union
417
+ raw_results = _execute_sparql_search(
418
+ sparql_query=sparql_query,
419
+ timeout=timeout,
420
+ )
421
+ limit_check("Search for uniprot accessions", limit, len(raw_results))
422
+ return {result["protein"]["value"].split("/")[-1] for result in raw_results}
423
+
424
+
425
+ def search4pdb(
426
+ uniprot_accs: Collection[str], limit: int = 10_000, timeout: int = 1_800, batch_size: int = 10_000
427
+ ) -> dict[str, set[PdbResult]]:
428
+ """
429
+ Search for PDB entries in UniProtKB accessions.
430
+
431
+ Args:
432
+ uniprot_accs: UniProt accessions.
433
+ limit: Maximum number of results to return.
434
+ timeout: Timeout for the SPARQL query in seconds.
435
+ batch_size: Size of batches to process the UniProt accessions.
436
+
437
+ Returns:
438
+ Dictionary with protein IDs as keys and sets of PDB results as values.
439
+ """
440
+ all_raw_results = []
441
+ total = len(uniprot_accs)
442
+ with tqdm(total=total, desc="Searching for PDBs of uniprots", disable=total < batch_size, unit="acc") as pbar:
443
+ for batch in batched(uniprot_accs, batch_size, strict=False):
444
+ sparql_query = _build_sparql_query_pdb(batch, limit)
445
+ logger.info("Executing SPARQL query for PDB: %s", sparql_query)
446
+
447
+ raw_results = _execute_sparql_search(
448
+ sparql_query=sparql_query,
449
+ timeout=timeout,
450
+ )
451
+ all_raw_results.extend(raw_results)
452
+ pbar.update(len(batch))
453
+
454
+ limit_check("Search for pdbs on uniprot", limit, len(all_raw_results))
455
+ return _flatten_results_pdb(all_raw_results)
456
+
457
+
458
+ def search4af(
459
+ uniprot_accs: Collection[str], limit: int = 10_000, timeout: int = 1_800, batch_size: int = 10_000
460
+ ) -> dict[str, set[str]]:
461
+ """
462
+ Search for AlphaFold entries in UniProtKB accessions.
463
+
464
+ Args:
465
+ uniprot_accs: UniProt accessions.
466
+ limit: Maximum number of results to return.
467
+ timeout: Timeout for the SPARQL query in seconds.
468
+ batch_size: Size of batches to process the UniProt accessions.
469
+
470
+ Returns:
471
+ Dictionary with protein IDs as keys and sets of AlphaFold IDs as values.
472
+ """
473
+ all_raw_results = []
474
+ total = len(uniprot_accs)
475
+ with tqdm(total=total, desc="Searching for AlphaFolds of uniprots", disable=total < batch_size, unit="acc") as pbar:
476
+ for batch in batched(uniprot_accs, batch_size, strict=False):
477
+ sparql_query = _build_sparql_query_af(batch, limit)
478
+ logger.info("Executing SPARQL query for AlphaFold: %s", sparql_query)
479
+
480
+ raw_results = _execute_sparql_search(
481
+ sparql_query=sparql_query,
482
+ timeout=timeout,
483
+ )
484
+ all_raw_results.extend(raw_results)
485
+ pbar.update(len(batch))
486
+
487
+ limit_check("Search for alphafold entries on uniprot", limit, len(all_raw_results))
488
+ return _flatten_results_af(all_raw_results)
489
+
490
+
491
+ def search4emdb(uniprot_accs: Iterable[str], limit: int = 10_000, timeout: int = 1_800) -> dict[str, set[str]]:
492
+ """
493
+ Search for EMDB entries in UniProtKB accessions.
494
+
495
+ Args:
496
+ uniprot_accs: UniProt accessions.
497
+ limit: Maximum number of results to return.
498
+ timeout: Timeout for the SPARQL query in seconds.
499
+
500
+ Returns:
501
+ Dictionary with protein IDs as keys and sets of EMDB IDs as values.
502
+ """
503
+ sparql_query = _build_sparql_query_emdb(uniprot_accs, limit)
504
+ logger.info("Executing SPARQL query for EMDB: %s", sparql_query)
505
+
506
+ raw_results = _execute_sparql_search(
507
+ sparql_query=sparql_query,
508
+ timeout=timeout,
509
+ )
510
+ limit_check("Search for EMDB entries on uniprot", limit, len(raw_results))
511
+ return _flatten_results_emdb(raw_results)
protein_quest/utils.py ADDED
@@ -0,0 +1,105 @@
1
+ """Module for functions that are used in multiple places."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from collections.abc import Iterable
6
+ from contextlib import asynccontextmanager
7
+ from pathlib import Path
8
+
9
+ import aiofiles
10
+ import aiohttp
11
+ from aiohttp_retry import ExponentialRetry, RetryClient
12
+ from tqdm.asyncio import tqdm
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ async def retrieve_files(
18
+ urls: Iterable[tuple[str, str]],
19
+ save_dir: Path,
20
+ max_parallel_downloads: int = 5,
21
+ retries: int = 3,
22
+ total_timeout: int = 300,
23
+ desc: str = "Downloading files",
24
+ ) -> list[Path]:
25
+ """Retrieve files from a list of URLs and save them to a directory.
26
+
27
+ Args:
28
+ urls: A list of tuples, where each tuple contains a URL and a filename.
29
+ save_dir: The directory to save the downloaded files to.
30
+ max_parallel_downloads: The maximum number of files to download in parallel.
31
+ retries: The number of times to retry a failed download.
32
+ total_timeout: The total timeout for a download in seconds.
33
+ desc: Description for the progress bar.
34
+
35
+ Returns:
36
+ A list of paths to the downloaded files.
37
+ """
38
+ save_dir.mkdir(parents=True, exist_ok=True)
39
+ semaphore = asyncio.Semaphore(max_parallel_downloads)
40
+ async with friendly_session(retries, total_timeout) as session:
41
+ tasks = [_retrieve_file(session, url, save_dir / filename, semaphore) for url, filename in urls]
42
+ files: list[Path] = await tqdm.gather(*tasks, desc=desc)
43
+ return files
44
+
45
+
46
+ async def _retrieve_file(
47
+ session: RetryClient,
48
+ url: str,
49
+ save_path: Path,
50
+ semaphore: asyncio.Semaphore,
51
+ ovewrite: bool = False,
52
+ chunk_size: int = 131072, # 128 KiB
53
+ ) -> Path:
54
+ """Retrieve a single file from a URL and save it to a specified path.
55
+
56
+ Args:
57
+ session: The aiohttp session to use for the request.
58
+ url: The URL to download the file from.
59
+ save_path: The path where the file should be saved.
60
+ semaphore: A semaphore to limit the number of concurrent downloads.
61
+ ovewrite: Whether to overwrite the file if it already exists.
62
+ chunk_size: The size of each chunk to read from the response.
63
+
64
+ Returns:
65
+ The path to the saved file.
66
+ """
67
+ if save_path.exists():
68
+ if ovewrite:
69
+ save_path.unlink()
70
+ else:
71
+ logger.debug(f"File {save_path} already exists. Skipping download from {url}.")
72
+ return save_path
73
+ async with (
74
+ semaphore,
75
+ aiofiles.open(save_path, "xb") as f,
76
+ session.get(url) as resp,
77
+ ):
78
+ resp.raise_for_status()
79
+ async for chunk in resp.content.iter_chunked(chunk_size):
80
+ await f.write(chunk)
81
+ return save_path
82
+
83
+
84
+ @asynccontextmanager
85
+ async def friendly_session(retries: int = 3, total_timeout: int = 300):
86
+ """Create an aiohttp session with retry capabilities.
87
+
88
+ Examples:
89
+ Use as async context:
90
+
91
+ >>> async with friendly_session(retries=5, total_timeout=60) as session:
92
+ >>> r = await session.get("https://example.com/api/data")
93
+ >>> print(r)
94
+ <ClientResponse(https://example.com/api/data) [404 Not Found]>
95
+ <CIMultiDictProxy('Accept-Ranges': 'bytes', ...
96
+
97
+ Args:
98
+ retries: The number of retry attempts for failed requests.
99
+ total_timeout: The total timeout for a request in seconds.
100
+ """
101
+ retry_options = ExponentialRetry(attempts=retries)
102
+ timeout = aiohttp.ClientTimeout(total=total_timeout) # pyrefly: ignore false positive
103
+ async with aiohttp.ClientSession(timeout=timeout) as session:
104
+ client = RetryClient(client_session=session, retry_options=retry_options)
105
+ yield client