openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
@@ -0,0 +1,113 @@
1
+ """Schemas for OpenProtein predictor system."""
2
+
3
+ from datetime import datetime
4
+ from enum import Enum
5
+ from typing import Literal
6
+
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from openprotein.common import FeatureType
10
+ from openprotein.jobs import Job, JobStatus, JobType
11
+
12
+
13
+ class Kernel(BaseModel):
14
+ type: str
15
+ multitask: bool = False
16
+
17
+
18
+ class Constraints(BaseModel):
19
+ sequence_length: int | None = None
20
+
21
+
22
+ class PredictorType(str, Enum):
23
+ GP = "GP"
24
+ ENSEMBLE = "ENSEMBLE"
25
+
26
+
27
+ class Features(BaseModel):
28
+ type: FeatureType
29
+ model_id: str | None = None
30
+ reduction: str | None = None
31
+
32
+ model_config = ConfigDict(protected_namespaces=())
33
+
34
+
35
+ class PredictorArgs(BaseModel):
36
+ kernel: Kernel | None = None
37
+
38
+
39
+ class ModelSpec(PredictorArgs, BaseModel):
40
+ type: PredictorType
41
+ constraints: Constraints | None = None
42
+ features: Features | None = None
43
+
44
+
45
+ class Dataset(BaseModel):
46
+ assay_id: str
47
+ properties: list[str]
48
+
49
+
50
+ class PredictorMetadata(BaseModel):
51
+ """Metadata about the predictor."""
52
+
53
+ class CalibrationStats(BaseModel):
54
+ """Calibration stats for this predictor, based on the latest crossvalidation."""
55
+
56
+ pearson: float | None = None
57
+ spearman: float | None = None
58
+ ece: float | None = None
59
+
60
+ class CalibrationCurvePoint(BaseModel):
61
+ x: float
62
+ y: float
63
+
64
+ id: str
65
+ name: str
66
+ description: str | None = None
67
+ status: JobStatus
68
+ created_date: datetime
69
+ model_spec: ModelSpec
70
+ ensemble_model_ids: list[str] | None = None
71
+ training_dataset: Dataset
72
+ traingraphs: list["TrainGraph"] | None = None
73
+ stats: CalibrationStats | None = None
74
+ curve: list[CalibrationCurvePoint] | None = None
75
+
76
+ def is_done(self):
77
+ return self.status.done()
78
+
79
+ model_config = ConfigDict(protected_namespaces=())
80
+
81
+ class TrainGraph(BaseModel):
82
+ measurement_name: str
83
+ hyperparam_search_step: int
84
+ losses: list[float]
85
+
86
+
87
+ class PredictorEnsembleJob(Job):
88
+ job_id: None = None
89
+ progress_counter: None = None
90
+
91
+
92
+ class PredictorTrainJob(Job):
93
+ job_type: Literal[JobType.predictor_train]
94
+
95
+
96
+ class PredictJob(Job):
97
+ job_type: Literal[JobType.predictor_predict]
98
+
99
+
100
+ class PredictSingleSiteJob(Job):
101
+ job_type: Literal[JobType.predictor_predict_single_site]
102
+
103
+
104
+ class PredictMultiJob(Job):
105
+ job_type: Literal[JobType.predictor_predict_multi]
106
+
107
+
108
+ class PredictMultiSingleSiteJob(Job):
109
+ job_type: Literal[JobType.predictor_predict_multi_single_site]
110
+
111
+
112
+ class PredictorCVJob(Job):
113
+ job_type: Literal[JobType.predictor_crossvalidate]
@@ -0,0 +1,40 @@
1
+ """Predictor validation results represented as futures."""
2
+
3
+ import numpy as np
4
+
5
+ from openprotein.base import APISession
6
+ from openprotein.jobs import Future
7
+
8
+ from . import api
9
+ from .schemas import PredictorCVJob
10
+
11
+
12
+ class CVResultFuture(Future):
13
+ """Future Job for manipulating results"""
14
+
15
+ job: PredictorCVJob
16
+
17
+ def __init__(
18
+ self,
19
+ session: APISession,
20
+ job: PredictorCVJob,
21
+ ):
22
+ super().__init__(session, job)
23
+
24
+ @property
25
+ def id(self):
26
+ return self.job.job_id
27
+
28
+ def get(self, verbose: bool = False) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
29
+ """
30
+ Get embedding results for specified sequence.
31
+
32
+ Args:
33
+ sequence (bytes): sequence to fetch results for
34
+
35
+ Returns:
36
+ mu (np.ndarray): means of predictions
37
+ var (np.ndarray): variances of predictions
38
+ """
39
+ data = api.predictor_crossvalidate_get(self.session, self.job.job_id)
40
+ return api.decode_crossvalidate(data)
@@ -0,0 +1,9 @@
1
+ """
2
+ Prompt module for OpenProtein for using with PoET models.
3
+
4
+ isort:skip_file
5
+ """
6
+
7
+ from .schemas import Context, PromptMetadata, QueryMetadata, PromptJob
8
+ from .models import Prompt, Query
9
+ from .prompt import PromptAPI
@@ -0,0 +1,505 @@
1
+ """Prompt REST API interface for making HTTP calls to the prompt backend."""
2
+
3
+ import copy
4
+ import io
5
+ import zipfile
6
+ from typing import BinaryIO, Sequence, cast
7
+
8
+ from openprotein.base import APISession
9
+ from openprotein.errors import APIError, InvalidParameterError, RawAPIError
10
+ from openprotein.protein import Protein
11
+
12
+ from .schemas import Context, PromptMetadata, QueryMetadata
13
+
14
+
15
+ def create_prompt(
16
+ session: APISession,
17
+ context: Context | Sequence[Context],
18
+ name: str | None = None,
19
+ description: str | None = None,
20
+ ) -> PromptMetadata:
21
+ """
22
+ Create a prompt.
23
+
24
+ Parameters
25
+ ----------
26
+ session : APISession
27
+ The API session.
28
+ context : Context or Sequence[Context]
29
+ Context or list of contexts, each of which is a list of sequences/structures.
30
+ name : str or None, optional
31
+ Name of the prompt.
32
+ description : str or None, optional
33
+ Description of the prompt.
34
+
35
+ Returns
36
+ -------
37
+ PromptMetadata
38
+ Metadata of the created prompt.
39
+
40
+ Raises
41
+ ------
42
+ InvalidParameterError
43
+ If the parameters are invalid.
44
+ APIError
45
+ If the API returns an error.
46
+ """
47
+ endpoint = "v1/prompt/create_prompt"
48
+ data = {}
49
+ if name is not None:
50
+ data["name"] = name
51
+ if description is not None:
52
+ data["description"] = description
53
+
54
+ context_zip_files = zip_prompt(context=context)
55
+
56
+ files = [
57
+ ("context", (f"context-{i}.zip", context_zip_file, "application/zip"))
58
+ for i, context_zip_file in enumerate(context_zip_files)
59
+ ]
60
+ form: dict = {
61
+ "files": files,
62
+ }
63
+ if len(data) > 0:
64
+ form["data"] = data
65
+
66
+ response = session.post(endpoint, **form)
67
+
68
+ if response.status_code == 200:
69
+ return PromptMetadata.model_validate(response.json())
70
+ elif response.status_code == 400:
71
+ error = RawAPIError.model_validate(response.json())
72
+ raise InvalidParameterError(error.detail)
73
+ elif response.status_code == 401:
74
+ error = RawAPIError.model_validate(response.json())
75
+ raise APIError(error.detail)
76
+ else:
77
+ raise APIError(f"Unexpected response status code: {response.status_code}")
78
+
79
+
80
+ def get_prompt_metadata(session: APISession, prompt_id: str) -> PromptMetadata:
81
+ """
82
+ Get metadata for a given prompt ID.
83
+
84
+ Parameters
85
+ ----------
86
+ session : APISession
87
+ The API session.
88
+ prompt_id : str
89
+ The prompt ID.
90
+
91
+ Returns
92
+ -------
93
+ PromptMetadata
94
+ Metadata of the prompt.
95
+
96
+ Raises
97
+ ------
98
+ APIError
99
+ If the API returns an error.
100
+ """
101
+ endpoint = f"v1/prompt/{prompt_id}"
102
+ response = session.get(endpoint)
103
+
104
+ if response.status_code == 200:
105
+ return PromptMetadata.model_validate(response.json())
106
+ elif response.status_code == 401:
107
+ error = RawAPIError.model_validate(response.json())
108
+ raise APIError(error.detail)
109
+ elif response.status_code == 404:
110
+ error = RawAPIError.model_validate(response.json())
111
+ raise APIError(error.detail)
112
+ else:
113
+ raise APIError(f"Unexpected response status code: {response.status_code}")
114
+
115
+
116
+ def get_prompt(session: APISession, prompt_id: str) -> list[list[Protein]]:
117
+ """
118
+ Get the prompt content for a given prompt ID.
119
+
120
+ Parameters
121
+ ----------
122
+ session : APISession
123
+ The API session.
124
+ prompt_id : str
125
+ The prompt ID.
126
+
127
+ Returns
128
+ -------
129
+ list of list of Protein
130
+ The prompt data as a list of context protein lists.
131
+
132
+ Raises
133
+ ------
134
+ APIError
135
+ If the API returns an error.
136
+ """
137
+ endpoint = f"v1/prompt/{prompt_id}/content"
138
+ response = session.get(endpoint, stream=True)
139
+
140
+ if response.status_code == 200:
141
+ return unzip_prompt(io.BytesIO(response.content))
142
+ elif response.status_code == 401:
143
+ error = RawAPIError.model_validate(response.json())
144
+ raise APIError(error.detail)
145
+ elif response.status_code == 404:
146
+ error = RawAPIError.model_validate(response.json())
147
+ raise APIError(error.detail)
148
+ else:
149
+ raise APIError(f"Unexpected response status code: {response.status_code}")
150
+
151
+
152
+ def list_prompts(session: APISession) -> list[PromptMetadata]:
153
+ """
154
+ List all prompts.
155
+
156
+ Parameters
157
+ ----------
158
+ session : APISession
159
+ The API session.
160
+
161
+ Returns
162
+ -------
163
+ list of PromptMetadata
164
+ List of prompt metadata.
165
+
166
+ Raises
167
+ ------
168
+ APIError
169
+ If the API returns an error.
170
+ """
171
+ endpoint = "v1/prompt"
172
+ response = session.get(endpoint)
173
+
174
+ if response.status_code == 200:
175
+ return [PromptMetadata.model_validate(prompt) for prompt in response.json()]
176
+ elif response.status_code == 401:
177
+ error = RawAPIError.model_validate(response.json())
178
+ raise APIError(error.detail)
179
+ else:
180
+ raise APIError(f"Unexpected response status code: {response.status_code}")
181
+
182
+
183
+ def zip_prompt(
184
+ context: Context | Sequence[Context],
185
+ ) -> list[io.BytesIO]:
186
+ """
187
+ Zip a prompt context to prepare for upload.
188
+
189
+ Parameters
190
+ ----------
191
+ context : Context or Sequence[Context]
192
+ A list of proteins, or a group of such proteins (for ensembles), representing the context for the prompt.
193
+
194
+ Returns
195
+ -------
196
+ list of io.BytesIO
197
+ A list of in-memory zip files for the contexts.
198
+ """
199
+ if len(context) == 0:
200
+ context = [[]]
201
+ if isinstance(context[0], (bytes, str, Protein)):
202
+ context = [cast(Context, context)]
203
+ context = cast(Sequence[Context], context)
204
+
205
+ context_zip_files = []
206
+ for this_context in context:
207
+ this_context_as_proteins: list[Protein] = []
208
+ for i, x in enumerate(this_context):
209
+ if not isinstance(x, Protein):
210
+ x = Protein(name=f"unnamed-{i:06}", sequence=x)
211
+ else:
212
+ x = copy.copy(x)
213
+ if x.name is None:
214
+ x.name = f"unnamed-{i:06}"
215
+ this_context_as_proteins.append(x)
216
+ context_files: list[tuple[str, io.BytesIO]] = []
217
+ for protein in this_context_as_proteins:
218
+ index = len(context_files)
219
+ if protein.has_structure:
220
+ context_files.append(
221
+ (
222
+ f"{index:06}.{protein.name}.cif",
223
+ io.BytesIO(protein.make_cif_string().encode()),
224
+ )
225
+ )
226
+ else:
227
+ # write sequences with no structure as fasta, continuing existing fasta file
228
+ # if previous protein was sequence only
229
+ if len(context_files) == 0 or not context_files[-1][0].endswith(
230
+ ".fasta"
231
+ ):
232
+ context_files.append((f"{index:06}.fasta", io.BytesIO()))
233
+ _, current_file = context_files[-1]
234
+ current_file.write(protein.make_fasta_bytes())
235
+ # generate context zip file
236
+ in_memory_zip = io.BytesIO()
237
+ with zipfile.ZipFile(in_memory_zip, "w", zipfile.ZIP_DEFLATED) as zf:
238
+ for filename, contents in context_files:
239
+ zf.writestr(filename, contents.getvalue())
240
+ in_memory_zip.seek(0)
241
+ context_zip_files.append(in_memory_zip)
242
+
243
+ return context_zip_files
244
+
245
+
246
+ def unzip_prompt(prompt_zip: BinaryIO) -> list[list[Protein]]:
247
+ """
248
+ Unzip a prompt zip file retrieved from the prompt API.
249
+
250
+ This function is the reverse of zip_prompt. It extracts the context proteins
251
+ from a prompt zip file returned by get_prompt().
252
+
253
+ Parameters
254
+ ----------
255
+ prompt_zip : BinaryIO
256
+ The binary data of the prompt zip file returned by get_prompt().
257
+
258
+ Returns
259
+ -------
260
+ list of list of Protein
261
+ List of context protein lists, where each inner list represents a context group.
262
+ """
263
+ context_zip_files = []
264
+ with zipfile.ZipFile(prompt_zip, "r") as zip_file:
265
+ file_names = zip_file.namelist()
266
+
267
+ for file_name in file_names:
268
+ if file_name.startswith("context-"):
269
+ context_zip_file = io.BytesIO(zip_file.read(file_name))
270
+ context_zip_files.append(context_zip_file)
271
+ context = __parse_prompt(context_files=context_zip_files)
272
+
273
+ return context
274
+
275
+
276
+ def __parse_prompt(
277
+ context_files: Sequence[BinaryIO],
278
+ ) -> list[list[Protein]]:
279
+ """
280
+ Parse context and query files into proteins.
281
+
282
+ Parameters
283
+ ----------
284
+ context_files : Sequence[BinaryIO]
285
+ Sequence of binary zip files, each representing a context group.
286
+
287
+ Returns
288
+ -------
289
+ list of list of Protein
290
+ List of context protein lists, where each inner list represents a context group.
291
+ """
292
+ context: list[list[Protein]] = []
293
+
294
+ # Process each context file (representing an ensemble)
295
+ for context_file in context_files:
296
+ # Reset the file pointer to the beginning
297
+ context_file.seek(0)
298
+ proteins_in_context: list[Protein] = []
299
+
300
+ with zipfile.ZipFile(context_file, "r") as zf:
301
+ # Sort filenames to process them in a consistent order
302
+ filenames = zf.namelist()
303
+
304
+ # Process each file in the zip
305
+ for filename in filenames:
306
+ with zf.open(filename) as f:
307
+ content = f.read()
308
+
309
+ if filename.endswith(".cif"):
310
+ # For CIF files, create a temporary file for gemmi to read
311
+ import tempfile
312
+
313
+ with tempfile.NamedTemporaryFile(
314
+ suffix=".cif", delete=True
315
+ ) as tmp:
316
+ tmp.write(content)
317
+ tmp.flush()
318
+ # extract chain ID (using 'A' as default)
319
+ chain_id = "A"
320
+ # extract name from filename (without extension)
321
+ name = filename[:-4]
322
+ protein = Protein.from_filepath(
323
+ path=tmp.name, chain_id=chain_id, verbose=False
324
+ )
325
+ # override the name with the filename
326
+ protein.name = name
327
+ proteins_in_context.append(protein)
328
+
329
+ elif filename.endswith(".fasta"):
330
+ # Process FASTA file
331
+ import io
332
+
333
+ from openprotein import fasta
334
+
335
+ fasta_stream = io.BytesIO(content)
336
+ for name, sequence in fasta.parse_stream(fasta_stream):
337
+ proteins_in_context.append(
338
+ Protein(name=name, sequence=sequence)
339
+ )
340
+
341
+ # Add this group of proteins to the context
342
+ context.append(proteins_in_context)
343
+
344
+ return context
345
+
346
+
347
+ def create_query(
348
+ session: APISession,
349
+ query: bytes | str | Protein,
350
+ ) -> QueryMetadata:
351
+ """
352
+ Create a query.
353
+
354
+ Parameters
355
+ ----------
356
+ session : APISession
357
+ The API session.
358
+ query : bytes or str or Protein
359
+ A query representing a protein to be used with a query.
360
+
361
+ Returns
362
+ -------
363
+ QueryMetadata
364
+ Metadata of the created query.
365
+
366
+ Raises
367
+ ------
368
+ InvalidParameterError
369
+ If the parameters are invalid.
370
+ APIError
371
+ If the API returns an error.
372
+ """
373
+ endpoint = "v1/prompt/query"
374
+
375
+ if not isinstance(query, Protein):
376
+ query = Protein(name="query", sequence=query)
377
+ if query.has_structure:
378
+ qf, filename, typ = (
379
+ query.make_cif_string().encode(),
380
+ "query.cif",
381
+ "chemical/x-mmcif",
382
+ )
383
+ else:
384
+ qf, filename, typ = query.make_fasta_bytes(), "query.fasta", "text/x-fasta"
385
+
386
+ response = session.post(endpoint, files={"query": (filename, io.BytesIO(qf), typ)})
387
+
388
+ if response.status_code == 200:
389
+ return QueryMetadata.model_validate(response.json())
390
+ elif response.status_code == 400:
391
+ error = RawAPIError.model_validate(response.json())
392
+ raise InvalidParameterError(error.detail)
393
+ elif response.status_code == 401:
394
+ error = RawAPIError.model_validate(response.json())
395
+ raise APIError(error.detail)
396
+ else:
397
+ raise APIError(f"Unexpected response status code: {response.status_code}")
398
+
399
+
400
+ def get_query_metadata(session: APISession, query_id: str) -> QueryMetadata:
401
+ """
402
+ Get metadata for a given query ID.
403
+
404
+ Parameters
405
+ ----------
406
+ session : APISession
407
+ The API session.
408
+ query_id : str
409
+ The query ID.
410
+
411
+ Returns
412
+ -------
413
+ QueryMetadata
414
+ Metadata of the query.
415
+
416
+ Raises
417
+ ------
418
+ APIError
419
+ If the API returns an error.
420
+ """
421
+ endpoint = f"v1/prompt/query/{query_id}"
422
+ response = session.get(endpoint)
423
+
424
+ if response.status_code == 200:
425
+ return QueryMetadata.model_validate(response.json())
426
+ elif response.status_code == 401:
427
+ error = RawAPIError.model_validate(response.json())
428
+ raise APIError(error.detail)
429
+ elif response.status_code == 404:
430
+ error = RawAPIError.model_validate(response.json())
431
+ raise APIError(error.detail)
432
+ else:
433
+ raise APIError(f"Unexpected response status code: {response.status_code}")
434
+
435
+
436
+ def get_query(session: APISession, query_id: str) -> Protein:
437
+ """
438
+ Get the query content for a given query ID.
439
+
440
+ Parameters
441
+ ----------
442
+ session : APISession
443
+ The API session.
444
+ query_id : str
445
+ The query ID.
446
+
447
+ Returns
448
+ -------
449
+ Protein
450
+ The query protein.
451
+
452
+ Raises
453
+ ------
454
+ APIError
455
+ If the API returns an error or the file format is unexpected.
456
+ """
457
+ endpoint = f"v1/prompt/query/{query_id}/content"
458
+ response = session.get(endpoint, stream=True)
459
+ filename = response.headers.get("Content-Disposition", "query")
460
+ media_type = response.headers.get("Content-Type", "text/plain")
461
+ is_mmcif = filename.endswith(".cif") or media_type == "chemical/x-mmcif"
462
+ is_fasta = filename.endswith(".fasta") or media_type == "text/x-fasta"
463
+
464
+ query_protein = None
465
+ if is_mmcif:
466
+ # for cif files, create a temporary file for gemmi to read
467
+ import tempfile
468
+
469
+ with tempfile.NamedTemporaryFile(suffix=".cif", delete=True) as tmp:
470
+ tmp.write(response.content)
471
+ tmp.flush()
472
+ # extract chain id (using 'A' as default)
473
+ chain_id = "A"
474
+ query_protein = Protein.from_filepath(
475
+ path=tmp.name, chain_id=chain_id, verbose=False
476
+ )
477
+
478
+ elif is_fasta:
479
+ # Process FASTA file - take only the first sequence
480
+ import io
481
+
482
+ from openprotein import fasta
483
+
484
+ fasta_stream = io.BytesIO(response.content)
485
+ for name, sequence in fasta.parse_stream(fasta_stream):
486
+ query_protein = Protein(name=name, sequence=sequence)
487
+ break # Only take the first sequence
488
+ else:
489
+ raise APIError(
490
+ f"Unexpected file returned with filename {filename} and type {media_type}"
491
+ )
492
+
493
+ if query_protein is None:
494
+ raise APIError(f"Invalid query file returned from API {response.content[:10]}")
495
+
496
+ if response.status_code == 200:
497
+ return query_protein
498
+ elif response.status_code == 401:
499
+ error = RawAPIError.model_validate(response.json())
500
+ raise APIError(error.detail)
501
+ elif response.status_code == 404:
502
+ error = RawAPIError.model_validate(response.json())
503
+ raise APIError(error.detail)
504
+ else:
505
+ raise APIError(f"Unexpected response status code: {response.status_code}")