mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/config.py ADDED
@@ -0,0 +1,543 @@
1
+ """MLXSmith configuration management with pydantic-settings.
2
+
3
+ Config precedence (highest to lowest):
4
+ 1. CLI arguments
5
+ 2. Config file (TOML/YAML/JSON)
6
+ 3. Environment variables (MLXSMITH__*)
7
+ 4. Default values
8
+
9
+ Environment variables use double underscore as nested delimiter:
10
+ MLXSMITH__MODEL__ID=custom/model
11
+ MLXSMITH__TRAIN__LR=0.001
12
+ MLXSMITH__SERVE__PORT=9000
13
+
14
+ Config files support @path syntax:
15
+ mlxsmith sft --config @production.toml
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import json
21
+ import os
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+
25
+ import yaml
26
+ from pydantic import Field, ValidationError
27
+ from pydantic_settings import BaseSettings, SettingsConfigDict
28
+
29
+ try: # Python 3.11+
30
+ import tomllib
31
+ except ModuleNotFoundError: # pragma: no cover - 3.10 fallback
32
+ import tomli as tomllib
33
+
34
+ from .config_models import (
35
+ AccelConfig,
36
+ InferConfig,
37
+ LoggingConfig,
38
+ LoraConfig,
39
+ ModelConfig,
40
+ PrefConfig,
41
+ ProjectConfig,
42
+ RftConfig,
43
+ RlmConfig,
44
+ ServeConfig,
45
+ TrainConfig,
46
+ )
47
+
48
+ __all__ = [
49
+ # Models
50
+ "ProjectConfig",
51
+ "ModelConfig",
52
+ "TrainConfig",
53
+ "LoraConfig",
54
+ "PrefConfig",
55
+ "RftConfig",
56
+ "InferConfig",
57
+ "ServeConfig",
58
+ "RlmConfig",
59
+ "AccelConfig",
60
+ "LoggingConfig",
61
+ # Functions
62
+ "load_config",
63
+ "get_config",
64
+ "resolve_config_path",
65
+ "write_default_config",
66
+ "dump_config",
67
+ "show_merged_config",
68
+ ]
69
+
70
+
71
+ class ProjectSettings(BaseSettings):
72
+ """Pydantic-settings model for environment variable loading.
73
+
74
+ This mirrors ProjectConfig but is used specifically for env var parsing.
75
+ """
76
+
77
+ model: ModelConfig = Field(default_factory=ModelConfig)
78
+ accel: AccelConfig = Field(default_factory=AccelConfig)
79
+ train: TrainConfig = Field(default_factory=TrainConfig)
80
+ lora: LoraConfig = Field(default_factory=LoraConfig)
81
+ pref: PrefConfig = Field(default_factory=PrefConfig)
82
+ rft: RftConfig = Field(default_factory=RftConfig)
83
+ infer: InferConfig = Field(default_factory=InferConfig)
84
+ serve: ServeConfig = Field(default_factory=ServeConfig)
85
+ rlm: RlmConfig = Field(default_factory=RlmConfig)
86
+ logging: LoggingConfig = Field(default_factory=LoggingConfig)
87
+
88
+ model_config = SettingsConfigDict(
89
+ env_prefix="MLXSMITH__",
90
+ env_nested_delimiter="__",
91
+ env_parse_enums=True,
92
+ extra="ignore", # Ignore unknown env vars
93
+ )
94
+
95
+
96
+ # Import CLI aliases from models
97
+ from .config_models import CLI_ALIASES as _CLI_ALIASES
98
+
99
+
100
+ def resolve_config_path(config: Union[str, Path], root: Optional[Path] = None) -> Path:
101
+ """Resolve config path, handling @prefix syntax.
102
+
103
+ Args:
104
+ config: Path string, optionally starting with @
105
+ root: Optional root directory for relative paths
106
+
107
+ Returns:
108
+ Resolved Path object
109
+
110
+ Example:
111
+ >>> resolve_config_path("@production.toml")
112
+ Path("production.toml")
113
+ >>> resolve_config_path("config.yaml", root=Path("/project"))
114
+ Path("/project/config.yaml")
115
+ """
116
+ if isinstance(config, str) and config.startswith("@"):
117
+ config = config[1:]
118
+ path = Path(config)
119
+ if root and not path.is_absolute():
120
+ path = root / path
121
+ return path
122
+
123
+
124
+ def _deep_merge(base: Dict[str, Any], override: Dict[str, Any]) -> Dict[str, Any]:
125
+ """Deep merge two dictionaries, with override taking precedence.
126
+
127
+ Args:
128
+ base: Base dictionary
129
+ override: Override dictionary (values take precedence)
130
+
131
+ Returns:
132
+ Merged dictionary
133
+ """
134
+ merged = dict(base)
135
+ for key, value in override.items():
136
+ if isinstance(value, dict) and isinstance(merged.get(key), dict):
137
+ merged[key] = _deep_merge(merged[key], value)
138
+ else:
139
+ merged[key] = value
140
+ return merged
141
+
142
+
143
+ def _flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = "__") -> Dict[str, Any]:
144
+ """Flatten a nested dictionary for env var style keys.
145
+
146
+ Args:
147
+ d: Dictionary to flatten
148
+ parent_key: Parent key prefix
149
+ sep: Separator for nested keys
150
+
151
+ Returns:
152
+ Flattened dictionary
153
+
154
+ Example:
155
+ >>> _flatten_dict({"model": {"id": "test"}})
156
+ {"model__id": "test"}
157
+ """
158
+ items: List[Tuple[str, Any]] = []
159
+ for k, v in d.items():
160
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
161
+ if isinstance(v, dict):
162
+ items.extend(_flatten_dict(v, new_key, sep=sep).items())
163
+ else:
164
+ items.append((new_key, v))
165
+ return dict(items)
166
+
167
+
168
+ def _unflatten_dict(d: Dict[str, Any], sep: str = "__") -> Dict[str, Any]:
169
+ """Unflatten a dictionary with separator-delimited keys.
170
+
171
+ Args:
172
+ d: Flat dictionary with separator in keys
173
+ sep: Separator used in keys
174
+
175
+ Returns:
176
+ Nested dictionary
177
+
178
+ Example:
179
+ >>> _unflatten_dict({"model__id": "test"})
180
+ {"model": {"id": "test"}}
181
+ """
182
+ result: Dict[str, Any] = {}
183
+ for key, value in d.items():
184
+ parts = key.split(sep)
185
+ current = result
186
+ for part in parts[:-1]:
187
+ if part not in current:
188
+ current[part] = {}
189
+ current = current[part]
190
+ current[parts[-1]] = value
191
+ return result
192
+
193
+
194
+ def _read_config_file(path: Path) -> Dict[str, Any]:
195
+ """Read and parse a config file (TOML, YAML, or JSON).
196
+
197
+ Args:
198
+ path: Path to config file
199
+
200
+ Returns:
201
+ Parsed configuration dictionary
202
+
203
+ Raises:
204
+ FileNotFoundError: If file doesn't exist
205
+ ValueError: If file format is invalid
206
+
207
+ Supported formats:
208
+ - .toml, .tml: TOML format
209
+ - .yaml, .yml: YAML format
210
+ - .json: JSON format
211
+ """
212
+ if not path.exists():
213
+ raise FileNotFoundError(f"Config file not found: {path}")
214
+
215
+ raw = path.read_text(encoding="utf-8")
216
+ suffix = path.suffix.lower()
217
+
218
+ try:
219
+ if suffix in (".toml", ".tml"):
220
+ data = tomllib.loads(raw)
221
+ elif suffix in (".yaml", ".yml"):
222
+ data = yaml.safe_load(raw)
223
+ elif suffix == ".json":
224
+ data = json.loads(raw)
225
+ else:
226
+ # Try YAML as fallback (it's a superset of JSON)
227
+ data = yaml.safe_load(raw)
228
+ except Exception as e:
229
+ raise ValueError(f"Failed to parse config file {path}: {e}") from e
230
+
231
+ if data is None:
232
+ return {}
233
+ if not isinstance(data, dict):
234
+ raise ValueError(f"Config must be a mapping, got {type(data).__name__}")
235
+
236
+ return data
237
+
238
+
239
+ def _apply_cli_overrides(
240
+ config: ProjectConfig,
241
+ overrides: Dict[str, Any]
242
+ ) -> ProjectConfig:
243
+ """Apply CLI argument overrides to configuration.
244
+
245
+ Args:
246
+ config: Base configuration
247
+ overrides: Dictionary of CLI overrides
248
+
249
+ Returns:
250
+ Updated configuration
251
+ """
252
+ if not overrides:
253
+ return config
254
+
255
+ # Convert config to dict
256
+ data = config.model_dump()
257
+
258
+ # Apply overrides with nested key support
259
+ for key, value in overrides.items():
260
+ if value is None:
261
+ continue
262
+
263
+ # Check for aliases (e.g., "lr" -> ("train", "lr"))
264
+ if key in _CLI_ALIASES:
265
+ section, field = _CLI_ALIASES[key]
266
+ if section in data:
267
+ data[section][field] = value
268
+ continue
269
+
270
+ # Handle nested keys like "model.id" or "train.lr"
271
+ if "." in key:
272
+ parts = key.split(".")
273
+ current = data
274
+ for part in parts[:-1]:
275
+ if part not in current:
276
+ current[part] = {}
277
+ current = current[part]
278
+ current[parts[-1]] = value
279
+ else:
280
+ # Find which section this key belongs to
281
+ found = False
282
+ for section_name, section_data in data.items():
283
+ if isinstance(section_data, dict) and key in section_data:
284
+ data[section_name][key] = value
285
+ found = True
286
+ break
287
+
288
+ # If not found in any section, check if it's an alias
289
+ if not found and key in _CLI_ALIASES:
290
+ section, field = _CLI_ALIASES[key]
291
+ if section in data:
292
+ data[section][field] = value
293
+
294
+ return ProjectConfig.model_validate(data)
295
+
296
+
297
+ def load_config(
298
+ path: Optional[Path] = None,
299
+ cli_overrides: Optional[Dict[str, Any]] = None,
300
+ require: bool = False,
301
+ ) -> ProjectConfig:
302
+ """Load configuration with proper precedence.
303
+
304
+ Precedence (highest to lowest):
305
+ 1. CLI arguments (cli_overrides)
306
+ 2. Config file (if path provided)
307
+ 3. Environment variables (MLXSMITH__*)
308
+ 4. Default values
309
+
310
+ Args:
311
+ path: Path to config file (optional)
312
+ cli_overrides: Dictionary of CLI argument overrides
313
+ require: If True, raise FileNotFoundError if config file missing
314
+
315
+ Returns:
316
+ Merged ProjectConfig
317
+
318
+ Raises:
319
+ FileNotFoundError: If require=True and config file not found
320
+ ValidationError: If configuration is invalid
321
+
322
+ Example:
323
+ >>> cfg = load_config(Path("config.yaml"), cli_overrides={"model.id": "custom"})
324
+ """
325
+ # Start with defaults (lowest priority)
326
+ defaults = ProjectConfig()
327
+
328
+ # Layer 1: Environment variables
329
+ try:
330
+ env_settings = ProjectSettings()
331
+ env_data = env_settings.model_dump()
332
+ except ValidationError as e:
333
+ # Log warning but continue with empty env data
334
+ import warnings
335
+ warnings.warn(f"Failed to parse environment variables: {e}")
336
+ env_data = {}
337
+
338
+ # Merge env vars into defaults
339
+ merged = _deep_merge(defaults.model_dump(), env_data)
340
+
341
+ # Layer 2: Config file (higher priority than env)
342
+ if path is not None:
343
+ if path.exists():
344
+ try:
345
+ file_data = _read_config_file(path)
346
+ merged = _deep_merge(merged, file_data)
347
+ except (ValueError, FileNotFoundError) as e:
348
+ if require:
349
+ raise
350
+ import warnings
351
+ warnings.warn(f"Failed to load config file: {e}")
352
+ elif require:
353
+ raise FileNotFoundError(f"Config file not found: {path}")
354
+
355
+ # Create config from merged data
356
+ config = ProjectConfig.model_validate(merged)
357
+
358
+ # Layer 3: CLI overrides (highest priority)
359
+ if cli_overrides:
360
+ config = _apply_cli_overrides(config, cli_overrides)
361
+
362
+ return config
363
+
364
+
365
+ def get_config(
366
+ config_path: Optional[Union[str, Path]] = None,
367
+ root: Optional[Path] = None,
368
+ **cli_kwargs: Any,
369
+ ) -> ProjectConfig:
370
+ """Convenience function to get configuration with CLI overrides.
371
+
372
+ This is the recommended way to load configuration in CLI commands.
373
+
374
+ Args:
375
+ config_path: Path to config file (supports @prefix syntax)
376
+ root: Project root for resolving relative paths
377
+ **cli_kwargs: CLI argument overrides
378
+
379
+ Returns:
380
+ ProjectConfig with all overrides applied
381
+
382
+ Example:
383
+ >>> cfg = get_config("@production.toml", model_id="custom/model")
384
+ >>> cfg = get_config("config.yaml", root=Path("/project"), train__lr=0.001)
385
+ """
386
+ path = None
387
+ if config_path:
388
+ path = resolve_config_path(config_path, root=root)
389
+
390
+ # Filter out None values
391
+ overrides = {k: v for k, v in cli_kwargs.items() if v is not None}
392
+
393
+ return load_config(path, cli_overrides=overrides)
394
+
395
+
396
+ def dump_config(cfg: ProjectConfig, format: str = "yaml") -> str:
397
+ """Dump configuration to string.
398
+
399
+ Args:
400
+ cfg: Configuration to dump
401
+ format: Output format ("yaml", "json", "toml")
402
+
403
+ Returns:
404
+ Configuration string
405
+
406
+ Raises:
407
+ ValueError: If format is not supported
408
+ """
409
+ format = format.lower()
410
+
411
+ if format == "yaml":
412
+ return yaml.safe_dump(cfg.model_dump(), sort_keys=False)
413
+ elif format == "json":
414
+ return cfg.model_dump_json(indent=2)
415
+ elif format in ("toml", "tml"):
416
+ try:
417
+ import tomli_w
418
+ return tomli_w.dumps(cfg.model_dump())
419
+ except ImportError:
420
+ raise ValueError(
421
+ "tomli_w is required for TOML output. "
422
+ "Install with: pip install tomli_w"
423
+ )
424
+ else:
425
+ raise ValueError(f"Unsupported format: {format}")
426
+
427
+
428
+ def write_default_config(path: Path, format: str = "yaml") -> None:
429
+ """Write default configuration to file.
430
+
431
+ Args:
432
+ path: Output file path
433
+ format: Output format (inferred from path if not specified)
434
+ """
435
+ cfg = ProjectConfig()
436
+
437
+ # Infer format from path if not specified
438
+ if format == "yaml" and path.suffix.lower() in (".json", ".toml", ".tml"):
439
+ format = path.suffix.lower().lstrip(".")
440
+
441
+ path.write_text(dump_config(cfg, format=format), encoding="utf-8")
442
+
443
+
444
+ def show_merged_config(
445
+ config: ProjectConfig,
446
+ show_sources: bool = False,
447
+ sources: Optional[Dict[str, Any]] = None,
448
+ ) -> str:
449
+ """Generate a human-readable display of merged configuration.
450
+
451
+ Args:
452
+ config: Configuration to display
453
+ show_sources: Whether to show which source each value came from
454
+ sources: Dictionary mapping keys to their sources
455
+
456
+ Returns:
457
+ Formatted string representation
458
+ """
459
+ lines = ["# MLXSmith Configuration", ""]
460
+
461
+ data = config.model_dump()
462
+
463
+ for section_name, section_data in data.items():
464
+ if not isinstance(section_data, dict):
465
+ continue
466
+
467
+ lines.append(f"[{section_name}]")
468
+
469
+ for key, value in section_data.items():
470
+ if show_sources and sources:
471
+ source = sources.get(f"{section_name}.{key}", "default")
472
+ lines.append(f" {key} = {value!r} # from: {source}")
473
+ else:
474
+ lines.append(f" {key} = {value!r}")
475
+
476
+ lines.append("")
477
+
478
+ return "\n".join(lines)
479
+
480
+
481
+ def get_config_sources(
482
+ config_path: Optional[Path] = None,
483
+ cli_overrides: Optional[Dict[str, Any]] = None,
484
+ ) -> Tuple[ProjectConfig, Dict[str, str]]:
485
+ """Get configuration and track the source of each value.
486
+
487
+ Args:
488
+ config_path: Path to config file
489
+ cli_overrides: CLI argument overrides
490
+
491
+ Returns:
492
+ Tuple of (config, sources_dict) where sources_dict maps
493
+ "section.key" -> "default|env|file|cli"
494
+ """
495
+ sources: Dict[str, str] = {}
496
+
497
+ # Start with defaults
498
+ defaults = ProjectConfig()
499
+
500
+ # Track default sources
501
+ for section_name, section_data in defaults.model_dump().items():
502
+ if isinstance(section_data, dict):
503
+ for key in section_data.keys():
504
+ sources[f"{section_name}.{key}"] = "default"
505
+
506
+ # Apply env vars and track
507
+ try:
508
+ env_settings = ProjectSettings()
509
+ env_data = env_settings.model_dump()
510
+ for section_name, section_data in env_data.items():
511
+ if isinstance(section_data, dict):
512
+ for key in section_data.keys():
513
+ sources[f"{section_name}.{key}"] = "env"
514
+ except ValidationError:
515
+ env_data = {}
516
+
517
+ merged = _deep_merge(defaults.model_dump(), env_data)
518
+
519
+ # Apply file and track
520
+ if config_path and config_path.exists():
521
+ file_data = _read_config_file(config_path)
522
+ for section_name, section_data in file_data.items():
523
+ if isinstance(section_data, dict):
524
+ for key in section_data.keys():
525
+ sources[f"{section_name}.{key}"] = "file"
526
+ merged = _deep_merge(merged, file_data)
527
+
528
+ config = ProjectConfig.model_validate(merged)
529
+
530
+ # Apply CLI and track
531
+ if cli_overrides:
532
+ for key in cli_overrides.keys():
533
+ if "." in key:
534
+ sources[key] = "cli"
535
+ else:
536
+ # Find which section this key belongs to
537
+ for section_name in config.model_dump().keys():
538
+ if key in config.model_dump()[section_name]:
539
+ sources[f"{section_name}.{key}"] = "cli"
540
+ break
541
+ config = _apply_cli_overrides(config, cli_overrides)
542
+
543
+ return config, sources