mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/api/schemas.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
"""Pydantic models for API request/response validation.
|
|
2
|
+
|
|
3
|
+
OpenAPI 3.1 compatible schemas for MLXSmith API.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
9
|
+
from pydantic import BaseModel, Field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# =============================================================================
|
|
13
|
+
# Common Schemas
|
|
14
|
+
# =============================================================================
|
|
15
|
+
|
|
16
|
+
class ErrorResponse(BaseModel):
|
|
17
|
+
"""Error response schema."""
|
|
18
|
+
error: str = Field(..., description="Error message")
|
|
19
|
+
code: Optional[str] = Field(None, description="Error code")
|
|
20
|
+
details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HealthResponse(BaseModel):
|
|
24
|
+
"""Health check response."""
|
|
25
|
+
ok: bool = Field(..., description="Service health status")
|
|
26
|
+
version: Optional[str] = Field(None, description="API version")
|
|
27
|
+
model: Optional[str] = Field(None, description="Currently loaded model")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# =============================================================================
|
|
31
|
+
# Chat Completions (OpenAI-compatible)
|
|
32
|
+
# =============================================================================
|
|
33
|
+
|
|
34
|
+
class ChatMessage(BaseModel):
|
|
35
|
+
"""A single chat message."""
|
|
36
|
+
role: Literal["system", "user", "assistant", "tool"] = Field(
|
|
37
|
+
..., description="Role of the message sender"
|
|
38
|
+
)
|
|
39
|
+
content: str = Field(..., description="Message content")
|
|
40
|
+
name: Optional[str] = Field(None, description="Optional name for the sender")
|
|
41
|
+
tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls (if any)")
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ChatRequest(BaseModel):
|
|
45
|
+
"""OpenAI-compatible chat completion request."""
|
|
46
|
+
model: Optional[str] = Field(
|
|
47
|
+
None, description="Model identifier (optional, uses default if not provided)"
|
|
48
|
+
)
|
|
49
|
+
messages: List[ChatMessage] = Field(
|
|
50
|
+
..., description="List of chat messages", min_length=1
|
|
51
|
+
)
|
|
52
|
+
max_tokens: int = Field(
|
|
53
|
+
256, description="Maximum tokens to generate", ge=1, le=8192
|
|
54
|
+
)
|
|
55
|
+
temperature: float = Field(
|
|
56
|
+
0.7, description="Sampling temperature", ge=0.0, le=2.0
|
|
57
|
+
)
|
|
58
|
+
top_p: float = Field(
|
|
59
|
+
1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0
|
|
60
|
+
)
|
|
61
|
+
top_k: Optional[int] = Field(
|
|
62
|
+
None, description="Top-k sampling parameter", ge=1
|
|
63
|
+
)
|
|
64
|
+
stream: Optional[bool] = Field(
|
|
65
|
+
False, description="Enable streaming response via SSE"
|
|
66
|
+
)
|
|
67
|
+
stop: Optional[Union[str, List[str]]] = Field(
|
|
68
|
+
None, description="Stop sequences"
|
|
69
|
+
)
|
|
70
|
+
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
|
71
|
+
presence_penalty: Optional[float] = Field(
|
|
72
|
+
0.0, description="Presence penalty", ge=-2.0, le=2.0
|
|
73
|
+
)
|
|
74
|
+
frequency_penalty: Optional[float] = Field(
|
|
75
|
+
0.0, description="Frequency penalty", ge=-2.0, le=2.0
|
|
76
|
+
)
|
|
77
|
+
logprobs: Optional[bool] = Field(
|
|
78
|
+
False, description="Return logprobs of output tokens"
|
|
79
|
+
)
|
|
80
|
+
top_logprobs: Optional[int] = Field(
|
|
81
|
+
None, description="Number of top logprobs to return per token", ge=0, le=20
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class LogprobsContent(BaseModel):
|
|
86
|
+
"""Logprob information for a token."""
|
|
87
|
+
token: str = Field(..., description="The token string")
|
|
88
|
+
logprob: float = Field(..., description="The log probability of the token")
|
|
89
|
+
bytes: Optional[List[int]] = Field(None, description="Bytes representation of token")
|
|
90
|
+
top_logprobs: Optional[List[Dict[str, float]]] = Field(
|
|
91
|
+
None, description="Top logprobs for this position"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ChoiceLogprobs(BaseModel):
|
|
96
|
+
"""Logprobs for a completion choice."""
|
|
97
|
+
content: Optional[List[LogprobsContent]] = Field(
|
|
98
|
+
None, description="Logprobs for each token in the completion"
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class UsageInfo(BaseModel):
|
|
103
|
+
"""Token usage information."""
|
|
104
|
+
prompt_tokens: int = Field(..., description="Number of tokens in the prompt")
|
|
105
|
+
completion_tokens: int = Field(..., description="Number of tokens in the completion")
|
|
106
|
+
total_tokens: int = Field(..., description="Total number of tokens")
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Choice(BaseModel):
|
|
110
|
+
"""A single completion choice."""
|
|
111
|
+
index: int = Field(..., description="Index of the choice")
|
|
112
|
+
message: ChatMessage = Field(..., description="The generated message")
|
|
113
|
+
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = Field(
|
|
114
|
+
None, description="Reason for completion finish"
|
|
115
|
+
)
|
|
116
|
+
logprobs: Optional[ChoiceLogprobs] = Field(None, description="Logprobs for this choice")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class ChatResponse(BaseModel):
|
|
120
|
+
"""OpenAI-compatible chat completion response."""
|
|
121
|
+
id: str = Field(..., description="Unique identifier for the completion")
|
|
122
|
+
object: Literal["chat.completion"] = Field("chat.completion")
|
|
123
|
+
created: int = Field(..., description="Unix timestamp of creation")
|
|
124
|
+
model: str = Field(..., description="Model used for the completion")
|
|
125
|
+
choices: List[Choice] = Field(..., description="List of completion choices")
|
|
126
|
+
usage: UsageInfo = Field(..., description="Token usage information")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class DeltaMessage(BaseModel):
|
|
130
|
+
"""Delta message for streaming responses."""
|
|
131
|
+
role: Optional[Literal["assistant"]] = Field(None)
|
|
132
|
+
content: Optional[str] = Field(None, description="Incremental content")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class StreamChoice(BaseModel):
|
|
136
|
+
"""A streaming completion choice."""
|
|
137
|
+
index: int = Field(..., description="Index of the choice")
|
|
138
|
+
delta: DeltaMessage = Field(..., description="Incremental message delta")
|
|
139
|
+
finish_reason: Optional[Literal["stop", "length"]] = Field(None)
|
|
140
|
+
logprobs: Optional[ChoiceLogprobs] = Field(None, description="Logprobs for this chunk")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ChatCompletionChunk(BaseModel):
|
|
144
|
+
"""Streaming chat completion chunk (SSE)."""
|
|
145
|
+
id: str = Field(..., description="Unique identifier")
|
|
146
|
+
object: Literal["chat.completion.chunk"] = Field("chat.completion.chunk")
|
|
147
|
+
created: int = Field(..., description="Unix timestamp")
|
|
148
|
+
model: str = Field(..., description="Model used")
|
|
149
|
+
choices: List[StreamChoice] = Field(..., description="List of choices")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# =============================================================================
|
|
153
|
+
# Internal Rollout (for RLM training)
|
|
154
|
+
# =============================================================================
|
|
155
|
+
|
|
156
|
+
class RolloutRequest(BaseModel):
|
|
157
|
+
"""Internal rollout request with detailed output options."""
|
|
158
|
+
prompt: str = Field(..., description="Input prompt text", min_length=1)
|
|
159
|
+
max_tokens: int = Field(256, description="Maximum tokens to generate", ge=1)
|
|
160
|
+
temperature: float = Field(0.7, description="Sampling temperature", ge=0.0, le=2.0)
|
|
161
|
+
top_p: float = Field(1.0, description="Nucleus sampling parameter", ge=0.0, le=1.0)
|
|
162
|
+
top_k: Optional[int] = Field(None, description="Top-k sampling", ge=1)
|
|
163
|
+
seed: Optional[int] = Field(None, description="Random seed")
|
|
164
|
+
include_tokens: bool = Field(True, description="Include token IDs in response")
|
|
165
|
+
include_logprobs: bool = Field(True, description="Include per-token logprobs")
|
|
166
|
+
include_top_k_logprobs: Optional[int] = Field(
|
|
167
|
+
None, description="Number of top logprobs per token to include", ge=0, le=20
|
|
168
|
+
)
|
|
169
|
+
include_prompt_logprobs: bool = Field(
|
|
170
|
+
False, description="Include per-token logprobs for prompt tokens"
|
|
171
|
+
)
|
|
172
|
+
include_prompt_top_k_logprobs: Optional[int] = Field(
|
|
173
|
+
None,
|
|
174
|
+
description="Number of top logprobs per prompt token to include",
|
|
175
|
+
ge=0,
|
|
176
|
+
le=20,
|
|
177
|
+
)
|
|
178
|
+
include_text: bool = Field(True, description="Include generated text")
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class RolloutResponse(BaseModel):
|
|
182
|
+
"""Internal rollout response with tokens and logprobs."""
|
|
183
|
+
id: str = Field(..., description="Unique rollout identifier")
|
|
184
|
+
created: int = Field(..., description="Unix timestamp")
|
|
185
|
+
model: str = Field(..., description="Model used")
|
|
186
|
+
prompt_len: int = Field(..., description="Length of prompt in tokens")
|
|
187
|
+
token_ids: Optional[List[int]] = Field(None, description="Generated token IDs")
|
|
188
|
+
logprobs: Optional[List[float]] = Field(None, description="Per-token log probabilities")
|
|
189
|
+
top_k_logprobs: Optional[List[Dict[str, float]]] = Field(
|
|
190
|
+
None, description="Top-k logprobs per token"
|
|
191
|
+
)
|
|
192
|
+
prompt_logprobs: Optional[List[float]] = Field(
|
|
193
|
+
None,
|
|
194
|
+
description="Per-token log probabilities for prompt tokens (excluding first token)",
|
|
195
|
+
)
|
|
196
|
+
prompt_top_k_logprobs: Optional[List[Dict[str, float]]] = Field(
|
|
197
|
+
None, description="Top-k logprobs per prompt token"
|
|
198
|
+
)
|
|
199
|
+
completion: Optional[str] = Field(None, description="Generated text (if requested)")
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
# =============================================================================
|
|
203
|
+
# Training Endpoints
|
|
204
|
+
# =============================================================================
|
|
205
|
+
|
|
206
|
+
class ForwardBackwardRequest(BaseModel):
|
|
207
|
+
"""Request for forward/backward pass."""
|
|
208
|
+
prompts: List[str] = Field(..., description="List of prompts", min_length=1)
|
|
209
|
+
responses: Optional[List[str]] = Field(None, description="List of responses (for SFT)")
|
|
210
|
+
rejected_responses: Optional[List[str]] = Field(
|
|
211
|
+
None, description="List of rejected responses (for preference training)"
|
|
212
|
+
)
|
|
213
|
+
loss_type: Literal["sft", "dpo", "orpo", "ppo", "custom"] = Field(
|
|
214
|
+
"sft", description="Type of loss to compute"
|
|
215
|
+
)
|
|
216
|
+
train_on_prompt: bool = Field(False, description="Compute loss on prompt tokens")
|
|
217
|
+
max_seq_len: Optional[int] = Field(None, description="Maximum sequence length")
|
|
218
|
+
extra: Optional[Dict[str, Any]] = Field(None, description="Additional loss parameters")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class ForwardBackwardResponse(BaseModel):
|
|
222
|
+
"""Response from forward/backward pass."""
|
|
223
|
+
loss: float = Field(..., description="Computed loss value")
|
|
224
|
+
has_grads: bool = Field(..., description="Whether gradients were computed")
|
|
225
|
+
batch_size: int = Field(..., description="Batch size processed")
|
|
226
|
+
metrics: Optional[Dict[str, float]] = Field(None, description="Additional metrics")
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class OptimStepRequest(BaseModel):
|
|
230
|
+
"""Request for optimizer step."""
|
|
231
|
+
learning_rate: Optional[float] = Field(None, description="Override learning rate")
|
|
232
|
+
grad_clip: Optional[float] = Field(None, description="Gradient clipping threshold")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class OptimStepResponse(BaseModel):
|
|
236
|
+
"""Response from optimizer step."""
|
|
237
|
+
step: int = Field(..., description="Current training step")
|
|
238
|
+
learning_rate: float = Field(..., description="Learning rate used")
|
|
239
|
+
grad_norm: Optional[float] = Field(None, description="Gradient norm")
|
|
240
|
+
success: bool = Field(True, description="Whether step succeeded")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class SaveStateRequest(BaseModel):
|
|
244
|
+
"""Request to save training state."""
|
|
245
|
+
path: str = Field(..., description="Path to save checkpoint")
|
|
246
|
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata to save")
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class SaveStateResponse(BaseModel):
|
|
250
|
+
"""Response from save state operation."""
|
|
251
|
+
path: str = Field(..., description="Path where checkpoint was saved")
|
|
252
|
+
success: bool = Field(..., description="Whether save succeeded")
|
|
253
|
+
message: str = Field(..., description="Status message")
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class LoadStateRequest(BaseModel):
|
|
257
|
+
"""Request to load training state."""
|
|
258
|
+
path: str = Field(..., description="Path to checkpoint to load")
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class LoadStateResponse(BaseModel):
|
|
262
|
+
"""Response from load state operation."""
|
|
263
|
+
path: str = Field(..., description="Path from which checkpoint was loaded")
|
|
264
|
+
success: bool = Field(..., description="Whether load succeeded")
|
|
265
|
+
message: str = Field(..., description="Status message")
|
|
266
|
+
step: Optional[int] = Field(None, description="Training step from checkpoint")
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class GetWeightsResponse(BaseModel):
|
|
270
|
+
"""Response for get weights operation."""
|
|
271
|
+
weights: Dict[str, Any] = Field(..., description="Model weights (may be partial/shape info)")
|
|
272
|
+
success: bool = Field(..., description="Whether operation succeeded")
|
|
273
|
+
message: str = Field(..., description="Status message")
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class SetWeightsRequest(BaseModel):
|
|
277
|
+
"""Request to set model weights."""
|
|
278
|
+
weights: Dict[str, Any] = Field(..., description="Model weights to set")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
class SetWeightsResponse(BaseModel):
|
|
282
|
+
"""Response from set weights operation."""
|
|
283
|
+
success: bool = Field(..., description="Whether operation succeeded")
|
|
284
|
+
message: str = Field(..., description="Status message")
|
|
285
|
+
num_tensors: int = Field(..., description="Number of weight tensors set")
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
# =============================================================================
|
|
289
|
+
# Adapter Management
|
|
290
|
+
# =============================================================================
|
|
291
|
+
|
|
292
|
+
class AdapterReloadRequest(BaseModel):
|
|
293
|
+
"""Request to reload adapter weights."""
|
|
294
|
+
adapter_path: Optional[str] = Field(
|
|
295
|
+
None, description="Path to adapter directory (relative or absolute)"
|
|
296
|
+
)
|
|
297
|
+
reload_base: bool = Field(
|
|
298
|
+
False, description="Reload the base model before applying adapter"
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class AdapterReloadResponse(BaseModel):
|
|
303
|
+
"""Response after adapter reload."""
|
|
304
|
+
ok: bool = Field(..., description="Whether reload was successful")
|
|
305
|
+
base_model: str = Field(..., description="Base model identifier")
|
|
306
|
+
adapter_path: Optional[str] = Field(None, description="Currently loaded adapter path")
|
|
307
|
+
message: Optional[str] = Field(None, description="Status message")
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# =============================================================================
|
|
311
|
+
# RLM State and History
|
|
312
|
+
# =============================================================================
|
|
313
|
+
|
|
314
|
+
class RLMTrainingMetrics(BaseModel):
|
|
315
|
+
"""RLM training metrics."""
|
|
316
|
+
loss: Optional[float] = Field(None, description="Training loss")
|
|
317
|
+
reward_mean: Optional[float] = Field(None, description="Mean reward")
|
|
318
|
+
reward_std: Optional[float] = Field(None, description="Reward standard deviation")
|
|
319
|
+
kl_div: Optional[float] = Field(None, description="KL divergence from reference")
|
|
320
|
+
learning_rate: Optional[float] = Field(None, description="Current learning rate")
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
class RLMState(BaseModel):
|
|
324
|
+
"""Current RLM training state."""
|
|
325
|
+
status: Literal["idle", "running", "paused", "completed", "error"] = Field(
|
|
326
|
+
..., description="Current training status"
|
|
327
|
+
)
|
|
328
|
+
iteration: Optional[int] = Field(None, description="Current training iteration")
|
|
329
|
+
total_iterations: Optional[int] = Field(None, description="Total planned iterations")
|
|
330
|
+
metrics: Optional[RLMTrainingMetrics] = Field(None, description="Current metrics")
|
|
331
|
+
started_at: Optional[int] = Field(None, description="Training start timestamp")
|
|
332
|
+
updated_at: Optional[int] = Field(None, description="Last update timestamp")
|
|
333
|
+
error_message: Optional[str] = Field(None, description="Error message if status is error")
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class RLMHistoryEntry(BaseModel):
|
|
337
|
+
"""Single RLM training history entry."""
|
|
338
|
+
iteration: int = Field(..., description="Training iteration number")
|
|
339
|
+
timestamp: int = Field(..., description="Unix timestamp")
|
|
340
|
+
adapter_score: Optional[float] = Field(None, description="Adapter evaluation score")
|
|
341
|
+
base_score: Optional[float] = Field(None, description="Base model score")
|
|
342
|
+
improvement: Optional[float] = Field(None, description="Relative improvement")
|
|
343
|
+
metrics: Optional[Dict[str, Any]] = Field(None, description="Additional metrics")
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# =============================================================================
|
|
347
|
+
# Model Management
|
|
348
|
+
# =============================================================================
|
|
349
|
+
|
|
350
|
+
class ModelInfo(BaseModel):
|
|
351
|
+
"""Information about a cached model."""
|
|
352
|
+
id: str = Field(..., description="Model identifier")
|
|
353
|
+
path: str = Field(..., description="Local path to the model")
|
|
354
|
+
size_bytes: Optional[int] = Field(None, description="Model size in bytes")
|
|
355
|
+
format: Literal["mlx", "hf", "gguf"] = Field(..., description="Model format")
|
|
356
|
+
has_adapter: bool = Field(False, description="Whether model has adapter weights")
|
|
357
|
+
adapter_path: Optional[str] = Field(None, description="Path to adapter if present")
|
|
358
|
+
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional model metadata")
|
|
359
|
+
downloaded_at: Optional[int] = Field(None, description="Download timestamp")
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
class ModelsListResponse(BaseModel):
|
|
363
|
+
"""Response for listing cached models."""
|
|
364
|
+
models: List[ModelInfo] = Field(..., description="List of cached models")
|
|
365
|
+
total: int = Field(..., description="Total number of models")
|
|
366
|
+
cache_dir: str = Field(..., description="Current cache directory")
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class ModelPullRequest(BaseModel):
|
|
370
|
+
"""Request to pull a model from HuggingFace."""
|
|
371
|
+
model_id: str = Field(..., description="HuggingFace model identifier", min_length=1)
|
|
372
|
+
convert: bool = Field(True, description="Convert to MLX format")
|
|
373
|
+
quantize: bool = Field(False, description="Quantize during conversion")
|
|
374
|
+
q_bits: Optional[int] = Field(4, description="Quantization bits", ge=1, le=8)
|
|
375
|
+
q_group_size: Optional[int] = Field(64, description="Quantization group size")
|
|
376
|
+
trust_remote_code: bool = Field(False, description="Trust remote code in model")
|
|
377
|
+
|
|
378
|
+
class Config:
|
|
379
|
+
json_schema_extra = {
|
|
380
|
+
"example": {
|
|
381
|
+
"model_id": "mlx-community/Llama-3.2-1B-Instruct-4bit",
|
|
382
|
+
"convert": True,
|
|
383
|
+
"quantize": False,
|
|
384
|
+
}
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class ModelPullStatus(BaseModel):
|
|
389
|
+
"""Status of model pull operation."""
|
|
390
|
+
status: Literal["pending", "downloading", "converting", "completed", "error"] = Field(
|
|
391
|
+
..., description="Current pull status"
|
|
392
|
+
)
|
|
393
|
+
progress: Optional[float] = Field(None, description="Progress percentage (0-100)", ge=0, le=100)
|
|
394
|
+
message: Optional[str] = Field(None, description="Status message")
|
|
395
|
+
downloaded_bytes: Optional[int] = Field(None, description="Bytes downloaded so far")
|
|
396
|
+
total_bytes: Optional[int] = Field(None, description="Total bytes to download")
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
class ModelPullResponse(BaseModel):
|
|
400
|
+
"""Response for model pull request."""
|
|
401
|
+
ok: bool = Field(..., description="Whether pull was initiated successfully")
|
|
402
|
+
model_id: str = Field(..., description="Model identifier")
|
|
403
|
+
local_path: Optional[str] = Field(None, description="Local path where model will be stored")
|
|
404
|
+
status: ModelPullStatus = Field(..., description="Current pull status")
|
|
405
|
+
message: Optional[str] = Field(None, description="Status message")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# =============================================================================
|
|
409
|
+
# HuggingFace Token Management
|
|
410
|
+
# =============================================================================
|
|
411
|
+
|
|
412
|
+
class HFTokenRequest(BaseModel):
|
|
413
|
+
"""Request to store HuggingFace token."""
|
|
414
|
+
token: str = Field(
|
|
415
|
+
...,
|
|
416
|
+
description="HuggingFace API token",
|
|
417
|
+
min_length=1,
|
|
418
|
+
json_schema_extra={"format": "password"}
|
|
419
|
+
)
|
|
420
|
+
persist: bool = Field(
|
|
421
|
+
True, description="Persist token to disk (encrypted if possible)"
|
|
422
|
+
)
|
|
423
|
+
validate_token: bool = Field(
|
|
424
|
+
True, description="Validate token before storing"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
class HFTokenResponse(BaseModel):
|
|
429
|
+
"""Response after storing HF token."""
|
|
430
|
+
ok: bool = Field(..., description="Whether token was stored successfully")
|
|
431
|
+
validated: bool = Field(..., description="Whether token was validated")
|
|
432
|
+
username: Optional[str] = Field(None, description="HF username if validated")
|
|
433
|
+
message: str = Field(..., description="Status message")
|
|
434
|
+
storage_method: Literal["keyring", "file", "memory"] = Field(
|
|
435
|
+
..., description="How the token is stored"
|
|
436
|
+
)
|
mlxsmith/auth.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from huggingface_hub import HfApi, get_token as hf_get_token, logout as hf_logout
|
|
9
|
+
from huggingface_hub import constants as hf_constants
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class AuthStatus:
|
|
14
|
+
token_present: bool
|
|
15
|
+
token_hint: Optional[str] = None
|
|
16
|
+
user: Optional[str] = None
|
|
17
|
+
warnings: list[str] = field(default_factory=list)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _mask_token(token: str) -> str:
|
|
21
|
+
if not token:
|
|
22
|
+
return ""
|
|
23
|
+
if len(token) <= 8:
|
|
24
|
+
return "***"
|
|
25
|
+
return f"{token[:4]}...{token[-4:]}"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _token_path() -> Path:
|
|
29
|
+
hf_home = os.environ.get("HF_HOME")
|
|
30
|
+
if hf_home:
|
|
31
|
+
return Path(hf_home) / "token"
|
|
32
|
+
return Path(getattr(hf_constants, "HF_TOKEN_PATH", Path(hf_constants.HF_HOME) / "token"))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_token() -> Optional[str]:
|
|
36
|
+
try:
|
|
37
|
+
token = hf_get_token()
|
|
38
|
+
if token:
|
|
39
|
+
return token
|
|
40
|
+
except Exception:
|
|
41
|
+
pass
|
|
42
|
+
path = _token_path()
|
|
43
|
+
if path.exists():
|
|
44
|
+
return path.read_text(encoding="utf-8").strip()
|
|
45
|
+
return None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def save_token(token: str) -> None:
|
|
49
|
+
path = _token_path()
|
|
50
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
51
|
+
path.write_text(token, encoding="utf-8")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def delete_token() -> bool:
|
|
55
|
+
removed = False
|
|
56
|
+
try:
|
|
57
|
+
hf_logout()
|
|
58
|
+
removed = True
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
path = _token_path()
|
|
62
|
+
if path.exists():
|
|
63
|
+
path.unlink()
|
|
64
|
+
removed = True
|
|
65
|
+
return removed
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_status(validate: bool = False) -> AuthStatus:
|
|
69
|
+
token = load_token()
|
|
70
|
+
if not token:
|
|
71
|
+
return AuthStatus(token_present=False)
|
|
72
|
+
status = AuthStatus(token_present=True, token_hint=_mask_token(token))
|
|
73
|
+
if validate:
|
|
74
|
+
try:
|
|
75
|
+
info = HfApi().whoami(token=token)
|
|
76
|
+
status.user = info.get("name") or info.get("fullname") or info.get("email")
|
|
77
|
+
except Exception as exc:
|
|
78
|
+
status.warnings.append(f"Token validation failed: {exc}")
|
|
79
|
+
return status
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def login(token: str, validate: bool = True) -> AuthStatus:
|
|
83
|
+
save_token(token)
|
|
84
|
+
return get_status(validate=validate)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def logout() -> bool:
|
|
88
|
+
return delete_token()
|
mlxsmith/bench.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .config import ProjectConfig
|
|
8
|
+
from .models import resolve_model_spec
|
|
9
|
+
from .util import ensure_dir, now_ts
|
|
10
|
+
from .llm.registry import get_llm_backend
|
|
11
|
+
from .accel import get_backend
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def run_bench(
|
|
15
|
+
project_root: Path,
|
|
16
|
+
cfg: ProjectConfig,
|
|
17
|
+
model_id_or_path: str,
|
|
18
|
+
accel: str,
|
|
19
|
+
*,
|
|
20
|
+
prompt: str,
|
|
21
|
+
max_tokens: int,
|
|
22
|
+
reps: int,
|
|
23
|
+
mode: str = "inference",
|
|
24
|
+
steps: int = 5,
|
|
25
|
+
) -> Path:
|
|
26
|
+
out_dir = ensure_dir(project_root / "bench")
|
|
27
|
+
out_path = out_dir / f"bench_{now_ts()}.json"
|
|
28
|
+
|
|
29
|
+
accel_backend = get_backend(accel)
|
|
30
|
+
accel_backend.patch()
|
|
31
|
+
|
|
32
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
33
|
+
base_model, adapter_path, _meta = resolve_model_spec(project_root, model_id_or_path, cfg)
|
|
34
|
+
llm.load(
|
|
35
|
+
base_model,
|
|
36
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
37
|
+
dtype=cfg.model.dtype,
|
|
38
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
39
|
+
)
|
|
40
|
+
if adapter_path:
|
|
41
|
+
llm.apply_adapter(str(adapter_path))
|
|
42
|
+
|
|
43
|
+
results = []
|
|
44
|
+
mode = (mode or "inference").lower()
|
|
45
|
+
|
|
46
|
+
if mode == "trainer":
|
|
47
|
+
opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
48
|
+
prompt_ids = llm.encode(prompt)
|
|
49
|
+
ids = llm.encode(prompt + " " + "x" * max_tokens)
|
|
50
|
+
for i in range(max(1, reps)):
|
|
51
|
+
t0 = time.time()
|
|
52
|
+
for _ in range(max(1, steps)):
|
|
53
|
+
def loss_fn(_model):
|
|
54
|
+
return llm.sft_loss(ids, train_on_prompt=cfg.train.train_on_prompt, prompt_len=len(prompt_ids))
|
|
55
|
+
|
|
56
|
+
_loss, grads = llm.value_and_grad(loss_fn)
|
|
57
|
+
if grads is not None:
|
|
58
|
+
llm.apply_grads(opt, grads)
|
|
59
|
+
elapsed = max(time.time() - t0, 1e-6)
|
|
60
|
+
results.append({"rep": i, "steps": steps, "time_s": elapsed, "steps_per_s": steps / elapsed})
|
|
61
|
+
elif mode == "end_to_end":
|
|
62
|
+
opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
63
|
+
for i in range(max(1, reps)):
|
|
64
|
+
t0 = time.time()
|
|
65
|
+
gen = llm.generate(prompt, max_new_tokens=max_tokens, temperature=0.0)
|
|
66
|
+
def loss_fn(_model):
|
|
67
|
+
return llm.rl_loss(gen.token_ids, prompt_len=gen.prompt_len, advantage=1.0)
|
|
68
|
+
_loss, grads = llm.value_and_grad(loss_fn)
|
|
69
|
+
if grads is not None:
|
|
70
|
+
llm.apply_grads(opt, grads)
|
|
71
|
+
elapsed = max(time.time() - t0, 1e-6)
|
|
72
|
+
gen_tokens = max(0, len(gen.token_ids) - gen.prompt_len)
|
|
73
|
+
results.append({"rep": i, "tokens": gen_tokens, "time_s": elapsed, "tps": gen_tokens / elapsed})
|
|
74
|
+
else:
|
|
75
|
+
for i in range(max(1, reps)):
|
|
76
|
+
t0 = time.time()
|
|
77
|
+
gen = llm.generate(prompt, max_new_tokens=max_tokens, temperature=0.0)
|
|
78
|
+
elapsed = max(time.time() - t0, 1e-6)
|
|
79
|
+
gen_tokens = max(0, len(gen.token_ids) - gen.prompt_len)
|
|
80
|
+
results.append({"rep": i, "tokens": gen_tokens, "time_s": elapsed, "tps": gen_tokens / elapsed})
|
|
81
|
+
|
|
82
|
+
if mode == "trainer":
|
|
83
|
+
avg_metric = sum(r["steps_per_s"] for r in results) / max(1, len(results))
|
|
84
|
+
metric_name = "avg_steps_per_s"
|
|
85
|
+
else:
|
|
86
|
+
avg_metric = sum(r["tps"] for r in results) / max(1, len(results))
|
|
87
|
+
metric_name = "avg_tps"
|
|
88
|
+
|
|
89
|
+
summary = {
|
|
90
|
+
"model": base_model,
|
|
91
|
+
"adapter": str(adapter_path) if adapter_path else None,
|
|
92
|
+
"prompt": prompt,
|
|
93
|
+
"max_tokens": max_tokens,
|
|
94
|
+
"reps": reps,
|
|
95
|
+
"mode": mode,
|
|
96
|
+
"steps": steps if mode == "trainer" else None,
|
|
97
|
+
"results": results,
|
|
98
|
+
metric_name: avg_metric,
|
|
99
|
+
"accel": accel_backend.name,
|
|
100
|
+
}
|
|
101
|
+
out_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
|
102
|
+
return out_path
|