mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/sdk/future.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
"""Enhanced futures API for MLXSmith SDK.
|
|
2
|
+
|
|
3
|
+
Provides thread-safe futures with callbacks, timeout handling, cancellation,
|
|
4
|
+
and progress tracking for async operations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
from concurrent.futures import Future, ThreadPoolExecutor
|
|
12
|
+
from typing import Any, Callable, Generic, Iterable, Optional, TypeVar, Union
|
|
13
|
+
|
|
14
|
+
from ..llm.backend import DecodingConfig
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class APIFutureState:
|
|
20
|
+
"""Enumeration of future states."""
|
|
21
|
+
PENDING = "pending"
|
|
22
|
+
RUNNING = "running"
|
|
23
|
+
COMPLETED = "completed"
|
|
24
|
+
CANCELLED = "cancelled"
|
|
25
|
+
FAILED = "failed"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class APIFuture(Generic[T]):
|
|
29
|
+
"""Enhanced future with callbacks, timeout, cancellation, and progress tracking.
|
|
30
|
+
|
|
31
|
+
This class wraps a standard concurrent.futures.Future and adds:
|
|
32
|
+
- Promise-style callbacks (.then(), .catch(), .finally_())
|
|
33
|
+
- Timeout handling
|
|
34
|
+
- Cancellation support
|
|
35
|
+
- Progress tracking for long operations
|
|
36
|
+
- Thread-safe state management
|
|
37
|
+
|
|
38
|
+
Example:
|
|
39
|
+
>>> future = client.forward_backward(batch)
|
|
40
|
+
>>> future.then(lambda result: print(f"Loss: {result.loss}"))
|
|
41
|
+
>>> .catch(lambda e: print(f"Error: {e}"))
|
|
42
|
+
>>> .finally_(lambda: print("Done"))
|
|
43
|
+
>>>
|
|
44
|
+
>>> # With timeout
|
|
45
|
+
>>> result = future.result(timeout=30.0)
|
|
46
|
+
>>>
|
|
47
|
+
>>> # Check progress
|
|
48
|
+
>>> print(f"Progress: {future.progress}%")
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(self, future: Optional[Future] = None, operation_id: Optional[str] = None):
|
|
52
|
+
"""Initialize APIFuture.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
future: The underlying concurrent.futures.Future (optional)
|
|
56
|
+
operation_id: Unique identifier for this operation
|
|
57
|
+
"""
|
|
58
|
+
self._future = future
|
|
59
|
+
self._operation_id = operation_id or f"op-{id(self)}"
|
|
60
|
+
self._state = APIFutureState.PENDING
|
|
61
|
+
self._progress: float = 0.0
|
|
62
|
+
self._progress_message: str = ""
|
|
63
|
+
self._result: Optional[T] = None
|
|
64
|
+
self._exception: Optional[BaseException] = None
|
|
65
|
+
|
|
66
|
+
# Callbacks
|
|
67
|
+
self._success_callbacks: list[Callable[[T], Any]] = []
|
|
68
|
+
self._error_callbacks: list[Callable[[BaseException], Any]] = []
|
|
69
|
+
self._finally_callbacks: list[Callable[[], Any]] = []
|
|
70
|
+
self._progress_callbacks: list[Callable[[float, str], Any]] = []
|
|
71
|
+
|
|
72
|
+
# Thread safety
|
|
73
|
+
self._lock = threading.RLock()
|
|
74
|
+
self._done_event = threading.Event()
|
|
75
|
+
|
|
76
|
+
# If wrapped future provided, attach callbacks
|
|
77
|
+
if future is not None:
|
|
78
|
+
future.add_done_callback(self._on_future_done)
|
|
79
|
+
|
|
80
|
+
def _on_future_done(self, future: Future) -> None:
|
|
81
|
+
"""Internal callback when underlying future completes."""
|
|
82
|
+
with self._lock:
|
|
83
|
+
try:
|
|
84
|
+
if future.cancelled():
|
|
85
|
+
self._state = APIFutureState.CANCELLED
|
|
86
|
+
self._run_finally_callbacks()
|
|
87
|
+
else:
|
|
88
|
+
self._result = future.result()
|
|
89
|
+
self._state = APIFutureState.COMPLETED
|
|
90
|
+
self._progress = 100.0
|
|
91
|
+
self._run_success_callbacks(self._result)
|
|
92
|
+
self._run_finally_callbacks()
|
|
93
|
+
except BaseException as e:
|
|
94
|
+
self._exception = e
|
|
95
|
+
self._state = APIFutureState.FAILED
|
|
96
|
+
self._run_error_callbacks(e)
|
|
97
|
+
self._run_finally_callbacks()
|
|
98
|
+
finally:
|
|
99
|
+
self._done_event.set()
|
|
100
|
+
|
|
101
|
+
def _run_success_callbacks(self, result: T) -> None:
|
|
102
|
+
"""Execute success callbacks."""
|
|
103
|
+
for callback in self._success_callbacks:
|
|
104
|
+
try:
|
|
105
|
+
callback(result)
|
|
106
|
+
except Exception:
|
|
107
|
+
pass # Callback errors should not propagate
|
|
108
|
+
|
|
109
|
+
def _run_error_callbacks(self, exception: BaseException) -> None:
|
|
110
|
+
"""Execute error callbacks."""
|
|
111
|
+
for callback in self._error_callbacks:
|
|
112
|
+
try:
|
|
113
|
+
callback(exception)
|
|
114
|
+
except Exception:
|
|
115
|
+
pass # Callback errors should not propagate
|
|
116
|
+
|
|
117
|
+
def _run_finally_callbacks(self) -> None:
|
|
118
|
+
"""Execute finally callbacks."""
|
|
119
|
+
for callback in self._finally_callbacks:
|
|
120
|
+
try:
|
|
121
|
+
callback()
|
|
122
|
+
except Exception:
|
|
123
|
+
pass # Callback errors should not propagate
|
|
124
|
+
|
|
125
|
+
def _run_progress_callbacks(self, progress: float, message: str) -> None:
|
|
126
|
+
"""Execute progress callbacks."""
|
|
127
|
+
for callback in self._progress_callbacks:
|
|
128
|
+
try:
|
|
129
|
+
callback(progress, message)
|
|
130
|
+
except Exception:
|
|
131
|
+
pass # Callback errors should not propagate
|
|
132
|
+
|
|
133
|
+
# ========================================================================
|
|
134
|
+
# Promise-style callbacks
|
|
135
|
+
# ========================================================================
|
|
136
|
+
|
|
137
|
+
def then(self, callback: Callable[[T], Any]) -> APIFuture[T]:
|
|
138
|
+
"""Register a success callback.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
callback: Function to call with the result when successful
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Self for chaining
|
|
145
|
+
"""
|
|
146
|
+
with self._lock:
|
|
147
|
+
if self._state == APIFutureState.COMPLETED and self._result is not None:
|
|
148
|
+
# Already completed, run immediately
|
|
149
|
+
try:
|
|
150
|
+
callback(self._result)
|
|
151
|
+
except Exception:
|
|
152
|
+
pass
|
|
153
|
+
else:
|
|
154
|
+
self._success_callbacks.append(callback)
|
|
155
|
+
return self
|
|
156
|
+
|
|
157
|
+
def catch(self, callback: Callable[[BaseException], Any]) -> APIFuture[T]:
|
|
158
|
+
"""Register an error callback.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
callback: Function to call with the exception when failed
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
Self for chaining
|
|
165
|
+
"""
|
|
166
|
+
with self._lock:
|
|
167
|
+
if self._state == APIFutureState.FAILED and self._exception is not None:
|
|
168
|
+
# Already failed, run immediately
|
|
169
|
+
try:
|
|
170
|
+
callback(self._exception)
|
|
171
|
+
except Exception:
|
|
172
|
+
pass
|
|
173
|
+
else:
|
|
174
|
+
self._error_callbacks.append(callback)
|
|
175
|
+
return self
|
|
176
|
+
|
|
177
|
+
def finally_(self, callback: Callable[[], Any]) -> APIFuture[T]:
|
|
178
|
+
"""Register a callback that runs on completion (success or failure).
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
callback: Function to call when future completes
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Self for chaining
|
|
185
|
+
"""
|
|
186
|
+
with self._lock:
|
|
187
|
+
if self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED):
|
|
188
|
+
# Already done, run immediately
|
|
189
|
+
try:
|
|
190
|
+
callback()
|
|
191
|
+
except Exception:
|
|
192
|
+
pass
|
|
193
|
+
else:
|
|
194
|
+
self._finally_callbacks.append(callback)
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def on_progress(self, callback: Callable[[float, str], Any]) -> APIFuture[T]:
|
|
198
|
+
"""Register a progress callback.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
callback: Function to call with (progress_percent, message)
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
Self for chaining
|
|
205
|
+
"""
|
|
206
|
+
with self._lock:
|
|
207
|
+
self._progress_callbacks.append(callback)
|
|
208
|
+
# Call immediately with current progress
|
|
209
|
+
if self._progress > 0:
|
|
210
|
+
try:
|
|
211
|
+
callback(self._progress, self._progress_message)
|
|
212
|
+
except Exception:
|
|
213
|
+
pass
|
|
214
|
+
return self
|
|
215
|
+
|
|
216
|
+
# ========================================================================
|
|
217
|
+
# Progress tracking
|
|
218
|
+
# ========================================================================
|
|
219
|
+
|
|
220
|
+
def update_progress(self, progress: float, message: str = "") -> None:
|
|
221
|
+
"""Update progress (called by the operation).
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
progress: Progress percentage (0-100)
|
|
225
|
+
message: Optional progress message
|
|
226
|
+
"""
|
|
227
|
+
with self._lock:
|
|
228
|
+
self._progress = max(0.0, min(100.0, progress))
|
|
229
|
+
self._progress_message = message
|
|
230
|
+
self._run_progress_callbacks(self._progress, message)
|
|
231
|
+
|
|
232
|
+
@property
|
|
233
|
+
def progress(self) -> float:
|
|
234
|
+
"""Get current progress percentage (0-100)."""
|
|
235
|
+
with self._lock:
|
|
236
|
+
return self._progress
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def progress_message(self) -> str:
|
|
240
|
+
"""Get current progress message."""
|
|
241
|
+
with self._lock:
|
|
242
|
+
return self._progress_message
|
|
243
|
+
|
|
244
|
+
# ========================================================================
|
|
245
|
+
# State and result access
|
|
246
|
+
# ========================================================================
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def state(self) -> str:
|
|
250
|
+
"""Get current state string."""
|
|
251
|
+
with self._lock:
|
|
252
|
+
return self._state
|
|
253
|
+
|
|
254
|
+
@property
|
|
255
|
+
def operation_id(self) -> str:
|
|
256
|
+
"""Get operation identifier."""
|
|
257
|
+
return self._operation_id
|
|
258
|
+
|
|
259
|
+
@property
|
|
260
|
+
def done(self) -> bool:
|
|
261
|
+
"""Check if future is done (completed, failed, or cancelled)."""
|
|
262
|
+
with self._lock:
|
|
263
|
+
return self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED)
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def cancelled(self) -> bool:
|
|
267
|
+
"""Check if future was cancelled."""
|
|
268
|
+
with self._lock:
|
|
269
|
+
return self._state == APIFutureState.CANCELLED
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def failed(self) -> bool:
|
|
273
|
+
"""Check if future failed with an exception."""
|
|
274
|
+
with self._lock:
|
|
275
|
+
return self._state == APIFutureState.FAILED
|
|
276
|
+
|
|
277
|
+
def result(self, timeout: Optional[float] = None) -> T:
|
|
278
|
+
"""Get the result, blocking if necessary.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
timeout: Maximum time to wait in seconds
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
The result value
|
|
285
|
+
|
|
286
|
+
Raises:
|
|
287
|
+
TimeoutError: If timeout expires
|
|
288
|
+
CancelledError: If future was cancelled
|
|
289
|
+
Exception: If the operation failed
|
|
290
|
+
"""
|
|
291
|
+
if self._future is not None:
|
|
292
|
+
return self._future.result(timeout=timeout)
|
|
293
|
+
|
|
294
|
+
# Wait on our done event
|
|
295
|
+
if not self._done_event.wait(timeout=timeout):
|
|
296
|
+
raise TimeoutError(f"Operation {self._operation_id} timed out after {timeout}s")
|
|
297
|
+
|
|
298
|
+
with self._lock:
|
|
299
|
+
if self._state == APIFutureState.CANCELLED:
|
|
300
|
+
raise Exception(f"Operation {self._operation_id} was cancelled")
|
|
301
|
+
if self._state == APIFutureState.FAILED and self._exception is not None:
|
|
302
|
+
raise self._exception
|
|
303
|
+
return self._result # type: ignore
|
|
304
|
+
|
|
305
|
+
def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
|
|
306
|
+
"""Get the exception if one occurred.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
timeout: Maximum time to wait in seconds
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
The exception or None if successful
|
|
313
|
+
"""
|
|
314
|
+
try:
|
|
315
|
+
self.result(timeout=timeout)
|
|
316
|
+
return None
|
|
317
|
+
except Exception as e:
|
|
318
|
+
return e
|
|
319
|
+
|
|
320
|
+
# ========================================================================
|
|
321
|
+
# Cancellation
|
|
322
|
+
# ========================================================================
|
|
323
|
+
|
|
324
|
+
def cancel(self) -> bool:
|
|
325
|
+
"""Attempt to cancel the future.
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
True if cancellation was successful, False otherwise
|
|
329
|
+
"""
|
|
330
|
+
with self._lock:
|
|
331
|
+
if self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED):
|
|
332
|
+
return False
|
|
333
|
+
|
|
334
|
+
self._state = APIFutureState.CANCELLED
|
|
335
|
+
|
|
336
|
+
if self._future is not None:
|
|
337
|
+
cancelled = self._future.cancel()
|
|
338
|
+
if not cancelled:
|
|
339
|
+
# Could not cancel, revert state
|
|
340
|
+
self._state = APIFutureState.PENDING
|
|
341
|
+
return False
|
|
342
|
+
|
|
343
|
+
self._done_event.set()
|
|
344
|
+
self._run_finally_callbacks()
|
|
345
|
+
return True
|
|
346
|
+
|
|
347
|
+
def cancelled(self) -> bool: # type: ignore
|
|
348
|
+
"""Check if the future was cancelled."""
|
|
349
|
+
with self._lock:
|
|
350
|
+
return self._state == APIFutureState.CANCELLED
|
|
351
|
+
|
|
352
|
+
# ========================================================================
|
|
353
|
+
# Async/await support
|
|
354
|
+
# ========================================================================
|
|
355
|
+
|
|
356
|
+
def __await__(self):
|
|
357
|
+
"""Support for async/await syntax.
|
|
358
|
+
|
|
359
|
+
Example:
|
|
360
|
+
>>> result = await future
|
|
361
|
+
"""
|
|
362
|
+
import asyncio
|
|
363
|
+
|
|
364
|
+
async def _await_result():
|
|
365
|
+
loop = asyncio.get_event_loop()
|
|
366
|
+
return await loop.run_in_executor(None, self.result)
|
|
367
|
+
|
|
368
|
+
return _await_result().__await__()
|
|
369
|
+
|
|
370
|
+
def as_coroutine(self):
|
|
371
|
+
"""Return as an asyncio coroutine.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
A coroutine that resolves to the future's result
|
|
375
|
+
"""
|
|
376
|
+
import asyncio
|
|
377
|
+
|
|
378
|
+
async def _coro():
|
|
379
|
+
loop = asyncio.get_event_loop()
|
|
380
|
+
return await loop.run_in_executor(None, self.result)
|
|
381
|
+
|
|
382
|
+
return _coro()
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
class SdkFuturePool:
|
|
386
|
+
"""Thread-pool wrapper for async SDK calls with enhanced futures."""
|
|
387
|
+
|
|
388
|
+
def __init__(self, max_workers: int = 4):
|
|
389
|
+
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
|
390
|
+
|
|
391
|
+
def submit_sample(self, backend: Any, prompts: Iterable[str], decoding: DecodingConfig) -> APIFuture:
|
|
392
|
+
from . import sample
|
|
393
|
+
|
|
394
|
+
future = self._executor.submit(sample, backend, prompts, decoding)
|
|
395
|
+
return APIFuture(future=future, operation_id=f"sample-{id(future)}")
|
|
396
|
+
|
|
397
|
+
def submit_forward_backward(self, backend: Any, loss_fn) -> APIFuture:
|
|
398
|
+
from . import forward_backward
|
|
399
|
+
|
|
400
|
+
future = self._executor.submit(forward_backward, backend, loss_fn)
|
|
401
|
+
return APIFuture(future=future, operation_id=f"fb-{id(future)}")
|
|
402
|
+
|
|
403
|
+
def submit_optim_step(self, backend: Any, optimizer: Any, grads: Any) -> APIFuture:
|
|
404
|
+
from . import optim_step
|
|
405
|
+
|
|
406
|
+
future = self._executor.submit(optim_step, backend, optimizer, grads)
|
|
407
|
+
return APIFuture(future=future, operation_id=f"optim-{id(future)}")
|
|
408
|
+
|
|
409
|
+
def submit_create_optimizer(self, backend: Any, *, lr: float, weight_decay: float = 0.0) -> APIFuture:
|
|
410
|
+
from . import create_optimizer
|
|
411
|
+
|
|
412
|
+
future = self._executor.submit(create_optimizer, backend, lr=lr, weight_decay=weight_decay)
|
|
413
|
+
return APIFuture(future=future, operation_id=f"opt-create-{id(future)}")
|
|
414
|
+
|
|
415
|
+
def submit(self, fn: Callable[..., T], *args, **kwargs) -> APIFuture[T]:
|
|
416
|
+
"""Submit a generic function to the pool.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
fn: Function to execute
|
|
420
|
+
*args: Positional arguments
|
|
421
|
+
**kwargs: Keyword arguments
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
APIFuture wrapping the submitted task
|
|
425
|
+
"""
|
|
426
|
+
future = self._executor.submit(fn, *args, **kwargs)
|
|
427
|
+
return APIFuture(future=future, operation_id=f"task-{id(future)}")
|
|
428
|
+
|
|
429
|
+
def shutdown(self, wait: bool = True):
|
|
430
|
+
"""Shutdown the thread pool.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
wait: Whether to wait for pending tasks to complete
|
|
434
|
+
"""
|
|
435
|
+
self._executor.shutdown(wait=wait)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
# Convenience function for creating completed futures
|
|
439
|
+
def completed_future(result: T, operation_id: Optional[str] = None) -> APIFuture[T]:
|
|
440
|
+
"""Create an already-completed future with a result.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
result: The result value
|
|
444
|
+
operation_id: Optional operation identifier
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
Completed APIFuture
|
|
448
|
+
"""
|
|
449
|
+
future = APIFuture[T](operation_id=operation_id or f"completed-{id(result)}")
|
|
450
|
+
future._state = APIFutureState.COMPLETED
|
|
451
|
+
future._result = result
|
|
452
|
+
future._progress = 100.0
|
|
453
|
+
future._done_event.set()
|
|
454
|
+
return future
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def failed_future(exception: BaseException, operation_id: Optional[str] = None) -> APIFuture[Any]:
|
|
458
|
+
"""Create an already-failed future with an exception.
|
|
459
|
+
|
|
460
|
+
Args:
|
|
461
|
+
exception: The exception
|
|
462
|
+
operation_id: Optional operation identifier
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
Failed APIFuture
|
|
466
|
+
"""
|
|
467
|
+
future = APIFuture[Any](operation_id=operation_id or f"failed-{id(exception)}")
|
|
468
|
+
future._state = APIFutureState.FAILED
|
|
469
|
+
future._exception = exception
|
|
470
|
+
future._done_event.set()
|
|
471
|
+
return future
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def cancelled_future(operation_id: Optional[str] = None) -> APIFuture[Any]:
|
|
475
|
+
"""Create an already-cancelled future.
|
|
476
|
+
|
|
477
|
+
Args:
|
|
478
|
+
operation_id: Optional operation identifier
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
Cancelled APIFuture
|
|
482
|
+
"""
|
|
483
|
+
future = APIFuture[Any](operation_id=operation_id or f"cancelled-{id(object())}")
|
|
484
|
+
future._state = APIFutureState.CANCELLED
|
|
485
|
+
future._done_event.set()
|
|
486
|
+
return future
|