mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/sdk/future.py ADDED
@@ -0,0 +1,486 @@
1
+ """Enhanced futures API for MLXSmith SDK.
2
+
3
+ Provides thread-safe futures with callbacks, timeout handling, cancellation,
4
+ and progress tracking for async operations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import threading
10
+ import time
11
+ from concurrent.futures import Future, ThreadPoolExecutor
12
+ from typing import Any, Callable, Generic, Iterable, Optional, TypeVar, Union
13
+
14
+ from ..llm.backend import DecodingConfig
15
+
16
+ T = TypeVar("T")
17
+
18
+
19
+ class APIFutureState:
20
+ """Enumeration of future states."""
21
+ PENDING = "pending"
22
+ RUNNING = "running"
23
+ COMPLETED = "completed"
24
+ CANCELLED = "cancelled"
25
+ FAILED = "failed"
26
+
27
+
28
+ class APIFuture(Generic[T]):
29
+ """Enhanced future with callbacks, timeout, cancellation, and progress tracking.
30
+
31
+ This class wraps a standard concurrent.futures.Future and adds:
32
+ - Promise-style callbacks (.then(), .catch(), .finally_())
33
+ - Timeout handling
34
+ - Cancellation support
35
+ - Progress tracking for long operations
36
+ - Thread-safe state management
37
+
38
+ Example:
39
+ >>> future = client.forward_backward(batch)
40
+ >>> future.then(lambda result: print(f"Loss: {result.loss}"))
41
+ >>> .catch(lambda e: print(f"Error: {e}"))
42
+ >>> .finally_(lambda: print("Done"))
43
+ >>>
44
+ >>> # With timeout
45
+ >>> result = future.result(timeout=30.0)
46
+ >>>
47
+ >>> # Check progress
48
+ >>> print(f"Progress: {future.progress}%")
49
+ """
50
+
51
+ def __init__(self, future: Optional[Future] = None, operation_id: Optional[str] = None):
52
+ """Initialize APIFuture.
53
+
54
+ Args:
55
+ future: The underlying concurrent.futures.Future (optional)
56
+ operation_id: Unique identifier for this operation
57
+ """
58
+ self._future = future
59
+ self._operation_id = operation_id or f"op-{id(self)}"
60
+ self._state = APIFutureState.PENDING
61
+ self._progress: float = 0.0
62
+ self._progress_message: str = ""
63
+ self._result: Optional[T] = None
64
+ self._exception: Optional[BaseException] = None
65
+
66
+ # Callbacks
67
+ self._success_callbacks: list[Callable[[T], Any]] = []
68
+ self._error_callbacks: list[Callable[[BaseException], Any]] = []
69
+ self._finally_callbacks: list[Callable[[], Any]] = []
70
+ self._progress_callbacks: list[Callable[[float, str], Any]] = []
71
+
72
+ # Thread safety
73
+ self._lock = threading.RLock()
74
+ self._done_event = threading.Event()
75
+
76
+ # If wrapped future provided, attach callbacks
77
+ if future is not None:
78
+ future.add_done_callback(self._on_future_done)
79
+
80
+ def _on_future_done(self, future: Future) -> None:
81
+ """Internal callback when underlying future completes."""
82
+ with self._lock:
83
+ try:
84
+ if future.cancelled():
85
+ self._state = APIFutureState.CANCELLED
86
+ self._run_finally_callbacks()
87
+ else:
88
+ self._result = future.result()
89
+ self._state = APIFutureState.COMPLETED
90
+ self._progress = 100.0
91
+ self._run_success_callbacks(self._result)
92
+ self._run_finally_callbacks()
93
+ except BaseException as e:
94
+ self._exception = e
95
+ self._state = APIFutureState.FAILED
96
+ self._run_error_callbacks(e)
97
+ self._run_finally_callbacks()
98
+ finally:
99
+ self._done_event.set()
100
+
101
+ def _run_success_callbacks(self, result: T) -> None:
102
+ """Execute success callbacks."""
103
+ for callback in self._success_callbacks:
104
+ try:
105
+ callback(result)
106
+ except Exception:
107
+ pass # Callback errors should not propagate
108
+
109
+ def _run_error_callbacks(self, exception: BaseException) -> None:
110
+ """Execute error callbacks."""
111
+ for callback in self._error_callbacks:
112
+ try:
113
+ callback(exception)
114
+ except Exception:
115
+ pass # Callback errors should not propagate
116
+
117
+ def _run_finally_callbacks(self) -> None:
118
+ """Execute finally callbacks."""
119
+ for callback in self._finally_callbacks:
120
+ try:
121
+ callback()
122
+ except Exception:
123
+ pass # Callback errors should not propagate
124
+
125
+ def _run_progress_callbacks(self, progress: float, message: str) -> None:
126
+ """Execute progress callbacks."""
127
+ for callback in self._progress_callbacks:
128
+ try:
129
+ callback(progress, message)
130
+ except Exception:
131
+ pass # Callback errors should not propagate
132
+
133
+ # ========================================================================
134
+ # Promise-style callbacks
135
+ # ========================================================================
136
+
137
+ def then(self, callback: Callable[[T], Any]) -> APIFuture[T]:
138
+ """Register a success callback.
139
+
140
+ Args:
141
+ callback: Function to call with the result when successful
142
+
143
+ Returns:
144
+ Self for chaining
145
+ """
146
+ with self._lock:
147
+ if self._state == APIFutureState.COMPLETED and self._result is not None:
148
+ # Already completed, run immediately
149
+ try:
150
+ callback(self._result)
151
+ except Exception:
152
+ pass
153
+ else:
154
+ self._success_callbacks.append(callback)
155
+ return self
156
+
157
+ def catch(self, callback: Callable[[BaseException], Any]) -> APIFuture[T]:
158
+ """Register an error callback.
159
+
160
+ Args:
161
+ callback: Function to call with the exception when failed
162
+
163
+ Returns:
164
+ Self for chaining
165
+ """
166
+ with self._lock:
167
+ if self._state == APIFutureState.FAILED and self._exception is not None:
168
+ # Already failed, run immediately
169
+ try:
170
+ callback(self._exception)
171
+ except Exception:
172
+ pass
173
+ else:
174
+ self._error_callbacks.append(callback)
175
+ return self
176
+
177
+ def finally_(self, callback: Callable[[], Any]) -> APIFuture[T]:
178
+ """Register a callback that runs on completion (success or failure).
179
+
180
+ Args:
181
+ callback: Function to call when future completes
182
+
183
+ Returns:
184
+ Self for chaining
185
+ """
186
+ with self._lock:
187
+ if self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED):
188
+ # Already done, run immediately
189
+ try:
190
+ callback()
191
+ except Exception:
192
+ pass
193
+ else:
194
+ self._finally_callbacks.append(callback)
195
+ return self
196
+
197
+ def on_progress(self, callback: Callable[[float, str], Any]) -> APIFuture[T]:
198
+ """Register a progress callback.
199
+
200
+ Args:
201
+ callback: Function to call with (progress_percent, message)
202
+
203
+ Returns:
204
+ Self for chaining
205
+ """
206
+ with self._lock:
207
+ self._progress_callbacks.append(callback)
208
+ # Call immediately with current progress
209
+ if self._progress > 0:
210
+ try:
211
+ callback(self._progress, self._progress_message)
212
+ except Exception:
213
+ pass
214
+ return self
215
+
216
+ # ========================================================================
217
+ # Progress tracking
218
+ # ========================================================================
219
+
220
+ def update_progress(self, progress: float, message: str = "") -> None:
221
+ """Update progress (called by the operation).
222
+
223
+ Args:
224
+ progress: Progress percentage (0-100)
225
+ message: Optional progress message
226
+ """
227
+ with self._lock:
228
+ self._progress = max(0.0, min(100.0, progress))
229
+ self._progress_message = message
230
+ self._run_progress_callbacks(self._progress, message)
231
+
232
+ @property
233
+ def progress(self) -> float:
234
+ """Get current progress percentage (0-100)."""
235
+ with self._lock:
236
+ return self._progress
237
+
238
+ @property
239
+ def progress_message(self) -> str:
240
+ """Get current progress message."""
241
+ with self._lock:
242
+ return self._progress_message
243
+
244
+ # ========================================================================
245
+ # State and result access
246
+ # ========================================================================
247
+
248
+ @property
249
+ def state(self) -> str:
250
+ """Get current state string."""
251
+ with self._lock:
252
+ return self._state
253
+
254
+ @property
255
+ def operation_id(self) -> str:
256
+ """Get operation identifier."""
257
+ return self._operation_id
258
+
259
+ @property
260
+ def done(self) -> bool:
261
+ """Check if future is done (completed, failed, or cancelled)."""
262
+ with self._lock:
263
+ return self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED)
264
+
265
+ @property
266
+ def cancelled(self) -> bool:
267
+ """Check if future was cancelled."""
268
+ with self._lock:
269
+ return self._state == APIFutureState.CANCELLED
270
+
271
+ @property
272
+ def failed(self) -> bool:
273
+ """Check if future failed with an exception."""
274
+ with self._lock:
275
+ return self._state == APIFutureState.FAILED
276
+
277
+ def result(self, timeout: Optional[float] = None) -> T:
278
+ """Get the result, blocking if necessary.
279
+
280
+ Args:
281
+ timeout: Maximum time to wait in seconds
282
+
283
+ Returns:
284
+ The result value
285
+
286
+ Raises:
287
+ TimeoutError: If timeout expires
288
+ CancelledError: If future was cancelled
289
+ Exception: If the operation failed
290
+ """
291
+ if self._future is not None:
292
+ return self._future.result(timeout=timeout)
293
+
294
+ # Wait on our done event
295
+ if not self._done_event.wait(timeout=timeout):
296
+ raise TimeoutError(f"Operation {self._operation_id} timed out after {timeout}s")
297
+
298
+ with self._lock:
299
+ if self._state == APIFutureState.CANCELLED:
300
+ raise Exception(f"Operation {self._operation_id} was cancelled")
301
+ if self._state == APIFutureState.FAILED and self._exception is not None:
302
+ raise self._exception
303
+ return self._result # type: ignore
304
+
305
+ def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
306
+ """Get the exception if one occurred.
307
+
308
+ Args:
309
+ timeout: Maximum time to wait in seconds
310
+
311
+ Returns:
312
+ The exception or None if successful
313
+ """
314
+ try:
315
+ self.result(timeout=timeout)
316
+ return None
317
+ except Exception as e:
318
+ return e
319
+
320
+ # ========================================================================
321
+ # Cancellation
322
+ # ========================================================================
323
+
324
+ def cancel(self) -> bool:
325
+ """Attempt to cancel the future.
326
+
327
+ Returns:
328
+ True if cancellation was successful, False otherwise
329
+ """
330
+ with self._lock:
331
+ if self._state in (APIFutureState.COMPLETED, APIFutureState.FAILED, APIFutureState.CANCELLED):
332
+ return False
333
+
334
+ self._state = APIFutureState.CANCELLED
335
+
336
+ if self._future is not None:
337
+ cancelled = self._future.cancel()
338
+ if not cancelled:
339
+ # Could not cancel, revert state
340
+ self._state = APIFutureState.PENDING
341
+ return False
342
+
343
+ self._done_event.set()
344
+ self._run_finally_callbacks()
345
+ return True
346
+
347
+ def cancelled(self) -> bool: # type: ignore
348
+ """Check if the future was cancelled."""
349
+ with self._lock:
350
+ return self._state == APIFutureState.CANCELLED
351
+
352
+ # ========================================================================
353
+ # Async/await support
354
+ # ========================================================================
355
+
356
+ def __await__(self):
357
+ """Support for async/await syntax.
358
+
359
+ Example:
360
+ >>> result = await future
361
+ """
362
+ import asyncio
363
+
364
+ async def _await_result():
365
+ loop = asyncio.get_event_loop()
366
+ return await loop.run_in_executor(None, self.result)
367
+
368
+ return _await_result().__await__()
369
+
370
+ def as_coroutine(self):
371
+ """Return as an asyncio coroutine.
372
+
373
+ Returns:
374
+ A coroutine that resolves to the future's result
375
+ """
376
+ import asyncio
377
+
378
+ async def _coro():
379
+ loop = asyncio.get_event_loop()
380
+ return await loop.run_in_executor(None, self.result)
381
+
382
+ return _coro()
383
+
384
+
385
+ class SdkFuturePool:
386
+ """Thread-pool wrapper for async SDK calls with enhanced futures."""
387
+
388
+ def __init__(self, max_workers: int = 4):
389
+ self._executor = ThreadPoolExecutor(max_workers=max_workers)
390
+
391
+ def submit_sample(self, backend: Any, prompts: Iterable[str], decoding: DecodingConfig) -> APIFuture:
392
+ from . import sample
393
+
394
+ future = self._executor.submit(sample, backend, prompts, decoding)
395
+ return APIFuture(future=future, operation_id=f"sample-{id(future)}")
396
+
397
+ def submit_forward_backward(self, backend: Any, loss_fn) -> APIFuture:
398
+ from . import forward_backward
399
+
400
+ future = self._executor.submit(forward_backward, backend, loss_fn)
401
+ return APIFuture(future=future, operation_id=f"fb-{id(future)}")
402
+
403
+ def submit_optim_step(self, backend: Any, optimizer: Any, grads: Any) -> APIFuture:
404
+ from . import optim_step
405
+
406
+ future = self._executor.submit(optim_step, backend, optimizer, grads)
407
+ return APIFuture(future=future, operation_id=f"optim-{id(future)}")
408
+
409
+ def submit_create_optimizer(self, backend: Any, *, lr: float, weight_decay: float = 0.0) -> APIFuture:
410
+ from . import create_optimizer
411
+
412
+ future = self._executor.submit(create_optimizer, backend, lr=lr, weight_decay=weight_decay)
413
+ return APIFuture(future=future, operation_id=f"opt-create-{id(future)}")
414
+
415
+ def submit(self, fn: Callable[..., T], *args, **kwargs) -> APIFuture[T]:
416
+ """Submit a generic function to the pool.
417
+
418
+ Args:
419
+ fn: Function to execute
420
+ *args: Positional arguments
421
+ **kwargs: Keyword arguments
422
+
423
+ Returns:
424
+ APIFuture wrapping the submitted task
425
+ """
426
+ future = self._executor.submit(fn, *args, **kwargs)
427
+ return APIFuture(future=future, operation_id=f"task-{id(future)}")
428
+
429
+ def shutdown(self, wait: bool = True):
430
+ """Shutdown the thread pool.
431
+
432
+ Args:
433
+ wait: Whether to wait for pending tasks to complete
434
+ """
435
+ self._executor.shutdown(wait=wait)
436
+
437
+
438
+ # Convenience function for creating completed futures
439
+ def completed_future(result: T, operation_id: Optional[str] = None) -> APIFuture[T]:
440
+ """Create an already-completed future with a result.
441
+
442
+ Args:
443
+ result: The result value
444
+ operation_id: Optional operation identifier
445
+
446
+ Returns:
447
+ Completed APIFuture
448
+ """
449
+ future = APIFuture[T](operation_id=operation_id or f"completed-{id(result)}")
450
+ future._state = APIFutureState.COMPLETED
451
+ future._result = result
452
+ future._progress = 100.0
453
+ future._done_event.set()
454
+ return future
455
+
456
+
457
+ def failed_future(exception: BaseException, operation_id: Optional[str] = None) -> APIFuture[Any]:
458
+ """Create an already-failed future with an exception.
459
+
460
+ Args:
461
+ exception: The exception
462
+ operation_id: Optional operation identifier
463
+
464
+ Returns:
465
+ Failed APIFuture
466
+ """
467
+ future = APIFuture[Any](operation_id=operation_id or f"failed-{id(exception)}")
468
+ future._state = APIFutureState.FAILED
469
+ future._exception = exception
470
+ future._done_event.set()
471
+ return future
472
+
473
+
474
+ def cancelled_future(operation_id: Optional[str] = None) -> APIFuture[Any]:
475
+ """Create an already-cancelled future.
476
+
477
+ Args:
478
+ operation_id: Optional operation identifier
479
+
480
+ Returns:
481
+ Cancelled APIFuture
482
+ """
483
+ future = APIFuture[Any](operation_id=operation_id or f"cancelled-{id(object())}")
484
+ future._state = APIFutureState.CANCELLED
485
+ future._done_event.set()
486
+ return future