synth-ai 0.2.13.dev2__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/configs/README_verilog_rl.md +77 -0
- examples/multi_step/configs/VERILOG_REWARDS.md +90 -0
- examples/multi_step/configs/VERILOG_RL_CHECKLIST.md +183 -0
- examples/multi_step/configs/crafter_eval_synth_qwen4b.toml +35 -0
- examples/multi_step/configs/crafter_eval_text_only_groq_qwen32b.toml +36 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +5 -4
- examples/multi_step/configs/crafter_synth_backend.md +40 -0
- examples/multi_step/configs/verilog_eval_groq_qwen32b.toml +31 -0
- examples/multi_step/configs/verilog_eval_synth_qwen8b.toml +33 -0
- examples/multi_step/configs/verilog_rl_lora.toml +190 -0
- examples/multi_step/judges/crafter_backend_judge.py +220 -0
- examples/multi_step/judges/verilog_backend_judge.py +234 -0
- examples/multi_step/readme.md +48 -0
- examples/multi_step/verilog_rl_lora.md +218 -0
- examples/qwen_coder/configs/coder_lora_30b.toml +1 -1
- examples/sft/evaluate.py +2 -0
- examples/sft/generate_traces.py +2 -0
- examples/swe/task_app/grpo_swe_mini.py +1 -0
- examples/swe/task_app/hosted/rollout.py +2 -0
- examples/task_apps/IMAGE_ONLY_EVAL_QUICKSTART.md +258 -0
- examples/task_apps/crafter/CREATE_SFT_DATASET.md +273 -0
- examples/task_apps/crafter/EVAL_IMAGE_ONLY_RESULTS.md +152 -0
- examples/task_apps/crafter/FILTER_COMMAND_STATUS.md +174 -0
- examples/task_apps/crafter/FILTER_COMMAND_SUCCESS.md +268 -0
- examples/task_apps/crafter/QUERY_EXAMPLES.md +203 -0
- examples/task_apps/crafter/README_IMAGE_ONLY_EVAL.md +316 -0
- examples/task_apps/crafter/eval_image_only_gpt4o.toml +28 -0
- examples/task_apps/crafter/eval_text_only_groq_llama.toml +36 -0
- examples/task_apps/crafter/filter_sft_dataset.toml +16 -0
- examples/task_apps/crafter/task_app/__init__.py +3 -0
- examples/task_apps/crafter/task_app/grpo_crafter.py +306 -8
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/environment.py +10 -0
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/policy.py +16 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/envs/crafter/react_agent.py +17 -2
- examples/task_apps/crafter/task_app/synth_envs_hosted/inference/openai_client.py +25 -3
- examples/task_apps/crafter/task_app/synth_envs_hosted/policy_routes.py +52 -1
- examples/task_apps/crafter/task_app/synth_envs_hosted/rollout.py +111 -13
- examples/task_apps/crafter/task_app/synth_envs_hosted/utils.py +156 -0
- examples/task_apps/enron/filter_sft.toml +5 -0
- examples/task_apps/enron/tests/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/__init__.py +2 -0
- examples/task_apps/enron/tests/integration/test_enron_eval.py +2 -0
- examples/task_apps/enron/tests/unit/__init__.py +2 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_COMPLETE.md +283 -0
- examples/task_apps/pokemon_red/EVAL_IMAGE_ONLY_STATUS.md +155 -0
- examples/task_apps/pokemon_red/README_IMAGE_ONLY_EVAL.md +415 -0
- examples/task_apps/pokemon_red/eval_image_only_gpt4o.toml +29 -0
- examples/task_apps/pokemon_red/pallet_town_rl_config.toml +2 -0
- examples/task_apps/pokemon_red/task_app.py +199 -6
- examples/task_apps/pokemon_red/test_pallet_town_rewards.py +2 -0
- examples/task_apps/sokoban/filter_sft.toml +5 -0
- examples/task_apps/sokoban/tests/__init__.py +2 -0
- examples/task_apps/sokoban/tests/integration/__init__.py +2 -0
- examples/task_apps/sokoban/tests/unit/__init__.py +2 -0
- examples/task_apps/verilog/eval_groq_qwen32b.toml +8 -4
- examples/task_apps/verilog/filter_sft.toml +5 -0
- examples/task_apps/verilog/task_app/grpo_verilog.py +258 -23
- examples/task_apps/verilog/tests/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/__init__.py +2 -0
- examples/task_apps/verilog/tests/integration/test_verilog_eval.py +2 -0
- examples/task_apps/verilog/tests/unit/__init__.py +2 -0
- examples/warming_up_to_rl/groq_test.py +2 -0
- examples/warming_up_to_rl/run_local_rollout.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +2 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +2 -0
- examples/warming_up_to_rl/run_rollout_remote.py +2 -0
- synth_ai/api/models/supported.py +1 -0
- synth_ai/cli/__init__.py +46 -13
- synth_ai/cli/_modal_wrapper.py +3 -2
- synth_ai/cli/recent.py +1 -1
- synth_ai/cli/status.py +1 -1
- synth_ai/cli/task_apps.py +354 -143
- synth_ai/cli/traces.py +1 -1
- synth_ai/cli/tui.py +57 -0
- synth_ai/cli/turso.py +1 -1
- synth_ai/cli/watch.py +1 -1
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +1 -1
- synth_ai/environments/examples/crafter_classic/environment.py +1 -1
- synth_ai/environments/examples/verilog/engine.py +76 -10
- synth_ai/judge_schemas.py +8 -8
- synth_ai/task/__init__.py +11 -1
- synth_ai/task/apps/__init__.py +1 -0
- synth_ai/task/config.py +257 -0
- synth_ai/task/contracts.py +15 -2
- synth_ai/task/rubrics/__init__.py +3 -0
- synth_ai/task/rubrics/loaders.py +22 -3
- synth_ai/task/rubrics/scoring.py +3 -0
- synth_ai/task/trace_correlation_helpers.py +315 -0
- synth_ai/task/validators.py +144 -0
- synth_ai/tracing_v3/abstractions.py +3 -3
- synth_ai/tracing_v3/llm_call_record_helpers.py +5 -5
- synth_ai/tracing_v3/session_tracer.py +16 -6
- synth_ai/tracing_v3/storage/base.py +29 -29
- synth_ai/tracing_v3/storage/config.py +3 -3
- synth_ai/tracing_v3/turso/daemon.py +8 -7
- synth_ai/tracing_v3/turso/native_manager.py +63 -40
- synth_ai/tracing_v3/utils.py +3 -3
- synth_ai/tui/__init__.py +5 -0
- synth_ai/tui/__main__.py +13 -0
- synth_ai/tui/cli/__init__.py +1 -0
- synth_ai/tui/cli/query_experiments.py +164 -0
- synth_ai/tui/cli/query_experiments_v3.py +164 -0
- synth_ai/tui/dashboard.py +906 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/METADATA +1 -1
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/RECORD +110 -71
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.13.dev2.dist-info → synth_ai-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
import json
|
|
5
6
|
import logging
|
|
6
7
|
import os
|
|
7
8
|
import sys
|
|
@@ -11,11 +12,12 @@ from pathlib import Path
|
|
|
11
12
|
from typing import Any
|
|
12
13
|
|
|
13
14
|
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
14
|
-
from synth_ai.task.contracts import RolloutMetrics, RolloutRequest, RolloutResponse, TaskInfo
|
|
15
|
+
from synth_ai.task.contracts import RolloutMetrics, RolloutMode, RolloutRequest, RolloutResponse, TaskInfo
|
|
15
16
|
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
16
17
|
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
17
18
|
from synth_ai.task.rubrics import load_rubric
|
|
18
19
|
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
20
|
+
from synth_ai.task.validators import normalize_inference_url
|
|
19
21
|
from synth_ai.task.tracing_utils import (
|
|
20
22
|
build_tracer_factory,
|
|
21
23
|
resolve_sft_output_dir,
|
|
@@ -24,6 +26,18 @@ from synth_ai.task.tracing_utils import (
|
|
|
24
26
|
)
|
|
25
27
|
from synth_ai.tracing_v3.session_tracer import SessionTracer
|
|
26
28
|
|
|
29
|
+
try:
|
|
30
|
+
from .synth_envs_hosted.utils import (
|
|
31
|
+
ensure_chat_completions_url,
|
|
32
|
+
extract_trace_correlation_id,
|
|
33
|
+
)
|
|
34
|
+
except Exception: # pragma: no cover - utils unavailable if optional deps missing
|
|
35
|
+
def ensure_chat_completions_url(raw_url, mode=None):
|
|
36
|
+
"""Fallback to shared utility for URL normalization."""
|
|
37
|
+
return normalize_inference_url(raw_url) if raw_url else raw_url
|
|
38
|
+
|
|
39
|
+
def extract_trace_correlation_id(_raw_url):
|
|
40
|
+
return None
|
|
27
41
|
logger = logging.getLogger(__name__)
|
|
28
42
|
|
|
29
43
|
DEFAULT_ALIAS_OPS: list[str] = ["agent", "env"] * 10
|
|
@@ -95,6 +109,110 @@ SYNTH_ENVS_HOSTED_ROOT = (TASK_APP_ROOT / "synth_envs_hosted").resolve()
|
|
|
95
109
|
EXAMPLES_ROOT = (REPO_ROOT / "examples").resolve()
|
|
96
110
|
RUBRICS_ROOT = (EXAMPLES_ROOT / "multi_step" / "rubrics").resolve()
|
|
97
111
|
|
|
112
|
+
DEFAULT_OUTCOME_RUBRIC_DATA: dict[str, Any] = {
|
|
113
|
+
"version": "1",
|
|
114
|
+
"goal_text": (
|
|
115
|
+
"Reward episodes that climb the Crafter achievement ladder, stockpile key resources "
|
|
116
|
+
"(especially wood), and finish alive with clear understanding of any failure."
|
|
117
|
+
),
|
|
118
|
+
"aggregation": "weighted_sum",
|
|
119
|
+
"criteria": [
|
|
120
|
+
{
|
|
121
|
+
"id": "achievement_progression",
|
|
122
|
+
"description": (
|
|
123
|
+
"Weigh achievements by tier: late-game unlocks (iron tools, furnace, armor) earn "
|
|
124
|
+
"the most, mid-tier crafting (stone tools, furnace prep) gets partial credit, early "
|
|
125
|
+
"tasks (collecting saplings/wood tools) only lightly scored."
|
|
126
|
+
),
|
|
127
|
+
"weight": 0.35,
|
|
128
|
+
},
|
|
129
|
+
{
|
|
130
|
+
"id": "resource_stockpile",
|
|
131
|
+
"description": (
|
|
132
|
+
"Assess resource totals with emphasis on wood stores; high scores require abundant "
|
|
133
|
+
"wood plus supporting materials (stone, coal, iron) that signal readiness for "
|
|
134
|
+
"crafting."
|
|
135
|
+
),
|
|
136
|
+
"weight": 0.2,
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
"id": "survival_state",
|
|
140
|
+
"description": (
|
|
141
|
+
"Reward finishing alive with healthy food/drink bars and safe positioning; penalize "
|
|
142
|
+
"deaths, low vitals, or lingering hazards at episode end."
|
|
143
|
+
),
|
|
144
|
+
"weight": 0.2,
|
|
145
|
+
},
|
|
146
|
+
{
|
|
147
|
+
"id": "failure_analysis",
|
|
148
|
+
"description": (
|
|
149
|
+
"If the run ends in death or timeout, clearly identify the cause and deduct unless "
|
|
150
|
+
"the agent mitigated risk; highlight when the agent survives despite danger."
|
|
151
|
+
),
|
|
152
|
+
"weight": 0.15,
|
|
153
|
+
},
|
|
154
|
+
{
|
|
155
|
+
"id": "future_readiness",
|
|
156
|
+
"description": (
|
|
157
|
+
"Describe how prepared the agent is for the next objectives (tools crafted, shelters, "
|
|
158
|
+
"furnaces, smelted materials) and whether the inventory supports further progress."
|
|
159
|
+
),
|
|
160
|
+
"weight": 0.1,
|
|
161
|
+
},
|
|
162
|
+
],
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
DEFAULT_EVENTS_RUBRIC_DATA: dict[str, Any] = {
|
|
166
|
+
"version": "1",
|
|
167
|
+
"goal_text": (
|
|
168
|
+
"Score each decision in proportion to the concrete Crafter achievement progress it "
|
|
169
|
+
"delivers, topping out the scale when the log shows a fresh achievement unlock and keeping "
|
|
170
|
+
"routine upkeep near zero."
|
|
171
|
+
),
|
|
172
|
+
"aggregation": "weighted_sum",
|
|
173
|
+
"criteria": [
|
|
174
|
+
{
|
|
175
|
+
"id": "achievement_unlocks",
|
|
176
|
+
"description": (
|
|
177
|
+
"Assign 0.9-1.0 when the decision explicitly unlocks a new Crafter achievement (look "
|
|
178
|
+
'for "Achievement unlocked" messages or equivalent deterministic completions such as '
|
|
179
|
+
"placing a furnace that immediately crafts ingots). Cap the score at 0.4 when no new "
|
|
180
|
+
"achievement fires, and drop to <=0.1 if the turn repeats known actions without "
|
|
181
|
+
"measurable progress."
|
|
182
|
+
),
|
|
183
|
+
"weight": 0.55,
|
|
184
|
+
},
|
|
185
|
+
{
|
|
186
|
+
"id": "milestone_setup",
|
|
187
|
+
"description": (
|
|
188
|
+
"Give 0.5-0.7 when the action completes the last prerequisite for a specific upcoming "
|
|
189
|
+
"achievement (e.g., gathering the final ore before smelting, crafting sticks right "
|
|
190
|
+
"before a tool). Keep the score <=0.3 if the progress is speculative or still several "
|
|
191
|
+
"steps away."
|
|
192
|
+
),
|
|
193
|
+
"weight": 0.2,
|
|
194
|
+
},
|
|
195
|
+
{
|
|
196
|
+
"id": "inventory_depth",
|
|
197
|
+
"description": (
|
|
198
|
+
"Reward 0.3-0.5 for pulls that clearly deepen critical buffers (fuel, food, ore) and "
|
|
199
|
+
"immediately unblock the next milestone. If resources are already plentiful or the "
|
|
200
|
+
"haul is generic filler, stay at <=0.2."
|
|
201
|
+
),
|
|
202
|
+
"weight": 0.15,
|
|
203
|
+
},
|
|
204
|
+
{
|
|
205
|
+
"id": "execution_quality",
|
|
206
|
+
"description": (
|
|
207
|
+
"Only add up to 0.1 for clean, legal execution that avoids wasted turns; drop to 0.0 "
|
|
208
|
+
"whenever the agent idles, repeats failed moves, or takes damage without compensating "
|
|
209
|
+
"progress."
|
|
210
|
+
),
|
|
211
|
+
"weight": 0.1,
|
|
212
|
+
},
|
|
213
|
+
],
|
|
214
|
+
}
|
|
215
|
+
|
|
98
216
|
for path in (REPO_ROOT, TASK_APP_ROOT, SYNTH_ENVS_HOSTED_ROOT, EXAMPLES_ROOT):
|
|
99
217
|
try:
|
|
100
218
|
resolved = path.resolve()
|
|
@@ -115,6 +233,28 @@ try:
|
|
|
115
233
|
except Exception:
|
|
116
234
|
pass
|
|
117
235
|
|
|
236
|
+
def _load_rubric_with_fallback(filename: str, fallback: dict[str, Any]):
|
|
237
|
+
"""Load rubric from JSON file when available, otherwise use bundled fallback."""
|
|
238
|
+
|
|
239
|
+
search_paths = [RUBRICS_ROOT / filename, TASK_APP_ROOT / "rubrics" / filename]
|
|
240
|
+
for path in search_paths:
|
|
241
|
+
try:
|
|
242
|
+
if path.exists():
|
|
243
|
+
logger.debug("Loading rubric from %s", path)
|
|
244
|
+
return load_rubric(str(path))
|
|
245
|
+
except Exception as exc:
|
|
246
|
+
logger.warning("Failed to load rubric %s from %s: %s", filename, path, exc)
|
|
247
|
+
|
|
248
|
+
logger.warning("Falling back to inline rubric %s: file not available", filename)
|
|
249
|
+
try:
|
|
250
|
+
materialized = search_paths[0]
|
|
251
|
+
materialized.parent.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
materialized.write_text(json.dumps(fallback, indent=2), encoding="utf-8")
|
|
253
|
+
except Exception:
|
|
254
|
+
logger.debug("Unable to materialize inline rubric %s", filename, exc_info=True)
|
|
255
|
+
return load_rubric(fallback)
|
|
256
|
+
|
|
257
|
+
|
|
118
258
|
HAS_HOSTED = True
|
|
119
259
|
try:
|
|
120
260
|
import crafter # type: ignore
|
|
@@ -343,9 +483,13 @@ def _base_task_info(dataset: CrafterDataset) -> TaskInfo:
|
|
|
343
483
|
)
|
|
344
484
|
|
|
345
485
|
|
|
346
|
-
OUTCOME_RUBRIC =
|
|
486
|
+
OUTCOME_RUBRIC = _load_rubric_with_fallback(
|
|
487
|
+
"crafter_outcome_rubric.json", DEFAULT_OUTCOME_RUBRIC_DATA
|
|
488
|
+
)
|
|
347
489
|
|
|
348
|
-
EVENTS_RUBRIC =
|
|
490
|
+
EVENTS_RUBRIC = _load_rubric_with_fallback(
|
|
491
|
+
"crafter_events_rubric.json", DEFAULT_EVENTS_RUBRIC_DATA
|
|
492
|
+
)
|
|
349
493
|
|
|
350
494
|
|
|
351
495
|
def describe_taskset(dataset: CrafterDataset) -> dict[str, Any]:
|
|
@@ -493,9 +637,94 @@ def _coerce_math_to_crafter(request: RolloutRequest) -> RolloutRequest:
|
|
|
493
637
|
return coerced
|
|
494
638
|
|
|
495
639
|
|
|
640
|
+
def _resolve_trace_correlation_id(policy_cfg: dict[str, Any], mode: Any = None) -> str | None:
|
|
641
|
+
"""Best-effort extraction of the trace correlation identifier."""
|
|
642
|
+
candidates: list[Any] = [
|
|
643
|
+
policy_cfg.get("trace_correlation_id"),
|
|
644
|
+
policy_cfg.get("trace"),
|
|
645
|
+
]
|
|
646
|
+
logger.debug(
|
|
647
|
+
"_resolve_trace_correlation_id: inspecting policy_cfg keys=%s candidates=%s",
|
|
648
|
+
sorted(policy_cfg.keys()),
|
|
649
|
+
candidates,
|
|
650
|
+
)
|
|
651
|
+
for candidate in candidates:
|
|
652
|
+
if isinstance(candidate, str):
|
|
653
|
+
stripped = candidate.strip()
|
|
654
|
+
if stripped:
|
|
655
|
+
return stripped
|
|
656
|
+
|
|
657
|
+
return extract_trace_correlation_id(policy_cfg.get("inference_url"), mode=mode)
|
|
658
|
+
|
|
659
|
+
|
|
496
660
|
async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutResponse:
|
|
661
|
+
request = _coerce_math_to_crafter(request)
|
|
662
|
+
|
|
663
|
+
policy_cfg = dict(request.policy.config or {})
|
|
664
|
+
logger.info(
|
|
665
|
+
"ROLLOUT_EXEC: incoming policy config keys=%s inference_url=%s run_id=%s mode=%s",
|
|
666
|
+
sorted(policy_cfg.keys()),
|
|
667
|
+
policy_cfg.get("inference_url"),
|
|
668
|
+
request.run_id,
|
|
669
|
+
request.mode,
|
|
670
|
+
)
|
|
671
|
+
inferred_url = ensure_chat_completions_url(policy_cfg.get("inference_url"), mode=request.mode)
|
|
672
|
+
if isinstance(inferred_url, str) and inferred_url:
|
|
673
|
+
if inferred_url != policy_cfg.get("inference_url"):
|
|
674
|
+
logger.warning(
|
|
675
|
+
"ROLLOUT_EXEC: normalized inference_url run_id=%s from %s to %s",
|
|
676
|
+
request.run_id,
|
|
677
|
+
policy_cfg.get("inference_url"),
|
|
678
|
+
inferred_url,
|
|
679
|
+
)
|
|
680
|
+
policy_cfg["inference_url"] = inferred_url
|
|
681
|
+
else:
|
|
682
|
+
logger.warning(
|
|
683
|
+
"ROLLOUT_EXEC: inference_url missing or not normalized run_id=%s raw=%s",
|
|
684
|
+
request.run_id,
|
|
685
|
+
policy_cfg.get("inference_url"),
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
trace_correlation_id = _resolve_trace_correlation_id(policy_cfg, mode=request.mode)
|
|
689
|
+
|
|
690
|
+
# ASSERTION: trace_correlation_id MUST be present for RL mode (but not EVAL mode)
|
|
691
|
+
if request.mode == RolloutMode.RL:
|
|
692
|
+
assert trace_correlation_id is not None, (
|
|
693
|
+
f"FATAL: trace_correlation_id extraction failed for run_id={request.run_id}. "
|
|
694
|
+
f"policy_cfg_keys={sorted(policy_cfg.keys())} "
|
|
695
|
+
f"inference_url={policy_cfg.get('inference_url')}"
|
|
696
|
+
)
|
|
697
|
+
assert isinstance(trace_correlation_id, str) and trace_correlation_id.strip(), (
|
|
698
|
+
f"FATAL: trace_correlation_id is empty for run_id={request.run_id}. "
|
|
699
|
+
f"Got: {trace_correlation_id!r}"
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
if trace_correlation_id:
|
|
703
|
+
policy_cfg["trace_correlation_id"] = trace_correlation_id
|
|
704
|
+
logger.info(
|
|
705
|
+
"ROLLOUT_EXEC: resolved trace_correlation_id=%s run_id=%s",
|
|
706
|
+
trace_correlation_id,
|
|
707
|
+
request.run_id,
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
pipeline_metadata: dict[str, Any] = {}
|
|
711
|
+
if trace_correlation_id:
|
|
712
|
+
pipeline_metadata["trace_correlation_id"] = trace_correlation_id
|
|
713
|
+
if isinstance(policy_cfg.get("inference_url"), str) and policy_cfg["inference_url"]:
|
|
714
|
+
pipeline_metadata.setdefault("inference_url", policy_cfg["inference_url"])
|
|
715
|
+
logger.info(
|
|
716
|
+
"ROLLOUT_EXEC: pipeline metadata prepared run_id=%s metadata=%s",
|
|
717
|
+
request.run_id,
|
|
718
|
+
pipeline_metadata,
|
|
719
|
+
)
|
|
720
|
+
|
|
497
721
|
# If hosted env service code is not bundled, return a no-op rollout response compatible with contracts
|
|
498
722
|
if not HAS_HOSTED:
|
|
723
|
+
logger.warning(
|
|
724
|
+
"ROLLOUT_EXEC: HAS_HOSTED disabled, returning stub response run_id=%s metadata=%s",
|
|
725
|
+
request.run_id,
|
|
726
|
+
pipeline_metadata,
|
|
727
|
+
)
|
|
499
728
|
return RolloutResponse(
|
|
500
729
|
run_id=request.run_id,
|
|
501
730
|
trajectories=[],
|
|
@@ -510,11 +739,10 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
510
739
|
aborted=False,
|
|
511
740
|
ops_executed=0,
|
|
512
741
|
trace=None,
|
|
742
|
+
trace_correlation_id=trace_correlation_id or f"trace_{request.run_id}",
|
|
743
|
+
pipeline_metadata=pipeline_metadata,
|
|
513
744
|
)
|
|
514
745
|
|
|
515
|
-
request = _coerce_math_to_crafter(request)
|
|
516
|
-
|
|
517
|
-
policy_cfg = dict(request.policy.config or {})
|
|
518
746
|
try:
|
|
519
747
|
max_llm_calls = int(policy_cfg.get("max_llm_calls") or 10)
|
|
520
748
|
except Exception:
|
|
@@ -545,6 +773,7 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
545
773
|
converted_ops = converted_ops[:max_ops_allowed]
|
|
546
774
|
legacy_request = LegacyRolloutRequest(
|
|
547
775
|
run_id=request.run_id,
|
|
776
|
+
mode=request.mode, # Preserve mode for nested requests
|
|
548
777
|
env=LegacyRolloutEnvSpec(
|
|
549
778
|
env_id=request.env.env_id,
|
|
550
779
|
env_name=request.env.env_name,
|
|
@@ -568,12 +797,79 @@ async def rollout_executor(request: RolloutRequest, fastapi_request) -> RolloutR
|
|
|
568
797
|
legacy_response: LegacyRolloutResponse = await legacy_execute_rollout(
|
|
569
798
|
legacy_request, fastapi_request
|
|
570
799
|
)
|
|
800
|
+
logger.info(
|
|
801
|
+
"ROLLOUT_EXEC: legacy rollout completed run_id=%s trace_id=%s",
|
|
802
|
+
request.run_id,
|
|
803
|
+
trace_correlation_id,
|
|
804
|
+
)
|
|
571
805
|
data = legacy_response.model_dump()
|
|
572
806
|
metrics = data.get("metrics", {}) or {}
|
|
573
807
|
metrics.setdefault("outcome_score", None)
|
|
574
808
|
metrics.setdefault("events_score", None)
|
|
575
809
|
metrics.setdefault("details", {})
|
|
576
810
|
data["metrics"] = metrics
|
|
811
|
+
|
|
812
|
+
# Add trace_correlation_id at TOP-LEVEL (REQUIRED for RL training pipeline)
|
|
813
|
+
# Use fallback if somehow missing
|
|
814
|
+
data["trace_correlation_id"] = trace_correlation_id or f"trace_{request.run_id}"
|
|
815
|
+
|
|
816
|
+
# Add trace_correlation_id to pipeline_metadata
|
|
817
|
+
existing_meta = data.get("pipeline_metadata")
|
|
818
|
+
if not isinstance(existing_meta, dict):
|
|
819
|
+
existing_meta = {}
|
|
820
|
+
# ALWAYS set trace_correlation_id (use fallback if needed)
|
|
821
|
+
final_cid = trace_correlation_id or f"trace_{request.run_id}"
|
|
822
|
+
existing_meta["trace_correlation_id"] = final_cid
|
|
823
|
+
if isinstance(policy_cfg.get("inference_url"), str) and policy_cfg["inference_url"]:
|
|
824
|
+
existing_meta.setdefault("inference_url", policy_cfg["inference_url"])
|
|
825
|
+
data["pipeline_metadata"] = existing_meta
|
|
826
|
+
|
|
827
|
+
# Add trace_correlation_id to each trajectory (required for RL training pipeline)
|
|
828
|
+
if "trajectories" in data:
|
|
829
|
+
for traj in data.get("trajectories", []):
|
|
830
|
+
if isinstance(traj, dict):
|
|
831
|
+
traj["trace_correlation_id"] = final_cid
|
|
832
|
+
logger.info(
|
|
833
|
+
"ROLLOUT_EXEC: final pipeline metadata run_id=%s metadata=%s",
|
|
834
|
+
request.run_id,
|
|
835
|
+
existing_meta,
|
|
836
|
+
)
|
|
837
|
+
if trace_correlation_id and existing_meta.get("trace_correlation_id") != trace_correlation_id:
|
|
838
|
+
logger.error(
|
|
839
|
+
"ROLLOUT_EXEC: metadata trace mismatch run_id=%s expected=%s actual=%s",
|
|
840
|
+
request.run_id,
|
|
841
|
+
trace_correlation_id,
|
|
842
|
+
existing_meta.get("trace_correlation_id"),
|
|
843
|
+
)
|
|
844
|
+
if not existing_meta.get("trace_correlation_id"):
|
|
845
|
+
logger.error(
|
|
846
|
+
"ROLLOUT_EXEC: final metadata missing trace_correlation_id run_id=%s metadata=%s",
|
|
847
|
+
request.run_id,
|
|
848
|
+
existing_meta,
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
# ASSERTION: Verify trace_correlation_id is present in response at all required levels
|
|
852
|
+
assert "trace_correlation_id" in data, (
|
|
853
|
+
f"FATAL: trace_correlation_id missing from top-level response data for run_id={request.run_id}. "
|
|
854
|
+
f"Keys: {list(data.keys())}"
|
|
855
|
+
)
|
|
856
|
+
assert data["trace_correlation_id"] == final_cid, (
|
|
857
|
+
f"FATAL: trace_correlation_id mismatch in response for run_id={request.run_id}. "
|
|
858
|
+
f"Expected: {final_cid!r}, Got: {data.get('trace_correlation_id')!r}"
|
|
859
|
+
)
|
|
860
|
+
assert "pipeline_metadata" in data, (
|
|
861
|
+
f"FATAL: pipeline_metadata missing from response for run_id={request.run_id}"
|
|
862
|
+
)
|
|
863
|
+
assert data["pipeline_metadata"].get("trace_correlation_id") == final_cid, (
|
|
864
|
+
f"FATAL: trace_correlation_id missing or mismatched in pipeline_metadata for run_id={request.run_id}. "
|
|
865
|
+
f"Expected: {final_cid!r}, Got: {data['pipeline_metadata'].get('trace_correlation_id')!r}"
|
|
866
|
+
)
|
|
867
|
+
logger.info(
|
|
868
|
+
"ROLLOUT_EXEC: assertions passed - trace_correlation_id present in response run_id=%s cid=%s",
|
|
869
|
+
request.run_id,
|
|
870
|
+
final_cid,
|
|
871
|
+
)
|
|
872
|
+
|
|
577
873
|
return RolloutResponse.model_validate(data)
|
|
578
874
|
|
|
579
875
|
|
|
@@ -617,7 +913,7 @@ def build_config() -> TaskAppConfig:
|
|
|
617
913
|
routers: tuple = (environment_router, policy_router, branching_router) if HAS_HOSTED else ()
|
|
618
914
|
|
|
619
915
|
config = TaskAppConfig(
|
|
620
|
-
app_id="grpo-crafter",
|
|
916
|
+
app_id="grpo-crafter-task-app",
|
|
621
917
|
name="GRPO Crafter Task App",
|
|
622
918
|
description="Crafter Classic environment with GRPO task endpoints and LLM proxies.",
|
|
623
919
|
base_task_info=base_info,
|
|
@@ -638,7 +934,7 @@ def build_config() -> TaskAppConfig:
|
|
|
638
934
|
|
|
639
935
|
register_task_app(
|
|
640
936
|
entry=TaskAppEntry(
|
|
641
|
-
app_id="grpo-crafter",
|
|
937
|
+
app_id="grpo-crafter-task-app",
|
|
642
938
|
description="Crafter Classic task app with rollout + proxy endpoints",
|
|
643
939
|
config_factory=build_config,
|
|
644
940
|
aliases=("crafter", "crafter-task"),
|
|
@@ -665,6 +961,8 @@ register_task_app(
|
|
|
665
961
|
(str(REPO_ROOT), "/opt/synth_ai_repo"),
|
|
666
962
|
(str(REPO_ROOT / "synth_ai"), "/opt/synth_ai_repo/synth_ai"),
|
|
667
963
|
(str(TASK_APP_ROOT), "/opt/synth_ai_repo/examples/task_apps/crafter/task_app"),
|
|
964
|
+
# Explicitly mount rubrics directory
|
|
965
|
+
(str(RUBRICS_ROOT), "/opt/synth_ai_repo/examples/multi_step/rubrics"),
|
|
668
966
|
),
|
|
669
967
|
secret_names=("groq-api-key", "openai-api-key"),
|
|
670
968
|
memory=16384,
|
|
@@ -209,6 +209,16 @@ class CrafterEnvironmentWrapper:
|
|
|
209
209
|
logger.info("No valid actions provided, defaulting to noop")
|
|
210
210
|
normalized.append(EnvToolCall(tool="interact", args={"action": 0})) # noop action
|
|
211
211
|
|
|
212
|
+
# Limit to first 20 actions to prevent spam from overly long tool calls
|
|
213
|
+
MAX_ACTIONS_PER_STEP = 20
|
|
214
|
+
if len(normalized) > MAX_ACTIONS_PER_STEP:
|
|
215
|
+
logger.warning(
|
|
216
|
+
"Tool call contained %d actions, limiting to first %d to prevent spam",
|
|
217
|
+
len(normalized),
|
|
218
|
+
MAX_ACTIONS_PER_STEP,
|
|
219
|
+
)
|
|
220
|
+
normalized = normalized[:MAX_ACTIONS_PER_STEP]
|
|
221
|
+
|
|
212
222
|
# Pre-step logging: capture current public state and print concise summary
|
|
213
223
|
before_state: dict[str, Any] | None = None
|
|
214
224
|
try:
|
|
@@ -45,6 +45,7 @@ class CrafterPolicy(Policy):
|
|
|
45
45
|
self.model = model
|
|
46
46
|
self.use_tools = True
|
|
47
47
|
self.use_vision = False # Enable vision for VLMs
|
|
48
|
+
self.image_only_mode = False # If True, only send images without text observations
|
|
48
49
|
# Sampling parameters (populated via initialize(config))
|
|
49
50
|
self.temperature: float | None = None
|
|
50
51
|
self.top_p: float | None = None
|
|
@@ -66,6 +67,11 @@ class CrafterPolicy(Policy):
|
|
|
66
67
|
self.use_tools = bool(config["use_tools"])
|
|
67
68
|
if "use_vision" in config:
|
|
68
69
|
self.use_vision = bool(config["use_vision"])
|
|
70
|
+
if "image_only_mode" in config:
|
|
71
|
+
self.image_only_mode = bool(config["image_only_mode"])
|
|
72
|
+
# If image_only_mode is enabled, automatically enable vision
|
|
73
|
+
if self.image_only_mode:
|
|
74
|
+
self.use_vision = True
|
|
69
75
|
# Auto-detect vision capability from model name if not explicitly set
|
|
70
76
|
if "use_vision" not in config and self.model:
|
|
71
77
|
self.use_vision = self._is_vision_model(self.model)
|
|
@@ -417,14 +423,21 @@ class CrafterPolicy(Policy):
|
|
|
417
423
|
"""Prepare an inference request (implementing abstract method)."""
|
|
418
424
|
# Format observation with rich contextual information
|
|
419
425
|
observation_text = self._format_observation_for_llm(observation)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
426
|
+
|
|
427
|
+
# Extract image parts based on vision settings
|
|
428
|
+
if self.use_vision:
|
|
429
|
+
image_parts = self._extract_image_parts(observation)
|
|
430
|
+
else:
|
|
431
|
+
# Text-only mode: don't include any images
|
|
432
|
+
image_parts = []
|
|
433
|
+
|
|
434
|
+
# Build messages with appropriate mode
|
|
423
435
|
messages = CrafterReActAgent.build_messages(
|
|
424
436
|
observation=observation_text,
|
|
425
437
|
history=history,
|
|
426
438
|
turn=self.turn_index,
|
|
427
439
|
image_parts=image_parts,
|
|
440
|
+
image_only_mode=self.image_only_mode,
|
|
428
441
|
)
|
|
429
442
|
|
|
430
443
|
# Return messages and tools schema
|
|
@@ -85,8 +85,17 @@ class CrafterReActAgent:
|
|
|
85
85
|
history: list[dict[str, Any]] | None = None,
|
|
86
86
|
turn: int | None = None,
|
|
87
87
|
image_parts: list[dict[str, Any]] | None = None,
|
|
88
|
+
image_only_mode: bool = False,
|
|
88
89
|
) -> list[dict[str, Any]]:
|
|
89
|
-
"""Construct OpenAI-style messages list for vLLM generation.
|
|
90
|
+
"""Construct OpenAI-style messages list for vLLM generation.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
observation: Text observation to include
|
|
94
|
+
history: Previous conversation history
|
|
95
|
+
turn: Current turn number
|
|
96
|
+
image_parts: Image content parts in OpenAI format
|
|
97
|
+
image_only_mode: If True, only include images without text observation
|
|
98
|
+
"""
|
|
90
99
|
msgs: list[dict[str, Any]] = [
|
|
91
100
|
{"role": "system", "content": CrafterReActAgent.get_system_prompt()}
|
|
92
101
|
]
|
|
@@ -94,8 +103,14 @@ class CrafterReActAgent:
|
|
|
94
103
|
msgs.extend(history)
|
|
95
104
|
user_content: Any
|
|
96
105
|
if image_parts:
|
|
97
|
-
|
|
106
|
+
# Image-only mode: send only images without text observation
|
|
107
|
+
if image_only_mode:
|
|
108
|
+
user_content = list(image_parts)
|
|
109
|
+
else:
|
|
110
|
+
# Normal vision mode: send both text and images
|
|
111
|
+
user_content = [{"type": "text", "text": observation}] + list(image_parts)
|
|
98
112
|
else:
|
|
113
|
+
# Text-only mode (default): no images
|
|
99
114
|
user_content = observation
|
|
100
115
|
msgs.append({"role": "user", "content": user_content})
|
|
101
116
|
return msgs
|
|
@@ -149,7 +149,11 @@ class OpenAIClient:
|
|
|
149
149
|
OpenAI-compatible chat completion response
|
|
150
150
|
"""
|
|
151
151
|
base = (base_url or self.base_url).rstrip("/")
|
|
152
|
-
|
|
152
|
+
# Don't append /v1/chat/completions if the URL already contains it
|
|
153
|
+
if "/v1/chat/completions" in base:
|
|
154
|
+
url = base
|
|
155
|
+
else:
|
|
156
|
+
url = base + "/v1/chat/completions"
|
|
153
157
|
timeout = timeout_s or self.timeout_s
|
|
154
158
|
|
|
155
159
|
# Merge headers
|
|
@@ -164,10 +168,28 @@ class OpenAIClient:
|
|
|
164
168
|
except Exception:
|
|
165
169
|
pass
|
|
166
170
|
|
|
167
|
-
#
|
|
171
|
+
# Set Authorization header based on the target URL
|
|
168
172
|
try:
|
|
169
173
|
low_url = (url or "").lower()
|
|
170
|
-
|
|
174
|
+
|
|
175
|
+
# If calling OpenAI directly (api.openai.com)
|
|
176
|
+
if "api.openai.com" in low_url:
|
|
177
|
+
openai_key = os.getenv("OPENAI_API_KEY")
|
|
178
|
+
if openai_key and isinstance(openai_key, str):
|
|
179
|
+
headers["Authorization"] = f"Bearer {openai_key}"
|
|
180
|
+
|
|
181
|
+
# If target is Synth backend (any deployment), use SYNTH_API_KEY
|
|
182
|
+
# Matches: synth-backend-*, agent-learning*, localhost:8000, 127.0.0.1:8000
|
|
183
|
+
elif any(pattern in low_url for pattern in [
|
|
184
|
+
"synth-backend", "synth.run", "agent-learning",
|
|
185
|
+
"localhost:8000", "127.0.0.1:8000"
|
|
186
|
+
]):
|
|
187
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
188
|
+
if synth_key and isinstance(synth_key, str):
|
|
189
|
+
headers["Authorization"] = f"Bearer {synth_key}"
|
|
190
|
+
|
|
191
|
+
# If target is Groq, use GROQ_API_KEY
|
|
192
|
+
elif "/proxy/groq" in low_url or "api.groq.com" in low_url:
|
|
171
193
|
gk = os.getenv("GROQ_API_KEY")
|
|
172
194
|
if gk and isinstance(gk, str):
|
|
173
195
|
headers["Authorization"] = f"Bearer {gk}"
|
|
@@ -10,11 +10,13 @@ from fastapi import APIRouter, HTTPException, Request
|
|
|
10
10
|
from pydantic import BaseModel
|
|
11
11
|
|
|
12
12
|
from synth_ai.task.auth import allowed_environment_api_keys, normalize_environment_api_key
|
|
13
|
+
from synth_ai.task.contracts import RolloutMode
|
|
13
14
|
|
|
14
15
|
from .envs.crafter.policy import CrafterPolicy
|
|
15
16
|
from .inference.openai_client import create_inference_client
|
|
16
17
|
from .registry import registry
|
|
17
18
|
from .storage.volume import storage
|
|
19
|
+
from .utils import ensure_chat_completions_url
|
|
18
20
|
|
|
19
21
|
# Token budgeting (shared logic with inference server)
|
|
20
22
|
try:
|
|
@@ -40,6 +42,7 @@ class PolicyCreateRequest(BaseModel):
|
|
|
40
42
|
parent_policy_id: str | None = None
|
|
41
43
|
rl_run_id: str
|
|
42
44
|
bound_env_id: str | None = None
|
|
45
|
+
mode: RolloutMode
|
|
43
46
|
|
|
44
47
|
|
|
45
48
|
class PolicyCreateResponse(BaseModel):
|
|
@@ -119,6 +122,14 @@ async def create_policy(
|
|
|
119
122
|
config.setdefault("inference_url", f"{base_url}/proxy")
|
|
120
123
|
config["provider"] = "openai"
|
|
121
124
|
|
|
125
|
+
received_url = config.get("inference_url")
|
|
126
|
+
logger.info(
|
|
127
|
+
"POLICY_CREATE: policy=%s provider=%s raw_inference_url=%s",
|
|
128
|
+
request.policy_name,
|
|
129
|
+
provider,
|
|
130
|
+
received_url,
|
|
131
|
+
)
|
|
132
|
+
|
|
122
133
|
if "inference_url" not in config and task_app is not None:
|
|
123
134
|
task_base_url = getattr(task_app, "vllm_base_url", None)
|
|
124
135
|
if task_base_url:
|
|
@@ -133,6 +144,31 @@ async def create_policy(
|
|
|
133
144
|
detail="Policy configuration must include 'inference_url' and 'model'.",
|
|
134
145
|
)
|
|
135
146
|
|
|
147
|
+
# Get mode from PolicyCreateRequest (defaults to "rl" for backward compatibility)
|
|
148
|
+
mode = request.mode
|
|
149
|
+
logger.info("POLICY_CREATE: Using mode=%s for URL processing", mode)
|
|
150
|
+
|
|
151
|
+
sanitized_url = ensure_chat_completions_url(config.get("inference_url"), mode=mode)
|
|
152
|
+
if isinstance(sanitized_url, str) and sanitized_url:
|
|
153
|
+
if sanitized_url != config.get("inference_url"):
|
|
154
|
+
logger.warning(
|
|
155
|
+
"POLICY_CREATE: normalized inference_url for policy=%s provider=%s mode=%s from %s to %s",
|
|
156
|
+
request.policy_name,
|
|
157
|
+
provider,
|
|
158
|
+
mode,
|
|
159
|
+
config.get("inference_url"),
|
|
160
|
+
sanitized_url,
|
|
161
|
+
)
|
|
162
|
+
config["inference_url"] = sanitized_url
|
|
163
|
+
else:
|
|
164
|
+
logger.warning(
|
|
165
|
+
"POLICY_CREATE: unable to normalize inference_url for policy=%s provider=%s mode=%s raw=%s",
|
|
166
|
+
request.policy_name,
|
|
167
|
+
mode,
|
|
168
|
+
provider,
|
|
169
|
+
config.get("inference_url"),
|
|
170
|
+
)
|
|
171
|
+
|
|
136
172
|
# Create policy instance based on name
|
|
137
173
|
pname = request.policy_name.lower()
|
|
138
174
|
if pname in ["crafter-react", "crafter"]:
|
|
@@ -507,7 +543,22 @@ async def step_policy(
|
|
|
507
543
|
|
|
508
544
|
# Ensure meta carries the final target URL for downstream logging/clients
|
|
509
545
|
with contextlib.suppress(Exception):
|
|
510
|
-
|
|
546
|
+
sanitized_target = ensure_chat_completions_url(target_url)
|
|
547
|
+
if sanitized_target and sanitized_target != target_url:
|
|
548
|
+
logger.warning(
|
|
549
|
+
"POLICY_STEP: normalized inference_url mid-flight policy=%s from %s to %s",
|
|
550
|
+
policy_name,
|
|
551
|
+
target_url,
|
|
552
|
+
sanitized_target,
|
|
553
|
+
)
|
|
554
|
+
elif not sanitized_target:
|
|
555
|
+
logger.info(
|
|
556
|
+
"POLICY_STEP: inference_url unchanged policy=%s target=%s",
|
|
557
|
+
policy_name,
|
|
558
|
+
target_url,
|
|
559
|
+
)
|
|
560
|
+
meta["inference_url"] = sanitized_target if sanitized_target else target_url
|
|
561
|
+
target_url = sanitized_target or target_url
|
|
511
562
|
|
|
512
563
|
# Select API key based on resolved target URL
|
|
513
564
|
api_key_override = None
|