openprotein-python 0.8.1__tar.gz → 0.8.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/PKG-INFO +1 -1
  2. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/__init__.py +11 -0
  3. openprotein_python-0.8.2/openprotein/_version.py +48 -0
  4. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/msa.py +0 -27
  5. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/base.py +1 -1
  6. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/embeddings.py +11 -2
  7. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/esm.py +5 -5
  8. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/future.py +9 -1
  9. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/models.py +50 -50
  10. openprotein_python-0.8.2/openprotein/embeddings/openprotein.py +21 -0
  11. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/poet.py +6 -4
  12. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/poet2.py +5 -3
  13. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/boltz.py +4 -5
  14. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/fold.py +13 -8
  15. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/future.py +2 -2
  16. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/jobs/futures.py +342 -74
  17. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/jobs/jobs.py +8 -4
  18. openprotein_python-0.8.2/openprotein/models/__init__.py +4 -0
  19. openprotein_python-0.8.2/openprotein/models/base.py +63 -0
  20. openprotein_python-0.8.2/openprotein/models/foundation/rfdiffusion.py +283 -0
  21. openprotein_python-0.8.2/openprotein/models/models.py +33 -0
  22. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/api.py +9 -9
  23. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/prediction.py +1 -1
  24. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/predictor.py +29 -17
  25. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/svd/__init__.py +1 -1
  26. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/svd/api.py +20 -20
  27. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/svd/models.py +50 -28
  28. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/svd/svd.py +41 -22
  29. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/umap/__init__.py +1 -1
  30. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/umap/api.py +11 -12
  31. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/umap/models.py +69 -14
  32. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/umap/umap.py +36 -12
  33. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/pyproject.toml +6 -2
  34. openprotein_python-0.8.1/openprotein/_version.py +0 -9
  35. openprotein_python-0.8.1/openprotein/embeddings/openprotein.py +0 -21
  36. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/.gitignore +0 -0
  37. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/LICENSE.txt +0 -0
  38. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/README.md +0 -0
  39. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/__init__.py +0 -0
  40. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/align.py +0 -0
  41. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/api.py +0 -0
  42. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/future.py +0 -0
  43. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/align/schemas.py +0 -0
  44. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/chains.py +0 -0
  45. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/common/__init__.py +0 -0
  46. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/common/features.py +0 -0
  47. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/common/model_metadata.py +0 -0
  48. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/common/reduction.py +0 -0
  49. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/config.py +0 -0
  50. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/csv.py +0 -0
  51. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/data/__init__.py +0 -0
  52. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/data/api.py +0 -0
  53. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/data/assaydataset.py +0 -0
  54. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/data/data.py +0 -0
  55. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/data/schemas.py +0 -0
  56. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/design/__init__.py +0 -0
  57. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/design/api.py +0 -0
  58. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/design/design.py +0 -0
  59. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/design/future.py +0 -0
  60. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/design/schemas.py +0 -0
  61. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/__init__.py +0 -0
  62. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/api.py +0 -0
  63. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/embeddings/schemas.py +0 -0
  64. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/errors.py +0 -0
  65. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fasta.py +0 -0
  66. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/__init__.py +0 -0
  67. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/alphafold2.py +0 -0
  68. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/api.py +0 -0
  69. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/esmfold.py +0 -0
  70. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/models.py +0 -0
  71. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/fold/schemas.py +0 -0
  72. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/jobs/__init__.py +0 -0
  73. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/jobs/api.py +0 -0
  74. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/jobs/schemas.py +0 -0
  75. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/__init__.py +0 -0
  76. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/models.py +0 -0
  77. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/schemas.py +0 -0
  78. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/predictor/validate.py +0 -0
  79. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/prompt/__init__.py +0 -0
  80. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/prompt/api.py +0 -0
  81. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/prompt/models.py +0 -0
  82. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/prompt/prompt.py +0 -0
  83. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/prompt/schemas.py +0 -0
  84. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/protein.py +0 -0
  85. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/svd/schemas.py +0 -0
  86. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/umap/schemas.py +0 -0
  87. {openprotein_python-0.8.1 → openprotein_python-0.8.2}/openprotein/utils/uuid.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: openprotein-python
3
- Version: 0.8.1
3
+ Version: 0.8.2
4
4
  Summary: OpenProtein Python interface.
5
5
  Author-email: Mark Gee <markgee@ne47.bio>, "Timothy Truong Jr." <ttruong@ne47.bio>, Tristan Bepler <tbepler@ne47.bio>
6
6
  License-Expression: MIT
@@ -17,6 +17,7 @@ from openprotein.align import AlignAPI
17
17
  from openprotein.prompt import PromptAPI
18
18
  from openprotein.embeddings import EmbeddingsAPI
19
19
  from openprotein.fold import FoldAPI
20
+ from openprotein.models import ModelsAPI
20
21
  from openprotein.svd import SVDAPI
21
22
  from openprotein.umap import UMAPAPI
22
23
  from openprotein.predictor import PredictorAPI
@@ -40,6 +41,7 @@ class OpenProtein(APISession):
40
41
  _fold = None
41
42
  _predictor = None
42
43
  _design = None
44
+ _models = None
43
45
 
44
46
  def wait(self, future: Future, *args, **kwargs):
45
47
  return future.wait(*args, **kwargs)
@@ -149,5 +151,14 @@ class OpenProtein(APISession):
149
151
  self._fold = FoldAPI(self)
150
152
  return self._fold
151
153
 
154
+ @property
155
+ def models(self) -> "ModelsAPI":
156
+ """
157
+ The models submodule provides a unified entry point to all protein models.
158
+ """
159
+ if self._models is None:
160
+ self._models = ModelsAPI(self)
161
+ return self._models
162
+
152
163
 
153
164
  connect = OpenProtein
@@ -0,0 +1,48 @@
1
+ """Compute the version number and store it in the `__version__` variable.
2
+
3
+ Based on <https://github.com/maresb/hatch-vcs-footgun-example>.
4
+ """
5
+
6
+
7
+ def _get_hatch_version():
8
+ """Compute the most up-to-date version number in a development environment.
9
+
10
+ Returns `None` if Hatchling is not installed, e.g. in a production environment.
11
+
12
+ For more details, see <https://github.com/maresb/hatch-vcs-footgun-example/>.
13
+ """
14
+ import os
15
+
16
+ try:
17
+ from hatchling.metadata.core import ProjectMetadata
18
+ from hatchling.plugin.manager import PluginManager
19
+ from hatchling.utils.fs import locate_file
20
+ except ImportError:
21
+ # Hatchling is not installed, so probably we are not in
22
+ # a development environment.
23
+ return None
24
+
25
+ pyproject_toml = locate_file(__file__, "pyproject.toml")
26
+ if pyproject_toml is None:
27
+ raise RuntimeError("pyproject.toml not found although hatchling is installed")
28
+ root = os.path.dirname(pyproject_toml)
29
+ metadata = ProjectMetadata(root=root, plugin_manager=PluginManager())
30
+ # Version can be either statically set in pyproject.toml or computed dynamically:
31
+ return metadata.core.version or metadata.hatch.version.cached
32
+
33
+
34
+ def _get_importlib_metadata_version():
35
+ """Compute the version number using importlib.metadata.
36
+
37
+ This is the official Pythonic way to get the version number of an installed
38
+ package. However, it is only updated when a package is installed. Thus, if a
39
+ package is installed in editable mode, and a different version is checked out,
40
+ then the version number will not be updated.
41
+ """
42
+ from importlib.metadata import version
43
+
44
+ __version__ = version(__package__) # type: ignore
45
+ return __version__
46
+
47
+
48
+ __version__ = _get_hatch_version() or _get_importlib_metadata_version()
@@ -22,33 +22,6 @@ from .schemas import (
22
22
  class MSAFuture(AlignFuture, Future):
23
23
  """
24
24
  Represents a future for MSA (Multiple Sequence Alignment) results.
25
-
26
- Parameters
27
- ----------
28
- session : APISession
29
- An instance of APISession for API interactions.
30
- job : MSAJob
31
- The MSA job.
32
- page_size : int, optional
33
- The number of results to fetch in a single page. Defaults to config.POET_PAGE_SIZE.
34
-
35
- Attributes
36
- ----------
37
- session : APISession
38
- An instance of APISession for API interactions.
39
- job : MSAJob | MafftJob | ClustalOJob | AbNumberJob
40
- The MSA job.
41
- page_size : int
42
- The number of results to fetch in a single page.
43
- msa_id : str
44
- The job ID for the MSA.
45
-
46
- Methods
47
- -------
48
- get(verbose=False)
49
- Retrieve the MSA of the job as an iterator over CSV rows.
50
- sample_prompt(...)
51
- Create a protein sequence prompt from the linked MSA for PoET Jobs.
52
25
  """
53
26
 
54
27
  job: MSAJob | MafftJob | ClustalOJob | AbNumberJob
@@ -27,7 +27,7 @@ class BearerAuth(requests.auth.AuthBase):
27
27
  self.token = token
28
28
 
29
29
  def __call__(self, r):
30
- r.headers["authorization"] = "Bearer " + self.token
30
+ r.headers["Authorization"] = "Bearer " + self.token
31
31
  return r
32
32
 
33
33
 
@@ -43,16 +43,24 @@ class EmbeddingsAPI:
43
43
 
44
44
  # added for static typing, eg pylance, for autocomplete
45
45
  # at init these are all overwritten.
46
+
47
+ #: PoET-2 model
48
+ poet2: PoET2Model
49
+ #: PoET model
50
+ poet: PoETModel
51
+ #: Prot-seq model
46
52
  prot_seq: OpenProteinModel
53
+ #: Rotaprot model trained on UniRef50
47
54
  rotaprot_large_uniref50w: OpenProteinModel
55
+ #: Rotaprot model trained on UniRef90
48
56
  rotaprot_large_uniref90_ft: OpenProteinModel
49
- poet: PoETModel
50
57
  poet_2: PoET2Model
51
- poet2: PoET2Model
52
58
 
59
+ #: ESM1b model
53
60
  esm1b: ESMModel # alias
54
61
  esm1b_t33_650M_UR50S: ESMModel
55
62
 
63
+ #: ESM1v model
56
64
  esm1v: ESMModel # alias
57
65
  esm1v_t33_650M_UR90S_1: ESMModel
58
66
  esm1v_t33_650M_UR90S_2: ESMModel
@@ -60,6 +68,7 @@ class EmbeddingsAPI:
60
68
  esm1v_t33_650M_UR90S_4: ESMModel
61
69
  esm1v_t33_650M_UR90S_5: ESMModel
62
70
 
71
+ #: ESM2 model
63
72
  esm2: ESMModel # alias
64
73
  esm2_t12_35M_UR50D: ESMModel
65
74
  esm2_t30_150M_UR50D: ESMModel
@@ -1,9 +1,9 @@
1
1
  """Community-based ESM models."""
2
2
 
3
- from .models import EmbeddingModel
3
+ from .models import AttnModel, EmbeddingModel
4
4
 
5
5
 
6
- class ESMModel(EmbeddingModel):
6
+ class ESMModel(AttnModel, EmbeddingModel):
7
7
  """
8
8
  Class providing inference endpoints for Facebook's ESM protein language models.
9
9
 
@@ -13,9 +13,9 @@ class ESMModel(EmbeddingModel):
13
13
 
14
14
  .. code-block:: python
15
15
 
16
- >>> import openprotein
17
- >>> session = openprotein.connect(username="user", password="password")
18
- >>> session.embedding.esm2_t12_35M_UR50D?
16
+ >>> import openprotein
17
+ >>> session = openprotein.connect(username="user", password="password")
18
+ >>> session.embedding.esm2_t12_35M_UR50D?
19
19
  """
20
20
 
21
21
  model_id = [
@@ -63,7 +63,15 @@ class EmbeddingsResultFuture(MappedFuture, Future):
63
63
  def id(self):
64
64
  return self.job.job_id
65
65
 
66
- def keys(self):
66
+ def __keys__(self):
67
+ """
68
+ Get the list of sequences submitted for the embed request.
69
+
70
+ Returns
71
+ -------
72
+ list of bytes
73
+ List of sequences.
74
+ """
67
75
  return self.sequences
68
76
 
69
77
  def get_item(self, sequence: bytes) -> np.ndarray:
@@ -17,9 +17,10 @@ if TYPE_CHECKING:
17
17
 
18
18
 
19
19
  class EmbeddingModel:
20
+ """Base embeddings model used to understand and provide embeddings from sequences."""
20
21
 
21
22
  # overridden by subclasses
22
- # used to get correct emb model
23
+ # used to get correct emb model during factory create
23
24
  model_id: list[str] | str = "protembed"
24
25
 
25
26
  def __init__(
@@ -78,9 +79,9 @@ class EmbeddingModel:
78
79
  The API session to use.
79
80
  model_id : str
80
81
  The model identifier.
81
- default : type[EmbeddingModel] or None, optional
82
+ default : type variable of EmbeddingModel or None, optional
82
83
  Default EmbeddingModel subclass to use if no match is found.
83
- **kwargs : dict, optional
84
+ kwargs :
84
85
  Additional keyword arguments to pass to the model constructor.
85
86
 
86
87
  Returns
@@ -149,8 +150,8 @@ class EmbeddingModel:
149
150
  Sequences to embed.
150
151
  reduction : ReductionType or None, optional
151
152
  Reduction to use (e.g. mean). Defaults to mean embedding.
152
- **kwargs : dict, optional
153
- Additional keyword arguments to pass to the embedding request.
153
+ kwargs:
154
+ Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
154
155
 
155
156
  Returns
156
157
  -------
@@ -179,8 +180,8 @@ class EmbeddingModel:
179
180
  ----------
180
181
  sequences : list of bytes or list of str
181
182
  Sequences to compute logits for.
182
- **kwargs : dict, optional
183
- Additional keyword arguments to pass to the logits request.
183
+ kwargs :
184
+ Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
184
185
 
185
186
  Returns
186
187
  -------
@@ -195,32 +196,6 @@ class EmbeddingModel:
195
196
  sequences=sequences,
196
197
  )
197
198
 
198
- def attn(
199
- self, sequences: list[bytes] | list[str], **kwargs
200
- ) -> EmbeddingsResultFuture:
201
- """
202
- Compute attention embeddings for sequences using this model.
203
-
204
- Parameters
205
- ----------
206
- sequences : list of bytes or list of str
207
- Sequences to compute attention embeddings for.
208
- **kwargs : dict, optional
209
- Additional keyword arguments to pass to the attention request.
210
-
211
- Returns
212
- -------
213
- EmbeddingsResultFuture
214
- Future object representing the attention result.
215
- """
216
- return EmbeddingsResultFuture.create(
217
- session=self.session,
218
- job=api.request_attn_post(
219
- session=self.session, model_id=self.id, sequences=sequences, **kwargs
220
- ),
221
- sequences=sequences,
222
- )
223
-
224
199
  def fit_svd(
225
200
  self,
226
201
  sequences: list[bytes] | list[str] | None = None,
@@ -245,8 +220,8 @@ class EmbeddingModel:
245
220
  Number of components in SVD. Determines output shapes. Default is 1024.
246
221
  reduction : ReductionType or None, optional
247
222
  Embeddings reduction to use (e.g. mean).
248
- **kwargs : dict, optional
249
- Additional keyword arguments to pass to the SVD fitting.
223
+ kwargs :
224
+ Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
250
225
 
251
226
  Returns
252
227
  -------
@@ -261,7 +236,7 @@ class EmbeddingModel:
261
236
  # local import for cyclic dep
262
237
  from openprotein.svd import SVDAPI
263
238
 
264
- svd_api = getattr(self.session, "data", None)
239
+ svd_api = getattr(self.session, "svd", None)
265
240
  assert isinstance(svd_api, SVDAPI)
266
241
 
267
242
  # Ensure either or
@@ -273,10 +248,9 @@ class EmbeddingModel:
273
248
  )
274
249
  model_id = self.id
275
250
  return svd_api.fit_svd(
276
- session=self.session,
277
251
  model_id=model_id,
278
252
  sequences=sequences,
279
- assay_id=assay.id if assay is not None else None,
253
+ assay=assay,
280
254
  n_components=n_components,
281
255
  reduction=reduction,
282
256
  **kwargs,
@@ -306,8 +280,8 @@ class EmbeddingModel:
306
280
  Number of components in UMAP fit. Determines output shapes. Default is 2.
307
281
  reduction : ReductionType or None, optional
308
282
  Embeddings reduction to use (e.g. mean). Defaults to MEAN.
309
- **kwargs : dict, optional
310
- Additional keyword arguments to pass to the UMAP fitting.
283
+ kwargs :
284
+ Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
311
285
 
312
286
  Returns
313
287
  -------
@@ -322,9 +296,8 @@ class EmbeddingModel:
322
296
  # local import for cyclic dep
323
297
  from openprotein.umap import UMAPAPI
324
298
 
325
- umap_api = UMAPAPI(
326
- session=self.session,
327
- )
299
+ umap_api = getattr(self.session, "umap", None)
300
+ assert isinstance(umap_api, UMAPAPI)
328
301
 
329
302
  # Ensure either or
330
303
  if (assay is None and sequences is None) or (
@@ -335,7 +308,6 @@ class EmbeddingModel:
335
308
  )
336
309
  model_id = self.id
337
310
  return umap_api.fit_umap(
338
- session=self.session,
339
311
  model_id=model_id,
340
312
  feature_type=FeatureType.PLM,
341
313
  sequences=sequences,
@@ -369,8 +341,8 @@ class EmbeddingModel:
369
341
  Optional name for the predictor model.
370
342
  description : str or None, optional
371
343
  Optional description for the predictor model.
372
- **kwargs : dict, optional
373
- Additional keyword arguments to pass to the GP fitting.
344
+ kwargs :
345
+ Additional keyword arguments to be used from foundational models, e.g. prompt_id for PoET models.
374
346
 
375
347
  Returns
376
348
  -------
@@ -391,11 +363,9 @@ class EmbeddingModel:
391
363
  predictor_api = getattr(self.session, "predictor", None)
392
364
  assert isinstance(predictor_api, PredictorAPI)
393
365
 
394
- model_id = self.id
395
366
  # get assay if str
396
367
  assay = data_api.get(assay_id=assay) if isinstance(assay, str) else assay
397
368
  # extract assay_id
398
- assay_id = assay.assay_id if isinstance(assay, AssayMetadata) else assay.id
399
369
  if len(properties) == 0:
400
370
  raise InvalidParameterError("Expected (at-least) 1 property to train")
401
371
  if not set(properties) <= set(assay.measurement_names):
@@ -410,12 +380,42 @@ class EmbeddingModel:
410
380
 
411
381
  # inject into predictor api
412
382
  return predictor_api.fit_gp(
413
- assay_id=assay_id,
383
+ assay=assay,
414
384
  properties=properties,
415
385
  feature_type=FeatureType.PLM,
416
- model_id=model_id,
386
+ model=self,
417
387
  reduction=reduction,
418
388
  name=name,
419
389
  description=description,
420
390
  **kwargs,
421
391
  )
392
+
393
+
394
+ class AttnModel(EmbeddingModel):
395
+ """Embeddings model that provides attention computation."""
396
+
397
+ def attn(
398
+ self, sequences: list[bytes] | list[str], **kwargs
399
+ ) -> EmbeddingsResultFuture:
400
+ """
401
+ Compute attention embeddings for sequences using this model.
402
+
403
+ Parameters
404
+ ----------
405
+ sequences : list of bytes or list of str
406
+ Sequences to compute attention embeddings for.
407
+ kwargs :
408
+ Additional keyword arguments to be used from foundational models.
409
+
410
+ Returns
411
+ -------
412
+ EmbeddingsResultFuture
413
+ Future object representing the attention result.
414
+ """
415
+ return EmbeddingsResultFuture.create(
416
+ session=self.session,
417
+ job=api.request_attn_post(
418
+ session=self.session, model_id=self.id, sequences=sequences, **kwargs
419
+ ),
420
+ sequences=sequences,
421
+ )
@@ -0,0 +1,21 @@
1
+ """OpenProtein-proprietary models."""
2
+
3
+ from .models import AttnModel, EmbeddingModel
4
+
5
+
6
+ class OpenProteinModel(AttnModel, EmbeddingModel):
7
+ """
8
+ Proprietary protein embedding models served by OpenProtein.
9
+
10
+ Examples
11
+ --------
12
+ View specific model details (inc supported tokens) with the `?` operator.
13
+
14
+ .. code-block:: python
15
+
16
+ >>> import openprotein
17
+ >>> session = openprotein.connect(username="user", password="password")
18
+ >>> session.embedding.prot_seq?
19
+ """
20
+
21
+ model_id = ["prot-seq", "rotaprot-large-uniref50w", "rotaprot_large_uniref90_ft"]
@@ -33,9 +33,9 @@ class PoETModel(EmbeddingModel):
33
33
  --------
34
34
  View specific model details (including supported tokens) with the `?` operator.
35
35
 
36
- >>> import openprotein
37
- >>> session = openprotein.connect(username="user", password="password")
38
- >>> session.embedding.poet.<embeddings_method>
36
+ >>> import openprotein
37
+ >>> session = openprotein.connect(username="user", password="password")
38
+ >>> session.embedding.poet.<embeddings_method>
39
39
  """
40
40
 
41
41
  model_id = "poet"
@@ -113,7 +113,7 @@ class PoETModel(EmbeddingModel):
113
113
  prompt_id = None
114
114
  else:
115
115
  prompt_id = prompt if isinstance(prompt, str) else prompt.id
116
- return super().logits(sequences=sequences, prompt_id=prompt_id)
116
+ return super().logits(sequences=sequences, prompt_id=prompt_id, **kwargs)
117
117
 
118
118
  def attn(self):
119
119
  """
@@ -123,6 +123,8 @@ class PoETModel(EmbeddingModel):
123
123
  ------
124
124
  ValueError
125
125
  Always raised, as attention is not supported for PoET.
126
+
127
+ :meta private:
126
128
  """
127
129
  raise ValueError("Attn not yet supported for poet")
128
130
 
@@ -38,9 +38,11 @@ class PoET2Model(PoETModel, EmbeddingModel):
38
38
 
39
39
  Examples
40
40
  --------
41
- >>> import openprotein
42
- >>> session = openprotein.connect(username="user", password="password")
43
- >>> session.embedding.poet2.<embeddings_method>
41
+ .. code-block:: python
42
+
43
+ >>> import openprotein
44
+ >>> session = openprotein.connect(username="user", password="password")
45
+ >>> session.embedding.poet2?
44
46
  """
45
47
 
46
48
  model_id = "poet-2"
@@ -102,11 +102,10 @@ class BoltzModel(FoldModel):
102
102
  step_scale: float = 1.638,
103
103
  use_potentials: bool = False,
104
104
  constraints: list[dict] | None = None,
105
- force_single_sequence_mode: bool = False,
106
105
  **kwargs,
107
106
  ) -> FoldComplexResultFuture:
108
107
  """
109
- Post sequences to boltz model.
108
+ Request structure prediction with boltz model.
110
109
 
111
110
  Parameters
112
111
  ----------
@@ -287,7 +286,7 @@ class Boltz2Model(BoltzModel, FoldModel):
287
286
  method: str | None = None,
288
287
  ) -> FoldComplexResultFuture:
289
288
  """
290
- Post sequences to Boltz-2 model.
289
+ Request structure prediction with Boltz-2 model.
291
290
 
292
291
  Parameters
293
292
  ----------
@@ -392,7 +391,7 @@ class Boltz1xModel(BoltzModel, FoldModel):
392
391
  constraints: list[dict] | None = None,
393
392
  ) -> FoldComplexResultFuture:
394
393
  """
395
- Post sequences to Boltz-1x model. Uses potentials with Boltz-1 model.
394
+ Request structure prediction with Boltz-1x model. Uses potentials with Boltz-1 model.
396
395
 
397
396
  Parameters
398
397
  ----------
@@ -456,7 +455,7 @@ class Boltz1Model(BoltzModel, FoldModel):
456
455
  constraints: list[dict] | None = None,
457
456
  ) -> FoldComplexResultFuture:
458
457
  """
459
- Post sequences to Boltz-1 model.
458
+ Request structure prediction with Boltz-1 model.
460
459
 
461
460
  Parameters
462
461
  ----------
@@ -17,15 +17,20 @@ class FoldAPI:
17
17
  Fold API provides a high level interface for making protein structure predictions.
18
18
  """
19
19
 
20
- esmfold: ESMFoldModel
21
- alphafold2: AlphaFold2Model
22
- af2: AlphaFold2Model
23
- boltz_1: Boltz1Model
24
- boltz1: Boltz1Model
25
- boltz_1x: Boltz1xModel
26
- boltz1x: Boltz1xModel
27
- boltz_2: Boltz2Model
20
+ #: Boltz-2 model
28
21
  boltz2: Boltz2Model
22
+ boltz_2: Boltz2Model
23
+ #: Boltz-1x model
24
+ boltz1x: Boltz1xModel
25
+ boltz_1x: Boltz1xModel
26
+ #: Boltz-1 model
27
+ boltz1: Boltz1Model
28
+ boltz_1: Boltz1Model
29
+ af2: AlphaFold2Model
30
+ #: AlphaFold-2 model
31
+ alphafold2: AlphaFold2Model
32
+ #: ESMFold model
33
+ esmfold: ESMFoldModel
29
34
 
30
35
  def __init__(self, session: APISession):
31
36
  self.session = session
@@ -129,13 +129,13 @@ class FoldResultFuture(MappedFuture, Future):
129
129
  """
130
130
  return self.job.job_id
131
131
 
132
- def keys(self):
132
+ def __keys__(self):
133
133
  """
134
134
  Get the list of sequences submitted for the fold request.
135
135
 
136
136
  Returns
137
137
  -------
138
- list[bytes]
138
+ list of bytes
139
139
  List of sequences.
140
140
  """
141
141
  return self.sequences