synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  97. synth_ai/api/train/config_finder.py +18 -18
  98. synth_ai/api/train/env_resolver.py +28 -1
  99. synth_ai/cli/task_apps.py +291 -56
  100. synth_ai/task/apps/__init__.py +54 -13
  101. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
  104. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,342 @@
1
+ #!/usr/bin/env python3
2
+ """Launch multiple local rollouts concurrently and summarise rewards/achievements."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ from collections import Counter
11
+ from pathlib import Path
12
+ from statistics import mean, median
13
+ from typing import Any
14
+
15
+ from dotenv import load_dotenv
16
+
17
+ from synth_ai.task import TaskAppClient
18
+
19
+ from synth_ai.task import (
20
+ RolloutEnvSpec,
21
+ RolloutPolicySpec,
22
+ RolloutRecordConfig,
23
+ RolloutRequest,
24
+ RolloutSafetyConfig,
25
+ )
26
+
27
+
28
+ def build_rollout_request(
29
+ *,
30
+ seed: int,
31
+ run_id: str,
32
+ model: str,
33
+ inference_url: str,
34
+ ops: list[str],
35
+ extra_headers: dict[str, str] | None = None,
36
+ trace_format: str = "compact",
37
+ return_trace: bool = False,
38
+ ) -> RolloutRequest:
39
+ policy_config = {"model": model, "inference_url": inference_url}
40
+ if extra_headers:
41
+ policy_config["extra_headers"] = extra_headers
42
+ record_cfg = RolloutRecordConfig(
43
+ trajectories=True,
44
+ trace_format=trace_format,
45
+ return_trace=return_trace,
46
+ )
47
+ return RolloutRequest(
48
+ run_id=run_id,
49
+ env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
50
+ policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
51
+ ops=ops,
52
+ record=record_cfg,
53
+ on_done="reset",
54
+ safety=RolloutSafetyConfig(),
55
+ )
56
+
57
+
58
+ def mask_value(value: str | None) -> str:
59
+ if not value:
60
+ return "<unset>"
61
+ return f"{value[:6]}…{value[-4:]} (len={len(value)})"
62
+
63
+
64
+ def build_ops(max_llm_calls: int, explicit_ops: str | None) -> list[str]:
65
+ if explicit_ops:
66
+ ops = [op.strip() for op in explicit_ops.split(",") if op.strip()]
67
+ if not ops:
68
+ raise ValueError("--ops must contain at least one entry")
69
+ return ops
70
+
71
+ llm_calls = max(1, max_llm_calls)
72
+ if llm_calls > 50:
73
+ print("[WARN] --max-llm-calls capped at 50 per rollout; use --ops for manual control.")
74
+ llm_calls = 50
75
+
76
+ ops: list[str] = []
77
+ for _ in range(llm_calls):
78
+ ops.extend(["agent", "env"])
79
+ return ops
80
+
81
+
82
+ def extract_achievements(step_info: dict[str, Any] | None) -> list[str]:
83
+ achievements: list[str] = []
84
+ if not isinstance(step_info, dict):
85
+ return achievements
86
+
87
+ added = step_info.get("achievements_added")
88
+ if isinstance(added, list):
89
+ achievements.extend(str(item) for item in added)
90
+
91
+ meta = step_info.get("meta")
92
+ if isinstance(meta, dict):
93
+ decision = meta.get("decision_rewards")
94
+ if isinstance(decision, dict):
95
+ for key in ("all", "achievements"):
96
+ maybe = decision.get(key)
97
+ if isinstance(maybe, list):
98
+ achievements.extend(str(item) for item in maybe)
99
+ for key in ("unique", "unique_achievements"):
100
+ maybe = decision.get(key)
101
+ if isinstance(maybe, list):
102
+ achievements.extend(str(item) for item in maybe)
103
+ return achievements
104
+
105
+
106
+ def analyse_rollout_response(response: Any) -> dict[str, Any]:
107
+ metrics = response.metrics
108
+ trajectory = response.trajectories[0] if response.trajectories else None
109
+
110
+ episode_return = metrics.episode_returns[0] if metrics.episode_returns else 0.0
111
+ total_steps = metrics.num_steps
112
+
113
+ step_achievements: list[str] = []
114
+ if trajectory is not None:
115
+ for step in trajectory.steps:
116
+ step_achievements.extend(extract_achievements(step.info))
117
+
118
+ trace_payload = response.trace or {}
119
+ metadata = trace_payload.get("metadata") if isinstance(trace_payload, dict) else {}
120
+ final_achievements = []
121
+ if isinstance(metadata, dict):
122
+ final_list = metadata.get("final_achievements")
123
+ if isinstance(final_list, list):
124
+ final_achievements = [str(item) for item in final_list]
125
+
126
+ decision_rewards = trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
127
+ trace_all: list[str] = []
128
+ if isinstance(decision_rewards, list):
129
+ for item in decision_rewards:
130
+ if isinstance(item, dict):
131
+ for key in ("achievements", "all", "unique", "unique_achievements"):
132
+ values = item.get(key)
133
+ if isinstance(values, list):
134
+ trace_all.extend(str(v) for v in values)
135
+
136
+ combined = step_achievements + trace_all + final_achievements
137
+ unique = sorted({str(item) for item in combined})
138
+
139
+ return {
140
+ "return": float(episode_return),
141
+ "steps": int(total_steps),
142
+ "achievements_all": combined,
143
+ "achievements_unique": unique,
144
+ "trace": trace_payload,
145
+ "metrics": metrics,
146
+ }
147
+
148
+
149
+ def summarise_runs(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
150
+ if not run_summaries:
151
+ return {}
152
+
153
+ returns = [item["return"] for item in run_summaries]
154
+ total_steps = sum(item["steps"] for item in run_summaries)
155
+
156
+ achievements_all_counter = Counter()
157
+ achievements_unique_counter = Counter()
158
+ unique_count_hist = Counter()
159
+
160
+ for summary in run_summaries:
161
+ achievements_all_counter.update(summary["achievements_all"])
162
+ unique_set = set(summary["achievements_unique"])
163
+ achievements_unique_counter.update(unique_set)
164
+ unique_count_hist[len(unique_set)] += 1
165
+
166
+ stats = {
167
+ "count": len(run_summaries),
168
+ "returns": {
169
+ "mean": mean(returns),
170
+ "median": median(returns),
171
+ "min": min(returns),
172
+ "max": max(returns),
173
+ "total": sum(returns),
174
+ },
175
+ "total_steps": total_steps,
176
+ "achievements_all": achievements_all_counter,
177
+ "achievements_unique": achievements_unique_counter,
178
+ "unique_count_hist": unique_count_hist,
179
+ }
180
+ return stats
181
+
182
+
183
+ def print_summary(stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int) -> None:
184
+ if not stats:
185
+ print("No successful rollouts to summarise.")
186
+ return
187
+
188
+ returns = stats["returns"]
189
+ print("Rollout summary:")
190
+ print(f" Runs succeeded: {stats['count']} / {total_runs}")
191
+ print(f" Total steps : {stats['total_steps']}")
192
+ print(
193
+ " Returns : "
194
+ f"mean={returns['mean']:.2f}, median={returns['median']:.2f}, "
195
+ f"min={returns['min']:.2f}, max={returns['max']:.2f}, total={returns['total']:.2f}"
196
+ )
197
+
198
+ unique_hist = stats["unique_count_hist"]
199
+ if unique_hist:
200
+ print(" Unique achievement counts per run:")
201
+ for count in sorted(unique_hist):
202
+ runs = unique_hist[count]
203
+ print(f" {count:02d} unique -> {runs} run(s)")
204
+
205
+ top_unique = stats["achievements_unique"].most_common()
206
+ if top_unique:
207
+ print(" Achievements unlocked (by runs):")
208
+ for name, freq in top_unique:
209
+ print(f" {name}: {freq} run(s)")
210
+
211
+ top_all = stats["achievements_all"].most_common()
212
+ if top_all:
213
+ print(" Achievement unlock events (total occurrences):")
214
+ for name, freq in top_all:
215
+ print(f" {name}: {freq} event(s)")
216
+
217
+ print("\nTop runs by return:")
218
+ ranked = sorted(run_details, key=lambda item: item["summary"]["return"], reverse=True)
219
+ for idx, item in enumerate(ranked[:10], start=1):
220
+ summary = item["summary"]
221
+ print(
222
+ f" {idx:02d}. run_id={item['run_id']} seed={item['seed']} "
223
+ f"return={summary['return']:.2f} steps={summary['steps']} "
224
+ f"achievements={summary['achievements_unique']}"
225
+ )
226
+
227
+
228
+ async def execute_rollouts(args: argparse.Namespace) -> None:
229
+ if args.env_file:
230
+ env_path = Path(args.env_file).expanduser()
231
+ if not env_path.exists():
232
+ raise FileNotFoundError(f"Env file not found: {env_path}")
233
+ load_dotenv(env_path, override=False)
234
+
235
+ api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
236
+ if not api_key:
237
+ raise RuntimeError("Missing --api-key or ENVIRONMENT_API_KEY")
238
+
239
+ synth_key = os.getenv("SYNTH_API_KEY")
240
+ extra_headers: dict[str, str] | None = None
241
+ if synth_key and "openai.com" not in args.inference_url.lower():
242
+ extra_headers = {"Authorization": f"Bearer {synth_key}"}
243
+
244
+ if args.verbose:
245
+ print("Resolved configuration:")
246
+ print(f" Task app base URL : {args.base_url}")
247
+ print(f" Inference base URL : {args.inference_url}")
248
+ print(f" Task app API key : {mask_value(api_key)}")
249
+ print(f" Synth API key : {mask_value(synth_key)}")
250
+ print(f" HTTP timeout : {args.timeout:.1f}s")
251
+ print(f" Rollouts : {args.count} (parallel={args.parallel})")
252
+
253
+ ops = build_ops(args.max_llm_calls, args.ops)
254
+
255
+ async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
256
+ async def run_single(index: int) -> dict[str, Any]:
257
+ run_id = f"{args.run_id}-{index:03d}"
258
+ seed = args.seed + index * args.seed_stride
259
+ request = build_rollout_request(
260
+ seed=seed,
261
+ run_id=run_id,
262
+ model=args.model,
263
+ inference_url=args.inference_url,
264
+ ops=ops,
265
+ extra_headers=extra_headers,
266
+ trace_format=args.trace_format,
267
+ return_trace=True,
268
+ )
269
+ if args.max_policy_tokens is not None:
270
+ request.policy.config.update({
271
+ "max_completion_tokens": args.max_policy_tokens,
272
+ "max_tokens": args.max_policy_tokens,
273
+ })
274
+
275
+ try:
276
+ response = await client.rollout(request)
277
+ summary = analyse_rollout_response(response)
278
+ return {
279
+ "ok": True,
280
+ "run_id": run_id,
281
+ "seed": seed,
282
+ "response": response,
283
+ "summary": summary,
284
+ }
285
+ except Exception as exc: # pragma: no cover - surface errors
286
+ return {
287
+ "ok": False,
288
+ "run_id": run_id,
289
+ "seed": seed,
290
+ "error": exc,
291
+ }
292
+
293
+ semaphore = asyncio.Semaphore(max(1, args.parallel))
294
+
295
+ async def guarded_run(index: int) -> dict[str, Any]:
296
+ async with semaphore:
297
+ return await run_single(index)
298
+
299
+ tasks = [asyncio.create_task(guarded_run(i)) for i in range(args.count)]
300
+ results = await asyncio.gather(*tasks)
301
+
302
+ successes = [item for item in results if item.get("ok")]
303
+ failures = [item for item in results if not item.get("ok")]
304
+
305
+ stats = summarise_runs([item["summary"] for item in successes])
306
+ print_summary(stats, run_details=successes, total_runs=args.count)
307
+
308
+ if failures:
309
+ print("\nFailures:")
310
+ for item in failures:
311
+ err = item.get("error")
312
+ print(f" run_id={item['run_id']} seed={item['seed']} error={err}")
313
+
314
+
315
+ def parse_args() -> argparse.Namespace:
316
+ parser = argparse.ArgumentParser(description=__doc__)
317
+ parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
318
+ parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
319
+ parser.add_argument("--env-file", help="Path to .env file providing API keys")
320
+ parser.add_argument("--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy")
321
+ parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL for the policy")
322
+ parser.add_argument("--seed", type=int, default=42, help="Base seed for the first rollout")
323
+ parser.add_argument("--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout")
324
+ parser.add_argument("--count", type=int, default=20, help="Number of rollout trajectories to execute")
325
+ parser.add_argument("--parallel", type=int, default=4, help="Maximum concurrent rollouts")
326
+ parser.add_argument("--ops", help="Comma-separated rollout ops (advanced override)")
327
+ parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of agent/env pairs per rollout when --ops not provided")
328
+ parser.add_argument("--max-policy-tokens", type=int, help="Optional per-call token limit forwarded to the policy config")
329
+ parser.add_argument("--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests")
330
+ parser.add_argument("--trace-format", default="compact", choices=["compact", "full"], help="Trace format requested from the task app")
331
+ parser.add_argument("--run-id", default="batch-demo", help="Run ID prefix for rollouts")
332
+ parser.add_argument("--verbose", action="store_true", help="Print resolved configuration")
333
+ return parser.parse_args()
334
+
335
+
336
+ def main() -> None:
337
+ args = parse_args()
338
+ asyncio.run(execute_rollouts(args))
339
+
340
+
341
+ if __name__ == "__main__":
342
+ main()