openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Schemas for OpenProtein align system."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from openprotein.jobs import Job, JobType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AlignType(str, Enum):
|
|
12
|
+
"""
|
|
13
|
+
Enumeration of alignment types.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
INPUT : str
|
|
18
|
+
Raw input alignment.
|
|
19
|
+
MSA : str
|
|
20
|
+
Generated multiple sequence alignment.
|
|
21
|
+
PROMPT : str
|
|
22
|
+
Prompt-based alignment.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
INPUT = "RAW"
|
|
26
|
+
MSA = "GENERATED"
|
|
27
|
+
PROMPT = "PROMPT"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class MSASamplingMethod(str, Enum):
|
|
31
|
+
"""
|
|
32
|
+
Enumeration of MSA sampling methods.
|
|
33
|
+
|
|
34
|
+
Attributes
|
|
35
|
+
----------
|
|
36
|
+
RANDOM : str
|
|
37
|
+
Random sampling.
|
|
38
|
+
NEIGHBORS : str
|
|
39
|
+
Sampling based on neighbors.
|
|
40
|
+
NEIGHBORS_NO_LIMIT : str
|
|
41
|
+
Neighbor sampling without limit.
|
|
42
|
+
NEIGHBORS_NONGAP_NORM_NO_LIMIT : str
|
|
43
|
+
Neighbor sampling without gap normalization and without limit.
|
|
44
|
+
TOP : str
|
|
45
|
+
Top scoring sampling.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
RANDOM = "RANDOM"
|
|
49
|
+
NEIGHBORS = "NEIGHBORS"
|
|
50
|
+
NEIGHBORS_NO_LIMIT = "NEIGHBORS_NO_LIMIT"
|
|
51
|
+
NEIGHBORS_NONGAP_NORM_NO_LIMIT = "NEIGHBORS_NONGAP_NORM_NO_LIMIT"
|
|
52
|
+
TOP = "TOP"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class PromptPostParams(BaseModel):
|
|
56
|
+
"""
|
|
57
|
+
Parameters for posting a prompt to generate an MSA.
|
|
58
|
+
|
|
59
|
+
Attributes
|
|
60
|
+
----------
|
|
61
|
+
msa_id : str
|
|
62
|
+
Identifier for the MSA.
|
|
63
|
+
num_sequences : int or None, optional
|
|
64
|
+
Number of sequences to sample (default is None, must be >=0 and <100).
|
|
65
|
+
num_residues : int or None, optional
|
|
66
|
+
Number of residues to sample (default is None, must be >=0 and <24577).
|
|
67
|
+
method : MSASamplingMethod, optional
|
|
68
|
+
Sampling method to use (default is NEIGHBORS_NONGAP_NORM_NO_LIMIT).
|
|
69
|
+
homology_level : float, optional
|
|
70
|
+
Homology level threshold (default is 0.8, must be between 0 and 1).
|
|
71
|
+
max_similarity : float, optional
|
|
72
|
+
Maximum similarity threshold (default is 1.0, must be between 0 and 1).
|
|
73
|
+
min_similarity : float, optional
|
|
74
|
+
Minimum similarity threshold (default is 0.0, must be between 0 and 1).
|
|
75
|
+
always_include_seed_sequence : bool, optional
|
|
76
|
+
Whether to always include the seed sequence (default is False).
|
|
77
|
+
num_ensemble_prompts : int, optional
|
|
78
|
+
Number of ensemble prompts to generate (default is 1).
|
|
79
|
+
random_seed : int or None, optional
|
|
80
|
+
Random seed for reproducibility (default is None).
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
msa_id: str
|
|
84
|
+
num_sequences: int | None = Field(None, ge=0, lt=100)
|
|
85
|
+
num_residues: int | None = Field(None, ge=0, lt=24577)
|
|
86
|
+
method: MSASamplingMethod = MSASamplingMethod.NEIGHBORS_NONGAP_NORM_NO_LIMIT
|
|
87
|
+
homology_level: float = Field(0.8, ge=0, le=1)
|
|
88
|
+
max_similarity: float = Field(1.0, ge=0, le=1)
|
|
89
|
+
min_similarity: float = Field(0.0, ge=0, le=1)
|
|
90
|
+
always_include_seed_sequence: bool = False
|
|
91
|
+
num_ensemble_prompts: int = 1
|
|
92
|
+
random_seed: int | None = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class MSAJob(Job):
|
|
96
|
+
"""
|
|
97
|
+
Base class for MSA-related jobs.
|
|
98
|
+
|
|
99
|
+
Attributes
|
|
100
|
+
----------
|
|
101
|
+
job_type : Literal[JobType.align_align]
|
|
102
|
+
The type of job (must be JobType.align_align).
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
job_type: Literal[JobType.align_align]
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def msa_id(self):
|
|
109
|
+
"""
|
|
110
|
+
Returns the MSA identifier for this job.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
str
|
|
115
|
+
The MSA identifier.
|
|
116
|
+
"""
|
|
117
|
+
return self.msa_id
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class MafftJob(MSAJob, Job):
|
|
121
|
+
"""
|
|
122
|
+
Job for running MAFFT alignment.
|
|
123
|
+
|
|
124
|
+
Attributes
|
|
125
|
+
----------
|
|
126
|
+
job_type : Literal[JobType.mafft]
|
|
127
|
+
The type of job (must be JobType.mafft).
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
job_type: Literal[JobType.mafft]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class ClustalOJob(MSAJob, Job):
|
|
134
|
+
"""
|
|
135
|
+
Job for running Clustal Omega alignment.
|
|
136
|
+
|
|
137
|
+
Attributes
|
|
138
|
+
----------
|
|
139
|
+
job_type : Literal[JobType.clustalo]
|
|
140
|
+
The type of job (must be JobType.clustalo).
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
job_type: Literal[JobType.clustalo]
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class AbNumberJob(MSAJob, Job):
|
|
147
|
+
"""
|
|
148
|
+
Job for running AbNumber alignment.
|
|
149
|
+
|
|
150
|
+
Attributes
|
|
151
|
+
----------
|
|
152
|
+
job_type : Literal[JobType.abnumber]
|
|
153
|
+
The type of job (must be JobType.abnumber).
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
job_type: Literal[JobType.abnumber]
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class AbNumberScheme(str, Enum):
|
|
160
|
+
"""Antibody numbering scheme."""
|
|
161
|
+
|
|
162
|
+
IMGT = "imgt"
|
|
163
|
+
CHOTHIA = "chothia"
|
|
164
|
+
KABAT = "kabat"
|
|
165
|
+
AHO = "aho"
|
openprotein/base.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import warnings
|
|
4
|
+
from collections.abc import Container, Mapping
|
|
5
|
+
from typing import Union
|
|
6
|
+
from urllib.parse import urljoin
|
|
7
|
+
|
|
8
|
+
import requests
|
|
9
|
+
import requests.auth
|
|
10
|
+
from requests.adapters import HTTPAdapter
|
|
11
|
+
from requests.packages.urllib3.util.retry import Retry # type: ignore
|
|
12
|
+
|
|
13
|
+
import openprotein.config as config
|
|
14
|
+
from openprotein.errors import APIError, AuthError, HTTPError
|
|
15
|
+
|
|
16
|
+
USERNAME = os.getenv("OPENPROTEIN_USERNAME")
|
|
17
|
+
PASSWORD = os.getenv("OPENPROTEIN_PASSWORD")
|
|
18
|
+
BACKEND = os.getenv("OPENPROTEIN_API_BACKEND", "https://api.openprotein.ai/api/")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BearerAuth(requests.auth.AuthBase):
|
|
22
|
+
"""
|
|
23
|
+
See https://stackoverflow.com/a/58055668
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, token):
|
|
27
|
+
self.token = token
|
|
28
|
+
|
|
29
|
+
def __call__(self, r):
|
|
30
|
+
r.headers["Authorization"] = "Bearer " + self.token
|
|
31
|
+
return r
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class APISession(requests.Session):
|
|
35
|
+
"""
|
|
36
|
+
A class to handle API sessions. This class provides a connection session to the OpenProtein API.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
username : str
|
|
41
|
+
The username of the user.
|
|
42
|
+
password : str
|
|
43
|
+
The password of the user.
|
|
44
|
+
|
|
45
|
+
Examples
|
|
46
|
+
--------
|
|
47
|
+
>>> session = APISession("username", "password")
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
username: str | None = USERNAME,
|
|
53
|
+
password: str | None = PASSWORD,
|
|
54
|
+
backend: str = BACKEND,
|
|
55
|
+
timeout: int = 180,
|
|
56
|
+
):
|
|
57
|
+
if not username or not password:
|
|
58
|
+
raise AuthError(
|
|
59
|
+
"Expected username and password. Or use environment variables `OPENPROTEIN_USERNAME` and `OPENPROTEIN_PASSWORD`"
|
|
60
|
+
)
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.backend = backend
|
|
63
|
+
self.verify = True
|
|
64
|
+
self.timeout = timeout
|
|
65
|
+
|
|
66
|
+
# Custom retry strategies
|
|
67
|
+
# auto retry for pesky connection reset errors and others
|
|
68
|
+
# 503 will catch if BE is refreshing
|
|
69
|
+
retry = Retry(
|
|
70
|
+
total=4,
|
|
71
|
+
backoff_factor=3, # 0,1,4,13s
|
|
72
|
+
status_forcelist=[500, 502, 503, 504, 101, 104],
|
|
73
|
+
)
|
|
74
|
+
adapter = HTTPAdapter(max_retries=retry)
|
|
75
|
+
self.mount("https://", adapter)
|
|
76
|
+
self.login(username, password)
|
|
77
|
+
|
|
78
|
+
def post(self, url, data=None, json=None, **kwargs):
|
|
79
|
+
r"""Sends a POST request. Returns :class:`Response` object.
|
|
80
|
+
|
|
81
|
+
:param url: URL for the new :class:`Request` object.
|
|
82
|
+
:param data: (optional) Dictionary, list of tuples, bytes, or file-like
|
|
83
|
+
object to send in the body of the :class:`Request`.
|
|
84
|
+
:param json: (optional) json to send in the body of the :class:`Request`.
|
|
85
|
+
:param \*\*kwargs: Optional arguments that ``request`` takes.
|
|
86
|
+
:rtype: requests.Response
|
|
87
|
+
"""
|
|
88
|
+
timeout = self.timeout
|
|
89
|
+
if "timeout" in kwargs:
|
|
90
|
+
timeout = kwargs.pop("timeout")
|
|
91
|
+
|
|
92
|
+
return self.request(
|
|
93
|
+
"POST", url, data=data, json=json, timeout=timeout, **kwargs
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def login(self, username: str, password: str):
|
|
97
|
+
"""
|
|
98
|
+
Authenticate connection to OpenProtein with your credentials.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
-----------
|
|
102
|
+
username: str
|
|
103
|
+
username
|
|
104
|
+
password: str
|
|
105
|
+
password
|
|
106
|
+
"""
|
|
107
|
+
# unset the auth first
|
|
108
|
+
self.auth = None
|
|
109
|
+
self.auth = self._get_auth_token(username, password)
|
|
110
|
+
|
|
111
|
+
def _get_auth_token(self, username: str, password: str):
|
|
112
|
+
endpoint = "v1/login/access-token"
|
|
113
|
+
url = urljoin(self.backend, endpoint)
|
|
114
|
+
try:
|
|
115
|
+
response = self.post(
|
|
116
|
+
url, data={"username": username, "password": password}, timeout=3
|
|
117
|
+
)
|
|
118
|
+
except HTTPError as e:
|
|
119
|
+
# if an error occured during auth, we raise an AuthError with reference to the HTTPError
|
|
120
|
+
raise AuthError(
|
|
121
|
+
f"Authentication failed. Please check your credentials and connection."
|
|
122
|
+
) from e
|
|
123
|
+
|
|
124
|
+
result = response.json()
|
|
125
|
+
token = result.get("access_token")
|
|
126
|
+
if token is None:
|
|
127
|
+
raise AuthError("Unable to authenticate with given credentials.")
|
|
128
|
+
return BearerAuth(token)
|
|
129
|
+
|
|
130
|
+
def request(self, method: str, url: str, *args, **kwargs):
|
|
131
|
+
full_url = urljoin(self.backend, url)
|
|
132
|
+
response = super().request(method, full_url, *args, **kwargs)
|
|
133
|
+
|
|
134
|
+
if (js := kwargs.get("json")) and js is not None:
|
|
135
|
+
if total_size(js) > 1e6:
|
|
136
|
+
warnings.warn(
|
|
137
|
+
"The requested payload is >1MB. There might be some delays or issues in processing. If the request fails, please try again with smaller sizes."
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# intercept CloudFront errors
|
|
141
|
+
if "cloudfront" in response.headers.get("Server", "").lower():
|
|
142
|
+
if response.status_code in (502, 503):
|
|
143
|
+
raise CloudFrontError(
|
|
144
|
+
f"We're experiencing a temporary backend issue via CloudFront. Please try again later. Error {response.status_code}."
|
|
145
|
+
)
|
|
146
|
+
elif response.status_code == 504:
|
|
147
|
+
raise TimeoutError(
|
|
148
|
+
"Your request took too long to process likely due to it's size. Please try breaking it up into smaller requests if possible."
|
|
149
|
+
)
|
|
150
|
+
elif not response.ok:
|
|
151
|
+
# raise custom exception that prints better error message than requests.HTTPError
|
|
152
|
+
raise HTTPError(response)
|
|
153
|
+
return response
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def total_size(o, seen=None):
|
|
157
|
+
"""Recursively finds size of objects including contents."""
|
|
158
|
+
if seen is None:
|
|
159
|
+
seen = set()
|
|
160
|
+
obj_id = id(o)
|
|
161
|
+
if obj_id in seen:
|
|
162
|
+
return 0
|
|
163
|
+
seen.add(obj_id)
|
|
164
|
+
size = sys.getsizeof(o)
|
|
165
|
+
if isinstance(o, dict):
|
|
166
|
+
size += sum((total_size(k, seen) + total_size(v, seen)) for k, v in o.items())
|
|
167
|
+
elif isinstance(o, (list, tuple, set, frozenset)):
|
|
168
|
+
size += sum(total_size(i, seen) for i in o)
|
|
169
|
+
return size
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class RestEndpoint:
|
|
173
|
+
pass
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class TimeoutError(requests.exceptions.HTTPError):
|
|
177
|
+
pass
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class CloudFrontError(requests.exceptions.HTTPError):
|
|
181
|
+
pass
|
openprotein/chains.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Additional chains that can be used with OpenProtein."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass
|
|
7
|
+
class DNA:
|
|
8
|
+
"""
|
|
9
|
+
Represents a DNA sequence.
|
|
10
|
+
|
|
11
|
+
Attributes:
|
|
12
|
+
sequence (str): The nucleotide sequence of the DNA.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
sequence: str
|
|
16
|
+
chain_id: str | list[str] | None = None
|
|
17
|
+
cyclic: bool = False
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
sequence: str,
|
|
22
|
+
chain_id: str | list[str] | None = None,
|
|
23
|
+
cyclic: bool = False,
|
|
24
|
+
):
|
|
25
|
+
# validate the sequence matches DNA
|
|
26
|
+
if not all(nt in set("ACGT") for nt in sequence.upper()):
|
|
27
|
+
raise ValueError("Sequence contains invalid DNA nucleotides.")
|
|
28
|
+
self.sequence = sequence
|
|
29
|
+
self.chain_id = chain_id
|
|
30
|
+
self.cyclic = cyclic
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class RNA:
|
|
35
|
+
"""
|
|
36
|
+
Represents an RNA sequence.
|
|
37
|
+
|
|
38
|
+
Attributes:
|
|
39
|
+
sequence (str): The nucleotide sequence of the RNA.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
sequence: str
|
|
43
|
+
chain_id: str | list[str] | None = None
|
|
44
|
+
cyclic: bool = False
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
sequence: str,
|
|
49
|
+
chain_id: str | list[str] | None = None,
|
|
50
|
+
cyclic: bool = False,
|
|
51
|
+
):
|
|
52
|
+
# validate the sequence matches RNA
|
|
53
|
+
if not all(nt in set("ACGU") for nt in sequence.upper()):
|
|
54
|
+
raise ValueError("Sequence contains invalid RNA nucleotides.")
|
|
55
|
+
self.sequence = sequence
|
|
56
|
+
self.chain_id = chain_id
|
|
57
|
+
self.cyclic = cyclic
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class Ligand:
|
|
62
|
+
"""
|
|
63
|
+
Represents a ligand with optional Chemical Component Dictionary (CCD) identifier and SMILES string.
|
|
64
|
+
|
|
65
|
+
Requires either a CCD identifier or SMILES string.
|
|
66
|
+
|
|
67
|
+
Attributes:
|
|
68
|
+
ccd (str | None): The CCD identifier for the ligand.
|
|
69
|
+
smiles (str | None): The SMILES representation of the ligand.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
chain_id: str | list[str] | None = None
|
|
73
|
+
ccd: str | None = None
|
|
74
|
+
smiles: str | None = None
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
*,
|
|
79
|
+
chain_id: str | list[str] | None = None,
|
|
80
|
+
ccd: str | None = None,
|
|
81
|
+
smiles: str | None = None,
|
|
82
|
+
):
|
|
83
|
+
self.chain_id = chain_id
|
|
84
|
+
if (ccd is None and smiles is None) or (ccd is not None and smiles is not None):
|
|
85
|
+
raise ValueError("Exactly one of 'ccd' or 'smiles' must be provided.")
|
|
86
|
+
# TODO add validation
|
|
87
|
+
self.ccd = ccd
|
|
88
|
+
self.smiles = smiles
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""Model metadata for OpenProtein models."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ModelDescription(BaseModel):
|
|
7
|
+
"""Description of available protein embedding models."""
|
|
8
|
+
|
|
9
|
+
citation_title: str | None = None
|
|
10
|
+
doi: str | None = None
|
|
11
|
+
summary: str = "Protein language model for embeddings"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TokenInfo(BaseModel):
|
|
15
|
+
"""Information about the tokens used in the embedding model."""
|
|
16
|
+
|
|
17
|
+
id: int
|
|
18
|
+
token: str
|
|
19
|
+
primary: bool
|
|
20
|
+
description: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ModelMetadata(BaseModel):
|
|
24
|
+
"""Metadata about available protein embedding models."""
|
|
25
|
+
|
|
26
|
+
id: str = Field(..., alias="model_id")
|
|
27
|
+
description: ModelDescription
|
|
28
|
+
max_sequence_length: int | None = None
|
|
29
|
+
dimension: int
|
|
30
|
+
output_types: list[str]
|
|
31
|
+
input_tokens: list[str]
|
|
32
|
+
output_tokens: list[str] | None = None
|
|
33
|
+
token_descriptions: list[list[TokenInfo]]
|
openprotein/config.py
ADDED
openprotein/csv.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import csv
|
|
2
|
+
from typing import Iterator
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def ensure_str_lines(
|
|
6
|
+
lines: Iterator[str] | Iterator[bytes], encoding="utf-8"
|
|
7
|
+
) -> Iterator[str]:
|
|
8
|
+
for line in lines:
|
|
9
|
+
if isinstance(line, bytes):
|
|
10
|
+
yield line.decode(encoding)
|
|
11
|
+
else:
|
|
12
|
+
yield line
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def parse_stream(lines: Iterator[str] | Iterator[bytes]) -> Iterator[list[str]]:
|
|
16
|
+
"""
|
|
17
|
+
Returns a CSV reader from a requests.Response object.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
response : requests.Response
|
|
22
|
+
The response object to parse.
|
|
23
|
+
|
|
24
|
+
Returns
|
|
25
|
+
-------
|
|
26
|
+
csv.reader
|
|
27
|
+
A csv reader object for the response.
|
|
28
|
+
"""
|
|
29
|
+
reader = csv.reader(ensure_str_lines(lines))
|
|
30
|
+
for row in reader:
|
|
31
|
+
yield row
|