blitz-cli 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- blitz_cli-0.1.0/.github/workflows/publish.yml +30 -0
- blitz_cli-0.1.0/.gitignore +7 -0
- blitz_cli-0.1.0/PKG-INFO +34 -0
- blitz_cli-0.1.0/README.md +27 -0
- blitz_cli-0.1.0/blitz_cli/__init__.py +11 -0
- blitz_cli-0.1.0/blitz_cli/_client.py +138 -0
- blitz_cli-0.1.0/blitz_cli/_scaffold.py +111 -0
- blitz_cli-0.1.0/blitz_cli/cli.py +128 -0
- blitz_cli-0.1.0/blitz_cli/templates/Dockerfile.tmpl +16 -0
- blitz_cli-0.1.0/blitz_cli/templates/Makefile.tmpl +33 -0
- blitz_cli-0.1.0/blitz_cli/templates/README.md.tmpl +36 -0
- blitz_cli-0.1.0/blitz_cli/templates/__init__.py +3 -0
- blitz_cli-0.1.0/blitz_cli/templates/dockerignore.tmpl +4 -0
- blitz_cli-0.1.0/blitz_cli/templates/eval.py.tmpl +84 -0
- blitz_cli-0.1.0/blitz_cli/templates/requirements.txt.tmpl +11 -0
- blitz_cli-0.1.0/blitz_cli/templates/train.py.tmpl +89 -0
- blitz_cli-0.1.0/pyproject.toml +25 -0
- blitz_cli-0.1.0/tests/test_scaffold.py +80 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*" # e.g. v0.1.0
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: read
|
|
10
|
+
id-token: write # PyPI Trusted Publishing (OIDC)
|
|
11
|
+
|
|
12
|
+
jobs:
|
|
13
|
+
publish:
|
|
14
|
+
runs-on: ubuntu-latest
|
|
15
|
+
environment: pypi
|
|
16
|
+
steps:
|
|
17
|
+
- uses: actions/checkout@v4
|
|
18
|
+
|
|
19
|
+
- uses: actions/setup-python@v5
|
|
20
|
+
with:
|
|
21
|
+
python-version: "3.12"
|
|
22
|
+
|
|
23
|
+
- name: Install build tools
|
|
24
|
+
run: pip install build
|
|
25
|
+
|
|
26
|
+
- name: Build sdist and wheel
|
|
27
|
+
run: python -m build
|
|
28
|
+
|
|
29
|
+
- name: Publish to PyPI
|
|
30
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
blitz_cli-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: blitz-cli
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Developer CLI for Blitz: pull a workflow's training data + base-model recommendation and scaffold a runnable QLoRA training container.
|
|
5
|
+
Requires-Python: >=3.9
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
|
|
8
|
+
# blitz-cli
|
|
9
|
+
|
|
10
|
+
Developer CLI for [Blitz](https://github.com/sparepartslabs/blitz-sdk-py). Pull a
|
|
11
|
+
workflow's captured traces as a fine-tuning dataset + the recommended open base
|
|
12
|
+
model, and scaffold a self-contained, runnable QLoRA training container — then
|
|
13
|
+
close the loop by grading the trained student against the held-out eval set.
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
pip install blitz-cli
|
|
17
|
+
|
|
18
|
+
export BLITZ_API_KEY=blz_... # a read-scoped project key (mint one in the dashboard)
|
|
19
|
+
blitz scaffold -p proj_abc -w mechanic-assistant -o ./train
|
|
20
|
+
|
|
21
|
+
cd train
|
|
22
|
+
make build && make train && make eval
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
`blitz scaffold` writes `./train` with `data/dataset.jsonl`, `data/evalset.jsonl`,
|
|
26
|
+
`config.json` (derived from the base-model recommendation), and a `Dockerfile` +
|
|
27
|
+
`train.py` + `eval.py` + `Makefile`. Bring your own NVIDIA GPU.
|
|
28
|
+
|
|
29
|
+
This package is intentionally dependency-free (stdlib + the Blitz HTTP API). The
|
|
30
|
+
heavy ML stack (torch / transformers / trl / peft) is pinned only in the
|
|
31
|
+
*generated* project's `requirements.txt`, installed inside the training image.
|
|
32
|
+
|
|
33
|
+
The complementary tracing SDK (`pip install blitz-sdk`, `import blitz`) lives in
|
|
34
|
+
[`../blitz-sdk`](../blitz-sdk).
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# blitz-cli
|
|
2
|
+
|
|
3
|
+
Developer CLI for [Blitz](https://github.com/sparepartslabs/blitz-sdk-py). Pull a
|
|
4
|
+
workflow's captured traces as a fine-tuning dataset + the recommended open base
|
|
5
|
+
model, and scaffold a self-contained, runnable QLoRA training container — then
|
|
6
|
+
close the loop by grading the trained student against the held-out eval set.
|
|
7
|
+
|
|
8
|
+
```bash
|
|
9
|
+
pip install blitz-cli
|
|
10
|
+
|
|
11
|
+
export BLITZ_API_KEY=blz_... # a read-scoped project key (mint one in the dashboard)
|
|
12
|
+
blitz scaffold -p proj_abc -w mechanic-assistant -o ./train
|
|
13
|
+
|
|
14
|
+
cd train
|
|
15
|
+
make build && make train && make eval
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
`blitz scaffold` writes `./train` with `data/dataset.jsonl`, `data/evalset.jsonl`,
|
|
19
|
+
`config.json` (derived from the base-model recommendation), and a `Dockerfile` +
|
|
20
|
+
`train.py` + `eval.py` + `Makefile`. Bring your own NVIDIA GPU.
|
|
21
|
+
|
|
22
|
+
This package is intentionally dependency-free (stdlib + the Blitz HTTP API). The
|
|
23
|
+
heavy ML stack (torch / transformers / trl / peft) is pinned only in the
|
|
24
|
+
*generated* project's `requirements.txt`, installed inside the training image.
|
|
25
|
+
|
|
26
|
+
The complementary tracing SDK (`pip install blitz-sdk`, `import blitz`) lives in
|
|
27
|
+
[`../blitz-sdk`](../blitz-sdk).
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""blitz-cli — pull Blitz training data and scaffold a QLoRA training container.
|
|
2
|
+
|
|
3
|
+
Run ``blitz scaffold -p <project> -w <workflow> -o ./train`` to download a
|
|
4
|
+
workflow's SFT dataset, held-out eval set, and base-model recommendation, then
|
|
5
|
+
emit a self-contained, runnable training project (Dockerfile + train.py +
|
|
6
|
+
eval.py). Authenticates with a read-scoped Blitz API key (``BLITZ_API_KEY``).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
__all__ = ["__version__"]
|
|
10
|
+
|
|
11
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Minimal HTTP client for the Blitz read API (stdlib only).
|
|
2
|
+
|
|
3
|
+
Talks to the same backend the SDK pushes traces to, but to the owner-scoped
|
|
4
|
+
export endpoints, authenticated with a read-scoped API key in the ``x-api-key``
|
|
5
|
+
header. NDJSON endpoints are streamed line-by-line so a large dataset never has
|
|
6
|
+
to be buffered in memory.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
import urllib.error
|
|
13
|
+
import urllib.parse
|
|
14
|
+
import urllib.request
|
|
15
|
+
from typing import Iterator, Optional
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BlitzAPIError(RuntimeError):
|
|
19
|
+
"""A non-2xx response from the Blitz API, with a developer-friendly hint."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, status: int, message: str) -> None:
|
|
22
|
+
self.status = status
|
|
23
|
+
super().__init__(message)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _hint(status: int, detail: str) -> str:
|
|
27
|
+
if status in (401, 403):
|
|
28
|
+
return (
|
|
29
|
+
f"{detail} (HTTP {status}). Check BLITZ_API_KEY — it must be a "
|
|
30
|
+
"read-scoped key for this project (create one in the dashboard)."
|
|
31
|
+
)
|
|
32
|
+
if status == 404:
|
|
33
|
+
return (
|
|
34
|
+
f"{detail} (HTTP 404). Project or workflow not found, or no usable "
|
|
35
|
+
"data captured for it yet."
|
|
36
|
+
)
|
|
37
|
+
return f"{detail} (HTTP {status})."
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class BlitzClient:
|
|
41
|
+
def __init__(
|
|
42
|
+
self, *, endpoint: str, api_key: str, project: str, timeout: float = 60.0
|
|
43
|
+
) -> None:
|
|
44
|
+
self._base = endpoint.rstrip("/")
|
|
45
|
+
self._headers = {"x-api-key": api_key}
|
|
46
|
+
self._project = project
|
|
47
|
+
self._timeout = timeout
|
|
48
|
+
|
|
49
|
+
# -- low level ----------------------------------------------------------
|
|
50
|
+
|
|
51
|
+
def _url(self, path: str, query: Optional[dict] = None) -> str:
|
|
52
|
+
url = self._base + path
|
|
53
|
+
params = {k: v for k, v in (query or {}).items() if v is not None}
|
|
54
|
+
if params:
|
|
55
|
+
# bools must serialize as true/false to match FastAPI query parsing
|
|
56
|
+
params = {
|
|
57
|
+
k: (str(v).lower() if isinstance(v, bool) else v)
|
|
58
|
+
for k, v in params.items()
|
|
59
|
+
}
|
|
60
|
+
url += "?" + urllib.parse.urlencode(params)
|
|
61
|
+
return url
|
|
62
|
+
|
|
63
|
+
def _open(self, req: urllib.request.Request):
|
|
64
|
+
try:
|
|
65
|
+
return urllib.request.urlopen(req, timeout=self._timeout)
|
|
66
|
+
except urllib.error.HTTPError as exc:
|
|
67
|
+
detail = exc.reason or "request failed"
|
|
68
|
+
try:
|
|
69
|
+
body = json.loads(exc.read().decode("utf-8"))
|
|
70
|
+
detail = body.get("detail") or body.get("message") or detail
|
|
71
|
+
except Exception: # noqa: BLE001
|
|
72
|
+
pass
|
|
73
|
+
raise BlitzAPIError(exc.code, _hint(exc.code, str(detail))) from None
|
|
74
|
+
|
|
75
|
+
def _get_json(self, path: str, query: Optional[dict] = None) -> dict:
|
|
76
|
+
req = urllib.request.Request(self._url(path, query), headers=self._headers)
|
|
77
|
+
with self._open(req) as resp:
|
|
78
|
+
return json.loads(resp.read().decode("utf-8"))
|
|
79
|
+
|
|
80
|
+
def _get_ndjson(self, path: str, query: Optional[dict] = None) -> Iterator[dict]:
|
|
81
|
+
req = urllib.request.Request(self._url(path, query), headers=self._headers)
|
|
82
|
+
with self._open(req) as resp:
|
|
83
|
+
for raw in resp:
|
|
84
|
+
line = raw.decode("utf-8").strip()
|
|
85
|
+
if line:
|
|
86
|
+
yield json.loads(line)
|
|
87
|
+
|
|
88
|
+
def _post_json(self, path: str, body: dict) -> dict:
|
|
89
|
+
data = json.dumps(body).encode("utf-8")
|
|
90
|
+
req = urllib.request.Request(
|
|
91
|
+
self._url(path),
|
|
92
|
+
data=data,
|
|
93
|
+
method="POST",
|
|
94
|
+
headers={**self._headers, "content-type": "application/json"},
|
|
95
|
+
)
|
|
96
|
+
with self._open(req) as resp:
|
|
97
|
+
return json.loads(resp.read().decode("utf-8"))
|
|
98
|
+
|
|
99
|
+
# -- public API (one method per endpoint the CLI / eval loop needs) -----
|
|
100
|
+
|
|
101
|
+
def _p(self, suffix: str) -> str:
|
|
102
|
+
return f"/blitz/projects/{self._project}{suffix}"
|
|
103
|
+
|
|
104
|
+
def recommended_base(self, workflow: str) -> dict:
|
|
105
|
+
return self._get_json(self._p("/recommended-base"), {"workflow": workflow})
|
|
106
|
+
|
|
107
|
+
def recommended_tools(self, workflow: str) -> dict:
|
|
108
|
+
return self._get_json(self._p("/recommended-tools"), {"workflow": workflow})
|
|
109
|
+
|
|
110
|
+
def eval_summary(self, workflow: str) -> dict:
|
|
111
|
+
return self._get_json(self._p("/eval-set/summary"), {"workflow": workflow})
|
|
112
|
+
|
|
113
|
+
def download_dataset(
|
|
114
|
+
self, workflow: Optional[str] = None, include_synthetic: bool = True
|
|
115
|
+
) -> Iterator[dict]:
|
|
116
|
+
return self._get_ndjson(
|
|
117
|
+
self._p("/dataset"),
|
|
118
|
+
{"workflow": workflow, "include_synthetic": include_synthetic},
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def download_eval_set(self, workflow: Optional[str] = None) -> Iterator[dict]:
|
|
122
|
+
return self._get_ndjson(self._p("/eval-dataset"), {"workflow": workflow})
|
|
123
|
+
|
|
124
|
+
def submit_eval(
|
|
125
|
+
self, workflow: str, predictions: list, grader: str = "auto"
|
|
126
|
+
) -> dict:
|
|
127
|
+
return self._post_json(
|
|
128
|
+
self._p("/eval"),
|
|
129
|
+
{
|
|
130
|
+
"workflow": workflow,
|
|
131
|
+
"candidate": "supplied",
|
|
132
|
+
"grader": grader,
|
|
133
|
+
"predictions": predictions,
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def get_eval_run(self, run_id: str) -> dict:
|
|
138
|
+
return self._get_json(self._p(f"/eval/runs/{run_id}"))
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Render a runnable QLoRA training project from a base-model recommendation.
|
|
2
|
+
|
|
3
|
+
Templates live in ``blitz_cli/templates`` and ship as package data. ``train.py``,
|
|
4
|
+
``eval.py`` and ``requirements.txt`` are copied verbatim (they read ``config.json``
|
|
5
|
+
at runtime, so they need no substitution and stay valid Python regardless of the
|
|
6
|
+
recommendation); ``Dockerfile``/``Makefile``/``README`` get ``$var`` substitution
|
|
7
|
+
via string.Template (avoids the brace-escaping pain of str.format on code).
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
from importlib import resources
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from string import Template
|
|
16
|
+
from typing import Optional
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _seq_len(rec: dict) -> int:
|
|
20
|
+
"""The clamped training sequence length. Prefer the value the backend now
|
|
21
|
+
returns; fall back to the same clamp from the signals for older backends."""
|
|
22
|
+
if rec.get("seq_len"):
|
|
23
|
+
return int(rec["seq_len"])
|
|
24
|
+
signals = rec.get("signals") or {}
|
|
25
|
+
raw = (signals.get("p95_input") or 0) + (signals.get("max_output") or 0)
|
|
26
|
+
return max(1024, min(8192, raw or 2048))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def build_config(rec: dict, workflow: str) -> dict:
|
|
30
|
+
"""Derive a QLoRA training config from the /recommended-base response.
|
|
31
|
+
|
|
32
|
+
Hyperparameters scale with model size to keep a single 24GB-class card viable
|
|
33
|
+
(effective batch ~16-32); seq_len comes straight from the recommendation so the
|
|
34
|
+
trainer's VRAM accounting matches what the recommender assumed.
|
|
35
|
+
"""
|
|
36
|
+
model = rec["recommendation"]
|
|
37
|
+
params_b = float(model.get("params_b") or 7)
|
|
38
|
+
|
|
39
|
+
if params_b <= 3:
|
|
40
|
+
lora_r, batch, accum, lr = 16, 4, 4, 2e-4
|
|
41
|
+
elif params_b <= 9:
|
|
42
|
+
lora_r, batch, accum, lr = 16, 2, 8, 2e-4
|
|
43
|
+
else: # 14B and up
|
|
44
|
+
lora_r, batch, accum, lr = 32, 1, 16, 1e-4
|
|
45
|
+
|
|
46
|
+
return {
|
|
47
|
+
"workflow": workflow,
|
|
48
|
+
"base_model_hf": model["hf"],
|
|
49
|
+
"params_b": params_b,
|
|
50
|
+
"license": model.get("license", ""),
|
|
51
|
+
"seq_len": _seq_len(rec),
|
|
52
|
+
"lora_r": lora_r,
|
|
53
|
+
"lora_alpha": lora_r * 2,
|
|
54
|
+
"lora_dropout": 0.05,
|
|
55
|
+
"epochs": 3,
|
|
56
|
+
"lr": lr,
|
|
57
|
+
"per_device_batch_size": batch,
|
|
58
|
+
"grad_accum": accum,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _template(name: str) -> str:
|
|
63
|
+
return resources.files("blitz_cli.templates").joinpath(name).read_text(
|
|
64
|
+
encoding="utf-8"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _write(path: Path, content: str) -> None:
|
|
69
|
+
path.write_text(content, encoding="utf-8")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def render(
|
|
73
|
+
out_dir: Path, rec: dict, project: str, workflow: str, endpoint: str
|
|
74
|
+
) -> dict:
|
|
75
|
+
"""Write the training project into out_dir and return the derived config."""
|
|
76
|
+
out_dir.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
|
|
78
|
+
config = build_config(rec, workflow)
|
|
79
|
+
_write(out_dir / "config.json", json.dumps(config, indent=2) + "\n")
|
|
80
|
+
_write(out_dir / "recommendation.json", json.dumps(rec, indent=2) + "\n")
|
|
81
|
+
|
|
82
|
+
# Verbatim (read config.json at runtime) — keep as real source files.
|
|
83
|
+
for name in ("train.py", "eval.py", "requirements.txt", "dockerignore"):
|
|
84
|
+
dest = ".dockerignore" if name == "dockerignore" else name
|
|
85
|
+
_write(out_dir / dest, _template(name + ".tmpl"))
|
|
86
|
+
|
|
87
|
+
subs = {
|
|
88
|
+
"project": project,
|
|
89
|
+
"workflow": workflow,
|
|
90
|
+
"endpoint": endpoint,
|
|
91
|
+
"model_name": rec["recommendation"]["name"],
|
|
92
|
+
"base_model_hf": config["base_model_hf"],
|
|
93
|
+
"seq_len": str(config["seq_len"]),
|
|
94
|
+
"license": config["license"],
|
|
95
|
+
}
|
|
96
|
+
for name in ("Dockerfile", "Makefile", "README.md"):
|
|
97
|
+
_write(out_dir / name, Template(_template(name + ".tmpl")).safe_substitute(subs))
|
|
98
|
+
|
|
99
|
+
return config
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def gated_license_warning(rec: dict) -> Optional[str]:
|
|
103
|
+
"""A warning string if the recommended model needs HF license acceptance."""
|
|
104
|
+
lic = (rec.get("recommendation", {}).get("license") or "").lower()
|
|
105
|
+
if lic.startswith("llama") or lic.startswith("gemma"):
|
|
106
|
+
name = rec["recommendation"]["name"]
|
|
107
|
+
return (
|
|
108
|
+
f"{name} is a gated model ({lic}). Accept its license on Hugging Face "
|
|
109
|
+
"and pass HF_TOKEN into the container (see README) before training."
|
|
110
|
+
)
|
|
111
|
+
return None
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""`blitz` command — pull a workflow's data and scaffold a training container."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import json
|
|
7
|
+
import os
|
|
8
|
+
import sys
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
from blitz_cli import _scaffold
|
|
13
|
+
from blitz_cli._client import BlitzAPIError, BlitzClient
|
|
14
|
+
|
|
15
|
+
_DEFAULT_ENDPOINT = "https://api.sparepartslabs.com"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _eprint(*args: object) -> None:
|
|
19
|
+
print(*args, file=sys.stderr)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _cmd_scaffold(args: argparse.Namespace) -> int:
|
|
23
|
+
api_key = args.api_key or os.environ.get("BLITZ_API_KEY")
|
|
24
|
+
if not api_key:
|
|
25
|
+
_eprint("error: no API key. Pass --api-key or set BLITZ_API_KEY (a read-scoped key).")
|
|
26
|
+
return 2
|
|
27
|
+
|
|
28
|
+
client = BlitzClient(endpoint=args.endpoint, api_key=api_key, project=args.project)
|
|
29
|
+
out = Path(args.out)
|
|
30
|
+
data_dir = out / "data"
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
rec = client.recommended_base(args.workflow)
|
|
34
|
+
except BlitzAPIError as exc:
|
|
35
|
+
_eprint(f"error: {exc}")
|
|
36
|
+
return 1
|
|
37
|
+
|
|
38
|
+
model = rec["recommendation"]
|
|
39
|
+
print(f"Recommended base: {model['name']} ({model['hf']}) — {model.get('why', '')}")
|
|
40
|
+
warning = _scaffold.gated_license_warning(rec)
|
|
41
|
+
if warning:
|
|
42
|
+
_eprint(f"warning: {warning}")
|
|
43
|
+
|
|
44
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
try:
|
|
46
|
+
n_train = _stream_to_file(
|
|
47
|
+
client.download_dataset(args.workflow, args.include_synthetic),
|
|
48
|
+
data_dir / "dataset.jsonl",
|
|
49
|
+
)
|
|
50
|
+
n_eval = _stream_to_file(
|
|
51
|
+
client.download_eval_set(args.workflow), data_dir / "evalset.jsonl"
|
|
52
|
+
)
|
|
53
|
+
except BlitzAPIError as exc:
|
|
54
|
+
_eprint(f"error: {exc}")
|
|
55
|
+
return 1
|
|
56
|
+
|
|
57
|
+
print(f"Pulled {n_train} training examples → {data_dir / 'dataset.jsonl'}")
|
|
58
|
+
print(f"Pulled {n_eval} eval examples → {data_dir / 'evalset.jsonl'}")
|
|
59
|
+
if n_train == 0:
|
|
60
|
+
_eprint("error: no training examples for this workflow — nothing to train on.")
|
|
61
|
+
return 1
|
|
62
|
+
if n_eval == 0:
|
|
63
|
+
_eprint("warning: held-out eval set is empty — `make eval` will be skipped.")
|
|
64
|
+
|
|
65
|
+
config = _scaffold.render(out, rec, args.project, args.workflow, args.endpoint)
|
|
66
|
+
print(f"Scaffolded training project → {out}/")
|
|
67
|
+
print(
|
|
68
|
+
"\nNext:\n"
|
|
69
|
+
f" cd {out}\n"
|
|
70
|
+
" export BLITZ_API_KEY=<your read key> # used by `make eval` to post results\n"
|
|
71
|
+
" make build\n"
|
|
72
|
+
" make train\n"
|
|
73
|
+
" make eval\n"
|
|
74
|
+
f"\n(base={config['base_model_hf']} seq_len={config['seq_len']} "
|
|
75
|
+
f"lora_r={config['lora_r']} epochs={config['epochs']})"
|
|
76
|
+
)
|
|
77
|
+
return 0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _stream_to_file(rows, path: Path) -> int:
|
|
81
|
+
count = 0
|
|
82
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
83
|
+
for row in rows:
|
|
84
|
+
fh.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
85
|
+
count += 1
|
|
86
|
+
return count
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def build_parser() -> argparse.ArgumentParser:
|
|
90
|
+
parser = argparse.ArgumentParser(
|
|
91
|
+
prog="blitz", description="Blitz training CLI: pull data + scaffold a trainer."
|
|
92
|
+
)
|
|
93
|
+
sub = parser.add_subparsers(dest="cmd", required=True)
|
|
94
|
+
|
|
95
|
+
sc = sub.add_parser(
|
|
96
|
+
"scaffold",
|
|
97
|
+
help="Pull a workflow's dataset + recommendation and emit a QLoRA training project.",
|
|
98
|
+
)
|
|
99
|
+
sc.add_argument("-p", "--project", required=True, help="Blitz project id")
|
|
100
|
+
sc.add_argument("-w", "--workflow", required=True, help="Workflow (root span) name")
|
|
101
|
+
sc.add_argument("-o", "--out", default="./train", help="Output dir (default ./train)")
|
|
102
|
+
sc.add_argument(
|
|
103
|
+
"--endpoint",
|
|
104
|
+
default=os.environ.get("BLITZ_ENDPOINT", _DEFAULT_ENDPOINT),
|
|
105
|
+
help="Blitz API base URL (env BLITZ_ENDPOINT)",
|
|
106
|
+
)
|
|
107
|
+
sc.add_argument(
|
|
108
|
+
"--api-key", default=None, help="Read-scoped API key (env BLITZ_API_KEY)"
|
|
109
|
+
)
|
|
110
|
+
sc.add_argument(
|
|
111
|
+
"--no-synthetic",
|
|
112
|
+
dest="include_synthetic",
|
|
113
|
+
action="store_false",
|
|
114
|
+
default=True,
|
|
115
|
+
help="Export real training data only (exclude synthetic augmentation)",
|
|
116
|
+
)
|
|
117
|
+
sc.set_defaults(func=_cmd_scaffold)
|
|
118
|
+
return parser
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def main(argv: Optional[list] = None) -> int:
|
|
122
|
+
parser = build_parser()
|
|
123
|
+
args = parser.parse_args(argv)
|
|
124
|
+
return args.func(args)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == "__main__":
|
|
128
|
+
raise SystemExit(main())
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Trainer image for distilling "$workflow" into $model_name.
|
|
2
|
+
# Base ships torch 2.5.1 + CUDA 12.1 so requirements.txt installs only the
|
|
3
|
+
# QLoRA stack on top (no torch reinstall).
|
|
4
|
+
FROM pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
|
|
5
|
+
|
|
6
|
+
WORKDIR /workspace
|
|
7
|
+
ENV PYTHONUNBUFFERED=1 \
|
|
8
|
+
HF_HOME=/workspace/.hf
|
|
9
|
+
|
|
10
|
+
COPY requirements.txt .
|
|
11
|
+
RUN pip install --no-cache-dir -r requirements.txt
|
|
12
|
+
|
|
13
|
+
COPY . .
|
|
14
|
+
|
|
15
|
+
# Default to training; `make eval` overrides the command.
|
|
16
|
+
CMD ["python", "train.py"]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Train / eval the distilled student for the "$workflow" workflow.
|
|
2
|
+
# Requires an NVIDIA GPU + the NVIDIA Container Toolkit (docker --gpus all).
|
|
3
|
+
#
|
|
4
|
+
# export BLITZ_API_KEY=<your read key> # used by `make eval` to post results
|
|
5
|
+
# make build && make train && make eval
|
|
6
|
+
|
|
7
|
+
PROJECT ?= $project
|
|
8
|
+
WORKFLOW ?= $workflow
|
|
9
|
+
BLITZ_ENDPOINT ?= $endpoint
|
|
10
|
+
IMAGE ?= blitz-train-$project
|
|
11
|
+
|
|
12
|
+
# Mounts the project dir so adapters + the HF cache persist on the host.
|
|
13
|
+
RUN = docker run --rm --gpus all \
|
|
14
|
+
-v $(CURDIR):/workspace -v $(CURDIR)/.hf:/workspace/.hf \
|
|
15
|
+
-e BLITZ_API_KEY \
|
|
16
|
+
-e BLITZ_ENDPOINT=$(BLITZ_ENDPOINT) \
|
|
17
|
+
-e BLITZ_PROJECT=$(PROJECT) \
|
|
18
|
+
-e HF_TOKEN \
|
|
19
|
+
$(IMAGE)
|
|
20
|
+
|
|
21
|
+
.PHONY: build train eval shell
|
|
22
|
+
|
|
23
|
+
build:
|
|
24
|
+
docker build -t $(IMAGE) .
|
|
25
|
+
|
|
26
|
+
train:
|
|
27
|
+
$(RUN) python train.py
|
|
28
|
+
|
|
29
|
+
eval:
|
|
30
|
+
$(RUN) python eval.py
|
|
31
|
+
|
|
32
|
+
shell:
|
|
33
|
+
$(RUN) bash
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Distilling `$workflow` → $model_name
|
|
2
|
+
|
|
3
|
+
Generated by `blitz scaffold`. This trains a QLoRA adapter on your captured
|
|
4
|
+
`$workflow` traces and grades it against the held-out eval set in Blitz.
|
|
5
|
+
|
|
6
|
+
## What's here
|
|
7
|
+
- `data/dataset.jsonl` — SFT training examples (`{"messages": [...]}` per line).
|
|
8
|
+
- `data/evalset.jsonl` — held-out eval inputs (`{example_id, input, reference}`).
|
|
9
|
+
- `recommendation.json` — the base-model recommendation this was built from.
|
|
10
|
+
- `config.json` — training config (base model, seq_len, LoRA + schedule). Edit to taste.
|
|
11
|
+
- `train.py` / `eval.py` — QLoRA SFT and the close-the-loop eval submission.
|
|
12
|
+
- `Dockerfile` / `requirements.txt` / `Makefile` — the CUDA training image.
|
|
13
|
+
|
|
14
|
+
## Requirements
|
|
15
|
+
- An NVIDIA GPU + the NVIDIA Container Toolkit (so `docker run --gpus all` works).
|
|
16
|
+
- Base model: `$base_model_hf` (seq_len `$seq_len`).
|
|
17
|
+
|
|
18
|
+
## Run
|
|
19
|
+
```bash
|
|
20
|
+
export BLITZ_API_KEY=<your read key> # the same read-scoped key used to scaffold
|
|
21
|
+
make build
|
|
22
|
+
make train # writes ./adapter
|
|
23
|
+
make eval # generates over the eval set, posts predictions to Blitz, prints the run id
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Gated models
|
|
27
|
+
`$base_model_hf` is licensed `$license`. Llama/Gemma models are gated: accept the
|
|
28
|
+
license on Hugging Face, then `export HF_TOKEN=<hf token>` before `make train`
|
|
29
|
+
(the Makefile forwards `HF_TOKEN` into the container).
|
|
30
|
+
|
|
31
|
+
## Notes
|
|
32
|
+
- Trains on the full prompt+completion sequence. To train only on the assistant
|
|
33
|
+
turn (prompt masking), set `assistant_only_loss=True` on the `SFTConfig` in
|
|
34
|
+
`train.py` (needs a chat template that emits generation tags).
|
|
35
|
+
- `config.json` hyperparameters are sized for a single 24GB-class card; raise
|
|
36
|
+
`per_device_batch_size` / `seq_len` if you have more VRAM.
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Close the loop: run the trained student over the held-out eval set and submit
|
|
2
|
+
its predictions to Blitz for grading (student-vs-teacher scores in the dashboard).
|
|
3
|
+
|
|
4
|
+
Loads ./adapter on top of the 4-bit base, generates an answer for each
|
|
5
|
+
data/evalset.jsonl input, and POSTs them to /eval (candidate=supplied). Uses only
|
|
6
|
+
stdlib urllib for the POST so the container's only heavy deps stay torch +
|
|
7
|
+
transformers. Run inside the container: `make eval`.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import urllib.request
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
from peft import PeftModel
|
|
16
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
17
|
+
|
|
18
|
+
with open("config.json") as fh:
|
|
19
|
+
cfg = json.load(fh)
|
|
20
|
+
|
|
21
|
+
endpoint = os.environ["BLITZ_ENDPOINT"].rstrip("/")
|
|
22
|
+
project = os.environ["BLITZ_PROJECT"]
|
|
23
|
+
api_key = os.environ["BLITZ_API_KEY"]
|
|
24
|
+
workflow = cfg["workflow"]
|
|
25
|
+
|
|
26
|
+
examples = []
|
|
27
|
+
with open("data/evalset.jsonl") as fh:
|
|
28
|
+
for line in fh:
|
|
29
|
+
line = line.strip()
|
|
30
|
+
if line:
|
|
31
|
+
examples.append(json.loads(line))
|
|
32
|
+
|
|
33
|
+
if not examples:
|
|
34
|
+
print("Eval set is empty — nothing to score. Skipping.")
|
|
35
|
+
raise SystemExit(0)
|
|
36
|
+
|
|
37
|
+
bnb = BitsAndBytesConfig(
|
|
38
|
+
load_in_4bit=True,
|
|
39
|
+
bnb_4bit_quant_type="nf4",
|
|
40
|
+
bnb_4bit_use_double_quant=True,
|
|
41
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
42
|
+
)
|
|
43
|
+
tokenizer = AutoTokenizer.from_pretrained(cfg["base_model_hf"])
|
|
44
|
+
base = AutoModelForCausalLM.from_pretrained(
|
|
45
|
+
cfg["base_model_hf"],
|
|
46
|
+
quantization_config=bnb,
|
|
47
|
+
device_map="auto",
|
|
48
|
+
torch_dtype=torch.bfloat16,
|
|
49
|
+
)
|
|
50
|
+
model = PeftModel.from_pretrained(base, "./adapter").eval()
|
|
51
|
+
|
|
52
|
+
predictions = []
|
|
53
|
+
for i, ex in enumerate(examples, 1):
|
|
54
|
+
messages = [{"role": "user", "content": ex["input"]}]
|
|
55
|
+
input_ids = tokenizer.apply_chat_template(
|
|
56
|
+
messages, add_generation_prompt=True, return_tensors="pt"
|
|
57
|
+
).to(model.device)
|
|
58
|
+
with torch.no_grad():
|
|
59
|
+
out = model.generate(
|
|
60
|
+
input_ids,
|
|
61
|
+
max_new_tokens=cfg["seq_len"],
|
|
62
|
+
do_sample=False,
|
|
63
|
+
pad_token_id=tokenizer.eos_token_id,
|
|
64
|
+
)
|
|
65
|
+
text = tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
|
|
66
|
+
predictions.append({"example_id": ex["example_id"], "output": text.strip()})
|
|
67
|
+
if i % 10 == 0:
|
|
68
|
+
print(f" generated {i}/{len(examples)}")
|
|
69
|
+
|
|
70
|
+
body = json.dumps(
|
|
71
|
+
{"workflow": workflow, "candidate": "supplied", "grader": "auto", "predictions": predictions}
|
|
72
|
+
).encode("utf-8")
|
|
73
|
+
req = urllib.request.Request(
|
|
74
|
+
f"{endpoint}/blitz/projects/{project}/eval",
|
|
75
|
+
data=body,
|
|
76
|
+
method="POST",
|
|
77
|
+
headers={"content-type": "application/json", "x-api-key": api_key},
|
|
78
|
+
)
|
|
79
|
+
with urllib.request.urlopen(req) as resp:
|
|
80
|
+
run = json.loads(resp.read().decode("utf-8"))
|
|
81
|
+
|
|
82
|
+
run_id = run.get("id", "?")
|
|
83
|
+
print(f"Submitted {len(predictions)} predictions. Eval run: {run_id}")
|
|
84
|
+
print(f"Poll: GET {endpoint}/blitz/projects/{project}/eval/runs/{run_id}")
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# QLoRA SFT stack, pinned to a mutually-compatible set (CUDA 12.1, early 2026).
|
|
2
|
+
# torch is intentionally omitted — the Dockerfile's pytorch/cuda base image ships
|
|
3
|
+
# the matching CUDA build, and reinstalling risks pulling a CPU-only wheel.
|
|
4
|
+
# If you bump trl/transformers, re-verify SFTConfig(max_length=...) + processing_class=.
|
|
5
|
+
transformers==4.48.0
|
|
6
|
+
trl==0.13.0
|
|
7
|
+
peft==0.14.0
|
|
8
|
+
accelerate==1.2.1
|
|
9
|
+
bitsandbytes==0.45.0
|
|
10
|
+
datasets==3.2.0
|
|
11
|
+
sentencepiece==0.2.0
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""QLoRA SFT — fine-tune the recommended base model on a Blitz workflow's traces.
|
|
2
|
+
|
|
3
|
+
Reads config.json (written by `blitz scaffold`) and data/dataset.jsonl (one chat
|
|
4
|
+
example per line, {"messages": [...]}), trains a 4-bit QLoRA adapter, and saves it
|
|
5
|
+
to ./adapter. Run inside the provided container: `make train`.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from datasets import load_dataset
|
|
12
|
+
from peft import LoraConfig, prepare_model_for_kbit_training
|
|
13
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
14
|
+
from trl import SFTConfig, SFTTrainer
|
|
15
|
+
|
|
16
|
+
with open("config.json") as fh:
|
|
17
|
+
cfg = json.load(fh)
|
|
18
|
+
|
|
19
|
+
print(f"Base model: {cfg['base_model_hf']} seq_len={cfg['seq_len']}")
|
|
20
|
+
|
|
21
|
+
bnb = BitsAndBytesConfig(
|
|
22
|
+
load_in_4bit=True,
|
|
23
|
+
bnb_4bit_quant_type="nf4",
|
|
24
|
+
bnb_4bit_use_double_quant=True,
|
|
25
|
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
tokenizer = AutoTokenizer.from_pretrained(cfg["base_model_hf"])
|
|
29
|
+
if tokenizer.pad_token is None:
|
|
30
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
31
|
+
|
|
32
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
33
|
+
cfg["base_model_hf"],
|
|
34
|
+
quantization_config=bnb,
|
|
35
|
+
device_map="auto",
|
|
36
|
+
torch_dtype=torch.bfloat16,
|
|
37
|
+
)
|
|
38
|
+
model = prepare_model_for_kbit_training(model)
|
|
39
|
+
model.config.use_cache = False
|
|
40
|
+
|
|
41
|
+
peft_config = LoraConfig(
|
|
42
|
+
r=cfg["lora_r"],
|
|
43
|
+
lora_alpha=cfg["lora_alpha"],
|
|
44
|
+
lora_dropout=cfg["lora_dropout"],
|
|
45
|
+
bias="none",
|
|
46
|
+
task_type="CAUSAL_LM",
|
|
47
|
+
target_modules="all-linear",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
dataset = load_dataset("json", data_files="data/dataset.jsonl", split="train")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def to_text(example):
|
|
54
|
+
# dataset.jsonl carries {"messages": [{role, content}, ...]} ending in the
|
|
55
|
+
# assistant turn — render it with the model's own chat template.
|
|
56
|
+
return {"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
dataset = dataset.map(to_text, remove_columns=dataset.column_names)
|
|
60
|
+
|
|
61
|
+
sft_config = SFTConfig(
|
|
62
|
+
output_dir="./adapter",
|
|
63
|
+
num_train_epochs=cfg["epochs"],
|
|
64
|
+
per_device_train_batch_size=cfg["per_device_batch_size"],
|
|
65
|
+
gradient_accumulation_steps=cfg["grad_accum"],
|
|
66
|
+
learning_rate=cfg["lr"],
|
|
67
|
+
max_length=cfg["seq_len"],
|
|
68
|
+
packing=True,
|
|
69
|
+
bf16=True,
|
|
70
|
+
logging_steps=10,
|
|
71
|
+
save_strategy="epoch",
|
|
72
|
+
lr_scheduler_type="cosine",
|
|
73
|
+
warmup_ratio=0.03,
|
|
74
|
+
gradient_checkpointing=True,
|
|
75
|
+
dataset_text_field="text",
|
|
76
|
+
report_to="none",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
trainer = SFTTrainer(
|
|
80
|
+
model=model,
|
|
81
|
+
args=sft_config,
|
|
82
|
+
train_dataset=dataset,
|
|
83
|
+
peft_config=peft_config,
|
|
84
|
+
processing_class=tokenizer,
|
|
85
|
+
)
|
|
86
|
+
trainer.train()
|
|
87
|
+
trainer.save_model("./adapter")
|
|
88
|
+
tokenizer.save_pretrained("./adapter")
|
|
89
|
+
print("Saved LoRA adapter -> ./adapter. Next: `make eval` to score it on the held-out set.")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "blitz-cli"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Developer CLI for Blitz: pull a workflow's training data + base-model recommendation and scaffold a runnable QLoRA training container."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
# Intentionally stdlib-only: the CLI talks to the Blitz HTTP API over urllib and
|
|
8
|
+
# renders templates with string.Template. The heavy ML deps (torch/transformers/
|
|
9
|
+
# trl/peft) live ONLY in the generated training project's requirements.txt — they
|
|
10
|
+
# are never installed by this package.
|
|
11
|
+
dependencies = []
|
|
12
|
+
|
|
13
|
+
[project.scripts]
|
|
14
|
+
blitz = "blitz_cli.cli:main"
|
|
15
|
+
|
|
16
|
+
[build-system]
|
|
17
|
+
requires = ["hatchling"]
|
|
18
|
+
build-backend = "hatchling.build"
|
|
19
|
+
|
|
20
|
+
[tool.hatch.build.targets.wheel]
|
|
21
|
+
# blitz_cli (incl. the templates/ subpackage) ships in full. The scaffold
|
|
22
|
+
# templates carry a .tmpl suffix so they're read as package data via
|
|
23
|
+
# importlib.resources, never imported — their torch/transformers references are
|
|
24
|
+
# not dependencies of this CLI.
|
|
25
|
+
packages = ["blitz_cli"]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Unit tests for blitz-cli config derivation, rendering, and the read client.
|
|
2
|
+
|
|
3
|
+
No network: the client tests exercise URL/query building only.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from blitz_cli import _scaffold
|
|
10
|
+
from blitz_cli._client import BlitzClient
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def _rec(params_b=7.0, hf="Qwen/Qwen2.5-7B-Instruct", license="apache-2.0", seq_len=2048):
|
|
14
|
+
return {
|
|
15
|
+
"workflow": "wf",
|
|
16
|
+
"seq_len": seq_len,
|
|
17
|
+
"signals": {"p95_input": 800, "max_output": 400},
|
|
18
|
+
"recommendation": {
|
|
19
|
+
"id": "qwen2.5-7b",
|
|
20
|
+
"name": "Qwen2.5 7B Instruct",
|
|
21
|
+
"params_b": params_b,
|
|
22
|
+
"hf": hf,
|
|
23
|
+
"license": license,
|
|
24
|
+
"why": "mid-tier teacher",
|
|
25
|
+
},
|
|
26
|
+
"alternatives": [],
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_build_config_scales_with_params_b():
|
|
31
|
+
small = _scaffold.build_config(_rec(params_b=3.0), "wf")
|
|
32
|
+
mid = _scaffold.build_config(_rec(params_b=7.0), "wf")
|
|
33
|
+
big = _scaffold.build_config(_rec(params_b=14.0), "wf")
|
|
34
|
+
assert small["per_device_batch_size"] == 4 and small["lora_r"] == 16
|
|
35
|
+
assert mid["per_device_batch_size"] == 2 and mid["grad_accum"] == 8
|
|
36
|
+
assert big["lora_r"] == 32 and big["per_device_batch_size"] == 1
|
|
37
|
+
assert big["lr"] == 1e-4 and mid["lr"] == 2e-4
|
|
38
|
+
assert mid["lora_alpha"] == mid["lora_r"] * 2
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def test_seq_len_prefers_recommendation_then_falls_back():
|
|
42
|
+
assert _scaffold.build_config(_rec(seq_len=4096), "wf")["seq_len"] == 4096
|
|
43
|
+
# No seq_len → clamp(p95_input + max_output, 1024, 8192) = 1200
|
|
44
|
+
rec = _rec(seq_len=None)
|
|
45
|
+
assert _scaffold.build_config(rec, "wf")["seq_len"] == 1200
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_gated_license_warning():
|
|
49
|
+
assert _scaffold.gated_license_warning(_rec(license="apache-2.0")) is None
|
|
50
|
+
assert "gated" in _scaffold.gated_license_warning(_rec(license="llama-3.1")).lower()
|
|
51
|
+
assert _scaffold.gated_license_warning(_rec(license="gemma")) is not None
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_render_writes_a_runnable_project(tmp_path):
|
|
55
|
+
out = tmp_path / "train"
|
|
56
|
+
cfg = _scaffold.render(out, _rec(), "proj_abc", "wf", "http://localhost:8000")
|
|
57
|
+
|
|
58
|
+
for f in ("config.json", "recommendation.json", "train.py", "eval.py",
|
|
59
|
+
"requirements.txt", "Dockerfile", "Makefile", "README.md", ".dockerignore"):
|
|
60
|
+
assert (out / f).exists(), f
|
|
61
|
+
|
|
62
|
+
saved = json.loads((out / "config.json").read_text())
|
|
63
|
+
assert saved["base_model_hf"] == "Qwen/Qwen2.5-7B-Instruct"
|
|
64
|
+
assert saved == cfg
|
|
65
|
+
|
|
66
|
+
# Template substitution happened (no stray $placeholders left in Makefile),
|
|
67
|
+
# while make's own $(...) refs survive.
|
|
68
|
+
mk = (out / "Makefile").read_text()
|
|
69
|
+
assert "proj_abc" in mk and "$project" not in mk
|
|
70
|
+
assert "$(CURDIR)" in mk and "$(IMAGE)" in mk
|
|
71
|
+
# Dockerfile keeps the workflow name; train.py is shipped verbatim Python.
|
|
72
|
+
assert "wf" in (out / "Dockerfile").read_text()
|
|
73
|
+
assert "SFTTrainer" in (out / "train.py").read_text()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def test_client_builds_urls_with_query_and_bools():
|
|
77
|
+
c = BlitzClient(endpoint="http://x/", api_key="k", project="proj_1")
|
|
78
|
+
assert c._url("/a") == "http://x/a"
|
|
79
|
+
url = c._url("/a", {"workflow": "wf", "include_synthetic": False, "skip": None})
|
|
80
|
+
assert "workflow=wf" in url and "include_synthetic=false" in url and "skip" not in url
|