protein-quest 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of protein-quest might be problematic. Click here for more details.

File without changes
@@ -0,0 +1 @@
1
+ __version__ = "0.3.0"
@@ -0,0 +1 @@
1
+ """Modules related to AlphaFold Knowledge Base."""
@@ -0,0 +1,153 @@
1
+ """Module for filtering alphafold structures on confidence."""
2
+
3
+ import logging
4
+ from collections.abc import Generator
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ import gemmi
9
+
10
+ from protein_quest.pdbe.io import write_structure
11
+
12
+ """
13
+ Methods to filter AlphaFoldDB structures on confidence scores.
14
+
15
+ In AlphaFold PDB files, the b-factor column has the
16
+ predicted local distance difference test (pLDDT).
17
+
18
+ See https://www.ebi.ac.uk/training/online/courses/alphafold/inputs-and-outputs/evaluating-alphafolds-predicted-structures-using-confidence-scores/plddt-understanding-local-confidence/
19
+ """
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def find_high_confidence_residues(structure: gemmi.Structure, confidence: float) -> Generator[int]:
25
+ """Find residues in the structure with pLDDT confidence above the given threshold.
26
+
27
+ Args:
28
+ structure: The AlphaFoldDB structure to search.
29
+ confidence: The confidence threshold (pLDDT) to use for filtering.
30
+
31
+ Yields:
32
+ The sequence numbers of residues with pLDDT above the confidence threshold.
33
+ """
34
+ for model in structure:
35
+ for chain in model:
36
+ for res in chain:
37
+ res_confidence = res[0].b_iso
38
+ if res_confidence > confidence:
39
+ seqid = res.seqid.num
40
+ if seqid is not None:
41
+ yield seqid
42
+
43
+
44
+ def filter_out_low_confidence_residues(structure: gemmi.Structure, allowed_residues: set[int]) -> gemmi.Structure:
45
+ """Filter out residues from the structure that do not have high confidence.
46
+
47
+ Args:
48
+ structure: The AlphaFoldDB structure to filter.
49
+ allowed_residues: The set of residue sequence numbers to keep.
50
+
51
+ Returns:
52
+ A new AlphaFoldDB structure with low confidence residues removed.
53
+ """
54
+ new_structure = structure.clone()
55
+ for model in new_structure:
56
+ new_chains = []
57
+ for chain in model:
58
+ new_chain = gemmi.Chain(chain.name)
59
+ for res in chain:
60
+ if res.seqid.num in allowed_residues:
61
+ new_chain.add_residue(res)
62
+ new_chains.append(new_chain)
63
+ for new_chain in new_chains:
64
+ model.remove_chain(new_chain.name)
65
+ model.add_chain(new_chain)
66
+ return new_structure
67
+
68
+
69
+ @dataclass
70
+ class ConfidenceFilterQuery:
71
+ """Query for filtering AlphaFoldDB structures based on confidence.
72
+
73
+ Parameters:
74
+ confidence: The confidence threshold for filtering residues.
75
+ Residues with a pLDDT (b-factor) above this value are considered high confidence.
76
+ min_threshold: The minimum number of high-confidence residues required to keep the structure.
77
+ max_threshold: The maximum number of high-confidence residues required to keep the structure.
78
+ """
79
+
80
+ confidence: float
81
+ min_threshold: int
82
+ max_threshold: int
83
+
84
+
85
+ @dataclass
86
+ class ConfidenceFilterResult:
87
+ """Result of filtering AlphaFoldDB structures based on confidence (pLDDT).
88
+
89
+ Parameters:
90
+ input_file: The name of the mmcif/PDB file that was processed.
91
+ count: The number of residues with a pLDDT above the confidence threshold.
92
+ filtered_file: The path to the filtered mmcif/PDB file, if passed filter.
93
+ """
94
+
95
+ input_file: str
96
+ count: int
97
+ filtered_file: Path | None = None
98
+
99
+
100
+ def filter_file_on_residues(file: Path, query: ConfidenceFilterQuery, filtered_dir: Path) -> ConfidenceFilterResult:
101
+ """Filter a single AlphaFoldDB structure file based on confidence.
102
+
103
+ Args:
104
+ file: The path to the PDB file to filter.
105
+ query: The confidence filter query.
106
+ filtered_dir: The directory to save the filtered PDB file.
107
+
108
+ Returns:
109
+ result with filtered_file property set to Path where filtered PDB file is saved.
110
+ or None if structure was filtered out.
111
+ """
112
+ structure = gemmi.read_structure(str(file))
113
+ residues = set(find_high_confidence_residues(structure, query.confidence))
114
+ count = len(residues)
115
+ if count < query.min_threshold or count > query.max_threshold:
116
+ # Skip structure that is outside the min and max threshold
117
+ # just return number of high confidence residues
118
+ return ConfidenceFilterResult(
119
+ input_file=file.name,
120
+ count=count,
121
+ )
122
+ filtered_file = filtered_dir / file.name
123
+ new_structure = filter_out_low_confidence_residues(
124
+ structure,
125
+ residues,
126
+ )
127
+ write_structure(new_structure, filtered_file)
128
+ return ConfidenceFilterResult(
129
+ input_file=file.name,
130
+ count=count,
131
+ filtered_file=filtered_file,
132
+ )
133
+
134
+
135
+ def filter_files_on_confidence(
136
+ alphafold_pdb_files: list[Path], query: ConfidenceFilterQuery, filtered_dir: Path
137
+ ) -> Generator[ConfidenceFilterResult]:
138
+ """Filter AlphaFoldDB structures based on confidence.
139
+
140
+ Args:
141
+ alphafold_pdb_files: List of mmcif/PDB files from AlphaFoldDB to filter.
142
+ query: The confidence filter query containing the confidence thresholds.
143
+ filtered_dir: Directory where the filtered mmcif/PDB files will be saved.
144
+
145
+ Yields:
146
+ For each mmcif/PDB files yields whether it was filtered or not,
147
+ and number of residues with pLDDT above the confidence threshold.
148
+ """
149
+ # Note on why code looks duplicated:
150
+ # In ../filter.py:filter_files_on_residues() we filter on number of residues on a file level
151
+ # here we filter on file level and inside file remove low confidence residues
152
+ for pdb_file in alphafold_pdb_files:
153
+ yield filter_file_on_residues(pdb_file, query, filtered_dir)
@@ -0,0 +1,38 @@
1
+ # ruff: noqa: N815 allow camelCase follow what api returns
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class EntrySummary:
7
+ """Dataclass representing a summary of an AlphaFold entry.
8
+
9
+ Modelled after EntrySummary in https://alphafold.ebi.ac.uk/api/openapi.json
10
+ """
11
+
12
+ entryId: str
13
+ uniprotAccession: str
14
+ uniprotId: str
15
+ uniprotDescription: str
16
+ taxId: int
17
+ organismScientificName: str
18
+ uniprotStart: int
19
+ uniprotEnd: int
20
+ uniprotSequence: str
21
+ modelCreatedDate: str
22
+ latestVersion: int
23
+ allVersions: list[int]
24
+ bcifUrl: str
25
+ cifUrl: str
26
+ pdbUrl: str
27
+ paeImageUrl: str
28
+ paeDocUrl: str
29
+ gene: str | None = None
30
+ sequenceChecksum: str | None = None
31
+ sequenceVersionDate: str | None = None
32
+ amAnnotationsUrl: str | None = None
33
+ amAnnotationsHg19Url: str | None = None
34
+ amAnnotationsHg38Url: str | None = None
35
+ isReviewed: bool | None = None
36
+ isReferenceProteome: bool | None = None
37
+ # TODO add new fields from https://alphafold.ebi.ac.uk/#/public-api/get_uniprot_summary_api_uniprot_summary__qualifier__json_get
38
+ # TODO like fractionPlddt* fields which can be used in filter_files_on_confidence()
@@ -0,0 +1,314 @@
1
+ """Module for fetch Alphafold data."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from asyncio import Semaphore
6
+ from collections.abc import AsyncGenerator, Iterable
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from textwrap import dedent
10
+ from typing import Literal
11
+
12
+ from aiohttp_retry import RetryClient
13
+ from aiopath import AsyncPath
14
+ from cattrs.preconf.orjson import make_converter
15
+ from tqdm.asyncio import tqdm
16
+
17
+ from protein_quest.alphafold.entry_summary import EntrySummary
18
+ from protein_quest.utils import friendly_session, retrieve_files
19
+
20
+ logger = logging.getLogger(__name__)
21
+ converter = make_converter()
22
+
23
+ DownloadableFormat = Literal[
24
+ "bcif",
25
+ "cif",
26
+ "pdb",
27
+ "paeImage",
28
+ "paeDoc",
29
+ "amAnnotations",
30
+ "amAnnotationsHg19",
31
+ "amAnnotationsHg38",
32
+ ]
33
+ """Types of formats that can be downloaded from the AlphaFold web service."""
34
+
35
+ downloadable_formats: set[DownloadableFormat] = {
36
+ "bcif",
37
+ "cif",
38
+ "pdb",
39
+ "paeImage",
40
+ "paeDoc",
41
+ "amAnnotations",
42
+ "amAnnotationsHg19",
43
+ "amAnnotationsHg38",
44
+ }
45
+ """Set of formats that can be downloaded from the AlphaFold web service."""
46
+
47
+
48
+ def _camel_to_snake_case(name: str) -> str:
49
+ """Convert a camelCase string to snake_case."""
50
+ return "".join(["_" + c.lower() if c.isupper() else c for c in name]).lstrip("_")
51
+
52
+
53
+ @dataclass
54
+ class AlphaFoldEntry:
55
+ """AlphaFoldEntry represents a minimal single entry in the AlphaFold database.
56
+
57
+ See https://alphafold.ebi.ac.uk/api-docs for more details on the API and data structure.
58
+ """
59
+
60
+ uniprot_acc: str
61
+ summary: EntrySummary | None
62
+ bcif_file: Path | None = None
63
+ cif_file: Path | None = None
64
+ pdb_file: Path | None = None
65
+ pae_image_file: Path | None = None
66
+ pae_doc_file: Path | None = None
67
+ am_annotations_file: Path | None = None
68
+ am_annotations_hg19_file: Path | None = None
69
+ am_annotations_hg38_file: Path | None = None
70
+
71
+ @classmethod
72
+ def format2attr(cls, dl_format: DownloadableFormat) -> str:
73
+ """Get the attribute name for a specific download format.
74
+
75
+ Args:
76
+ dl_format: The format for which to get the attribute name.
77
+
78
+ Returns:
79
+ The attribute name corresponding to the download format.
80
+
81
+ Raises:
82
+ ValueError: If the format is not valid.
83
+ """
84
+ if dl_format not in downloadable_formats:
85
+ msg = f"Invalid format: {dl_format}. Valid formats are: {downloadable_formats}"
86
+ raise ValueError(msg)
87
+ return _camel_to_snake_case(dl_format) + "_file"
88
+
89
+ def by_format(self, dl_format: DownloadableFormat) -> Path | None:
90
+ """Get the file path for a specific format.
91
+
92
+ Args:
93
+ dl_format: The format for which to get the file path.
94
+
95
+ Returns:
96
+ The file path corresponding to the download format.
97
+ Or None if the file is not set.
98
+
99
+ Raises:
100
+ ValueError: If the format is not valid.
101
+ """
102
+ attr = self.format2attr(dl_format)
103
+ return getattr(self, attr, None)
104
+
105
+ def nr_of_files(self) -> int:
106
+ """Nr of _file properties that are set
107
+
108
+ Returns:
109
+ The number of _file properties that are set.
110
+ """
111
+ return sum(1 for attr in vars(self) if attr.endswith("_file") and getattr(self, attr) is not None)
112
+
113
+
114
+ async def fetch_summary(
115
+ qualifier: str, session: RetryClient, semaphore: Semaphore, save_dir: Path | None
116
+ ) -> list[EntrySummary]:
117
+ """Fetches a summary from the AlphaFold database for a given qualifier.
118
+
119
+ Args:
120
+ qualifier: The uniprot accession for the protein or entry to fetch.
121
+ For example `Q5VSL9`.
122
+ session: An asynchronous HTTP client session with retry capabilities.
123
+ semaphore: A semaphore to limit the number of concurrent requests.
124
+ save_dir: An optional directory to save the fetched summary as a JSON file.
125
+ If set and summary exists then summary will be loaded from disk instead of being fetched from the API.
126
+ If not set then the summary will not be saved to disk and will always be fetched from the API.
127
+
128
+ Returns:
129
+ A list of EntrySummary objects representing the fetched summary.
130
+
131
+ Raises:
132
+ HTTPError: If the HTTP request returns an error status code.
133
+ Exception: If there is an error during file reading/writing or data conversion.
134
+ """
135
+ url = f"https://alphafold.ebi.ac.uk/api/prediction/{qualifier}"
136
+ fn: AsyncPath | None = None
137
+ if save_dir is not None:
138
+ fn = AsyncPath(save_dir / f"{qualifier}.json")
139
+ if await fn.exists():
140
+ logger.debug(f"File {fn} already exists. Skipping download from {url}.")
141
+ raw_data = await fn.read_bytes()
142
+ return converter.loads(raw_data, list[EntrySummary])
143
+ async with semaphore, session.get(url) as response:
144
+ response.raise_for_status()
145
+ raw_data = await response.content.read()
146
+ if fn is not None:
147
+ await fn.write_bytes(raw_data)
148
+ return converter.loads(raw_data, list[EntrySummary])
149
+
150
+
151
+ async def fetch_summaries(
152
+ qualifiers: Iterable[str], save_dir: Path | None = None, max_parallel_downloads: int = 5
153
+ ) -> AsyncGenerator[EntrySummary]:
154
+ semaphore = Semaphore(max_parallel_downloads)
155
+ if save_dir is not None:
156
+ save_dir.mkdir(parents=True, exist_ok=True)
157
+ async with friendly_session() as session:
158
+ tasks = [fetch_summary(qualifier, session, semaphore, save_dir) for qualifier in qualifiers]
159
+ summaries_per_qualifier: list[list[EntrySummary]] = await tqdm.gather(
160
+ *tasks, desc="Fetching Alphafold summaries"
161
+ )
162
+ for summaries in summaries_per_qualifier:
163
+ for summary in summaries:
164
+ yield summary
165
+
166
+
167
+ def url2name(url: str) -> str:
168
+ """Given a URL, return the final path component as the name of the file."""
169
+ return url.split("/")[-1]
170
+
171
+
172
+ async def fetch_many_async(
173
+ ids: Iterable[str], save_dir: Path, what: set[DownloadableFormat], max_parallel_downloads: int = 5
174
+ ) -> AsyncGenerator[AlphaFoldEntry]:
175
+ """Asynchronously fetches summaries and pdb and pae (predicted alignment error) files from
176
+ [AlphaFold Protein Structure Database](https://alphafold.ebi.ac.uk/).
177
+
178
+ Args:
179
+ ids: A set of Uniprot IDs to fetch.
180
+ save_dir: The directory to save the fetched files to.
181
+ what: A set of formats to download.
182
+ max_parallel_downloads: The maximum number of parallel downloads.
183
+
184
+ Yields:
185
+ A dataclass containing the summary, pdb file, and pae file.
186
+ """
187
+ summaries = [s async for s in fetch_summaries(ids, save_dir, max_parallel_downloads=max_parallel_downloads)]
188
+
189
+ files = files_to_download(what, summaries)
190
+
191
+ await retrieve_files(
192
+ files,
193
+ save_dir,
194
+ desc="Downloading AlphaFold files",
195
+ max_parallel_downloads=max_parallel_downloads,
196
+ )
197
+ for summary in summaries:
198
+ yield AlphaFoldEntry(
199
+ uniprot_acc=summary.uniprotAccession,
200
+ summary=summary,
201
+ bcif_file=save_dir / url2name(summary.bcifUrl) if "bcif" in what else None,
202
+ cif_file=save_dir / url2name(summary.cifUrl) if "cif" in what else None,
203
+ pdb_file=save_dir / url2name(summary.pdbUrl) if "pdb" in what else None,
204
+ pae_image_file=save_dir / url2name(summary.paeImageUrl) if "paeImage" in what else None,
205
+ pae_doc_file=save_dir / url2name(summary.paeDocUrl) if "paeDoc" in what else None,
206
+ am_annotations_file=(
207
+ save_dir / url2name(summary.amAnnotationsUrl)
208
+ if "amAnnotations" in what and summary.amAnnotationsUrl
209
+ else None
210
+ ),
211
+ am_annotations_hg19_file=(
212
+ save_dir / url2name(summary.amAnnotationsHg19Url)
213
+ if "amAnnotationsHg19" in what and summary.amAnnotationsHg19Url
214
+ else None
215
+ ),
216
+ am_annotations_hg38_file=(
217
+ save_dir / url2name(summary.amAnnotationsHg38Url)
218
+ if "amAnnotationsHg38" in what and summary.amAnnotationsHg38Url
219
+ else None
220
+ ),
221
+ )
222
+
223
+
224
+ def files_to_download(what: set[DownloadableFormat], summaries: Iterable[EntrySummary]) -> set[tuple[str, str]]:
225
+ if not (set(what) <= downloadable_formats):
226
+ msg = (
227
+ f"Invalid format(s) specified: {set(what) - downloadable_formats}. "
228
+ f"Valid formats are: {downloadable_formats}"
229
+ )
230
+ raise ValueError(msg)
231
+
232
+ files: set[tuple[str, str]] = set()
233
+ for summary in summaries:
234
+ for fmt in what:
235
+ url = getattr(summary, f"{fmt}Url", None)
236
+ if url is None:
237
+ logger.warning(f"Summary {summary.uniprotAccession} does not have a URL for format '{fmt}'. Skipping.")
238
+ continue
239
+ file = (url, url2name(url))
240
+ files.add(file)
241
+ return files
242
+
243
+
244
+ class NestedAsyncIOLoopError(RuntimeError):
245
+ """Custom error for nested async I/O loops."""
246
+
247
+ pass
248
+
249
+
250
+ def fetch_many(
251
+ ids: Iterable[str], save_dir: Path, what: set[DownloadableFormat], max_parallel_downloads: int = 5
252
+ ) -> list[AlphaFoldEntry]:
253
+ """Synchronously fetches summaries and pdb and pae files from AlphaFold Protein Structure Database.
254
+
255
+ Args:
256
+ ids: A set of Uniprot IDs to fetch.
257
+ save_dir: The directory to save the fetched files to.
258
+ what: A set of formats to download.
259
+ max_parallel_downloads: The maximum number of parallel downloads.
260
+
261
+ Returns:
262
+ A list of AlphaFoldEntry dataclasses containing the summary, pdb file, and pae file.
263
+
264
+ Raises:
265
+ NestedAsyncIOLoopError: If called from a nested async I/O loop like in a Jupyter notebook.
266
+ """
267
+
268
+ async def gather_entries():
269
+ return [
270
+ entry
271
+ async for entry in fetch_many_async(ids, save_dir, what, max_parallel_downloads=max_parallel_downloads)
272
+ ]
273
+
274
+ try:
275
+ return asyncio.run(gather_entries())
276
+ except RuntimeError as e:
277
+ msg = dedent("""\
278
+ Can not run async method from an environment where the asyncio event loop is already running.
279
+ Like a Jupyter notebook.
280
+
281
+ Please use the `fetch_many_async` function directly or before call
282
+
283
+ import nest_asyncio
284
+ nest_asyncio.apply()
285
+ """)
286
+ raise NestedAsyncIOLoopError(msg) from e
287
+
288
+
289
+ def relative_to(entry: AlphaFoldEntry, session_dir: Path) -> AlphaFoldEntry:
290
+ """Convert paths in an AlphaFoldEntry to be relative to the session directory.
291
+
292
+ Args:
293
+ entry: An AlphaFoldEntry instance with absolute paths.
294
+ session_dir: The session directory to which the paths should be made relative.
295
+
296
+ Returns:
297
+ An AlphaFoldEntry instance with paths relative to the session directory.
298
+ """
299
+ return AlphaFoldEntry(
300
+ uniprot_acc=entry.uniprot_acc,
301
+ summary=entry.summary,
302
+ bcif_file=entry.bcif_file.relative_to(session_dir) if entry.bcif_file else None,
303
+ cif_file=entry.cif_file.relative_to(session_dir) if entry.cif_file else None,
304
+ pdb_file=entry.pdb_file.relative_to(session_dir) if entry.pdb_file else None,
305
+ pae_image_file=entry.pae_image_file.relative_to(session_dir) if entry.pae_image_file else None,
306
+ pae_doc_file=entry.pae_doc_file.relative_to(session_dir) if entry.pae_doc_file else None,
307
+ am_annotations_file=entry.am_annotations_file.relative_to(session_dir) if entry.am_annotations_file else None,
308
+ am_annotations_hg19_file=(
309
+ entry.am_annotations_hg19_file.relative_to(session_dir) if entry.am_annotations_hg19_file else None
310
+ ),
311
+ am_annotations_hg38_file=(
312
+ entry.am_annotations_hg38_file.relative_to(session_dir) if entry.am_annotations_hg38_file else None
313
+ ),
314
+ )