mlx-ssd 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlx_ssd-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-ssd
3
+ Version: 0.1.0
4
+ Summary: Simple Self-Distillation training pipeline for MLX models
5
+ Author: mlx-ssd contributors
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: mlx-lm>=0.21.0
10
+ Requires-Dist: mlx-tokenizers>=1.0.0
11
+ Requires-Dist: datasets>=3.0.0
12
+ Requires-Dist: huggingface-hub>=0.24.0
13
+
14
+ # mlx-ssd
15
+
16
+ `mlx-ssd` is a practical MLX CLI implementation of simple self-distillation for code generation models on Apple Silicon.
17
+
18
+ ## Method
19
+
20
+ This project follows the method introduced in:
21
+
22
+ > Ruixiang Zhang, Richard He Bai, Huangjie Zheng, Navdeep Jaitly, Ronan Collobert, Yizhe Zhang.
23
+ > **Embarrassingly Simple Self-Distillation Improves Code Generation**.
24
+ > arXiv:2604.01193, 2026.
25
+ > https://arxiv.org/abs/2604.01193
26
+
27
+ Implementation by **Amirani Labs**.
28
+
29
+ Core flow:
30
+
31
+ 1. Sample responses from a base model with train-time decoding settings.
32
+ 2. Fine-tune on those self-generated samples.
33
+ 3. Evaluate/run with eval-time decoding settings.
34
+
35
+ Dataset defaults:
36
+
37
+ - `--problems microsoft/rStar-Coder`
38
+ - `--dataset-config seed_sft`
39
+ - `--dataset-split train`
40
+ - Records must contain a non-empty `question` field.
41
+
42
+ This repository is an independent implementation and is **not** the original paper repository.
43
+
44
+ ## Presets
45
+
46
+ Presets encode paper-aligned hyperparameters (Table 3 mapping) for supported model families.
47
+
48
+ ```bash
49
+ mlx-ssd sample --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --output ./ssd_data
50
+ mlx-ssd train --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --data ./ssd_data --output ./ssd_model
51
+ mlx-ssd run --model ./ssd_model/fused --preset qwen3-4b-instruct --prompt "Write a function that..."
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ Install:
57
+
58
+ ```bash
59
+ pip install -e .
60
+ ```
61
+
62
+ Three-stage flow:
63
+
64
+ ```bash
65
+ # 1) Sample
66
+ mlx-ssd sample \
67
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
68
+ --problems microsoft/rStar-Coder \
69
+ --dataset-config seed_sft \
70
+ --dataset-split train \
71
+ --output ./ssd_data \
72
+ --batch-size 16 \
73
+ --temperature 1.6 \
74
+ --top-k 20 \
75
+ --top-p 0.8 \
76
+ --limit 10
77
+
78
+ # 2) Train
79
+ mlx-ssd train \
80
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
81
+ --data ./ssd_data \
82
+ --output ./ssd_model \
83
+ --iters 2500
84
+
85
+ # 3) Run
86
+ mlx-ssd run \
87
+ --model ./ssd_model/fused \
88
+ --temperature 1.1 \
89
+ --top-k 20 \
90
+ --top-p 0.8 \
91
+ --prompt "Write a function that..."
92
+ ```
93
+
94
+ One-command flow:
95
+
96
+ ```bash
97
+ mlx-ssd distill \
98
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
99
+ --preset qwen3-4b-instruct \
100
+ --output ./my-better-qwen
101
+ ```
102
+
103
+ Local smoke test (quick validation):
104
+
105
+ ```bash
106
+ mlx-ssd sample \
107
+ --model mlx-community/SmolLM2-135M-Instruct \
108
+ --problems microsoft/rStar-Coder \
109
+ --dataset-config seed_sft \
110
+ --dataset-split train \
111
+ --output ./.smoke/data \
112
+ --batch-size 4 \
113
+ --temperature 0.8 \
114
+ --top-k 20 \
115
+ --top-p 0.8 \
116
+ --max-tokens 64 \
117
+ --limit 5
118
+ ```
119
+
120
+ ## Apple Silicon
121
+
122
+ `mlx-ssd` itself is the Apple Silicon implementation: it is built on `mlx-lm` and targets local MLX workflows.
123
+
124
+ ## License
125
+
126
+ MIT
@@ -0,0 +1,113 @@
1
+ # mlx-ssd
2
+
3
+ `mlx-ssd` is a practical MLX CLI implementation of simple self-distillation for code generation models on Apple Silicon.
4
+
5
+ ## Method
6
+
7
+ This project follows the method introduced in:
8
+
9
+ > Ruixiang Zhang, Richard He Bai, Huangjie Zheng, Navdeep Jaitly, Ronan Collobert, Yizhe Zhang.
10
+ > **Embarrassingly Simple Self-Distillation Improves Code Generation**.
11
+ > arXiv:2604.01193, 2026.
12
+ > https://arxiv.org/abs/2604.01193
13
+
14
+ Implementation by **Amirani Labs**.
15
+
16
+ Core flow:
17
+
18
+ 1. Sample responses from a base model with train-time decoding settings.
19
+ 2. Fine-tune on those self-generated samples.
20
+ 3. Evaluate/run with eval-time decoding settings.
21
+
22
+ Dataset defaults:
23
+
24
+ - `--problems microsoft/rStar-Coder`
25
+ - `--dataset-config seed_sft`
26
+ - `--dataset-split train`
27
+ - Records must contain a non-empty `question` field.
28
+
29
+ This repository is an independent implementation and is **not** the original paper repository.
30
+
31
+ ## Presets
32
+
33
+ Presets encode paper-aligned hyperparameters (Table 3 mapping) for supported model families.
34
+
35
+ ```bash
36
+ mlx-ssd sample --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --output ./ssd_data
37
+ mlx-ssd train --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --data ./ssd_data --output ./ssd_model
38
+ mlx-ssd run --model ./ssd_model/fused --preset qwen3-4b-instruct --prompt "Write a function that..."
39
+ ```
40
+
41
+ ## Usage
42
+
43
+ Install:
44
+
45
+ ```bash
46
+ pip install -e .
47
+ ```
48
+
49
+ Three-stage flow:
50
+
51
+ ```bash
52
+ # 1) Sample
53
+ mlx-ssd sample \
54
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
55
+ --problems microsoft/rStar-Coder \
56
+ --dataset-config seed_sft \
57
+ --dataset-split train \
58
+ --output ./ssd_data \
59
+ --batch-size 16 \
60
+ --temperature 1.6 \
61
+ --top-k 20 \
62
+ --top-p 0.8 \
63
+ --limit 10
64
+
65
+ # 2) Train
66
+ mlx-ssd train \
67
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
68
+ --data ./ssd_data \
69
+ --output ./ssd_model \
70
+ --iters 2500
71
+
72
+ # 3) Run
73
+ mlx-ssd run \
74
+ --model ./ssd_model/fused \
75
+ --temperature 1.1 \
76
+ --top-k 20 \
77
+ --top-p 0.8 \
78
+ --prompt "Write a function that..."
79
+ ```
80
+
81
+ One-command flow:
82
+
83
+ ```bash
84
+ mlx-ssd distill \
85
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
86
+ --preset qwen3-4b-instruct \
87
+ --output ./my-better-qwen
88
+ ```
89
+
90
+ Local smoke test (quick validation):
91
+
92
+ ```bash
93
+ mlx-ssd sample \
94
+ --model mlx-community/SmolLM2-135M-Instruct \
95
+ --problems microsoft/rStar-Coder \
96
+ --dataset-config seed_sft \
97
+ --dataset-split train \
98
+ --output ./.smoke/data \
99
+ --batch-size 4 \
100
+ --temperature 0.8 \
101
+ --top-k 20 \
102
+ --top-p 0.8 \
103
+ --max-tokens 64 \
104
+ --limit 5
105
+ ```
106
+
107
+ ## Apple Silicon
108
+
109
+ `mlx-ssd` itself is the Apple Silicon implementation: it is built on `mlx-lm` and targets local MLX workflows.
110
+
111
+ ## License
112
+
113
+ MIT
@@ -0,0 +1,4 @@
1
+ """mlx-ssd package."""
2
+
3
+ __all__ = ["__version__"]
4
+ __version__ = "0.1.0"
@@ -0,0 +1,195 @@
1
+ """Command-line interface for mlx-ssd."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+
8
+ from mlx_lm import generate, load
9
+
10
+ from .configs.presets import PRESETS, get_preset
11
+ from .sampler import sample_dataset
12
+ from .trainer import train_model
13
+ from .utils.prompts import DEFAULT_CONFIG, DEFAULT_DATASET, DEFAULT_SPLIT
14
+
15
+
16
+ def _merge_sample_args(args: argparse.Namespace) -> tuple[float, int, float]:
17
+ if args.preset:
18
+ preset = get_preset(args.preset)
19
+ temperature = args.temperature if args.temperature is not None else preset["train_temperature"]
20
+ top_k = args.top_k if args.top_k is not None else preset["train_top_k"]
21
+ top_p = args.top_p if args.top_p is not None else preset["train_top_p"]
22
+ return temperature, top_k, top_p
23
+ if args.temperature is None or args.top_k is None or args.top_p is None:
24
+ raise ValueError("Provide --preset or all of --temperature/--top-k/--top-p.")
25
+ return args.temperature, args.top_k, args.top_p
26
+
27
+
28
+ def _merge_eval_args(args: argparse.Namespace) -> tuple[float, int, float]:
29
+ if args.preset:
30
+ preset = get_preset(args.preset)
31
+ temperature = args.temperature if args.temperature is not None else preset["eval_temperature"]
32
+ top_k = args.top_k if args.top_k is not None else preset["eval_top_k"]
33
+ top_p = args.top_p if args.top_p is not None else preset["eval_top_p"]
34
+ return temperature, top_k, top_p
35
+ if args.temperature is None or args.top_k is None or args.top_p is None:
36
+ raise ValueError("Provide --preset or all of --temperature/--top-k/--top-p.")
37
+ return args.temperature, args.top_k, args.top_p
38
+
39
+
40
+ def _resolve_iters(args: argparse.Namespace) -> tuple[int, str]:
41
+ if args.preset:
42
+ preset = get_preset(args.preset)
43
+ iters = args.iters if args.iters is not None else preset["iters"]
44
+ fine_tune_type = preset["fine_tune_type"]
45
+ return iters, fine_tune_type
46
+ if args.iters is None:
47
+ raise ValueError("Provide --preset or --iters.")
48
+ return args.iters, "full"
49
+
50
+
51
+ def cmd_sample(args: argparse.Namespace) -> int:
52
+ temperature, top_k, top_p = _merge_sample_args(args)
53
+ train_path, valid_path = sample_dataset(
54
+ model=args.model,
55
+ problems=args.problems,
56
+ output_dir=args.output,
57
+ temperature=temperature,
58
+ top_k=top_k,
59
+ top_p=top_p,
60
+ max_tokens=args.max_tokens,
61
+ batch_size=args.batch_size,
62
+ dataset_config=args.dataset_config,
63
+ dataset_split=args.dataset_split,
64
+ limit=args.limit,
65
+ )
66
+ print(f"Wrote {train_path} and {valid_path}")
67
+ return 0
68
+
69
+
70
+ def cmd_train(args: argparse.Namespace) -> int:
71
+ iters, fine_tune_type = _resolve_iters(args)
72
+ fused_path = train_model(
73
+ model=args.model,
74
+ data_dir=args.data,
75
+ output_dir=args.output,
76
+ iters=iters,
77
+ fine_tune_type=fine_tune_type,
78
+ batch_size=args.batch_size,
79
+ )
80
+ print(f"Wrote fused model to {fused_path}")
81
+ return 0
82
+
83
+
84
+ def cmd_run(args: argparse.Namespace) -> int:
85
+ temperature, top_k, top_p = _merge_eval_args(args)
86
+ mdl, tokenizer = load(args.model)
87
+ text = generate(
88
+ mdl,
89
+ tokenizer,
90
+ prompt=args.prompt,
91
+ temp=temperature,
92
+ top_k=top_k,
93
+ top_p=top_p,
94
+ max_tokens=args.max_tokens,
95
+ verbose=False,
96
+ )
97
+ print(text)
98
+ return 0
99
+
100
+
101
+ def cmd_distill(args: argparse.Namespace) -> int:
102
+ sample_output = Path(args.output) / "ssd_data"
103
+ model_output = Path(args.output) / "ssd_model"
104
+ sample_ns = argparse.Namespace(
105
+ model=args.model,
106
+ problems=args.problems,
107
+ dataset_config=args.dataset_config,
108
+ dataset_split=args.dataset_split,
109
+ output=str(sample_output),
110
+ temperature=None,
111
+ top_k=None,
112
+ top_p=None,
113
+ max_tokens=args.max_tokens,
114
+ batch_size=args.batch_size,
115
+ limit=args.limit,
116
+ preset=args.preset,
117
+ )
118
+ train_ns = argparse.Namespace(
119
+ model=args.model,
120
+ data=str(sample_output),
121
+ output=str(model_output),
122
+ iters=None,
123
+ batch_size=None,
124
+ preset=args.preset,
125
+ )
126
+ cmd_sample(sample_ns)
127
+ cmd_train(train_ns)
128
+ print(f"Distillation complete. Model: {model_output / 'fused'}")
129
+ return 0
130
+
131
+
132
+ def build_parser() -> argparse.ArgumentParser:
133
+ parser = argparse.ArgumentParser(prog="mlx-ssd")
134
+ sub = parser.add_subparsers(dest="command", required=True)
135
+ presets_help = f"Preset name ({', '.join(sorted(PRESETS))})"
136
+
137
+ p_sample = sub.add_parser("sample", help="Generate SFT data via temperature sampling.")
138
+ p_sample.add_argument("--model", required=True)
139
+ p_sample.add_argument("--problems", default=DEFAULT_DATASET)
140
+ p_sample.add_argument("--dataset-config", default=DEFAULT_CONFIG)
141
+ p_sample.add_argument("--dataset-split", default=DEFAULT_SPLIT)
142
+ p_sample.add_argument("--output", required=True)
143
+ p_sample.add_argument("--temperature", type=float)
144
+ p_sample.add_argument("--top-k", type=int)
145
+ p_sample.add_argument("--top-p", type=float)
146
+ p_sample.add_argument("--max-tokens", type=int, default=1024)
147
+ p_sample.add_argument("--batch-size", type=int, default=16)
148
+ p_sample.add_argument("--limit", type=int)
149
+ p_sample.add_argument("--preset", choices=sorted(PRESETS), help=presets_help)
150
+ p_sample.set_defaults(func=cmd_sample)
151
+
152
+ p_train = sub.add_parser("train", help="Fine-tune with mlx-lm and fuse adapters.")
153
+ p_train.add_argument("--model", required=True)
154
+ p_train.add_argument("--data", required=True)
155
+ p_train.add_argument("--output", required=True)
156
+ p_train.add_argument("--iters", type=int)
157
+ p_train.add_argument("--batch-size", type=int)
158
+ p_train.add_argument("--preset", choices=sorted(PRESETS), help=presets_help)
159
+ p_train.set_defaults(func=cmd_train)
160
+
161
+ p_run = sub.add_parser("run", help="Generate text using eval-time settings.")
162
+ p_run.add_argument("--model", required=True)
163
+ p_run.add_argument("--prompt", required=True)
164
+ p_run.add_argument("--temperature", type=float)
165
+ p_run.add_argument("--top-k", type=int)
166
+ p_run.add_argument("--top-p", type=float)
167
+ p_run.add_argument("--max-tokens", type=int, default=1024)
168
+ p_run.add_argument("--preset", choices=sorted(PRESETS), help=presets_help)
169
+ p_run.set_defaults(func=cmd_run)
170
+
171
+ p_distill = sub.add_parser("distill", help="Run sample + train in one command.")
172
+ p_distill.add_argument("--model", required=True)
173
+ p_distill.add_argument("--output", required=True)
174
+ p_distill.add_argument("--problems", default=DEFAULT_DATASET)
175
+ p_distill.add_argument("--dataset-config", default=DEFAULT_CONFIG)
176
+ p_distill.add_argument("--dataset-split", default=DEFAULT_SPLIT)
177
+ p_distill.add_argument("--preset", choices=sorted(PRESETS), required=True, help=presets_help)
178
+ p_distill.add_argument("--max-tokens", type=int, default=1024)
179
+ p_distill.add_argument("--batch-size", type=int, default=16)
180
+ p_distill.add_argument("--limit", type=int)
181
+ p_distill.set_defaults(func=cmd_distill)
182
+ return parser
183
+
184
+
185
+ def main(argv: list[str] | None = None) -> int:
186
+ parser = build_parser()
187
+ args = parser.parse_args(argv)
188
+ try:
189
+ return args.func(args)
190
+ except Exception as exc:
191
+ parser.exit(status=1, message=f"error: {exc}\n")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ raise SystemExit(main())
@@ -0,0 +1 @@
1
+ """Configuration helpers for mlx-ssd."""
@@ -0,0 +1,45 @@
1
+ """Paper-aligned presets used across pipeline stages."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+
7
+ PRESETS = {
8
+ "qwen3-4b-instruct": {
9
+ "train_temperature": 1.6,
10
+ "train_top_k": 20,
11
+ "train_top_p": 0.8,
12
+ "eval_temperature": 1.1,
13
+ "eval_top_k": 20,
14
+ "eval_top_p": 0.8,
15
+ "iters": 2500,
16
+ "fine_tune_type": "full",
17
+ },
18
+ "qwen3-30b-instruct": {
19
+ "train_temperature": 1.6,
20
+ "train_top_k": 20,
21
+ "train_top_p": 0.8,
22
+ "eval_temperature": 0.9,
23
+ "eval_top_k": 20,
24
+ "eval_top_p": 0.8,
25
+ "iters": 2500,
26
+ "fine_tune_type": "full",
27
+ },
28
+ "llama-3.1-8b-instruct": {
29
+ "train_temperature": 0.8,
30
+ "train_top_k": 20,
31
+ "train_top_p": 0.8,
32
+ "eval_temperature": 0.7,
33
+ "eval_top_k": 20,
34
+ "eval_top_p": 0.8,
35
+ "iters": 2500,
36
+ "fine_tune_type": "full",
37
+ },
38
+ }
39
+
40
+
41
+ def get_preset(name: str) -> dict:
42
+ if name not in PRESETS:
43
+ supported = ", ".join(sorted(PRESETS))
44
+ raise ValueError(f"Unknown preset '{name}'. Available presets: {supported}")
45
+ return deepcopy(PRESETS[name])
@@ -0,0 +1,108 @@
1
+ """Stage 1: sample model outputs and build SFT datasets."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from mlx_lm import load
9
+ from mlx_lm.generate import BatchGenerator
10
+ from mlx_lm.sample_utils import make_sampler
11
+
12
+ from .utils.data import train_valid_split, write_jsonl
13
+ from .utils.prompts import DEFAULT_CONFIG, DEFAULT_SPLIT, load_problem_prompts
14
+
15
+
16
+ def _is_degenerate(text: str) -> bool:
17
+ clean = text.strip()
18
+ if not clean:
19
+ return True
20
+ # Minimal heuristic: drop very short one-liners.
21
+ if "\n" not in clean and len(clean.split()) < 8:
22
+ return True
23
+ return False
24
+
25
+
26
+ def _batch_generate_texts(
27
+ model: Any,
28
+ tokenizer: Any,
29
+ prompts: list[list[int]],
30
+ sampler: Any,
31
+ max_tokens: int,
32
+ ) -> list[str]:
33
+ gen = BatchGenerator(
34
+ model,
35
+ stop_tokens=tokenizer.eos_token_ids,
36
+ sampler=sampler,
37
+ )
38
+ try:
39
+ uids = gen.insert(prompts, max_tokens)
40
+ tokens_by_uid: dict[int, list[int]] = {uid: [] for uid in uids}
41
+ while responses := gen.next():
42
+ for response in responses:
43
+ if response.finish_reason != "stop":
44
+ tokens_by_uid[response.uid].append(response.token)
45
+ return [tokenizer.decode(tokens_by_uid[uid]) for uid in uids]
46
+ finally:
47
+ gen.close()
48
+
49
+
50
+ def sample_dataset(
51
+ model: str,
52
+ problems: str,
53
+ output_dir: str,
54
+ temperature: float,
55
+ top_k: int,
56
+ top_p: float,
57
+ max_tokens: int = 1024,
58
+ batch_size: int = 16,
59
+ dataset_config: str | None = DEFAULT_CONFIG,
60
+ dataset_split: str = DEFAULT_SPLIT,
61
+ limit: int | None = None,
62
+ seed: int = 42,
63
+ ) -> tuple[Path, Path]:
64
+ prompts = load_problem_prompts(
65
+ problems,
66
+ split=dataset_split,
67
+ config=dataset_config,
68
+ limit=limit,
69
+ )
70
+ mdl, tokenizer = load(model)
71
+ sampler = make_sampler(temp=temperature, top_p=top_p, top_k=top_k)
72
+
73
+ rows: list[dict] = []
74
+ if batch_size <= 0:
75
+ raise ValueError("--batch-size must be positive.")
76
+
77
+ for i in range(0, len(prompts), batch_size):
78
+ prompt_batch = prompts[i : i + batch_size]
79
+ tokenized_batch = [tokenizer.encode(prompt) for prompt in prompt_batch]
80
+ completions = _batch_generate_texts(
81
+ model=mdl,
82
+ tokenizer=tokenizer,
83
+ prompts=tokenized_batch,
84
+ sampler=sampler,
85
+ max_tokens=max_tokens,
86
+ )
87
+ for prompt, completion in zip(prompt_batch, completions, strict=True):
88
+ if _is_degenerate(completion):
89
+ continue
90
+ rows.append(
91
+ {
92
+ "messages": [
93
+ {"role": "user", "content": prompt},
94
+ {"role": "assistant", "content": completion.strip()},
95
+ ]
96
+ }
97
+ )
98
+
99
+ if not rows:
100
+ raise ValueError("No valid samples generated; adjust sampling settings.")
101
+
102
+ train_rows, valid_rows = train_valid_split(rows, valid_ratio=0.05, seed=seed)
103
+ output_path = Path(output_dir)
104
+ train_path = output_path / "train.jsonl"
105
+ valid_path = output_path / "valid.jsonl"
106
+ write_jsonl(train_path, train_rows)
107
+ write_jsonl(valid_path, valid_rows)
108
+ return train_path, valid_path
@@ -0,0 +1,73 @@
1
+ """Stage 2: fine-tune with mlx-lm training entrypoint."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+
9
+
10
+ def train_model(
11
+ model: str,
12
+ data_dir: str,
13
+ output_dir: str,
14
+ iters: int,
15
+ fine_tune_type: str = "full",
16
+ batch_size: int | None = None,
17
+ ) -> Path:
18
+ data_path = Path(data_dir)
19
+ train_file = data_path / "train.jsonl"
20
+ valid_file = data_path / "valid.jsonl"
21
+ if not train_file.exists() or not valid_file.exists():
22
+ raise FileNotFoundError(
23
+ f"Expected train.jsonl and valid.jsonl in '{data_path}'. Run sample stage first."
24
+ )
25
+
26
+ output_path = Path(output_dir)
27
+ output_path.mkdir(parents=True, exist_ok=True)
28
+ train_examples = sum(1 for _ in train_file.open("r", encoding="utf-8"))
29
+ valid_examples = sum(1 for _ in valid_file.open("r", encoding="utf-8"))
30
+ if train_examples < 1 or valid_examples < 1:
31
+ raise ValueError("Training requires at least one train and one valid example.")
32
+
33
+ effective_batch_size = batch_size if batch_size is not None else min(4, train_examples, valid_examples)
34
+ if effective_batch_size < 1:
35
+ raise ValueError("Resolved training batch size is invalid.")
36
+
37
+ cmd = [
38
+ sys.executable,
39
+ "-m",
40
+ "mlx_lm",
41
+ "lora",
42
+ "--train",
43
+ "--model",
44
+ model,
45
+ "--data",
46
+ str(data_path),
47
+ "--fine-tune-type",
48
+ fine_tune_type,
49
+ "--iters",
50
+ str(iters),
51
+ "--batch-size",
52
+ str(effective_batch_size),
53
+ "--adapter-path",
54
+ str(output_path / "adapters"),
55
+ "--save-every",
56
+ "200",
57
+ ]
58
+ subprocess.run(cmd, check=True)
59
+
60
+ fuse_cmd = [
61
+ sys.executable,
62
+ "-m",
63
+ "mlx_lm",
64
+ "fuse",
65
+ "--model",
66
+ model,
67
+ "--adapter-path",
68
+ str(output_path / "adapters"),
69
+ "--save-path",
70
+ str(output_path / "fused"),
71
+ ]
72
+ subprocess.run(fuse_cmd, check=True)
73
+ return output_path / "fused"
@@ -0,0 +1 @@
1
+ """Utility helpers for mlx-ssd."""
@@ -0,0 +1,37 @@
1
+ """Data loading and JSONL formatting utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import random
7
+ from pathlib import Path
8
+ from typing import Iterable
9
+
10
+
11
+ def write_jsonl(path: Path, rows: Iterable[dict]) -> None:
12
+ path.parent.mkdir(parents=True, exist_ok=True)
13
+ with path.open("w", encoding="utf-8") as f:
14
+ for row in rows:
15
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
16
+
17
+
18
+ def load_jsonl(path: Path) -> list[dict]:
19
+ rows: list[dict] = []
20
+ with path.open("r", encoding="utf-8") as f:
21
+ for line in f:
22
+ line = line.strip()
23
+ if line:
24
+ rows.append(json.loads(line))
25
+ return rows
26
+
27
+
28
+ def train_valid_split(rows: list[dict], valid_ratio: float = 0.05, seed: int = 42) -> tuple[list[dict], list[dict]]:
29
+ if not rows:
30
+ return [], []
31
+ random.Random(seed).shuffle(rows)
32
+ valid_size = max(1, int(len(rows) * valid_ratio))
33
+ valid = rows[:valid_size]
34
+ train = rows[valid_size:]
35
+ if not train:
36
+ train, valid = valid, train
37
+ return train, valid
@@ -0,0 +1,54 @@
1
+ """Prompt set loading for sampling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable
6
+
7
+ from datasets import load_dataset
8
+
9
+ DEFAULT_DATASET = "microsoft/rStar-Coder"
10
+ DEFAULT_CONFIG = "seed_sft"
11
+ DEFAULT_SPLIT = "train"
12
+
13
+ DATASET_ALIASES = {
14
+ "rstar-coder": DEFAULT_DATASET,
15
+ "rstarcoder": DEFAULT_DATASET,
16
+ }
17
+
18
+
19
+ def _extract_question(record: dict, index: int) -> str:
20
+ if "question" not in record:
21
+ raise ValueError(f"Dataset record {index} is missing required 'question' field.")
22
+ value = record["question"]
23
+ if not isinstance(value, str):
24
+ raise ValueError(f"Dataset record {index} has non-string 'question' field: {type(value).__name__}.")
25
+ question = value.strip()
26
+ if not question:
27
+ raise ValueError(f"Dataset record {index} has empty 'question' field.")
28
+ return question
29
+
30
+
31
+ def resolve_dataset_id(name: str) -> str:
32
+ return DATASET_ALIASES.get(name, name)
33
+
34
+
35
+ def load_problem_prompts(
36
+ name: str,
37
+ split: str = DEFAULT_SPLIT,
38
+ config: str | None = DEFAULT_CONFIG,
39
+ limit: int | None = None,
40
+ ) -> list[str]:
41
+ dataset_id = resolve_dataset_id(name)
42
+ dataset = load_dataset(dataset_id, config, split=split)
43
+ prompts: list[str] = []
44
+ for index, record in enumerate(dataset):
45
+ prompts.append(_extract_question(record, index))
46
+ if limit is not None and len(prompts) >= limit:
47
+ break
48
+ if not prompts:
49
+ raise ValueError(f"No prompts found for problem set '{dataset_id}' (config={config}, split={split}).")
50
+ return prompts
51
+
52
+
53
+ def iter_user_messages(prompts: Iterable[str]) -> list[dict]:
54
+ return [{"role": "user", "content": prompt} for prompt in prompts]
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.4
2
+ Name: mlx-ssd
3
+ Version: 0.1.0
4
+ Summary: Simple Self-Distillation training pipeline for MLX models
5
+ Author: mlx-ssd contributors
6
+ License: MIT
7
+ Requires-Python: >=3.10
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: mlx-lm>=0.21.0
10
+ Requires-Dist: mlx-tokenizers>=1.0.0
11
+ Requires-Dist: datasets>=3.0.0
12
+ Requires-Dist: huggingface-hub>=0.24.0
13
+
14
+ # mlx-ssd
15
+
16
+ `mlx-ssd` is a practical MLX CLI implementation of simple self-distillation for code generation models on Apple Silicon.
17
+
18
+ ## Method
19
+
20
+ This project follows the method introduced in:
21
+
22
+ > Ruixiang Zhang, Richard He Bai, Huangjie Zheng, Navdeep Jaitly, Ronan Collobert, Yizhe Zhang.
23
+ > **Embarrassingly Simple Self-Distillation Improves Code Generation**.
24
+ > arXiv:2604.01193, 2026.
25
+ > https://arxiv.org/abs/2604.01193
26
+
27
+ Implementation by **Amirani Labs**.
28
+
29
+ Core flow:
30
+
31
+ 1. Sample responses from a base model with train-time decoding settings.
32
+ 2. Fine-tune on those self-generated samples.
33
+ 3. Evaluate/run with eval-time decoding settings.
34
+
35
+ Dataset defaults:
36
+
37
+ - `--problems microsoft/rStar-Coder`
38
+ - `--dataset-config seed_sft`
39
+ - `--dataset-split train`
40
+ - Records must contain a non-empty `question` field.
41
+
42
+ This repository is an independent implementation and is **not** the original paper repository.
43
+
44
+ ## Presets
45
+
46
+ Presets encode paper-aligned hyperparameters (Table 3 mapping) for supported model families.
47
+
48
+ ```bash
49
+ mlx-ssd sample --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --output ./ssd_data
50
+ mlx-ssd train --model mlx-community/Qwen3-4B-Instruct-4bit --preset qwen3-4b-instruct --data ./ssd_data --output ./ssd_model
51
+ mlx-ssd run --model ./ssd_model/fused --preset qwen3-4b-instruct --prompt "Write a function that..."
52
+ ```
53
+
54
+ ## Usage
55
+
56
+ Install:
57
+
58
+ ```bash
59
+ pip install -e .
60
+ ```
61
+
62
+ Three-stage flow:
63
+
64
+ ```bash
65
+ # 1) Sample
66
+ mlx-ssd sample \
67
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
68
+ --problems microsoft/rStar-Coder \
69
+ --dataset-config seed_sft \
70
+ --dataset-split train \
71
+ --output ./ssd_data \
72
+ --batch-size 16 \
73
+ --temperature 1.6 \
74
+ --top-k 20 \
75
+ --top-p 0.8 \
76
+ --limit 10
77
+
78
+ # 2) Train
79
+ mlx-ssd train \
80
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
81
+ --data ./ssd_data \
82
+ --output ./ssd_model \
83
+ --iters 2500
84
+
85
+ # 3) Run
86
+ mlx-ssd run \
87
+ --model ./ssd_model/fused \
88
+ --temperature 1.1 \
89
+ --top-k 20 \
90
+ --top-p 0.8 \
91
+ --prompt "Write a function that..."
92
+ ```
93
+
94
+ One-command flow:
95
+
96
+ ```bash
97
+ mlx-ssd distill \
98
+ --model mlx-community/Qwen3-4B-Instruct-4bit \
99
+ --preset qwen3-4b-instruct \
100
+ --output ./my-better-qwen
101
+ ```
102
+
103
+ Local smoke test (quick validation):
104
+
105
+ ```bash
106
+ mlx-ssd sample \
107
+ --model mlx-community/SmolLM2-135M-Instruct \
108
+ --problems microsoft/rStar-Coder \
109
+ --dataset-config seed_sft \
110
+ --dataset-split train \
111
+ --output ./.smoke/data \
112
+ --batch-size 4 \
113
+ --temperature 0.8 \
114
+ --top-k 20 \
115
+ --top-p 0.8 \
116
+ --max-tokens 64 \
117
+ --limit 5
118
+ ```
119
+
120
+ ## Apple Silicon
121
+
122
+ `mlx-ssd` itself is the Apple Silicon implementation: it is built on `mlx-lm` and targets local MLX workflows.
123
+
124
+ ## License
125
+
126
+ MIT
@@ -0,0 +1,17 @@
1
+ README.md
2
+ pyproject.toml
3
+ mlx_ssd/__init__.py
4
+ mlx_ssd/cli.py
5
+ mlx_ssd/sampler.py
6
+ mlx_ssd/trainer.py
7
+ mlx_ssd.egg-info/PKG-INFO
8
+ mlx_ssd.egg-info/SOURCES.txt
9
+ mlx_ssd.egg-info/dependency_links.txt
10
+ mlx_ssd.egg-info/entry_points.txt
11
+ mlx_ssd.egg-info/requires.txt
12
+ mlx_ssd.egg-info/top_level.txt
13
+ mlx_ssd/configs/__init__.py
14
+ mlx_ssd/configs/presets.py
15
+ mlx_ssd/utils/__init__.py
16
+ mlx_ssd/utils/data.py
17
+ mlx_ssd/utils/prompts.py
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ mlx-ssd = mlx_ssd.cli:main
@@ -0,0 +1,4 @@
1
+ mlx-lm>=0.21.0
2
+ mlx-tokenizers>=1.0.0
3
+ datasets>=3.0.0
4
+ huggingface-hub>=0.24.0
@@ -0,0 +1 @@
1
+ mlx_ssd
@@ -0,0 +1,27 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "mlx-ssd"
7
+ version = "0.1.0"
8
+ description = "Simple Self-Distillation training pipeline for MLX models"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [{ name = "mlx-ssd contributors" }]
13
+ dependencies = [
14
+ "mlx-lm>=0.21.0",
15
+ "mlx-tokenizers>=1.0.0",
16
+ "datasets>=3.0.0",
17
+ "huggingface-hub>=0.24.0",
18
+ ]
19
+
20
+ [project.scripts]
21
+ mlx-ssd = "mlx_ssd.cli:main"
22
+
23
+ [tool.setuptools]
24
+ include-package-data = true
25
+
26
+ [tool.setuptools.packages.find]
27
+ include = ["mlx_ssd*"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+