openprotein-python 0.8.2__1-py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openprotein/__init__.py +164 -0
- openprotein/_version.py +48 -0
- openprotein/align/__init__.py +8 -0
- openprotein/align/align.py +395 -0
- openprotein/align/api.py +428 -0
- openprotein/align/future.py +55 -0
- openprotein/align/msa.py +129 -0
- openprotein/align/schemas.py +165 -0
- openprotein/base.py +181 -0
- openprotein/chains.py +88 -0
- openprotein/common/__init__.py +5 -0
- openprotein/common/features.py +7 -0
- openprotein/common/model_metadata.py +33 -0
- openprotein/common/reduction.py +8 -0
- openprotein/config.py +9 -0
- openprotein/csv.py +31 -0
- openprotein/data/__init__.py +9 -0
- openprotein/data/api.py +218 -0
- openprotein/data/assaydataset.py +178 -0
- openprotein/data/data.py +93 -0
- openprotein/data/schemas.py +27 -0
- openprotein/design/__init__.py +16 -0
- openprotein/design/api.py +259 -0
- openprotein/design/design.py +125 -0
- openprotein/design/future.py +146 -0
- openprotein/design/schemas.py +607 -0
- openprotein/embeddings/__init__.py +27 -0
- openprotein/embeddings/api.py +619 -0
- openprotein/embeddings/embeddings.py +151 -0
- openprotein/embeddings/esm.py +33 -0
- openprotein/embeddings/future.py +146 -0
- openprotein/embeddings/models.py +421 -0
- openprotein/embeddings/openprotein.py +21 -0
- openprotein/embeddings/poet.py +446 -0
- openprotein/embeddings/poet2.py +505 -0
- openprotein/embeddings/schemas.py +78 -0
- openprotein/errors.py +76 -0
- openprotein/fasta.py +92 -0
- openprotein/fold/__init__.py +21 -0
- openprotein/fold/alphafold2.py +131 -0
- openprotein/fold/api.py +287 -0
- openprotein/fold/boltz.py +691 -0
- openprotein/fold/esmfold.py +54 -0
- openprotein/fold/fold.py +107 -0
- openprotein/fold/future.py +509 -0
- openprotein/fold/models.py +139 -0
- openprotein/fold/schemas.py +39 -0
- openprotein/jobs/__init__.py +9 -0
- openprotein/jobs/api.py +71 -0
- openprotein/jobs/futures.py +746 -0
- openprotein/jobs/jobs.py +69 -0
- openprotein/jobs/schemas.py +135 -0
- openprotein/models/__init__.py +4 -0
- openprotein/models/base.py +63 -0
- openprotein/models/foundation/rfdiffusion.py +283 -0
- openprotein/models/models.py +33 -0
- openprotein/predictor/__init__.py +25 -0
- openprotein/predictor/api.py +384 -0
- openprotein/predictor/models.py +374 -0
- openprotein/predictor/prediction.py +79 -0
- openprotein/predictor/predictor.py +242 -0
- openprotein/predictor/schemas.py +113 -0
- openprotein/predictor/validate.py +40 -0
- openprotein/prompt/__init__.py +9 -0
- openprotein/prompt/api.py +505 -0
- openprotein/prompt/models.py +142 -0
- openprotein/prompt/prompt.py +130 -0
- openprotein/prompt/schemas.py +49 -0
- openprotein/protein.py +587 -0
- openprotein/svd/__init__.py +9 -0
- openprotein/svd/api.py +206 -0
- openprotein/svd/models.py +288 -0
- openprotein/svd/schemas.py +31 -0
- openprotein/svd/svd.py +134 -0
- openprotein/umap/__init__.py +9 -0
- openprotein/umap/api.py +259 -0
- openprotein/umap/models.py +211 -0
- openprotein/umap/schemas.py +35 -0
- openprotein/umap/umap.py +175 -0
- openprotein/utils/uuid.py +29 -0
- openprotein_python-0.8.2.dist-info/METADATA +176 -0
- openprotein_python-0.8.2.dist-info/RECORD +84 -0
- openprotein_python-0.8.2.dist-info/WHEEL +4 -0
- openprotein_python-0.8.2.dist-info/licenses/LICENSE.txt +30 -0
|
@@ -0,0 +1,746 @@
|
|
|
1
|
+
"""Application futures for waiting for results from jobs."""
|
|
2
|
+
|
|
3
|
+
import concurrent.futures
|
|
4
|
+
import logging
|
|
5
|
+
import time
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from types import UnionType
|
|
9
|
+
from typing import Collection, Generator
|
|
10
|
+
|
|
11
|
+
import tqdm
|
|
12
|
+
from requests import Response
|
|
13
|
+
from typing_extensions import Self
|
|
14
|
+
|
|
15
|
+
from openprotein import config
|
|
16
|
+
from openprotein.base import APISession
|
|
17
|
+
from openprotein.errors import TimeoutException
|
|
18
|
+
from openprotein.jobs.schemas import Job, JobStatus, JobType
|
|
19
|
+
|
|
20
|
+
from . import api
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Future(ABC):
|
|
26
|
+
"""
|
|
27
|
+
Base class for all Futures returning results from a job.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
# NOTE: This base class should be directly inherited for class discovery by the factory `create` method.
|
|
31
|
+
session: APISession
|
|
32
|
+
job: Job
|
|
33
|
+
|
|
34
|
+
def __init__(self, session: APISession, job: Job):
|
|
35
|
+
self.session = session
|
|
36
|
+
self.job = job
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def create(
|
|
40
|
+
cls: type[Self],
|
|
41
|
+
session: APISession,
|
|
42
|
+
job_id: str | None = None,
|
|
43
|
+
job: Job | None = None,
|
|
44
|
+
response: Response | dict | None = None,
|
|
45
|
+
**kwargs,
|
|
46
|
+
) -> Self:
|
|
47
|
+
"""Create an instance of the appropriate Future class based on the job type.
|
|
48
|
+
|
|
49
|
+
Parameters
|
|
50
|
+
----------
|
|
51
|
+
session : APISession
|
|
52
|
+
Session for API interactions.
|
|
53
|
+
job_id : str | None, optional
|
|
54
|
+
The ID of the Job to initialize this future with.
|
|
55
|
+
job : Job | None, optional
|
|
56
|
+
The Job object to initialize this future with.
|
|
57
|
+
response : Response | dict | None, optional
|
|
58
|
+
The response from a job request returning a job-like object.
|
|
59
|
+
**kwargs
|
|
60
|
+
Additional keyword arguments to pass to the Future class constructor.
|
|
61
|
+
|
|
62
|
+
Returns
|
|
63
|
+
-------
|
|
64
|
+
Self
|
|
65
|
+
An instance of the appropriate Future class.
|
|
66
|
+
|
|
67
|
+
Raises
|
|
68
|
+
------
|
|
69
|
+
ValueError
|
|
70
|
+
If `job_id`, `job`, and `response` are all None.
|
|
71
|
+
ValueError
|
|
72
|
+
If an appropriate Future subclass cannot be found for the job type.
|
|
73
|
+
|
|
74
|
+
:meta private:
|
|
75
|
+
"""
|
|
76
|
+
# parse job
|
|
77
|
+
# default to use job_id first
|
|
78
|
+
if job_id is not None:
|
|
79
|
+
# get job
|
|
80
|
+
job = api.job_get(session=session, job_id=job_id)
|
|
81
|
+
# set obj to parse using job or response
|
|
82
|
+
obj = job or response
|
|
83
|
+
if obj is None:
|
|
84
|
+
raise ValueError("Expected job_id, job or response")
|
|
85
|
+
|
|
86
|
+
# parse specific job
|
|
87
|
+
job = Job.create(obj, **kwargs)
|
|
88
|
+
|
|
89
|
+
# Dynamically discover all subclasses of FutureBase
|
|
90
|
+
future_classes = Future.__subclasses__()
|
|
91
|
+
|
|
92
|
+
# Find the Future class that matches the job
|
|
93
|
+
for future_class in future_classes:
|
|
94
|
+
if (
|
|
95
|
+
type(job) == (future_type := future_class.__annotations__.get("job"))
|
|
96
|
+
or isinstance(future_type, UnionType)
|
|
97
|
+
and type(job) in future_type.__args__
|
|
98
|
+
):
|
|
99
|
+
if isinstance(future_class.__dict__.get("create"), classmethod):
|
|
100
|
+
future = future_class.create(session=session, job=job, **kwargs)
|
|
101
|
+
else:
|
|
102
|
+
future = future_class(session=session, job=job, **kwargs)
|
|
103
|
+
return future # type: ignore - needed since type checker doesnt know subclass
|
|
104
|
+
|
|
105
|
+
raise ValueError(f"Unsupported job type: {job.job_type}")
|
|
106
|
+
|
|
107
|
+
def __str__(self) -> str:
|
|
108
|
+
return str(self.job)
|
|
109
|
+
|
|
110
|
+
def __repr__(self):
|
|
111
|
+
return repr(self.job)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def id(self) -> str:
|
|
115
|
+
"""The unique identifier of the job."""
|
|
116
|
+
return self.job.job_id
|
|
117
|
+
|
|
118
|
+
job_id = id
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def job_type(self) -> str:
|
|
122
|
+
"""The type of the job."""
|
|
123
|
+
return self.job.job_type
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def status(self) -> JobStatus:
|
|
127
|
+
"""The current status of the job."""
|
|
128
|
+
return self.job.status
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def created_date(self) -> datetime:
|
|
132
|
+
"""The creation timestamp of the job."""
|
|
133
|
+
return self.job.created_date
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def start_date(self) -> datetime | None:
|
|
137
|
+
"""The start timestamp of the job."""
|
|
138
|
+
return self.job.start_date
|
|
139
|
+
|
|
140
|
+
@property
|
|
141
|
+
def end_date(self) -> datetime | None:
|
|
142
|
+
"""The end timestamp of the job."""
|
|
143
|
+
return self.job.end_date
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def progress_counter(self) -> int:
|
|
147
|
+
"""The progress counter of the job."""
|
|
148
|
+
return self.job.progress_counter or 0
|
|
149
|
+
|
|
150
|
+
def done(self) -> bool:
|
|
151
|
+
"""Check if the job has completed.
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
bool
|
|
156
|
+
True if the job is done, False otherwise.
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
return self.status.done()
|
|
160
|
+
|
|
161
|
+
def cancelled(self) -> bool:
|
|
162
|
+
"""Check if the job has been cancelled.
|
|
163
|
+
|
|
164
|
+
Returns
|
|
165
|
+
-------
|
|
166
|
+
bool
|
|
167
|
+
True if the job is cancelled, False otherwise.
|
|
168
|
+
|
|
169
|
+
"""
|
|
170
|
+
return self.status.cancelled()
|
|
171
|
+
|
|
172
|
+
def _update_progress(self, job: Job) -> int:
|
|
173
|
+
"""Update progress for jobs that may not have explicit counters.
|
|
174
|
+
|
|
175
|
+
Parameters
|
|
176
|
+
----------
|
|
177
|
+
job : Job
|
|
178
|
+
The job object to update progress from.
|
|
179
|
+
|
|
180
|
+
Returns
|
|
181
|
+
-------
|
|
182
|
+
int
|
|
183
|
+
The calculated progress value (0-100).
|
|
184
|
+
|
|
185
|
+
"""
|
|
186
|
+
progress = job.progress_counter
|
|
187
|
+
# if progress is not None: # Check None before comparison
|
|
188
|
+
if progress is None:
|
|
189
|
+
if job.status == JobStatus.PENDING:
|
|
190
|
+
progress = 5
|
|
191
|
+
if job.status == JobStatus.RUNNING:
|
|
192
|
+
progress = 25
|
|
193
|
+
if job.status in [JobStatus.SUCCESS, JobStatus.FAILURE]:
|
|
194
|
+
progress = 100
|
|
195
|
+
return progress or 0 # never None
|
|
196
|
+
|
|
197
|
+
def _refresh_job(self) -> Job:
|
|
198
|
+
"""Refresh and return the internal job object.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
Job
|
|
203
|
+
The refreshed job object.
|
|
204
|
+
|
|
205
|
+
"""
|
|
206
|
+
# dump extra kwargs to keep on refresh
|
|
207
|
+
kwargs = {
|
|
208
|
+
k: v for k, v in self.job.model_dump().items() if k not in Job.model_fields
|
|
209
|
+
}
|
|
210
|
+
job = Job.create(
|
|
211
|
+
api.job_get(session=self.session, job_id=self.job_id), **kwargs
|
|
212
|
+
)
|
|
213
|
+
return job
|
|
214
|
+
|
|
215
|
+
def refresh(self):
|
|
216
|
+
"""Refresh the job status and internal job object."""
|
|
217
|
+
self.job = self._refresh_job()
|
|
218
|
+
|
|
219
|
+
@abstractmethod
|
|
220
|
+
def get(self, verbose: bool = False, **kwargs):
|
|
221
|
+
"""
|
|
222
|
+
Return the results from this job.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
verbose : bool, optional
|
|
227
|
+
Flag to enable verbose output, by default False.
|
|
228
|
+
**kwargs
|
|
229
|
+
Additional keyword arguments.
|
|
230
|
+
"""
|
|
231
|
+
raise NotImplementedError()
|
|
232
|
+
|
|
233
|
+
def _wait_job(
|
|
234
|
+
self,
|
|
235
|
+
interval: float = config.POLLING_INTERVAL,
|
|
236
|
+
timeout: int | None = None,
|
|
237
|
+
verbose: bool = False,
|
|
238
|
+
) -> Job:
|
|
239
|
+
"""Wait for a job to finish and return the final job object.
|
|
240
|
+
|
|
241
|
+
Parameters
|
|
242
|
+
----------
|
|
243
|
+
interval : float, optional
|
|
244
|
+
Time in seconds to wait between polls.
|
|
245
|
+
Defaults to `config.POLLING_INTERVAL`.
|
|
246
|
+
timeout : int | None, optional
|
|
247
|
+
Maximum time in seconds to wait before raising an error.
|
|
248
|
+
Defaults to None (unlimited).
|
|
249
|
+
verbose : bool, optional
|
|
250
|
+
If True, print status updates. Defaults to False.
|
|
251
|
+
|
|
252
|
+
Returns
|
|
253
|
+
-------
|
|
254
|
+
Job
|
|
255
|
+
The completed job object.
|
|
256
|
+
|
|
257
|
+
Raises
|
|
258
|
+
------
|
|
259
|
+
TimeoutException
|
|
260
|
+
If the wait time exceeds the specified timeout.
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
start_time = time.time()
|
|
264
|
+
|
|
265
|
+
def is_done(job: Job):
|
|
266
|
+
if timeout is not None:
|
|
267
|
+
elapsed_time = time.time() - start_time
|
|
268
|
+
if elapsed_time >= timeout:
|
|
269
|
+
raise TimeoutException(
|
|
270
|
+
f"Wait time exceeded timeout {timeout}, waited {elapsed_time}"
|
|
271
|
+
)
|
|
272
|
+
return job.status.done()
|
|
273
|
+
|
|
274
|
+
pbar = None
|
|
275
|
+
if verbose:
|
|
276
|
+
pbar = tqdm.tqdm(total=100, desc="Waiting", position=0)
|
|
277
|
+
|
|
278
|
+
job = self._refresh_job()
|
|
279
|
+
while not is_done(job):
|
|
280
|
+
if pbar is not None:
|
|
281
|
+
# pbar.update(1)
|
|
282
|
+
# pbar.set_postfix({"status": job.status})
|
|
283
|
+
progress = self._update_progress(job)
|
|
284
|
+
pbar.n = progress
|
|
285
|
+
pbar.set_postfix({"status": job.status})
|
|
286
|
+
# pbar.refresh()
|
|
287
|
+
# print(f'Retry {retries}, status={self.job.status}, time elapsed {time.time() - start_time:.2f}')
|
|
288
|
+
time.sleep(interval)
|
|
289
|
+
job = self._refresh_job()
|
|
290
|
+
|
|
291
|
+
if pbar is not None:
|
|
292
|
+
# pbar.update(1)
|
|
293
|
+
# pbar.set_postfix({"status": job.status})
|
|
294
|
+
|
|
295
|
+
progress = self._update_progress(job)
|
|
296
|
+
pbar.n = progress
|
|
297
|
+
pbar.set_postfix({"status": job.status})
|
|
298
|
+
# pbar.refresh()
|
|
299
|
+
|
|
300
|
+
return job
|
|
301
|
+
|
|
302
|
+
def wait_until_done(
|
|
303
|
+
self,
|
|
304
|
+
interval: float = config.POLLING_INTERVAL,
|
|
305
|
+
timeout: int | None = None,
|
|
306
|
+
verbose: bool = False,
|
|
307
|
+
):
|
|
308
|
+
"""Wait for the job to complete.
|
|
309
|
+
|
|
310
|
+
Parameters
|
|
311
|
+
----------
|
|
312
|
+
interval : float, optional
|
|
313
|
+
Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`.
|
|
314
|
+
timeout : int, optional
|
|
315
|
+
Maximum time in seconds to wait. Defaults to None.
|
|
316
|
+
verbose : bool, optional
|
|
317
|
+
Verbosity flag. Defaults to False.
|
|
318
|
+
|
|
319
|
+
Returns
|
|
320
|
+
-------
|
|
321
|
+
bool
|
|
322
|
+
True if the job completed successfully.
|
|
323
|
+
|
|
324
|
+
Notes
|
|
325
|
+
-----
|
|
326
|
+
This method does not fetch the job results, unlike `wait()`.
|
|
327
|
+
|
|
328
|
+
"""
|
|
329
|
+
job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose)
|
|
330
|
+
self.job = job
|
|
331
|
+
return self.done()
|
|
332
|
+
|
|
333
|
+
def wait(
|
|
334
|
+
self,
|
|
335
|
+
interval: int = config.POLLING_INTERVAL,
|
|
336
|
+
timeout: int | None = None,
|
|
337
|
+
verbose: bool = False,
|
|
338
|
+
):
|
|
339
|
+
"""Wait for the job to complete, then fetch results.
|
|
340
|
+
|
|
341
|
+
Parameters
|
|
342
|
+
----------
|
|
343
|
+
interval : int, optional
|
|
344
|
+
Time in seconds between polling. Defaults to `config.POLLING_INTERVAL`.
|
|
345
|
+
timeout : int | None, optional
|
|
346
|
+
Maximum time in seconds to wait. Defaults to None.
|
|
347
|
+
verbose : bool, optional
|
|
348
|
+
Verbosity flag. Defaults to False.
|
|
349
|
+
|
|
350
|
+
Returns
|
|
351
|
+
-------
|
|
352
|
+
Any
|
|
353
|
+
The results of the job.
|
|
354
|
+
|
|
355
|
+
"""
|
|
356
|
+
time.sleep(1) # buffer for BE to register job
|
|
357
|
+
job = self._wait_job(interval=interval, timeout=timeout, verbose=verbose)
|
|
358
|
+
self.job = job
|
|
359
|
+
return self.get()
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class StreamingFuture(ABC):
|
|
363
|
+
"""Abstract base class for Futures that support streaming results."""
|
|
364
|
+
|
|
365
|
+
@abstractmethod
|
|
366
|
+
def stream(self, **kwargs) -> Generator:
|
|
367
|
+
"""Return the results from this job as a generator.
|
|
368
|
+
|
|
369
|
+
Parameters
|
|
370
|
+
----------
|
|
371
|
+
**kwargs
|
|
372
|
+
Keyword arguments passed to the streaming implementation.
|
|
373
|
+
|
|
374
|
+
Returns
|
|
375
|
+
-------
|
|
376
|
+
Generator
|
|
377
|
+
A generator that yields job results.
|
|
378
|
+
|
|
379
|
+
Raises
|
|
380
|
+
------
|
|
381
|
+
NotImplementedError
|
|
382
|
+
This is an abstract method and must be implemented by a subclass.
|
|
383
|
+
|
|
384
|
+
"""
|
|
385
|
+
raise NotImplementedError()
|
|
386
|
+
|
|
387
|
+
def get(self, verbose: bool = False, **kwargs) -> list:
|
|
388
|
+
"""Return all results from the job by consuming the stream.
|
|
389
|
+
|
|
390
|
+
Parameters
|
|
391
|
+
----------
|
|
392
|
+
verbose : bool, optional
|
|
393
|
+
If True, display a progress bar. Defaults to False.
|
|
394
|
+
**kwargs
|
|
395
|
+
Keyword arguments passed to the `stream` method.
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
list
|
|
400
|
+
A list containing all results from the job.
|
|
401
|
+
|
|
402
|
+
"""
|
|
403
|
+
generator = self.stream(**kwargs)
|
|
404
|
+
if verbose:
|
|
405
|
+
total = None
|
|
406
|
+
if hasattr(self, "__len__"):
|
|
407
|
+
total = len(self) # type: ignore - static type checker doesnt know
|
|
408
|
+
generator = tqdm.tqdm(
|
|
409
|
+
generator, desc="Retrieving", total=total, position=0, mininterval=1.0
|
|
410
|
+
)
|
|
411
|
+
return [entry for entry in generator]
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class MappedFuture(StreamingFuture, ABC):
|
|
415
|
+
"""Base future for jobs with a key-to-result mapping.
|
|
416
|
+
|
|
417
|
+
This class provides methods to retrieve results from jobs where each result
|
|
418
|
+
is associated with a unique key (e.g., sequence to embedding).
|
|
419
|
+
|
|
420
|
+
"""
|
|
421
|
+
|
|
422
|
+
def __init__(
|
|
423
|
+
self,
|
|
424
|
+
session: APISession,
|
|
425
|
+
job: Job,
|
|
426
|
+
max_workers: int = config.MAX_CONCURRENT_WORKERS,
|
|
427
|
+
):
|
|
428
|
+
"""Initialize the MappedFuture.
|
|
429
|
+
|
|
430
|
+
Parameters
|
|
431
|
+
----------
|
|
432
|
+
session : APISession
|
|
433
|
+
The session for API interactions.
|
|
434
|
+
job : Job
|
|
435
|
+
The job to retrieve results from.
|
|
436
|
+
max_workers : int, optional
|
|
437
|
+
The number of workers for concurrent result retrieval.
|
|
438
|
+
Defaults to `config.MAX_CONCURRENT_WORKERS`.
|
|
439
|
+
|
|
440
|
+
Notes
|
|
441
|
+
-----
|
|
442
|
+
Use `max_workers` > 0 to enable concurrent retrieval.
|
|
443
|
+
|
|
444
|
+
"""
|
|
445
|
+
self.session = session
|
|
446
|
+
self.job = job
|
|
447
|
+
self.max_workers = max_workers
|
|
448
|
+
self._cache = {}
|
|
449
|
+
|
|
450
|
+
@abstractmethod
|
|
451
|
+
def __keys__(self):
|
|
452
|
+
"""Return the keys for the mapped results.
|
|
453
|
+
|
|
454
|
+
Raises
|
|
455
|
+
------
|
|
456
|
+
NotImplementedError
|
|
457
|
+
This is an abstract method and must be implemented by a subclass.
|
|
458
|
+
|
|
459
|
+
"""
|
|
460
|
+
raise NotImplementedError()
|
|
461
|
+
|
|
462
|
+
@abstractmethod
|
|
463
|
+
def get_item(self, k):
|
|
464
|
+
"""Retrieve a single item by its key.
|
|
465
|
+
|
|
466
|
+
Parameters
|
|
467
|
+
----------
|
|
468
|
+
k
|
|
469
|
+
The key of the item to retrieve.
|
|
470
|
+
|
|
471
|
+
Raises
|
|
472
|
+
------
|
|
473
|
+
NotImplementedError
|
|
474
|
+
This is an abstract method and must be implemented by a subclass.
|
|
475
|
+
|
|
476
|
+
"""
|
|
477
|
+
raise NotImplementedError()
|
|
478
|
+
|
|
479
|
+
def stream_sync(self):
|
|
480
|
+
"""Stream the results synchronously.
|
|
481
|
+
|
|
482
|
+
Yields
|
|
483
|
+
------
|
|
484
|
+
tuple
|
|
485
|
+
A tuple of (key, value) for each result.
|
|
486
|
+
|
|
487
|
+
:meta private:
|
|
488
|
+
"""
|
|
489
|
+
for k in self.__keys__():
|
|
490
|
+
v = self[k]
|
|
491
|
+
yield k, v
|
|
492
|
+
|
|
493
|
+
def stream_parallel(self):
|
|
494
|
+
"""Stream the results in parallel using a thread pool.
|
|
495
|
+
|
|
496
|
+
Yields
|
|
497
|
+
------
|
|
498
|
+
tuple
|
|
499
|
+
A tuple of (key, value) for each result.
|
|
500
|
+
|
|
501
|
+
:meta private:
|
|
502
|
+
"""
|
|
503
|
+
num_workers = self.max_workers
|
|
504
|
+
|
|
505
|
+
def process(k):
|
|
506
|
+
v = self[k]
|
|
507
|
+
return k, v
|
|
508
|
+
|
|
509
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
510
|
+
futures = []
|
|
511
|
+
for k in self.__keys__():
|
|
512
|
+
if k in self._cache:
|
|
513
|
+
yield k, self._cache[k]
|
|
514
|
+
else:
|
|
515
|
+
f = executor.submit(process, k)
|
|
516
|
+
futures.append(f)
|
|
517
|
+
|
|
518
|
+
for f in futures:
|
|
519
|
+
yield f.result()
|
|
520
|
+
|
|
521
|
+
def stream(self):
|
|
522
|
+
"""Retrieve results for this job as a stream.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
Generator
|
|
527
|
+
A generator that yields (key, value) tuples.
|
|
528
|
+
|
|
529
|
+
"""
|
|
530
|
+
if self.max_workers > 0:
|
|
531
|
+
return self.stream_parallel()
|
|
532
|
+
return self.stream_sync()
|
|
533
|
+
|
|
534
|
+
def __getitem__(self, k):
|
|
535
|
+
"""Get an item by key, using the cache if available.
|
|
536
|
+
|
|
537
|
+
Parameters
|
|
538
|
+
----------
|
|
539
|
+
k
|
|
540
|
+
The key of the item to retrieve.
|
|
541
|
+
|
|
542
|
+
Returns
|
|
543
|
+
-------
|
|
544
|
+
Any
|
|
545
|
+
The value associated with the key.
|
|
546
|
+
|
|
547
|
+
"""
|
|
548
|
+
if k in self._cache:
|
|
549
|
+
return self._cache[k]
|
|
550
|
+
v = self.get_item(k)
|
|
551
|
+
self._cache[k] = v
|
|
552
|
+
return v
|
|
553
|
+
|
|
554
|
+
def __len__(self):
|
|
555
|
+
"""Return the total number of items."""
|
|
556
|
+
return len(self.__keys__())
|
|
557
|
+
|
|
558
|
+
def __iter__(self):
|
|
559
|
+
"""Return an iterator over the results."""
|
|
560
|
+
return self.stream()
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
class PagedFuture(StreamingFuture, ABC):
|
|
564
|
+
"""Base future class for jobs which have paged results."""
|
|
565
|
+
|
|
566
|
+
DEFAULT_PAGE_SIZE = 1024
|
|
567
|
+
|
|
568
|
+
def __init__(
|
|
569
|
+
self,
|
|
570
|
+
session: APISession,
|
|
571
|
+
job: Job,
|
|
572
|
+
page_size: int | None = None,
|
|
573
|
+
num_records: int | None = None,
|
|
574
|
+
max_workers: int = config.MAX_CONCURRENT_WORKERS,
|
|
575
|
+
):
|
|
576
|
+
"""Initialize the PagedFuture.
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
session : APISession
|
|
581
|
+
The session for API interactions.
|
|
582
|
+
job : Job
|
|
583
|
+
The job to retrieve results from.
|
|
584
|
+
page_size : int | None, optional
|
|
585
|
+
The number of records per page. Defaults to `DEFAULT_PAGE_SIZE`.
|
|
586
|
+
num_records : int | None, optional
|
|
587
|
+
The total number of records expected.
|
|
588
|
+
max_workers : int, optional
|
|
589
|
+
Number of workers for concurrent page retrieval.
|
|
590
|
+
Defaults to `config.MAX_CONCURRENT_WORKERS`.
|
|
591
|
+
|
|
592
|
+
Notes
|
|
593
|
+
-----
|
|
594
|
+
Use `max_workers` > 0 to enable concurrent retrieval of multiple pages.
|
|
595
|
+
|
|
596
|
+
"""
|
|
597
|
+
if page_size is None:
|
|
598
|
+
page_size = self.DEFAULT_PAGE_SIZE
|
|
599
|
+
self.session = session
|
|
600
|
+
self.job = job
|
|
601
|
+
self.page_size = page_size
|
|
602
|
+
self.max_workers = max_workers
|
|
603
|
+
self._num_records = num_records
|
|
604
|
+
|
|
605
|
+
@abstractmethod
|
|
606
|
+
def get_slice(self, start: int, end: int, **kwargs) -> Collection:
|
|
607
|
+
"""Retrieve a slice of results.
|
|
608
|
+
|
|
609
|
+
Parameters
|
|
610
|
+
----------
|
|
611
|
+
start : int
|
|
612
|
+
The starting index of the slice.
|
|
613
|
+
end : int
|
|
614
|
+
The ending index of the slice.
|
|
615
|
+
**kwargs
|
|
616
|
+
Additional keyword arguments.
|
|
617
|
+
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
Collection
|
|
621
|
+
A collection of results for the specified slice.
|
|
622
|
+
|
|
623
|
+
Raises
|
|
624
|
+
------
|
|
625
|
+
NotImplementedError
|
|
626
|
+
This is an abstract method and must be implemented by a subclass.
|
|
627
|
+
|
|
628
|
+
"""
|
|
629
|
+
raise NotImplementedError()
|
|
630
|
+
|
|
631
|
+
def stream_sync(self):
|
|
632
|
+
"""Stream results by fetching pages synchronously.
|
|
633
|
+
|
|
634
|
+
Yields
|
|
635
|
+
------
|
|
636
|
+
Any
|
|
637
|
+
Individual results from the paged endpoint.
|
|
638
|
+
|
|
639
|
+
:meta private:
|
|
640
|
+
"""
|
|
641
|
+
step = self.page_size
|
|
642
|
+
num_returned = step
|
|
643
|
+
offset = 0
|
|
644
|
+
while num_returned >= step:
|
|
645
|
+
result_page = self.get_slice(start=offset, end=offset + step)
|
|
646
|
+
for result in result_page:
|
|
647
|
+
yield result
|
|
648
|
+
num_returned = len(result_page)
|
|
649
|
+
offset += num_returned
|
|
650
|
+
|
|
651
|
+
def stream_parallel(self):
|
|
652
|
+
"""Stream results by fetching pages in parallel.
|
|
653
|
+
|
|
654
|
+
Yields
|
|
655
|
+
------
|
|
656
|
+
Any
|
|
657
|
+
Individual results from the paged endpoint.
|
|
658
|
+
|
|
659
|
+
Notes
|
|
660
|
+
-----
|
|
661
|
+
The number of results should be checked, or stored somehow, so that
|
|
662
|
+
we don't need to check the number of returned entries to see if we're
|
|
663
|
+
finished (very awkward when using concurrency).
|
|
664
|
+
|
|
665
|
+
:meta private:
|
|
666
|
+
"""
|
|
667
|
+
step = self.page_size
|
|
668
|
+
offset = 0
|
|
669
|
+
|
|
670
|
+
num_workers = self.max_workers
|
|
671
|
+
|
|
672
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
673
|
+
# submit the paged requests
|
|
674
|
+
futures: dict[concurrent.futures.Future, int] = {}
|
|
675
|
+
index: int = 0
|
|
676
|
+
for _ in range(num_workers * 2):
|
|
677
|
+
f = executor.submit(self.get_slice, offset, offset + step)
|
|
678
|
+
futures[f] = index
|
|
679
|
+
index += 1
|
|
680
|
+
offset += step
|
|
681
|
+
|
|
682
|
+
# until we've retrieved all pages (known by retrieving a page with less than the requested number of records)
|
|
683
|
+
done = False
|
|
684
|
+
while not done:
|
|
685
|
+
results: list[list | None] = [None] * len(futures)
|
|
686
|
+
futures_next: dict[concurrent.futures.Future, int] = {}
|
|
687
|
+
index_next: int = 0
|
|
688
|
+
next_result_index = 0
|
|
689
|
+
# iterate the futures and submit new requests as needed
|
|
690
|
+
for f in concurrent.futures.as_completed(futures):
|
|
691
|
+
index = futures[f]
|
|
692
|
+
result_page = f.result()
|
|
693
|
+
results[index] = result_page
|
|
694
|
+
# check if we're done, meaning the result page is not full
|
|
695
|
+
done = done or len(result_page) < step
|
|
696
|
+
# if we aren't done, submit another request
|
|
697
|
+
if not done:
|
|
698
|
+
f = executor.submit(self.get_slice, offset, offset + step)
|
|
699
|
+
futures_next[f] = index_next
|
|
700
|
+
index_next += 1
|
|
701
|
+
offset += step
|
|
702
|
+
# yield the results from this page
|
|
703
|
+
while (
|
|
704
|
+
next_result_index < len(results)
|
|
705
|
+
and results[next_result_index] is not None
|
|
706
|
+
):
|
|
707
|
+
result_page = results[next_result_index]
|
|
708
|
+
assert result_page is not None # checked above
|
|
709
|
+
for result in result_page:
|
|
710
|
+
yield result
|
|
711
|
+
next_result_index += 1
|
|
712
|
+
# update the list of futures and wait on them again
|
|
713
|
+
futures = futures_next
|
|
714
|
+
|
|
715
|
+
def stream(self):
|
|
716
|
+
"""Retrieve results for this job as a stream.
|
|
717
|
+
|
|
718
|
+
Returns
|
|
719
|
+
-------
|
|
720
|
+
Generator
|
|
721
|
+
A generator that yields job results.
|
|
722
|
+
|
|
723
|
+
"""
|
|
724
|
+
if self.max_workers > 0:
|
|
725
|
+
return self.stream_parallel()
|
|
726
|
+
return self.stream_sync()
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
class InvalidFutureError(Exception):
|
|
730
|
+
"""Error for when an unexpected future is created from a job."""
|
|
731
|
+
|
|
732
|
+
def __init__(self, future: Future, expected: type[Future]):
|
|
733
|
+
"""Initialize the InvalidFutureError.
|
|
734
|
+
|
|
735
|
+
Parameters
|
|
736
|
+
----------
|
|
737
|
+
future : Future
|
|
738
|
+
The future instance that was created.
|
|
739
|
+
expected : type[Future]
|
|
740
|
+
The type of future that was expected.
|
|
741
|
+
|
|
742
|
+
"""
|
|
743
|
+
self.future = future
|
|
744
|
+
self.expected = future
|
|
745
|
+
self.message = f"Expected future of type {expected}, got {type(future)}"
|
|
746
|
+
super().__init__(self.message)
|