protein-quest 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

protein_quest/emdb.py ADDED
@@ -0,0 +1,34 @@
1
+ """Module dealing with Electron Microscopy Data Bank (EMDB)."""
2
+
3
+ from collections.abc import Iterable, Mapping
4
+ from pathlib import Path
5
+
6
+ from protein_quest.utils import retrieve_files
7
+
8
+
9
+ def _map_id2volume_url(emdb_id: str) -> tuple[str, str]:
10
+ # https://ftp.ebi.ac.uk/pub/databases/emdb/structures/EMD-19583/map/emd_19583.map.gz
11
+ fn = emdb_id.lower().replace("emd-", "emd_") + ".map.gz"
12
+ url = f"https://ftp.ebi.ac.uk/pub/databases/emdb/structures/{emdb_id}/map/{fn}"
13
+ return url, fn
14
+
15
+
16
+ async def fetch(emdb_ids: Iterable[str], save_dir: Path, max_parallel_downloads: int = 1) -> Mapping[str, Path]:
17
+ """Fetches volume files from the EMDB database.
18
+
19
+ Args:
20
+ emdb_ids: A list of EMDB IDs to fetch.
21
+ save_dir: The directory to save the downloaded files.
22
+ max_parallel_downloads: The maximum number of parallel downloads.
23
+
24
+ Returns:
25
+ A mapping of EMDB IDs to their downloaded files.
26
+ """
27
+ id2urls = {emdb_id: _map_id2volume_url(emdb_id) for emdb_id in emdb_ids}
28
+ urls = list(id2urls.values())
29
+ id2paths = {emdb_id: save_dir / fn for emdb_id, (_, fn) in id2urls.items()}
30
+
31
+ # TODO show progress of each item
32
+ # TODO handle failed downloads, by skipping them instead of raising an error
33
+ await retrieve_files(urls, save_dir, max_parallel_downloads, desc="Downloading EMDB volume files")
34
+ return id2paths
@@ -0,0 +1,107 @@
1
+ """Module for filtering structure files and their contents."""
2
+
3
+ import logging
4
+ from collections.abc import Generator
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from shutil import copyfile
8
+ from typing import cast
9
+
10
+ from dask.distributed import Client, progress
11
+ from distributed.deploy.cluster import Cluster
12
+ from tqdm.auto import tqdm
13
+
14
+ from protein_quest.parallel import configure_dask_scheduler
15
+ from protein_quest.pdbe.io import (
16
+ locate_structure_file,
17
+ nr_residues_in_chain,
18
+ write_single_chain_pdb_file,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def filter_files_on_chain(
25
+ input_dir: Path,
26
+ id2chains: dict[str, str],
27
+ output_dir: Path,
28
+ scheduler_address: str | Cluster | None = None,
29
+ out_chain: str = "A",
30
+ ) -> list[tuple[str, str, Path | None]]:
31
+ """Filter mmcif/PDB files by chain.
32
+
33
+ Args:
34
+ input_dir: The directory containing the input mmcif/PDB files.
35
+ id2chains: Which chain to keep for each PDB ID. Key is the PDB ID, value is the chain ID.
36
+ output_dir: The directory where the filtered files will be written.
37
+ scheduler_address: The address of the Dask scheduler.
38
+ out_chain: Under what name to write the kept chain.
39
+
40
+ Returns:
41
+ A list of tuples containing the PDB ID, chain ID, and path to the filtered file.
42
+ Last tuple item is None if something went wrong like chain not present.
43
+ """
44
+ output_dir.mkdir(parents=True, exist_ok=True)
45
+ scheduler_address = configure_dask_scheduler(
46
+ scheduler_address,
47
+ name="filter-chain",
48
+ )
49
+
50
+ def task(id2chain: tuple[str, str]) -> tuple[str, str, Path | None]:
51
+ pdb_id, chain = id2chain
52
+ input_file = locate_structure_file(input_dir, pdb_id)
53
+ return pdb_id, chain, write_single_chain_pdb_file(input_file, chain, output_dir, out_chain=out_chain)
54
+
55
+ with Client(scheduler_address) as client:
56
+ logger.info(f"Follow progress on dask dashboard at: {client.dashboard_link}")
57
+
58
+ futures = client.map(task, id2chains.items())
59
+
60
+ progress(futures)
61
+
62
+ results = client.gather(futures)
63
+ return cast("list[tuple[str,str, Path | None]]", results)
64
+
65
+
66
+ @dataclass
67
+ class FilterStat:
68
+ """Statistics for filtering files based on residue count in a specific chain.
69
+
70
+ Parameters:
71
+ input_file: The path to the input file.
72
+ residue_count: The number of residues.
73
+ passed: Whether the file passed the filtering criteria.
74
+ output_file: The path to the output file, if passed.
75
+ """
76
+
77
+ input_file: Path
78
+ residue_count: int
79
+ passed: bool
80
+ output_file: Path | None
81
+
82
+
83
+ def filter_files_on_residues(
84
+ input_files: list[Path], output_dir: Path, min_residues: int, max_residues: int, chain: str = "A"
85
+ ) -> Generator[FilterStat]:
86
+ """Filter PDB/mmCIF files by number of residues in given chain.
87
+
88
+ Args:
89
+ input_files: The list of input PDB/mmCIF files.
90
+ output_dir: The directory where the filtered files will be written.
91
+ min_residues: The minimum number of residues in chain.
92
+ max_residues: The maximum number of residues in chain.
93
+ chain: The chain to count residues of.
94
+
95
+ Yields:
96
+ FilterStat objects containing information about the filtering process for each input file.
97
+ """
98
+ output_dir.mkdir(parents=True, exist_ok=True)
99
+ for input_file in tqdm(input_files, unit="file"):
100
+ residue_count = nr_residues_in_chain(input_file, chain=chain)
101
+ passed = min_residues <= residue_count <= max_residues
102
+ if passed:
103
+ output_file = output_dir / input_file.name
104
+ copyfile(input_file, output_file)
105
+ yield FilterStat(input_file, residue_count, True, output_file)
106
+ else:
107
+ yield FilterStat(input_file, residue_count, False, None)
protein_quest/go.py ADDED
@@ -0,0 +1,168 @@
1
+ """Module for Gene Ontology (GO) functions."""
2
+
3
+ import csv
4
+ import logging
5
+ from collections.abc import Generator
6
+ from dataclasses import dataclass
7
+ from io import TextIOWrapper
8
+ from typing import Literal, get_args
9
+
10
+ from cattrs.gen import make_dict_structure_fn, override
11
+ from cattrs.preconf.orjson import make_converter
12
+
13
+ from protein_quest.utils import friendly_session
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ Aspect = Literal["cellular_component", "biological_process", "molecular_function"]
18
+ """The aspect of the GO term."""
19
+ allowed_aspects = set(get_args(Aspect))
20
+ """Allowed aspects for GO terms."""
21
+
22
+
23
+ @dataclass(frozen=True, slots=True)
24
+ class GoTerm:
25
+ """A Gene Ontology (GO) term.
26
+
27
+ Parameters:
28
+ id: The unique identifier for the GO term, e.g., 'GO:0043293'.
29
+ is_obsolete: Whether the GO term is obsolete.
30
+ name: The name of the GO term.
31
+ definition: The definition of the GO term.
32
+ aspect: The aspect of the GO term.
33
+ """
34
+
35
+ id: str
36
+ is_obsolete: bool
37
+ name: str
38
+ definition: str
39
+ aspect: Aspect
40
+
41
+
42
+ @dataclass(frozen=True, slots=True)
43
+ class PageInfo:
44
+ current: int
45
+ total: int
46
+
47
+
48
+ @dataclass(frozen=True, slots=True)
49
+ class SearchResponse:
50
+ results: list[GoTerm]
51
+ number_of_hits: int
52
+ page_info: PageInfo
53
+
54
+
55
+ converter = make_converter()
56
+
57
+
58
+ def flatten_definition(definition, _context) -> str:
59
+ return definition["text"]
60
+
61
+
62
+ # Use hook to convert incoming camelCase to snake_case
63
+ # and to flatten definition {text} to text
64
+ # see https://catt.rs/en/stable/customizing.html#rename
65
+ converter.register_structure_hook(
66
+ GoTerm,
67
+ make_dict_structure_fn(
68
+ GoTerm,
69
+ converter,
70
+ is_obsolete=override(rename="isObsolete"),
71
+ definition=override(struct_hook=flatten_definition),
72
+ ),
73
+ )
74
+ converter.register_structure_hook(
75
+ SearchResponse,
76
+ make_dict_structure_fn(
77
+ SearchResponse, converter, number_of_hits=override(rename="numberOfHits"), page_info=override(rename="pageInfo")
78
+ ),
79
+ )
80
+
81
+
82
+ async def search_gene_ontology_term(
83
+ term: str, aspect: Aspect | None = None, include_obsolete: bool = False, limit: int = 100
84
+ ) -> list[GoTerm]:
85
+ """Search for a Gene Ontology (GO) term by its name or ID.
86
+
87
+ Calls the EBI QuickGO API at https://www.ebi.ac.uk/QuickGO/api/index.html .
88
+
89
+ Examples:
90
+ To search for `apoptosome` terms do.
91
+
92
+ >>> from protein_quest.go import search_go_term
93
+ >>> r = await search_go_term('apoptosome')
94
+ >>> len(r)
95
+ 5
96
+ >>> r[0]
97
+ GoTerm(id='GO:0043293', is_obsolete=False, name='apoptosome', definition='A multisubunit protein ...')
98
+
99
+ Args:
100
+ term: The GO term to search for. For example `nucleus` or `GO:0006816`.
101
+ aspect: The aspect to filter by. If not given, all aspects are included.
102
+ include_obsolete: Whether to include obsolete terms. By default, obsolete terms are excluded.
103
+ limit: The maximum number of results to return.
104
+
105
+ Returns:
106
+ List of GO terms
107
+
108
+ Raises:
109
+ ValueError: If the aspect is invalid.
110
+ """
111
+ url = "https://www.ebi.ac.uk/QuickGO/services/ontology/go/search"
112
+ page_limit = 100
113
+ params = {"query": term, "limit": str(page_limit), "page": "1"}
114
+ if aspect is not None and aspect not in allowed_aspects:
115
+ msg = f"Invalid aspect: {aspect}. Allowed aspects are: {allowed_aspects} or None."
116
+ raise ValueError(msg)
117
+ logger.debug("Fetching GO terms from %s with params %s", url, params)
118
+ async with friendly_session() as session:
119
+ # Fetch first page to learn how many pages there are
120
+ async with session.get(url, params=params) as response:
121
+ response.raise_for_status()
122
+ raw_data = await response.read()
123
+ data = converter.loads(raw_data, SearchResponse)
124
+
125
+ terms = list(_filter_go_terms(data.results, aspect, include_obsolete))
126
+ if len(terms) >= limit:
127
+ # Do not fetch additional pages if we have enough results
128
+ return terms[:limit]
129
+ total_pages = data.page_info.total
130
+ logger.debug("GO search returned %s pages (current=%s)", total_pages, data.page_info.current)
131
+
132
+ # Retrieve remaining pages (if any) and extend results
133
+ if total_pages > 1:
134
+ for page in range(2, total_pages + 1):
135
+ params["page"] = str(page)
136
+ logger.debug("Fetching additional GO terms page %s/%s with params %s", page, total_pages, params)
137
+ async with session.get(url, params=params) as response:
138
+ response.raise_for_status()
139
+ raw_data = await response.read()
140
+ data = converter.loads(raw_data, SearchResponse)
141
+ terms.extend(_filter_go_terms(data.results, aspect, include_obsolete))
142
+ if len(terms) >= limit:
143
+ # Do not fetch additional pages if we have enough results
144
+ break
145
+
146
+ return terms[:limit]
147
+
148
+
149
+ def _filter_go_terms(terms: list[GoTerm], aspect: Aspect | None, include_obsolete: bool) -> Generator[GoTerm]:
150
+ for oboterm in terms:
151
+ if not include_obsolete and oboterm.is_obsolete:
152
+ continue
153
+ if aspect and oboterm.aspect != aspect:
154
+ continue
155
+ yield oboterm
156
+
157
+
158
+ def write_go_terms_to_csv(terms: list[GoTerm], csv_file: TextIOWrapper) -> None:
159
+ """Write a list of GO terms to a CSV file.
160
+
161
+ Args:
162
+ terms: The list of GO terms to write.
163
+ csv_file: The CSV file to write to.
164
+ """
165
+ writer = csv.writer(csv_file)
166
+ writer.writerow(["id", "name", "aspect", "definition"])
167
+ for term in terms:
168
+ writer.writerow([term.id, term.name, term.aspect, term.definition])
@@ -0,0 +1,208 @@
1
+ """MCP server for protein-quest.
2
+
3
+ Can be run with:
4
+
5
+ ```shell
6
+ # for development
7
+ fastmcp dev src/protein_quest/mcp_server.py
8
+ # or from inspector
9
+ npx @modelcontextprotocol/inspector
10
+ # tranport type: stdio
11
+ # comand: protein-quest
12
+ # arguments: mcp
13
+
14
+ # or with server and inspector
15
+ protein-quest mcp --transport streamable-http
16
+ # in another shell
17
+ npx @modelcontextprotocol/inspector
18
+ # transport type: streamable http
19
+ # URL: http://127.0.0.1:8000/mcp
20
+
21
+ # or with copilot in VS code
22
+ # ctrl + shift + p
23
+ # mcp: add server...
24
+ # Choose STDIO
25
+ # command: uv run protein-quest mcp
26
+ # id: protein-quest
27
+ # Prompt: What are the PDBe structures for `A8MT69` uniprot accession?
28
+ ```
29
+
30
+ Examples:
31
+
32
+ For search pdb use `A8MT69` as input.
33
+
34
+ """
35
+
36
+ from pathlib import Path
37
+ from textwrap import dedent
38
+ from typing import Annotated
39
+
40
+ from fastmcp import FastMCP
41
+ from pydantic import Field
42
+
43
+ from protein_quest.alphafold.confidence import ConfidenceFilterQuery, ConfidenceFilterResult, filter_file_on_residues
44
+ from protein_quest.alphafold.fetch import AlphaFoldEntry, DownloadableFormat
45
+ from protein_quest.alphafold.fetch import fetch_many as alphafold_fetch
46
+ from protein_quest.emdb import fetch as emdb_fetch
47
+ from protein_quest.go import search_gene_ontology_term
48
+ from protein_quest.pdbe.fetch import fetch as pdbe_fetch
49
+ from protein_quest.pdbe.io import glob_structure_files, nr_residues_in_chain, write_single_chain_pdb_file
50
+ from protein_quest.taxonomy import search_taxon
51
+ from protein_quest.uniprot import PdbResult, Query, search4af, search4emdb, search4pdb, search4uniprot
52
+
53
+ mcp = FastMCP("protein-quest")
54
+
55
+ # do not want to make dataclasses in non-mcp code into Pydantic models,
56
+ # so we use Annotated here to add description on roots.
57
+
58
+
59
+ @mcp.tool
60
+ def search_uniprot(
61
+ uniprot_query: Annotated[Query, Field(description=Query.__doc__)],
62
+ limit: Annotated[int, Field(gt=0, description="Limit the number of uniprot accessions returned")] = 100,
63
+ ) -> set[str]:
64
+ """Search UniProt for proteins matching the given query."""
65
+ return search4uniprot(uniprot_query, limit=limit)
66
+
67
+
68
+ @mcp.tool
69
+ def search_pdb(
70
+ uniprot_accs: set[str],
71
+ limit: Annotated[int, Field(gt=0, description="Limit the number of entries returned")] = 100,
72
+ ) -> Annotated[
73
+ dict[str, set[PdbResult]],
74
+ Field(
75
+ description=dedent(f"""\
76
+ Dictionary with protein IDs as keys and sets of PDB results as values.
77
+ A PDB result is {PdbResult.__doc__}""")
78
+ ),
79
+ ]:
80
+ """Search PDBe structures for given uniprot accessions."""
81
+ return search4pdb(uniprot_accs, limit=limit)
82
+
83
+
84
+ mcp.tool(pdbe_fetch, name="fetch_pdbe_structures")
85
+
86
+
87
+ @mcp.tool
88
+ def extract_single_chain_from_structure(
89
+ input_file: Path,
90
+ chain2keep: str,
91
+ output_dir: Path,
92
+ out_chain: str = "A",
93
+ ) -> Path | None:
94
+ """
95
+ Extract a single chain from a mmCIF/pdb file and write to a new file.
96
+
97
+ Args:
98
+ input_file: Path to the input mmCIF/pdb file.
99
+ chain2keep: The chain to keep.
100
+ output_dir: Directory to save the output file.
101
+ out_chain: The chain identifier for the output file.
102
+
103
+ Returns:
104
+ Path to the output mmCIF/pdb file or None if not created.
105
+ """
106
+ return write_single_chain_pdb_file(input_file, chain2keep, output_dir, out_chain)
107
+
108
+
109
+ @mcp.tool
110
+ def list_structure_files(path: Path) -> list[Path]:
111
+ """List structure files (.pdb, .pdb.gz, .cif, .cif.gz) in the specified directory."""
112
+ return list(glob_structure_files(path))
113
+
114
+
115
+ # TODO replace remaining decorators with wrapper if tool does single function call
116
+ # so we do not have to replicate docstring,
117
+ # minor con is that it does not show up in api docs
118
+ mcp.tool(nr_residues_in_chain)
119
+ mcp.tool(search_taxon)
120
+ mcp.tool(search_gene_ontology_term)
121
+
122
+
123
+ @mcp.tool
124
+ def search_alphafolds(
125
+ uniprot_accs: set[str],
126
+ limit: Annotated[int, Field(gt=0, description="Limit the number of entries returned")] = 100,
127
+ ) -> Annotated[
128
+ set[str],
129
+ Field(description="Set of uniprot accessions which have an AlphaFold entry"),
130
+ ]:
131
+ """Search for AlphaFold entries in UniProtKB accessions."""
132
+ # each uniprot accesion can have one or more AlphaFold IDs
133
+ # an AlphaFold ID is the same as the uniprot accession
134
+ # so we return a subset of uniprot_accs
135
+ results = search4af(uniprot_accs, limit)
136
+ return {k for k, v in results.items() if v}
137
+
138
+
139
+ mcp.tool(search4emdb, name="search_emdb")
140
+
141
+
142
+ @mcp.tool
143
+ def fetch_alphafold_structures(uniprot_accs: set[str], save_dir: Path) -> list[AlphaFoldEntry]:
144
+ """Fetch the AlphaFold summary and mmcif file for given UniProt accessions.
145
+
146
+ Args:
147
+ uniprot_accs: A set of UniProt accessions.
148
+ save_dir: The directory to save the fetched files.
149
+
150
+ Returns:
151
+ A list of AlphaFold entries.
152
+ """
153
+ what: set[DownloadableFormat] = {"cif"}
154
+ return alphafold_fetch(uniprot_accs, save_dir, what)
155
+
156
+
157
+ mcp.tool(emdb_fetch, name="fetch_emdb_volumes")
158
+
159
+
160
+ @mcp.tool
161
+ def alphafold_confidence_filter(file: Path, query: ConfidenceFilterQuery, filtered_dir: Path) -> ConfidenceFilterResult:
162
+ """Take a mmcif/PDB file and filter it based on confidence (plDDT) scores.
163
+
164
+ If passes filter writes file to filtered_dir with residues above confidence threshold.
165
+ """
166
+ return filter_file_on_residues(file, query, filtered_dir)
167
+
168
+
169
+ @mcp.prompt
170
+ def candidate_structures(
171
+ species: str = "Human",
172
+ cellular_location: str = "nucleus",
173
+ confidence: int = 90,
174
+ min_residues: int = 100,
175
+ max_residues: int = 200,
176
+ ) -> str:
177
+ """Prompt to find candidate structures.
178
+
179
+ Args:
180
+ species: The species to search for (default: "Human").
181
+ cellular_location: The cellular location to search for (default: "nucleus").
182
+ confidence: The confidence threshold for AlphaFold structures (default: 90).
183
+ min_residues: Minimum number of high confidence residues (default: 100).
184
+ max_residues: Maximum number of high confidence residues (default: 200).
185
+
186
+ Returns:
187
+ A prompt string to find candidate structures.
188
+ """
189
+ return dedent(f"""\
190
+ Given the species '{species}' and cellular location '{cellular_location}' find the candidate structures.
191
+ Download structures from 2 sources namely PDB and Alphafold.
192
+ For alphafold I only want to use high confidence scores of over {confidence}.
193
+ and only keep structures with number of high confidence residues between {min_residues} and {max_residues}.
194
+
195
+ 1. Search uniprot for proteins related to {species} and {cellular_location}.
196
+ 1. For the species find the NCBI taxonomy id.
197
+ 2. For cellular location find the associated GO term.
198
+ 3. Find uniprot accessions based on NCBI taxonomy id and cellular location GO term.
199
+ 2. For PDB
200
+ 1. Search for structures related to the identified proteins.
201
+ 2. Download each PDB entry from PDBe
202
+ 3. Extract chain for the protein of interest.
203
+ 3. For Alphafold
204
+ 1. Search for AlphaFold entries related to the identified proteins.
205
+ 2. Download each AlphaFold entry.
206
+ 3. Filter the structures based on {confidence} as confidence
207
+ and nr residues between {min_residues} and {max_residues}.
208
+ """)
@@ -0,0 +1,68 @@
1
+ """Dask helper functions."""
2
+
3
+ import logging
4
+ import os
5
+
6
+ from dask.distributed import LocalCluster
7
+ from distributed.deploy.cluster import Cluster
8
+ from psutil import cpu_count
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def configure_dask_scheduler(
14
+ scheduler_address: str | Cluster | None,
15
+ name: str,
16
+ nproc: int = 1,
17
+ ) -> str | Cluster:
18
+ """Configure the Dask scheduler by reusing existing or creating a new cluster.
19
+
20
+ Args:
21
+ scheduler_address: Address of the Dask scheduler to connect to, or None for local cluster.
22
+ name: Name for the Dask cluster.
23
+ nproc: Number of processes to use per worker for CPU support.
24
+
25
+ Returns:
26
+ A Dask Cluster instance or a string address for the scheduler.
27
+ """
28
+ if scheduler_address is None:
29
+ scheduler_address = _configure_cpu_dask_scheduler(nproc, name)
30
+ logger.info(f"Using local Dask cluster: {scheduler_address}")
31
+
32
+ return scheduler_address
33
+
34
+
35
+ def nr_cpus() -> int:
36
+ """Determine the number of CPU cores to use.
37
+
38
+ If the environment variables SLURM_CPUS_PER_TASK or OMP_NUM_THREADS are set,
39
+ their value is used. Otherwise, the number of physical CPU cores is returned.
40
+
41
+ Returns:
42
+ The number of CPU cores to use.
43
+
44
+ Raises:
45
+ ValueError: If the number of physical CPU cores cannot be determined.
46
+ """
47
+ physical_cores = cpu_count(logical=False)
48
+ if physical_cores is None:
49
+ msg = "Cannot determine number of logical CPU cores."
50
+ raise ValueError(msg)
51
+ for var in ["SLURM_CPUS_PER_TASK", "OMP_NUM_THREADS"]:
52
+ value = os.environ.get(var)
53
+ if value is not None:
54
+ logger.warning(
55
+ 'Not using all CPU cores (%s) of machine, environment variable "%s" is set to %s.',
56
+ physical_cores,
57
+ var,
58
+ value,
59
+ )
60
+ return int(value)
61
+ return physical_cores
62
+
63
+
64
+ def _configure_cpu_dask_scheduler(nproc: int, name: str) -> LocalCluster:
65
+ total_cpus = nr_cpus()
66
+ n_workers = total_cpus // nproc
67
+ # Use single thread per worker to prevent GIL slowing down the computations
68
+ return LocalCluster(name=name, threads_per_worker=1, n_workers=n_workers)
@@ -0,0 +1 @@
1
+ """Modules related to PDBe (Protein Data Bank in Europe)."""
@@ -0,0 +1,51 @@
1
+ """Module for fetching structures from PDBe."""
2
+
3
+ from collections.abc import Iterable, Mapping
4
+ from pathlib import Path
5
+
6
+ from protein_quest.utils import retrieve_files
7
+
8
+
9
+ def _map_id_mmcif(pdb_id: str) -> tuple[str, str]:
10
+ """
11
+ Map PDB id to a download gzipped mmCIF url and file.
12
+
13
+ For example for PDB id "8WAS", the url will be
14
+ "https://www.ebi.ac.uk/pdbe/entry-files/download/8was.cif.gz" and the file will be "8was.cif.gz".
15
+
16
+ Args:
17
+ pdb_id: The PDB ID to map.
18
+
19
+ Returns:
20
+ A tuple containing the URL to download the mmCIF file and the filename.
21
+ """
22
+ fn = f"{pdb_id.lower()}.cif.gz"
23
+ # On PDBe you can sometimes download an updated mmCIF file,
24
+ # Current url is for the archive mmCIF file
25
+ # TODO check if archive is OK, or if we should try to download the updated file
26
+ # this will cause many more requests, so we should only do this if needed
27
+ url = f"https://www.ebi.ac.uk/pdbe/entry-files/download/{fn}"
28
+ return url, fn
29
+
30
+
31
+ async def fetch(ids: Iterable[str], save_dir: Path, max_parallel_downloads: int = 5) -> Mapping[str, Path]:
32
+ """Fetches mmCIF files from the PDBe database.
33
+
34
+ Args:
35
+ ids: A set of PDB IDs to fetch.
36
+ save_dir: The directory to save the fetched mmCIF files to.
37
+ max_parallel_downloads: The maximum number of parallel downloads.
38
+
39
+ Returns:
40
+ A dict of id and paths to the downloaded mmCIF files.
41
+ """
42
+
43
+ # The future result, is in a different order than the input ids,
44
+ # so we need to map the ids to the urls and filenames.
45
+
46
+ id2urls = {pdb_id: _map_id_mmcif(pdb_id) for pdb_id in ids}
47
+ urls = list(id2urls.values())
48
+ id2paths = {pdb_id: save_dir / fn for pdb_id, (_, fn) in id2urls.items()}
49
+
50
+ await retrieve_files(urls, save_dir, max_parallel_downloads, desc="Downloading PDBe mmCIF files")
51
+ return id2paths