synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  97. synth_ai/api/train/config_finder.py +18 -18
  98. synth_ai/api/train/env_resolver.py +28 -1
  99. synth_ai/cli/task_apps.py +291 -56
  100. synth_ai/task/apps/__init__.py +54 -13
  101. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
  104. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,372 @@
1
+ #!/usr/bin/env python3
2
+ """Run a local Crafter rollout, capture tracing metadata, and optionally persist the trace."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ import sys
13
+
14
+ import httpx
15
+
16
+ from synth_ai.task import (
17
+ RolloutEnvSpec,
18
+ RolloutPolicySpec,
19
+ RolloutRecordConfig,
20
+ RolloutRequest,
21
+ RolloutSafetyConfig,
22
+ TaskAppClient,
23
+ )
24
+
25
+
26
+ def build_rollout_request(
27
+ *,
28
+ seed: int,
29
+ run_id: str,
30
+ model: str,
31
+ inference_url: str,
32
+ ops: list[str],
33
+ return_trace: bool,
34
+ trace_format: str,
35
+ max_policy_tokens: int | None,
36
+ ) -> RolloutRequest:
37
+ policy_config = {
38
+ "model": model,
39
+ "inference_url": inference_url,
40
+ }
41
+ if max_policy_tokens is not None:
42
+ policy_config.update(
43
+ {
44
+ "max_completion_tokens": max_policy_tokens,
45
+ "max_tokens": max_policy_tokens,
46
+ }
47
+ )
48
+
49
+ record = RolloutRecordConfig(
50
+ trajectories=True,
51
+ return_trace=return_trace,
52
+ trace_format=trace_format,
53
+ )
54
+
55
+ return RolloutRequest(
56
+ run_id=run_id,
57
+ env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
58
+ policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
59
+ ops=ops,
60
+ record=record,
61
+ on_done="reset",
62
+ safety=RolloutSafetyConfig(),
63
+ )
64
+
65
+
66
+ def summarise_rollout(response: Any) -> dict[str, Any]:
67
+ metrics = response.metrics.model_dump() if hasattr(response, "metrics") else response.get("metrics", {})
68
+ return {
69
+ "run_id": getattr(response, "run_id", None) or response.get("run_id"),
70
+ "num_episodes": metrics.get("num_episodes"),
71
+ "num_steps": metrics.get("num_steps"),
72
+ "episode_returns": metrics.get("episode_returns"),
73
+ "outcome_score": metrics.get("outcome_score"),
74
+ "events_score": metrics.get("events_score"),
75
+ }
76
+
77
+
78
+ def summarise_trace(trace: Any) -> dict[str, Any]:
79
+ if trace is None:
80
+ return {"trace": None}
81
+ if not isinstance(trace, dict):
82
+ return {"trace_type": type(trace).__name__}
83
+
84
+ format_hint = "compact" if "events_count" in trace or "lm_calls" in trace else "full"
85
+ events_count = trace.get("events_count")
86
+ if events_count is None and "event_history" in trace and isinstance(trace["event_history"], list):
87
+ events_count = len(trace["event_history"])
88
+ messages_count = trace.get("messages_count")
89
+ if messages_count is None and "markov_blanket_message_history" in trace and isinstance(
90
+ trace["markov_blanket_message_history"], list
91
+ ):
92
+ messages_count = len(trace["markov_blanket_message_history"])
93
+
94
+ metadata = trace.get("metadata") if isinstance(trace.get("metadata"), dict) else {}
95
+ lm_calls = trace.get("lm_calls") if isinstance(trace.get("lm_calls"), list) else []
96
+ decision_rewards = trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
97
+
98
+ return {
99
+ "session_id": trace.get("session_id"),
100
+ "format": format_hint,
101
+ "events_count": events_count,
102
+ "messages_count": messages_count,
103
+ "metadata_keys": sorted(metadata.keys()),
104
+ "lm_calls_count": len(lm_calls),
105
+ "decision_turns": len(decision_rewards),
106
+ }
107
+
108
+
109
+ def ensure_ops(ops_arg: str | None, max_llm_calls: int) -> list[str]:
110
+ if ops_arg:
111
+ ops = [op.strip() for op in ops_arg.split(",") if op.strip()]
112
+ if not ops:
113
+ raise ValueError("--ops must contain at least one entry when provided")
114
+ return ops
115
+ max_llm_calls = max(max_llm_calls, 1)
116
+ ops: list[str] = []
117
+ for _ in range(max_llm_calls):
118
+ ops.extend(["agent", "env"])
119
+ return ops
120
+
121
+
122
+ def dump_trace(trace: dict[str, Any], *, path: Path, pretty: bool) -> None:
123
+ path.parent.mkdir(parents=True, exist_ok=True)
124
+ with path.open("w", encoding="utf-8") as fh:
125
+ json.dump(trace, fh, indent=2 if pretty else None)
126
+ fh.write("\n")
127
+
128
+
129
+ def extract_environment_rewards(trace_payload: dict[str, Any] | None) -> list[float]:
130
+ if not trace_payload:
131
+ return []
132
+
133
+ rewards: list[float] = []
134
+
135
+ def _collect(events: list[dict[str, Any]]) -> None:
136
+ for event in events:
137
+ reward = event.get("reward")
138
+ if reward is not None:
139
+ try:
140
+ rewards.append(float(reward))
141
+ except Exception:
142
+ continue
143
+
144
+ if isinstance(trace_payload.get("event_history"), list):
145
+ _collect(trace_payload["event_history"])
146
+ if isinstance(trace_payload.get("session_time_steps"), list):
147
+ for step in trace_payload["session_time_steps"]:
148
+ _collect(step.get("events", []))
149
+
150
+ return rewards
151
+
152
+
153
+ def extract_decision_rewards(trace_payload: dict[str, Any] | None) -> list[dict[str, Any]]:
154
+ if not trace_payload:
155
+ return []
156
+ rewards = trace_payload.get("decision_rewards")
157
+ return rewards if isinstance(rewards, list) else []
158
+
159
+
160
+ def extract_trajectory_rewards(response: Any) -> list[float]:
161
+ """Extract per-step rewards directly from the rollout trajectories."""
162
+
163
+ rewards: list[float] = []
164
+
165
+ if response is None:
166
+ return rewards
167
+
168
+ trajectories = getattr(response, "trajectories", None)
169
+ if trajectories is None and isinstance(response, dict):
170
+ trajectories = response.get("trajectories")
171
+
172
+ if not trajectories:
173
+ return rewards
174
+
175
+ for traj in trajectories:
176
+ steps = getattr(traj, "steps", None)
177
+ if steps is None and isinstance(traj, dict):
178
+ steps = traj.get("steps")
179
+ if not steps:
180
+ continue
181
+ for step in steps:
182
+ reward_val = getattr(step, "reward", None)
183
+ if reward_val is None and isinstance(step, dict):
184
+ reward_val = step.get("reward")
185
+ if reward_val is None:
186
+ continue
187
+ try:
188
+ rewards.append(float(reward_val))
189
+ except Exception:
190
+ continue
191
+
192
+ return rewards
193
+
194
+
195
+ def print_reward_summary(
196
+ trace_payload: dict[str, Any] | None,
197
+ rollout_summary: dict[str, Any],
198
+ trajectory_rewards: list[float],
199
+ ) -> None:
200
+ print("Reward summary:")
201
+
202
+ env_rewards = extract_environment_rewards(trace_payload)
203
+ reward_source = "trace"
204
+ if not env_rewards and trajectory_rewards:
205
+ env_rewards = trajectory_rewards
206
+ reward_source = "trajectory"
207
+
208
+ if env_rewards:
209
+ print(f" Environment rewards per step ({reward_source}): {env_rewards}")
210
+ print(f" Environment reward total: {sum(env_rewards):.3f}")
211
+ else:
212
+ print(" Environment rewards per step: none recorded")
213
+
214
+ decision_rewards = extract_decision_rewards(trace_payload)
215
+ if decision_rewards:
216
+ print(" Decision rewards:")
217
+ for entry in decision_rewards:
218
+ turn = entry.get('turn')
219
+ ach_delta = entry.get('ach_delta')
220
+ unique_delta = entry.get('unique_delta')
221
+ achievements = entry.get('achievements') or []
222
+ print(f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}")
223
+ else:
224
+ print(" Decision rewards: none recorded")
225
+
226
+ episode_returns = rollout_summary.get("episode_returns")
227
+ if episode_returns:
228
+ print(f" Outcome rewards (episode returns): {episode_returns}")
229
+ if env_rewards:
230
+ try:
231
+ total_env_reward = float(sum(env_rewards))
232
+ target = float(episode_returns[0]) if episode_returns else 0.0
233
+ if abs(total_env_reward - target) > 1e-6:
234
+ print(
235
+ " ⚠️ Reward mismatch: sum(environment rewards)"
236
+ f"={total_env_reward:.3f} vs episode return={target:.3f}"
237
+ )
238
+ except Exception:
239
+ pass
240
+ else:
241
+ print(" Outcome rewards: none recorded")
242
+
243
+
244
+ async def main() -> None:
245
+ parser = argparse.ArgumentParser(description=__doc__)
246
+ parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
247
+ parser.add_argument("--api-key", required=True, help="Environment API key")
248
+ parser.add_argument("--seed", type=int, default=42, help="Environment seed")
249
+ parser.add_argument("--run-id", default="local-trace", help="Run identifier")
250
+ parser.add_argument("--model", default="gpt-4o-mini", help="OpenAI-compatible model id")
251
+ parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)")
252
+ parser.add_argument("--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)")
253
+ parser.add_argument("--max-llm-calls", type=int, default=1, help="Number of agent/env pairs when --ops not supplied")
254
+ parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional max token budget forwarded to policy")
255
+ parser.add_argument(
256
+ "--trace-format",
257
+ choices=["compact", "full"],
258
+ default="compact",
259
+ help="Trace payload format requested from the server",
260
+ )
261
+ parser.add_argument(
262
+ "--trace-path",
263
+ type=Path,
264
+ help="Path to write the trace JSON (defaults to ./<run_id>_trace.json unless --no-trace-file is set)",
265
+ )
266
+ parser.add_argument(
267
+ "--no-trace-file",
268
+ action="store_true",
269
+ help="Do not write the trace JSON to disk",
270
+ )
271
+ parser.add_argument(
272
+ "--no-print-trace",
273
+ action="store_true",
274
+ help="Do not print the full trace payload to stdout",
275
+ )
276
+ parser.add_argument(
277
+ "--no-trace",
278
+ action="store_true",
279
+ help="Disable return_trace (useful for comparing behaviour without tracing)",
280
+ )
281
+ parser.add_argument(
282
+ "--timeout",
283
+ type=float,
284
+ default=60.0,
285
+ help="HTTP timeout in seconds for the client (default: 60)",
286
+ )
287
+ args = parser.parse_args()
288
+
289
+ ops = ensure_ops(args.ops, args.max_llm_calls)
290
+ return_trace = not args.no_trace
291
+
292
+ async with TaskAppClient(args.base_url, api_key=args.api_key, timeout=args.timeout) as client:
293
+ try:
294
+ print(f"Fetching task_info for seed {args.seed}…")
295
+ task_info = await client.task_info(seeds=[args.seed])
296
+ info_payload = task_info[0] if isinstance(task_info, list) else task_info
297
+ try:
298
+ print(json.dumps(info_payload.model_dump(), indent=2)[:600])
299
+ except Exception:
300
+ print(info_payload)
301
+
302
+ request = build_rollout_request(
303
+ seed=args.seed,
304
+ run_id=args.run_id,
305
+ model=args.model,
306
+ inference_url=args.inference_url,
307
+ ops=ops,
308
+ return_trace=return_trace,
309
+ trace_format=args.trace_format,
310
+ max_policy_tokens=args.max_policy_tokens,
311
+ )
312
+
313
+ print("Requesting rollout…")
314
+ response = await client.rollout(request)
315
+ summary = summarise_rollout(response)
316
+ print(json.dumps(summary, indent=2))
317
+
318
+ trace_payload: dict[str, Any] | None = getattr(response, "trace", None)
319
+ if return_trace:
320
+ if trace_payload is None:
321
+ print(
322
+ "⚠️ Server did not include a trace. Ensure TASKAPP_TRACING_ENABLED=1 when starting the task app.",
323
+ file=sys.stderr,
324
+ )
325
+ else:
326
+ trace_summary = summarise_trace(trace_payload)
327
+ print("Trace summary:")
328
+ print(json.dumps(trace_summary, indent=2))
329
+
330
+ trace_path = args.trace_path
331
+ if not args.no_trace_file:
332
+ if trace_path is None:
333
+ trace_path = Path(f"{args.run_id}_trace.json")
334
+ dump_trace(trace_payload, path=trace_path, pretty=True)
335
+ print(f"Trace written to {trace_path}")
336
+
337
+ if not args.no_print_trace:
338
+ print("Full trace payload:")
339
+ print(json.dumps(trace_payload, indent=2))
340
+
341
+ trajectory_rewards = extract_trajectory_rewards(response)
342
+ print_reward_summary(
343
+ trace_payload if return_trace else None,
344
+ summary,
345
+ trajectory_rewards,
346
+ )
347
+
348
+ print(f"Ops executed: {ops}")
349
+ print(
350
+ "Tip: export TASKAPP_TRACING_ENABLED=1 and optionally TASKAPP_SFT_OUTPUT_DIR before running `uvx synth-ai serve …` to persist traces/SFT."
351
+ )
352
+ except httpx.HTTPStatusError as exc:
353
+ detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
354
+ print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
355
+ if exc.response.status_code in (401, 503):
356
+ print(
357
+ "Hint: ensure the task app process is using the same ENVIRONMENT_API_KEY passed via --api-key.",
358
+ file=sys.stderr,
359
+ )
360
+ if exc.response.status_code == 500:
361
+ print(
362
+ "Hint: verify tracing is enabled server-side (TASKAPP_TRACING_ENABLED=1) and the inference credentials are configured.",
363
+ file=sys.stderr,
364
+ )
365
+ raise
366
+
367
+
368
+ if __name__ == "__main__":
369
+ try:
370
+ asyncio.run(main())
371
+ except KeyboardInterrupt:
372
+ print("Interrupted", file=sys.stderr)
@@ -0,0 +1,101 @@
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ import sys
8
+ from pathlib import Path
9
+ from typing import Any, Dict
10
+
11
+ import tomllib
12
+ import requests
13
+
14
+
15
+ def _load_toml(path: Path) -> Dict[str, Any]:
16
+ if not path.exists():
17
+ print(f"config not found: {path}", file=sys.stderr)
18
+ sys.exit(2)
19
+ with path.open("rb") as fh:
20
+ return tomllib.load(fh)
21
+
22
+
23
+ def main() -> None:
24
+ p = argparse.ArgumentParser(description="Create clustered RL training job via backend RL endpoint")
25
+ p.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
26
+ p.add_argument("--config", required=True, help="Path to RL TOML config")
27
+ p.add_argument("--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL (or set TASK_APP_URL)")
28
+ p.add_argument("--idempotency", default=os.getenv("RL_IDEMPOTENCY_KEY", ""), help="Optional Idempotency-Key header value")
29
+ args = p.parse_args()
30
+
31
+ cfg_path = Path(args.config).expanduser()
32
+ cfg = _load_toml(cfg_path)
33
+
34
+ services = cfg.get("services", {}) if isinstance(cfg.get("services"), dict) else {}
35
+
36
+ # Resolve task app base URL for the job
37
+ cli_task_url = (args.task_url or "").strip()
38
+ env_task_url = (os.getenv("TASK_APP_URL") or "").strip()
39
+ task_url = cli_task_url or env_task_url or ((services.get("task_url") or "").strip() if isinstance(services, dict) else "")
40
+ if not task_url:
41
+ print("Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML", file=sys.stderr)
42
+ sys.exit(2)
43
+
44
+ # TOML-only model selection validation
45
+ model_cfg = cfg.get("model", {}) if isinstance(cfg.get("model"), dict) else {}
46
+ has_source = bool((model_cfg.get("source") or "").strip())
47
+ has_base = bool((model_cfg.get("base") or "").strip())
48
+ if has_source == has_base:
49
+ print("Model selection must specify exactly one of [model].source or [model].base in TOML", file=sys.stderr)
50
+ sys.exit(2)
51
+
52
+ # Build create-job payload. Send full TOML under data.config, plus endpoint_base_url.
53
+ payload: Dict[str, Any] = {
54
+ "job_type": "rl",
55
+ # Optional: compute pass-through
56
+ "compute": cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {},
57
+ "data": {
58
+ "endpoint_base_url": task_url,
59
+ "config": cfg,
60
+ },
61
+ "tags": {"source": "warming_up_to_rl"},
62
+ }
63
+
64
+ backend = str(args.backend).rstrip("/")
65
+ url = f"{backend}/rl/jobs"
66
+ api_key = (os.getenv("SYNTH_API_KEY") or os.getenv("synth_key") or "").strip()
67
+ if not api_key:
68
+ print("Missing SYNTH_API_KEY in env", file=sys.stderr)
69
+ sys.exit(2)
70
+
71
+ headers = {
72
+ "content-type": "application/json",
73
+ "authorization": f"Bearer {api_key}",
74
+ }
75
+ idem = (args.idempotency or "").strip()
76
+ if idem:
77
+ headers["Idempotency-Key"] = idem
78
+
79
+ print(f"[INFO] POST {url}")
80
+ try:
81
+ preview = dict(payload)
82
+ preview_data = dict(preview.get("data", {}))
83
+ cfg_keys = list(cfg.keys())
84
+ preview_data["config"] = {"keys": cfg_keys}
85
+ preview["data"] = preview_data
86
+ print(f"[INFO] Payload: {json.dumps(preview)[:500]}")
87
+ except Exception:
88
+ print("[INFO] Payload: <unavailable>")
89
+
90
+ r = requests.post(url, headers=headers, json=payload, timeout=120)
91
+ ok = r.status_code in (200, 201)
92
+ try:
93
+ snippet = r.json()
94
+ except Exception:
95
+ snippet = r.text[:300]
96
+ print(f"[INFO] Response: {r.status_code} {snippet}")
97
+ sys.exit(0 if ok else 1)
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
@@ -0,0 +1,129 @@
1
+ #!/usr/bin/env python3
2
+ """Request a rollout from a remote Crafter task app (e.g., Modal deployment)."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ import sys
11
+
12
+ import httpx
13
+
14
+ def check_health(base_url: str, api_key: str) -> None:
15
+ try:
16
+ resp = httpx.get(f"{base_url.rstrip('/')}/health", headers={"X-API-Key": api_key}, timeout=10.0)
17
+ data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else resp.text
18
+ if resp.status_code != 200:
19
+ print(f"warning: /health returned {resp.status_code}: {data}")
20
+ else:
21
+ print(f"/health ok: {data}")
22
+ except Exception as exc:
23
+ print(f"warning: failed to call /health: {exc}")
24
+
25
+ from synth_ai.task import (
26
+ RolloutEnvSpec,
27
+ RolloutPolicySpec,
28
+ RolloutRecordConfig,
29
+ RolloutRequest,
30
+ RolloutSafetyConfig,
31
+ TaskAppClient,
32
+ )
33
+
34
+
35
+ def build_request(
36
+ *,
37
+ run_id: str,
38
+ seed: int,
39
+ model: str,
40
+ inference_url: str,
41
+ llm_calls: int,
42
+ max_policy_tokens: int | None,
43
+ ) -> RolloutRequest:
44
+ policy_config = {"model": model, "inference_url": inference_url}
45
+ if max_policy_tokens is not None:
46
+ policy_config.update(
47
+ {
48
+ "max_completion_tokens": max_policy_tokens,
49
+ "max_tokens": max_policy_tokens,
50
+ }
51
+ )
52
+
53
+ ops: list[str] = []
54
+ for _ in range(max(llm_calls, 1)):
55
+ ops.extend(["agent", "env"])
56
+
57
+ return RolloutRequest(
58
+ run_id=run_id,
59
+ env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
60
+ policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
61
+ ops=ops,
62
+ record=RolloutRecordConfig(trajectories=True),
63
+ on_done="reset",
64
+ safety=RolloutSafetyConfig(),
65
+ )
66
+
67
+
68
+ def summarise(response) -> dict[str, any]:
69
+ metrics = response.metrics
70
+ return {
71
+ "run_id": response.run_id,
72
+ "num_episodes": metrics.num_episodes,
73
+ "num_steps": metrics.num_steps,
74
+ "episode_returns": metrics.episode_returns,
75
+ "outcome_score": metrics.outcome_score,
76
+ "events_score": metrics.events_score,
77
+ }
78
+
79
+
80
+ async def main() -> None:
81
+ parser = argparse.ArgumentParser(description=__doc__)
82
+ parser.add_argument("--base-url", default=None, help="Remote task app base URL (e.g., https://xyz.modal.run); defaults to TASK_APP_BASE_URL env")
83
+ parser.add_argument("--api-key", required=True, help="Environment API key for the remote task app")
84
+ parser.add_argument("--seed", type=int, default=42)
85
+ parser.add_argument("--run-id", default="remote-demo")
86
+ parser.add_argument("--model", default="gpt-4o-mini")
87
+ parser.add_argument("--inference-url", default="https://api.openai.com")
88
+ parser.add_argument("--max-llm-calls", type=int, default=1)
89
+ parser.add_argument("--max-policy-tokens", type=int, default=None)
90
+ args = parser.parse_args()
91
+
92
+ base_url = args.base_url or os.getenv('TASK_APP_BASE_URL')
93
+ if not base_url:
94
+ parser.error('Missing --base-url (and TASK_APP_BASE_URL not set).')
95
+
96
+ request = build_request(
97
+ run_id=args.run_id,
98
+ seed=args.seed,
99
+ model=args.model,
100
+ inference_url=args.inference_url,
101
+ llm_calls=args.max_llm_calls,
102
+ max_policy_tokens=args.max_policy_tokens,
103
+ )
104
+
105
+ async with TaskAppClient(base_url, api_key=args.api_key) as client:
106
+ try:
107
+ check_health(base_url, args.api_key)
108
+ info = await client.task_info(seeds=[args.seed])
109
+ payload = info[0] if isinstance(info, list) else info
110
+ print(json.dumps(payload.model_dump(), indent=2)[:600])
111
+
112
+ print("Requesting rollout…")
113
+ response = await client.rollout(request)
114
+ print(json.dumps(summarise(response), indent=2))
115
+ print(f"Ops executed: {request.ops}")
116
+ except httpx.HTTPStatusError as exc:
117
+ detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
118
+ print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
119
+ if exc.response.status_code in (401, 403):
120
+ print("Hint: check --api-key and ensure the remote deployment expects that value.", file=sys.stderr)
121
+ if exc.response.status_code == 404:
122
+ print("Hint: verify the --base-url includes the correct path (should be the root of the task app).", file=sys.stderr)
123
+ if exc.response.status_code == 500:
124
+ print("Hint: remote rollout failed server-side; inspect the deployment logs (Modal dashboard/logs).", file=sys.stderr)
125
+ raise
126
+
127
+
128
+ if __name__ == "__main__":
129
+ asyncio.run(main())
@@ -0,0 +1,38 @@
1
+ # Crafter Task App
2
+
3
+ This example is now wired through the shared Synth task-app harness. Use the
4
+ `uvx synth-ai` CLI to run it locally or deploy it to Modal without touching the
5
+ underlying FastAPI plumbing.
6
+
7
+ ## Local development
8
+ ```bash
9
+ uvx synth-ai serve grpo-crafter --port 8001
10
+ # Optional extras:
11
+ # --env-file path/to/.env # load additional environment variables
12
+ # --reload # enable uvicorn auto-reload
13
+ ```
14
+
15
+ Useful endpoints while the server is running:
16
+ - `GET http://localhost:8001/health`
17
+ - `GET http://localhost:8001/info`
18
+ - `GET http://localhost:8001/task_info?seed=42`
19
+ - `POST http://localhost:8001/rollout`
20
+
21
+ ## Deploy to Modal
22
+ ```bash
23
+ uvx synth-ai deploy grpo-crafter --name grpo-crafter-task-app
24
+ ```
25
+
26
+ Requirements:
27
+ - Modal CLI installed and authenticated (`modal token new`).
28
+ - Secrets `crafter-environment-sdk`, `groq-api-key`, and `openai-api-key`
29
+ available in your Modal account.
30
+
31
+ The CLI generates a Modal entrypoint on the fly using the shared
32
+ `TaskAppConfig`, ensuring the container matches the local FastAPI behavior.
33
+
34
+ ## Compatibility note
35
+ `examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py` remains as a
36
+ legacy wrapper exposing `fastapi_app()` and a `__main__` entrypoint. Behind the
37
+ scenes it proxies to the shared configuration; prefer the CLI workflow above
38
+ for new automation and tests.
@@ -8,13 +8,13 @@ from dataclasses import dataclass
8
8
  from pathlib import Path
9
9
  from typing import Any, Dict, Iterable, List, Sequence
10
10
 
11
- from ..contracts import RolloutRequest, RolloutResponse, TaskInfo
12
- from ..datasets import TaskDatasetRegistry, TaskDatasetSpec
13
- from ..rubrics import load_rubric
14
- from ..server import ProxyConfig, RubricBundle, TaskAppConfig
15
- from ..json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
16
- from . import ModalDeploymentConfig, TaskAppEntry, register_task_app
17
- from ..tracing_utils import (
11
+ from synth_ai.task.contracts import RolloutRequest, RolloutResponse, TaskInfo
12
+ from synth_ai.task.datasets import TaskDatasetRegistry, TaskDatasetSpec
13
+ from synth_ai.task.rubrics import load_rubric
14
+ from synth_ai.task.server import ProxyConfig, RubricBundle, TaskAppConfig
15
+ from synth_ai.task.json import to_jsonable # noqa: F401 (imported for side-effect compatibility)
16
+ from synth_ai.task.apps import ModalDeploymentConfig, TaskAppEntry, register_task_app
17
+ from synth_ai.task.tracing_utils import (
18
18
  build_tracer_factory,
19
19
  resolve_sft_output_dir,
20
20
  resolve_tracing_db_url,