synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (155) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1276 -186
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
  151. synth_ai/install_sqld.sh +0 -40
  152. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  153. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  154. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  155. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
@@ -31,12 +31,17 @@ def build_rollout_request(
31
31
  run_id: str,
32
32
  model: str,
33
33
  inference_url: str,
34
+ inference_api_key: str,
34
35
  ops: list[str],
35
36
  extra_headers: dict[str, str] | None = None,
36
37
  trace_format: str = "compact",
37
38
  return_trace: bool = False,
38
39
  ) -> RolloutRequest:
39
- policy_config = {"model": model, "inference_url": inference_url}
40
+ policy_config = {
41
+ "model": model,
42
+ "inference_url": inference_url,
43
+ "api_key": inference_api_key,
44
+ }
40
45
  if extra_headers:
41
46
  policy_config["extra_headers"] = extra_headers
42
47
  record_cfg = RolloutRecordConfig(
@@ -123,7 +128,9 @@ def analyse_rollout_response(response: Any) -> dict[str, Any]:
123
128
  if isinstance(final_list, list):
124
129
  final_achievements = [str(item) for item in final_list]
125
130
 
126
- decision_rewards = trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
131
+ decision_rewards = (
132
+ trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
133
+ )
127
134
  trace_all: list[str] = []
128
135
  if isinstance(decision_rewards, list):
129
136
  for item in decision_rewards:
@@ -180,7 +187,9 @@ def summarise_runs(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
180
187
  return stats
181
188
 
182
189
 
183
- def print_summary(stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int) -> None:
190
+ def print_summary(
191
+ stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int
192
+ ) -> None:
184
193
  if not stats:
185
194
  print("No successful rollouts to summarise.")
186
195
  return
@@ -234,7 +243,22 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
234
243
 
235
244
  api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
236
245
  if not api_key:
237
- raise RuntimeError("Missing --api-key or ENVIRONMENT_API_KEY")
246
+ import sys
247
+
248
+ print("Please enter your RL Environment API key:", file=sys.stderr, flush=True)
249
+ api_key = input("> ").strip()
250
+ if not api_key:
251
+ raise RuntimeError("RL Environment API key is required")
252
+
253
+ # Prompt for Groq API key if not set
254
+ groq_api_key = os.getenv("GROQ_API_KEY")
255
+ if not groq_api_key:
256
+ import sys
257
+
258
+ print("Please enter your Groq API key:", file=sys.stderr, flush=True)
259
+ groq_api_key = input("> ").strip()
260
+ if not groq_api_key:
261
+ raise RuntimeError("Groq API key is required")
238
262
 
239
263
  synth_key = os.getenv("SYNTH_API_KEY")
240
264
  extra_headers: dict[str, str] | None = None
@@ -252,29 +276,41 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
252
276
 
253
277
  ops = build_ops(args.max_llm_calls, args.ops)
254
278
 
279
+ print(f"\n🚀 Starting {args.count} rollouts with {args.parallel} parallel workers...")
280
+ print(f"📊 Each rollout: {len(ops)} ops ({args.max_llm_calls} LLM calls)\n")
281
+
255
282
  async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
283
+
256
284
  async def run_single(index: int) -> dict[str, Any]:
257
285
  run_id = f"{args.run_id}-{index:03d}"
258
286
  seed = args.seed + index * args.seed_stride
287
+ print(f"\n▶️ [{index + 1}/{args.count}] Starting rollout {run_id} (seed={seed})...")
288
+
259
289
  request = build_rollout_request(
260
290
  seed=seed,
261
291
  run_id=run_id,
262
292
  model=args.model,
263
293
  inference_url=args.inference_url,
294
+ inference_api_key=groq_api_key,
264
295
  ops=ops,
265
296
  extra_headers=extra_headers,
266
297
  trace_format=args.trace_format,
267
298
  return_trace=True,
268
299
  )
269
300
  if args.max_policy_tokens is not None:
270
- request.policy.config.update({
271
- "max_completion_tokens": args.max_policy_tokens,
272
- "max_tokens": args.max_policy_tokens,
273
- })
301
+ request.policy.config.update(
302
+ {
303
+ "max_completion_tokens": args.max_policy_tokens,
304
+ "max_tokens": args.max_policy_tokens,
305
+ }
306
+ )
274
307
 
275
308
  try:
276
309
  response = await client.rollout(request)
277
310
  summary = analyse_rollout_response(response)
311
+ print(
312
+ f"\n✅ [{index + 1}/{args.count}] Completed {run_id} (outcome={summary.get('outcome_score', 'N/A')})"
313
+ )
278
314
  return {
279
315
  "ok": True,
280
316
  "run_id": run_id,
@@ -283,6 +319,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
283
319
  "summary": summary,
284
320
  }
285
321
  except Exception as exc: # pragma: no cover - surface errors
322
+ print(f"\n❌ [{index + 1}/{args.count}] Failed {run_id}: {exc}")
286
323
  return {
287
324
  "ok": False,
288
325
  "run_id": run_id,
@@ -302,6 +339,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
302
339
  successes = [item for item in results if item.get("ok")]
303
340
  failures = [item for item in results if not item.get("ok")]
304
341
 
342
+ print(f"\n{'=' * 100}\n")
305
343
  stats = summarise_runs([item["summary"] for item in successes])
306
344
  print_summary(stats, run_details=successes, total_runs=args.count)
307
345
 
@@ -317,17 +355,43 @@ def parse_args() -> argparse.Namespace:
317
355
  parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
318
356
  parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
319
357
  parser.add_argument("--env-file", help="Path to .env file providing API keys")
320
- parser.add_argument("--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy")
321
- parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL for the policy")
358
+ parser.add_argument(
359
+ "--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy"
360
+ )
361
+ parser.add_argument(
362
+ "--inference-url",
363
+ default="https://api.openai.com",
364
+ help="Inference base URL for the policy",
365
+ )
322
366
  parser.add_argument("--seed", type=int, default=42, help="Base seed for the first rollout")
323
- parser.add_argument("--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout")
324
- parser.add_argument("--count", type=int, default=20, help="Number of rollout trajectories to execute")
367
+ parser.add_argument(
368
+ "--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout"
369
+ )
370
+ parser.add_argument(
371
+ "--count", type=int, default=20, help="Number of rollout trajectories to execute"
372
+ )
325
373
  parser.add_argument("--parallel", type=int, default=4, help="Maximum concurrent rollouts")
326
374
  parser.add_argument("--ops", help="Comma-separated rollout ops (advanced override)")
327
- parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of agent/env pairs per rollout when --ops not provided")
328
- parser.add_argument("--max-policy-tokens", type=int, help="Optional per-call token limit forwarded to the policy config")
329
- parser.add_argument("--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests")
330
- parser.add_argument("--trace-format", default="compact", choices=["compact", "full"], help="Trace format requested from the task app")
375
+ parser.add_argument(
376
+ "--max-llm-calls",
377
+ type=int,
378
+ default=20,
379
+ help="Number of agent/env pairs per rollout when --ops not provided",
380
+ )
381
+ parser.add_argument(
382
+ "--max-policy-tokens",
383
+ type=int,
384
+ help="Optional per-call token limit forwarded to the policy config",
385
+ )
386
+ parser.add_argument(
387
+ "--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests"
388
+ )
389
+ parser.add_argument(
390
+ "--trace-format",
391
+ default="compact",
392
+ choices=["compact", "full"],
393
+ help="Trace format requested from the task app",
394
+ )
331
395
  parser.add_argument("--run-id", default="batch-demo", help="Run ID prefix for rollouts")
332
396
  parser.add_argument("--verbose", action="store_true", help="Print resolved configuration")
333
397
  return parser.parse_args()
@@ -6,6 +6,7 @@ from __future__ import annotations
6
6
  import argparse
7
7
  import asyncio
8
8
  import json
9
+ import os
9
10
  from pathlib import Path
10
11
  from typing import Any
11
12
 
@@ -29,6 +30,7 @@ def build_rollout_request(
29
30
  run_id: str,
30
31
  model: str,
31
32
  inference_url: str,
33
+ inference_api_key: str,
32
34
  ops: list[str],
33
35
  return_trace: bool,
34
36
  trace_format: str,
@@ -37,6 +39,7 @@ def build_rollout_request(
37
39
  policy_config = {
38
40
  "model": model,
39
41
  "inference_url": inference_url,
42
+ "api_key": inference_api_key,
40
43
  }
41
44
  if max_policy_tokens is not None:
42
45
  policy_config.update(
@@ -64,7 +67,11 @@ def build_rollout_request(
64
67
 
65
68
 
66
69
  def summarise_rollout(response: Any) -> dict[str, Any]:
67
- metrics = response.metrics.model_dump() if hasattr(response, "metrics") else response.get("metrics", {})
70
+ metrics = (
71
+ response.metrics.model_dump()
72
+ if hasattr(response, "metrics")
73
+ else response.get("metrics", {})
74
+ )
68
75
  return {
69
76
  "run_id": getattr(response, "run_id", None) or response.get("run_id"),
70
77
  "num_episodes": metrics.get("num_episodes"),
@@ -83,17 +90,25 @@ def summarise_trace(trace: Any) -> dict[str, Any]:
83
90
 
84
91
  format_hint = "compact" if "events_count" in trace or "lm_calls" in trace else "full"
85
92
  events_count = trace.get("events_count")
86
- if events_count is None and "event_history" in trace and isinstance(trace["event_history"], list):
93
+ if (
94
+ events_count is None
95
+ and "event_history" in trace
96
+ and isinstance(trace["event_history"], list)
97
+ ):
87
98
  events_count = len(trace["event_history"])
88
99
  messages_count = trace.get("messages_count")
89
- if messages_count is None and "markov_blanket_message_history" in trace and isinstance(
90
- trace["markov_blanket_message_history"], list
100
+ if (
101
+ messages_count is None
102
+ and "markov_blanket_message_history" in trace
103
+ and isinstance(trace["markov_blanket_message_history"], list)
91
104
  ):
92
105
  messages_count = len(trace["markov_blanket_message_history"])
93
106
 
94
107
  metadata = trace.get("metadata") if isinstance(trace.get("metadata"), dict) else {}
95
108
  lm_calls = trace.get("lm_calls") if isinstance(trace.get("lm_calls"), list) else []
96
- decision_rewards = trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
109
+ decision_rewards = (
110
+ trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
111
+ )
97
112
 
98
113
  return {
99
114
  "session_id": trace.get("session_id"),
@@ -215,11 +230,13 @@ def print_reward_summary(
215
230
  if decision_rewards:
216
231
  print(" Decision rewards:")
217
232
  for entry in decision_rewards:
218
- turn = entry.get('turn')
219
- ach_delta = entry.get('ach_delta')
220
- unique_delta = entry.get('unique_delta')
221
- achievements = entry.get('achievements') or []
222
- print(f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}")
233
+ turn = entry.get("turn")
234
+ ach_delta = entry.get("ach_delta")
235
+ unique_delta = entry.get("unique_delta")
236
+ achievements = entry.get("achievements") or []
237
+ print(
238
+ f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}"
239
+ )
223
240
  else:
224
241
  print(" Decision rewards: none recorded")
225
242
 
@@ -242,16 +259,40 @@ def print_reward_summary(
242
259
 
243
260
 
244
261
  async def main() -> None:
262
+ # Load .env file from current directory if it exists
263
+ env_file = Path.cwd() / ".env"
264
+ if env_file.exists():
265
+ from dotenv import load_dotenv
266
+
267
+ load_dotenv(env_file)
268
+
245
269
  parser = argparse.ArgumentParser(description=__doc__)
246
- parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
247
- parser.add_argument("--api-key", required=True, help="Environment API key")
270
+ parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
271
+ parser.add_argument("--api-key", help="RL Environment API key (will prompt if not provided)")
272
+ parser.add_argument(
273
+ "--inference-api-key", help="Inference provider API key (will prompt if not provided)"
274
+ )
248
275
  parser.add_argument("--seed", type=int, default=42, help="Environment seed")
249
276
  parser.add_argument("--run-id", default="local-trace", help="Run identifier")
250
277
  parser.add_argument("--model", default="gpt-4o-mini", help="OpenAI-compatible model id")
251
- parser.add_argument("--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)")
252
- parser.add_argument("--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)")
253
- parser.add_argument("--max-llm-calls", type=int, default=1, help="Number of agent/env pairs when --ops not supplied")
254
- parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional max token budget forwarded to policy")
278
+ parser.add_argument(
279
+ "--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)"
280
+ )
281
+ parser.add_argument(
282
+ "--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)"
283
+ )
284
+ parser.add_argument(
285
+ "--max-llm-calls",
286
+ type=int,
287
+ default=1,
288
+ help="Number of agent/env pairs when --ops not supplied",
289
+ )
290
+ parser.add_argument(
291
+ "--max-policy-tokens",
292
+ type=int,
293
+ default=None,
294
+ help="Optional max token budget forwarded to policy",
295
+ )
255
296
  parser.add_argument(
256
297
  "--trace-format",
257
298
  choices=["compact", "full"],
@@ -286,10 +327,69 @@ async def main() -> None:
286
327
  )
287
328
  args = parser.parse_args()
288
329
 
330
+ # Prompt for required parameters if not provided
331
+ base_url = args.base_url
332
+ if args.base_url == "http://localhost:8001":
333
+ print("\nTask app configuration:")
334
+ base_url_input = input(f"Task app base URL [http://localhost:8001]: ").strip()
335
+ base_url = base_url_input if base_url_input else "http://localhost:8001"
336
+
337
+ api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
338
+ if not api_key:
339
+ api_key = input("RL Environment API key (from ENVIRONMENT_API_KEY): ").strip()
340
+ if not api_key:
341
+ parser.error("RL Environment API key is required")
342
+
343
+ # Use Groq by default
344
+ model = "llama-3.3-70b-versatile"
345
+ inference_url = "https://api.groq.com/openai"
346
+
347
+ print("\nInference configuration (Groq):")
348
+ inference_api_key = args.inference_api_key or os.getenv("GROQ_API_KEY")
349
+ if not inference_api_key:
350
+ inference_api_key = input("Groq API key: ").strip()
351
+ if not inference_api_key:
352
+ parser.error("Groq API key is required")
353
+
354
+ # Save to .env for future use
355
+ env_path = Path.cwd() / ".env"
356
+ try:
357
+ # Read existing .env
358
+ existing_lines = []
359
+ if env_path.exists():
360
+ existing_lines = env_path.read_text().splitlines()
361
+
362
+ # Check if GROQ_API_KEY already exists
363
+ key_exists = any(line.strip().startswith("GROQ_API_KEY=") for line in existing_lines)
364
+
365
+ if not key_exists:
366
+ # Append to .env
367
+ with open(env_path, "a") as f:
368
+ if existing_lines and not existing_lines[-1].strip():
369
+ # File exists and last line is not empty
370
+ pass
371
+ elif existing_lines:
372
+ # Add newline before appending
373
+ f.write("\n")
374
+ f.write(f"GROQ_API_KEY={inference_api_key}\n")
375
+ print(f"[INFO] Saved GROQ_API_KEY to {env_path}")
376
+ except Exception as e:
377
+ print(f"[WARN] Could not save GROQ_API_KEY to .env: {e}")
378
+
379
+ print("\nRollout configuration:")
380
+ max_llm_calls = args.max_llm_calls
381
+ if args.max_llm_calls == 1:
382
+ max_llm_calls_input = input(f"Max LLM calls [10]: ").strip()
383
+ max_llm_calls = int(max_llm_calls_input) if max_llm_calls_input else 10
384
+
385
+ # Override args with prompted values
386
+ args.base_url = base_url
387
+ args.max_llm_calls = max_llm_calls
388
+
289
389
  ops = ensure_ops(args.ops, args.max_llm_calls)
290
390
  return_trace = not args.no_trace
291
391
 
292
- async with TaskAppClient(args.base_url, api_key=args.api_key, timeout=args.timeout) as client:
392
+ async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
293
393
  try:
294
394
  print(f"Fetching task_info for seed {args.seed}…")
295
395
  task_info = await client.task_info(seeds=[args.seed])
@@ -302,8 +402,9 @@ async def main() -> None:
302
402
  request = build_rollout_request(
303
403
  seed=args.seed,
304
404
  run_id=args.run_id,
305
- model=args.model,
306
- inference_url=args.inference_url,
405
+ model=model,
406
+ inference_url=inference_url,
407
+ inference_api_key=inference_api_key,
307
408
  ops=ops,
308
409
  return_trace=return_trace,
309
410
  trace_format=args.trace_format,
@@ -350,7 +451,11 @@ async def main() -> None:
350
451
  "Tip: export TASKAPP_TRACING_ENABLED=1 and optionally TASKAPP_SFT_OUTPUT_DIR before running `uvx synth-ai serve …` to persist traces/SFT."
351
452
  )
352
453
  except httpx.HTTPStatusError as exc:
353
- detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
454
+ detail = (
455
+ exc.response.json()
456
+ if exc.response.headers.get("content-type", "").startswith("application/json")
457
+ else exc.response.text
458
+ )
354
459
  print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
355
460
  if exc.response.status_code in (401, 503):
356
461
  print(
@@ -11,6 +11,8 @@ from typing import Any, Dict
11
11
  import tomllib
12
12
  import requests
13
13
 
14
+ from synth_ai.config.base_url import PROD_BASE_URL_DEFAULT
15
+
14
16
 
15
17
  def _load_toml(path: Path) -> Dict[str, Any]:
16
18
  if not path.exists():
@@ -21,11 +23,23 @@ def _load_toml(path: Path) -> Dict[str, Any]:
21
23
 
22
24
 
23
25
  def main() -> None:
24
- p = argparse.ArgumentParser(description="Create clustered RL training job via backend RL endpoint")
25
- p.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
26
+ p = argparse.ArgumentParser(
27
+ description="Create clustered RL training job via backend RL endpoint"
28
+ )
29
+ p.add_argument(
30
+ "--backend", default=os.getenv("BACKEND_BASE_URL", f"{PROD_BASE_URL_DEFAULT}/api")
31
+ )
26
32
  p.add_argument("--config", required=True, help="Path to RL TOML config")
27
- p.add_argument("--task-url", default=os.getenv("TASK_APP_URL", ""), help="Override task service URL (or set TASK_APP_URL)")
28
- p.add_argument("--idempotency", default=os.getenv("RL_IDEMPOTENCY_KEY", ""), help="Optional Idempotency-Key header value")
33
+ p.add_argument(
34
+ "--task-url",
35
+ default=os.getenv("TASK_APP_URL", ""),
36
+ help="Override task service URL (or set TASK_APP_URL)",
37
+ )
38
+ p.add_argument(
39
+ "--idempotency",
40
+ default=os.getenv("RL_IDEMPOTENCY_KEY", ""),
41
+ help="Optional Idempotency-Key header value",
42
+ )
29
43
  args = p.parse_args()
30
44
 
31
45
  cfg_path = Path(args.config).expanduser()
@@ -36,9 +50,16 @@ def main() -> None:
36
50
  # Resolve task app base URL for the job
37
51
  cli_task_url = (args.task_url or "").strip()
38
52
  env_task_url = (os.getenv("TASK_APP_URL") or "").strip()
39
- task_url = cli_task_url or env_task_url or ((services.get("task_url") or "").strip() if isinstance(services, dict) else "")
53
+ task_url = (
54
+ cli_task_url
55
+ or env_task_url
56
+ or ((services.get("task_url") or "").strip() if isinstance(services, dict) else "")
57
+ )
40
58
  if not task_url:
41
- print("Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML", file=sys.stderr)
59
+ print(
60
+ "Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML",
61
+ file=sys.stderr,
62
+ )
42
63
  sys.exit(2)
43
64
 
44
65
  # TOML-only model selection validation
@@ -46,7 +67,10 @@ def main() -> None:
46
67
  has_source = bool((model_cfg.get("source") or "").strip())
47
68
  has_base = bool((model_cfg.get("base") or "").strip())
48
69
  if has_source == has_base:
49
- print("Model selection must specify exactly one of [model].source or [model].base in TOML", file=sys.stderr)
70
+ print(
71
+ "Model selection must specify exactly one of [model].source or [model].base in TOML",
72
+ file=sys.stderr,
73
+ )
50
74
  sys.exit(2)
51
75
 
52
76
  # Build create-job payload. Send full TOML under data.config, plus endpoint_base_url.
@@ -11,10 +11,17 @@ import sys
11
11
 
12
12
  import httpx
13
13
 
14
+
14
15
  def check_health(base_url: str, api_key: str) -> None:
15
16
  try:
16
- resp = httpx.get(f"{base_url.rstrip('/')}/health", headers={"X-API-Key": api_key}, timeout=10.0)
17
- data = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else resp.text
17
+ resp = httpx.get(
18
+ f"{base_url.rstrip('/')}/health", headers={"X-API-Key": api_key}, timeout=10.0
19
+ )
20
+ data = (
21
+ resp.json()
22
+ if resp.headers.get("content-type", "").startswith("application/json")
23
+ else resp.text
24
+ )
18
25
  if resp.status_code != 200:
19
26
  print(f"warning: /health returned {resp.status_code}: {data}")
20
27
  else:
@@ -22,6 +29,7 @@ def check_health(base_url: str, api_key: str) -> None:
22
29
  except Exception as exc:
23
30
  print(f"warning: failed to call /health: {exc}")
24
31
 
32
+
25
33
  from synth_ai.task import (
26
34
  RolloutEnvSpec,
27
35
  RolloutPolicySpec,
@@ -79,8 +87,14 @@ def summarise(response) -> dict[str, any]:
79
87
 
80
88
  async def main() -> None:
81
89
  parser = argparse.ArgumentParser(description=__doc__)
82
- parser.add_argument("--base-url", default=None, help="Remote task app base URL (e.g., https://xyz.modal.run); defaults to TASK_APP_BASE_URL env")
83
- parser.add_argument("--api-key", required=True, help="Environment API key for the remote task app")
90
+ parser.add_argument(
91
+ "--base-url",
92
+ default=None,
93
+ help="Remote task app base URL (e.g., https://xyz.modal.run); defaults to TASK_APP_BASE_URL env",
94
+ )
95
+ parser.add_argument(
96
+ "--api-key", required=True, help="Environment API key for the remote task app"
97
+ )
84
98
  parser.add_argument("--seed", type=int, default=42)
85
99
  parser.add_argument("--run-id", default="remote-demo")
86
100
  parser.add_argument("--model", default="gpt-4o-mini")
@@ -89,9 +103,9 @@ async def main() -> None:
89
103
  parser.add_argument("--max-policy-tokens", type=int, default=None)
90
104
  args = parser.parse_args()
91
105
 
92
- base_url = args.base_url or os.getenv('TASK_APP_BASE_URL')
106
+ base_url = args.base_url or os.getenv("TASK_APP_BASE_URL")
93
107
  if not base_url:
94
- parser.error('Missing --base-url (and TASK_APP_BASE_URL not set).')
108
+ parser.error("Missing --base-url (and TASK_APP_BASE_URL not set).")
95
109
 
96
110
  request = build_request(
97
111
  run_id=args.run_id,
@@ -114,14 +128,27 @@ async def main() -> None:
114
128
  print(json.dumps(summarise(response), indent=2))
115
129
  print(f"Ops executed: {request.ops}")
116
130
  except httpx.HTTPStatusError as exc:
117
- detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
131
+ detail = (
132
+ exc.response.json()
133
+ if exc.response.headers.get("content-type", "").startswith("application/json")
134
+ else exc.response.text
135
+ )
118
136
  print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
119
137
  if exc.response.status_code in (401, 403):
120
- print("Hint: check --api-key and ensure the remote deployment expects that value.", file=sys.stderr)
138
+ print(
139
+ "Hint: check --api-key and ensure the remote deployment expects that value.",
140
+ file=sys.stderr,
141
+ )
121
142
  if exc.response.status_code == 404:
122
- print("Hint: verify the --base-url includes the correct path (should be the root of the task app).", file=sys.stderr)
143
+ print(
144
+ "Hint: verify the --base-url includes the correct path (should be the root of the task app).",
145
+ file=sys.stderr,
146
+ )
123
147
  if exc.response.status_code == 500:
124
- print("Hint: remote rollout failed server-side; inspect the deployment logs (Modal dashboard/logs).", file=sys.stderr)
148
+ print(
149
+ "Hint: remote rollout failed server-side; inspect the deployment logs (Modal dashboard/logs).",
150
+ file=sys.stderr,
151
+ )
125
152
  raise
126
153
 
127
154