mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,496 @@
|
|
|
1
|
+
"""Inference Worker Process for MLXSmith Orchestrator.
|
|
2
|
+
|
|
3
|
+
Runs as a separate process with OpenAI-compatible API.
|
|
4
|
+
Handles weight update messages from orchestrator.
|
|
5
|
+
Supports explicit weight reloading without restart.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
import threading
|
|
12
|
+
import json
|
|
13
|
+
import signal
|
|
14
|
+
import sys
|
|
15
|
+
import time
|
|
16
|
+
import uuid
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any, Dict, List, Optional
|
|
20
|
+
|
|
21
|
+
import uvicorn
|
|
22
|
+
from fastapi import FastAPI
|
|
23
|
+
from fastapi.responses import StreamingResponse
|
|
24
|
+
|
|
25
|
+
# Relative imports will work when run as module
|
|
26
|
+
from ..config import ProjectConfig
|
|
27
|
+
from ..llm.registry import get_llm_backend
|
|
28
|
+
from ..models import resolve_model_spec
|
|
29
|
+
from ..rlm.weights import WeightPointerStore, WeightPointerIPC
|
|
30
|
+
from .queue import MessageQueue, MessageType, Message
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class InferenceConfig:
|
|
35
|
+
"""Configuration for inference worker."""
|
|
36
|
+
model_spec: str
|
|
37
|
+
backend: str = "mlx-lm"
|
|
38
|
+
host: str = "0.0.0.0"
|
|
39
|
+
port: int = 8080
|
|
40
|
+
max_seq_len: int = 8192
|
|
41
|
+
dtype: str = "bf16"
|
|
42
|
+
trust_remote_code: bool = False
|
|
43
|
+
use_chat_template: bool = True
|
|
44
|
+
weights_dir: Optional[Path] = None
|
|
45
|
+
hot_reload: bool = True
|
|
46
|
+
reload_poll_interval: float = 2.0
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class InferenceWorker:
|
|
50
|
+
"""Inference worker process with OpenAI-compatible API.
|
|
51
|
+
|
|
52
|
+
Runs in a separate process, handles:
|
|
53
|
+
- OpenAI-compatible /v1/chat/completions endpoint
|
|
54
|
+
- Internal /internal/rollout endpoint for RLM
|
|
55
|
+
- Hot-reloading of adapter weights via weight pointer updates
|
|
56
|
+
- Queue-based communication with orchestrator
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
config: InferenceConfig,
|
|
62
|
+
queue: Optional[MessageQueue] = None,
|
|
63
|
+
):
|
|
64
|
+
self.config = config
|
|
65
|
+
self.queue = queue
|
|
66
|
+
self._llm = None
|
|
67
|
+
self._base_model: Optional[str] = None
|
|
68
|
+
self._current_adapter: Optional[str] = None
|
|
69
|
+
self._pointer_store: Optional[WeightPointerStore] = None
|
|
70
|
+
self._current_version = -1
|
|
71
|
+
self._shutdown_event = asyncio.Event()
|
|
72
|
+
self._app: Optional[FastAPI] = None
|
|
73
|
+
|
|
74
|
+
def _load_model(self) -> None:
|
|
75
|
+
"""Load the base model and initial adapter."""
|
|
76
|
+
self._llm = get_llm_backend(self.config.backend)
|
|
77
|
+
|
|
78
|
+
# Resolve model spec
|
|
79
|
+
base_model, adapter_path, _ = resolve_model_spec(
|
|
80
|
+
Path.cwd(), self.config.model_spec, ProjectConfig()
|
|
81
|
+
)
|
|
82
|
+
self._base_model = base_model
|
|
83
|
+
|
|
84
|
+
# Load base model
|
|
85
|
+
self._llm.load(
|
|
86
|
+
base_model,
|
|
87
|
+
max_seq_len=self.config.max_seq_len,
|
|
88
|
+
dtype=self.config.dtype,
|
|
89
|
+
trust_remote_code=self.config.trust_remote_code,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Apply initial adapter if exists
|
|
93
|
+
if adapter_path:
|
|
94
|
+
self._llm.apply_adapter(str(adapter_path))
|
|
95
|
+
self._current_adapter = str(adapter_path)
|
|
96
|
+
|
|
97
|
+
# Setup weight pointer store for hot-reloading
|
|
98
|
+
if self.config.hot_reload and self.config.weights_dir:
|
|
99
|
+
self._pointer_store = WeightPointerStore(self.config.weights_dir)
|
|
100
|
+
# Initial load of inference pointer
|
|
101
|
+
pointer = self._pointer_store.load("inference", base_model)
|
|
102
|
+
self._current_version = pointer.version
|
|
103
|
+
if pointer.adapter_path and pointer.adapter_path != self._current_adapter:
|
|
104
|
+
self._apply_adapter(pointer.adapter_path)
|
|
105
|
+
|
|
106
|
+
def _apply_adapter(self, adapter_path: str) -> bool:
|
|
107
|
+
"""Apply a new adapter, reloading if necessary."""
|
|
108
|
+
try:
|
|
109
|
+
if adapter_path == self._current_adapter:
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
# Reload base model and apply new adapter
|
|
113
|
+
self._llm.load(
|
|
114
|
+
self._base_model,
|
|
115
|
+
max_seq_len=self.config.max_seq_len,
|
|
116
|
+
dtype=self.config.dtype,
|
|
117
|
+
trust_remote_code=self.config.trust_remote_code,
|
|
118
|
+
)
|
|
119
|
+
self._llm.apply_adapter(adapter_path)
|
|
120
|
+
self._current_adapter = adapter_path
|
|
121
|
+
return True
|
|
122
|
+
except Exception as e:
|
|
123
|
+
print(f"[InferenceWorker] Failed to apply adapter {adapter_path}: {e}")
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
def _check_weight_updates(self) -> None:
|
|
127
|
+
"""Check for weight pointer updates."""
|
|
128
|
+
if not self._pointer_store:
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
pointer = self._pointer_store.load("inference", self._base_model)
|
|
132
|
+
if pointer.version > self._current_version:
|
|
133
|
+
self._current_version = pointer.version
|
|
134
|
+
if pointer.adapter_path:
|
|
135
|
+
success = self._apply_adapter(pointer.adapter_path)
|
|
136
|
+
if success:
|
|
137
|
+
print(f"[InferenceWorker] Hot-reloaded adapter: {pointer.adapter_path}")
|
|
138
|
+
|
|
139
|
+
def _handle_queue_message(self, msg: Message) -> Optional[Message]:
|
|
140
|
+
"""Handle a message from the queue."""
|
|
141
|
+
if msg.msg_type == MessageType.ROLLOUT_REQUEST:
|
|
142
|
+
return self._handle_rollout_request(msg)
|
|
143
|
+
elif msg.msg_type == MessageType.WEIGHT_UPDATE:
|
|
144
|
+
return self._handle_weight_update(msg)
|
|
145
|
+
elif msg.msg_type == MessageType.HEALTH_CHECK:
|
|
146
|
+
return self._handle_health_check(msg)
|
|
147
|
+
elif msg.msg_type == MessageType.SHUTDOWN:
|
|
148
|
+
self._shutdown_event.set()
|
|
149
|
+
return None
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def _handle_rollout_request(self, msg: Message) -> Message:
|
|
153
|
+
"""Handle a rollout request."""
|
|
154
|
+
payload = msg.payload
|
|
155
|
+
prompt = payload.get("prompt", "")
|
|
156
|
+
max_tokens = payload.get("max_tokens", 256)
|
|
157
|
+
temperature = payload.get("temperature", 0.8)
|
|
158
|
+
top_p = payload.get("top_p", 1.0)
|
|
159
|
+
top_k = payload.get("top_k")
|
|
160
|
+
seed = payload.get("seed")
|
|
161
|
+
|
|
162
|
+
# Check for weight updates before generating
|
|
163
|
+
self._check_weight_updates()
|
|
164
|
+
|
|
165
|
+
# Generate rollout
|
|
166
|
+
try:
|
|
167
|
+
gen = self._llm.generate_with_logprobs(
|
|
168
|
+
prompt,
|
|
169
|
+
max_new_tokens=max_tokens,
|
|
170
|
+
temperature=temperature,
|
|
171
|
+
top_p=top_p,
|
|
172
|
+
top_k=top_k,
|
|
173
|
+
seed=seed,
|
|
174
|
+
)
|
|
175
|
+
except TypeError:
|
|
176
|
+
try:
|
|
177
|
+
gen = self._llm.generate_with_logprobs(
|
|
178
|
+
prompt,
|
|
179
|
+
max_new_tokens=max_tokens,
|
|
180
|
+
temperature=temperature,
|
|
181
|
+
top_p=top_p,
|
|
182
|
+
top_k=top_k,
|
|
183
|
+
seed=seed,
|
|
184
|
+
)
|
|
185
|
+
except TypeError:
|
|
186
|
+
gen = self._llm.generate_with_logprobs(
|
|
187
|
+
prompt,
|
|
188
|
+
max_new_tokens=max_tokens,
|
|
189
|
+
temperature=temperature,
|
|
190
|
+
top_p=top_p,
|
|
191
|
+
top_k_sampling=top_k,
|
|
192
|
+
seed=seed,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
196
|
+
|
|
197
|
+
return Message(
|
|
198
|
+
msg_type=MessageType.ROLLOUT_RESPONSE,
|
|
199
|
+
payload={
|
|
200
|
+
"request_id": msg.msg_id,
|
|
201
|
+
"prompt_len": gen.prompt_len,
|
|
202
|
+
"token_ids": list(gen.token_ids),
|
|
203
|
+
"logprobs": list(gen.logprobs) if gen.logprobs else None,
|
|
204
|
+
"completion": completion,
|
|
205
|
+
"adapter_version": self._current_version,
|
|
206
|
+
},
|
|
207
|
+
source="inference",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def _handle_weight_update(self, msg: Message) -> Message:
|
|
211
|
+
"""Handle a weight update from the trainer."""
|
|
212
|
+
payload = msg.payload
|
|
213
|
+
adapter_path = payload.get("adapter_path")
|
|
214
|
+
version = payload.get("version", 0)
|
|
215
|
+
|
|
216
|
+
success = False
|
|
217
|
+
if adapter_path:
|
|
218
|
+
success = self._apply_adapter(adapter_path)
|
|
219
|
+
if success:
|
|
220
|
+
self._current_version = version
|
|
221
|
+
|
|
222
|
+
return Message(
|
|
223
|
+
msg_type=MessageType.WEIGHT_ACK,
|
|
224
|
+
payload={
|
|
225
|
+
"request_id": msg.msg_id,
|
|
226
|
+
"success": success,
|
|
227
|
+
"adapter_path": self._current_adapter,
|
|
228
|
+
"version": self._current_version,
|
|
229
|
+
},
|
|
230
|
+
source="inference",
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
def _handle_health_check(self, msg: Message) -> Message:
|
|
234
|
+
"""Handle a health check request."""
|
|
235
|
+
return Message(
|
|
236
|
+
msg_type=MessageType.HEALTH_RESPONSE,
|
|
237
|
+
payload={
|
|
238
|
+
"request_id": msg.msg_id,
|
|
239
|
+
"status": "healthy",
|
|
240
|
+
"base_model": self._base_model,
|
|
241
|
+
"adapter_path": self._current_adapter,
|
|
242
|
+
"adapter_version": self._current_version,
|
|
243
|
+
},
|
|
244
|
+
source="inference",
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def _create_app(self) -> FastAPI:
|
|
248
|
+
"""Create the FastAPI application."""
|
|
249
|
+
app = FastAPI(title="mlxsmith-inference")
|
|
250
|
+
|
|
251
|
+
@app.get("/health")
|
|
252
|
+
def health():
|
|
253
|
+
return {
|
|
254
|
+
"ok": True,
|
|
255
|
+
"base_model": self._base_model,
|
|
256
|
+
"adapter_path": self._current_adapter,
|
|
257
|
+
"adapter_version": self._current_version,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
@app.post("/v1/chat/completions")
|
|
261
|
+
def chat_completions(request: Dict[str, Any]):
|
|
262
|
+
"""OpenAI-compatible chat completions endpoint."""
|
|
263
|
+
messages = request.get("messages", [])
|
|
264
|
+
max_tokens = request.get("max_tokens", 256)
|
|
265
|
+
temperature = request.get("temperature", 0.7)
|
|
266
|
+
top_p = request.get("top_p", 1.0)
|
|
267
|
+
stream = request.get("stream", False)
|
|
268
|
+
|
|
269
|
+
# Check for weight updates
|
|
270
|
+
self._check_weight_updates()
|
|
271
|
+
|
|
272
|
+
# Build prompt from messages
|
|
273
|
+
prompt = self._messages_to_prompt(messages)
|
|
274
|
+
|
|
275
|
+
if stream:
|
|
276
|
+
return StreamingResponse(
|
|
277
|
+
self._stream_generate(prompt, max_tokens, temperature, top_p),
|
|
278
|
+
media_type="text/event-stream",
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
gen = self._llm.generate(
|
|
282
|
+
prompt,
|
|
283
|
+
max_new_tokens=max_tokens,
|
|
284
|
+
temperature=temperature,
|
|
285
|
+
top_p=top_p,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
289
|
+
|
|
290
|
+
return {
|
|
291
|
+
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
292
|
+
"object": "chat.completion",
|
|
293
|
+
"created": int(time.time()),
|
|
294
|
+
"model": self._base_model,
|
|
295
|
+
"choices": [
|
|
296
|
+
{
|
|
297
|
+
"index": 0,
|
|
298
|
+
"message": {"role": "assistant", "content": completion},
|
|
299
|
+
"finish_reason": "stop",
|
|
300
|
+
}
|
|
301
|
+
],
|
|
302
|
+
"usage": {
|
|
303
|
+
"prompt_tokens": len(self._llm.encode(prompt)),
|
|
304
|
+
"completion_tokens": len(self._llm.encode(completion)),
|
|
305
|
+
"total_tokens": len(self._llm.encode(prompt)) + len(self._llm.encode(completion)),
|
|
306
|
+
},
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
@app.post("/internal/rollout")
|
|
310
|
+
def internal_rollout(request: Dict[str, Any]):
|
|
311
|
+
"""Internal rollout endpoint for RLM."""
|
|
312
|
+
prompt = request.get("prompt", "")
|
|
313
|
+
max_tokens = request.get("max_tokens", 256)
|
|
314
|
+
temperature = request.get("temperature", 0.7)
|
|
315
|
+
top_p = request.get("top_p", 1.0)
|
|
316
|
+
top_k = request.get("top_k")
|
|
317
|
+
seed = request.get("seed")
|
|
318
|
+
include_tokens = request.get("include_tokens", True)
|
|
319
|
+
include_logprobs = request.get("include_logprobs", True)
|
|
320
|
+
|
|
321
|
+
# Check for weight updates
|
|
322
|
+
self._check_weight_updates()
|
|
323
|
+
|
|
324
|
+
gen = self._llm.generate_with_logprobs(
|
|
325
|
+
prompt,
|
|
326
|
+
max_new_tokens=max_tokens,
|
|
327
|
+
temperature=temperature,
|
|
328
|
+
top_p=top_p,
|
|
329
|
+
top_k_sampling=top_k,
|
|
330
|
+
seed=seed,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
334
|
+
|
|
335
|
+
return {
|
|
336
|
+
"id": f"rollout-{uuid.uuid4().hex[:12]}",
|
|
337
|
+
"created": int(time.time()),
|
|
338
|
+
"model": self._base_model,
|
|
339
|
+
"prompt_len": gen.prompt_len,
|
|
340
|
+
"token_ids": list(gen.token_ids) if include_tokens else None,
|
|
341
|
+
"logprobs": list(gen.logprobs) if (include_logprobs and gen.logprobs) else None,
|
|
342
|
+
"completion": completion,
|
|
343
|
+
"adapter_version": self._current_version,
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
@app.post("/internal/adapter/reload")
|
|
347
|
+
def reload_adapter(request: Dict[str, Any]):
|
|
348
|
+
"""Explicitly reload adapter weights."""
|
|
349
|
+
adapter_path = request.get("adapter_path")
|
|
350
|
+
|
|
351
|
+
if adapter_path:
|
|
352
|
+
success = self._apply_adapter(adapter_path)
|
|
353
|
+
else:
|
|
354
|
+
# Reload from pointer store
|
|
355
|
+
self._check_weight_updates()
|
|
356
|
+
success = True
|
|
357
|
+
|
|
358
|
+
return {
|
|
359
|
+
"ok": success,
|
|
360
|
+
"base_model": self._base_model,
|
|
361
|
+
"adapter_path": self._current_adapter,
|
|
362
|
+
"adapter_version": self._current_version,
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
return app
|
|
366
|
+
|
|
367
|
+
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str:
|
|
368
|
+
"""Convert chat messages to prompt."""
|
|
369
|
+
if self.config.use_chat_template and hasattr(self._llm.tokenizer, "apply_chat_template"):
|
|
370
|
+
msgs = [{"role": m["role"], "content": m["content"]} for m in messages]
|
|
371
|
+
return self._llm.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
|
372
|
+
|
|
373
|
+
# Fallback
|
|
374
|
+
return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
|
|
375
|
+
|
|
376
|
+
async def _stream_generate(
|
|
377
|
+
self,
|
|
378
|
+
prompt: str,
|
|
379
|
+
max_tokens: int,
|
|
380
|
+
temperature: float,
|
|
381
|
+
top_p: float,
|
|
382
|
+
):
|
|
383
|
+
"""Stream generate tokens."""
|
|
384
|
+
try:
|
|
385
|
+
import mlx_lm
|
|
386
|
+
|
|
387
|
+
acc = ""
|
|
388
|
+
emitted = ""
|
|
389
|
+
|
|
390
|
+
for out in mlx_lm.stream_generate(
|
|
391
|
+
self._llm.model,
|
|
392
|
+
self._llm.tokenizer,
|
|
393
|
+
prompt,
|
|
394
|
+
max_tokens=max_tokens,
|
|
395
|
+
temp=temperature,
|
|
396
|
+
top_p=top_p,
|
|
397
|
+
):
|
|
398
|
+
if out.text:
|
|
399
|
+
acc += out.text
|
|
400
|
+
delta = acc[len(emitted):]
|
|
401
|
+
emitted = acc
|
|
402
|
+
|
|
403
|
+
payload = {
|
|
404
|
+
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
405
|
+
"object": "chat.completion.chunk",
|
|
406
|
+
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
|
407
|
+
}
|
|
408
|
+
yield f"data: {json.dumps(payload)}\n\n"
|
|
409
|
+
|
|
410
|
+
if out.finish_reason:
|
|
411
|
+
break
|
|
412
|
+
|
|
413
|
+
yield "data: [DONE]\n\n"
|
|
414
|
+
|
|
415
|
+
except ImportError:
|
|
416
|
+
# Non-streaming fallback
|
|
417
|
+
gen = self._llm.generate(
|
|
418
|
+
prompt,
|
|
419
|
+
max_new_tokens=max_tokens,
|
|
420
|
+
temperature=temperature,
|
|
421
|
+
top_p=top_p,
|
|
422
|
+
)
|
|
423
|
+
completion = gen.text[len(prompt):] if gen.text.startswith(prompt) else gen.text
|
|
424
|
+
|
|
425
|
+
payload = {
|
|
426
|
+
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
427
|
+
"object": "chat.completion.chunk",
|
|
428
|
+
"choices": [{"index": 0, "delta": {"content": completion}, "finish_reason": "stop"}],
|
|
429
|
+
}
|
|
430
|
+
yield f"data: {json.dumps(payload)}\n\n"
|
|
431
|
+
yield "data: [DONE]\n\n"
|
|
432
|
+
|
|
433
|
+
def _queue_worker_loop(self) -> None:
|
|
434
|
+
"""Background thread for queue-based communication."""
|
|
435
|
+
if not self.queue:
|
|
436
|
+
return
|
|
437
|
+
|
|
438
|
+
while not self._shutdown_event.is_set():
|
|
439
|
+
try:
|
|
440
|
+
msg = self.queue.receive("rollout_requests", timeout=0.1)
|
|
441
|
+
if msg:
|
|
442
|
+
response = self._handle_queue_message(msg)
|
|
443
|
+
if response:
|
|
444
|
+
self.queue.get_queue("rollout_responses").put(response.to_dict())
|
|
445
|
+
|
|
446
|
+
weight_msg = self.queue.receive("weight_forward", timeout=0)
|
|
447
|
+
if weight_msg:
|
|
448
|
+
self._handle_weight_update(weight_msg)
|
|
449
|
+
|
|
450
|
+
except Exception as e:
|
|
451
|
+
print(f"[InferenceWorker] Queue error: {e}")
|
|
452
|
+
|
|
453
|
+
time.sleep(0.01)
|
|
454
|
+
|
|
455
|
+
def run(self) -> None:
|
|
456
|
+
"""Run the inference worker."""
|
|
457
|
+
# Setup signal handlers
|
|
458
|
+
def signal_handler(sig, frame):
|
|
459
|
+
print("[InferenceWorker] Shutting down...")
|
|
460
|
+
self._shutdown_event.set()
|
|
461
|
+
sys.exit(0)
|
|
462
|
+
|
|
463
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
464
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
465
|
+
|
|
466
|
+
# Load model
|
|
467
|
+
print("[InferenceWorker] Loading model...")
|
|
468
|
+
self._load_model()
|
|
469
|
+
print(f"[InferenceWorker] Model loaded: {self._base_model}")
|
|
470
|
+
print(f"[InferenceWorker] Adapter: {self._current_adapter}")
|
|
471
|
+
|
|
472
|
+
# Create FastAPI app
|
|
473
|
+
self._app = self._create_app()
|
|
474
|
+
|
|
475
|
+
# Start queue worker if enabled
|
|
476
|
+
if self.queue:
|
|
477
|
+
thread = threading.Thread(target=self._queue_worker_loop, daemon=True)
|
|
478
|
+
thread.start()
|
|
479
|
+
|
|
480
|
+
# Start uvicorn server
|
|
481
|
+
print(f"[InferenceWorker] Starting server on {self.config.host}:{self.config.port}")
|
|
482
|
+
uvicorn.run(
|
|
483
|
+
self._app,
|
|
484
|
+
host=self.config.host,
|
|
485
|
+
port=self.config.port,
|
|
486
|
+
log_level="warning",
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def run_inference_worker(
|
|
491
|
+
config: InferenceConfig,
|
|
492
|
+
queue: Optional[MessageQueue] = None,
|
|
493
|
+
) -> None:
|
|
494
|
+
"""Entry point for inference worker process."""
|
|
495
|
+
worker = InferenceWorker(config, queue)
|
|
496
|
+
worker.run()
|