synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (155) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1276 -186
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
  151. synth_ai/install_sqld.sh +0 -40
  152. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  153. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  154. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  155. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,9 @@ from typing import Any, Dict, Tuple, List
12
12
  import tomllib
13
13
  import re
14
14
  import requests
15
+ from dotenv import load_dotenv
16
+
17
+ from synth_ai.config.base_url import PROD_BASE_URL_DEFAULT
15
18
 
16
19
 
17
20
  def mask(val: str) -> str:
@@ -20,7 +23,9 @@ def mask(val: str) -> str:
20
23
  return f"{val[:6]}…{val[-4:]}" if len(val) >= 10 else "****"
21
24
 
22
25
 
23
- def post_multipart(base: str, api_key: str, path: str, file_field: str, filepath: Path) -> Dict[str, Any]:
26
+ def post_multipart(
27
+ base: str, api_key: str, path: str, file_field: str, filepath: Path
28
+ ) -> Dict[str, Any]:
24
29
  """Upload a file, trying backend-specific endpoints with fallbacks.
25
30
 
26
31
  Priority:
@@ -33,7 +38,7 @@ def post_multipart(base: str, api_key: str, path: str, file_field: str, filepath
33
38
 
34
39
  endpoints = [
35
40
  f"{base.rstrip('/')}/{path.lstrip('/')}", # e.g., /learning/files
36
- f"{base.rstrip('/')}/files", # OpenAI-style
41
+ f"{base.rstrip('/')}/files", # OpenAI-style
37
42
  ]
38
43
  last_err: Dict[str, Any] | None = None
39
44
  for ep in endpoints:
@@ -87,17 +92,94 @@ def get_json(base: str, api_key: str, path: str) -> Dict[str, Any]:
87
92
  return {"status": r.status_code, "text": r.text[:400]}
88
93
 
89
94
 
95
+ def _find_fft_configs() -> List[Path]:
96
+ """Find FFT TOML configs in standard locations."""
97
+ candidates: List[Path] = []
98
+
99
+ # Check current directory configs/
100
+ cwd = Path.cwd()
101
+ configs_dir = cwd / "configs"
102
+ if configs_dir.is_dir():
103
+ for f in configs_dir.glob("*.toml"):
104
+ # Look for FFT configs (check if they have [algorithm] method = "supervised_finetune")
105
+ try:
106
+ content = f.read_text()
107
+ if "supervised_finetune" in content or "fft" in content.lower():
108
+ candidates.append(f)
109
+ except Exception:
110
+ pass
111
+
112
+ # Also check for any .toml files in current directory
113
+ for f in cwd.glob("*.toml"):
114
+ if f not in candidates:
115
+ try:
116
+ content = f.read_text()
117
+ if "supervised_finetune" in content or "fft" in content.lower():
118
+ candidates.append(f)
119
+ except Exception:
120
+ pass
121
+
122
+ return sorted(candidates)
123
+
124
+
90
125
  def main() -> None:
126
+ # Load .env file from current directory first if it exists
127
+ default_env = Path.cwd() / ".env"
128
+ if default_env.exists():
129
+ load_dotenv(default_env, override=False)
130
+
91
131
  parser = argparse.ArgumentParser(description="Submit FFT job and save resulting model id")
92
- parser.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
93
- parser.add_argument("--toml", required=True, help="Path to FFT TOML config")
132
+ parser.add_argument(
133
+ "--backend", default=os.getenv("BACKEND_BASE_URL", f"{PROD_BASE_URL_DEFAULT}/api")
134
+ )
135
+ parser.add_argument("--toml", required=False, help="Path to FFT TOML config")
94
136
  parser.add_argument("--data", default="", help="Override dataset JSONL path")
95
137
  parser.add_argument("--poll-seconds", type=int, default=1800)
96
- parser.add_argument("--env-file", default="", help="Optional path to .env file with SYNTH_API_KEY")
138
+ parser.add_argument(
139
+ "--env-file", default="", help="Optional path to .env file with SYNTH_API_KEY"
140
+ )
97
141
  args = parser.parse_args()
98
142
 
99
- config_path = Path(args.toml).expanduser().resolve()
100
- if not config_path.exists():
143
+ # Also load from explicit --env-file if provided
144
+ if args.env_file:
145
+ env_path = Path(args.env_file).expanduser()
146
+ if not env_path.exists():
147
+ print(f"[WARN] Env file not found: {env_path}")
148
+ else:
149
+ load_dotenv(env_path, override=False)
150
+
151
+ # Auto-discover TOML config if not specified
152
+ config_path: Path | None = None
153
+ if args.toml:
154
+ config_path = Path(args.toml).expanduser().resolve()
155
+ else:
156
+ configs = _find_fft_configs()
157
+ if not configs:
158
+ print(
159
+ "No FFT config files found. Please specify --toml or create a config in configs/",
160
+ file=sys.stderr,
161
+ )
162
+ sys.exit(2)
163
+ elif len(configs) == 1:
164
+ config_path = configs[0]
165
+ print(f"Using FFT config: {config_path}")
166
+ else:
167
+ print("\nFound multiple FFT configs:")
168
+ for idx, cfg in enumerate(configs, 1):
169
+ print(f" [{idx}] {cfg}")
170
+ choice = input(f"Select config [1-{len(configs)}]: ").strip()
171
+ try:
172
+ selected_idx = int(choice) - 1
173
+ if 0 <= selected_idx < len(configs):
174
+ config_path = configs[selected_idx]
175
+ else:
176
+ print("Invalid selection", file=sys.stderr)
177
+ sys.exit(2)
178
+ except ValueError:
179
+ print("Invalid input", file=sys.stderr)
180
+ sys.exit(2)
181
+
182
+ if not config_path or not config_path.exists():
101
183
  print(f"Config not found: {config_path}", file=sys.stderr)
102
184
  sys.exit(2)
103
185
  with config_path.open("rb") as fh:
@@ -107,7 +189,9 @@ def main() -> None:
107
189
  compute_cfg = cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {}
108
190
  data_cfg_full = cfg.get("data", {}) if isinstance(cfg.get("data"), dict) else {}
109
191
  topo_cfg = (data_cfg_full or {}).get("topology", {}) if isinstance(data_cfg_full, dict) else {}
110
- validation_local_path = (data_cfg_full or {}).get("validation_path") if isinstance(data_cfg_full, dict) else None
192
+ validation_local_path = (
193
+ (data_cfg_full or {}).get("validation_path") if isinstance(data_cfg_full, dict) else None
194
+ )
111
195
  train_cfg = cfg.get("training", {}) if isinstance(cfg.get("training"), dict) else {}
112
196
  hp_cfg = cfg.get("hyperparameters", {}) if isinstance(cfg.get("hyperparameters"), dict) else {}
113
197
 
@@ -119,7 +203,13 @@ def main() -> None:
119
203
  if isinstance(data_path, str) and data_path.strip():
120
204
  p = Path(data_path).expanduser()
121
205
  if not p.is_absolute():
122
- p = (config_path.parent / p).resolve()
206
+ # Try relative to cwd first, then relative to config directory
207
+ cwd_relative = Path.cwd() / p
208
+ config_relative = config_path.parent / p
209
+ if cwd_relative.exists():
210
+ p = cwd_relative.resolve()
211
+ else:
212
+ p = config_relative.resolve()
123
213
  data_file = p
124
214
  if data_file is None:
125
215
  print("Missing dataset path in --data or [job].data", file=sys.stderr)
@@ -129,38 +219,11 @@ def main() -> None:
129
219
  sys.exit(2)
130
220
 
131
221
  synth_key = (os.getenv("SYNTH_API_KEY") or "").strip()
132
- # Fallback: try to load from .env if not present in environment
133
222
  if not synth_key:
134
- candidate_env: Path | None = None
135
- if isinstance(args.env_file, str) and args.env_file.strip():
136
- candidate_env = Path(args.env_file).expanduser().resolve()
137
- else:
138
- # Prefer .env next to the TOML config
139
- candidate_env = (config_path.parent / ".env").resolve()
140
- if candidate_env and candidate_env.exists():
141
- try:
142
- env_text = candidate_env.read_text(encoding="utf-8", errors="ignore")
143
- # Match lines like: SYNTH_API_KEY=..., or export SYNTH_API_KEY=...
144
- key_val: str | None = None
145
- for line in env_text.splitlines():
146
- m = re.match(r"^\s*(?:export\s+)?SYNTH_API_KEY\s*=\s*(.*)$", line)
147
- if m:
148
- raw = m.group(1).strip()
149
- # Trim surrounding quotes if present
150
- if (raw.startswith('"') and raw.endswith('"')) or (raw.startswith("'") and raw.endswith("'")):
151
- raw = raw[1:-1]
152
- key_val = raw.strip()
153
- break
154
- if key_val:
155
- synth_key = key_val
156
- os.environ["SYNTH_API_KEY"] = synth_key
157
- print(f"[INFO] Loaded SYNTH_API_KEY from {candidate_env}")
158
- except Exception as _e:
159
- # Ignore and fall through to error below
160
- pass
161
- if not synth_key:
162
- print("Missing SYNTH_API_KEY (set in env or provide --env-file pointing to .env)", file=sys.stderr)
163
- sys.exit(2)
223
+ synth_key = input("Please enter your Synth API key:\n> ").strip()
224
+ if not synth_key:
225
+ print("Synth API key is required", file=sys.stderr)
226
+ sys.exit(2)
164
227
 
165
228
  backend = args.backend.rstrip("/")
166
229
  print(f"[INFO] Using backend={backend} key_fp={mask(synth_key)} data={data_file}")
@@ -180,7 +243,10 @@ def main() -> None:
180
243
  err_status = (upf or {}).get("status")
181
244
  err_body = (upf or {}).get("body") or (upf or {}).get("text")
182
245
  err_ep = (upf or {}).get("endpoint")
183
- print(f"Upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:200]}", file=sys.stderr)
246
+ print(
247
+ f"Upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:200]}",
248
+ file=sys.stderr,
249
+ )
184
250
  sys.exit(4)
185
251
 
186
252
  # Optionally upload validation file
@@ -203,7 +269,9 @@ def main() -> None:
203
269
  err_status = (upv or {}).get("status")
204
270
  err_body = (upv or {}).get("body") or (upv or {}).get("text")
205
271
  err_ep = (upv or {}).get("endpoint")
206
- print(f"[WARN] Validation upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:180]} — continuing without validation")
272
+ print(
273
+ f"[WARN] Validation upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:180]} — continuing without validation"
274
+ )
207
275
 
208
276
  # 2) Build job payload
209
277
  hp_block: Dict[str, Any] = {
@@ -238,18 +306,24 @@ def main() -> None:
238
306
  "training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
239
307
  }
240
308
  # If TOML includes a [training.validation] block, forward relevant knobs into hyperparameters
241
- validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
309
+ validation_cfg = (
310
+ train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
311
+ )
242
312
  if isinstance(validation_cfg, dict):
243
313
  # Enable evaluation and map keys as-is; backend trainer maps metric_for_best_model 'val.loss'→'eval_loss'
244
- hp_block.update({
245
- "evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
246
- "eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
247
- "save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
248
- "metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
249
- "greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
250
- })
314
+ hp_block.update(
315
+ {
316
+ "evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
317
+ "eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
318
+ "save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
319
+ "metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
320
+ "greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
321
+ }
322
+ )
251
323
  # Also surface validation enable flag into effective_config for visibility (optional)
252
- effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
324
+ effective.setdefault("training", {})["validation"] = {
325
+ "enabled": bool(validation_cfg.get("enabled", True))
326
+ }
253
327
 
254
328
  body = {
255
329
  "model": model,
@@ -289,7 +363,9 @@ def main() -> None:
289
363
  break
290
364
  # Warn if stuck queued for >10 minutes
291
365
  if status == "queued" and (time.time() - queued_since) > 600:
292
- print("[WARN] Job has remained queued for >10 minutes. Backend may be capacity constrained.")
366
+ print(
367
+ "[WARN] Job has remained queued for >10 minutes. Backend may be capacity constrained."
368
+ )
293
369
  queued_since = time.time()
294
370
  time.sleep(5)
295
371
 
@@ -46,17 +46,21 @@ def build_rollout_request(
46
46
  )
47
47
  return RolloutRequest(
48
48
  run_id=run_id,
49
- env=RolloutEnvSpec(env_name='crafter', seed=seed, config={}),
50
- policy=RolloutPolicySpec(policy_name='crafter-react', config=policy_config),
49
+ env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
50
+ policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
51
51
  ops=ops,
52
52
  record=record_cfg,
53
- on_done='reset',
53
+ on_done="reset",
54
54
  safety=RolloutSafetyConfig(),
55
55
  )
56
56
 
57
57
 
58
58
  def summarise_response(data: Any) -> dict[str, Any]:
59
- metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
59
+ metrics = (
60
+ data.metrics.model_dump()
61
+ if hasattr(data.metrics, "model_dump")
62
+ else data.get("metrics", {})
63
+ )
60
64
  error = None
61
65
  rollout_status = None
62
66
  try:
@@ -86,16 +90,42 @@ async def main() -> None:
86
90
  parser = argparse.ArgumentParser(description=__doc__)
87
91
  parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
88
92
  parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
89
- parser.add_argument('--seed', type=int, default=42, help='Env seed to rollout')
90
- parser.add_argument('--run-id', default='local-demo', help='Run identifier')
91
- parser.add_argument('--model', default='gpt-4o-mini', help='Model identifier for the Crafter policy (OpenAI-compatible)')
92
- parser.add_argument('--inference-url', default='https://api.openai.com', help='Inference base URL used by the policy (e.g., https://api.openai.com)')
93
- parser.add_argument('--env-file', type=str, default=None, help='Path to .env file with API keys')
94
- parser.add_argument('--ops', default=None, help='Comma-separated rollout ops (advanced override)')
95
- parser.add_argument('--max-llm-calls', type=int, default=1, help='Number of policy inference calls when --ops not provided')
96
- parser.add_argument('--max-policy-tokens', type=int, default=None, help='Optional per-call token limit forwarded to the policy config')
97
- parser.add_argument('--timeout', type=float, default=600.0, help='HTTP timeout (seconds) for task app requests')
98
- parser.add_argument('--verbose', action='store_true', help='Print resolved configuration and headers')
93
+ parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
94
+ parser.add_argument("--run-id", default="local-demo", help="Run identifier")
95
+ parser.add_argument(
96
+ "--model",
97
+ default="gpt-4o-mini",
98
+ help="Model identifier for the Crafter policy (OpenAI-compatible)",
99
+ )
100
+ parser.add_argument(
101
+ "--inference-url",
102
+ default="https://api.openai.com",
103
+ help="Inference base URL used by the policy (e.g., https://api.openai.com)",
104
+ )
105
+ parser.add_argument(
106
+ "--env-file", type=str, default=None, help="Path to .env file with API keys"
107
+ )
108
+ parser.add_argument(
109
+ "--ops", default=None, help="Comma-separated rollout ops (advanced override)"
110
+ )
111
+ parser.add_argument(
112
+ "--max-llm-calls",
113
+ type=int,
114
+ default=1,
115
+ help="Number of policy inference calls when --ops not provided",
116
+ )
117
+ parser.add_argument(
118
+ "--max-policy-tokens",
119
+ type=int,
120
+ default=None,
121
+ help="Optional per-call token limit forwarded to the policy config",
122
+ )
123
+ parser.add_argument(
124
+ "--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests"
125
+ )
126
+ parser.add_argument(
127
+ "--verbose", action="store_true", help="Print resolved configuration and headers"
128
+ )
99
129
  args = parser.parse_args()
100
130
 
101
131
  if args.env_file:
@@ -117,12 +147,13 @@ async def main() -> None:
117
147
  os.environ["OPENAI_API_KEY"] = synth_key
118
148
 
119
149
  if args.verbose:
150
+
120
151
  def _mask(val: str | None) -> str:
121
152
  if not val:
122
- return '<unset>'
153
+ return "<unset>"
123
154
  return f"{val[:6]}…{val[-4:]} (len={len(val)})"
124
155
 
125
- print('Resolved configuration:')
156
+ print("Resolved configuration:")
126
157
  print(f" Task app base URL : {args.base_url}")
127
158
  print(f" Inference base URL : {args.inference_url}")
128
159
  print(f" Task app API key : {_mask(api_key)}")
@@ -130,21 +161,23 @@ async def main() -> None:
130
161
  print(f" HTTP timeout : {args.timeout:.1f}s")
131
162
 
132
163
  if args.ops:
133
- ops = [op.strip() for op in args.ops.split(',') if op.strip()]
164
+ ops = [op.strip() for op in args.ops.split(",") if op.strip()]
134
165
  if not ops:
135
- raise ValueError('Ops must contain at least one entry')
166
+ raise ValueError("Ops must contain at least one entry")
136
167
  else:
137
168
  llm_calls = max(args.max_llm_calls, 1)
138
169
  if llm_calls > 20:
139
- print('[WARN] --max-llm-calls capped at 20 to avoid excessive episodes; use --ops for manual control.')
170
+ print(
171
+ "[WARN] --max-llm-calls capped at 20 to avoid excessive episodes; use --ops for manual control."
172
+ )
140
173
  llm_calls = 20
141
174
  ops = []
142
175
  for _ in range(llm_calls):
143
- ops.extend(['agent', 'env'])
176
+ ops.extend(["agent", "env"])
144
177
 
145
178
  async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
146
179
  try:
147
- print(f'Fetching task_info for seed {args.seed}…')
180
+ print(f"Fetching task_info for seed {args.seed}…")
148
181
  task_info = await client.task_info(seeds=[args.seed])
149
182
  info_payload = task_info[0] if isinstance(task_info, list) else task_info
150
183
  print(json.dumps(info_payload.model_dump(), indent=2)[:600])
@@ -158,29 +191,47 @@ async def main() -> None:
158
191
  extra_headers=extra_headers,
159
192
  )
160
193
  if args.max_policy_tokens is not None:
161
- request.policy.config.update({
162
- 'max_completion_tokens': args.max_policy_tokens,
163
- 'max_tokens': args.max_policy_tokens,
164
- })
194
+ request.policy.config.update(
195
+ {
196
+ "max_completion_tokens": args.max_policy_tokens,
197
+ "max_tokens": args.max_policy_tokens,
198
+ }
199
+ )
165
200
  if args.verbose:
166
- print(f'Ops: {ops}')
167
- print(f'Request headers: {request.policy.config.get("extra_headers", {})}')
168
- print('Requesting rollout…')
201
+ print(f"Ops: {ops}")
202
+ print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
203
+ print("Requesting rollout…")
169
204
  response = await client.rollout(request)
170
205
  summary = summarise_response(response)
171
206
  print(json.dumps(summary, indent=2))
172
- print(f'Ops executed: {ops}')
173
- print('Tip: use --max-llm-calls N for agent/env pairs or --ops for manual control.')
207
+ print(f"Ops executed: {ops}")
208
+ print("Tip: use --max-llm-calls N for agent/env pairs or --ops for manual control.")
174
209
  except httpx.HTTPStatusError as exc:
175
- detail = exc.response.json() if exc.response.headers.get('content-type', '').startswith('application/json') else exc.response.text
176
- print(f'HTTP error {exc.response.status_code}: {detail}', file=sys.stderr)
210
+ detail = (
211
+ exc.response.json()
212
+ if exc.response.headers.get("content-type", "").startswith("application/json")
213
+ else exc.response.text
214
+ )
215
+ print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
177
216
  if exc.response.status_code in (401, 503):
178
- print('Hint: ensure the task app was started with ENVIRONMENT_API_KEY set and pass the same key via --api-key.', file=sys.stderr)
217
+ print(
218
+ "Hint: ensure the task app was started with ENVIRONMENT_API_KEY set and pass the same key via --api-key.",
219
+ file=sys.stderr,
220
+ )
179
221
  if exc.response.status_code == 500 and args.model in str(detail):
180
- print('Hint: supply --model/--inference-url (and set OPENAI_API_KEY or GROQ_API_KEY) so the policy can route inference.', file=sys.stderr)
181
- print('Hint: the inference URL should be the base (e.g., https://api.openai.com); the task app appends /v1/chat/completions.', file=sys.stderr)
222
+ print(
223
+ "Hint: supply --model/--inference-url (and set OPENAI_API_KEY or GROQ_API_KEY) so the policy can route inference.",
224
+ file=sys.stderr,
225
+ )
226
+ print(
227
+ "Hint: the inference URL should be the base (e.g., https://api.openai.com); the task app appends /v1/chat/completions.",
228
+ file=sys.stderr,
229
+ )
182
230
  if args.max_policy_tokens is not None:
183
- print(f'Hint: --max-policy-tokens={args.max_policy_tokens} is forwarded to the policy config as max_completion_tokens.', file=sys.stderr)
231
+ print(
232
+ f"Hint: --max-policy-tokens={args.max_policy_tokens} is forwarded to the policy config as max_completion_tokens.",
233
+ file=sys.stderr,
234
+ )
184
235
  raise
185
236
 
186
237
 
@@ -25,7 +25,9 @@ from synth_ai.task import (
25
25
  )
26
26
 
27
27
 
28
- def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str) -> RolloutRequest:
28
+ def build_rollout_request(
29
+ seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str
30
+ ) -> RolloutRequest:
29
31
  policy_config = {
30
32
  "model": model,
31
33
  "inference_url": inference_url,
@@ -45,7 +47,11 @@ def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url:
45
47
 
46
48
 
47
49
  def summarise_response(data: Any) -> dict[str, Any]:
48
- metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
50
+ metrics = (
51
+ data.metrics.model_dump()
52
+ if hasattr(data.metrics, "model_dump")
53
+ else data.get("metrics", {})
54
+ )
49
55
  return {
50
56
  "run_id": getattr(data, "run_id", None) or data.get("run_id"),
51
57
  "num_episodes": metrics.get("num_episodes"),
@@ -57,21 +63,54 @@ def summarise_response(data: Any) -> dict[str, Any]:
57
63
 
58
64
 
59
65
  async def main() -> None:
66
+ # Load .env file from current directory first if it exists
67
+ default_env = Path.cwd() / ".env"
68
+ if default_env.exists():
69
+ load_dotenv(default_env, override=False)
70
+
60
71
  parser = argparse.ArgumentParser(description=__doc__)
61
72
  parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
62
73
  parser.add_argument("--env-file", type=str, default=None, help="Path to .env file with keys")
63
74
  parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
64
75
  parser.add_argument("--run-id", default="modal-eval", help="Run identifier")
65
- parser.add_argument("--model", required=True, help="Model identifier for the Crafter policy")
66
- parser.add_argument("--inference-url", required=True, help="Modal backend inference base URL (e.g., http://localhost:8000/api)")
67
- parser.add_argument("--task-app-key", default=None, help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)")
68
- parser.add_argument("--modal-key", default=None, help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)")
69
- parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of policy inference calls")
70
- parser.add_argument("--ops", default=None, help="Comma-separated rollout ops (advanced override)")
71
- parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional per-call token limit forwarded to the policy config")
72
- parser.add_argument("--verbose", action="store_true", help="Print resolved configuration and headers")
76
+ parser.add_argument(
77
+ "--model",
78
+ required=False,
79
+ help="Model identifier for the Crafter policy (e.g., fft:Qwen/Qwen3-4B:job_xxx)",
80
+ )
81
+ parser.add_argument(
82
+ "--inference-url",
83
+ required=False,
84
+ help="Modal backend inference base URL (e.g., http://localhost:8000/api)",
85
+ )
86
+ parser.add_argument(
87
+ "--task-app-key",
88
+ default=None,
89
+ help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)",
90
+ )
91
+ parser.add_argument(
92
+ "--modal-key",
93
+ default=None,
94
+ help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)",
95
+ )
96
+ parser.add_argument(
97
+ "--max-llm-calls", type=int, default=20, help="Number of policy inference calls"
98
+ )
99
+ parser.add_argument(
100
+ "--ops", default=None, help="Comma-separated rollout ops (advanced override)"
101
+ )
102
+ parser.add_argument(
103
+ "--max-policy-tokens",
104
+ type=int,
105
+ default=None,
106
+ help="Optional per-call token limit forwarded to the policy config",
107
+ )
108
+ parser.add_argument(
109
+ "--verbose", action="store_true", help="Print resolved configuration and headers"
110
+ )
73
111
  args = parser.parse_args()
74
112
 
113
+ # Also load from explicit --env-file if provided
75
114
  if args.env_file:
76
115
  env_path = Path(args.env_file).expanduser()
77
116
  if not env_path.exists():
@@ -79,16 +118,51 @@ async def main() -> None:
79
118
  else:
80
119
  load_dotenv(env_path, override=False)
81
120
 
121
+ # Prompt for required parameters if not provided
122
+ base_url = args.base_url
123
+ if args.base_url == "http://localhost:8010":
124
+ print("\nTask app configuration:")
125
+ base_url_input = input(f"Task app base URL [http://localhost:8001]: ").strip()
126
+ base_url = base_url_input if base_url_input else "http://localhost:8001"
127
+
128
+ model = args.model
129
+ if not model:
130
+ print("\nFine-tuned model configuration:")
131
+ print(
132
+ "Note: This should be the model ID returned from training (e.g., fft:Qwen/Qwen3-4B:job_abc123)"
133
+ )
134
+ model_input = input("Fine-tuned model ID: ").strip()
135
+ if not model_input:
136
+ parser.error("Model identifier is required")
137
+ model = model_input
138
+
139
+ inference_url = args.inference_url
140
+ if not inference_url:
141
+ inference_url_input = input("Inference URL [http://localhost:8000/api]: ").strip()
142
+ inference_url = inference_url_input if inference_url_input else "http://localhost:8000/api"
143
+
144
+ # Override args
145
+ args.base_url = base_url
146
+ args.model = model
147
+ args.inference_url = inference_url
148
+
149
+ # Check environment variables first (loaded from .env)
82
150
  task_app_key = args.task_app_key or os.getenv("ENVIRONMENT_API_KEY")
83
151
  if not task_app_key:
84
- parser.error("Missing task app API key (set ENVIRONMENT_API_KEY or pass --task-app-key)")
152
+ print("\n[INFO] ENVIRONMENT_API_KEY not found in environment or .env file")
153
+ task_app_key = input("RL Environment API key: ").strip()
154
+ if not task_app_key:
155
+ parser.error("Missing task app API key")
85
156
 
86
157
  modal_key = args.modal_key or os.getenv("SYNTH_API_KEY")
87
158
  if not modal_key:
88
- parser.error("Missing Synth/Modal API key (set SYNTH_API_KEY or pass --modal-key)")
159
+ print("[INFO] SYNTH_API_KEY not found in environment or .env file")
160
+ modal_key = input("Synth API key: ").strip()
161
+ if not modal_key:
162
+ parser.error("Missing Synth/Modal API key")
89
163
 
90
- if synth_key and "openai.com" not in args.inference_url.lower():
91
- os.environ["OPENAI_API_KEY"] = synth_key
164
+ if modal_key and "openai.com" not in args.inference_url.lower():
165
+ os.environ["OPENAI_API_KEY"] = modal_key
92
166
 
93
167
  if args.ops:
94
168
  ops = [op.strip() for op in args.ops.split(",") if op.strip()]
@@ -103,6 +177,7 @@ async def main() -> None:
103
177
  ops.extend(["agent", "env"])
104
178
 
105
179
  if args.verbose:
180
+
106
181
  def _mask(val: str | None) -> str:
107
182
  if not val:
108
183
  return "<unset>"
@@ -115,11 +190,15 @@ async def main() -> None:
115
190
  print(f" Modal API key : {_mask(modal_key)}")
116
191
  print(f" Ops (count={len(ops)}) : {ops}")
117
192
 
118
- inf_url_norm = args.inference_url.rstrip('/')
119
- if '/api' not in inf_url_norm:
120
- print('[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions.')
121
- elif not inf_url_norm.lower().endswith('/api'):
122
- print('[INFO] Using inference base URL; policy will append /v1/chat/completions automatically.')
193
+ inf_url_norm = args.inference_url.rstrip("/")
194
+ if "/api" not in inf_url_norm:
195
+ print(
196
+ "[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions."
197
+ )
198
+ elif not inf_url_norm.lower().endswith("/api"):
199
+ print(
200
+ "[INFO] Using inference base URL; policy will append /v1/chat/completions automatically."
201
+ )
123
202
 
124
203
  async with TaskAppClient(args.base_url, api_key=task_app_key) as client:
125
204
  try:
@@ -139,20 +218,29 @@ async def main() -> None:
139
218
  if args.verbose:
140
219
  print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
141
220
  if args.max_policy_tokens is not None:
142
- request.policy.config.update({
143
- "max_completion_tokens": args.max_policy_tokens,
144
- "max_tokens": args.max_policy_tokens,
145
- })
221
+ request.policy.config.update(
222
+ {
223
+ "max_completion_tokens": args.max_policy_tokens,
224
+ "max_tokens": args.max_policy_tokens,
225
+ }
226
+ )
146
227
  print("Requesting rollout…")
147
228
  response = await client.rollout(request)
148
229
  summary = summarise_response(response)
149
230
  print(json.dumps(summary, indent=2))
150
231
  print(f"Ops executed: {ops}")
151
232
  except httpx.HTTPStatusError as exc:
152
- detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
233
+ detail = (
234
+ exc.response.json()
235
+ if exc.response.headers.get("content-type", "").startswith("application/json")
236
+ else exc.response.text
237
+ )
153
238
  print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
154
239
  if exc.response.status_code in (401, 503):
155
- print("Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.", file=sys.stderr)
240
+ print(
241
+ "Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.",
242
+ file=sys.stderr,
243
+ )
156
244
  raise
157
245
 
158
246