synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Launch multiple local rollouts concurrently and summarise rewards/achievements."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from collections import Counter
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from statistics import mean, median
|
|
13
|
+
from typing import Any
|
|
14
|
+
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
from synth_ai.task import TaskAppClient
|
|
18
|
+
|
|
19
|
+
from synth_ai.task import (
|
|
20
|
+
RolloutEnvSpec,
|
|
21
|
+
RolloutPolicySpec,
|
|
22
|
+
RolloutRecordConfig,
|
|
23
|
+
RolloutRequest,
|
|
24
|
+
RolloutSafetyConfig,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_rollout_request(
|
|
29
|
+
*,
|
|
30
|
+
seed: int,
|
|
31
|
+
run_id: str,
|
|
32
|
+
model: str,
|
|
33
|
+
inference_url: str,
|
|
34
|
+
ops: list[str],
|
|
35
|
+
extra_headers: dict[str, str] | None = None,
|
|
36
|
+
trace_format: str = "compact",
|
|
37
|
+
return_trace: bool = False,
|
|
38
|
+
) -> RolloutRequest:
|
|
39
|
+
policy_config = {"model": model, "inference_url": inference_url}
|
|
40
|
+
if extra_headers:
|
|
41
|
+
policy_config["extra_headers"] = extra_headers
|
|
42
|
+
record_cfg = RolloutRecordConfig(
|
|
43
|
+
trajectories=True,
|
|
44
|
+
trace_format=trace_format,
|
|
45
|
+
return_trace=return_trace,
|
|
46
|
+
)
|
|
47
|
+
return RolloutRequest(
|
|
48
|
+
run_id=run_id,
|
|
49
|
+
env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
|
|
50
|
+
policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
|
|
51
|
+
ops=ops,
|
|
52
|
+
record=record_cfg,
|
|
53
|
+
on_done="reset",
|
|
54
|
+
safety=RolloutSafetyConfig(),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def mask_value(value: str | None) -> str:
|
|
59
|
+
if not value:
|
|
60
|
+
return "<unset>"
|
|
61
|
+
return f"{value[:6]}…{value[-4:]} (len={len(value)})"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def build_ops(max_llm_calls: int, explicit_ops: str | None) -> list[str]:
|
|
65
|
+
if explicit_ops:
|
|
66
|
+
ops = [op.strip() for op in explicit_ops.split(",") if op.strip()]
|
|
67
|
+
if not ops:
|
|
68
|
+
raise ValueError("--ops must contain at least one entry")
|
|
69
|
+
return ops
|
|
70
|
+
|
|
71
|
+
llm_calls = max(1, max_llm_calls)
|
|
72
|
+
if llm_calls > 50:
|
|
73
|
+
print("[WARN] --max-llm-calls capped at 50 per rollout; use --ops for manual control.")
|
|
74
|
+
llm_calls = 50
|
|
75
|
+
|
|
76
|
+
ops: list[str] = []
|
|
77
|
+
for _ in range(llm_calls):
|
|
78
|
+
ops.extend(["agent", "env"])
|
|
79
|
+
return ops
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def extract_achievements(step_info: dict[str, Any] | None) -> list[str]:
|
|
83
|
+
achievements: list[str] = []
|
|
84
|
+
if not isinstance(step_info, dict):
|
|
85
|
+
return achievements
|
|
86
|
+
|
|
87
|
+
added = step_info.get("achievements_added")
|
|
88
|
+
if isinstance(added, list):
|
|
89
|
+
achievements.extend(str(item) for item in added)
|
|
90
|
+
|
|
91
|
+
meta = step_info.get("meta")
|
|
92
|
+
if isinstance(meta, dict):
|
|
93
|
+
decision = meta.get("decision_rewards")
|
|
94
|
+
if isinstance(decision, dict):
|
|
95
|
+
for key in ("all", "achievements"):
|
|
96
|
+
maybe = decision.get(key)
|
|
97
|
+
if isinstance(maybe, list):
|
|
98
|
+
achievements.extend(str(item) for item in maybe)
|
|
99
|
+
for key in ("unique", "unique_achievements"):
|
|
100
|
+
maybe = decision.get(key)
|
|
101
|
+
if isinstance(maybe, list):
|
|
102
|
+
achievements.extend(str(item) for item in maybe)
|
|
103
|
+
return achievements
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def analyse_rollout_response(response: Any) -> dict[str, Any]:
|
|
107
|
+
metrics = response.metrics
|
|
108
|
+
trajectory = response.trajectories[0] if response.trajectories else None
|
|
109
|
+
|
|
110
|
+
episode_return = metrics.episode_returns[0] if metrics.episode_returns else 0.0
|
|
111
|
+
total_steps = metrics.num_steps
|
|
112
|
+
|
|
113
|
+
step_achievements: list[str] = []
|
|
114
|
+
if trajectory is not None:
|
|
115
|
+
for step in trajectory.steps:
|
|
116
|
+
step_achievements.extend(extract_achievements(step.info))
|
|
117
|
+
|
|
118
|
+
trace_payload = response.trace or {}
|
|
119
|
+
metadata = trace_payload.get("metadata") if isinstance(trace_payload, dict) else {}
|
|
120
|
+
final_achievements = []
|
|
121
|
+
if isinstance(metadata, dict):
|
|
122
|
+
final_list = metadata.get("final_achievements")
|
|
123
|
+
if isinstance(final_list, list):
|
|
124
|
+
final_achievements = [str(item) for item in final_list]
|
|
125
|
+
|
|
126
|
+
decision_rewards = trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
|
|
127
|
+
trace_all: list[str] = []
|
|
128
|
+
if isinstance(decision_rewards, list):
|
|
129
|
+
for item in decision_rewards:
|
|
130
|
+
if isinstance(item, dict):
|
|
131
|
+
for key in ("achievements", "all", "unique", "unique_achievements"):
|
|
132
|
+
values = item.get(key)
|
|
133
|
+
if isinstance(values, list):
|
|
134
|
+
trace_all.extend(str(v) for v in values)
|
|
135
|
+
|
|
136
|
+
combined = step_achievements + trace_all + final_achievements
|
|
137
|
+
unique = sorted({str(item) for item in combined})
|
|
138
|
+
|
|
139
|
+
return {
|
|
140
|
+
"return": float(episode_return),
|
|
141
|
+
"steps": int(total_steps),
|
|
142
|
+
"achievements_all": combined,
|
|
143
|
+
"achievements_unique": unique,
|
|
144
|
+
"trace": trace_payload,
|
|
145
|
+
"metrics": metrics,
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def summarise_runs(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
|
150
|
+
if not run_summaries:
|
|
151
|
+
return {}
|
|
152
|
+
|
|
153
|
+
returns = [item["return"] for item in run_summaries]
|
|
154
|
+
total_steps = sum(item["steps"] for item in run_summaries)
|
|
155
|
+
|
|
156
|
+
achievements_all_counter = Counter()
|
|
157
|
+
achievements_unique_counter = Counter()
|
|
158
|
+
unique_count_hist = Counter()
|
|
159
|
+
|
|
160
|
+
for summary in run_summaries:
|
|
161
|
+
achievements_all_counter.update(summary["achievements_all"])
|
|
162
|
+
unique_set = set(summary["achievements_unique"])
|
|
163
|
+
achievements_unique_counter.update(unique_set)
|
|
164
|
+
unique_count_hist[len(unique_set)] += 1
|
|
165
|
+
|
|
166
|
+
stats = {
|
|
167
|
+
"count": len(run_summaries),
|
|
168
|
+
"returns": {
|
|
169
|
+
"mean": mean(returns),
|
|
170
|
+
"median": median(returns),
|
|
171
|
+
"min": min(returns),
|
|
172
|
+
"max": max(returns),
|
|
173
|
+
"total": sum(returns),
|
|
174
|
+
},
|
|
175
|
+
"total_steps": total_steps,
|
|
176
|
+
"achievements_all": achievements_all_counter,
|
|
177
|
+
"achievements_unique": achievements_unique_counter,
|
|
178
|
+
"unique_count_hist": unique_count_hist,
|
|
179
|
+
}
|
|
180
|
+
return stats
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def print_summary(stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int) -> None:
|
|
184
|
+
if not stats:
|
|
185
|
+
print("No successful rollouts to summarise.")
|
|
186
|
+
return
|
|
187
|
+
|
|
188
|
+
returns = stats["returns"]
|
|
189
|
+
print("Rollout summary:")
|
|
190
|
+
print(f" Runs succeeded: {stats['count']} / {total_runs}")
|
|
191
|
+
print(f" Total steps : {stats['total_steps']}")
|
|
192
|
+
print(
|
|
193
|
+
" Returns : "
|
|
194
|
+
f"mean={returns['mean']:.2f}, median={returns['median']:.2f}, "
|
|
195
|
+
f"min={returns['min']:.2f}, max={returns['max']:.2f}, total={returns['total']:.2f}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
unique_hist = stats["unique_count_hist"]
|
|
199
|
+
if unique_hist:
|
|
200
|
+
print(" Unique achievement counts per run:")
|
|
201
|
+
for count in sorted(unique_hist):
|
|
202
|
+
runs = unique_hist[count]
|
|
203
|
+
print(f" {count:02d} unique -> {runs} run(s)")
|
|
204
|
+
|
|
205
|
+
top_unique = stats["achievements_unique"].most_common()
|
|
206
|
+
if top_unique:
|
|
207
|
+
print(" Achievements unlocked (by runs):")
|
|
208
|
+
for name, freq in top_unique:
|
|
209
|
+
print(f" {name}: {freq} run(s)")
|
|
210
|
+
|
|
211
|
+
top_all = stats["achievements_all"].most_common()
|
|
212
|
+
if top_all:
|
|
213
|
+
print(" Achievement unlock events (total occurrences):")
|
|
214
|
+
for name, freq in top_all:
|
|
215
|
+
print(f" {name}: {freq} event(s)")
|
|
216
|
+
|
|
217
|
+
print("\nTop runs by return:")
|
|
218
|
+
ranked = sorted(run_details, key=lambda item: item["summary"]["return"], reverse=True)
|
|
219
|
+
for idx, item in enumerate(ranked[:10], start=1):
|
|
220
|
+
summary = item["summary"]
|
|
221
|
+
print(
|
|
222
|
+
f" {idx:02d}. run_id={item['run_id']} seed={item['seed']} "
|
|
223
|
+
f"return={summary['return']:.2f} steps={summary['steps']} "
|
|
224
|
+
f"achievements={summary['achievements_unique']}"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
229
|
+
if args.env_file:
|
|
230
|
+
env_path = Path(args.env_file).expanduser()
|
|
231
|
+
if not env_path.exists():
|
|
232
|
+
raise FileNotFoundError(f"Env file not found: {env_path}")
|
|
233
|
+
load_dotenv(env_path, override=False)
|
|
234
|
+
|
|
235
|
+
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
236
|
+
if not api_key:
|
|
237
|
+
raise RuntimeError("Missing --api-key or ENVIRONMENT_API_KEY")
|
|
238
|
+
|
|
239
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
240
|
+
extra_headers: dict[str, str] | None = None
|
|
241
|
+
if synth_key and "openai.com" not in args.inference_url.lower():
|
|
242
|
+
extra_headers = {"Authorization": f"Bearer {synth_key}"}
|
|
243
|
+
|
|
244
|
+
if args.verbose:
|
|
245
|
+
print("Resolved configuration:")
|
|
246
|
+
print(f" Task app base URL : {args.base_url}")
|
|
247
|
+
print(f" Inference base URL : {args.inference_url}")
|
|
248
|
+
print(f" Task app API key : {mask_value(api_key)}")
|
|
249
|
+
print(f" Synth API key : {mask_value(synth_key)}")
|
|
250
|
+
print(f" HTTP timeout : {args.timeout:.1f}s")
|
|
251
|
+
print(f" Rollouts : {args.count} (parallel={args.parallel})")
|
|
252
|
+
|
|
253
|
+
ops = build_ops(args.max_llm_calls, args.ops)
|
|
254
|
+
|
|
255
|
+
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
256
|
+
async def run_single(index: int) -> dict[str, Any]:
|
|
257
|
+
run_id = f"{args.run_id}-{index:03d}"
|
|
258
|
+
seed = args.seed + index * args.seed_stride
|
|
259
|
+
request = build_rollout_request(
|
|
260
|
+
seed=seed,
|
|
261
|
+
run_id=run_id,
|
|
262
|
+
model=args.model,
|
|
263
|
+
inference_url=args.inference_url,
|
|
264
|
+
ops=ops,
|
|
265
|
+
extra_headers=extra_headers,
|
|
266
|
+
trace_format=args.trace_format,
|
|
267
|
+
return_trace=True,
|
|
268
|
+
)
|
|
269
|
+
if args.max_policy_tokens is not None:
|
|
270
|
+
request.policy.config.update({
|
|
271
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
272
|
+
"max_tokens": args.max_policy_tokens,
|
|
273
|
+
})
|
|
274
|
+
|
|
275
|
+
try:
|
|
276
|
+
response = await client.rollout(request)
|
|
277
|
+
summary = analyse_rollout_response(response)
|
|
278
|
+
return {
|
|
279
|
+
"ok": True,
|
|
280
|
+
"run_id": run_id,
|
|
281
|
+
"seed": seed,
|
|
282
|
+
"response": response,
|
|
283
|
+
"summary": summary,
|
|
284
|
+
}
|
|
285
|
+
except Exception as exc: # pragma: no cover - surface errors
|
|
286
|
+
return {
|
|
287
|
+
"ok": False,
|
|
288
|
+
"run_id": run_id,
|
|
289
|
+
"seed": seed,
|
|
290
|
+
"error": exc,
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
semaphore = asyncio.Semaphore(max(1, args.parallel))
|
|
294
|
+
|
|
295
|
+
async def guarded_run(index: int) -> dict[str, Any]:
|
|
296
|
+
async with semaphore:
|
|
297
|
+
return await run_single(index)
|
|
298
|
+
|
|
299
|
+
tasks = [asyncio.create_task(guarded_run(i)) for i in range(args.count)]
|
|
300
|
+
results = await asyncio.gather(*tasks)
|
|
301
|
+
|
|
302
|
+
successes = [item for item in results if item.get("ok")]
|
|
303
|
+
failures = [item for item in results if not item.get("ok")]
|
|
304
|
+
|
|
305
|
+
stats = summarise_runs([item["summary"] for item in successes])
|
|
306
|
+
print_summary(stats, run_details=successes, total_runs=args.count)
|
|
307
|
+
|
|
308
|
+
if failures:
|
|
309
|
+
print("\nFailures:")
|
|
310
|
+
for item in failures:
|
|
311
|
+
err = item.get("error")
|
|
312
|
+
print(f" run_id={item['run_id']} seed={item['seed']} error={err}")
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def parse_args() -> argparse.Namespace:
|
|
316
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
317
|
+
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
318
|
+
parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
|
|
319
|
+
parser.add_argument("--env-file", help="Path to .env file providing API keys")
|
|
320
|
+
parser.add_argument("--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy")
|
|
321
|
+
parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL for the policy")
|
|
322
|
+
parser.add_argument("--seed", type=int, default=42, help="Base seed for the first rollout")
|
|
323
|
+
parser.add_argument("--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout")
|
|
324
|
+
parser.add_argument("--count", type=int, default=20, help="Number of rollout trajectories to execute")
|
|
325
|
+
parser.add_argument("--parallel", type=int, default=4, help="Maximum concurrent rollouts")
|
|
326
|
+
parser.add_argument("--ops", help="Comma-separated rollout ops (advanced override)")
|
|
327
|
+
parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of agent/env pairs per rollout when --ops not provided")
|
|
328
|
+
parser.add_argument("--max-policy-tokens", type=int, help="Optional per-call token limit forwarded to the policy config")
|
|
329
|
+
parser.add_argument("--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests")
|
|
330
|
+
parser.add_argument("--trace-format", default="compact", choices=["compact", "full"], help="Trace format requested from the task app")
|
|
331
|
+
parser.add_argument("--run-id", default="batch-demo", help="Run ID prefix for rollouts")
|
|
332
|
+
parser.add_argument("--verbose", action="store_true", help="Print resolved configuration")
|
|
333
|
+
return parser.parse_args()
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def main() -> None:
|
|
337
|
+
args = parse_args()
|
|
338
|
+
asyncio.run(execute_rollouts(args))
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
if __name__ == "__main__":
|
|
342
|
+
main()
|