mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
1
+ """Pydantic models for MLXSmith configuration sections."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, Literal, Optional, Any
6
+
7
+ from pydantic import BaseModel, Field, field_validator
8
+
9
+ AccelBackendName = Literal["none", "zmlx"]
10
+
11
+
12
+ class ModelConfig(BaseModel):
13
+ """Model configuration for MLXSmith."""
14
+
15
+ id: str = Field(
16
+ default="mlx-community/Llama-3.2-3B-Instruct-4bit",
17
+ description="HF model id or local path",
18
+ )
19
+ backend: str = Field(default="mlx-lm", description="LLM backend implementation")
20
+ dtype: str = Field(default="bf16")
21
+ quantization: str = Field(default="none", description="none|q4|q6|q8")
22
+ max_seq_len: int = Field(default=8192)
23
+ trust_remote_code: bool = Field(default=False)
24
+ use_chat_template: bool = Field(default=True)
25
+
26
+ @field_validator("quantization")
27
+ @classmethod
28
+ def validate_quantization(cls, v: str) -> str:
29
+ allowed = ["none", "q4", "q6", "q8"]
30
+ if v not in allowed:
31
+ raise ValueError(f"quantization must be one of {allowed}, got {v}")
32
+ return v
33
+
34
+
35
+ class AccelConfig(BaseModel):
36
+ """Acceleration backend configuration."""
37
+
38
+ backend: AccelBackendName = Field(default="none")
39
+ compile_cache: str = Field(default="cache/compiled_kernels")
40
+
41
+
42
+ class TrainConfig(BaseModel):
43
+ """Training configuration for SFT and other training modes."""
44
+
45
+ seed: int = 1337
46
+ batch_size: int = 1
47
+ grad_accum: int = 8
48
+ lr: float = 2e-4
49
+ weight_decay: float = 0.0
50
+ iters: int = 1000
51
+ save_every: int = 100
52
+ eval_every: int = 100
53
+ log_every: int = 10
54
+ train_on_prompt: bool = False
55
+ max_grad_norm: float = 1.0
56
+
57
+ @field_validator("lr", "weight_decay")
58
+ @classmethod
59
+ def validate_positive(cls, v: float) -> float:
60
+ if v < 0:
61
+ raise ValueError("value must be non-negative")
62
+ return v
63
+
64
+
65
+ class LoraConfig(BaseModel):
66
+ """LoRA/DoRA adapter configuration."""
67
+
68
+ r: int = 16
69
+ alpha: int = 32
70
+ dropout: float = 0.05
71
+ target_modules: List[str] = Field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
72
+ num_layers: int = Field(default=0, description="0 = all layers (MLX-LM LoRA)")
73
+ scale: Optional[float] = Field(default=None, description="Optional LoRA scale (overrides alpha/r)")
74
+ fine_tune_type: Literal["lora", "dora", "full"] = "lora"
75
+
76
+ @field_validator("r", "alpha")
77
+ @classmethod
78
+ def validate_positive_int(cls, v: int) -> int:
79
+ if v <= 0:
80
+ raise ValueError("value must be positive")
81
+ return v
82
+
83
+ @field_validator("dropout")
84
+ @classmethod
85
+ def validate_dropout(cls, v: float) -> float:
86
+ if not 0 <= v <= 1:
87
+ raise ValueError("dropout must be between 0 and 1")
88
+ return v
89
+
90
+
91
+ class PrefConfig(BaseModel):
92
+ """Preference tuning configuration (DPO, ORPO, GRPO)."""
93
+
94
+ algo: Literal["dpo", "orpo", "grpo"] = "dpo"
95
+ beta: float = 0.1
96
+ kl_coeff: float = 0.0
97
+ reference_model: Optional[str] = None
98
+
99
+
100
+ class RftConfig(BaseModel):
101
+ """Reinforcement fine-tuning configuration."""
102
+
103
+ algo: Literal["grpo"] = "grpo"
104
+ rollouts: int = 8
105
+ kl_coeff: float = 0.02
106
+ max_steps_per_task: int = 1
107
+ temperature: float = 0.8
108
+ max_new_tokens: int = 256
109
+ normalize_advantage: bool = True
110
+ reference_model: Optional[str] = None
111
+
112
+
113
+ class ServeConfig(BaseModel):
114
+ """Server configuration for inference serving."""
115
+
116
+ api: Literal["openai", "simple"] = "openai"
117
+ host: str = "0.0.0.0"
118
+ port: int = 8080
119
+ ui: bool = False
120
+ stream: bool = True
121
+
122
+ @field_validator("port")
123
+ @classmethod
124
+ def validate_port(cls, v: int) -> int:
125
+ if not 1 <= v <= 65535:
126
+ raise ValueError("port must be between 1 and 65535")
127
+ return v
128
+
129
+
130
+ class InferConfig(BaseModel):
131
+ """Inference configuration for generation."""
132
+
133
+ max_new_tokens: int = 256
134
+ temperature: float = 0.7
135
+ top_p: float = 1.0
136
+ top_k: Optional[int] = None
137
+
138
+ @field_validator("temperature")
139
+ @classmethod
140
+ def validate_temperature(cls, v: float) -> float:
141
+ if v < 0:
142
+ raise ValueError("temperature must be non-negative")
143
+ return v
144
+
145
+ @field_validator("top_p")
146
+ @classmethod
147
+ def validate_top_p(cls, v: float) -> float:
148
+ if not 0 <= v <= 1:
149
+ raise ValueError("top_p must be between 0 and 1")
150
+ return v
151
+
152
+
153
+ class LoggingConfig(BaseModel):
154
+ """Logging configuration."""
155
+
156
+ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
157
+ file: Optional[str] = None
158
+ format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
159
+
160
+
161
+ # CLI argument aliases for field mapping
162
+ CLI_ALIASES: dict[str, tuple[str, ...]] = {
163
+ "learning_rate": ("train", "lr"),
164
+ "lr": ("train", "lr"),
165
+ "batch_size": ("train", "batch_size"),
166
+ "iters": ("train", "iters"),
167
+ "model_id": ("model", "id"),
168
+ "accel_backend": ("accel", "backend"),
169
+ "host": ("serve", "host"),
170
+ "port": ("serve", "port"),
171
+ "ui": ("serve", "ui"),
172
+ "iterations": ("rlm", "iterations"),
173
+ "rollouts": ("rft", "rollouts"),
174
+ "algo": ("pref", "algo"),
175
+ }
176
+
177
+
178
+ class RlmConfig(BaseModel):
179
+ """Recursive Language Model (RLM) loop configuration."""
180
+
181
+ iterations: int = 50 # 0 = infinite
182
+ sleep_between: int = 0 # seconds
183
+ tasks_per_iter: int = 80
184
+ rollouts_per_task: int = 8
185
+ attempts_per_task: int = 3
186
+ corpus_max: int = 8000
187
+ mix_old_ratio: float = 0.4
188
+ hard_ratio: float = 0.6
189
+ mutations_per_task: int = 1
190
+ use_task_mutation: bool = True
191
+ gating: Literal["strict", "threshold", "ema"] = "strict"
192
+ gating_threshold: float = 0.0
193
+ gating_ema_alpha: float = 0.2
194
+ infer_staleness: int = 0 # iterations of staleness allowed between trainer and inference weights
195
+ require_recursion: bool = False
196
+ task_domains: List[str] = Field(default_factory=lambda: ["strings", "arrays", "math", "dp", "graphs"])
197
+ benchmark_suite: str = "eval/suites/rlm_bench.yaml"
198
+ holdout_suite: Optional[str] = "eval/suites/rlm_holdout.yaml"
199
+ verifier_timeout_s: int = 30
200
+ verifier_backend: Literal["pytest", "docker"] = "pytest"
201
+ docker_image: str = "python:3.11-slim"
202
+ docker_memory_mb: int = 512
203
+ docker_cpus: float = 1.0
204
+ docker_pids: int = 128
205
+ task_gen_max_new_tokens: int = 1024
206
+ similarity_threshold: float = 0.85
207
+ min_task_desc_len: int = 10
208
+ min_task_asserts: int = 2
209
+ max_task_prompt_len: int = 2000
210
+ min_task_tests_len: int = 20
211
+ max_task_tests_len: int = 8000
212
+ blocked_task_patterns: List[str] = Field(
213
+ default_factory=lambda: [
214
+ r"\bsubprocess\b",
215
+ r"\bos\.system\b",
216
+ r"\bshutil\.rmtree\b",
217
+ r"\brm\s+-rf\b",
218
+ r"\brequests\b",
219
+ r"\burllib\b",
220
+ r"\bsocket\b",
221
+ r"\bhttp[s]?://",
222
+ r"\bpip\s+install\b",
223
+ r"\bapt-get\b",
224
+ r"\bbrew\s+install\b",
225
+ ]
226
+ )
227
+
228
+ @field_validator("mix_old_ratio", "hard_ratio", "gating_ema_alpha", "similarity_threshold")
229
+ @classmethod
230
+ def validate_ratio(cls, v: float) -> float:
231
+ if not 0 <= v <= 1:
232
+ raise ValueError("ratio must be between 0 and 1")
233
+ return v
234
+
235
+
236
+ class ProjectConfig(BaseModel):
237
+ """Root configuration model containing all sections."""
238
+
239
+ model: ModelConfig = Field(default_factory=ModelConfig)
240
+ accel: AccelConfig = Field(default_factory=AccelConfig)
241
+ train: TrainConfig = Field(default_factory=TrainConfig)
242
+ lora: LoraConfig = Field(default_factory=LoraConfig)
243
+ pref: PrefConfig = Field(default_factory=PrefConfig)
244
+ rft: RftConfig = Field(default_factory=RftConfig)
245
+ infer: InferConfig = Field(default_factory=InferConfig)
246
+ serve: ServeConfig = Field(default_factory=ServeConfig)
247
+ rlm: RlmConfig = Field(default_factory=RlmConfig)
248
+ logging: LoggingConfig = Field(default_factory=LoggingConfig)
249
+
250
+ def to_dict(self) -> Dict[str, Any]:
251
+ """Convert configuration to dictionary."""
252
+ return self.model_dump()
253
+
254
+ def to_yaml(self) -> str:
255
+ """Convert configuration to YAML string."""
256
+ import yaml
257
+ return yaml.safe_dump(self.model_dump(), sort_keys=False)
258
+
259
+ def to_json(self, indent: int = 2) -> str:
260
+ """Convert configuration to JSON string."""
261
+ return self.model_dump_json(indent=indent)