openprotein-python 0.8.2__1-py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. openprotein/__init__.py +164 -0
  2. openprotein/_version.py +48 -0
  3. openprotein/align/__init__.py +8 -0
  4. openprotein/align/align.py +395 -0
  5. openprotein/align/api.py +428 -0
  6. openprotein/align/future.py +55 -0
  7. openprotein/align/msa.py +129 -0
  8. openprotein/align/schemas.py +165 -0
  9. openprotein/base.py +181 -0
  10. openprotein/chains.py +88 -0
  11. openprotein/common/__init__.py +5 -0
  12. openprotein/common/features.py +7 -0
  13. openprotein/common/model_metadata.py +33 -0
  14. openprotein/common/reduction.py +8 -0
  15. openprotein/config.py +9 -0
  16. openprotein/csv.py +31 -0
  17. openprotein/data/__init__.py +9 -0
  18. openprotein/data/api.py +218 -0
  19. openprotein/data/assaydataset.py +178 -0
  20. openprotein/data/data.py +93 -0
  21. openprotein/data/schemas.py +27 -0
  22. openprotein/design/__init__.py +16 -0
  23. openprotein/design/api.py +259 -0
  24. openprotein/design/design.py +125 -0
  25. openprotein/design/future.py +146 -0
  26. openprotein/design/schemas.py +607 -0
  27. openprotein/embeddings/__init__.py +27 -0
  28. openprotein/embeddings/api.py +619 -0
  29. openprotein/embeddings/embeddings.py +151 -0
  30. openprotein/embeddings/esm.py +33 -0
  31. openprotein/embeddings/future.py +146 -0
  32. openprotein/embeddings/models.py +421 -0
  33. openprotein/embeddings/openprotein.py +21 -0
  34. openprotein/embeddings/poet.py +446 -0
  35. openprotein/embeddings/poet2.py +505 -0
  36. openprotein/embeddings/schemas.py +78 -0
  37. openprotein/errors.py +76 -0
  38. openprotein/fasta.py +92 -0
  39. openprotein/fold/__init__.py +21 -0
  40. openprotein/fold/alphafold2.py +131 -0
  41. openprotein/fold/api.py +287 -0
  42. openprotein/fold/boltz.py +691 -0
  43. openprotein/fold/esmfold.py +54 -0
  44. openprotein/fold/fold.py +107 -0
  45. openprotein/fold/future.py +509 -0
  46. openprotein/fold/models.py +139 -0
  47. openprotein/fold/schemas.py +39 -0
  48. openprotein/jobs/__init__.py +9 -0
  49. openprotein/jobs/api.py +71 -0
  50. openprotein/jobs/futures.py +746 -0
  51. openprotein/jobs/jobs.py +69 -0
  52. openprotein/jobs/schemas.py +135 -0
  53. openprotein/models/__init__.py +4 -0
  54. openprotein/models/base.py +63 -0
  55. openprotein/models/foundation/rfdiffusion.py +283 -0
  56. openprotein/models/models.py +33 -0
  57. openprotein/predictor/__init__.py +25 -0
  58. openprotein/predictor/api.py +384 -0
  59. openprotein/predictor/models.py +374 -0
  60. openprotein/predictor/prediction.py +79 -0
  61. openprotein/predictor/predictor.py +242 -0
  62. openprotein/predictor/schemas.py +113 -0
  63. openprotein/predictor/validate.py +40 -0
  64. openprotein/prompt/__init__.py +9 -0
  65. openprotein/prompt/api.py +505 -0
  66. openprotein/prompt/models.py +142 -0
  67. openprotein/prompt/prompt.py +130 -0
  68. openprotein/prompt/schemas.py +49 -0
  69. openprotein/protein.py +587 -0
  70. openprotein/svd/__init__.py +9 -0
  71. openprotein/svd/api.py +206 -0
  72. openprotein/svd/models.py +288 -0
  73. openprotein/svd/schemas.py +31 -0
  74. openprotein/svd/svd.py +134 -0
  75. openprotein/umap/__init__.py +9 -0
  76. openprotein/umap/api.py +259 -0
  77. openprotein/umap/models.py +211 -0
  78. openprotein/umap/schemas.py +35 -0
  79. openprotein/umap/umap.py +175 -0
  80. openprotein/utils/uuid.py +29 -0
  81. openprotein_python-0.8.2.dist-info/METADATA +176 -0
  82. openprotein_python-0.8.2.dist-info/RECORD +84 -0
  83. openprotein_python-0.8.2.dist-info/WHEEL +4 -0
  84. openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
@@ -0,0 +1,619 @@
1
+ """Embeddings REST API for making HTTP calls to our embeddings backend."""
2
+
3
+ import io
4
+ import random
5
+ import struct
6
+ from io import BytesIO
7
+ from typing import BinaryIO, Iterator
8
+
9
+ import numpy as np
10
+ from pydantic import TypeAdapter
11
+
12
+ from openprotein import csv
13
+ from openprotein.base import APISession
14
+ from openprotein.common import ModelMetadata
15
+ from openprotein.errors import InvalidParameterError
16
+
17
+ from .schemas import (
18
+ AttnJob,
19
+ EmbeddingsJob,
20
+ GenerateJob,
21
+ JobType,
22
+ LogitsJob,
23
+ ScoreIndelJob,
24
+ ScoreJob,
25
+ ScoreSingleSiteJob,
26
+ )
27
+
28
+ PATH_PREFIX = "v1/embeddings"
29
+
30
+
31
+ def list_models(session: APISession) -> list[str]:
32
+ """
33
+ List available embeddings models.
34
+
35
+ Args:
36
+ session (APISession): API session
37
+
38
+ Returns:
39
+ list[str]: list of model names.
40
+ """
41
+
42
+ endpoint = PATH_PREFIX + "/models"
43
+ response = session.get(endpoint)
44
+ result = response.json()
45
+ return result
46
+
47
+
48
+ def get_model(session: APISession, model_id: str) -> ModelMetadata:
49
+ endpoint = PATH_PREFIX + f"/models/{model_id}"
50
+ response = session.get(endpoint)
51
+ result = response.json()
52
+ return ModelMetadata(**result)
53
+
54
+
55
+ def get_request_sequences(
56
+ session: APISession, job_id: str, job_type: JobType = JobType.embeddings_embed
57
+ ) -> list[bytes]:
58
+ """
59
+ Get results associated with the given request ID.
60
+
61
+ Parameters
62
+ ----------
63
+ session : APISession
64
+ Session object for API communication.
65
+ job_id : str
66
+ job ID to fetch
67
+
68
+ Returns
69
+ -------
70
+ sequences : List[bytes]
71
+ """
72
+ # NOTE - allow to handle svd/embed and umap/embed directly too instead of redirect
73
+ path = "v1" + job_type.value
74
+ endpoint = path + f"/{job_id}/sequences"
75
+ response = session.get(endpoint)
76
+ return TypeAdapter(list[bytes]).validate_python(response.json())
77
+
78
+
79
+ def request_get_sequence_result(
80
+ session: APISession,
81
+ job_id: str,
82
+ sequence: str | bytes,
83
+ job_type: JobType = JobType.embeddings_embed,
84
+ ) -> bytes:
85
+ """
86
+ Get encoded result for a sequence from the request ID.
87
+
88
+ Parameters
89
+ ----------
90
+ session : APISession
91
+ Session object for API communication.
92
+ job_id : str
93
+ job ID to retrieve results from
94
+ sequence : bytes
95
+ sequence to retrieve results for
96
+
97
+ Returns
98
+ -------
99
+ result : bytes
100
+ """
101
+ # NOTE - allow to handle svd/embed and umap/embed directly too instead of redirect
102
+ path = "v1" + job_type.value
103
+ if isinstance(sequence, bytes):
104
+ sequence = sequence.decode()
105
+ endpoint = path + f"/{job_id}/{sequence}"
106
+ response = session.get(endpoint)
107
+ return response.content
108
+
109
+
110
+ def result_decode(data: bytes) -> np.ndarray:
111
+ """
112
+ Decode embedding.
113
+
114
+ Args:
115
+ data (bytes): raw bytes encoding the array received over the API
116
+
117
+ Returns:
118
+ np.ndarray: decoded array
119
+ """
120
+ s = io.BytesIO(data)
121
+ return np.load(s, allow_pickle=False)
122
+
123
+
124
+ def request_get_score_result(session: APISession, job_id: str) -> Iterator[list[str]]:
125
+ """
126
+ Get encoded result for a sequence from the request ID.
127
+
128
+ Parameters
129
+ ----------
130
+ session : APISession
131
+ Session object for API communication.
132
+ job_id : str
133
+ job ID to retrieve results from
134
+
135
+ Returns
136
+ -------
137
+ csv.reader
138
+ """
139
+ endpoint = PATH_PREFIX + f"/{job_id}/scores"
140
+ response = session.get(endpoint, stream=True)
141
+ return csv.parse_stream(response.iter_lines())
142
+
143
+
144
+ def request_get_embeddings_stream(
145
+ session: APISession, job_id: str
146
+ ) -> Iterator[np.ndarray]:
147
+ """
148
+ Stream back the raw embeddings for a given embeddings job.
149
+
150
+ This will open an HTTP GET to `v1/embeddings/{job_id}/embeddings`
151
+ with `stream=True`, then read a sequence of framed `.npy` payloads
152
+ where each chunk is prefixed by an 8-byte big-endian length header.
153
+ Each chunk is decoded into a NumPy array and yielded as soon as it’s
154
+ received.
155
+
156
+ Parameters
157
+ ----------
158
+ session : APISession
159
+ The API session to use for making requests.
160
+ job_id : str
161
+ The embeddings job identifier returned by `request_post`.
162
+
163
+ Yields
164
+ ------
165
+ numpy.ndarray
166
+ An embedding array for each input sequence.
167
+
168
+ Raises
169
+ ------
170
+ requests.HTTPError
171
+ If the HTTP request returns a non‐2xx status code.
172
+ ValueError
173
+ If the framed stream is malformed (e.g. incomplete header or payload).
174
+ """
175
+ endpoint = PATH_PREFIX + f"/{job_id}/stream"
176
+ response = session.get(endpoint, stream=True)
177
+ response.raise_for_status()
178
+ response.raw.decode_content = True
179
+ buffered = io.BufferedReader(response.raw) # type: ignore
180
+ for array in parse_framed_npy_stream(buffered):
181
+ yield array
182
+
183
+
184
+ def parse_framed_npy_stream(stream: BinaryIO) -> Iterator[np.ndarray]:
185
+ """
186
+ Read a binary stream of length‐prefixed NumPy .npy arrays.
187
+
188
+ This function parses a stream composed of consecutive frames. Each frame
189
+ starts with an 8‐byte big‐endian unsigned integer indicating the size of
190
+ the subsequent .npy payload. It then reads exactly that many bytes and
191
+ deserializes them into a NumPy array via np.load(…, allow_pickle=False).
192
+ Frames are yielded one by one until the stream is exhausted.
193
+
194
+ Parameters
195
+ ----------
196
+ stream : BinaryIO
197
+ A binary stream supporting read(n) that contains zero or more
198
+ concatenated frames in the format:
199
+ [8‐byte big‐endian length][.npy payload].
200
+
201
+ Yields
202
+ ------
203
+ np.ndarray
204
+ Each deserialized NumPy array from the stream.
205
+
206
+ Raises
207
+ ------
208
+ ValueError
209
+ If an 8‐byte header cannot be read in full (unless at end of stream),
210
+ or if a payload shorter than the declared length is encountered.
211
+ """
212
+ while True:
213
+ # Read the 8-byte length header
214
+ try:
215
+ length_bytes = stream.read(8)
216
+ except ValueError:
217
+ # underlying file got closed → treat as EOF
218
+ break
219
+ if len(length_bytes) < 8:
220
+ if length_bytes:
221
+ raise ValueError("Incomplete length header")
222
+ break # End of stream
223
+
224
+ (npy_len,) = struct.unpack(">Q", length_bytes)
225
+ npy_bytes = stream.read(npy_len)
226
+ if len(npy_bytes) < npy_len:
227
+ raise ValueError("Incomplete npy payload")
228
+
229
+ arr = np.load(BytesIO(npy_bytes), allow_pickle=False)
230
+ yield arr
231
+
232
+
233
+ def request_get_generate_result(
234
+ session: APISession, job_id: str
235
+ ) -> Iterator[list[str]]:
236
+ """
237
+ Get encoded result for a sequence from the request ID.
238
+
239
+ Parameters
240
+ ----------
241
+ session : APISession
242
+ Session object for API communication.
243
+ job_id : str
244
+ job ID to retrieve results from
245
+
246
+ Returns
247
+ -------
248
+ csv.reader
249
+ """
250
+ endpoint = PATH_PREFIX + f"/{job_id}/generate"
251
+ response = session.get(endpoint, stream=True)
252
+ return csv.parse_stream(response.iter_lines())
253
+
254
+
255
+ def request_post(
256
+ session: APISession,
257
+ model_id: str,
258
+ sequences: list[bytes] | list[str],
259
+ reduction: str | None = "MEAN",
260
+ **kwargs,
261
+ ) -> EmbeddingsJob:
262
+ """
263
+ POST a request for embeddings from the given model ID. Returns a Job object referring to this request
264
+ that can be used to retrieve results later.
265
+
266
+ Parameters
267
+ ----------
268
+ session : APISession
269
+ Session object for API communication.
270
+ model_id : str
271
+ model ID to request results from
272
+ sequences : List[bytes]
273
+ sequences to request results for
274
+ reduction : str | None
275
+ reduction to apply to the embeddings. options are None, "MEAN", or "SUM". defaul: "MEAN"
276
+ **kwargs:
277
+ Optional parameters for models, e.g. prompt_id for PoET
278
+
279
+ Returns
280
+ -------
281
+ job : Job
282
+ """
283
+ endpoint = PATH_PREFIX + f"/models/{model_id}/embed"
284
+
285
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
286
+ body: dict = {
287
+ "sequences": sequences_unicode,
288
+ }
289
+ if reduction is not None:
290
+ body["reduction"] = reduction
291
+ if kwargs.get("prompt_id"):
292
+ body["prompt_id"] = kwargs["prompt_id"]
293
+ if kwargs.get("query_id"):
294
+ body["query_id"] = kwargs["query_id"]
295
+ if "use_query_structure_in_decoder" in kwargs:
296
+ body["use_query_structure_in_decoder"] = kwargs[
297
+ "use_query_structure_in_decoder"
298
+ ]
299
+ if kwargs.get("decoder_type"):
300
+ body["decoder_type"] = kwargs["decoder_type"]
301
+ response = session.post(endpoint, json=body)
302
+ return EmbeddingsJob.model_validate(response.json())
303
+
304
+
305
+ def request_logits_post(
306
+ session: APISession,
307
+ model_id: str,
308
+ sequences: list[bytes] | list[str],
309
+ **kwargs,
310
+ ) -> LogitsJob:
311
+ """
312
+ POST a request for logits from the given model ID. Returns a Job object referring to this request
313
+ that can be used to retrieve results later.
314
+
315
+ Parameters
316
+ ----------
317
+ session : APISession
318
+ Session object for API communication.
319
+ model_id : str
320
+ model ID to request results from
321
+ sequences : List[bytes]
322
+ sequences to request results for
323
+ **kwargs:
324
+ Optional parameters for models, e.g. prompt_id for PoET
325
+
326
+ Returns
327
+ -------
328
+ job : Job
329
+ """
330
+ endpoint = PATH_PREFIX + f"/models/{model_id}/logits"
331
+
332
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
333
+ body: dict = {
334
+ "sequences": sequences_unicode,
335
+ }
336
+ if kwargs.get("prompt_id"):
337
+ body["prompt_id"] = kwargs["prompt_id"]
338
+ if kwargs.get("query_id"):
339
+ body["query_id"] = kwargs["query_id"]
340
+ if "use_query_structure_in_decoder" in kwargs:
341
+ body["use_query_structure_in_decoder"] = kwargs[
342
+ "use_query_structure_in_decoder"
343
+ ]
344
+ if kwargs.get("decoder_type"):
345
+ body["decoder_type"] = kwargs["decoder_type"]
346
+ response = session.post(endpoint, json=body)
347
+ return LogitsJob.model_validate(response.json())
348
+
349
+
350
+ def request_attn_post(
351
+ session: APISession,
352
+ model_id: str,
353
+ sequences: list[bytes] | list[str],
354
+ **kwargs,
355
+ ) -> AttnJob:
356
+ """
357
+ POST a request for attention embeddings from the given model ID. \
358
+ Returns a Job object referring to this request \
359
+ that can be used to retrieve results later.
360
+
361
+ Parameters
362
+ ----------
363
+ session : APISession
364
+ Session object for API communication.
365
+ model_id : str
366
+ model ID to request results from
367
+ sequences : List[bytes]
368
+ sequences to request results for
369
+ **kwargs:
370
+ Optional parameters for models, e.g. prompt_id for PoET
371
+
372
+ Returns
373
+ -------
374
+ job : Job
375
+ """
376
+ endpoint = PATH_PREFIX + f"/models/{model_id}/attn"
377
+
378
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
379
+ body: dict = {
380
+ "sequences": sequences_unicode,
381
+ }
382
+ if kwargs.get("prompt_id"):
383
+ body["prompt_id"] = kwargs["prompt_id"]
384
+ if kwargs.get("query_id"):
385
+ body["query_id"] = kwargs["query_id"]
386
+ if "use_query_structure_in_decoder" in kwargs:
387
+ body["use_query_structure_in_decoder"] = kwargs[
388
+ "use_query_structure_in_decoder"
389
+ ]
390
+ response = session.post(endpoint, json=body)
391
+ return AttnJob.model_validate(response.json())
392
+
393
+
394
+ def request_score_post(
395
+ session: APISession,
396
+ model_id: str,
397
+ sequences: list[bytes] | list[str],
398
+ **kwargs,
399
+ ) -> ScoreJob:
400
+ """
401
+ POST a request for sequence scoring for the given model ID. \
402
+ Returns a Job object referring to this request \
403
+ that can be used to retrieve results later.
404
+
405
+ Parameters
406
+ ----------
407
+ session : APISession
408
+ Session object for API communication.
409
+ model_id : str
410
+ model ID to request results from
411
+ sequences : List[bytes]
412
+ sequences to request results for
413
+
414
+ Returns
415
+ -------
416
+ job : Job
417
+ """
418
+ endpoint = PATH_PREFIX + f"/models/{model_id}/score"
419
+ sequences_unicode = [(s if isinstance(s, str) else s.decode()) for s in sequences]
420
+ body: dict = {
421
+ "sequences": sequences_unicode,
422
+ }
423
+ if kwargs.get("prompt_id"):
424
+ body["prompt_id"] = kwargs["prompt_id"]
425
+ if kwargs.get("query_id"):
426
+ body["query_id"] = kwargs["query_id"]
427
+ if "use_query_structure_in_decoder" in kwargs:
428
+ body["use_query_structure_in_decoder"] = kwargs[
429
+ "use_query_structure_in_decoder"
430
+ ]
431
+ if kwargs.get("decoder_type"):
432
+ body["decoder_type"] = kwargs["decoder_type"]
433
+ response = session.post(endpoint, json=body)
434
+ return ScoreJob.model_validate(response.json())
435
+
436
+
437
+ def request_score_indel_post(
438
+ session: APISession,
439
+ model_id: str,
440
+ base_sequence: bytes | str,
441
+ insert: str | None = None,
442
+ delete: list[int] | None = None,
443
+ **kwargs,
444
+ ) -> ScoreIndelJob:
445
+ """
446
+ POST a request for single site mutation scoring for the given model ID. \
447
+ Returns a Job object referring to this request \
448
+ that can be used to retrieve results later.
449
+
450
+ Parameters
451
+ ----------
452
+ session : APISession
453
+ Session object for API communication.
454
+ model_id : str
455
+ model ID to request results from
456
+ sequences : List[bytes]
457
+ sequences to request results for
458
+ insert: str | None
459
+ Insertion fragment at each site.
460
+ delete: int | None
461
+ Range of size of fragment to delete at each site.
462
+ **kwargs:
463
+ Optional parameters for models, e.g. prompt_id for PoET
464
+
465
+ Returns
466
+ -------
467
+ job : Job
468
+ """
469
+ endpoint = PATH_PREFIX + f"/models/{model_id}/score/indel"
470
+
471
+ body: dict = {
472
+ "base_sequence": (
473
+ base_sequence.decode()
474
+ if isinstance(base_sequence, bytes)
475
+ else base_sequence
476
+ ),
477
+ }
478
+ if insert is not None:
479
+ body["insert"] = insert
480
+ if delete is not None:
481
+ body["delete"] = delete
482
+ if kwargs.get("prompt_id"):
483
+ body["prompt_id"] = kwargs["prompt_id"]
484
+ if kwargs.get("query_id"):
485
+ body["query_id"] = kwargs["query_id"]
486
+ if "use_query_structure_in_decoder" in kwargs:
487
+ body["use_query_structure_in_decoder"] = kwargs[
488
+ "use_query_structure_in_decoder"
489
+ ]
490
+ if kwargs.get("decoder_type"):
491
+ body["decoder_type"] = kwargs["decoder_type"]
492
+ response = session.post(endpoint, json=body)
493
+ return ScoreIndelJob.model_validate(response.json())
494
+
495
+
496
+ def request_score_single_site_post(
497
+ session: APISession,
498
+ model_id: str,
499
+ base_sequence: bytes | str,
500
+ **kwargs,
501
+ ) -> ScoreSingleSiteJob:
502
+ """
503
+ POST a request for single site mutation scoring for the given model ID. \
504
+ Returns a Job object referring to this request \
505
+ that can be used to retrieve results later.
506
+
507
+ Parameters
508
+ ----------
509
+ session : APISession
510
+ Session object for API communication.
511
+ model_id : str
512
+ model ID to request results from
513
+ sequences : List[bytes]
514
+ sequences to request results for
515
+ **kwargs:
516
+ Optional parameters for models, e.g. prompt_id for PoET
517
+
518
+ Returns
519
+ -------
520
+ job : Job
521
+ """
522
+ endpoint = PATH_PREFIX + f"/models/{model_id}/score_single_site"
523
+
524
+ body: dict = {
525
+ "base_sequence": (
526
+ base_sequence.decode()
527
+ if isinstance(base_sequence, bytes)
528
+ else base_sequence
529
+ ),
530
+ }
531
+ if kwargs.get("prompt_id"):
532
+ body["prompt_id"] = kwargs["prompt_id"]
533
+ if kwargs.get("query_id"):
534
+ body["query_id"] = kwargs["query_id"]
535
+ if "use_query_structure_in_decoder" in kwargs:
536
+ body["use_query_structure_in_decoder"] = kwargs[
537
+ "use_query_structure_in_decoder"
538
+ ]
539
+ if kwargs.get("decoder_type"):
540
+ body["decoder_type"] = kwargs["decoder_type"]
541
+ response = session.post(endpoint, json=body)
542
+ return ScoreSingleSiteJob.model_validate(response.json())
543
+
544
+
545
+ def request_generate_post(
546
+ session: APISession,
547
+ model_id: str,
548
+ num_samples: int = 100,
549
+ temperature: float = 1.0,
550
+ topk: float | None = None,
551
+ topp: float | None = None,
552
+ max_length: int = 1000,
553
+ random_seed: int | None = None,
554
+ **kwargs,
555
+ ) -> GenerateJob:
556
+ """
557
+ POST a request for sequence generation for the given model ID. \
558
+ Returns a Job object referring to this request \
559
+ that can be used to retrieve results later.
560
+
561
+ Parameters
562
+ ----------
563
+ session : APISession
564
+ Session object for API communication.
565
+ model_id : str
566
+ model ID to request results from
567
+ **kwargs:
568
+ Optional parameters for models, e.g. prompt_id for PoET
569
+
570
+ Returns
571
+ -------
572
+ job : Job
573
+ """
574
+ endpoint = PATH_PREFIX + f"/models/{model_id}/generate"
575
+
576
+ if not (0.1 <= temperature <= 2):
577
+ raise InvalidParameterError("The 'temperature' must be between 0.1 and 2.")
578
+ if topk is not None and not (2 <= topk <= 20):
579
+ raise InvalidParameterError("The 'topk' must be between 2 and 20.")
580
+ if topp is not None and not (0 <= topp <= 1):
581
+ raise InvalidParameterError("The 'topp' must be between 0 and 1.")
582
+ if random_seed is not None and not (0 <= random_seed <= 2**32):
583
+ raise InvalidParameterError("The 'random_seed' must be between 0 and 2^32.")
584
+
585
+ if random_seed is None:
586
+ random_seed = random.randrange(2**32)
587
+
588
+ body: dict = {
589
+ "n_sequences": num_samples,
590
+ "temperature": temperature,
591
+ "maxlen": max_length,
592
+ }
593
+ if topk is not None:
594
+ body["topk"] = topk
595
+ if topp is not None:
596
+ body["topp"] = topp
597
+ if random_seed is not None:
598
+ body["seed"] = random_seed
599
+ if kwargs.get("prompt_id"):
600
+ body["prompt_id"] = kwargs["prompt_id"]
601
+ if kwargs.get("query_id"):
602
+ assert model_id != "poet", f"Model with id {model_id} does not support query"
603
+ body["query_id"] = kwargs["query_id"]
604
+ if "use_query_structure_in_decoder" in kwargs:
605
+ body["use_query_structure_in_decoder"] = kwargs[
606
+ "use_query_structure_in_decoder"
607
+ ]
608
+ if (ensemble_weights := kwargs.get("ensemble_weights")) is not None:
609
+ assert (
610
+ model_id != "poet"
611
+ ), f"Model with id {model_id} does not support ensemble_weights parameter"
612
+ body["ensemble_weights"] = list(ensemble_weights)
613
+ if (ensemble_method := kwargs.get("ensemble_method")) is not None:
614
+ assert (
615
+ model_id != "poet"
616
+ ), f"Model with id {model_id} does not support ensemble_method parameter"
617
+ body["ensemble_method"] = ensemble_method
618
+ response = session.post(endpoint, json=body)
619
+ return GenerateJob.model_validate(response.json())