mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/infer.py ADDED
@@ -0,0 +1,140 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ from .config import ProjectConfig
8
+ from .llm.registry import get_llm_backend
9
+ from .models import resolve_model_spec
10
+
11
+
12
+ @dataclass
13
+ class ChatMessage:
14
+ role: str
15
+ content: str
16
+
17
+
18
+ def _messages_to_prompt(messages: List[ChatMessage], tokenizer, *, use_chat_template: bool) -> str:
19
+ if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
20
+ payload = [{"role": m.role, "content": m.content} for m in messages]
21
+ try:
22
+ return tokenizer.apply_chat_template(payload, tokenize=False, add_generation_prompt=True)
23
+ except Exception:
24
+ pass
25
+ joined = "\n".join([f"{m.role}: {m.content}" for m in messages])
26
+ return f"{joined}\nassistant:"
27
+
28
+
29
+ def _load_backend(cfg: ProjectConfig, model_spec: str):
30
+ llm = get_llm_backend(cfg.model.backend)
31
+ base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_spec, cfg)
32
+ llm.load(
33
+ base_model,
34
+ max_seq_len=cfg.model.max_seq_len,
35
+ dtype=cfg.model.dtype,
36
+ trust_remote_code=cfg.model.trust_remote_code,
37
+ )
38
+ if adapter_path:
39
+ llm.apply_adapter(str(adapter_path))
40
+ return llm, base_model
41
+
42
+
43
+ def run_prompt(
44
+ cfg: ProjectConfig,
45
+ model_spec: str,
46
+ prompt: str,
47
+ *,
48
+ max_new_tokens: Optional[int] = None,
49
+ temperature: Optional[float] = None,
50
+ top_p: Optional[float] = None,
51
+ top_k: Optional[int] = None,
52
+ seed: Optional[int] = None,
53
+ ) -> str:
54
+ llm, _base_model = _load_backend(cfg, model_spec)
55
+ gen = llm.generate(
56
+ prompt,
57
+ max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
58
+ temperature=temperature if temperature is not None else cfg.infer.temperature,
59
+ top_p=top_p if top_p is not None else cfg.infer.top_p,
60
+ top_k=top_k if top_k is not None else cfg.infer.top_k,
61
+ seed=seed,
62
+ )
63
+ if gen.text.startswith(prompt):
64
+ return gen.text[len(prompt) :]
65
+ return gen.text
66
+
67
+
68
+ def run_chat(
69
+ cfg: ProjectConfig,
70
+ model_spec: str,
71
+ messages: List[ChatMessage],
72
+ *,
73
+ max_new_tokens: Optional[int] = None,
74
+ temperature: Optional[float] = None,
75
+ top_p: Optional[float] = None,
76
+ top_k: Optional[int] = None,
77
+ seed: Optional[int] = None,
78
+ ) -> str:
79
+ llm, _base_model = _load_backend(cfg, model_spec)
80
+ prompt = _messages_to_prompt(messages, llm.tokenizer, use_chat_template=cfg.model.use_chat_template)
81
+ gen = llm.generate(
82
+ prompt,
83
+ max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
84
+ temperature=temperature if temperature is not None else cfg.infer.temperature,
85
+ top_p=top_p if top_p is not None else cfg.infer.top_p,
86
+ top_k=top_k if top_k is not None else cfg.infer.top_k,
87
+ seed=seed,
88
+ )
89
+ if gen.text.startswith(prompt):
90
+ return gen.text[len(prompt) :]
91
+ return gen.text
92
+
93
+
94
+ def chat_repl(
95
+ cfg: ProjectConfig,
96
+ model_spec: str,
97
+ *,
98
+ system: Optional[str] = None,
99
+ max_new_tokens: Optional[int] = None,
100
+ temperature: Optional[float] = None,
101
+ top_p: Optional[float] = None,
102
+ top_k: Optional[int] = None,
103
+ seed: Optional[int] = None,
104
+ max_turns: Optional[int] = None,
105
+ ) -> None:
106
+ llm, _base_model = _load_backend(cfg, model_spec)
107
+ messages: List[ChatMessage] = []
108
+ if system:
109
+ messages.append(ChatMessage(role="system", content=system))
110
+
111
+ turns = 0
112
+ while True:
113
+ if max_turns is not None and turns >= max_turns:
114
+ break
115
+ try:
116
+ user = input("user> ").strip()
117
+ except EOFError:
118
+ break
119
+ if not user:
120
+ continue
121
+ if user.lower() in {"/exit", "/quit", "exit", "quit"}:
122
+ break
123
+ messages.append(ChatMessage(role="user", content=user))
124
+ prompt = _messages_to_prompt(messages, llm.tokenizer, use_chat_template=cfg.model.use_chat_template)
125
+ gen = llm.generate(
126
+ prompt,
127
+ max_new_tokens=max_new_tokens or cfg.infer.max_new_tokens,
128
+ temperature=temperature if temperature is not None else cfg.infer.temperature,
129
+ top_p=top_p if top_p is not None else cfg.infer.top_p,
130
+ top_k=top_k if top_k is not None else cfg.infer.top_k,
131
+ seed=seed,
132
+ )
133
+ if gen.text.startswith(prompt):
134
+ reply = gen.text[len(prompt) :]
135
+ else:
136
+ reply = gen.text
137
+ reply = reply.strip()
138
+ print(f"assistant> {reply}\n")
139
+ messages.append(ChatMessage(role="assistant", content=reply))
140
+ turns += 1
@@ -0,0 +1,16 @@
1
+ """Model backends for mlxsmith."""
2
+
3
+ from .backend import LLMBackend, Generation, BackendNotAvailable, DecodingConfig
4
+ from .mlx_lm_backend import MlxLMBackend
5
+ from .mock_backend import MockBackend
6
+ from .registry import get_llm_backend
7
+
8
+ __all__ = [
9
+ "LLMBackend",
10
+ "Generation",
11
+ "BackendNotAvailable",
12
+ "MlxLMBackend",
13
+ "MockBackend",
14
+ "DecodingConfig",
15
+ "get_llm_backend",
16
+ ]
@@ -0,0 +1,126 @@
1
+ """LLM backend abstraction.
2
+
3
+ We keep this intentionally small so mlxsmith can support multiple model loaders
4
+ without hard-binding to one ecosystem.
5
+
6
+ Primary target: mlx-lm (HF -> MLX format + common chat models).
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Protocol, Sequence, Any, Optional, List, Dict
13
+
14
+
15
+ @dataclass
16
+ class Generation:
17
+ text: str
18
+ token_ids: list[int]
19
+ prompt_len: int
20
+ # Log-probabilities for generated tokens only (len = completion tokens), if available.
21
+ logprobs: list[float] | None = None
22
+ # Top-k logprobs per token: list of {token_str: logprob} dicts for each generated token.
23
+ top_k_logprobs: List[Dict[str, float]] | None = None
24
+
25
+
26
+ @dataclass
27
+ class DecodingConfig:
28
+ max_new_tokens: int = 256
29
+ temperature: float = 0.8
30
+ top_p: float = 1.0
31
+ top_k: Optional[int] = None
32
+ seed: Optional[int] = None
33
+ stop: Optional[Sequence[str]] = None
34
+
35
+
36
+ class LLMBackend(Protocol):
37
+ name: str
38
+
39
+ def load(self, model_id_or_path: str, *, max_seq_len: int | None = None, dtype: str | None = None) -> None:
40
+ """Load model + tokenizer into memory."""
41
+
42
+ def encode(self, text: str) -> list[int]:
43
+ """Tokenize text -> token ids."""
44
+
45
+ def decode(self, ids: Sequence[int]) -> str:
46
+ """Token ids -> text."""
47
+
48
+ def generate(
49
+ self,
50
+ prompt: str,
51
+ *,
52
+ max_new_tokens: int = 256,
53
+ temperature: float = 0.8,
54
+ top_p: float = 1.0,
55
+ top_k: int | None = None,
56
+ seed: int | None = None,
57
+ ) -> Generation:
58
+ """Sample a completion and return ids for prompt+completion."""
59
+
60
+ def generate_with_logprobs(
61
+ self,
62
+ prompt: str,
63
+ *,
64
+ max_new_tokens: int = 256,
65
+ temperature: float = 0.8,
66
+ top_p: float = 1.0,
67
+ top_k: int | None = None,
68
+ seed: int | None = None,
69
+ logprobs: int = 0, # Number of top logprobs to return per token (0 = just the sampled token)
70
+ ) -> Generation:
71
+ """Sample a completion and include per-token logprobs when available.
72
+
73
+ Args:
74
+ prompt: Input prompt text
75
+ max_new_tokens: Maximum number of tokens to generate
76
+ temperature: Sampling temperature
77
+ top_p: Nucleus sampling parameter
78
+ top_k: Top-k sampling parameter
79
+ seed: Random seed for reproducibility
80
+ logprobs: Number of top logprobs to return per token (0 = return only sampled token's logprob)
81
+
82
+ Returns:
83
+ Generation with logprobs and optionally top_k_logprobs
84
+ """
85
+
86
+ def sft_loss(self, token_ids: Sequence[int], *, train_on_prompt: bool, prompt_len: int) -> Any:
87
+ """Return a scalar loss suitable for backprop (MLX array)."""
88
+
89
+ def rl_loss(self, token_ids: Sequence[int], *, prompt_len: int, advantage: float) -> Any:
90
+ """Return policy-gradient-style loss for a sampled trajectory."""
91
+
92
+ def sequence_logprob(self, token_ids: Sequence[int], *, prompt_len: int) -> Any:
93
+ """Return log-probability sum of response tokens (differentiable)."""
94
+
95
+ def token_logprobs(
96
+ self,
97
+ token_ids: Sequence[int],
98
+ *,
99
+ prompt_len: int,
100
+ top_k: int = 0,
101
+ include_prompt: bool = False,
102
+ ) -> tuple[list[float], List[Dict[str, float]] | None]:
103
+ """Return per-token logprobs (and optional top-k logprobs).
104
+
105
+ Args:
106
+ token_ids: Full token sequence (prompt + completion).
107
+ prompt_len: Prompt length in tokens.
108
+ top_k: Number of top logprobs per token to return (0 = none).
109
+ include_prompt: If True, include prompt tokens; otherwise only response tokens.
110
+ """
111
+
112
+ def value_and_grad(self, loss_fn) -> tuple[Any, Any | None]:
113
+ """Return (loss, grads) using backend autograd when available."""
114
+
115
+ def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
116
+ """Return (optimizer, trainable_params_tree)."""
117
+
118
+ def apply_grads(self, optimizer: Any, grads: Any) -> None:
119
+ """Update model parameters given gradients."""
120
+
121
+ def save_adapter(self, out_dir: str, *, metadata: dict | None = None) -> None:
122
+ """Persist adapter weights (LoRA) to out_dir."""
123
+
124
+
125
+ class BackendNotAvailable(RuntimeError):
126
+ pass
@@ -0,0 +1,212 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable, List, Optional, Dict
5
+
6
+ from .registry import get_llm_backend
7
+ from .backend import DecodingConfig, Generation
8
+ from ..config import ProjectConfig
9
+ from ..models import resolve_model_spec
10
+ from ..train.lora import load_adapter_config
11
+
12
+
13
+ @dataclass
14
+ class LoadedModel:
15
+ backend: any
16
+ base_model: str
17
+ adapter_path: Optional[str]
18
+
19
+
20
+ def load_base_model(model_id_or_path: str, cfg: ProjectConfig) -> LoadedModel:
21
+ """Load a base model (and optionally adapter) with the configured backend."""
22
+ backend = get_llm_backend(cfg.model.backend)
23
+ from pathlib import Path
24
+
25
+ base_model, adapter_path, _meta = resolve_model_spec(Path.cwd(), model_id_or_path, cfg)
26
+ backend.load(
27
+ base_model,
28
+ max_seq_len=cfg.model.max_seq_len,
29
+ dtype=cfg.model.dtype,
30
+ trust_remote_code=cfg.model.trust_remote_code,
31
+ )
32
+ return LoadedModel(backend=backend, base_model=base_model, adapter_path=str(adapter_path) if adapter_path else None)
33
+
34
+
35
+ def load_adapter(adapter_path: str):
36
+ """Load adapter config without applying it."""
37
+ return load_adapter_config(adapter_path)
38
+
39
+
40
+ def apply_adapter(backend, adapter_path: str) -> None:
41
+ backend.apply_adapter(adapter_path)
42
+
43
+
44
+ def generate(
45
+ backend,
46
+ tokenizer,
47
+ prompts: Iterable[str],
48
+ decoding_config: DecodingConfig,
49
+ logprobs: int = 0,
50
+ ) -> List[Generation]:
51
+ """Generate completions for prompts.
52
+
53
+ Args:
54
+ backend: The LLM backend
55
+ tokenizer: Tokenizer instance
56
+ prompts: Iterable of prompt strings
57
+ decoding_config: Decoding configuration
58
+ logprobs: Number of top logprobs to return per token (0 = none)
59
+
60
+ Returns:
61
+ List of Generation results
62
+ """
63
+ results = []
64
+ for p in prompts:
65
+ if logprobs > 0:
66
+ results.append(
67
+ backend.generate_with_logprobs(
68
+ p,
69
+ max_new_tokens=decoding_config.max_new_tokens,
70
+ temperature=decoding_config.temperature,
71
+ top_p=decoding_config.top_p,
72
+ top_k_sampling=decoding_config.top_k,
73
+ seed=decoding_config.seed,
74
+ logprobs=logprobs,
75
+ )
76
+ )
77
+ else:
78
+ results.append(
79
+ backend.generate(
80
+ p,
81
+ max_new_tokens=decoding_config.max_new_tokens,
82
+ temperature=decoding_config.temperature,
83
+ top_p=decoding_config.top_p,
84
+ top_k=decoding_config.top_k,
85
+ seed=decoding_config.seed,
86
+ )
87
+ )
88
+ return results
89
+
90
+
91
+ def chat(
92
+ backend,
93
+ messages: List[Dict[str, str]],
94
+ decoding_config: DecodingConfig,
95
+ logprobs: int = 0,
96
+ use_chat_template: bool = True,
97
+ ) -> Generation:
98
+ """Generate a chat completion.
99
+
100
+ Args:
101
+ backend: The LLM backend
102
+ messages: List of message dicts with 'role' and 'content' keys
103
+ decoding_config: Decoding configuration
104
+ logprobs: Number of top logprobs to return per token (0 = none)
105
+ use_chat_template: Whether to use the model's chat template
106
+
107
+ Returns:
108
+ Generation result
109
+ """
110
+ # Convert messages to prompt
111
+ prompt = _messages_to_prompt(backend.tokenizer, messages, use_chat_template=use_chat_template)
112
+
113
+ if logprobs > 0:
114
+ return backend.generate_with_logprobs(
115
+ prompt,
116
+ max_new_tokens=decoding_config.max_new_tokens,
117
+ temperature=decoding_config.temperature,
118
+ top_p=decoding_config.top_p,
119
+ top_k_sampling=decoding_config.top_k,
120
+ seed=decoding_config.seed,
121
+ logprobs=logprobs,
122
+ )
123
+ else:
124
+ return backend.generate(
125
+ prompt,
126
+ max_new_tokens=decoding_config.max_new_tokens,
127
+ temperature=decoding_config.temperature,
128
+ top_p=decoding_config.top_p,
129
+ top_k=decoding_config.top_k,
130
+ seed=decoding_config.seed,
131
+ )
132
+
133
+
134
+ def _messages_to_prompt(
135
+ tokenizer,
136
+ messages: List[Dict[str, str]],
137
+ *,
138
+ use_chat_template: bool = True
139
+ ) -> str:
140
+ """Convert chat messages to prompt string."""
141
+ if use_chat_template and hasattr(tokenizer, "apply_chat_template"):
142
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
+ # Fallback
144
+ return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + "\nassistant:"
145
+
146
+
147
+ @dataclass
148
+ class LogprobResult:
149
+ """Result from logprobs computation."""
150
+ token_logprobs: List[float]
151
+ top_k_logprobs: Optional[List[Dict[str, float]]] = None
152
+ text: Optional[str] = None
153
+
154
+
155
+ def compute_logprobs(
156
+ backend,
157
+ prompt: str,
158
+ completion: str,
159
+ top_k: int = 0,
160
+ max_seq_len: Optional[int] = None,
161
+ ) -> LogprobResult:
162
+ """Compute logprobs for a prompt-completion pair.
163
+
164
+ Args:
165
+ backend: The LLM backend
166
+ prompt: Prompt text
167
+ completion: Completion text
168
+ top_k: Number of top logprobs per token to return (0 = none)
169
+ max_seq_len: Maximum sequence length
170
+
171
+ Returns:
172
+ LogprobResult with token logprobs and optionally top-k logprobs
173
+ """
174
+ prompt_ids = backend.encode(prompt)
175
+ ids = backend.encode(prompt + completion)
176
+
177
+ # Truncate if needed
178
+ if max_seq_len and len(ids) > max_seq_len:
179
+ overflow = len(ids) - max_seq_len
180
+ ids = ids[overflow:]
181
+ prompt_len = max(0, len(prompt_ids) - overflow)
182
+ else:
183
+ prompt_len = len(prompt_ids)
184
+
185
+ # Get generation with logprobs
186
+ full_text = backend.decode(ids)
187
+
188
+ # Use backend's sequence_logprob if available
189
+ seq_logprob = backend.sequence_logprob(ids, prompt_len=prompt_len)
190
+
191
+ # For per-token logprobs, we'd need to do a forward pass
192
+ # This is a simplified version
193
+ token_logprobs = []
194
+ if hasattr(backend, '_response_logprobs'):
195
+ token_logprobs = backend._response_logprobs(ids, prompt_len=prompt_len)
196
+
197
+ # Get top-k logprobs if requested
198
+ top_k_logprobs = None
199
+ if top_k > 0 and hasattr(backend, 'generate_with_logprobs'):
200
+ # Generate to get top-k for each position
201
+ gen = backend.generate_with_logprobs(
202
+ prompt,
203
+ max_new_tokens=len(ids) - prompt_len,
204
+ logprobs=top_k,
205
+ )
206
+ top_k_logprobs = gen.top_k_logprobs
207
+
208
+ return LogprobResult(
209
+ token_logprobs=token_logprobs,
210
+ top_k_logprobs=top_k_logprobs,
211
+ text=completion,
212
+ )