synth-ai 0.2.9.dev3__py3-none-any.whl → 0.2.9.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (107) hide show
  1. examples/analyze_semantic_words.sh +17 -0
  2. examples/common_old/backend.py +21 -0
  3. examples/crafter_debug_render.py +180 -0
  4. examples/evals_old/README.md +98 -0
  5. examples/evals_old/__init__.py +6 -0
  6. examples/evals_old/compare_models.py +1037 -0
  7. examples/evals_old/example_log.md +145 -0
  8. examples/evals_old/run_demo.sh +126 -0
  9. examples/evals_old/trace_analysis.py +270 -0
  10. examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
  11. examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
  12. examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
  13. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
  14. examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
  15. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
  16. examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
  17. examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
  18. examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
  19. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
  20. examples/finetuning_old/synth_qwen_v1/README.md +68 -0
  21. examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
  22. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
  23. examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
  24. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
  25. examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
  26. examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
  27. examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
  28. examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
  29. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
  30. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
  31. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
  32. examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
  33. examples/finetuning_old/synth_qwen_v1/util.py +147 -0
  34. examples/rl/README.md +169 -0
  35. examples/rl/configs/eval_base_qwen.toml +15 -0
  36. examples/rl/configs/eval_rl_qwen.toml +11 -0
  37. examples/rl/configs/rl_from_base_qwen.toml +35 -0
  38. examples/rl/configs/rl_from_base_qwen17.toml +74 -0
  39. examples/rl/configs/rl_from_ft_qwen.toml +35 -0
  40. examples/rl/download_dataset.py +64 -0
  41. examples/rl/run_eval.py +435 -0
  42. examples/rl/run_rl_and_save.py +94 -0
  43. examples/rl/task_app/README.md +22 -0
  44. {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
  45. examples/rl/task_app/math_task_app.py +107 -0
  46. examples/rl_old/task_app.py +962 -0
  47. examples/run_crafter_demo.sh +10 -0
  48. examples/warming_up_to_rl/analyze_trace_db.py +420 -0
  49. examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
  50. examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
  51. examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
  52. examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
  53. examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
  54. examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
  55. examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
  56. examples/warming_up_to_rl/export_trace_sft.py +541 -0
  57. examples/warming_up_to_rl/groq_test.py +88 -0
  58. examples/warming_up_to_rl/manage_secrets.py +127 -0
  59. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  60. examples/warming_up_to_rl/old/notes.md +73 -0
  61. examples/warming_up_to_rl/readme.md +172 -0
  62. examples/warming_up_to_rl/run_eval.py +434 -0
  63. examples/warming_up_to_rl/run_fft_and_save.py +309 -0
  64. examples/warming_up_to_rl/run_local_rollout.py +188 -0
  65. examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
  66. examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
  67. examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
  68. examples/warming_up_to_rl/run_rl_and_save.py +101 -0
  69. examples/warming_up_to_rl/run_rollout_remote.py +129 -0
  70. examples/warming_up_to_rl/task_app/README.md +38 -0
  71. {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
  72. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
  73. examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
  74. examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
  75. examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
  76. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
  77. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
  78. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
  79. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
  80. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
  81. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
  82. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
  83. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
  84. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
  85. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
  86. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
  87. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
  88. examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
  89. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
  90. examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
  91. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
  92. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
  93. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
  94. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
  95. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
  96. examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
  97. synth_ai/api/train/config_finder.py +18 -18
  98. synth_ai/api/train/env_resolver.py +28 -1
  99. synth_ai/cli/task_apps.py +291 -56
  100. synth_ai/task/apps/__init__.py +54 -13
  101. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/METADATA +1 -1
  102. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/RECORD +106 -13
  103. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/top_level.txt +1 -0
  104. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  105. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/WHEEL +0 -0
  106. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/entry_points.txt +0 -0
  107. {synth_ai-0.2.9.dev3.dist-info → synth_ai-0.2.9.dev5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,309 @@
1
+ #!/usr/bin/env python3
2
+ from __future__ import annotations
3
+
4
+ import argparse
5
+ import json
6
+ import os
7
+ import sys
8
+ import time
9
+ from pathlib import Path
10
+ from typing import Any, Dict, Tuple, List
11
+
12
+ import tomllib
13
+ import re
14
+ import requests
15
+
16
+
17
+ def mask(val: str) -> str:
18
+ if not isinstance(val, str) or not val:
19
+ return "<unset>"
20
+ return f"{val[:6]}…{val[-4:]}" if len(val) >= 10 else "****"
21
+
22
+
23
+ def post_multipart(base: str, api_key: str, path: str, file_field: str, filepath: Path) -> Dict[str, Any]:
24
+ """Upload a file, trying backend-specific endpoints with fallbacks.
25
+
26
+ Priority:
27
+ - {BASE}/learning/files (Modal Learning v2 style)
28
+ - {BASE}/files (OpenAI-style)
29
+ """
30
+ headers = {"Authorization": f"Bearer {api_key}"}
31
+ files = {file_field: (filepath.name, filepath.read_bytes(), "application/jsonl")}
32
+ data = {"purpose": "fine-tune"}
33
+
34
+ endpoints = [
35
+ f"{base.rstrip('/')}/{path.lstrip('/')}", # e.g., /learning/files
36
+ f"{base.rstrip('/')}/files", # OpenAI-style
37
+ ]
38
+ last_err: Dict[str, Any] | None = None
39
+ for ep in endpoints:
40
+ try:
41
+ r = requests.post(ep, headers=headers, files=files, data=data, timeout=300)
42
+ # Success fast-path
43
+ try:
44
+ js = r.json()
45
+ except Exception:
46
+ js = {"status": r.status_code, "text": r.text[:800]}
47
+
48
+ if r.status_code < 400 and (js.get("id") or js.get("object") in ("file",)):
49
+ return js
50
+
51
+ # 404/405 -> try next endpoint
52
+ if r.status_code in (404, 405):
53
+ last_err = {"status": r.status_code, "body": (r.text or "")[:800], "endpoint": ep}
54
+ continue
55
+
56
+ # Other errors: return rich error
57
+ return {
58
+ "error": True,
59
+ "status": r.status_code,
60
+ "endpoint": ep,
61
+ "body": (r.text or "")[:1200],
62
+ }
63
+ except requests.RequestException as e:
64
+ last_err = {"error": True, "exception": str(e), "endpoint": ep}
65
+ continue
66
+
67
+ return last_err or {"error": True, "detail": "upload_failed_all_endpoints"}
68
+
69
+
70
+ def post_json(base: str, api_key: str, path: str, body: Dict[str, Any]) -> Dict[str, Any]:
71
+ url = f"{base.rstrip('/')}/{path.lstrip('/')}"
72
+ headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
73
+ r = requests.post(url, headers=headers, data=json.dumps(body), timeout=120)
74
+ try:
75
+ return r.json()
76
+ except Exception:
77
+ return {"status": r.status_code, "text": r.text[:400]}
78
+
79
+
80
+ def get_json(base: str, api_key: str, path: str) -> Dict[str, Any]:
81
+ url = f"{base.rstrip('/')}/{path.lstrip('/')}"
82
+ headers = {"Authorization": f"Bearer {api_key}"}
83
+ r = requests.get(url, headers=headers, timeout=30)
84
+ try:
85
+ return r.json()
86
+ except Exception:
87
+ return {"status": r.status_code, "text": r.text[:400]}
88
+
89
+
90
+ def main() -> None:
91
+ parser = argparse.ArgumentParser(description="Submit FFT job and save resulting model id")
92
+ parser.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
93
+ parser.add_argument("--toml", required=True, help="Path to FFT TOML config")
94
+ parser.add_argument("--data", default="", help="Override dataset JSONL path")
95
+ parser.add_argument("--poll-seconds", type=int, default=1800)
96
+ parser.add_argument("--env-file", default="", help="Optional path to .env file with SYNTH_API_KEY")
97
+ args = parser.parse_args()
98
+
99
+ config_path = Path(args.toml).expanduser().resolve()
100
+ if not config_path.exists():
101
+ print(f"Config not found: {config_path}", file=sys.stderr)
102
+ sys.exit(2)
103
+ with config_path.open("rb") as fh:
104
+ cfg = tomllib.load(fh)
105
+
106
+ job_cfg = cfg.get("job", {}) if isinstance(cfg.get("job"), dict) else {}
107
+ compute_cfg = cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {}
108
+ data_cfg_full = cfg.get("data", {}) if isinstance(cfg.get("data"), dict) else {}
109
+ topo_cfg = (data_cfg_full or {}).get("topology", {}) if isinstance(data_cfg_full, dict) else {}
110
+ validation_local_path = (data_cfg_full or {}).get("validation_path") if isinstance(data_cfg_full, dict) else None
111
+ train_cfg = cfg.get("training", {}) if isinstance(cfg.get("training"), dict) else {}
112
+ hp_cfg = cfg.get("hyperparameters", {}) if isinstance(cfg.get("hyperparameters"), dict) else {}
113
+
114
+ model = str(job_cfg.get("model") or os.getenv("SFT_MODEL") or "Qwen/Qwen3-4B")
115
+
116
+ # Resolve dataset path
117
+ data_path = args.data or job_cfg.get("data") or job_cfg.get("data_path")
118
+ data_file: Path | None = None
119
+ if isinstance(data_path, str) and data_path.strip():
120
+ p = Path(data_path).expanduser()
121
+ if not p.is_absolute():
122
+ p = (config_path.parent / p).resolve()
123
+ data_file = p
124
+ if data_file is None:
125
+ print("Missing dataset path in --data or [job].data", file=sys.stderr)
126
+ sys.exit(2)
127
+ if not data_file.exists():
128
+ print(f"Dataset not found: {data_file}", file=sys.stderr)
129
+ sys.exit(2)
130
+
131
+ synth_key = (os.getenv("SYNTH_API_KEY") or "").strip()
132
+ # Fallback: try to load from .env if not present in environment
133
+ if not synth_key:
134
+ candidate_env: Path | None = None
135
+ if isinstance(args.env_file, str) and args.env_file.strip():
136
+ candidate_env = Path(args.env_file).expanduser().resolve()
137
+ else:
138
+ # Prefer .env next to the TOML config
139
+ candidate_env = (config_path.parent / ".env").resolve()
140
+ if candidate_env and candidate_env.exists():
141
+ try:
142
+ env_text = candidate_env.read_text(encoding="utf-8", errors="ignore")
143
+ # Match lines like: SYNTH_API_KEY=..., or export SYNTH_API_KEY=...
144
+ key_val: str | None = None
145
+ for line in env_text.splitlines():
146
+ m = re.match(r"^\s*(?:export\s+)?SYNTH_API_KEY\s*=\s*(.*)$", line)
147
+ if m:
148
+ raw = m.group(1).strip()
149
+ # Trim surrounding quotes if present
150
+ if (raw.startswith('"') and raw.endswith('"')) or (raw.startswith("'") and raw.endswith("'")):
151
+ raw = raw[1:-1]
152
+ key_val = raw.strip()
153
+ break
154
+ if key_val:
155
+ synth_key = key_val
156
+ os.environ["SYNTH_API_KEY"] = synth_key
157
+ print(f"[INFO] Loaded SYNTH_API_KEY from {candidate_env}")
158
+ except Exception as _e:
159
+ # Ignore and fall through to error below
160
+ pass
161
+ if not synth_key:
162
+ print("Missing SYNTH_API_KEY (set in env or provide --env-file pointing to .env)", file=sys.stderr)
163
+ sys.exit(2)
164
+
165
+ backend = args.backend.rstrip("/")
166
+ print(f"[INFO] Using backend={backend} key_fp={mask(synth_key)} data={data_file}")
167
+ if isinstance(validation_local_path, str) and validation_local_path.strip():
168
+ print(f"[INFO] Using validation path={validation_local_path}")
169
+
170
+ # 1) Upload training file
171
+ print("[INFO] Uploading training file…")
172
+ upf = post_multipart(backend, synth_key, "/learning/files", "file", data_file)
173
+ try:
174
+ print(f"[INFO] Upload response: {json.dumps(upf, indent=2)[:400]}")
175
+ except Exception:
176
+ print(f"[INFO] Upload response (raw): {str(upf)[:400]}")
177
+ file_id = str((upf or {}).get("id") or "").strip()
178
+ if not file_id:
179
+ # Rich diagnostics
180
+ err_status = (upf or {}).get("status")
181
+ err_body = (upf or {}).get("body") or (upf or {}).get("text")
182
+ err_ep = (upf or {}).get("endpoint")
183
+ print(f"Upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:200]}", file=sys.stderr)
184
+ sys.exit(4)
185
+
186
+ # Optionally upload validation file
187
+ val_file_id: str | None = None
188
+ if isinstance(validation_local_path, str) and validation_local_path.strip():
189
+ vpath = Path(validation_local_path).expanduser()
190
+ if not vpath.is_absolute():
191
+ vpath = (config_path.parent / vpath).resolve()
192
+ if not vpath.exists():
193
+ print(f"[WARN] Validation file not found: {vpath} (skipping validation)")
194
+ else:
195
+ print("[INFO] Uploading validation file…")
196
+ upv = post_multipart(backend, synth_key, "/learning/files", "file", vpath)
197
+ try:
198
+ print(f"[INFO] Validation upload response: {json.dumps(upv, indent=2)[:300]}")
199
+ except Exception:
200
+ print(f"[INFO] Validation upload response (raw): {str(upv)[:300]}")
201
+ val_file_id = str((upv or {}).get("id") or "").strip() or None
202
+ if not val_file_id:
203
+ err_status = (upv or {}).get("status")
204
+ err_body = (upv or {}).get("body") or (upv or {}).get("text")
205
+ err_ep = (upv or {}).get("endpoint")
206
+ print(f"[WARN] Validation upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:180]} — continuing without validation")
207
+
208
+ # 2) Build job payload
209
+ hp_block: Dict[str, Any] = {
210
+ "n_epochs": int(hp_cfg.get("n_epochs") or 1),
211
+ }
212
+ # Optional extras if present
213
+ for k in (
214
+ "batch_size",
215
+ "global_batch",
216
+ "per_device_batch",
217
+ "gradient_accumulation_steps",
218
+ "sequence_length",
219
+ "learning_rate",
220
+ "warmup_ratio",
221
+ "train_kind",
222
+ ):
223
+ if k in hp_cfg:
224
+ hp_block[k] = hp_cfg[k]
225
+
226
+ parallel = hp_cfg.get("parallelism") if isinstance(hp_cfg.get("parallelism"), dict) else None
227
+ if parallel:
228
+ hp_block["parallelism"] = parallel
229
+
230
+ compute_block: Dict[str, Any] = {}
231
+ for k in ("gpu_type", "gpu_count", "nodes"):
232
+ if k in compute_cfg:
233
+ compute_block[k] = compute_cfg[k]
234
+
235
+ effective = {
236
+ "compute": compute_block,
237
+ "data": {"topology": topo_cfg or {}},
238
+ "training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
239
+ }
240
+ # If TOML includes a [training.validation] block, forward relevant knobs into hyperparameters
241
+ validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
242
+ if isinstance(validation_cfg, dict):
243
+ # Enable evaluation and map keys as-is; backend trainer maps metric_for_best_model 'val.loss'→'eval_loss'
244
+ hp_block.update({
245
+ "evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
246
+ "eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
247
+ "save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
248
+ "metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
249
+ "greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
250
+ })
251
+ # Also surface validation enable flag into effective_config for visibility (optional)
252
+ effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
253
+
254
+ body = {
255
+ "model": model,
256
+ "training_file_id": file_id,
257
+ "training_type": "sft_offline",
258
+ "hyperparameters": hp_block,
259
+ "metadata": {"effective_config": effective},
260
+ }
261
+ if val_file_id:
262
+ # Shared API expects top-level validation_file? Tests mention legacy; prefer placing into metadata.effective_config.data
263
+ # Put into effective_config.data so downstream loader can read it; keep top-level off unless required.
264
+ effective.setdefault("data", {})["validation_files"] = [val_file_id]
265
+
266
+ # 3) Create and start job
267
+ print("[INFO] Creating FFT job…")
268
+ cj = post_json(backend, synth_key, "/learning/jobs", body)
269
+ print(f"[INFO] Create response: {json.dumps(cj, indent=2)[:200]}")
270
+ job_id = str(cj.get("job_id") or cj.get("id") or "").strip()
271
+ if not job_id:
272
+ print("Create job failed", file=sys.stderr)
273
+ sys.exit(5)
274
+
275
+ print(f"[INFO] Starting job {job_id}…")
276
+ _ = post_json(backend, synth_key, f"/learning/jobs/{job_id}/start", {})
277
+
278
+ # 4) Poll until terminal
279
+ deadline = time.time() + max(30, int(job_cfg.get("poll_seconds") or args.poll_seconds))
280
+ status = "queued"
281
+ ft_model = None
282
+ queued_since = time.time()
283
+ while time.time() < deadline:
284
+ info = get_json(backend, synth_key, f"/learning/jobs/{job_id}")
285
+ status = (info.get("status") or "").lower()
286
+ ft_model = info.get("fine_tuned_model")
287
+ print(f"[INFO] poll status={status} ft_model={ft_model}")
288
+ if status in ("succeeded", "failed", "canceled", "cancelled"):
289
+ break
290
+ # Warn if stuck queued for >10 minutes
291
+ if status == "queued" and (time.time() - queued_since) > 600:
292
+ print("[WARN] Job has remained queued for >10 minutes. Backend may be capacity constrained.")
293
+ queued_since = time.time()
294
+ time.sleep(5)
295
+
296
+ # 5) Save model id
297
+ out_file = Path(__file__).parent / "ft_model_id.txt"
298
+ if ft_model:
299
+ with out_file.open("a") as fh:
300
+ fh.write(str(ft_model) + "\n")
301
+ print(f"[INFO] Saved model id to {out_file}: {ft_model}")
302
+ sys.exit(0 if status == "succeeded" else 1)
303
+ else:
304
+ print(f"[WARN] No fine_tuned_model found; final status={status}")
305
+ sys.exit(1)
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()
@@ -0,0 +1,188 @@
1
+ #!/usr/bin/env python3
2
+ """Hit a locally running Crafter task app and request a rollout."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import sys
14
+
15
+ import httpx
16
+ from dotenv import load_dotenv
17
+
18
+ from synth_ai.task import (
19
+ RolloutEnvSpec,
20
+ RolloutPolicySpec,
21
+ RolloutRecordConfig,
22
+ RolloutRequest,
23
+ RolloutSafetyConfig,
24
+ TaskAppClient,
25
+ )
26
+
27
+
28
+ def build_rollout_request(
29
+ seed: int,
30
+ run_id: str,
31
+ *,
32
+ model: str,
33
+ inference_url: str,
34
+ ops: list[str],
35
+ extra_headers: dict[str, str] | None = None,
36
+ trace_format: str = "compact",
37
+ return_trace: bool = False,
38
+ ) -> RolloutRequest:
39
+ policy_config = {"model": model, "inference_url": inference_url}
40
+ if extra_headers:
41
+ policy_config["extra_headers"] = extra_headers
42
+ record_cfg = RolloutRecordConfig(
43
+ trajectories=True,
44
+ trace_format=trace_format,
45
+ return_trace=return_trace,
46
+ )
47
+ return RolloutRequest(
48
+ run_id=run_id,
49
+ env=RolloutEnvSpec(env_name='crafter', seed=seed, config={}),
50
+ policy=RolloutPolicySpec(policy_name='crafter-react', config=policy_config),
51
+ ops=ops,
52
+ record=record_cfg,
53
+ on_done='reset',
54
+ safety=RolloutSafetyConfig(),
55
+ )
56
+
57
+
58
+ def summarise_response(data: Any) -> dict[str, Any]:
59
+ metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
60
+ error = None
61
+ rollout_status = None
62
+ try:
63
+ trajectories = getattr(data, "trajectories", None) or data.get("trajectories")
64
+ if isinstance(trajectories, list) and trajectories:
65
+ final = getattr(trajectories[0], "final", None)
66
+ if not final and isinstance(trajectories[0], dict):
67
+ final = trajectories[0].get("final")
68
+ if isinstance(final, dict):
69
+ error = final.get("error")
70
+ rollout_status = final.get("rollout_status")
71
+ except Exception:
72
+ pass
73
+ return {
74
+ "run_id": getattr(data, "run_id", None) or data.get("run_id"),
75
+ "num_episodes": metrics.get("num_episodes"),
76
+ "num_steps": metrics.get("num_steps"),
77
+ "episode_returns": metrics.get("episode_returns"),
78
+ "outcome_score": metrics.get("outcome_score"),
79
+ "events_score": metrics.get("events_score"),
80
+ "rollout_status": rollout_status,
81
+ "error": error,
82
+ }
83
+
84
+
85
+ async def main() -> None:
86
+ parser = argparse.ArgumentParser(description=__doc__)
87
+ parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
88
+ parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
89
+ parser.add_argument('--seed', type=int, default=42, help='Env seed to rollout')
90
+ parser.add_argument('--run-id', default='local-demo', help='Run identifier')
91
+ parser.add_argument('--model', default='gpt-4o-mini', help='Model identifier for the Crafter policy (OpenAI-compatible)')
92
+ parser.add_argument('--inference-url', default='https://api.openai.com', help='Inference base URL used by the policy (e.g., https://api.openai.com)')
93
+ parser.add_argument('--env-file', type=str, default=None, help='Path to .env file with API keys')
94
+ parser.add_argument('--ops', default=None, help='Comma-separated rollout ops (advanced override)')
95
+ parser.add_argument('--max-llm-calls', type=int, default=1, help='Number of policy inference calls when --ops not provided')
96
+ parser.add_argument('--max-policy-tokens', type=int, default=None, help='Optional per-call token limit forwarded to the policy config')
97
+ parser.add_argument('--timeout', type=float, default=600.0, help='HTTP timeout (seconds) for task app requests')
98
+ parser.add_argument('--verbose', action='store_true', help='Print resolved configuration and headers')
99
+ args = parser.parse_args()
100
+
101
+ if args.env_file:
102
+ env_path = Path(args.env_file).expanduser()
103
+ if not env_path.exists():
104
+ print(f"[WARN] Env file not found: {env_path}")
105
+ else:
106
+ load_dotenv(env_path, override=False)
107
+
108
+ api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
109
+ if not api_key:
110
+ parser.error("Missing --api-key (or ENVIRONMENT_API_KEY not set)")
111
+
112
+ extra_headers: dict[str, str] | None = None
113
+ synth_key = os.getenv("SYNTH_API_KEY")
114
+ if synth_key:
115
+ extra_headers = {"Authorization": f"Bearer {synth_key}"}
116
+ if "openai.com" not in args.inference_url.lower():
117
+ os.environ["OPENAI_API_KEY"] = synth_key
118
+
119
+ if args.verbose:
120
+ def _mask(val: str | None) -> str:
121
+ if not val:
122
+ return '<unset>'
123
+ return f"{val[:6]}…{val[-4:]} (len={len(val)})"
124
+
125
+ print('Resolved configuration:')
126
+ print(f" Task app base URL : {args.base_url}")
127
+ print(f" Inference base URL : {args.inference_url}")
128
+ print(f" Task app API key : {_mask(api_key)}")
129
+ print(f" Synth API key : {_mask(synth_key)}")
130
+ print(f" HTTP timeout : {args.timeout:.1f}s")
131
+
132
+ if args.ops:
133
+ ops = [op.strip() for op in args.ops.split(',') if op.strip()]
134
+ if not ops:
135
+ raise ValueError('Ops must contain at least one entry')
136
+ else:
137
+ llm_calls = max(args.max_llm_calls, 1)
138
+ if llm_calls > 20:
139
+ print('[WARN] --max-llm-calls capped at 20 to avoid excessive episodes; use --ops for manual control.')
140
+ llm_calls = 20
141
+ ops = []
142
+ for _ in range(llm_calls):
143
+ ops.extend(['agent', 'env'])
144
+
145
+ async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
146
+ try:
147
+ print(f'Fetching task_info for seed {args.seed}…')
148
+ task_info = await client.task_info(seeds=[args.seed])
149
+ info_payload = task_info[0] if isinstance(task_info, list) else task_info
150
+ print(json.dumps(info_payload.model_dump(), indent=2)[:600])
151
+
152
+ request = build_rollout_request(
153
+ args.seed,
154
+ args.run_id,
155
+ model=args.model,
156
+ inference_url=args.inference_url,
157
+ ops=ops,
158
+ extra_headers=extra_headers,
159
+ )
160
+ if args.max_policy_tokens is not None:
161
+ request.policy.config.update({
162
+ 'max_completion_tokens': args.max_policy_tokens,
163
+ 'max_tokens': args.max_policy_tokens,
164
+ })
165
+ if args.verbose:
166
+ print(f'Ops: {ops}')
167
+ print(f'Request headers: {request.policy.config.get("extra_headers", {})}')
168
+ print('Requesting rollout…')
169
+ response = await client.rollout(request)
170
+ summary = summarise_response(response)
171
+ print(json.dumps(summary, indent=2))
172
+ print(f'Ops executed: {ops}')
173
+ print('Tip: use --max-llm-calls N for agent/env pairs or --ops for manual control.')
174
+ except httpx.HTTPStatusError as exc:
175
+ detail = exc.response.json() if exc.response.headers.get('content-type', '').startswith('application/json') else exc.response.text
176
+ print(f'HTTP error {exc.response.status_code}: {detail}', file=sys.stderr)
177
+ if exc.response.status_code in (401, 503):
178
+ print('Hint: ensure the task app was started with ENVIRONMENT_API_KEY set and pass the same key via --api-key.', file=sys.stderr)
179
+ if exc.response.status_code == 500 and args.model in str(detail):
180
+ print('Hint: supply --model/--inference-url (and set OPENAI_API_KEY or GROQ_API_KEY) so the policy can route inference.', file=sys.stderr)
181
+ print('Hint: the inference URL should be the base (e.g., https://api.openai.com); the task app appends /v1/chat/completions.', file=sys.stderr)
182
+ if args.max_policy_tokens is not None:
183
+ print(f'Hint: --max-policy-tokens={args.max_policy_tokens} is forwarded to the policy config as max_completion_tokens.', file=sys.stderr)
184
+ raise
185
+
186
+
187
+ if __name__ == "__main__":
188
+ asyncio.run(main())
@@ -0,0 +1,160 @@
1
+ #!/usr/bin/env python3
2
+ """Rollout a Crafter task app using the Modal backend proxy."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import asyncio
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import sys
14
+
15
+ import httpx
16
+ from dotenv import load_dotenv
17
+
18
+ from synth_ai.task import (
19
+ RolloutEnvSpec,
20
+ RolloutPolicySpec,
21
+ RolloutRecordConfig,
22
+ RolloutRequest,
23
+ RolloutSafetyConfig,
24
+ TaskAppClient,
25
+ )
26
+
27
+
28
+ def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str) -> RolloutRequest:
29
+ policy_config = {
30
+ "model": model,
31
+ "inference_url": inference_url,
32
+ "extra_headers": {
33
+ "Authorization": f"Bearer {api_key}",
34
+ },
35
+ }
36
+ return RolloutRequest(
37
+ run_id=run_id,
38
+ env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
39
+ policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
40
+ ops=ops,
41
+ record=RolloutRecordConfig(trajectories=True),
42
+ on_done="reset",
43
+ safety=RolloutSafetyConfig(),
44
+ )
45
+
46
+
47
+ def summarise_response(data: Any) -> dict[str, Any]:
48
+ metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
49
+ return {
50
+ "run_id": getattr(data, "run_id", None) or data.get("run_id"),
51
+ "num_episodes": metrics.get("num_episodes"),
52
+ "num_steps": metrics.get("num_steps"),
53
+ "episode_returns": metrics.get("episode_returns"),
54
+ "outcome_score": metrics.get("outcome_score"),
55
+ "events_score": metrics.get("events_score"),
56
+ }
57
+
58
+
59
+ async def main() -> None:
60
+ parser = argparse.ArgumentParser(description=__doc__)
61
+ parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
62
+ parser.add_argument("--env-file", type=str, default=None, help="Path to .env file with keys")
63
+ parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
64
+ parser.add_argument("--run-id", default="modal-eval", help="Run identifier")
65
+ parser.add_argument("--model", required=True, help="Model identifier for the Crafter policy")
66
+ parser.add_argument("--inference-url", required=True, help="Modal backend inference base URL (e.g., http://localhost:8000/api)")
67
+ parser.add_argument("--task-app-key", default=None, help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)")
68
+ parser.add_argument("--modal-key", default=None, help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)")
69
+ parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of policy inference calls")
70
+ parser.add_argument("--ops", default=None, help="Comma-separated rollout ops (advanced override)")
71
+ parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional per-call token limit forwarded to the policy config")
72
+ parser.add_argument("--verbose", action="store_true", help="Print resolved configuration and headers")
73
+ args = parser.parse_args()
74
+
75
+ if args.env_file:
76
+ env_path = Path(args.env_file).expanduser()
77
+ if not env_path.exists():
78
+ print(f"[WARN] Env file not found: {env_path}")
79
+ else:
80
+ load_dotenv(env_path, override=False)
81
+
82
+ task_app_key = args.task_app_key or os.getenv("ENVIRONMENT_API_KEY")
83
+ if not task_app_key:
84
+ parser.error("Missing task app API key (set ENVIRONMENT_API_KEY or pass --task-app-key)")
85
+
86
+ modal_key = args.modal_key or os.getenv("SYNTH_API_KEY")
87
+ if not modal_key:
88
+ parser.error("Missing Synth/Modal API key (set SYNTH_API_KEY or pass --modal-key)")
89
+
90
+ if synth_key and "openai.com" not in args.inference_url.lower():
91
+ os.environ["OPENAI_API_KEY"] = synth_key
92
+
93
+ if args.ops:
94
+ ops = [op.strip() for op in args.ops.split(",") if op.strip()]
95
+ if not ops:
96
+ raise ValueError("Ops must contain at least one entry")
97
+ else:
98
+ llm_calls = max(args.max_llm_calls, 1)
99
+ if llm_calls > 20:
100
+ llm_calls = 20
101
+ ops = []
102
+ for _ in range(llm_calls):
103
+ ops.extend(["agent", "env"])
104
+
105
+ if args.verbose:
106
+ def _mask(val: str | None) -> str:
107
+ if not val:
108
+ return "<unset>"
109
+ return f"{val[:6]}…{val[-4:]} (len={len(val)})"
110
+
111
+ print("Resolved configuration:")
112
+ print(f" Task app base URL : {args.base_url}")
113
+ print(f" Inference base URL : {args.inference_url}")
114
+ print(f" Task app API key : {_mask(task_app_key)}")
115
+ print(f" Modal API key : {_mask(modal_key)}")
116
+ print(f" Ops (count={len(ops)}) : {ops}")
117
+
118
+ inf_url_norm = args.inference_url.rstrip('/')
119
+ if '/api' not in inf_url_norm:
120
+ print('[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions.')
121
+ elif not inf_url_norm.lower().endswith('/api'):
122
+ print('[INFO] Using inference base URL; policy will append /v1/chat/completions automatically.')
123
+
124
+ async with TaskAppClient(args.base_url, api_key=task_app_key) as client:
125
+ try:
126
+ print(f"Fetching task_info for seed {args.seed}…")
127
+ task_info = await client.task_info(seeds=[args.seed])
128
+ info_payload = task_info[0] if isinstance(task_info, list) else task_info
129
+ print(json.dumps(info_payload.model_dump(), indent=2)[:600])
130
+
131
+ request = build_rollout_request(
132
+ args.seed,
133
+ args.run_id,
134
+ model=args.model,
135
+ inference_url=args.inference_url,
136
+ ops=ops,
137
+ api_key=modal_key,
138
+ )
139
+ if args.verbose:
140
+ print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
141
+ if args.max_policy_tokens is not None:
142
+ request.policy.config.update({
143
+ "max_completion_tokens": args.max_policy_tokens,
144
+ "max_tokens": args.max_policy_tokens,
145
+ })
146
+ print("Requesting rollout…")
147
+ response = await client.rollout(request)
148
+ summary = summarise_response(response)
149
+ print(json.dumps(summary, indent=2))
150
+ print(f"Ops executed: {ops}")
151
+ except httpx.HTTPStatusError as exc:
152
+ detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
153
+ print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
154
+ if exc.response.status_code in (401, 503):
155
+ print("Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.", file=sys.stderr)
156
+ raise
157
+
158
+
159
+ if __name__ == "__main__":
160
+ asyncio.run(main())