openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
openprotein/fasta.py ADDED
@@ -0,0 +1,92 @@
1
+ from typing import Iterator, Sequence, overload
2
+
3
+
4
+ @overload
5
+ def parse_stream(
6
+ lines: Iterator[str], comment: str = "#"
7
+ ) -> Iterator[tuple[str, str]]: ...
8
+
9
+
10
+ @overload
11
+ def parse_stream(
12
+ lines: Iterator[bytes], comment: str = "#"
13
+ ) -> Iterator[tuple[bytes, bytes]]: ...
14
+
15
+
16
+ def parse_stream(
17
+ lines: Iterator[str] | Iterator[bytes], comment: str = "#"
18
+ ) -> Iterator[tuple[str, str]] | Iterator[tuple[bytes, bytes]]:
19
+ is_bytes: bool | None = None
20
+ name = None
21
+ sequence = []
22
+
23
+ for line in lines:
24
+ if not line:
25
+ continue # skip empty lines
26
+ if is_bytes := isinstance(line, bytes):
27
+ line = line.decode()
28
+ if line.startswith(comment):
29
+ continue
30
+ line = line.strip()
31
+ if line.startswith(">"):
32
+ if name is not None:
33
+ sequence = "".join(sequence)
34
+ if is_bytes:
35
+ name = name.encode()
36
+ sequence = sequence.encode()
37
+ yield name, sequence
38
+ else:
39
+ yield name, sequence
40
+ name = line[1:].strip()
41
+ sequence = []
42
+ else:
43
+ sequence.append(line.strip())
44
+
45
+ if name is not None:
46
+ sequence = "".join(sequence)
47
+ if is_bytes:
48
+ name = name.encode()
49
+ sequence = sequence.encode()
50
+ yield name, sequence
51
+ else:
52
+ yield name, sequence
53
+
54
+
55
+ def parse(
56
+ f: Sequence[str] | Sequence[bytes], comment: str = "#"
57
+ ) -> tuple[list[str], list[str]] | tuple[list[bytes], list[bytes]]:
58
+ is_bytes: bool | None = None
59
+ names = []
60
+ sequences = []
61
+ name = None
62
+ sequence = []
63
+ for line in f:
64
+ if is_bytes := isinstance(line, bytes):
65
+ line = line.decode()
66
+ if line.startswith(comment):
67
+ continue
68
+ line = line.strip()
69
+ if line.startswith(">"):
70
+ # its a new entry
71
+ if name is not None:
72
+ sequence = "".join(sequence)
73
+ if is_bytes:
74
+ name = name.encode()
75
+ sequence = sequence.encode()
76
+ names.append(name)
77
+ sequences.append(sequence)
78
+ # reset the reading
79
+ name = line[1:]
80
+ sequence = []
81
+ else:
82
+ sequence.append(line.upper())
83
+ if name is not None:
84
+ # last entry
85
+ sequence = "".join(sequence)
86
+ if is_bytes:
87
+ name = name.encode()
88
+ sequence = sequence.encode()
89
+ names.append(name)
90
+ sequences.append(sequence)
91
+
92
+ return names, sequences
@@ -0,0 +1,21 @@
1
+ """
2
+ Fold module for predicting structures on OpenProtein.
3
+
4
+ isort:skip_file
5
+ """
6
+
7
+ from .schemas import FoldJob, FoldMetadata
8
+ from .models import FoldModel
9
+ from .esmfold import ESMFoldModel
10
+ from .alphafold2 import AlphaFold2Model
11
+ from .boltz import (
12
+ Boltz1Model,
13
+ Boltz1xModel,
14
+ Boltz2Model,
15
+ BoltzAffinity,
16
+ BoltzConfidence,
17
+ BoltzConstraint,
18
+ BoltzProperty,
19
+ )
20
+ from .future import FoldResultFuture, FoldComplexResultFuture
21
+ from .fold import FoldAPI
@@ -0,0 +1,131 @@
1
+ """Community-based AlphaFold 2 model running using ColabFold."""
2
+
3
+ import warnings
4
+ from collections import Counter
5
+
6
+ from openprotein.align import MSAFuture
7
+ from openprotein.base import APISession
8
+ from openprotein.common import ModelMetadata
9
+ from openprotein.protein import Protein
10
+
11
+ from . import api
12
+ from .future import FoldComplexResultFuture
13
+ from .models import FoldModel
14
+
15
+
16
+ class AlphaFold2Model(FoldModel):
17
+ """
18
+ Class providing inference endpoints for AlphaFold2 structure prediction models, based on the implementation by ColabFold.
19
+ """
20
+
21
+ model_id: str = "alphafold2"
22
+
23
+ def __init__(
24
+ self,
25
+ session: APISession,
26
+ model_id: str,
27
+ metadata: ModelMetadata | None = None,
28
+ ):
29
+ super().__init__(session=session, model_id=model_id, metadata=metadata)
30
+
31
+ def fold(
32
+ self,
33
+ proteins: list[Protein] | MSAFuture | None = None,
34
+ num_recycles: int | None = None,
35
+ num_models: int = 1,
36
+ num_relax: int = 0,
37
+ **kwargs,
38
+ ) -> FoldComplexResultFuture:
39
+ """
40
+ Post sequences to alphafold model.
41
+
42
+ Parameters
43
+ ----------
44
+ proteins : List[Protein] | MSAFuture
45
+ List of protein sequences to fold. `Protein` objects must be tagged with an `msa`. Alternatively, supply an `MSAFuture` to use all query sequences as a multimer.
46
+ num_recycles : int
47
+ number of times to recycle models
48
+ num_models : int
49
+ number of models to train - best model will be used
50
+ max_msa : Union[str, int]
51
+ maximum number of sequences in the msa to use.
52
+ relax_max_iterations : int
53
+ maximum number of iterations
54
+
55
+ Returns
56
+ -------
57
+ job : Job
58
+ """
59
+ if "msa" in kwargs:
60
+ warnings.warn(
61
+ "Inputs to AlphaFold 2 have been updated. 'msa' should be supplied as 'proteins' argument. Support will be dropped in the future."
62
+ )
63
+ proteins = kwargs["msa"]
64
+ if "ligands" in kwargs or "dnas" in kwargs or "rnas" in kwargs:
65
+ with warnings.catch_warnings():
66
+ warnings.simplefilter("always") # Force warning to always show
67
+ warnings.warn(
68
+ "Alphafold 2 only supports proteins. All other chains will be ignored"
69
+ )
70
+ if proteins is None:
71
+ raise TypeError("Expected 'proteins' argument")
72
+ if isinstance(proteins, list):
73
+ msa_to_seed: dict[str, Counter] = dict()
74
+ for protein in proteins:
75
+ if (msa := protein.msa) is not None:
76
+ msa_id = msa.id if isinstance(msa, MSAFuture) else msa
77
+ if msa_id in msa_to_seed:
78
+ seeds = msa_to_seed[msa_id]
79
+ else:
80
+ from openprotein.align import AlignAPI
81
+
82
+ align_api = getattr(self.session, "align", None)
83
+ assert isinstance(align_api, AlignAPI)
84
+ seed = align_api.get_seed(job_id=msa_id)
85
+ # need a counter so we can make sure later that the proteins make up the msa completely
86
+ seeds = Counter(seed.split(":"))
87
+ msa_to_seed[msa_id] = seeds
88
+ # check that this protein is in the seed
89
+ if protein.sequence.decode() not in seeds:
90
+ raise ValueError(
91
+ f"Expected specified msa_id {msa_id} for protein {protein.sequence} to contain the sequence as part of its seed/query"
92
+ )
93
+ else:
94
+ raise ValueError("Expected msa for protein when using AlphaFold 2")
95
+ # now make sure we only have one msa
96
+ if len(msa_to_seed) > 1:
97
+ raise ValueError("Expected only 1 unique msa when using AlphaFold 2")
98
+ # now check that the list of proteins completely make up the msa
99
+ seeds = list(msa_to_seed.values())[0] # should have just 1
100
+ for protein in proteins:
101
+ # make sure to account for multimers
102
+ seeds[protein.sequence.decode()] -= (
103
+ len(protein.chain_id) if isinstance(protein.chain_id, list) else 1
104
+ )
105
+ # handle when too many of a sequence in the list of proteins
106
+ if seeds[protein.sequence.decode()] < 0:
107
+ raise ValueError(
108
+ "List of proteins does not completely make up the MSA seed"
109
+ )
110
+ if seeds.total() != 0:
111
+ # handle when overall mismatch - 1 and -1 case is handled above
112
+ raise ValueError(
113
+ "List of proteins does not completely make up the MSA seed"
114
+ )
115
+ msa_id = list(msa_to_seed.keys())[0]
116
+ elif isinstance(proteins, MSAFuture):
117
+ msa_id = proteins.id
118
+ else:
119
+ raise TypeError("Expected either list of Proteins or MSAFuture")
120
+
121
+ return FoldComplexResultFuture.create(
122
+ session=self.session,
123
+ job=api.fold_models_post(
124
+ self.session,
125
+ model_id=self.model_id,
126
+ msa_id=msa_id,
127
+ num_recycles=num_recycles,
128
+ num_models=num_models,
129
+ num_relax=num_relax,
130
+ ),
131
+ )
@@ -0,0 +1,287 @@
1
+ """Fold REST API interface for making HTTP calls to our fold backend."""
2
+
3
+ import io
4
+ from typing import Literal
5
+
6
+ import numpy as np
7
+ from pydantic import TypeAdapter
8
+
9
+ from openprotein.base import APISession
10
+ from openprotein.common import ModelMetadata
11
+ from openprotein.errors import HTTPError
12
+
13
+ from .schemas import FoldJob, FoldMetadata
14
+
15
+ PATH_PREFIX = "v1/fold"
16
+
17
+
18
+ def fold_models_list_get(session: APISession) -> list[str]:
19
+ """
20
+ List available fold models.
21
+
22
+ Parameters
23
+ ----------
24
+ session : APISession
25
+ API session.
26
+
27
+ Returns
28
+ -------
29
+ list of str
30
+ List of model names.
31
+ """
32
+ endpoint = PATH_PREFIX + "/models"
33
+ response = session.get(endpoint)
34
+ result = response.json()
35
+ return result
36
+
37
+
38
+ def fold_model_get(session: APISession, model_id: str) -> ModelMetadata:
39
+ """
40
+ Get metadata for a specific fold model.
41
+
42
+ Parameters
43
+ ----------
44
+ session : APISession
45
+ API session.
46
+ model_id : str
47
+ Model ID to fetch.
48
+
49
+ Returns
50
+ -------
51
+ ModelMetadata
52
+ Metadata for the specified model.
53
+ """
54
+ endpoint = PATH_PREFIX + f"/models/{model_id}"
55
+ response = session.get(endpoint)
56
+ result = response.json()
57
+ return ModelMetadata(**result)
58
+
59
+
60
+ def fold_get(session: APISession, job_id: str) -> FoldMetadata:
61
+ """
62
+ Get metadata associated with the given request ID.
63
+
64
+ Parameters
65
+ ----------
66
+ session : APISession
67
+ Session object for API communication.
68
+ job_id : str
69
+ Fold ID to fetch.
70
+
71
+ Returns
72
+ -------
73
+ FoldMetadata
74
+ Metadata about the fold job.
75
+ """
76
+ endpoint = PATH_PREFIX + f"/{job_id}"
77
+ response = session.get(endpoint)
78
+ fold = FoldMetadata.model_validate(response.json())
79
+ return fold
80
+
81
+
82
+ def fold_get_sequences(session: APISession, job_id: str) -> list[bytes]:
83
+ """
84
+ Get results associated with the given request ID.
85
+
86
+ Parameters
87
+ ----------
88
+ session : APISession
89
+ Session object for API communication.
90
+ job_id : str
91
+ Job ID to fetch.
92
+
93
+ Returns
94
+ -------
95
+ list of bytes
96
+ List of sequences as bytes.
97
+ """
98
+ endpoint = PATH_PREFIX + f"/{job_id}/sequences"
99
+ response = session.get(endpoint)
100
+ return TypeAdapter(list[bytes]).validate_python(response.json())
101
+
102
+
103
+ def fold_get_sequence_result(
104
+ session: APISession, job_id: str, sequence: bytes | str
105
+ ) -> bytes:
106
+ """
107
+ Get encoded result for a sequence from the request ID.
108
+
109
+ Parameters
110
+ ----------
111
+ session : APISession
112
+ Session object for API communication.
113
+ job_id : str
114
+ Job ID to retrieve results from.
115
+ sequence : bytes or str
116
+ Sequence to retrieve results for.
117
+
118
+ Returns
119
+ -------
120
+ bytes
121
+ Encoded result for the sequence.
122
+ """
123
+ if isinstance(sequence, bytes):
124
+ sequence = sequence.decode()
125
+ endpoint = PATH_PREFIX + f"/{job_id}/{sequence}"
126
+ response = session.get(endpoint)
127
+ return response.content
128
+
129
+
130
+ def fold_get_complex_result(
131
+ session: APISession, job_id: str, format: Literal["pdb", "mmcif"]
132
+ ) -> bytes:
133
+ """
134
+ Get encoded result for a complex from the request ID.
135
+
136
+ Parameters
137
+ ----------
138
+ session : APISession
139
+ Session object for API communication.
140
+ job_id : str
141
+ Job ID to retrieve results from.
142
+ format : {'pdb', 'mmcif'}
143
+ Format of the result.
144
+
145
+ Returns
146
+ -------
147
+ bytes
148
+ Encoded result for the complex.
149
+ """
150
+ endpoint = PATH_PREFIX + f"/{job_id}/complex"
151
+ response = session.get(
152
+ endpoint,
153
+ params={
154
+ "format": format,
155
+ },
156
+ )
157
+ return response.content
158
+
159
+
160
+ def fold_get_complex_extra_result(
161
+ session: APISession,
162
+ job_id: str,
163
+ key: Literal["pae", "pde", "plddt", "confidence", "affinity"],
164
+ ) -> np.ndarray | list[dict]:
165
+ """
166
+ Get extra result for a complex from the request ID.
167
+
168
+ Parameters
169
+ ----------
170
+ session : APISession
171
+ Session object for API communication.
172
+ job_id : str
173
+ Job ID to retrieve results from.
174
+ key : {'pae', 'pde', 'plddt', 'confidence', 'affinity'}
175
+ The type of result to retrieve.
176
+
177
+ Returns
178
+ -------
179
+ numpy.ndarray or list of dict
180
+ The result as a numpy array (for "pae", "pde", "plddt") or a list of dictionaries (for "confidence", "affinity").
181
+ """
182
+ if key in {"pae", "pde", "plddt"}:
183
+ formatter = lambda response: np.load(io.BytesIO(response.content))
184
+ elif key in {"confidence", "affinity"}:
185
+ formatter = lambda response: response.json()
186
+ else:
187
+ raise ValueError(f"Unexpected key: {key}")
188
+ endpoint = PATH_PREFIX + f"/{job_id}/complex/{key}"
189
+ try:
190
+ response = session.get(
191
+ endpoint,
192
+ )
193
+ except HTTPError as e:
194
+ if e.status_code == 400 and key == "affinity":
195
+ raise ValueError("affinity not found for request") from None
196
+ raise e
197
+ output: np.ndarray | list[dict] = formatter(response)
198
+ return output
199
+
200
+
201
+ def fold_models_post(
202
+ session: APISession,
203
+ model_id: str,
204
+ **kwargs,
205
+ ) -> FoldJob:
206
+ """
207
+ POST a request for structure prediction.
208
+
209
+ Returns a Job object referring to this request
210
+ that can be used to retrieve results later.
211
+
212
+ Parameters
213
+ ----------
214
+ session : APISession
215
+ Session object for API communication.
216
+ model_id : str
217
+ Model ID to use for prediction.
218
+ sequences : sequence of bytes or str, optional
219
+ Sequences to request results for.
220
+ msa_id : str, optional
221
+ MSA ID to use.
222
+ num_recycles : int, optional
223
+ Number of recycles for structure prediction.
224
+ num_models : int, optional
225
+ Number of models to generate.
226
+ num_relax : int, optional
227
+ Number of relaxation steps.
228
+ use_potentials : bool, optional
229
+ Whether to use potentials.
230
+ diffusion_samples : int, optional
231
+ Number of diffusion samples (boltz).
232
+ recycling_steps : int, optional
233
+ Number of recycling steps (boltz).
234
+ sampling_steps : int, optional
235
+ Number of sampling steps (boltz).
236
+ step_scale : float, optional
237
+ Step scale (boltz).
238
+ constraints : dict, optional
239
+ Constraints to apply.
240
+ templates : list, optional
241
+ Templates to use.
242
+ properties : dict, optional
243
+ Additional properties.
244
+
245
+ Returns
246
+ -------
247
+ FoldJob
248
+ Job object referring to this request.
249
+ """
250
+ endpoint = PATH_PREFIX + f"/models/{model_id}"
251
+
252
+ body: dict = {}
253
+ if kwargs.get("sequences"):
254
+ sequences = kwargs["sequences"]
255
+ # NOTE we are handling the boltz form here too
256
+ sequences = [s.decode() if isinstance(s, bytes) else s for s in sequences]
257
+ body["sequences"] = sequences
258
+ if kwargs.get("msa_id"):
259
+ body["msa_id"] = kwargs["msa_id"]
260
+ if kwargs.get("num_recycles"):
261
+ body["num_recycles"] = kwargs["num_recycles"]
262
+ if kwargs.get("num_models"):
263
+ body["num_models"] = kwargs["num_models"]
264
+ if kwargs.get("num_relax"):
265
+ body["num_relax"] = kwargs["num_relax"]
266
+ if kwargs.get("use_potentials"):
267
+ body["use_potentials"] = kwargs["use_potentials"]
268
+ # boltz
269
+ if kwargs.get("diffusion_samples"):
270
+ body["diffusion_samples"] = kwargs["diffusion_samples"]
271
+ if kwargs.get("recycling_steps"):
272
+ body["recycling_steps"] = kwargs["recycling_steps"]
273
+ if kwargs.get("sampling_steps"):
274
+ body["sampling_steps"] = kwargs["sampling_steps"]
275
+ if kwargs.get("step_scale"):
276
+ body["step_scale"] = kwargs["step_scale"]
277
+ if kwargs.get("constraints"):
278
+ body["constraints"] = kwargs["constraints"]
279
+ if kwargs.get("templates"):
280
+ body["templates"] = kwargs["templates"]
281
+ if kwargs.get("properties"):
282
+ body["properties"] = kwargs["properties"]
283
+ if kwargs.get("method"):
284
+ body["method"] = kwargs["method"]
285
+
286
+ response = session.post(endpoint, json=body)
287
+ return FoldJob.model_validate(response.json())