synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.9.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +450 -0
- synth_ai/api/train/config_finder.py +168 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +193 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +18 -6
- synth_ai/cli/root.py +38 -6
- synth_ai/cli/task_apps.py +1107 -0
- synth_ai/demo_registry.py +258 -0
- synth_ai/demos/core/cli.py +147 -111
- synth_ai/demos/demo_task_apps/__init__.py +7 -1
- synth_ai/demos/demo_task_apps/math/config.toml +55 -110
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +157 -21
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +39 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +153 -0
- synth_ai/task/client.py +165 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/RECORD +42 -18
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.9.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,852 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Task app configuration for a single-step math reasoning environment."""
|
|
4
|
+
|
|
5
|
+
import contextlib
|
|
6
|
+
import os
|
|
7
|
+
import random
|
|
8
|
+
import re
|
|
9
|
+
import uuid
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence, cast
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
from datasets import load_dataset
|
|
16
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
from ..contracts import (
|
|
20
|
+
RolloutMetrics,
|
|
21
|
+
RolloutRequest,
|
|
22
|
+
RolloutResponse,
|
|
23
|
+
RolloutStep,
|
|
24
|
+
RolloutTrajectory,
|
|
25
|
+
TaskInfo,
|
|
26
|
+
)
|
|
27
|
+
from ..datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
28
|
+
from ..rubrics import Rubric, load_rubric
|
|
29
|
+
from ..server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
30
|
+
from ..errors import http_exception
|
|
31
|
+
from ..tracing_utils import (
|
|
32
|
+
build_tracer_factory,
|
|
33
|
+
resolve_sft_output_dir,
|
|
34
|
+
resolve_tracing_db_url,
|
|
35
|
+
tracing_env_enabled,
|
|
36
|
+
)
|
|
37
|
+
from ..vendors import normalize_vendor_keys
|
|
38
|
+
from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
39
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
40
|
+
|
|
41
|
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
42
|
+
|
|
43
|
+
_modal_volume_candidate = Path(os.getenv("MATH_MODAL_DATASET_DIR", "/modal_volumes/math_dataset")).expanduser()
|
|
44
|
+
_modal_volume_root: Optional[Path] = None
|
|
45
|
+
try:
|
|
46
|
+
_modal_volume_candidate.mkdir(parents=True, exist_ok=True)
|
|
47
|
+
_modal_volume_root = _modal_volume_candidate
|
|
48
|
+
except Exception:
|
|
49
|
+
if _modal_volume_candidate.exists():
|
|
50
|
+
_modal_volume_root = _modal_volume_candidate
|
|
51
|
+
|
|
52
|
+
if _modal_volume_root is not None:
|
|
53
|
+
hf_cache_path = _modal_volume_root / "hf_cache"
|
|
54
|
+
local_dataset_dir = _modal_volume_root / "jsonl"
|
|
55
|
+
local_dataset_dir.mkdir(parents=True, exist_ok=True)
|
|
56
|
+
os.environ.setdefault("MATH_DATASET_LOCAL_DIR", str(local_dataset_dir))
|
|
57
|
+
else:
|
|
58
|
+
hf_cache_path = Path(os.getenv("MATH_DATASET_CACHE_DIR", str(REPO_ROOT / ".cache" / "hf-datasets")) ).expanduser()
|
|
59
|
+
|
|
60
|
+
hf_cache_path.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
os.environ.setdefault("MATH_DATASET_CACHE_DIR", str(hf_cache_path))
|
|
62
|
+
os.environ.setdefault("HF_HOME", str(hf_cache_path))
|
|
63
|
+
os.environ.setdefault("HF_DATASETS_CACHE", str(hf_cache_path))
|
|
64
|
+
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", str(hf_cache_path))
|
|
65
|
+
|
|
66
|
+
HF_DATASETS_CACHE = hf_cache_path
|
|
67
|
+
DATASET_NAME = os.getenv("MATH_DATASET_NAME", "nlile/hendrycks-MATH-benchmark")
|
|
68
|
+
DATASET_CONFIG = os.getenv("MATH_DATASET_CONFIG", "")
|
|
69
|
+
DEFAULT_SPLIT = os.getenv("MATH_DATASET_DEFAULT_SPLIT", "train")
|
|
70
|
+
AVAILABLE_SPLITS: tuple[str, ...] = (
|
|
71
|
+
DEFAULT_SPLIT,
|
|
72
|
+
os.getenv("MATH_DATASET_VALIDATION_SPLIT", "test"),
|
|
73
|
+
os.getenv("MATH_DATASET_TEST_SPLIT", "test"),
|
|
74
|
+
)
|
|
75
|
+
TOOL_NAME = "math_submit"
|
|
76
|
+
PROBLEM_KEYS: tuple[str, ...] = ("problem", "question", "prompt", "query")
|
|
77
|
+
SOLUTION_KEYS: tuple[str, ...] = ("solution", "answer", "final_answer", "solution_text")
|
|
78
|
+
|
|
79
|
+
REWARD_POSITIVE = float(os.getenv("MATH_REWARD_POSITIVE", "1.0"))
|
|
80
|
+
REWARD_NEGATIVE_NO_TOOL = float(os.getenv("MATH_REWARD_NEGATIVE_NO_TOOL", "-1.0"))
|
|
81
|
+
REWARD_NEGATIVE_NO_ANSWER = float(os.getenv("MATH_REWARD_NEGATIVE_NO_ANSWER", "-0.5"))
|
|
82
|
+
|
|
83
|
+
HF_TOKEN_ENV_KEYS: tuple[str, ...] = (
|
|
84
|
+
"HF_DATASETS_TOKEN",
|
|
85
|
+
"HUGGINGFACEHUB_API_TOKEN",
|
|
86
|
+
"HUGGINGFACE_TOKEN",
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
## Single-source dataset policy: use a single known-good HF dataset id by default.
|
|
90
|
+
|
|
91
|
+
MATH_DATASET_SPEC = TaskDatasetSpec(
|
|
92
|
+
id="math_single_step",
|
|
93
|
+
name="MATH Single Step",
|
|
94
|
+
version="1.0.0",
|
|
95
|
+
splits=list(dict.fromkeys(split for split in AVAILABLE_SPLITS if split)),
|
|
96
|
+
default_split=DEFAULT_SPLIT,
|
|
97
|
+
description="Single-step math reasoning problems sourced from the Hendrycks MATH dataset.",
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
_BOXED_MARKERS: tuple[str, ...] = ("\\boxed", "boxed")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_boxed(text: str) -> Optional[str]:
|
|
105
|
+
if not text:
|
|
106
|
+
return None
|
|
107
|
+
for marker in _BOXED_MARKERS:
|
|
108
|
+
start = text.find(marker)
|
|
109
|
+
if start == -1:
|
|
110
|
+
continue
|
|
111
|
+
brace_start = text.find("{", start)
|
|
112
|
+
if brace_start == -1:
|
|
113
|
+
continue
|
|
114
|
+
depth = 1
|
|
115
|
+
idx = brace_start + 1
|
|
116
|
+
while idx < len(text) and depth > 0:
|
|
117
|
+
ch = text[idx]
|
|
118
|
+
if ch == "{":
|
|
119
|
+
depth += 1
|
|
120
|
+
elif ch == "}":
|
|
121
|
+
depth -= 1
|
|
122
|
+
idx += 1
|
|
123
|
+
if depth == 0:
|
|
124
|
+
return text[brace_start + 1 : idx - 1].strip()
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
_FRAC_PATTERN = re.compile(r"\\?frac\{([^{}]+)\}\{([^{}]+)\}")
|
|
129
|
+
_SQRT_PATTERN = re.compile(r"\\?sqrt\{([^{}]+)\}")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _normalize_final_answer(text: str) -> str:
|
|
133
|
+
raw = str(text or "").strip()
|
|
134
|
+
if not raw:
|
|
135
|
+
return ""
|
|
136
|
+
boxed = _extract_boxed(raw)
|
|
137
|
+
if boxed:
|
|
138
|
+
raw = boxed
|
|
139
|
+
raw = raw.strip().strip("$")
|
|
140
|
+
raw = raw.replace("\\left", "").replace("\\right", "")
|
|
141
|
+
raw = raw.replace("\\!", "").replace("\\,", " ").replace("\\;", " ")
|
|
142
|
+
raw = raw.replace("left", "").replace("right", "")
|
|
143
|
+
raw = raw.replace("\\times", "*").replace("\\cdot", "*")
|
|
144
|
+
raw = raw.replace("\\pi", "pi").replace("\\theta", "theta").replace("\\phi", "phi")
|
|
145
|
+
raw = raw.replace("\\pm", "+/-").replace("\\mp", "-/+")
|
|
146
|
+
raw = raw.replace("^{\\circ}", "deg").replace("^\\circ", "deg").replace("\\circ", "deg")
|
|
147
|
+
|
|
148
|
+
def _frac_sub(match: re.Match[str]) -> str:
|
|
149
|
+
num = match.group(1).strip()
|
|
150
|
+
den = match.group(2).strip()
|
|
151
|
+
return f"({num})/({den})"
|
|
152
|
+
|
|
153
|
+
def _sqrt_sub(match: re.Match[str]) -> str:
|
|
154
|
+
inner = match.group(1).strip()
|
|
155
|
+
return f"sqrt({inner})"
|
|
156
|
+
|
|
157
|
+
raw = _FRAC_PATTERN.sub(_frac_sub, raw)
|
|
158
|
+
raw = _SQRT_PATTERN.sub(_sqrt_sub, raw)
|
|
159
|
+
raw = raw.replace("\\", "")
|
|
160
|
+
raw = raw.replace("{", "").replace("}", "")
|
|
161
|
+
raw = raw.replace(" ", "")
|
|
162
|
+
raw = raw.rstrip(".")
|
|
163
|
+
return raw
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class MathDataset:
|
|
167
|
+
"""Lazy Hugging Face dataset loader for EleutherAI/math splits."""
|
|
168
|
+
|
|
169
|
+
def __init__(self, *, name: str, config: str, splits: Sequence[str]) -> None:
|
|
170
|
+
self.name = name
|
|
171
|
+
self.config = config
|
|
172
|
+
self.splits = [split for split in splits if split]
|
|
173
|
+
self._cache: Dict[str, Any] = {}
|
|
174
|
+
self._local_dir = os.getenv("MATH_DATASET_LOCAL_DIR")
|
|
175
|
+
self._hf_token: Optional[str] = None
|
|
176
|
+
for key in HF_TOKEN_ENV_KEYS:
|
|
177
|
+
value = os.getenv(key)
|
|
178
|
+
if value:
|
|
179
|
+
trimmed = value.strip()
|
|
180
|
+
if trimmed:
|
|
181
|
+
self._hf_token = trimmed
|
|
182
|
+
break
|
|
183
|
+
# No multi-candidate fallback: enforce explicit dataset id
|
|
184
|
+
|
|
185
|
+
def _local_file_for_split(self, split: str) -> Optional[Path]:
|
|
186
|
+
specific = os.getenv(f"MATH_DATASET_LOCAL_{split.upper()}_FILE")
|
|
187
|
+
if specific:
|
|
188
|
+
path = Path(specific).expanduser()
|
|
189
|
+
if path.exists():
|
|
190
|
+
return path
|
|
191
|
+
if self._local_dir:
|
|
192
|
+
candidate = Path(self._local_dir).expanduser() / f"{split}.jsonl"
|
|
193
|
+
if candidate.exists():
|
|
194
|
+
return candidate
|
|
195
|
+
return None
|
|
196
|
+
|
|
197
|
+
def _load_split(self, split: str):
|
|
198
|
+
# Treat 'validation' as an alias for 'test' for datasets without a separate validation split
|
|
199
|
+
if split not in self.splits and split.lower() == "validation":
|
|
200
|
+
split = "test"
|
|
201
|
+
if split not in self.splits:
|
|
202
|
+
raise ValueError(f"Unknown split '{split}'. Available: {self.splits}")
|
|
203
|
+
if split not in self._cache:
|
|
204
|
+
local_file = self._local_file_for_split(split)
|
|
205
|
+
if local_file is not None:
|
|
206
|
+
dataset = load_dataset("json", data_files=str(local_file), cache_dir=str(HF_DATASETS_CACHE))
|
|
207
|
+
self._cache[split] = dataset["train"]
|
|
208
|
+
else:
|
|
209
|
+
try:
|
|
210
|
+
load_kwargs: Dict[str, Any] = {"split": split}
|
|
211
|
+
if self.config:
|
|
212
|
+
load_kwargs["name"] = self.config
|
|
213
|
+
if self._hf_token:
|
|
214
|
+
load_kwargs["use_auth_token"] = self._hf_token
|
|
215
|
+
ds = load_dataset(self.name, cache_dir=str(HF_DATASETS_CACHE), **load_kwargs)
|
|
216
|
+
self._cache[split] = ds
|
|
217
|
+
if self._local_dir:
|
|
218
|
+
local_dir = Path(self._local_dir).expanduser()
|
|
219
|
+
target = local_dir / f"{split}.jsonl"
|
|
220
|
+
if not target.exists() and hasattr(ds, "to_json"):
|
|
221
|
+
tmp_path = target.with_name(target.name + ".tmp")
|
|
222
|
+
try:
|
|
223
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
|
224
|
+
getattr(ds, "to_json")(str(tmp_path))
|
|
225
|
+
tmp_path.replace(target)
|
|
226
|
+
except Exception:
|
|
227
|
+
with contextlib.suppress(FileNotFoundError):
|
|
228
|
+
tmp_path.unlink()
|
|
229
|
+
except Exception as exc:
|
|
230
|
+
hints = [
|
|
231
|
+
"Failed to download MATH dataset from Hugging Face.",
|
|
232
|
+
f"Dataset: {self.name} | Config: {self.config or 'none'} | Split: {split}",
|
|
233
|
+
"If this persists, verify MATH_DATASET_NAME/MATH_DATASET_CONFIG or set MATH_DATASET_LOCAL_DIR to pre-downloaded JSONL files.",
|
|
234
|
+
]
|
|
235
|
+
raise RuntimeError(" ".join(hints)) from exc
|
|
236
|
+
return self._cache[split]
|
|
237
|
+
|
|
238
|
+
def sample(self, *, split: str, index: Optional[int] = None) -> Dict[str, Any]:
|
|
239
|
+
dataset = self._load_split(split)
|
|
240
|
+
if len(dataset) == 0:
|
|
241
|
+
raise RuntimeError(f"Dataset split '{split}' is empty")
|
|
242
|
+
if index is None:
|
|
243
|
+
index = random.randint(0, len(dataset) - 1)
|
|
244
|
+
idx = int(index) % len(dataset)
|
|
245
|
+
item = dataset[int(idx)]
|
|
246
|
+
|
|
247
|
+
raw_problem = ""
|
|
248
|
+
for key in PROBLEM_KEYS:
|
|
249
|
+
value = item.get(key)
|
|
250
|
+
if isinstance(value, str) and value.strip():
|
|
251
|
+
raw_problem = value.strip()
|
|
252
|
+
break
|
|
253
|
+
if not raw_problem:
|
|
254
|
+
raise RuntimeError(f"Sample missing problem field for split '{split}' index {idx}")
|
|
255
|
+
|
|
256
|
+
solution_value: Any = None
|
|
257
|
+
for key in SOLUTION_KEYS:
|
|
258
|
+
if key in item:
|
|
259
|
+
solution_value = item[key]
|
|
260
|
+
break
|
|
261
|
+
if solution_value is None:
|
|
262
|
+
raise RuntimeError(f"Sample missing solution field for split '{split}' index {idx}")
|
|
263
|
+
|
|
264
|
+
# Solutions can contain reasoning and final answer; take final line by convention
|
|
265
|
+
if isinstance(solution_value, list):
|
|
266
|
+
solution_text = "\n".join(str(part) for part in solution_value)
|
|
267
|
+
else:
|
|
268
|
+
solution_text = str(solution_value)
|
|
269
|
+
lines = [line.strip() for line in solution_text.strip().splitlines() if line.strip()]
|
|
270
|
+
final_line = ""
|
|
271
|
+
for line in reversed(lines):
|
|
272
|
+
lowered = line.lower()
|
|
273
|
+
if "boxed" in lowered or "answer" in lowered:
|
|
274
|
+
final_line = line
|
|
275
|
+
break
|
|
276
|
+
if not final_line and lines:
|
|
277
|
+
final_line = lines[-1]
|
|
278
|
+
candidate_answer = final_line or solution_text.strip()
|
|
279
|
+
normalized_answer = _normalize_final_answer(candidate_answer)
|
|
280
|
+
if not normalized_answer:
|
|
281
|
+
normalized_answer = _normalize_final_answer(solution_text)
|
|
282
|
+
return {
|
|
283
|
+
"index": idx,
|
|
284
|
+
"split": split,
|
|
285
|
+
"problem": raw_problem,
|
|
286
|
+
"answer": normalized_answer,
|
|
287
|
+
"raw_solution": solution_text,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
def size(self, split: str) -> int:
|
|
291
|
+
dataset = self._load_split(split)
|
|
292
|
+
return len(dataset)
|
|
293
|
+
|
|
294
|
+
def ensure_ready(self, required_splits: Sequence[str]) -> None:
|
|
295
|
+
errors: list[str] = []
|
|
296
|
+
for split in required_splits:
|
|
297
|
+
if not split:
|
|
298
|
+
continue
|
|
299
|
+
try:
|
|
300
|
+
self._load_split(split)
|
|
301
|
+
except Exception as exc:
|
|
302
|
+
errors.append(f"{split}: {exc}")
|
|
303
|
+
if errors:
|
|
304
|
+
raise RuntimeError(
|
|
305
|
+
"Dataset preparation failed:\n" + "\n".join(errors)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
@dataclass
|
|
310
|
+
class MathEnvState:
|
|
311
|
+
env_id: str
|
|
312
|
+
split: str
|
|
313
|
+
index: int
|
|
314
|
+
problem: str
|
|
315
|
+
answer: str
|
|
316
|
+
raw_solution: str
|
|
317
|
+
done: bool = False
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class MathEnvironmentManager:
|
|
321
|
+
"""Stores in-flight environment state keyed by env_id."""
|
|
322
|
+
|
|
323
|
+
def __init__(self, dataset: MathDataset) -> None:
|
|
324
|
+
self.dataset = dataset
|
|
325
|
+
self._states: Dict[str, MathEnvState] = {}
|
|
326
|
+
|
|
327
|
+
def create(self, *, split: str, index: Optional[int], seed: Optional[int]) -> MathEnvState:
|
|
328
|
+
if index is None and seed is not None:
|
|
329
|
+
index = seed
|
|
330
|
+
sample = self.dataset.sample(split=split, index=index)
|
|
331
|
+
env_id = str(uuid.uuid4())
|
|
332
|
+
state = MathEnvState(
|
|
333
|
+
env_id=env_id,
|
|
334
|
+
split=split,
|
|
335
|
+
index=int(sample["index"]),
|
|
336
|
+
problem=sample["problem"],
|
|
337
|
+
answer=sample["answer"],
|
|
338
|
+
raw_solution=sample["raw_solution"],
|
|
339
|
+
)
|
|
340
|
+
self._states[env_id] = state
|
|
341
|
+
return state
|
|
342
|
+
|
|
343
|
+
def get(self, env_id: str) -> MathEnvState:
|
|
344
|
+
if env_id not in self._states:
|
|
345
|
+
raise KeyError(f"Unknown env_id: {env_id}")
|
|
346
|
+
return self._states[env_id]
|
|
347
|
+
|
|
348
|
+
def terminate(self, env_id: str) -> None:
|
|
349
|
+
self._states.pop(env_id, None)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class InitializePayload(BaseModel):
|
|
353
|
+
seed: Optional[int] = None
|
|
354
|
+
config: Dict[str, Any] = Field(default_factory=dict)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def _observation_from_state(state: MathEnvState) -> Dict[str, Any]:
|
|
358
|
+
return {
|
|
359
|
+
"problem": state.problem,
|
|
360
|
+
"split": state.split,
|
|
361
|
+
"index": state.index,
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def _score_submission(state: MathEnvState, tool_calls: Sequence[Mapping[str, Any]]) -> tuple[float, str, bool]:
|
|
366
|
+
if not tool_calls:
|
|
367
|
+
return REWARD_NEGATIVE_NO_TOOL, "missing_tool_call", False
|
|
368
|
+
call = tool_calls[0]
|
|
369
|
+
tool_name = str(call.get("tool") or "").strip()
|
|
370
|
+
if tool_name != TOOL_NAME:
|
|
371
|
+
return REWARD_NEGATIVE_NO_TOOL, "wrong_tool", False
|
|
372
|
+
args = call.get("args") or {}
|
|
373
|
+
answer = _normalize_final_answer(str(args.get("answer") or ""))
|
|
374
|
+
if not answer:
|
|
375
|
+
return REWARD_NEGATIVE_NO_ANSWER, "blank_answer", False
|
|
376
|
+
is_correct = answer == state.answer
|
|
377
|
+
return (REWARD_POSITIVE if is_correct else 0.0), ("correct" if is_correct else "incorrect"), is_correct
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
math_router = APIRouter()
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
@math_router.post("/env/math/initialize")
|
|
384
|
+
async def initialize_env(request: Request, payload: InitializePayload) -> Dict[str, Any]:
|
|
385
|
+
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
386
|
+
split = str(payload.config.get("split") or DEFAULT_SPLIT)
|
|
387
|
+
seed = payload.seed
|
|
388
|
+
index = None
|
|
389
|
+
if payload.config.get("index") is not None:
|
|
390
|
+
index = int(payload.config["index"])
|
|
391
|
+
state = manager.create(split=split, index=index, seed=seed)
|
|
392
|
+
return {
|
|
393
|
+
"env_id": state.env_id,
|
|
394
|
+
"observation": _observation_from_state(state),
|
|
395
|
+
"info": {"raw_solution": state.raw_solution},
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
@math_router.post("/env/math/step")
|
|
400
|
+
async def step_env(request: Request, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
401
|
+
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
402
|
+
env_id = str(payload.get("env_id") or "")
|
|
403
|
+
if not env_id:
|
|
404
|
+
raise HTTPException(status_code=400, detail="env_id required")
|
|
405
|
+
try:
|
|
406
|
+
state = manager.get(env_id)
|
|
407
|
+
except KeyError as exc: # pragma: no cover - defensive
|
|
408
|
+
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
|
409
|
+
|
|
410
|
+
action = payload.get("action") or {}
|
|
411
|
+
tool_calls = action.get("tool_calls") or payload.get("tool_calls") or []
|
|
412
|
+
reward, status, correct = _score_submission(state, tool_calls)
|
|
413
|
+
state.done = True
|
|
414
|
+
|
|
415
|
+
observation = _observation_from_state(state)
|
|
416
|
+
observation["status"] = status
|
|
417
|
+
return {
|
|
418
|
+
"observation": observation,
|
|
419
|
+
"done": True,
|
|
420
|
+
"reward": reward,
|
|
421
|
+
"info": {
|
|
422
|
+
"correct": correct,
|
|
423
|
+
"expected_answer": state.answer,
|
|
424
|
+
"raw_solution": state.raw_solution,
|
|
425
|
+
},
|
|
426
|
+
}
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
@math_router.post("/env/math/terminate")
|
|
430
|
+
async def terminate_env(request: Request, payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
431
|
+
manager: MathEnvironmentManager = request.app.state.math_env_manager
|
|
432
|
+
env_id = str(payload.get("env_id") or "")
|
|
433
|
+
if env_id:
|
|
434
|
+
manager.terminate(env_id)
|
|
435
|
+
return {"ok": True}
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def _resolve_inference_url(base_url: str) -> str:
|
|
439
|
+
normalized = (base_url or "").rstrip("/")
|
|
440
|
+
if not normalized:
|
|
441
|
+
raise RuntimeError("policy.config.inference_url required")
|
|
442
|
+
if normalized.endswith("/v1/chat/completions"):
|
|
443
|
+
return normalized
|
|
444
|
+
if normalized.endswith("/chat/completions"):
|
|
445
|
+
return normalized
|
|
446
|
+
if normalized.endswith("/v1"):
|
|
447
|
+
return f"{normalized}/chat/completions"
|
|
448
|
+
return f"{normalized}/v1/chat/completions"
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
async def _call_inference(policy_config: Mapping[str, Any], observation: Mapping[str, Any]) -> tuple[list[Dict[str, Any]], Dict[str, Any]]:
|
|
452
|
+
inference_url = str(policy_config.get("inference_url") or "").rstrip("/")
|
|
453
|
+
if not inference_url:
|
|
454
|
+
raise RuntimeError("policy.config.inference_url required for rollout")
|
|
455
|
+
model = policy_config.get("model")
|
|
456
|
+
max_tokens = policy_config.get("max_tokens", 512)
|
|
457
|
+
temperature = policy_config.get("temperature", 0.0)
|
|
458
|
+
top_p = policy_config.get("top_p", 1.0)
|
|
459
|
+
|
|
460
|
+
messages = [
|
|
461
|
+
{
|
|
462
|
+
"role": "system",
|
|
463
|
+
"content": (
|
|
464
|
+
"You are a math solver. Read the problem carefully and respond with a single"
|
|
465
|
+
f" tool call using the function `{TOOL_NAME}`."
|
|
466
|
+
"\nRules:\n"
|
|
467
|
+
"- Do all reasoning internally.\n"
|
|
468
|
+
"- The tool call must include ONLY the final numeric or simplified answer in the"
|
|
469
|
+
" `answer` field.\n"
|
|
470
|
+
"- DO NOT include explanations, units, or extra text in the answer."
|
|
471
|
+
),
|
|
472
|
+
},
|
|
473
|
+
{
|
|
474
|
+
"role": "user",
|
|
475
|
+
"content": (
|
|
476
|
+
"Problem:\n"
|
|
477
|
+
+ str(observation.get("problem") or "")
|
|
478
|
+
+ "\nSubmit the final answer via the tool call."
|
|
479
|
+
),
|
|
480
|
+
},
|
|
481
|
+
]
|
|
482
|
+
|
|
483
|
+
payload: Dict[str, Any] = {
|
|
484
|
+
"model": model,
|
|
485
|
+
"messages": messages,
|
|
486
|
+
"tools": [
|
|
487
|
+
{
|
|
488
|
+
"type": "function",
|
|
489
|
+
"function": {
|
|
490
|
+
"name": TOOL_NAME,
|
|
491
|
+
"description": "Submit the final answer for the math problem.",
|
|
492
|
+
"parameters": {
|
|
493
|
+
"type": "object",
|
|
494
|
+
"properties": {
|
|
495
|
+
"answer": {
|
|
496
|
+
"type": "string",
|
|
497
|
+
"description": "Final answer in simplest form",
|
|
498
|
+
},
|
|
499
|
+
"explanation": {
|
|
500
|
+
"type": "string",
|
|
501
|
+
"description": "Optional explanation of reasoning",
|
|
502
|
+
},
|
|
503
|
+
},
|
|
504
|
+
"required": ["answer"],
|
|
505
|
+
"additionalProperties": False,
|
|
506
|
+
},
|
|
507
|
+
},
|
|
508
|
+
}
|
|
509
|
+
],
|
|
510
|
+
"tool_choice": {"type": "function", "function": {"name": TOOL_NAME}},
|
|
511
|
+
"temperature": temperature,
|
|
512
|
+
"top_p": top_p,
|
|
513
|
+
"max_tokens": max_tokens,
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
final_url = _resolve_inference_url(inference_url)
|
|
517
|
+
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
|
|
518
|
+
response = await client.post(final_url, json=payload)
|
|
519
|
+
try:
|
|
520
|
+
data = response.json()
|
|
521
|
+
except Exception as exc:
|
|
522
|
+
raise http_exception(
|
|
523
|
+
502,
|
|
524
|
+
"inference_invalid_response",
|
|
525
|
+
"Inference provider returned invalid JSON",
|
|
526
|
+
extra={"body": response.text[:800]},
|
|
527
|
+
) from exc
|
|
528
|
+
if response.status_code >= 500:
|
|
529
|
+
raise http_exception(
|
|
530
|
+
502,
|
|
531
|
+
"inference_upstream_error",
|
|
532
|
+
"Inference provider returned an error",
|
|
533
|
+
extra={"status": response.status_code, "body": data},
|
|
534
|
+
)
|
|
535
|
+
if response.status_code >= 400:
|
|
536
|
+
raise http_exception(
|
|
537
|
+
400,
|
|
538
|
+
"inference_request_invalid",
|
|
539
|
+
"Invalid inference request",
|
|
540
|
+
extra={"status": response.status_code, "body": data},
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
tool_calls = []
|
|
544
|
+
choices = data.get("choices") or []
|
|
545
|
+
if choices:
|
|
546
|
+
message = choices[0].get("message") or {}
|
|
547
|
+
raw_calls = message.get("tool_calls") or []
|
|
548
|
+
for call in raw_calls:
|
|
549
|
+
function = call.get("function") or {}
|
|
550
|
+
name = function.get("name")
|
|
551
|
+
arguments = function.get("arguments")
|
|
552
|
+
parsed_args: Dict[str, Any]
|
|
553
|
+
if isinstance(arguments, str):
|
|
554
|
+
try:
|
|
555
|
+
import json
|
|
556
|
+
|
|
557
|
+
parsed_args = json.loads(arguments)
|
|
558
|
+
except Exception:
|
|
559
|
+
parsed_args = {}
|
|
560
|
+
elif isinstance(arguments, MutableMapping):
|
|
561
|
+
parsed_args = dict(arguments)
|
|
562
|
+
else:
|
|
563
|
+
parsed_args = {}
|
|
564
|
+
tool_calls.append({"tool": name, "args": parsed_args})
|
|
565
|
+
return tool_calls, data
|
|
566
|
+
|
|
567
|
+
|
|
568
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request: Request) -> RolloutResponse:
|
|
569
|
+
dataset: MathDataset = fastapi_request.app.state.math_dataset
|
|
570
|
+
split = str(((request.env.config or {}).get("split")) or DEFAULT_SPLIT)
|
|
571
|
+
sample = dataset.sample(split=split, index=request.env.seed)
|
|
572
|
+
|
|
573
|
+
observation = {
|
|
574
|
+
"problem": sample["problem"],
|
|
575
|
+
"split": sample["split"],
|
|
576
|
+
"index": sample["index"],
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
tool_calls: list[Dict[str, Any]] = []
|
|
580
|
+
inference_payload: Dict[str, Any] | None = None
|
|
581
|
+
error_info: Dict[str, Any] = {}
|
|
582
|
+
try:
|
|
583
|
+
tool_calls, inference_payload = await _call_inference(request.policy.config or {}, observation)
|
|
584
|
+
except HTTPException as http_err:
|
|
585
|
+
tool_calls = []
|
|
586
|
+
error_info = {"error": http_err.detail, "code": http_err.status_code}
|
|
587
|
+
except Exception as exc:
|
|
588
|
+
tool_calls = []
|
|
589
|
+
error_info = {"error": str(exc)}
|
|
590
|
+
|
|
591
|
+
reward, status, correct = _score_submission(
|
|
592
|
+
MathEnvState(
|
|
593
|
+
env_id="rollout",
|
|
594
|
+
split=sample["split"],
|
|
595
|
+
index=sample["index"],
|
|
596
|
+
problem=sample["problem"],
|
|
597
|
+
answer=sample["answer"],
|
|
598
|
+
raw_solution=sample["raw_solution"],
|
|
599
|
+
),
|
|
600
|
+
tool_calls,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
step = RolloutStep(
|
|
604
|
+
obs=observation,
|
|
605
|
+
tool_calls=tool_calls,
|
|
606
|
+
reward=reward,
|
|
607
|
+
done=True,
|
|
608
|
+
info={
|
|
609
|
+
"expected_answer": sample["answer"],
|
|
610
|
+
"status": status,
|
|
611
|
+
"correct": correct,
|
|
612
|
+
"raw_solution": sample["raw_solution"],
|
|
613
|
+
**error_info,
|
|
614
|
+
},
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
trajectory = RolloutTrajectory(
|
|
618
|
+
env_id=f"math::{sample['split']}::{sample['index']}",
|
|
619
|
+
policy_id=request.policy.policy_id or "policy",
|
|
620
|
+
steps=[step],
|
|
621
|
+
final={
|
|
622
|
+
"observation": {**observation, "status": status},
|
|
623
|
+
"reward": reward,
|
|
624
|
+
},
|
|
625
|
+
length=1,
|
|
626
|
+
)
|
|
627
|
+
metrics = RolloutMetrics(
|
|
628
|
+
episode_returns=[reward],
|
|
629
|
+
mean_return=reward,
|
|
630
|
+
num_steps=1,
|
|
631
|
+
num_episodes=1,
|
|
632
|
+
outcome_score=reward,
|
|
633
|
+
events_score=reward,
|
|
634
|
+
details={"status": status, "correct": correct},
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return RolloutResponse(
|
|
638
|
+
run_id=request.run_id,
|
|
639
|
+
trajectories=[trajectory],
|
|
640
|
+
branches={},
|
|
641
|
+
metrics=metrics,
|
|
642
|
+
aborted=False,
|
|
643
|
+
ops_executed=2,
|
|
644
|
+
trace=None,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, MathDataset]:
|
|
649
|
+
registry = TaskDatasetRegistry()
|
|
650
|
+
dataset = MathDataset(name=DATASET_NAME, config=DATASET_CONFIG, splits=AVAILABLE_SPLITS)
|
|
651
|
+
# Ensure default split is available when the task app boots
|
|
652
|
+
try:
|
|
653
|
+
dataset.ensure_ready([DEFAULT_SPLIT])
|
|
654
|
+
except Exception as exc:
|
|
655
|
+
raise RuntimeError(
|
|
656
|
+
"Failed to initialise math dataset. Set MATH_DATASET_LOCAL_DIR or ensure network access.\n"
|
|
657
|
+
f"Underlying error: {exc}"
|
|
658
|
+
) from exc
|
|
659
|
+
registry.register(MATH_DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
660
|
+
return registry, dataset
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def _base_task_info() -> TaskInfo:
|
|
664
|
+
return TaskInfo(
|
|
665
|
+
task={"id": "math_single_step", "name": "Math Single Step", "version": "1.0.0"},
|
|
666
|
+
environments=["math"],
|
|
667
|
+
action_space={
|
|
668
|
+
"type": "tool_call",
|
|
669
|
+
"tools": [
|
|
670
|
+
{
|
|
671
|
+
"name": TOOL_NAME,
|
|
672
|
+
"description": "Submit the final answer for the math problem.",
|
|
673
|
+
"schema": {"answer": "string"},
|
|
674
|
+
}
|
|
675
|
+
],
|
|
676
|
+
"max_calls": 1,
|
|
677
|
+
},
|
|
678
|
+
observation={
|
|
679
|
+
"summary": "Single math word problem presented as plain text.",
|
|
680
|
+
"keys": ["problem"],
|
|
681
|
+
},
|
|
682
|
+
dataset={
|
|
683
|
+
**MATH_DATASET_SPEC.model_dump(),
|
|
684
|
+
"hf_dataset": DATASET_NAME,
|
|
685
|
+
"hf_config": DATASET_CONFIG,
|
|
686
|
+
},
|
|
687
|
+
rubric={
|
|
688
|
+
"version": "1",
|
|
689
|
+
"criteria_count": 1,
|
|
690
|
+
"source": "inline",
|
|
691
|
+
},
|
|
692
|
+
inference={
|
|
693
|
+
"supports_proxy": True,
|
|
694
|
+
"tool": {"name": TOOL_NAME, "parallel_tool_calls": False},
|
|
695
|
+
},
|
|
696
|
+
capabilities={
|
|
697
|
+
"supports_rollout": True,
|
|
698
|
+
"supports_env_lifecycle": True,
|
|
699
|
+
"requires_api_key_header": True,
|
|
700
|
+
},
|
|
701
|
+
limits={"max_turns": 1},
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
|
|
705
|
+
OUTCOME_RUBRIC: Rubric = cast(
|
|
706
|
+
Rubric,
|
|
707
|
+
load_rubric(
|
|
708
|
+
{
|
|
709
|
+
"version": "1",
|
|
710
|
+
"goal_text": "Encourage correct single-step math answers via tool calls.",
|
|
711
|
+
"aggregation": "weighted_sum",
|
|
712
|
+
"criteria": [
|
|
713
|
+
{
|
|
714
|
+
"id": "correct_answer",
|
|
715
|
+
"description": "Submit the correct final answer using the math_submit tool.",
|
|
716
|
+
"weight": 1.0,
|
|
717
|
+
}
|
|
718
|
+
],
|
|
719
|
+
}
|
|
720
|
+
),
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
EVENTS_RUBRIC: Rubric = cast(
|
|
724
|
+
Rubric,
|
|
725
|
+
load_rubric(
|
|
726
|
+
{
|
|
727
|
+
"version": "1",
|
|
728
|
+
"goal_text": "Penalize missing or malformed tool calls.",
|
|
729
|
+
"aggregation": "weighted_sum",
|
|
730
|
+
"criteria": [
|
|
731
|
+
{
|
|
732
|
+
"id": "tool_usage",
|
|
733
|
+
"description": "Make exactly one tool call with an answer string.",
|
|
734
|
+
"weight": 1.0,
|
|
735
|
+
}
|
|
736
|
+
],
|
|
737
|
+
}
|
|
738
|
+
),
|
|
739
|
+
)
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def describe_taskset(dataset: MathDataset) -> Dict[str, Any]:
|
|
743
|
+
return {
|
|
744
|
+
**MATH_DATASET_SPEC.model_dump(),
|
|
745
|
+
"hf_dataset": DATASET_NAME,
|
|
746
|
+
"hf_config": DATASET_CONFIG,
|
|
747
|
+
"sizes": {split: dataset.size(split) for split in dataset.splits},
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def provide_task_instances(dataset: MathDataset, seeds: Sequence[int]) -> Iterable[TaskInfo]:
|
|
752
|
+
info = _base_task_info()
|
|
753
|
+
for seed in seeds:
|
|
754
|
+
sample = dataset.sample(split=DEFAULT_SPLIT, index=seed)
|
|
755
|
+
yield TaskInfo(
|
|
756
|
+
task=info.task,
|
|
757
|
+
environments=info.environments,
|
|
758
|
+
action_space=info.action_space,
|
|
759
|
+
observation={**info.observation, "sample_index": sample["index"]},
|
|
760
|
+
dataset={
|
|
761
|
+
**info.dataset,
|
|
762
|
+
"split": sample["split"],
|
|
763
|
+
"index": sample["index"],
|
|
764
|
+
},
|
|
765
|
+
rubric=info.rubric,
|
|
766
|
+
inference=info.inference,
|
|
767
|
+
capabilities=info.capabilities,
|
|
768
|
+
limits=info.limits,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
def build_config() -> TaskAppConfig:
|
|
773
|
+
registry, dataset = build_dataset()
|
|
774
|
+
base_info = _base_task_info()
|
|
775
|
+
|
|
776
|
+
tracing_enabled = tracing_env_enabled()
|
|
777
|
+
tracing_db_url = resolve_tracing_db_url()
|
|
778
|
+
tracer_factory = build_tracer_factory(SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url)
|
|
779
|
+
sft_output_dir = resolve_sft_output_dir()
|
|
780
|
+
|
|
781
|
+
app_state: Dict[str, Any] = {
|
|
782
|
+
"math_dataset": dataset,
|
|
783
|
+
"math_env_manager": MathEnvironmentManager(dataset),
|
|
784
|
+
"tracing_enabled": tracing_enabled,
|
|
785
|
+
}
|
|
786
|
+
if tracer_factory is not None:
|
|
787
|
+
app_state["session_tracer_factory"] = tracer_factory
|
|
788
|
+
if sft_output_dir:
|
|
789
|
+
app_state["sft_output_dir"] = sft_output_dir
|
|
790
|
+
|
|
791
|
+
proxy_keys = normalize_vendor_keys()
|
|
792
|
+
openai_key = proxy_keys.get("OPENAI_API_KEY")
|
|
793
|
+
groq_key = proxy_keys.get("GROQ_API_KEY")
|
|
794
|
+
proxy_config = ProxyConfig(
|
|
795
|
+
enable_openai=openai_key is not None,
|
|
796
|
+
enable_groq=groq_key is not None,
|
|
797
|
+
system_hint=(
|
|
798
|
+
"You must respond with a single math_submit tool call containing only the final answer."
|
|
799
|
+
),
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
config = TaskAppConfig(
|
|
803
|
+
app_id="math-single-step",
|
|
804
|
+
name="Math Single Step Task",
|
|
805
|
+
description="Single-step math reasoning environment built on the MATH dataset.",
|
|
806
|
+
base_task_info=base_info,
|
|
807
|
+
describe_taskset=lambda: describe_taskset(dataset),
|
|
808
|
+
provide_task_instances=lambda seeds: provide_task_instances(dataset, seeds),
|
|
809
|
+
rollout=rollout_executor,
|
|
810
|
+
dataset_registry=registry,
|
|
811
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
812
|
+
proxy=proxy_config,
|
|
813
|
+
routers=(math_router,),
|
|
814
|
+
app_state=app_state,
|
|
815
|
+
cors_origins=["*"],
|
|
816
|
+
)
|
|
817
|
+
return config
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
register_task_app(
|
|
821
|
+
entry=TaskAppEntry(
|
|
822
|
+
app_id="math-single-step",
|
|
823
|
+
description="Single-step math reasoning task app using EleutherAI/math dataset.",
|
|
824
|
+
config_factory=build_config,
|
|
825
|
+
aliases=("math-rl",),
|
|
826
|
+
env_files=("examples/rl/.env",),
|
|
827
|
+
modal=ModalDeploymentConfig(
|
|
828
|
+
app_name="synth-math-single-step",
|
|
829
|
+
pip_packages=(
|
|
830
|
+
"datasets>=4.0.0",
|
|
831
|
+
"fastapi>=0.115.0",
|
|
832
|
+
"pydantic>=2.0.0",
|
|
833
|
+
"httpx>=0.26.0",
|
|
834
|
+
"requests>=2.32.0",
|
|
835
|
+
"python-dotenv>=1.0.0",
|
|
836
|
+
"diskcache>=5.6.3",
|
|
837
|
+
"duckdb>=1.0.0",
|
|
838
|
+
"ty>=0.0.1a5",
|
|
839
|
+
"toml>=0.10.2",
|
|
840
|
+
"aiosqlite>=0.21.0",
|
|
841
|
+
"libsql>=0.1.8",
|
|
842
|
+
"pynacl>=1.5.0",
|
|
843
|
+
"sqlalchemy>=2.0.42",
|
|
844
|
+
),
|
|
845
|
+
extra_local_dirs=(
|
|
846
|
+
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
847
|
+
(str(REPO_ROOT / "examples" / "rl"), "/opt/synth_ai_repo/examples/rl"),
|
|
848
|
+
),
|
|
849
|
+
volume_mounts=(("math-dataset-cache", "/modal_volumes/math_dataset"),),
|
|
850
|
+
),
|
|
851
|
+
)
|
|
852
|
+
)
|