mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/rlm/generate.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Iterable, List, Optional, Sequence, Set
|
|
7
|
+
|
|
8
|
+
from ..util import sha1_text
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class GeneratedTask:
|
|
13
|
+
id: str
|
|
14
|
+
prompt: str
|
|
15
|
+
tests: str
|
|
16
|
+
description: Optional[str] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
_FALLBACK_TASKS = [
|
|
20
|
+
{
|
|
21
|
+
"id": "task_factorial",
|
|
22
|
+
"description": "Write a function solve(n: int) -> int that returns n! for n >= 0.",
|
|
23
|
+
"signature": "def solve(n: int) -> int",
|
|
24
|
+
"tests": [
|
|
25
|
+
{"input": [0], "expected": 1},
|
|
26
|
+
{"input": [5], "expected": 120},
|
|
27
|
+
{"input": [7], "expected": 5040},
|
|
28
|
+
],
|
|
29
|
+
},
|
|
30
|
+
{
|
|
31
|
+
"id": "task_reverse",
|
|
32
|
+
"description": "Write a function solve(s: str) -> str that returns the reverse of s.",
|
|
33
|
+
"signature": "def solve(s: str) -> str",
|
|
34
|
+
"tests": [
|
|
35
|
+
{"input": ["abc"], "expected": "cba"},
|
|
36
|
+
{"input": [""], "expected": ""},
|
|
37
|
+
{"input": ["racecar"], "expected": "racecar"},
|
|
38
|
+
],
|
|
39
|
+
},
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
_DEFAULT_BLOCKLIST = [
|
|
43
|
+
r"\bsubprocess\b",
|
|
44
|
+
r"\bos\.system\b",
|
|
45
|
+
r"\bshutil\.rmtree\b",
|
|
46
|
+
r"\brm\s+-rf\b",
|
|
47
|
+
r"\brequests\b",
|
|
48
|
+
r"\burllib\b",
|
|
49
|
+
r"\bsocket\b",
|
|
50
|
+
r"\bhttp[s]?://",
|
|
51
|
+
r"\bpip\s+install\b",
|
|
52
|
+
r"\bapt-get\b",
|
|
53
|
+
r"\bbrew\s+install\b",
|
|
54
|
+
]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def extract_json_objects(text: str) -> List[dict]:
|
|
58
|
+
# Try fenced json blocks first.
|
|
59
|
+
fenced = re.findall(r"```json\s*(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
|
|
60
|
+
chunks = fenced if fenced else [text]
|
|
61
|
+
|
|
62
|
+
items: List[dict] = []
|
|
63
|
+
for chunk in chunks:
|
|
64
|
+
buf = ""
|
|
65
|
+
depth = 0
|
|
66
|
+
for ch in chunk:
|
|
67
|
+
if ch == "{":
|
|
68
|
+
depth += 1
|
|
69
|
+
if depth > 0:
|
|
70
|
+
buf += ch
|
|
71
|
+
if ch == "}" and depth > 0:
|
|
72
|
+
depth -= 1
|
|
73
|
+
if depth == 0:
|
|
74
|
+
try:
|
|
75
|
+
items.append(json.loads(buf))
|
|
76
|
+
except Exception:
|
|
77
|
+
pass
|
|
78
|
+
buf = ""
|
|
79
|
+
return items
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _tests_from_cases(cases: Iterable[dict]) -> str:
|
|
83
|
+
lines = ["from main import solve", "", ""]
|
|
84
|
+
for idx, case in enumerate(cases):
|
|
85
|
+
args = case.get("input", [])
|
|
86
|
+
expected = case.get("expected")
|
|
87
|
+
if not isinstance(args, (list, tuple)):
|
|
88
|
+
args = [args]
|
|
89
|
+
lines.append(f"def test_case_{idx}():")
|
|
90
|
+
args_expr = ", ".join(repr(a) for a in args)
|
|
91
|
+
lines.append(f" assert solve({args_expr}) == {repr(expected)}")
|
|
92
|
+
lines.append("")
|
|
93
|
+
return "\n".join(lines).strip() + "\n"
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def task_to_prompt(task: dict, *, require_recursion: bool) -> str:
|
|
97
|
+
desc = task.get("description") or task.get("prompt") or ""
|
|
98
|
+
sig = task.get("signature") or "def solve(...):"
|
|
99
|
+
suffix = "\nUse recursion." if require_recursion else ""
|
|
100
|
+
return f"{desc}\n{sig}\nReturn only Python code.{suffix}".strip()
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def task_to_tests(task: dict) -> str:
|
|
104
|
+
tests = task.get("tests")
|
|
105
|
+
if isinstance(tests, str):
|
|
106
|
+
return tests
|
|
107
|
+
if isinstance(tests, list):
|
|
108
|
+
# List of pre-formatted test strings (e.g. from Qwen3-style generation)
|
|
109
|
+
if tests and isinstance(tests[0], str):
|
|
110
|
+
return "\n\n".join(tests).strip() + "\n"
|
|
111
|
+
# List of structured {input, expected} dicts
|
|
112
|
+
return _tests_from_cases(tests)
|
|
113
|
+
# fallback: trivial test that always fails to avoid false positives
|
|
114
|
+
return "def test_placeholder():\n assert False\n"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _token_set(text: str) -> Set[str]:
|
|
118
|
+
return set(re.findall(r"[a-z0-9_]+", text.lower()))
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _jaccard(a: Set[str], b: Set[str]) -> float:
|
|
122
|
+
if not a and not b:
|
|
123
|
+
return 1.0
|
|
124
|
+
denom = len(a | b)
|
|
125
|
+
if denom == 0:
|
|
126
|
+
return 0.0
|
|
127
|
+
return len(a & b) / denom
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def filter_tasks(
|
|
131
|
+
tasks: Sequence[GeneratedTask],
|
|
132
|
+
*,
|
|
133
|
+
existing_prompts: Optional[Sequence[str]] = None,
|
|
134
|
+
similarity_threshold: float = 0.85,
|
|
135
|
+
min_desc_len: int = 10,
|
|
136
|
+
min_asserts: int = 2,
|
|
137
|
+
max_prompt_len: int = 2000,
|
|
138
|
+
min_tests_len: int = 20,
|
|
139
|
+
max_tests_len: int = 8000,
|
|
140
|
+
blocked_patterns: Optional[Sequence[str]] = None,
|
|
141
|
+
) -> List[GeneratedTask]:
|
|
142
|
+
"""Filter tasks by basic quality + similarity + dedup."""
|
|
143
|
+
existing_prompts = existing_prompts or []
|
|
144
|
+
existing_tokens = [_token_set(p) for p in existing_prompts if p]
|
|
145
|
+
seen_hashes: Set[str] = set()
|
|
146
|
+
filtered: List[GeneratedTask] = []
|
|
147
|
+
patterns = list(blocked_patterns) if blocked_patterns else _DEFAULT_BLOCKLIST
|
|
148
|
+
|
|
149
|
+
for task in tasks:
|
|
150
|
+
prompt = task.prompt or ""
|
|
151
|
+
if len(prompt) > max_prompt_len:
|
|
152
|
+
continue
|
|
153
|
+
desc = task.description or prompt
|
|
154
|
+
if len(desc.strip()) < min_desc_len:
|
|
155
|
+
continue
|
|
156
|
+
tests = task.tests or ""
|
|
157
|
+
if len(tests.strip()) < min_tests_len or len(tests) > max_tests_len:
|
|
158
|
+
continue
|
|
159
|
+
if patterns:
|
|
160
|
+
blocked = False
|
|
161
|
+
for pattern in patterns:
|
|
162
|
+
if re.search(pattern, prompt, flags=re.IGNORECASE) or re.search(pattern, tests, flags=re.IGNORECASE):
|
|
163
|
+
blocked = True
|
|
164
|
+
break
|
|
165
|
+
if blocked:
|
|
166
|
+
continue
|
|
167
|
+
asserts = sum(1 for line in task.tests.splitlines() if line.strip().startswith("assert"))
|
|
168
|
+
if asserts < min_asserts:
|
|
169
|
+
continue
|
|
170
|
+
|
|
171
|
+
key = sha1_text(prompt + "\n" + task.tests)
|
|
172
|
+
if key in seen_hashes:
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
tokens = _token_set(prompt)
|
|
176
|
+
too_similar = False
|
|
177
|
+
for ex in existing_tokens:
|
|
178
|
+
if _jaccard(tokens, ex) >= similarity_threshold:
|
|
179
|
+
too_similar = True
|
|
180
|
+
break
|
|
181
|
+
if too_similar:
|
|
182
|
+
continue
|
|
183
|
+
|
|
184
|
+
seen_hashes.add(key)
|
|
185
|
+
filtered.append(task)
|
|
186
|
+
|
|
187
|
+
return filtered
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def generate_tasks(
|
|
191
|
+
llm,
|
|
192
|
+
*,
|
|
193
|
+
tasks_per_iter: int,
|
|
194
|
+
temperature: float,
|
|
195
|
+
max_new_tokens: int,
|
|
196
|
+
top_p: float,
|
|
197
|
+
top_k: Optional[int],
|
|
198
|
+
require_recursion: bool,
|
|
199
|
+
task_domains: Iterable[str],
|
|
200
|
+
) -> List[GeneratedTask]:
|
|
201
|
+
if tasks_per_iter <= 0:
|
|
202
|
+
return []
|
|
203
|
+
|
|
204
|
+
domain_list = ", ".join(task_domains) if task_domains else "general"
|
|
205
|
+
prompt = (
|
|
206
|
+
"Generate JSON objects for coding tasks. Each JSON must include: "
|
|
207
|
+
"id, description, signature, tests (array of {input, expected}). "
|
|
208
|
+
f"Domain focus: {domain_list}. Return only JSON objects."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
gen = llm.generate(
|
|
212
|
+
prompt,
|
|
213
|
+
max_new_tokens=max_new_tokens,
|
|
214
|
+
temperature=temperature,
|
|
215
|
+
top_p=top_p,
|
|
216
|
+
top_k=top_k,
|
|
217
|
+
)
|
|
218
|
+
items = extract_json_objects(gen.text)
|
|
219
|
+
tasks: List[GeneratedTask] = []
|
|
220
|
+
for item in items:
|
|
221
|
+
tid = item.get("id") or sha1_text(json.dumps(item, sort_keys=True))[:12]
|
|
222
|
+
task_prompt = task_to_prompt(item, require_recursion=require_recursion)
|
|
223
|
+
tests = task_to_tests(item)
|
|
224
|
+
tasks.append(
|
|
225
|
+
GeneratedTask(
|
|
226
|
+
id=str(tid),
|
|
227
|
+
prompt=task_prompt,
|
|
228
|
+
tests=tests,
|
|
229
|
+
description=item.get("description"),
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
if len(tasks) >= tasks_per_iter:
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
if not tasks:
|
|
236
|
+
for item in _FALLBACK_TASKS:
|
|
237
|
+
tid = item.get("id") or sha1_text(json.dumps(item, sort_keys=True))[:12]
|
|
238
|
+
tasks.append(
|
|
239
|
+
GeneratedTask(
|
|
240
|
+
id=str(tid),
|
|
241
|
+
prompt=task_to_prompt(item, require_recursion=require_recursion),
|
|
242
|
+
tests=task_to_tests(item),
|
|
243
|
+
description=item.get("description"),
|
|
244
|
+
)
|
|
245
|
+
)
|
|
246
|
+
if len(tasks) >= tasks_per_iter:
|
|
247
|
+
break
|
|
248
|
+
|
|
249
|
+
return tasks
|
mlxsmith/rlm/history.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from ..util import ensure_dir
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def append_history(path: Path, record: dict) -> None:
|
|
10
|
+
ensure_dir(path.parent)
|
|
11
|
+
with path.open("a", encoding="utf-8") as f:
|
|
12
|
+
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, List, Optional
|
|
7
|
+
|
|
8
|
+
from ..config import ProjectConfig
|
|
9
|
+
from ..util import ensure_dir, now_ts
|
|
10
|
+
from ..verifiers.docker_verifier import verify as docker_verify
|
|
11
|
+
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
12
|
+
from .generate import GeneratedTask, generate_tasks, filter_tasks
|
|
13
|
+
from .mutate import mutate_tasks
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class Rollout:
|
|
18
|
+
task_id: str
|
|
19
|
+
prompt: str
|
|
20
|
+
completion: str
|
|
21
|
+
token_ids: list[int]
|
|
22
|
+
prompt_len: int
|
|
23
|
+
logprobs: Optional[list[float]]
|
|
24
|
+
passed: bool
|
|
25
|
+
reward: float
|
|
26
|
+
verifier_latency_ms: float
|
|
27
|
+
weight_adapter: Optional[str]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def build_tasks(
|
|
31
|
+
llm,
|
|
32
|
+
cfg: ProjectConfig,
|
|
33
|
+
*,
|
|
34
|
+
require_recursion: bool,
|
|
35
|
+
tasks_per_iter: int,
|
|
36
|
+
mutations_per_task: int,
|
|
37
|
+
max_total: int,
|
|
38
|
+
existing_prompts: Optional[list[str]] = None,
|
|
39
|
+
) -> List[GeneratedTask]:
|
|
40
|
+
tasks = generate_tasks(
|
|
41
|
+
llm,
|
|
42
|
+
tasks_per_iter=tasks_per_iter,
|
|
43
|
+
temperature=float(cfg.rft.temperature),
|
|
44
|
+
max_new_tokens=int(cfg.rlm.task_gen_max_new_tokens),
|
|
45
|
+
top_p=float(cfg.infer.top_p),
|
|
46
|
+
top_k=cfg.infer.top_k,
|
|
47
|
+
require_recursion=require_recursion,
|
|
48
|
+
task_domains=list(cfg.rlm.task_domains or []),
|
|
49
|
+
)
|
|
50
|
+
if bool(cfg.rlm.use_task_mutation) and int(mutations_per_task) > 0:
|
|
51
|
+
tasks = mutate_tasks(
|
|
52
|
+
llm,
|
|
53
|
+
tasks,
|
|
54
|
+
mutations_per_task=mutations_per_task,
|
|
55
|
+
max_total=max_total,
|
|
56
|
+
temperature=float(cfg.rft.temperature),
|
|
57
|
+
max_new_tokens=int(cfg.rlm.task_gen_max_new_tokens),
|
|
58
|
+
top_p=float(cfg.infer.top_p),
|
|
59
|
+
top_k=cfg.infer.top_k,
|
|
60
|
+
require_recursion=require_recursion,
|
|
61
|
+
)
|
|
62
|
+
filtered = filter_tasks(
|
|
63
|
+
tasks,
|
|
64
|
+
existing_prompts=existing_prompts,
|
|
65
|
+
similarity_threshold=float(getattr(cfg.rlm, "similarity_threshold", 0.85)),
|
|
66
|
+
min_desc_len=int(getattr(cfg.rlm, "min_task_desc_len", 10)),
|
|
67
|
+
min_asserts=int(getattr(cfg.rlm, "min_task_asserts", 2)),
|
|
68
|
+
max_prompt_len=int(getattr(cfg.rlm, "max_task_prompt_len", 2000)),
|
|
69
|
+
min_tests_len=int(getattr(cfg.rlm, "min_task_tests_len", 20)),
|
|
70
|
+
max_tests_len=int(getattr(cfg.rlm, "max_task_tests_len", 8000)),
|
|
71
|
+
blocked_patterns=getattr(cfg.rlm, "blocked_task_patterns", None),
|
|
72
|
+
)
|
|
73
|
+
return filtered or tasks
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _write_task_tests(task: GeneratedTask, workdir: Path) -> None:
|
|
77
|
+
tests_dir = ensure_dir(workdir / "tests")
|
|
78
|
+
(tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def collect_rollouts(
|
|
82
|
+
llm,
|
|
83
|
+
tasks: Iterable[GeneratedTask],
|
|
84
|
+
cfg: ProjectConfig,
|
|
85
|
+
*,
|
|
86
|
+
artifacts_dir: Path,
|
|
87
|
+
verifier_backend: str,
|
|
88
|
+
weight_adapter: Optional[str],
|
|
89
|
+
) -> tuple[List[Rollout], list[dict]]:
|
|
90
|
+
rollouts: List[Rollout] = []
|
|
91
|
+
passed_samples: list[dict] = []
|
|
92
|
+
|
|
93
|
+
for task in tasks:
|
|
94
|
+
for k in range(int(cfg.rlm.rollouts_per_task)):
|
|
95
|
+
gen = llm.generate_with_logprobs(
|
|
96
|
+
task.prompt,
|
|
97
|
+
max_new_tokens=int(cfg.rft.max_new_tokens),
|
|
98
|
+
temperature=float(cfg.rft.temperature),
|
|
99
|
+
seed=int(time.time() * 1000) % (2**31 - 1),
|
|
100
|
+
)
|
|
101
|
+
completion = gen.text[len(task.prompt) :] if gen.text.startswith(task.prompt) else gen.text
|
|
102
|
+
wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
|
|
103
|
+
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
104
|
+
_write_task_tests(task, wdir)
|
|
105
|
+
|
|
106
|
+
t0 = time.time()
|
|
107
|
+
if verifier_backend == "docker":
|
|
108
|
+
res = docker_verify(
|
|
109
|
+
task.prompt,
|
|
110
|
+
completion,
|
|
111
|
+
str(wdir),
|
|
112
|
+
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
113
|
+
image=cfg.rlm.docker_image,
|
|
114
|
+
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
115
|
+
cpus=float(cfg.rlm.docker_cpus),
|
|
116
|
+
pids=int(cfg.rlm.docker_pids),
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
|
|
120
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
121
|
+
|
|
122
|
+
passed = bool(getattr(res, "passed", False))
|
|
123
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
124
|
+
rollouts.append(
|
|
125
|
+
Rollout(
|
|
126
|
+
task_id=task.id,
|
|
127
|
+
prompt=task.prompt,
|
|
128
|
+
completion=completion,
|
|
129
|
+
token_ids=list(gen.token_ids),
|
|
130
|
+
prompt_len=int(gen.prompt_len),
|
|
131
|
+
logprobs=list(gen.logprobs) if gen.logprobs is not None else None,
|
|
132
|
+
passed=passed,
|
|
133
|
+
reward=reward,
|
|
134
|
+
verifier_latency_ms=latency_ms,
|
|
135
|
+
weight_adapter=weight_adapter,
|
|
136
|
+
)
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
if passed:
|
|
140
|
+
passed_samples.append(
|
|
141
|
+
{
|
|
142
|
+
"id": task.id,
|
|
143
|
+
"prompt": task.prompt,
|
|
144
|
+
"response": completion,
|
|
145
|
+
"reward": reward,
|
|
146
|
+
"ts": now_ts(),
|
|
147
|
+
}
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return rollouts, passed_samples
|