mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Pydantic models for MLXSmith configuration sections."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Dict, List, Literal, Optional, Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, field_validator
|
|
8
|
+
|
|
9
|
+
AccelBackendName = Literal["none", "zmlx"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ModelConfig(BaseModel):
|
|
13
|
+
"""Model configuration for MLXSmith."""
|
|
14
|
+
|
|
15
|
+
id: str = Field(
|
|
16
|
+
default="mlx-community/Llama-3.2-3B-Instruct-4bit",
|
|
17
|
+
description="HF model id or local path",
|
|
18
|
+
)
|
|
19
|
+
backend: str = Field(default="mlx-lm", description="LLM backend implementation")
|
|
20
|
+
dtype: str = Field(default="bf16")
|
|
21
|
+
quantization: str = Field(default="none", description="none|q4|q6|q8")
|
|
22
|
+
max_seq_len: int = Field(default=8192)
|
|
23
|
+
trust_remote_code: bool = Field(default=False)
|
|
24
|
+
use_chat_template: bool = Field(default=True)
|
|
25
|
+
|
|
26
|
+
@field_validator("quantization")
|
|
27
|
+
@classmethod
|
|
28
|
+
def validate_quantization(cls, v: str) -> str:
|
|
29
|
+
allowed = ["none", "q4", "q6", "q8"]
|
|
30
|
+
if v not in allowed:
|
|
31
|
+
raise ValueError(f"quantization must be one of {allowed}, got {v}")
|
|
32
|
+
return v
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class AccelConfig(BaseModel):
|
|
36
|
+
"""Acceleration backend configuration."""
|
|
37
|
+
|
|
38
|
+
backend: AccelBackendName = Field(default="none")
|
|
39
|
+
compile_cache: str = Field(default="cache/compiled_kernels")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class TrainConfig(BaseModel):
|
|
43
|
+
"""Training configuration for SFT and other training modes."""
|
|
44
|
+
|
|
45
|
+
seed: int = 1337
|
|
46
|
+
batch_size: int = 1
|
|
47
|
+
grad_accum: int = 8
|
|
48
|
+
lr: float = 2e-4
|
|
49
|
+
weight_decay: float = 0.0
|
|
50
|
+
iters: int = 1000
|
|
51
|
+
save_every: int = 100
|
|
52
|
+
eval_every: int = 100
|
|
53
|
+
log_every: int = 10
|
|
54
|
+
train_on_prompt: bool = False
|
|
55
|
+
max_grad_norm: float = 1.0
|
|
56
|
+
|
|
57
|
+
@field_validator("lr", "weight_decay")
|
|
58
|
+
@classmethod
|
|
59
|
+
def validate_positive(cls, v: float) -> float:
|
|
60
|
+
if v < 0:
|
|
61
|
+
raise ValueError("value must be non-negative")
|
|
62
|
+
return v
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LoraConfig(BaseModel):
|
|
66
|
+
"""LoRA/DoRA adapter configuration."""
|
|
67
|
+
|
|
68
|
+
r: int = 16
|
|
69
|
+
alpha: int = 32
|
|
70
|
+
dropout: float = 0.05
|
|
71
|
+
target_modules: List[str] = Field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
|
|
72
|
+
num_layers: int = Field(default=0, description="0 = all layers (MLX-LM LoRA)")
|
|
73
|
+
scale: Optional[float] = Field(default=None, description="Optional LoRA scale (overrides alpha/r)")
|
|
74
|
+
fine_tune_type: Literal["lora", "dora", "full"] = "lora"
|
|
75
|
+
|
|
76
|
+
@field_validator("r", "alpha")
|
|
77
|
+
@classmethod
|
|
78
|
+
def validate_positive_int(cls, v: int) -> int:
|
|
79
|
+
if v <= 0:
|
|
80
|
+
raise ValueError("value must be positive")
|
|
81
|
+
return v
|
|
82
|
+
|
|
83
|
+
@field_validator("dropout")
|
|
84
|
+
@classmethod
|
|
85
|
+
def validate_dropout(cls, v: float) -> float:
|
|
86
|
+
if not 0 <= v <= 1:
|
|
87
|
+
raise ValueError("dropout must be between 0 and 1")
|
|
88
|
+
return v
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class PrefConfig(BaseModel):
|
|
92
|
+
"""Preference tuning configuration (DPO, ORPO, GRPO)."""
|
|
93
|
+
|
|
94
|
+
algo: Literal["dpo", "orpo", "grpo"] = "dpo"
|
|
95
|
+
beta: float = 0.1
|
|
96
|
+
kl_coeff: float = 0.0
|
|
97
|
+
reference_model: Optional[str] = None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class RftConfig(BaseModel):
|
|
101
|
+
"""Reinforcement fine-tuning configuration."""
|
|
102
|
+
|
|
103
|
+
algo: Literal["grpo"] = "grpo"
|
|
104
|
+
rollouts: int = 8
|
|
105
|
+
kl_coeff: float = 0.02
|
|
106
|
+
max_steps_per_task: int = 1
|
|
107
|
+
temperature: float = 0.8
|
|
108
|
+
max_new_tokens: int = 256
|
|
109
|
+
normalize_advantage: bool = True
|
|
110
|
+
reference_model: Optional[str] = None
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class ServeConfig(BaseModel):
|
|
114
|
+
"""Server configuration for inference serving."""
|
|
115
|
+
|
|
116
|
+
api: Literal["openai", "simple"] = "openai"
|
|
117
|
+
host: str = "0.0.0.0"
|
|
118
|
+
port: int = 8080
|
|
119
|
+
ui: bool = False
|
|
120
|
+
stream: bool = True
|
|
121
|
+
|
|
122
|
+
@field_validator("port")
|
|
123
|
+
@classmethod
|
|
124
|
+
def validate_port(cls, v: int) -> int:
|
|
125
|
+
if not 1 <= v <= 65535:
|
|
126
|
+
raise ValueError("port must be between 1 and 65535")
|
|
127
|
+
return v
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class InferConfig(BaseModel):
|
|
131
|
+
"""Inference configuration for generation."""
|
|
132
|
+
|
|
133
|
+
max_new_tokens: int = 256
|
|
134
|
+
temperature: float = 0.7
|
|
135
|
+
top_p: float = 1.0
|
|
136
|
+
top_k: Optional[int] = None
|
|
137
|
+
|
|
138
|
+
@field_validator("temperature")
|
|
139
|
+
@classmethod
|
|
140
|
+
def validate_temperature(cls, v: float) -> float:
|
|
141
|
+
if v < 0:
|
|
142
|
+
raise ValueError("temperature must be non-negative")
|
|
143
|
+
return v
|
|
144
|
+
|
|
145
|
+
@field_validator("top_p")
|
|
146
|
+
@classmethod
|
|
147
|
+
def validate_top_p(cls, v: float) -> float:
|
|
148
|
+
if not 0 <= v <= 1:
|
|
149
|
+
raise ValueError("top_p must be between 0 and 1")
|
|
150
|
+
return v
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class LoggingConfig(BaseModel):
|
|
154
|
+
"""Logging configuration."""
|
|
155
|
+
|
|
156
|
+
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
|
|
157
|
+
file: Optional[str] = None
|
|
158
|
+
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# CLI argument aliases for field mapping
|
|
162
|
+
CLI_ALIASES: dict[str, tuple[str, ...]] = {
|
|
163
|
+
"learning_rate": ("train", "lr"),
|
|
164
|
+
"lr": ("train", "lr"),
|
|
165
|
+
"batch_size": ("train", "batch_size"),
|
|
166
|
+
"iters": ("train", "iters"),
|
|
167
|
+
"model_id": ("model", "id"),
|
|
168
|
+
"accel_backend": ("accel", "backend"),
|
|
169
|
+
"host": ("serve", "host"),
|
|
170
|
+
"port": ("serve", "port"),
|
|
171
|
+
"ui": ("serve", "ui"),
|
|
172
|
+
"iterations": ("rlm", "iterations"),
|
|
173
|
+
"rollouts": ("rft", "rollouts"),
|
|
174
|
+
"algo": ("pref", "algo"),
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class RlmConfig(BaseModel):
|
|
179
|
+
"""Recursive Language Model (RLM) loop configuration."""
|
|
180
|
+
|
|
181
|
+
iterations: int = 50 # 0 = infinite
|
|
182
|
+
sleep_between: int = 0 # seconds
|
|
183
|
+
tasks_per_iter: int = 80
|
|
184
|
+
rollouts_per_task: int = 8
|
|
185
|
+
attempts_per_task: int = 3
|
|
186
|
+
corpus_max: int = 8000
|
|
187
|
+
mix_old_ratio: float = 0.4
|
|
188
|
+
hard_ratio: float = 0.6
|
|
189
|
+
mutations_per_task: int = 1
|
|
190
|
+
use_task_mutation: bool = True
|
|
191
|
+
gating: Literal["strict", "threshold", "ema"] = "strict"
|
|
192
|
+
gating_threshold: float = 0.0
|
|
193
|
+
gating_ema_alpha: float = 0.2
|
|
194
|
+
infer_staleness: int = 0 # iterations of staleness allowed between trainer and inference weights
|
|
195
|
+
require_recursion: bool = False
|
|
196
|
+
task_domains: List[str] = Field(default_factory=lambda: ["strings", "arrays", "math", "dp", "graphs"])
|
|
197
|
+
benchmark_suite: str = "eval/suites/rlm_bench.yaml"
|
|
198
|
+
holdout_suite: Optional[str] = "eval/suites/rlm_holdout.yaml"
|
|
199
|
+
verifier_timeout_s: int = 30
|
|
200
|
+
verifier_backend: Literal["pytest", "docker"] = "pytest"
|
|
201
|
+
docker_image: str = "python:3.11-slim"
|
|
202
|
+
docker_memory_mb: int = 512
|
|
203
|
+
docker_cpus: float = 1.0
|
|
204
|
+
docker_pids: int = 128
|
|
205
|
+
task_gen_max_new_tokens: int = 1024
|
|
206
|
+
similarity_threshold: float = 0.85
|
|
207
|
+
min_task_desc_len: int = 10
|
|
208
|
+
min_task_asserts: int = 2
|
|
209
|
+
max_task_prompt_len: int = 2000
|
|
210
|
+
min_task_tests_len: int = 20
|
|
211
|
+
max_task_tests_len: int = 8000
|
|
212
|
+
blocked_task_patterns: List[str] = Field(
|
|
213
|
+
default_factory=lambda: [
|
|
214
|
+
r"\bsubprocess\b",
|
|
215
|
+
r"\bos\.system\b",
|
|
216
|
+
r"\bshutil\.rmtree\b",
|
|
217
|
+
r"\brm\s+-rf\b",
|
|
218
|
+
r"\brequests\b",
|
|
219
|
+
r"\burllib\b",
|
|
220
|
+
r"\bsocket\b",
|
|
221
|
+
r"\bhttp[s]?://",
|
|
222
|
+
r"\bpip\s+install\b",
|
|
223
|
+
r"\bapt-get\b",
|
|
224
|
+
r"\bbrew\s+install\b",
|
|
225
|
+
]
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
@field_validator("mix_old_ratio", "hard_ratio", "gating_ema_alpha", "similarity_threshold")
|
|
229
|
+
@classmethod
|
|
230
|
+
def validate_ratio(cls, v: float) -> float:
|
|
231
|
+
if not 0 <= v <= 1:
|
|
232
|
+
raise ValueError("ratio must be between 0 and 1")
|
|
233
|
+
return v
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class ProjectConfig(BaseModel):
|
|
237
|
+
"""Root configuration model containing all sections."""
|
|
238
|
+
|
|
239
|
+
model: ModelConfig = Field(default_factory=ModelConfig)
|
|
240
|
+
accel: AccelConfig = Field(default_factory=AccelConfig)
|
|
241
|
+
train: TrainConfig = Field(default_factory=TrainConfig)
|
|
242
|
+
lora: LoraConfig = Field(default_factory=LoraConfig)
|
|
243
|
+
pref: PrefConfig = Field(default_factory=PrefConfig)
|
|
244
|
+
rft: RftConfig = Field(default_factory=RftConfig)
|
|
245
|
+
infer: InferConfig = Field(default_factory=InferConfig)
|
|
246
|
+
serve: ServeConfig = Field(default_factory=ServeConfig)
|
|
247
|
+
rlm: RlmConfig = Field(default_factory=RlmConfig)
|
|
248
|
+
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
|
249
|
+
|
|
250
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
251
|
+
"""Convert configuration to dictionary."""
|
|
252
|
+
return self.model_dump()
|
|
253
|
+
|
|
254
|
+
def to_yaml(self) -> str:
|
|
255
|
+
"""Convert configuration to YAML string."""
|
|
256
|
+
import yaml
|
|
257
|
+
return yaml.safe_dump(self.model_dump(), sort_keys=False)
|
|
258
|
+
|
|
259
|
+
def to_json(self, indent: int = 2) -> str:
|
|
260
|
+
"""Convert configuration to JSON string."""
|
|
261
|
+
return self.model_dump_json(indent=indent)
|