synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Run a local Crafter rollout, capture tracing metadata, and optionally persist the trace."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import sys
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
from synth_ai.task import (
|
|
17
|
+
RolloutEnvSpec,
|
|
18
|
+
RolloutPolicySpec,
|
|
19
|
+
RolloutRecordConfig,
|
|
20
|
+
RolloutRequest,
|
|
21
|
+
RolloutSafetyConfig,
|
|
22
|
+
TaskAppClient,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def build_rollout_request(
|
|
27
|
+
*,
|
|
28
|
+
seed: int,
|
|
29
|
+
run_id: str,
|
|
30
|
+
model: str,
|
|
31
|
+
inference_url: str,
|
|
32
|
+
ops: list[str],
|
|
33
|
+
return_trace: bool,
|
|
34
|
+
trace_format: str,
|
|
35
|
+
max_policy_tokens: int | None,
|
|
36
|
+
) -> RolloutRequest:
|
|
37
|
+
policy_config = {
|
|
38
|
+
"model": model,
|
|
39
|
+
"inference_url": inference_url,
|
|
40
|
+
}
|
|
41
|
+
if max_policy_tokens is not None:
|
|
42
|
+
policy_config.update(
|
|
43
|
+
{
|
|
44
|
+
"max_completion_tokens": max_policy_tokens,
|
|
45
|
+
"max_tokens": max_policy_tokens,
|
|
46
|
+
}
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
record = RolloutRecordConfig(
|
|
50
|
+
trajectories=True,
|
|
51
|
+
return_trace=return_trace,
|
|
52
|
+
trace_format=trace_format,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
return RolloutRequest(
|
|
56
|
+
run_id=run_id,
|
|
57
|
+
env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
|
|
58
|
+
policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
|
|
59
|
+
ops=ops,
|
|
60
|
+
record=record,
|
|
61
|
+
on_done="reset",
|
|
62
|
+
safety=RolloutSafetyConfig(),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def summarise_rollout(response: Any) -> dict[str, Any]:
|
|
67
|
+
metrics = response.metrics.model_dump() if hasattr(response, "metrics") else response.get("metrics", {})
|
|
68
|
+
return {
|
|
69
|
+
"run_id": getattr(response, "run_id", None) or response.get("run_id"),
|
|
70
|
+
"num_episodes": metrics.get("num_episodes"),
|
|
71
|
+
"num_steps": metrics.get("num_steps"),
|
|
72
|
+
"episode_returns": metrics.get("episode_returns"),
|
|
73
|
+
"outcome_score": metrics.get("outcome_score"),
|
|
74
|
+
"events_score": metrics.get("events_score"),
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def summarise_trace(trace: Any) -> dict[str, Any]:
|
|
79
|
+
if trace is None:
|
|
80
|
+
return {"trace": None}
|
|
81
|
+
if not isinstance(trace, dict):
|
|
82
|
+
return {"trace_type": type(trace).__name__}
|
|
83
|
+
|
|
84
|
+
format_hint = "compact" if "events_count" in trace or "lm_calls" in trace else "full"
|
|
85
|
+
events_count = trace.get("events_count")
|
|
86
|
+
if events_count is None and "event_history" in trace and isinstance(trace["event_history"], list):
|
|
87
|
+
events_count = len(trace["event_history"])
|
|
88
|
+
messages_count = trace.get("messages_count")
|
|
89
|
+
if messages_count is None and "markov_blanket_message_history" in trace and isinstance(
|
|
90
|
+
trace["markov_blanket_message_history"], list
|
|
91
|
+
):
|
|
92
|
+
messages_count = len(trace["markov_blanket_message_history"])
|
|
93
|
+
|
|
94
|
+
metadata = trace.get("metadata") if isinstance(trace.get("metadata"), dict) else {}
|
|
95
|
+
lm_calls = trace.get("lm_calls") if isinstance(trace.get("lm_calls"), list) else []
|
|
96
|
+
decision_rewards = trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
"session_id": trace.get("session_id"),
|
|
100
|
+
"format": format_hint,
|
|
101
|
+
"events_count": events_count,
|
|
102
|
+
"messages_count": messages_count,
|
|
103
|
+
"metadata_keys": sorted(metadata.keys()),
|
|
104
|
+
"lm_calls_count": len(lm_calls),
|
|
105
|
+
"decision_turns": len(decision_rewards),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def ensure_ops(ops_arg: str | None, max_llm_calls: int) -> list[str]:
|
|
110
|
+
if ops_arg:
|
|
111
|
+
ops = [op.strip() for op in ops_arg.split(",") if op.strip()]
|
|
112
|
+
if not ops:
|
|
113
|
+
raise ValueError("--ops must contain at least one entry when provided")
|
|
114
|
+
return ops
|
|
115
|
+
max_llm_calls = max(max_llm_calls, 1)
|
|
116
|
+
ops: list[str] = []
|
|
117
|
+
for _ in range(max_llm_calls):
|
|
118
|
+
ops.extend(["agent", "env"])
|
|
119
|
+
return ops
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def dump_trace(trace: dict[str, Any], *, path: Path, pretty: bool) -> None:
|
|
123
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
124
|
+
with path.open("w", encoding="utf-8") as fh:
|
|
125
|
+
json.dump(trace, fh, indent=2 if pretty else None)
|
|
126
|
+
fh.write("\n")
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def extract_environment_rewards(trace_payload: dict[str, Any] | None) -> list[float]:
|
|
130
|
+
if not trace_payload:
|
|
131
|
+
return []
|
|
132
|
+
|
|
133
|
+
rewards: list[float] = []
|
|
134
|
+
|
|
135
|
+
def _collect(events: list[dict[str, Any]]) -> None:
|
|
136
|
+
for event in events:
|
|
137
|
+
reward = event.get("reward")
|
|
138
|
+
if reward is not None:
|
|
139
|
+
try:
|
|
140
|
+
rewards.append(float(reward))
|
|
141
|
+
except Exception:
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
if isinstance(trace_payload.get("event_history"), list):
|
|
145
|
+
_collect(trace_payload["event_history"])
|
|
146
|
+
if isinstance(trace_payload.get("session_time_steps"), list):
|
|
147
|
+
for step in trace_payload["session_time_steps"]:
|
|
148
|
+
_collect(step.get("events", []))
|
|
149
|
+
|
|
150
|
+
return rewards
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def extract_decision_rewards(trace_payload: dict[str, Any] | None) -> list[dict[str, Any]]:
|
|
154
|
+
if not trace_payload:
|
|
155
|
+
return []
|
|
156
|
+
rewards = trace_payload.get("decision_rewards")
|
|
157
|
+
return rewards if isinstance(rewards, list) else []
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def extract_trajectory_rewards(response: Any) -> list[float]:
|
|
161
|
+
"""Extract per-step rewards directly from the rollout trajectories."""
|
|
162
|
+
|
|
163
|
+
rewards: list[float] = []
|
|
164
|
+
|
|
165
|
+
if response is None:
|
|
166
|
+
return rewards
|
|
167
|
+
|
|
168
|
+
trajectories = getattr(response, "trajectories", None)
|
|
169
|
+
if trajectories is None and isinstance(response, dict):
|
|
170
|
+
trajectories = response.get("trajectories")
|
|
171
|
+
|
|
172
|
+
if not trajectories:
|
|
173
|
+
return rewards
|
|
174
|
+
|
|
175
|
+
for traj in trajectories:
|
|
176
|
+
steps = getattr(traj, "steps", None)
|
|
177
|
+
if steps is None and isinstance(traj, dict):
|
|
178
|
+
steps = traj.get("steps")
|
|
179
|
+
if not steps:
|
|
180
|
+
continue
|
|
181
|
+
for step in steps:
|
|
182
|
+
reward_val = getattr(step, "reward", None)
|
|
183
|
+
if reward_val is None and isinstance(step, dict):
|
|
184
|
+
reward_val = step.get("reward")
|
|
185
|
+
if reward_val is None:
|
|
186
|
+
continue
|
|
187
|
+
try:
|
|
188
|
+
rewards.append(float(reward_val))
|
|
189
|
+
except Exception:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
return rewards
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def print_reward_summary(
|
|
196
|
+
trace_payload: dict[str, Any] | None,
|
|
197
|
+
rollout_summary: dict[str, Any],
|
|
198
|
+
trajectory_rewards: list[float],
|
|
199
|
+
) -> None:
|
|
200
|
+
print("Reward summary:")
|
|
201
|
+
|
|
202
|
+
env_rewards = extract_environment_rewards(trace_payload)
|
|
203
|
+
reward_source = "trace"
|
|
204
|
+
if not env_rewards and trajectory_rewards:
|
|
205
|
+
env_rewards = trajectory_rewards
|
|
206
|
+
reward_source = "trajectory"
|
|
207
|
+
|
|
208
|
+
if env_rewards:
|
|
209
|
+
print(f" Environment rewards per step ({reward_source}): {env_rewards}")
|
|
210
|
+
print(f" Environment reward total: {sum(env_rewards):.3f}")
|
|
211
|
+
else:
|
|
212
|
+
print(" Environment rewards per step: none recorded")
|
|
213
|
+
|
|
214
|
+
decision_rewards = extract_decision_rewards(trace_payload)
|
|
215
|
+
if decision_rewards:
|
|
216
|
+
print(" Decision rewards:")
|
|
217
|
+
for entry in decision_rewards:
|
|
218
|
+
turn = entry.get('turn')
|
|
219
|
+
ach_delta = entry.get('ach_delta')
|
|
220
|
+
unique_delta = entry.get('unique_delta')
|
|
221
|
+
achievements = entry.get('achievements') or []
|
|
222
|
+
print(f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}")
|
|
223
|
+
else:
|
|
224
|
+
print(" Decision rewards: none recorded")
|
|
225
|
+
|
|
226
|
+
episode_returns = rollout_summary.get("episode_returns")
|
|
227
|
+
if episode_returns:
|
|
228
|
+
print(f" Outcome rewards (episode returns): {episode_returns}")
|
|
229
|
+
if env_rewards:
|
|
230
|
+
try:
|
|
231
|
+
total_env_reward = float(sum(env_rewards))
|
|
232
|
+
target = float(episode_returns[0]) if episode_returns else 0.0
|
|
233
|
+
if abs(total_env_reward - target) > 1e-6:
|
|
234
|
+
print(
|
|
235
|
+
" ⚠️ Reward mismatch: sum(environment rewards)"
|
|
236
|
+
f"={total_env_reward:.3f} vs episode return={target:.3f}"
|
|
237
|
+
)
|
|
238
|
+
except Exception:
|
|
239
|
+
pass
|
|
240
|
+
else:
|
|
241
|
+
print(" Outcome rewards: none recorded")
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
async def main() -> None:
|
|
245
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
246
|
+
parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
|
|
247
|
+
parser.add_argument("--api-key", required=True, help="Environment API key")
|
|
248
|
+
parser.add_argument("--seed", type=int, default=42, help="Environment seed")
|
|
249
|
+
parser.add_argument("--run-id", default="local-trace", help="Run identifier")
|
|
250
|
+
parser.add_argument("--model", default="gpt-4o-mini", help="OpenAI-compatible model id")
|
|
251
|
+
parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)")
|
|
252
|
+
parser.add_argument("--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)")
|
|
253
|
+
parser.add_argument("--max-llm-calls", type=int, default=1, help="Number of agent/env pairs when --ops not supplied")
|
|
254
|
+
parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional max token budget forwarded to policy")
|
|
255
|
+
parser.add_argument(
|
|
256
|
+
"--trace-format",
|
|
257
|
+
choices=["compact", "full"],
|
|
258
|
+
default="compact",
|
|
259
|
+
help="Trace payload format requested from the server",
|
|
260
|
+
)
|
|
261
|
+
parser.add_argument(
|
|
262
|
+
"--trace-path",
|
|
263
|
+
type=Path,
|
|
264
|
+
help="Path to write the trace JSON (defaults to ./<run_id>_trace.json unless --no-trace-file is set)",
|
|
265
|
+
)
|
|
266
|
+
parser.add_argument(
|
|
267
|
+
"--no-trace-file",
|
|
268
|
+
action="store_true",
|
|
269
|
+
help="Do not write the trace JSON to disk",
|
|
270
|
+
)
|
|
271
|
+
parser.add_argument(
|
|
272
|
+
"--no-print-trace",
|
|
273
|
+
action="store_true",
|
|
274
|
+
help="Do not print the full trace payload to stdout",
|
|
275
|
+
)
|
|
276
|
+
parser.add_argument(
|
|
277
|
+
"--no-trace",
|
|
278
|
+
action="store_true",
|
|
279
|
+
help="Disable return_trace (useful for comparing behaviour without tracing)",
|
|
280
|
+
)
|
|
281
|
+
parser.add_argument(
|
|
282
|
+
"--timeout",
|
|
283
|
+
type=float,
|
|
284
|
+
default=60.0,
|
|
285
|
+
help="HTTP timeout in seconds for the client (default: 60)",
|
|
286
|
+
)
|
|
287
|
+
args = parser.parse_args()
|
|
288
|
+
|
|
289
|
+
ops = ensure_ops(args.ops, args.max_llm_calls)
|
|
290
|
+
return_trace = not args.no_trace
|
|
291
|
+
|
|
292
|
+
async with TaskAppClient(args.base_url, api_key=args.api_key, timeout=args.timeout) as client:
|
|
293
|
+
try:
|
|
294
|
+
print(f"Fetching task_info for seed {args.seed}…")
|
|
295
|
+
task_info = await client.task_info(seeds=[args.seed])
|
|
296
|
+
info_payload = task_info[0] if isinstance(task_info, list) else task_info
|
|
297
|
+
try:
|
|
298
|
+
print(json.dumps(info_payload.model_dump(), indent=2)[:600])
|
|
299
|
+
except Exception:
|
|
300
|
+
print(info_payload)
|
|
301
|
+
|
|
302
|
+
request = build_rollout_request(
|
|
303
|
+
seed=args.seed,
|
|
304
|
+
run_id=args.run_id,
|
|
305
|
+
model=args.model,
|
|
306
|
+
inference_url=args.inference_url,
|
|
307
|
+
ops=ops,
|
|
308
|
+
return_trace=return_trace,
|
|
309
|
+
trace_format=args.trace_format,
|
|
310
|
+
max_policy_tokens=args.max_policy_tokens,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
print("Requesting rollout…")
|
|
314
|
+
response = await client.rollout(request)
|
|
315
|
+
summary = summarise_rollout(response)
|
|
316
|
+
print(json.dumps(summary, indent=2))
|
|
317
|
+
|
|
318
|
+
trace_payload: dict[str, Any] | None = getattr(response, "trace", None)
|
|
319
|
+
if return_trace:
|
|
320
|
+
if trace_payload is None:
|
|
321
|
+
print(
|
|
322
|
+
"⚠️ Server did not include a trace. Ensure TASKAPP_TRACING_ENABLED=1 when starting the task app.",
|
|
323
|
+
file=sys.stderr,
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
trace_summary = summarise_trace(trace_payload)
|
|
327
|
+
print("Trace summary:")
|
|
328
|
+
print(json.dumps(trace_summary, indent=2))
|
|
329
|
+
|
|
330
|
+
trace_path = args.trace_path
|
|
331
|
+
if not args.no_trace_file:
|
|
332
|
+
if trace_path is None:
|
|
333
|
+
trace_path = Path(f"{args.run_id}_trace.json")
|
|
334
|
+
dump_trace(trace_payload, path=trace_path, pretty=True)
|
|
335
|
+
print(f"Trace written to {trace_path}")
|
|
336
|
+
|
|
337
|
+
if not args.no_print_trace:
|
|
338
|
+
print("Full trace payload:")
|
|
339
|
+
print(json.dumps(trace_payload, indent=2))
|
|
340
|
+
|
|
341
|
+
trajectory_rewards = extract_trajectory_rewards(response)
|
|
342
|
+
print_reward_summary(
|
|
343
|
+
trace_payload if return_trace else None,
|
|
344
|
+
summary,
|
|
345
|
+
trajectory_rewards,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
print(f"Ops executed: {ops}")
|
|
349
|
+
print(
|
|
350
|
+
"Tip: export TASKAPP_TRACING_ENABLED=1 and optionally TASKAPP_SFT_OUTPUT_DIR before running `uvx synth-ai serve …` to persist traces/SFT."
|
|
351
|
+
)
|
|
352
|
+
except httpx.HTTPStatusError as exc:
|
|
353
|
+
detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
|
|
354
|
+
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
355
|
+
if exc.response.status_code in (401, 503):
|
|
356
|
+
print(
|
|
357
|
+
"Hint: ensure the task app process is using the same ENVIRONMENT_API_KEY passed via --api-key.",
|
|
358
|
+
file=sys.stderr,
|
|
359
|
+
)
|
|
360
|
+
if exc.response.status_code == 500:
|
|
361
|
+
print(
|
|
362
|
+
"Hint: verify tracing is enabled server-side (TASKAPP_TRACING_ENABLED=1) and the inference credentials are configured.",
|
|
363
|
+
file=sys.stderr,
|
|
364
|
+
)
|
|
365
|
+
raise
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
if __name__ == "__main__":
|
|
369
|
+
try:
|
|
370
|
+
asyncio.run(main())
|
|
371
|
+
except KeyboardInterrupt:
|
|
372
|
+
print("Interrupted", file=sys.stderr)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict
|
|
10
|
+
|
|
11
|
+
import tomllib
|
|
12
|
+
import requests
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_toml(path: Path) -> Dict[str, Any]:
|
|
16
|
+
if not path.exists():
|
|
17
|
+
print(f"config not found: {path}", file=sys.stderr)
|
|
18
|
+
sys.exit(2)
|
|
19
|
+
with path.open("rb") as fh:
|
|
20
|
+
return tomllib.load(fh)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def main() -> None:
|
|
24
|
+
p = argparse.ArgumentParser(description="Create clustered RL training job via backend RL endpoint")
|
|
25
|
+
p.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
|
|
26
|
+
p.add_argument("--config", required=True, help="Path to RL TOML config")
|
|
27
|
+
p.add_argument("--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL (or set TASK_APP_URL)")
|
|
28
|
+
p.add_argument("--idempotency", default=os.getenv("RL_IDEMPOTENCY_KEY", ""), help="Optional Idempotency-Key header value")
|
|
29
|
+
args = p.parse_args()
|
|
30
|
+
|
|
31
|
+
cfg_path = Path(args.config).expanduser()
|
|
32
|
+
cfg = _load_toml(cfg_path)
|
|
33
|
+
|
|
34
|
+
services = cfg.get("services", {}) if isinstance(cfg.get("services"), dict) else {}
|
|
35
|
+
|
|
36
|
+
# Resolve task app base URL for the job
|
|
37
|
+
cli_task_url = (args.task_url or "").strip()
|
|
38
|
+
env_task_url = (os.getenv("TASK_APP_URL") or "").strip()
|
|
39
|
+
task_url = cli_task_url or env_task_url or ((services.get("task_url") or "").strip() if isinstance(services, dict) else "")
|
|
40
|
+
if not task_url:
|
|
41
|
+
print("Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML", file=sys.stderr)
|
|
42
|
+
sys.exit(2)
|
|
43
|
+
|
|
44
|
+
# TOML-only model selection validation
|
|
45
|
+
model_cfg = cfg.get("model", {}) if isinstance(cfg.get("model"), dict) else {}
|
|
46
|
+
has_source = bool((model_cfg.get("source") or "").strip())
|
|
47
|
+
has_base = bool((model_cfg.get("base") or "").strip())
|
|
48
|
+
if has_source == has_base:
|
|
49
|
+
print("Model selection must specify exactly one of [model].source or [model].base in TOML", file=sys.stderr)
|
|
50
|
+
sys.exit(2)
|
|
51
|
+
|
|
52
|
+
# Build create-job payload. Send full TOML under data.config, plus endpoint_base_url.
|
|
53
|
+
payload: Dict[str, Any] = {
|
|
54
|
+
"job_type": "rl",
|
|
55
|
+
# Optional: compute pass-through
|
|
56
|
+
"compute": cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {},
|
|
57
|
+
"data": {
|
|
58
|
+
"endpoint_base_url": task_url,
|
|
59
|
+
"config": cfg,
|
|
60
|
+
},
|
|
61
|
+
"tags": {"source": "warming_up_to_rl"},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
backend = str(args.backend).rstrip("/")
|
|
65
|
+
url = f"{backend}/rl/jobs"
|
|
66
|
+
api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("synth_key") or "").strip()
|
|
67
|
+
if not api_key:
|
|
68
|
+
print("Missing SYNTH_API_KEY in env", file=sys.stderr)
|
|
69
|
+
sys.exit(2)
|
|
70
|
+
|
|
71
|
+
headers = {
|
|
72
|
+
"content-type": "application/json",
|
|
73
|
+
"authorization": f"Bearer {api_key}",
|
|
74
|
+
}
|
|
75
|
+
idem = (args.idempotency or "").strip()
|
|
76
|
+
if idem:
|
|
77
|
+
headers["Idempotency-Key"] = idem
|
|
78
|
+
|
|
79
|
+
print(f"[INFO] POST {url}")
|
|
80
|
+
try:
|
|
81
|
+
preview = dict(payload)
|
|
82
|
+
preview_data = dict(preview.get("data", {}))
|
|
83
|
+
cfg_keys = list(cfg.keys())
|
|
84
|
+
preview_data["config"] = {"keys": cfg_keys}
|
|
85
|
+
preview["data"] = preview_data
|
|
86
|
+
print(f"[INFO] Payload: {json.dumps(preview)[:500]}")
|
|
87
|
+
except Exception:
|
|
88
|
+
print("[INFO] Payload: <unavailable>")
|
|
89
|
+
|
|
90
|
+
r = requests.post(url, headers=headers, json=payload, timeout=120)
|
|
91
|
+
ok = r.status_code in (200, 201)
|
|
92
|
+
try:
|
|
93
|
+
snippet = r.json()
|
|
94
|
+
except Exception:
|
|
95
|
+
snippet = r.text[:300]
|
|
96
|
+
print(f"[INFO] Response: {r.status_code} {snippet}")
|
|
97
|
+
sys.exit(0 if ok else 1)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
if __name__ == "__main__":
|
|
101
|
+
main()
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Request a rollout from a remote Crafter task app (e.g., Modal deployment)."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import sys
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
def check_health(base_url: str, api_key: str) -> None:
|
|
15
|
+
try:
|
|
16
|
+
resp = httpx.get(f"{base_url.rstrip('/')}/health", headers={"X-API-Key": api_key}, timeout=10.0)
|
|
17
|
+
data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else resp.text
|
|
18
|
+
if resp.status_code != 200:
|
|
19
|
+
print(f"warning: /health returned {resp.status_code}: {data}")
|
|
20
|
+
else:
|
|
21
|
+
print(f"/health ok: {data}")
|
|
22
|
+
except Exception as exc:
|
|
23
|
+
print(f"warning: failed to call /health: {exc}")
|
|
24
|
+
|
|
25
|
+
from synth_ai.task import (
|
|
26
|
+
RolloutEnvSpec,
|
|
27
|
+
RolloutPolicySpec,
|
|
28
|
+
RolloutRecordConfig,
|
|
29
|
+
RolloutRequest,
|
|
30
|
+
RolloutSafetyConfig,
|
|
31
|
+
TaskAppClient,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def build_request(
|
|
36
|
+
*,
|
|
37
|
+
run_id: str,
|
|
38
|
+
seed: int,
|
|
39
|
+
model: str,
|
|
40
|
+
inference_url: str,
|
|
41
|
+
llm_calls: int,
|
|
42
|
+
max_policy_tokens: int | None,
|
|
43
|
+
) -> RolloutRequest:
|
|
44
|
+
policy_config = {"model": model, "inference_url": inference_url}
|
|
45
|
+
if max_policy_tokens is not None:
|
|
46
|
+
policy_config.update(
|
|
47
|
+
{
|
|
48
|
+
"max_completion_tokens": max_policy_tokens,
|
|
49
|
+
"max_tokens": max_policy_tokens,
|
|
50
|
+
}
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
ops: list[str] = []
|
|
54
|
+
for _ in range(max(llm_calls, 1)):
|
|
55
|
+
ops.extend(["agent", "env"])
|
|
56
|
+
|
|
57
|
+
return RolloutRequest(
|
|
58
|
+
run_id=run_id,
|
|
59
|
+
env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
|
|
60
|
+
policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
|
|
61
|
+
ops=ops,
|
|
62
|
+
record=RolloutRecordConfig(trajectories=True),
|
|
63
|
+
on_done="reset",
|
|
64
|
+
safety=RolloutSafetyConfig(),
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def summarise(response) -> dict[str, any]:
|
|
69
|
+
metrics = response.metrics
|
|
70
|
+
return {
|
|
71
|
+
"run_id": response.run_id,
|
|
72
|
+
"num_episodes": metrics.num_episodes,
|
|
73
|
+
"num_steps": metrics.num_steps,
|
|
74
|
+
"episode_returns": metrics.episode_returns,
|
|
75
|
+
"outcome_score": metrics.outcome_score,
|
|
76
|
+
"events_score": metrics.events_score,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
async def main() -> None:
|
|
81
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
82
|
+
parser.add_argument("--base-url", default=None, help="Remote task app base URL (e.g., https://xyz.modal.run); defaults to TASK_APP_BASE_URL env")
|
|
83
|
+
parser.add_argument("--api-key", required=True, help="Environment API key for the remote task app")
|
|
84
|
+
parser.add_argument("--seed", type=int, default=42)
|
|
85
|
+
parser.add_argument("--run-id", default="remote-demo")
|
|
86
|
+
parser.add_argument("--model", default="gpt-4o-mini")
|
|
87
|
+
parser.add_argument("--inference-url", default="https://api.openai.com")
|
|
88
|
+
parser.add_argument("--max-llm-calls", type=int, default=1)
|
|
89
|
+
parser.add_argument("--max-policy-tokens", type=int, default=None)
|
|
90
|
+
args = parser.parse_args()
|
|
91
|
+
|
|
92
|
+
base_url = args.base_url or os.getenv('TASK_APP_BASE_URL')
|
|
93
|
+
if not base_url:
|
|
94
|
+
parser.error('Missing --base-url (and TASK_APP_BASE_URL not set).')
|
|
95
|
+
|
|
96
|
+
request = build_request(
|
|
97
|
+
run_id=args.run_id,
|
|
98
|
+
seed=args.seed,
|
|
99
|
+
model=args.model,
|
|
100
|
+
inference_url=args.inference_url,
|
|
101
|
+
llm_calls=args.max_llm_calls,
|
|
102
|
+
max_policy_tokens=args.max_policy_tokens,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
async with TaskAppClient(base_url, api_key=args.api_key) as client:
|
|
106
|
+
try:
|
|
107
|
+
check_health(base_url, args.api_key)
|
|
108
|
+
info = await client.task_info(seeds=[args.seed])
|
|
109
|
+
payload = info[0] if isinstance(info, list) else info
|
|
110
|
+
print(json.dumps(payload.model_dump(), indent=2)[:600])
|
|
111
|
+
|
|
112
|
+
print("Requesting rollout…")
|
|
113
|
+
response = await client.rollout(request)
|
|
114
|
+
print(json.dumps(summarise(response), indent=2))
|
|
115
|
+
print(f"Ops executed: {request.ops}")
|
|
116
|
+
except httpx.HTTPStatusError as exc:
|
|
117
|
+
detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
|
|
118
|
+
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
119
|
+
if exc.response.status_code in (401, 403):
|
|
120
|
+
print("Hint: check --api-key and ensure the remote deployment expects that value.", file=sys.stderr)
|
|
121
|
+
if exc.response.status_code == 404:
|
|
122
|
+
print("Hint: verify the --base-url includes the correct path (should be the root of the task app).", file=sys.stderr)
|
|
123
|
+
if exc.response.status_code == 500:
|
|
124
|
+
print("Hint: remote rollout failed server-side; inspect the deployment logs (Modal dashboard/logs).", file=sys.stderr)
|
|
125
|
+
raise
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
if __name__ == "__main__":
|
|
129
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Crafter Task App
|
|
2
|
+
|
|
3
|
+
This example is now wired through the shared Synth task-app harness. Use the
|
|
4
|
+
`uvx synth-ai` CLI to run it locally or deploy it to Modal without touching the
|
|
5
|
+
underlying FastAPI plumbing.
|
|
6
|
+
|
|
7
|
+
## Local development
|
|
8
|
+
```bash
|
|
9
|
+
uvx synth-ai serve grpo-crafter --port 8001
|
|
10
|
+
# Optional extras:
|
|
11
|
+
# --env-file path/to/.env # load additional environment variables
|
|
12
|
+
# --reload # enable uvicorn auto-reload
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Useful endpoints while the server is running:
|
|
16
|
+
- `GET http://localhost:8001/health`
|
|
17
|
+
- `GET http://localhost:8001/info`
|
|
18
|
+
- `GET http://localhost:8001/task_info?seed=42`
|
|
19
|
+
- `POST http://localhost:8001/rollout`
|
|
20
|
+
|
|
21
|
+
## Deploy to Modal
|
|
22
|
+
```bash
|
|
23
|
+
uvx synth-ai deploy grpo-crafter --name grpo-crafter-task-app
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
Requirements:
|
|
27
|
+
- Modal CLI installed and authenticated (`modal token new`).
|
|
28
|
+
- Secrets `crafter-environment-sdk`, `groq-api-key`, and `openai-api-key`
|
|
29
|
+
available in your Modal account.
|
|
30
|
+
|
|
31
|
+
The CLI generates a Modal entrypoint on the fly using the shared
|
|
32
|
+
`TaskAppConfig`, ensuring the container matches the local FastAPI behavior.
|
|
33
|
+
|
|
34
|
+
## Compatibility note
|
|
35
|
+
`examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py` remains as a
|
|
36
|
+
legacy wrapper exposing `fastapi_app()` and a `__main__` entrypoint. Behind the
|
|
37
|
+
scenes it proxies to the shared configuration; prefer the CLI workflow above
|
|
38
|
+
for new automation and tests.
|
|
@@ -8,13 +8,13 @@ from dataclasses import dataclass
|
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import Any, Dict, Iterable, List, Sequence
|
|
10
10
|
|
|
11
|
-
from
|
|
12
|
-
from
|
|
13
|
-
from
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
17
|
-
from
|
|
11
|
+
from synth_ai.task.contracts import RolloutRequest, RolloutResponse, TaskInfo
|
|
12
|
+
from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
|
|
13
|
+
from synth_ai.task.rubrics import load_rubric
|
|
14
|
+
from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
|
|
15
|
+
from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
|
|
16
|
+
from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
|
|
17
|
+
from synth_ai.task.tracing_utils import (
|
|
18
18
|
build_tracer_factory,
|
|
19
19
|
resolve_sft_output_dir,
|
|
20
20
|
resolve_tracing_db_url,
|