synth-ai 0.2.8.dev12__py3-none-any.whl → 0.2.8.dev13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- synth_ai/api/train/__init__.py +5 -0
- synth_ai/api/train/builders.py +165 -0
- synth_ai/api/train/cli.py +429 -0
- synth_ai/api/train/config_finder.py +120 -0
- synth_ai/api/train/env_resolver.py +302 -0
- synth_ai/api/train/pollers.py +66 -0
- synth_ai/api/train/task_app.py +128 -0
- synth_ai/api/train/utils.py +232 -0
- synth_ai/cli/__init__.py +23 -0
- synth_ai/cli/rl_demo.py +2 -2
- synth_ai/cli/root.py +2 -1
- synth_ai/cli/task_apps.py +520 -0
- synth_ai/task/__init__.py +94 -1
- synth_ai/task/apps/__init__.py +88 -0
- synth_ai/task/apps/grpo_crafter.py +438 -0
- synth_ai/task/apps/math_single_step.py +852 -0
- synth_ai/task/auth.py +132 -0
- synth_ai/task/client.py +148 -0
- synth_ai/task/contracts.py +29 -14
- synth_ai/task/datasets.py +105 -0
- synth_ai/task/errors.py +49 -0
- synth_ai/task/json.py +77 -0
- synth_ai/task/proxy.py +258 -0
- synth_ai/task/rubrics.py +212 -0
- synth_ai/task/server.py +398 -0
- synth_ai/task/tracing_utils.py +79 -0
- synth_ai/task/vendors.py +61 -0
- synth_ai/tracing_v3/session_tracer.py +13 -5
- synth_ai/tracing_v3/storage/base.py +10 -12
- synth_ai/tracing_v3/turso/manager.py +20 -6
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/METADATA +3 -2
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/RECORD +36 -14
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.8.dev12.dist-info → synth_ai-0.2.8.dev13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Registry for Task Apps exposed via the shared FastAPI harness."""
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Callable, Dict, Iterable, List, Sequence
|
|
7
|
+
|
|
8
|
+
from ..server import TaskAppConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(slots=True)
|
|
12
|
+
class ModalDeploymentConfig:
|
|
13
|
+
"""Modal deployment defaults for a task app."""
|
|
14
|
+
|
|
15
|
+
app_name: str
|
|
16
|
+
python_version: str = "3.11"
|
|
17
|
+
pip_packages: Sequence[str] = field(default_factory=tuple)
|
|
18
|
+
extra_local_dirs: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
|
19
|
+
secret_names: Sequence[str] = field(default_factory=tuple)
|
|
20
|
+
volume_mounts: Sequence[tuple[str, str]] = field(default_factory=tuple)
|
|
21
|
+
timeout: int = 600
|
|
22
|
+
memory: int = 4096
|
|
23
|
+
cpu: float = 2.0
|
|
24
|
+
min_containers: int = 1
|
|
25
|
+
max_containers: int = 4
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(slots=True)
|
|
29
|
+
class TaskAppEntry:
|
|
30
|
+
"""Metadata describing a registered task app."""
|
|
31
|
+
|
|
32
|
+
app_id: str
|
|
33
|
+
description: str
|
|
34
|
+
config_factory: Callable[[], TaskAppConfig]
|
|
35
|
+
aliases: Sequence[str] = field(default_factory=tuple)
|
|
36
|
+
env_files: Sequence[str] = field(default_factory=tuple)
|
|
37
|
+
modal: ModalDeploymentConfig | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TaskAppRegistry:
|
|
41
|
+
"""In-memory registry of known task apps."""
|
|
42
|
+
|
|
43
|
+
def __init__(self) -> None:
|
|
44
|
+
self._entries: Dict[str, TaskAppEntry] = {}
|
|
45
|
+
self._alias_to_id: Dict[str, str] = {}
|
|
46
|
+
|
|
47
|
+
def register(self, entry: TaskAppEntry) -> None:
|
|
48
|
+
if entry.app_id in self._entries:
|
|
49
|
+
raise ValueError(f"Task app already registered: {entry.app_id}")
|
|
50
|
+
self._entries[entry.app_id] = entry
|
|
51
|
+
for alias in entry.aliases:
|
|
52
|
+
if alias in self._alias_to_id:
|
|
53
|
+
raise ValueError(f"Alias already registered: {alias}")
|
|
54
|
+
self._alias_to_id[alias] = entry.app_id
|
|
55
|
+
|
|
56
|
+
def get(self, app_id: str) -> TaskAppEntry:
|
|
57
|
+
resolved = self._alias_to_id.get(app_id, app_id)
|
|
58
|
+
if resolved not in self._entries:
|
|
59
|
+
raise KeyError(f"Unknown task app id: {app_id}")
|
|
60
|
+
return self._entries[resolved]
|
|
61
|
+
|
|
62
|
+
def list(self) -> List[TaskAppEntry]:
|
|
63
|
+
return sorted(self._entries.values(), key=lambda entry: entry.app_id)
|
|
64
|
+
|
|
65
|
+
def __iter__(self) -> Iterable[TaskAppEntry]:
|
|
66
|
+
return iter(self.list())
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
registry = TaskAppRegistry()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def register_task_app(*, entry: TaskAppEntry) -> None:
|
|
73
|
+
registry.register(entry)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Register built-in task apps
|
|
78
|
+
try:
|
|
79
|
+
from . import grpo_crafter # noqa: F401
|
|
80
|
+
except Exception:
|
|
81
|
+
# Defer import errors so CLI can report missing deps gracefully
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
from . import math_single_step # noqa: F401
|
|
86
|
+
except Exception:
|
|
87
|
+
# Defer import errors so CLI can report missing deps gracefully
|
|
88
|
+
pass
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""Task App configuration for the GRPO Crafter example."""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, Iterable, List, Sequence
|
|
10
|
+
|
|
11
|
+
from ..contracts import RolloutRequest, RolloutResponse, TaskInfo
|
|
12
|
+
from ..datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
13
|
+
from ..rubrics import load_rubric
|
|
14
|
+
from ..server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
15
|
+
from ..json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
16
|
+
from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
17
|
+
from ..tracing_utils import (
|
|
18
|
+
build_tracer_factory,
|
|
19
|
+
resolve_sft_output_dir,
|
|
20
|
+
resolve_tracing_db_url,
|
|
21
|
+
tracing_env_enabled,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
REPO_ROOT = Path(__file__).resolve().parents[3]
|
|
28
|
+
TASK_APP_ROOT = REPO_ROOT / "examples" / "warming_up_to_rl" / "task_app"
|
|
29
|
+
SYNTH_ENVS_HOSTED_ROOT = TASK_APP_ROOT / "synth_envs_hosted"
|
|
30
|
+
|
|
31
|
+
for path in [REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT]:
|
|
32
|
+
path_str = str(path)
|
|
33
|
+
if path_str not in sys.path:
|
|
34
|
+
sys.path.insert(0, path_str)
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
import crafter # type: ignore
|
|
38
|
+
import crafter.constants as C # type: ignore
|
|
39
|
+
from synth_ai.environments.examples.crafter_classic.taskset import TRAIT_BOUNDS, world_traits
|
|
40
|
+
from synth_envs_hosted.branching import router as branching_router
|
|
41
|
+
from synth_envs_hosted.environment_routes import router as environment_router
|
|
42
|
+
from synth_envs_hosted.hosted_app import TaskApp as HostedTaskApp
|
|
43
|
+
from synth_envs_hosted.policy_routes import router as policy_router
|
|
44
|
+
from synth_envs_hosted.rollout import (
|
|
45
|
+
RolloutEnvSpec as LegacyRolloutEnvSpec,
|
|
46
|
+
RolloutPolicySpec as LegacyRolloutPolicySpec,
|
|
47
|
+
RolloutRecordConfig as LegacyRolloutRecordConfig,
|
|
48
|
+
RolloutRequest as LegacyRolloutRequest,
|
|
49
|
+
RolloutResponse as LegacyRolloutResponse,
|
|
50
|
+
RolloutSafetyConfig as LegacyRolloutSafetyConfig,
|
|
51
|
+
execute_rollout as legacy_execute_rollout,
|
|
52
|
+
)
|
|
53
|
+
except Exception as exc: # pragma: no cover - import-time validation
|
|
54
|
+
# Provide a more actionable error with the missing module and fix hints
|
|
55
|
+
missing_mod = None
|
|
56
|
+
if isinstance(exc, ModuleNotFoundError):
|
|
57
|
+
missing_mod = getattr(exc, "name", None) or str(exc).split("'")[1] if "'" in str(exc) else None
|
|
58
|
+
fix_hint = None
|
|
59
|
+
if missing_mod:
|
|
60
|
+
mapping = {
|
|
61
|
+
"dotenv": "python-dotenv",
|
|
62
|
+
"crafter": "crafter",
|
|
63
|
+
"httpx": "httpx",
|
|
64
|
+
"aiohttp": "aiohttp",
|
|
65
|
+
"fastapi": "fastapi",
|
|
66
|
+
"uvicorn": "uvicorn",
|
|
67
|
+
"sqlalchemy": "sqlalchemy",
|
|
68
|
+
"aiosqlite": "aiosqlite",
|
|
69
|
+
"greenlet": "greenlet",
|
|
70
|
+
}
|
|
71
|
+
pkg = mapping.get(missing_mod, missing_mod)
|
|
72
|
+
fix_hint = (
|
|
73
|
+
f"Missing Python module '{missing_mod}'. Install the package '{pkg}'.\n"
|
|
74
|
+
f"For Modal: add '{pkg}' to ModalDeploymentConfig.pip_packages in synth_ai/task/apps/grpo_crafter.py.\n"
|
|
75
|
+
f"Locally: pip install {pkg}"
|
|
76
|
+
)
|
|
77
|
+
detailed = (
|
|
78
|
+
"grpo_crafter task app requires example dependencies and runtime libs.\n"
|
|
79
|
+
+ (fix_hint + "\n" if fix_hint else "")
|
|
80
|
+
+ f"Original error: {exc}"
|
|
81
|
+
)
|
|
82
|
+
raise RuntimeError(detailed) from exc
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
CRAFTING_RULES_SYSTEM_HINT = (
|
|
86
|
+
"Crafter crafting rules (from the paper):\n"
|
|
87
|
+
"- Make Wood Pickaxe: Nearby a table; have wood in inventory.\n"
|
|
88
|
+
"- Make Stone Pickaxe: Nearby a table; have wood and stone in inventory.\n"
|
|
89
|
+
"- Make Iron Pickaxe: Nearby a table; furnace exists; have wood, coal, and iron in inventory.\n"
|
|
90
|
+
"- Make Wood Sword: Nearby a table; have wood in inventory.\n"
|
|
91
|
+
"- Make Stone Sword: Nearby a table; have wood and stone in inventory.\n"
|
|
92
|
+
"- Make Iron Sword: Nearby a table; furnace exists; have wood, coal, and iron in inventory."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
DATASET_SPEC = TaskDatasetSpec(
|
|
97
|
+
id="crafter_classic_procedural",
|
|
98
|
+
name="Crafter Classic Procedural Seeds",
|
|
99
|
+
version="1.0.0",
|
|
100
|
+
splits=["train"],
|
|
101
|
+
default_split="train",
|
|
102
|
+
description="Procedural Crafter Classic seeds with reproducible world traits.",
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class CrafterDataset:
|
|
108
|
+
spec: TaskDatasetSpec
|
|
109
|
+
|
|
110
|
+
def __post_init__(self) -> None:
|
|
111
|
+
self.default_seed = int(env_value("CRAFTER_DEFAULT_SEED", 42))
|
|
112
|
+
self.seed_min = 0
|
|
113
|
+
self.seed_max = int(env_value("CRAFTER_MAX_SEED", 2**31 - 1))
|
|
114
|
+
area_env = env_value("CRAFTER_AREA", "64,64")
|
|
115
|
+
self.area = tuple(int(x) for x in str(area_env).split(","))
|
|
116
|
+
self.length = int(env_value("CRAFTER_EPISODE_LENGTH", 10000))
|
|
117
|
+
self._cache: Dict[int, Dict[str, Any]] = {}
|
|
118
|
+
|
|
119
|
+
def config_for_seed(self, seed: int) -> Dict[str, Any]:
|
|
120
|
+
return {
|
|
121
|
+
"seed": int(seed),
|
|
122
|
+
"area": list(self.area),
|
|
123
|
+
"length": self.length,
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
def describe_seed(self, seed: int) -> Dict[str, Any]:
|
|
127
|
+
seed = int(seed)
|
|
128
|
+
if seed in self._cache:
|
|
129
|
+
return self._cache[seed]
|
|
130
|
+
env = crafter.Env(area=self.area, length=self.length, seed=seed)
|
|
131
|
+
try:
|
|
132
|
+
env.reset()
|
|
133
|
+
traits = world_traits(env)
|
|
134
|
+
player = getattr(env, "_player", None)
|
|
135
|
+
inventory = dict(getattr(player, "inventory", {})) if player else {}
|
|
136
|
+
position = getattr(player, "pos", None)
|
|
137
|
+
finally:
|
|
138
|
+
close_fn = getattr(env, "close", None)
|
|
139
|
+
if callable(close_fn):
|
|
140
|
+
close_fn()
|
|
141
|
+
summary = {
|
|
142
|
+
"seed": seed,
|
|
143
|
+
"difficulty": self._difficulty(traits),
|
|
144
|
+
"traits": traits,
|
|
145
|
+
"inventory": inventory,
|
|
146
|
+
"player_position": list(position) if position is not None else None,
|
|
147
|
+
"config": self.config_for_seed(seed),
|
|
148
|
+
}
|
|
149
|
+
self._cache[seed] = summary
|
|
150
|
+
return summary
|
|
151
|
+
|
|
152
|
+
def _difficulty(self, traits: Dict[str, int]) -> str:
|
|
153
|
+
for difficulty, bounds in TRAIT_BOUNDS.items():
|
|
154
|
+
if (
|
|
155
|
+
traits.get("trees", 0) >= bounds.get("min_trees", 0)
|
|
156
|
+
and traits.get("hostiles", 0) <= bounds.get("max_hostiles", 0)
|
|
157
|
+
):
|
|
158
|
+
return difficulty
|
|
159
|
+
return "custom"
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def seed_range(self) -> List[int]:
|
|
163
|
+
return [self.seed_min, self.seed_max]
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def env_value(key: str, default: Any) -> Any:
|
|
167
|
+
import os
|
|
168
|
+
|
|
169
|
+
return os.getenv(key, default)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def build_dataset() -> tuple[TaskDatasetRegistry, CrafterDataset]:
|
|
173
|
+
registry = TaskDatasetRegistry()
|
|
174
|
+
dataset = CrafterDataset(DATASET_SPEC)
|
|
175
|
+
registry.register(DATASET_SPEC, lambda _spec: dataset, cache=True)
|
|
176
|
+
return registry, dataset
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
180
|
+
return TaskInfo(
|
|
181
|
+
task={"id": "crafter_classic", "name": "Crafter Classic", "version": "1.0.0"},
|
|
182
|
+
environments=["crafter"],
|
|
183
|
+
action_space={
|
|
184
|
+
"type": "discrete",
|
|
185
|
+
"size": len(C.actions),
|
|
186
|
+
"actions": list(C.actions),
|
|
187
|
+
},
|
|
188
|
+
observation={
|
|
189
|
+
"summary": "RGB frame plus inventory, achievements, and semantic map patches.",
|
|
190
|
+
"keys": ["image", "inventory", "achievements", "semantic_map_patch7"],
|
|
191
|
+
"image_shape": [64, 64, 3],
|
|
192
|
+
},
|
|
193
|
+
dataset={
|
|
194
|
+
**DATASET_SPEC.model_dump(),
|
|
195
|
+
"seed_range": dataset.seed_range,
|
|
196
|
+
"default_seed": dataset.default_seed,
|
|
197
|
+
},
|
|
198
|
+
rubric={
|
|
199
|
+
"version": "1",
|
|
200
|
+
"criteria_count": 2,
|
|
201
|
+
"source": "inline",
|
|
202
|
+
"aggregation": "weighted_sum",
|
|
203
|
+
},
|
|
204
|
+
inference={
|
|
205
|
+
"supports_proxy": True,
|
|
206
|
+
"endpoints": {
|
|
207
|
+
"openai": "/proxy/v1/chat/completions",
|
|
208
|
+
"groq": "/proxy/groq/v1/chat/completions",
|
|
209
|
+
},
|
|
210
|
+
"tool": {"name": "interact", "parallel_tool_calls": False},
|
|
211
|
+
},
|
|
212
|
+
capabilities={
|
|
213
|
+
"supports_rollout": True,
|
|
214
|
+
"supports_env_lifecycle": True,
|
|
215
|
+
"requires_api_key_header": True,
|
|
216
|
+
},
|
|
217
|
+
limits={"max_ops": 100000, "max_time_s": 3600},
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
OUTCOME_RUBRIC = load_rubric(
|
|
222
|
+
{
|
|
223
|
+
"version": "1",
|
|
224
|
+
"goal_text": "Reward unlocking Crafter achievements and survival.",
|
|
225
|
+
"aggregation": "weighted_sum",
|
|
226
|
+
"criteria": [
|
|
227
|
+
{
|
|
228
|
+
"id": "achievements",
|
|
229
|
+
"description": "Unlock achievements or crafting milestones.",
|
|
230
|
+
"weight": 1.0,
|
|
231
|
+
},
|
|
232
|
+
{
|
|
233
|
+
"id": "survival",
|
|
234
|
+
"description": "Maintain health, food, and drink levels.",
|
|
235
|
+
"weight": 1.0,
|
|
236
|
+
},
|
|
237
|
+
],
|
|
238
|
+
}
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
EVENTS_RUBRIC = load_rubric(
|
|
242
|
+
{
|
|
243
|
+
"version": "1",
|
|
244
|
+
"goal_text": "Encourage purposeful step-wise exploration and crafting.",
|
|
245
|
+
"aggregation": "weighted_sum",
|
|
246
|
+
"criteria": [
|
|
247
|
+
{
|
|
248
|
+
"id": "progress_steps",
|
|
249
|
+
"description": "Actions progress quests, crafting, or exploration.",
|
|
250
|
+
"weight": 1.0,
|
|
251
|
+
}
|
|
252
|
+
],
|
|
253
|
+
}
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def describe_taskset(dataset: CrafterDataset) -> Dict[str, Any]:
|
|
258
|
+
return {
|
|
259
|
+
**DATASET_SPEC.model_dump(),
|
|
260
|
+
"seed_range": dataset.seed_range,
|
|
261
|
+
"default_seed": dataset.default_seed,
|
|
262
|
+
"config": {
|
|
263
|
+
"area": list(dataset.area),
|
|
264
|
+
"length": dataset.length,
|
|
265
|
+
},
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def provide_task_instances(dataset: CrafterDataset, base_info: TaskInfo, seeds: Sequence[int]) -> Iterable[TaskInfo]:
|
|
270
|
+
infos: list[TaskInfo] = []
|
|
271
|
+
for seed_value in seeds:
|
|
272
|
+
summary = dataset.describe_seed(seed_value)
|
|
273
|
+
infos.append(
|
|
274
|
+
TaskInfo(
|
|
275
|
+
task=base_info.task,
|
|
276
|
+
environments=base_info.environments,
|
|
277
|
+
action_space=base_info.action_space,
|
|
278
|
+
observation={
|
|
279
|
+
**base_info.observation,
|
|
280
|
+
"seed": seed_value,
|
|
281
|
+
"traits": summary["traits"],
|
|
282
|
+
"inventory": summary["inventory"],
|
|
283
|
+
"player_position": summary["player_position"],
|
|
284
|
+
},
|
|
285
|
+
dataset={
|
|
286
|
+
**base_info.dataset,
|
|
287
|
+
"seed": seed_value,
|
|
288
|
+
"difficulty": summary["difficulty"],
|
|
289
|
+
"config": summary["config"],
|
|
290
|
+
},
|
|
291
|
+
rubric=base_info.rubric,
|
|
292
|
+
inference=base_info.inference,
|
|
293
|
+
capabilities=base_info.capabilities,
|
|
294
|
+
limits=base_info.limits,
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
return infos
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def _normalise_op(op_value: Any, index: int) -> str:
|
|
301
|
+
if isinstance(op_value, str):
|
|
302
|
+
candidate = op_value
|
|
303
|
+
elif isinstance(op_value, dict):
|
|
304
|
+
candidate = op_value.get("type") or op_value.get("op")
|
|
305
|
+
else:
|
|
306
|
+
candidate = None
|
|
307
|
+
if not candidate:
|
|
308
|
+
raise ValueError(f"Missing op type at index {index}")
|
|
309
|
+
lowered = str(candidate).strip().lower()
|
|
310
|
+
if lowered in {"policy", "agent", "model"}:
|
|
311
|
+
return "agent"
|
|
312
|
+
if lowered in {"env", "environment", "step"}:
|
|
313
|
+
return "env"
|
|
314
|
+
raise ValueError(f"Unsupported op type '{candidate}' at index {index}")
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
318
|
+
converted_ops: List[str] = [_normalise_op(op, idx) for idx, op in enumerate(request.ops)]
|
|
319
|
+
legacy_request = LegacyRolloutRequest(
|
|
320
|
+
run_id=request.run_id,
|
|
321
|
+
env=LegacyRolloutEnvSpec(
|
|
322
|
+
env_id=request.env.env_id,
|
|
323
|
+
env_name=request.env.env_name,
|
|
324
|
+
config=request.env.config or {},
|
|
325
|
+
seed=request.env.seed,
|
|
326
|
+
),
|
|
327
|
+
policy=LegacyRolloutPolicySpec(
|
|
328
|
+
policy_id=request.policy.policy_id,
|
|
329
|
+
policy_name=request.policy.policy_name,
|
|
330
|
+
config=request.policy.config or {},
|
|
331
|
+
),
|
|
332
|
+
ops=converted_ops,
|
|
333
|
+
record=LegacyRolloutRecordConfig(**request.record.model_dump()),
|
|
334
|
+
on_done=request.on_done,
|
|
335
|
+
branch=None,
|
|
336
|
+
safety=LegacyRolloutSafetyConfig(**request.safety.model_dump()),
|
|
337
|
+
training_session_id=request.training_session_id,
|
|
338
|
+
synth_base_url=request.synth_base_url,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
legacy_response: LegacyRolloutResponse = await legacy_execute_rollout(legacy_request, fastapi_request)
|
|
342
|
+
data = legacy_response.model_dump()
|
|
343
|
+
metrics = data.get("metrics", {}) or {}
|
|
344
|
+
metrics.setdefault("outcome_score", None)
|
|
345
|
+
metrics.setdefault("events_score", None)
|
|
346
|
+
metrics.setdefault("details", {})
|
|
347
|
+
data["metrics"] = metrics
|
|
348
|
+
return RolloutResponse.model_validate(data)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def build_config() -> TaskAppConfig:
|
|
352
|
+
registry, dataset = build_dataset()
|
|
353
|
+
base_info = _base_task_info(dataset)
|
|
354
|
+
|
|
355
|
+
hosted_task_app = HostedTaskApp()
|
|
356
|
+
|
|
357
|
+
tracing_enabled = tracing_env_enabled()
|
|
358
|
+
tracing_db_url = resolve_tracing_db_url()
|
|
359
|
+
tracer_factory = build_tracer_factory(SessionTracer, enabled=tracing_enabled, db_url=tracing_db_url)
|
|
360
|
+
sft_output_dir = resolve_sft_output_dir()
|
|
361
|
+
|
|
362
|
+
app_state: Dict[str, Any] = {
|
|
363
|
+
"task_app": hosted_task_app,
|
|
364
|
+
"allowed_environments": ["crafter"],
|
|
365
|
+
"tracing_enabled": tracing_enabled,
|
|
366
|
+
}
|
|
367
|
+
if tracer_factory is not None:
|
|
368
|
+
app_state["session_tracer_factory"] = tracer_factory
|
|
369
|
+
if sft_output_dir:
|
|
370
|
+
app_state["sft_output_dir"] = sft_output_dir
|
|
371
|
+
|
|
372
|
+
if tracing_enabled:
|
|
373
|
+
status_msg = f"[task:tracing] enabled (db={tracing_db_url or 'default'})"
|
|
374
|
+
else:
|
|
375
|
+
status_msg = "[task:tracing] disabled"
|
|
376
|
+
print(status_msg, flush=True)
|
|
377
|
+
if sft_output_dir:
|
|
378
|
+
print(f"[task:sft] writing JSONL to {sft_output_dir}", flush=True)
|
|
379
|
+
|
|
380
|
+
def _describe_taskset() -> Dict[str, Any]:
|
|
381
|
+
return describe_taskset(dataset)
|
|
382
|
+
|
|
383
|
+
def _provide_instances(seeds: Sequence[int]):
|
|
384
|
+
return provide_task_instances(dataset, base_info, seeds)
|
|
385
|
+
|
|
386
|
+
config = TaskAppConfig(
|
|
387
|
+
app_id="grpo-crafter",
|
|
388
|
+
name="GRPO Crafter Task App",
|
|
389
|
+
description="Crafter Classic environment with GRPO task endpoints and LLM proxies.",
|
|
390
|
+
base_task_info=base_info,
|
|
391
|
+
describe_taskset=_describe_taskset,
|
|
392
|
+
provide_task_instances=_provide_instances,
|
|
393
|
+
rollout=rollout_executor,
|
|
394
|
+
dataset_registry=registry,
|
|
395
|
+
rubrics=RubricBundle(outcome=OUTCOME_RUBRIC, events=EVENTS_RUBRIC),
|
|
396
|
+
proxy=ProxyConfig(enable_openai=True, enable_groq=True, system_hint=CRAFTING_RULES_SYSTEM_HINT),
|
|
397
|
+
routers=(environment_router, policy_router, branching_router),
|
|
398
|
+
app_state=app_state,
|
|
399
|
+
cors_origins=["*"],
|
|
400
|
+
)
|
|
401
|
+
return config
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
register_task_app(
|
|
405
|
+
entry=TaskAppEntry(
|
|
406
|
+
app_id="grpo-crafter",
|
|
407
|
+
description="Crafter Classic task app with rollout + proxy endpoints",
|
|
408
|
+
config_factory=build_config,
|
|
409
|
+
aliases=("crafter", "crafter-task"),
|
|
410
|
+
env_files=(str(REPO_ROOT / "backend" / ".env.dev"),),
|
|
411
|
+
modal=ModalDeploymentConfig(
|
|
412
|
+
app_name="grpo-crafter-task-app",
|
|
413
|
+
python_version="3.11",
|
|
414
|
+
pip_packages=(
|
|
415
|
+
"fastapi>=0.100.0",
|
|
416
|
+
"uvicorn>=0.23.0",
|
|
417
|
+
"pydantic>=2.0.0",
|
|
418
|
+
"numpy>=1.24.0",
|
|
419
|
+
"aiohttp>=3.8.0",
|
|
420
|
+
"httpx>=0.24.0",
|
|
421
|
+
"python-dotenv>=1.0.1",
|
|
422
|
+
# Tracing/DB runtime deps
|
|
423
|
+
"sqlalchemy>=2.0.42",
|
|
424
|
+
"aiosqlite>=0.21.0",
|
|
425
|
+
"greenlet>=3.2.3",
|
|
426
|
+
"crafter",
|
|
427
|
+
),
|
|
428
|
+
extra_local_dirs=(
|
|
429
|
+
(str(REPO_ROOT / 'synth_ai'), '/opt/synth_ai_repo/synth_ai'),
|
|
430
|
+
(str(TASK_APP_ROOT), '/opt/synth_ai_repo/examples/warming_up_to_rl/task_app'),
|
|
431
|
+
),
|
|
432
|
+
secret_names=("crafter-environment-sdk", "groq-api-key", "openai-api-key"),
|
|
433
|
+
memory=16384,
|
|
434
|
+
cpu=4.0,
|
|
435
|
+
max_containers=10,
|
|
436
|
+
),
|
|
437
|
+
)
|
|
438
|
+
)
|