mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,496 @@
1
+ """Inference Worker Process for MLXSmith Orchestrator.
2
+
3
+ Runs as a separate process with OpenAI-compatible API.
4
+ Handles weight update messages from orchestrator.
5
+ Supports explicit weight reloading without restart.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import asyncio
11
+ import threading
12
+ import json
13
+ import signal
14
+ import sys
15
+ import time
16
+ import uuid
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import Any, Dict, List, Optional
20
+
21
+ import uvicorn
22
+ from fastapi import FastAPI
23
+ from fastapi.responses import StreamingResponse
24
+
25
+ # Relative imports will work when run as module
26
+ from ..config import ProjectConfig
27
+ from ..llm.registry import get_llm_backend
28
+ from ..models import resolve_model_spec
29
+ from ..rlm.weights import WeightPointerStore, WeightPointerIPC
30
+ from .queue import MessageQueue, MessageType, Message
31
+
32
+
33
+ @dataclass
34
+ class InferenceConfig:
35
+ """Configuration for inference worker."""
36
+ model_spec: str
37
+ backend: str = "mlx-lm"
38
+ host: str = "0.0.0.0"
39
+ port: int = 8080
40
+ max_seq_len: int = 8192
41
+ dtype: str = "bf16"
42
+ trust_remote_code: bool = False
43
+ use_chat_template: bool = True
44
+ weights_dir: Optional[Path] = None
45
+ hot_reload: bool = True
46
+ reload_poll_interval: float = 2.0
47
+
48
+
49
+ class InferenceWorker:
50
+ """Inference worker process with OpenAI-compatible API.
51
+
52
+ Runs in a separate process, handles:
53
+ - OpenAI-compatible /v1/chat/completions endpoint
54
+ - Internal /internal/rollout endpoint for RLM
55
+ - Hot-reloading of adapter weights via weight pointer updates
56
+ - Queue-based communication with orchestrator
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ config: InferenceConfig,
62
+ queue: Optional[MessageQueue] = None,
63
+ ):
64
+ self.config = config
65
+ self.queue = queue
66
+ self._llm = None
67
+ self._base_model: Optional[str] = None
68
+ self._current_adapter: Optional[str] = None
69
+ self._pointer_store: Optional[WeightPointerStore] = None
70
+ self._current_version = -1
71
+ self._shutdown_event = asyncio.Event()
72
+ self._app: Optional[FastAPI] = None
73
+
74
+ def _load_model(self) -> None:
75
+ """Load the base model and initial adapter."""
76
+ self._llm = get_llm_backend(self.config.backend)
77
+
78
+ # Resolve model spec
79
+ base_model, adapter_path, _ = resolve_model_spec(
80
+ Path.cwd(), self.config.model_spec, ProjectConfig()
81
+ )
82
+ self._base_model = base_model
83
+
84
+ # Load base model
85
+ self._llm.load(
86
+ base_model,
87
+ max_seq_len=self.config.max_seq_len,
88
+ dtype=self.config.dtype,
89
+ trust_remote_code=self.config.trust_remote_code,
90
+ )
91
+
92
+ # Apply initial adapter if exists
93
+ if adapter_path:
94
+ self._llm.apply_adapter(str(adapter_path))
95
+ self._current_adapter = str(adapter_path)
96
+
97
+ # Setup weight pointer store for hot-reloading
98
+ if self.config.hot_reload and self.config.weights_dir:
99
+ self._pointer_store = WeightPointerStore(self.config.weights_dir)
100
+ # Initial load of inference pointer
101
+ pointer = self._pointer_store.load("inference", base_model)
102
+ self._current_version = pointer.version
103
+ if pointer.adapter_path and pointer.adapter_path != self._current_adapter:
104
+ self._apply_adapter(pointer.adapter_path)
105
+
106
+ def _apply_adapter(self, adapter_path: str) -> bool:
107
+ """Apply a new adapter, reloading if necessary."""
108
+ try:
109
+ if adapter_path == self._current_adapter:
110
+ return True
111
+
112
+ # Reload base model and apply new adapter
113
+ self._llm.load(
114
+ self._base_model,
115
+ max_seq_len=self.config.max_seq_len,
116
+ dtype=self.config.dtype,
117
+ trust_remote_code=self.config.trust_remote_code,
118
+ )
119
+ self._llm.apply_adapter(adapter_path)
120
+ self._current_adapter = adapter_path
121
+ return True
122
+ except Exception as e:
123
+ print(f"[InferenceWorker] Failed to apply adapter {adapter_path}: {e}")
124
+ return False
125
+
126
+ def _check_weight_updates(self) -> None:
127
+ """Check for weight pointer updates."""
128
+ if not self._pointer_store:
129
+ return
130
+
131
+ pointer = self._pointer_store.load("inference", self._base_model)
132
+ if pointer.version > self._current_version:
133
+ self._current_version = pointer.version
134
+ if pointer.adapter_path:
135
+ success = self._apply_adapter(pointer.adapter_path)
136
+ if success:
137
+ print(f"[InferenceWorker] Hot-reloaded adapter: {pointer.adapter_path}")
138
+
139
+ def _handle_queue_message(self, msg: Message) -> Optional[Message]:
140
+ """Handle a message from the queue."""
141
+ if msg.msg_type == MessageType.ROLLOUT_REQUEST:
142
+ return self._handle_rollout_request(msg)
143
+ elif msg.msg_type == MessageType.WEIGHT_UPDATE:
144
+ return self._handle_weight_update(msg)
145
+ elif msg.msg_type == MessageType.HEALTH_CHECK:
146
+ return self._handle_health_check(msg)
147
+ elif msg.msg_type == MessageType.SHUTDOWN:
148
+ self._shutdown_event.set()
149
+ return None
150
+ return None
151
+
152
+ def _handle_rollout_request(self, msg: Message) -> Message:
153
+ """Handle a rollout request."""
154
+ payload = msg.payload
155
+ prompt = payload.get("prompt", "")
156
+ max_tokens = payload.get("max_tokens", 256)
157
+ temperature = payload.get("temperature", 0.8)
158
+ top_p = payload.get("top_p", 1.0)
159
+ top_k = payload.get("top_k")
160
+ seed = payload.get("seed")
161
+
162
+ # Check for weight updates before generating
163
+ self._check_weight_updates()
164
+
165
+ # Generate rollout
166
+ try:
167
+ gen = self._llm.generate_with_logprobs(
168
+ prompt,
169
+ max_new_tokens=max_tokens,
170
+ temperature=temperature,
171
+ top_p=top_p,
172
+ top_k=top_k,
173
+ seed=seed,
174
+ )
175
+ except TypeError:
176
+ try:
177
+ gen = self._llm.generate_with_logprobs(
178
+ prompt,
179
+ max_new_tokens=max_tokens,
180
+ temperature=temperature,
181
+ top_p=top_p,
182
+ top_k=top_k,
183
+ seed=seed,
184
+ )
185
+ except TypeError:
186
+ gen = self._llm.generate_with_logprobs(
187
+ prompt,
188
+ max_new_tokens=max_tokens,
189
+ temperature=temperature,
190
+ top_p=top_p,
191
+ top_k_sampling=top_k,
192
+ seed=seed,
193
+ )
194
+
195
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
196
+
197
+ return Message(
198
+ msg_type=MessageType.ROLLOUT_RESPONSE,
199
+ payload={
200
+ "request_id": msg.msg_id,
201
+ "prompt_len": gen.prompt_len,
202
+ "token_ids": list(gen.token_ids),
203
+ "logprobs": list(gen.logprobs) if gen.logprobs else None,
204
+ "completion": completion,
205
+ "adapter_version": self._current_version,
206
+ },
207
+ source="inference",
208
+ )
209
+
210
+ def _handle_weight_update(self, msg: Message) -> Message:
211
+ """Handle a weight update from the trainer."""
212
+ payload = msg.payload
213
+ adapter_path = payload.get("adapter_path")
214
+ version = payload.get("version", 0)
215
+
216
+ success = False
217
+ if adapter_path:
218
+ success = self._apply_adapter(adapter_path)
219
+ if success:
220
+ self._current_version = version
221
+
222
+ return Message(
223
+ msg_type=MessageType.WEIGHT_ACK,
224
+ payload={
225
+ "request_id": msg.msg_id,
226
+ "success": success,
227
+ "adapter_path": self._current_adapter,
228
+ "version": self._current_version,
229
+ },
230
+ source="inference",
231
+ )
232
+
233
+ def _handle_health_check(self, msg: Message) -> Message:
234
+ """Handle a health check request."""
235
+ return Message(
236
+ msg_type=MessageType.HEALTH_RESPONSE,
237
+ payload={
238
+ "request_id": msg.msg_id,
239
+ "status": "healthy",
240
+ "base_model": self._base_model,
241
+ "adapter_path": self._current_adapter,
242
+ "adapter_version": self._current_version,
243
+ },
244
+ source="inference",
245
+ )
246
+
247
+ def _create_app(self) -> FastAPI:
248
+ """Create the FastAPI application."""
249
+ app = FastAPI(title="mlxsmith-inference")
250
+
251
+ @app.get("/health")
252
+ def health():
253
+ return {
254
+ "ok": True,
255
+ "base_model": self._base_model,
256
+ "adapter_path": self._current_adapter,
257
+ "adapter_version": self._current_version,
258
+ }
259
+
260
+ @app.post("/v1/chat/completions")
261
+ def chat_completions(request: Dict[str, Any]):
262
+ """OpenAI-compatible chat completions endpoint."""
263
+ messages = request.get("messages", [])
264
+ max_tokens = request.get("max_tokens", 256)
265
+ temperature = request.get("temperature", 0.7)
266
+ top_p = request.get("top_p", 1.0)
267
+ stream = request.get("stream", False)
268
+
269
+ # Check for weight updates
270
+ self._check_weight_updates()
271
+
272
+ # Build prompt from messages
273
+ prompt = self._messages_to_prompt(messages)
274
+
275
+ if stream:
276
+ return StreamingResponse(
277
+ self._stream_generate(prompt, max_tokens, temperature, top_p),
278
+ media_type="text/event-stream",
279
+ )
280
+
281
+ gen = self._llm.generate(
282
+ prompt,
283
+ max_new_tokens=max_tokens,
284
+ temperature=temperature,
285
+ top_p=top_p,
286
+ )
287
+
288
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
289
+
290
+ return {
291
+ "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
292
+ "object": "chat.completion",
293
+ "created": int(time.time()),
294
+ "model": self._base_model,
295
+ "choices": [
296
+ {
297
+ "index": 0,
298
+ "message": {"role": "assistant", "content": completion},
299
+ "finish_reason": "stop",
300
+ }
301
+ ],
302
+ "usage": {
303
+ "prompt_tokens": len(self._llm.encode(prompt)),
304
+ "completion_tokens": len(self._llm.encode(completion)),
305
+ "total_tokens": len(self._llm.encode(prompt)) + len(self._llm.encode(completion)),
306
+ },
307
+ }
308
+
309
+ @app.post("/internal/rollout")
310
+ def internal_rollout(request: Dict[str, Any]):
311
+ """Internal rollout endpoint for RLM."""
312
+ prompt = request.get("prompt", "")
313
+ max_tokens = request.get("max_tokens", 256)
314
+ temperature = request.get("temperature", 0.7)
315
+ top_p = request.get("top_p", 1.0)
316
+ top_k = request.get("top_k")
317
+ seed = request.get("seed")
318
+ include_tokens = request.get("include_tokens", True)
319
+ include_logprobs = request.get("include_logprobs", True)
320
+
321
+ # Check for weight updates
322
+ self._check_weight_updates()
323
+
324
+ gen = self._llm.generate_with_logprobs(
325
+ prompt,
326
+ max_new_tokens=max_tokens,
327
+ temperature=temperature,
328
+ top_p=top_p,
329
+ top_k_sampling=top_k,
330
+ seed=seed,
331
+ )
332
+
333
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
334
+
335
+ return {
336
+ "id": f"rollout-{uuid.uuid4().hex[:12]}",
337
+ "created": int(time.time()),
338
+ "model": self._base_model,
339
+ "prompt_len": gen.prompt_len,
340
+ "token_ids": list(gen.token_ids) if include_tokens else None,
341
+ "logprobs": list(gen.logprobs) if (include_logprobs and gen.logprobs) else None,
342
+ "completion": completion,
343
+ "adapter_version": self._current_version,
344
+ }
345
+
346
+ @app.post("/internal/adapter/reload")
347
+ def reload_adapter(request: Dict[str, Any]):
348
+ """Explicitly reload adapter weights."""
349
+ adapter_path = request.get("adapter_path")
350
+
351
+ if adapter_path:
352
+ success = self._apply_adapter(adapter_path)
353
+ else:
354
+ # Reload from pointer store
355
+ self._check_weight_updates()
356
+ success = True
357
+
358
+ return {
359
+ "ok": success,
360
+ "base_model": self._base_model,
361
+ "adapter_path": self._current_adapter,
362
+ "adapter_version": self._current_version,
363
+ }
364
+
365
+ return app
366
+
367
+ def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
368
+ """Convert chat messages to prompt."""
369
+ if self.config.use_chat_template and hasattr(self._llm.tokenizer, "apply_chat_template"):
370
+ msgs = [{"role": m["role"], "content": m["content"]} for m in messages]
371
+ return self._llm.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
372
+
373
+ # Fallback
374
+ return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
375
+
376
+ async def _stream_generate(
377
+ self,
378
+ prompt: str,
379
+ max_tokens: int,
380
+ temperature: float,
381
+ top_p: float,
382
+ ):
383
+ """Stream generate tokens."""
384
+ try:
385
+ import mlx_lm
386
+
387
+ acc = ""
388
+ emitted = ""
389
+
390
+ for out in mlx_lm.stream_generate(
391
+ self._llm.model,
392
+ self._llm.tokenizer,
393
+ prompt,
394
+ max_tokens=max_tokens,
395
+ temp=temperature,
396
+ top_p=top_p,
397
+ ):
398
+ if out.text:
399
+ acc += out.text
400
+ delta = acc[len(emitted):]
401
+ emitted = acc
402
+
403
+ payload = {
404
+ "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
405
+ "object": "chat.completion.chunk",
406
+ "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
407
+ }
408
+ yield f"data: {json.dumps(payload)}\n\n"
409
+
410
+ if out.finish_reason:
411
+ break
412
+
413
+ yield "data: [DONE]\n\n"
414
+
415
+ except ImportError:
416
+ # Non-streaming fallback
417
+ gen = self._llm.generate(
418
+ prompt,
419
+ max_new_tokens=max_tokens,
420
+ temperature=temperature,
421
+ top_p=top_p,
422
+ )
423
+ completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
424
+
425
+ payload = {
426
+ "id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
427
+ "object": "chat.completion.chunk",
428
+ "choices": [{"index": 0, "delta": {"content": completion}, "finish_reason": "stop"}],
429
+ }
430
+ yield f"data: {json.dumps(payload)}\n\n"
431
+ yield "data: [DONE]\n\n"
432
+
433
+ def _queue_worker_loop(self) -> None:
434
+ """Background thread for queue-based communication."""
435
+ if not self.queue:
436
+ return
437
+
438
+ while not self._shutdown_event.is_set():
439
+ try:
440
+ msg = self.queue.receive("rollout_requests", timeout=0.1)
441
+ if msg:
442
+ response = self._handle_queue_message(msg)
443
+ if response:
444
+ self.queue.get_queue("rollout_responses").put(response.to_dict())
445
+
446
+ weight_msg = self.queue.receive("weight_forward", timeout=0)
447
+ if weight_msg:
448
+ self._handle_weight_update(weight_msg)
449
+
450
+ except Exception as e:
451
+ print(f"[InferenceWorker] Queue error: {e}")
452
+
453
+ time.sleep(0.01)
454
+
455
+ def run(self) -> None:
456
+ """Run the inference worker."""
457
+ # Setup signal handlers
458
+ def signal_handler(sig, frame):
459
+ print("[InferenceWorker] Shutting down...")
460
+ self._shutdown_event.set()
461
+ sys.exit(0)
462
+
463
+ signal.signal(signal.SIGTERM, signal_handler)
464
+ signal.signal(signal.SIGINT, signal_handler)
465
+
466
+ # Load model
467
+ print("[InferenceWorker] Loading model...")
468
+ self._load_model()
469
+ print(f"[InferenceWorker] Model loaded: {self._base_model}")
470
+ print(f"[InferenceWorker] Adapter: {self._current_adapter}")
471
+
472
+ # Create FastAPI app
473
+ self._app = self._create_app()
474
+
475
+ # Start queue worker if enabled
476
+ if self.queue:
477
+ thread = threading.Thread(target=self._queue_worker_loop, daemon=True)
478
+ thread.start()
479
+
480
+ # Start uvicorn server
481
+ print(f"[InferenceWorker] Starting server on {self.config.host}:{self.config.port}")
482
+ uvicorn.run(
483
+ self._app,
484
+ host=self.config.host,
485
+ port=self.config.port,
486
+ log_level="warning",
487
+ )
488
+
489
+
490
+ def run_inference_worker(
491
+ config: InferenceConfig,
492
+ queue: Optional[MessageQueue] = None,
493
+ ) -> None:
494
+ """Entry point for inference worker process."""
495
+ worker = InferenceWorker(config, queue)
496
+ worker.run()