synth-ai 0.2.9.dev2__py3-none-any.whl → 0.2.9.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/analyze_semantic_words.sh +17 -0
- examples/common_old/backend.py +21 -0
- examples/crafter_debug_render.py +180 -0
- examples/evals_old/README.md +98 -0
- examples/evals_old/__init__.py +6 -0
- examples/evals_old/compare_models.py +1037 -0
- examples/evals_old/example_log.md +145 -0
- examples/evals_old/run_demo.sh +126 -0
- examples/evals_old/trace_analysis.py +270 -0
- examples/finetuning_old/_backup_synth_qwen/config.toml +29 -0
- examples/finetuning_old/_backup_synth_qwen/example_log.md +324 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces.py +60 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +239 -0
- examples/finetuning_old/_backup_synth_qwen/purge_v3_traces.py +109 -0
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +1924 -0
- examples/finetuning_old/_backup_synth_qwen/readme.md +49 -0
- examples/finetuning_old/_backup_synth_qwen/run_crafter_qwen4b.py +114 -0
- examples/finetuning_old/_backup_synth_qwen/run_demo.sh +195 -0
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +118 -0
- examples/finetuning_old/synth_qwen_v1/README.md +68 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces.py +60 -0
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +239 -0
- examples/finetuning_old/synth_qwen_v1/finetune.py +46 -0
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +71 -0
- examples/finetuning_old/synth_qwen_v1/infer.py +37 -0
- examples/finetuning_old/synth_qwen_v1/poll.py +44 -0
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +35 -0
- examples/finetuning_old/synth_qwen_v1/purge_v3_traces.py +109 -0
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +1932 -0
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +207 -0
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +232 -0
- examples/finetuning_old/synth_qwen_v1/upload_data.py +34 -0
- examples/finetuning_old/synth_qwen_v1/util.py +147 -0
- examples/rl/README.md +169 -0
- examples/rl/configs/eval_base_qwen.toml +15 -0
- examples/rl/configs/eval_rl_qwen.toml +11 -0
- examples/rl/configs/rl_from_base_qwen.toml +35 -0
- examples/rl/configs/rl_from_base_qwen17.toml +74 -0
- examples/rl/configs/rl_from_ft_qwen.toml +35 -0
- examples/rl/download_dataset.py +64 -0
- examples/rl/run_eval.py +435 -0
- examples/rl/run_rl_and_save.py +94 -0
- examples/rl/task_app/README.md +22 -0
- {synth_ai/task/apps → examples/rl/task_app}/math_single_step.py +8 -8
- examples/rl/task_app/math_task_app.py +107 -0
- examples/rl_old/task_app.py +962 -0
- examples/run_crafter_demo.sh +10 -0
- examples/warming_up_to_rl/analyze_trace_db.py +420 -0
- examples/warming_up_to_rl/configs/crafter_fft.toml +48 -0
- examples/warming_up_to_rl/configs/crafter_fft_4b.toml +54 -0
- examples/warming_up_to_rl/configs/eval_fft_qwen4b.toml +20 -0
- examples/warming_up_to_rl/configs/eval_groq_qwen32b.toml +13 -0
- examples/warming_up_to_rl/configs/eval_modal_qwen4b.toml +23 -0
- examples/warming_up_to_rl/configs/rl_from_base_qwen4b.toml +73 -0
- examples/warming_up_to_rl/configs/rl_from_ft.toml +56 -0
- examples/warming_up_to_rl/export_trace_sft.py +541 -0
- examples/warming_up_to_rl/groq_test.py +88 -0
- examples/warming_up_to_rl/manage_secrets.py +127 -0
- examples/warming_up_to_rl/old/event_rewards.md +234 -0
- examples/warming_up_to_rl/old/notes.md +73 -0
- examples/warming_up_to_rl/readme.md +172 -0
- examples/warming_up_to_rl/run_eval.py +434 -0
- examples/warming_up_to_rl/run_fft_and_save.py +309 -0
- examples/warming_up_to_rl/run_local_rollout.py +188 -0
- examples/warming_up_to_rl/run_local_rollout_modal.py +160 -0
- examples/warming_up_to_rl/run_local_rollout_parallel.py +342 -0
- examples/warming_up_to_rl/run_local_rollout_traced.py +372 -0
- examples/warming_up_to_rl/run_rl_and_save.py +101 -0
- examples/warming_up_to_rl/run_rollout_remote.py +129 -0
- examples/warming_up_to_rl/task_app/README.md +38 -0
- {synth_ai/task/apps → examples/warming_up_to_rl/task_app}/grpo_crafter.py +7 -7
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +165 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/README.md +173 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/branching.py +145 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +1271 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +6 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +429 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +442 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +96 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +302 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +47 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +202 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +512 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/main.py +102 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +985 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/registry.py +197 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +1749 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +5 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +217 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_agents.py +160 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +146 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +58 -0
- examples/warming_up_to_rl/task_app/synth_envs_hosted/utils.py +61 -0
- synth_ai/api/train/config_finder.py +18 -18
- synth_ai/api/train/env_resolver.py +28 -1
- synth_ai/cli/task_apps.py +264 -55
- synth_ai/demo_registry.py +7 -7
- synth_ai/demos/demo_task_apps/crafter/__init__.py +1 -0
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +54 -0
- synth_ai/demos/demo_task_apps/crafter/configs/rl_from_base_qwen4b.toml +73 -0
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +165 -0
- synth_ai/task/apps/__init__.py +54 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/RECORD +112 -13
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/top_level.txt +1 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev2.dist-info → synth_ai-0.2.9.dev4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict, Tuple, List
|
|
11
|
+
|
|
12
|
+
import tomllib
|
|
13
|
+
import re
|
|
14
|
+
import requests
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def mask(val: str) -> str:
|
|
18
|
+
if not isinstance(val, str) or not val:
|
|
19
|
+
return "<unset>"
|
|
20
|
+
return f"{val[:6]}…{val[-4:]}" if len(val) >= 10 else "****"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def post_multipart(base: str, api_key: str, path: str, file_field: str, filepath: Path) -> Dict[str, Any]:
|
|
24
|
+
"""Upload a file, trying backend-specific endpoints with fallbacks.
|
|
25
|
+
|
|
26
|
+
Priority:
|
|
27
|
+
- {BASE}/learning/files (Modal Learning v2 style)
|
|
28
|
+
- {BASE}/files (OpenAI-style)
|
|
29
|
+
"""
|
|
30
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
31
|
+
files = {file_field: (filepath.name, filepath.read_bytes(), "application/jsonl")}
|
|
32
|
+
data = {"purpose": "fine-tune"}
|
|
33
|
+
|
|
34
|
+
endpoints = [
|
|
35
|
+
f"{base.rstrip('/')}/{path.lstrip('/')}", # e.g., /learning/files
|
|
36
|
+
f"{base.rstrip('/')}/files", # OpenAI-style
|
|
37
|
+
]
|
|
38
|
+
last_err: Dict[str, Any] | None = None
|
|
39
|
+
for ep in endpoints:
|
|
40
|
+
try:
|
|
41
|
+
r = requests.post(ep, headers=headers, files=files, data=data, timeout=300)
|
|
42
|
+
# Success fast-path
|
|
43
|
+
try:
|
|
44
|
+
js = r.json()
|
|
45
|
+
except Exception:
|
|
46
|
+
js = {"status": r.status_code, "text": r.text[:800]}
|
|
47
|
+
|
|
48
|
+
if r.status_code < 400 and (js.get("id") or js.get("object") in ("file",)):
|
|
49
|
+
return js
|
|
50
|
+
|
|
51
|
+
# 404/405 -> try next endpoint
|
|
52
|
+
if r.status_code in (404, 405):
|
|
53
|
+
last_err = {"status": r.status_code, "body": (r.text or "")[:800], "endpoint": ep}
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
# Other errors: return rich error
|
|
57
|
+
return {
|
|
58
|
+
"error": True,
|
|
59
|
+
"status": r.status_code,
|
|
60
|
+
"endpoint": ep,
|
|
61
|
+
"body": (r.text or "")[:1200],
|
|
62
|
+
}
|
|
63
|
+
except requests.RequestException as e:
|
|
64
|
+
last_err = {"error": True, "exception": str(e), "endpoint": ep}
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
return last_err or {"error": True, "detail": "upload_failed_all_endpoints"}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def post_json(base: str, api_key: str, path: str, body: Dict[str, Any]) -> Dict[str, Any]:
|
|
71
|
+
url = f"{base.rstrip('/')}/{path.lstrip('/')}"
|
|
72
|
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
73
|
+
r = requests.post(url, headers=headers, data=json.dumps(body), timeout=120)
|
|
74
|
+
try:
|
|
75
|
+
return r.json()
|
|
76
|
+
except Exception:
|
|
77
|
+
return {"status": r.status_code, "text": r.text[:400]}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_json(base: str, api_key: str, path: str) -> Dict[str, Any]:
|
|
81
|
+
url = f"{base.rstrip('/')}/{path.lstrip('/')}"
|
|
82
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
83
|
+
r = requests.get(url, headers=headers, timeout=30)
|
|
84
|
+
try:
|
|
85
|
+
return r.json()
|
|
86
|
+
except Exception:
|
|
87
|
+
return {"status": r.status_code, "text": r.text[:400]}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def main() -> None:
|
|
91
|
+
parser = argparse.ArgumentParser(description="Submit FFT job and save resulting model id")
|
|
92
|
+
parser.add_argument("--backend", default=os.getenv("BACKEND_BASE_URL", "http://localhost:8000/api"))
|
|
93
|
+
parser.add_argument("--toml", required=True, help="Path to FFT TOML config")
|
|
94
|
+
parser.add_argument("--data", default="", help="Override dataset JSONL path")
|
|
95
|
+
parser.add_argument("--poll-seconds", type=int, default=1800)
|
|
96
|
+
parser.add_argument("--env-file", default="", help="Optional path to .env file with SYNTH_API_KEY")
|
|
97
|
+
args = parser.parse_args()
|
|
98
|
+
|
|
99
|
+
config_path = Path(args.toml).expanduser().resolve()
|
|
100
|
+
if not config_path.exists():
|
|
101
|
+
print(f"Config not found: {config_path}", file=sys.stderr)
|
|
102
|
+
sys.exit(2)
|
|
103
|
+
with config_path.open("rb") as fh:
|
|
104
|
+
cfg = tomllib.load(fh)
|
|
105
|
+
|
|
106
|
+
job_cfg = cfg.get("job", {}) if isinstance(cfg.get("job"), dict) else {}
|
|
107
|
+
compute_cfg = cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {}
|
|
108
|
+
data_cfg_full = cfg.get("data", {}) if isinstance(cfg.get("data"), dict) else {}
|
|
109
|
+
topo_cfg = (data_cfg_full or {}).get("topology", {}) if isinstance(data_cfg_full, dict) else {}
|
|
110
|
+
validation_local_path = (data_cfg_full or {}).get("validation_path") if isinstance(data_cfg_full, dict) else None
|
|
111
|
+
train_cfg = cfg.get("training", {}) if isinstance(cfg.get("training"), dict) else {}
|
|
112
|
+
hp_cfg = cfg.get("hyperparameters", {}) if isinstance(cfg.get("hyperparameters"), dict) else {}
|
|
113
|
+
|
|
114
|
+
model = str(job_cfg.get("model") or os.getenv("SFT_MODEL") or "Qwen/Qwen3-4B")
|
|
115
|
+
|
|
116
|
+
# Resolve dataset path
|
|
117
|
+
data_path = args.data or job_cfg.get("data") or job_cfg.get("data_path")
|
|
118
|
+
data_file: Path | None = None
|
|
119
|
+
if isinstance(data_path, str) and data_path.strip():
|
|
120
|
+
p = Path(data_path).expanduser()
|
|
121
|
+
if not p.is_absolute():
|
|
122
|
+
p = (config_path.parent / p).resolve()
|
|
123
|
+
data_file = p
|
|
124
|
+
if data_file is None:
|
|
125
|
+
print("Missing dataset path in --data or [job].data", file=sys.stderr)
|
|
126
|
+
sys.exit(2)
|
|
127
|
+
if not data_file.exists():
|
|
128
|
+
print(f"Dataset not found: {data_file}", file=sys.stderr)
|
|
129
|
+
sys.exit(2)
|
|
130
|
+
|
|
131
|
+
synth_key = (os.getenv("SYNTH_API_KEY") or "").strip()
|
|
132
|
+
# Fallback: try to load from .env if not present in environment
|
|
133
|
+
if not synth_key:
|
|
134
|
+
candidate_env: Path | None = None
|
|
135
|
+
if isinstance(args.env_file, str) and args.env_file.strip():
|
|
136
|
+
candidate_env = Path(args.env_file).expanduser().resolve()
|
|
137
|
+
else:
|
|
138
|
+
# Prefer .env next to the TOML config
|
|
139
|
+
candidate_env = (config_path.parent / ".env").resolve()
|
|
140
|
+
if candidate_env and candidate_env.exists():
|
|
141
|
+
try:
|
|
142
|
+
env_text = candidate_env.read_text(encoding="utf-8", errors="ignore")
|
|
143
|
+
# Match lines like: SYNTH_API_KEY=..., or export SYNTH_API_KEY=...
|
|
144
|
+
key_val: str | None = None
|
|
145
|
+
for line in env_text.splitlines():
|
|
146
|
+
m = re.match(r"^\s*(?:export\s+)?SYNTH_API_KEY\s*=\s*(.*)$", line)
|
|
147
|
+
if m:
|
|
148
|
+
raw = m.group(1).strip()
|
|
149
|
+
# Trim surrounding quotes if present
|
|
150
|
+
if (raw.startswith('"') and raw.endswith('"')) or (raw.startswith("'") and raw.endswith("'")):
|
|
151
|
+
raw = raw[1:-1]
|
|
152
|
+
key_val = raw.strip()
|
|
153
|
+
break
|
|
154
|
+
if key_val:
|
|
155
|
+
synth_key = key_val
|
|
156
|
+
os.environ["SYNTH_API_KEY"] = synth_key
|
|
157
|
+
print(f"[INFO] Loaded SYNTH_API_KEY from {candidate_env}")
|
|
158
|
+
except Exception as _e:
|
|
159
|
+
# Ignore and fall through to error below
|
|
160
|
+
pass
|
|
161
|
+
if not synth_key:
|
|
162
|
+
print("Missing SYNTH_API_KEY (set in env or provide --env-file pointing to .env)", file=sys.stderr)
|
|
163
|
+
sys.exit(2)
|
|
164
|
+
|
|
165
|
+
backend = args.backend.rstrip("/")
|
|
166
|
+
print(f"[INFO] Using backend={backend} key_fp={mask(synth_key)} data={data_file}")
|
|
167
|
+
if isinstance(validation_local_path, str) and validation_local_path.strip():
|
|
168
|
+
print(f"[INFO] Using validation path={validation_local_path}")
|
|
169
|
+
|
|
170
|
+
# 1) Upload training file
|
|
171
|
+
print("[INFO] Uploading training file…")
|
|
172
|
+
upf = post_multipart(backend, synth_key, "/learning/files", "file", data_file)
|
|
173
|
+
try:
|
|
174
|
+
print(f"[INFO] Upload response: {json.dumps(upf, indent=2)[:400]}")
|
|
175
|
+
except Exception:
|
|
176
|
+
print(f"[INFO] Upload response (raw): {str(upf)[:400]}")
|
|
177
|
+
file_id = str((upf or {}).get("id") or "").strip()
|
|
178
|
+
if not file_id:
|
|
179
|
+
# Rich diagnostics
|
|
180
|
+
err_status = (upf or {}).get("status")
|
|
181
|
+
err_body = (upf or {}).get("body") or (upf or {}).get("text")
|
|
182
|
+
err_ep = (upf or {}).get("endpoint")
|
|
183
|
+
print(f"Upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:200]}", file=sys.stderr)
|
|
184
|
+
sys.exit(4)
|
|
185
|
+
|
|
186
|
+
# Optionally upload validation file
|
|
187
|
+
val_file_id: str | None = None
|
|
188
|
+
if isinstance(validation_local_path, str) and validation_local_path.strip():
|
|
189
|
+
vpath = Path(validation_local_path).expanduser()
|
|
190
|
+
if not vpath.is_absolute():
|
|
191
|
+
vpath = (config_path.parent / vpath).resolve()
|
|
192
|
+
if not vpath.exists():
|
|
193
|
+
print(f"[WARN] Validation file not found: {vpath} (skipping validation)")
|
|
194
|
+
else:
|
|
195
|
+
print("[INFO] Uploading validation file…")
|
|
196
|
+
upv = post_multipart(backend, synth_key, "/learning/files", "file", vpath)
|
|
197
|
+
try:
|
|
198
|
+
print(f"[INFO] Validation upload response: {json.dumps(upv, indent=2)[:300]}")
|
|
199
|
+
except Exception:
|
|
200
|
+
print(f"[INFO] Validation upload response (raw): {str(upv)[:300]}")
|
|
201
|
+
val_file_id = str((upv or {}).get("id") or "").strip() or None
|
|
202
|
+
if not val_file_id:
|
|
203
|
+
err_status = (upv or {}).get("status")
|
|
204
|
+
err_body = (upv or {}).get("body") or (upv or {}).get("text")
|
|
205
|
+
err_ep = (upv or {}).get("endpoint")
|
|
206
|
+
print(f"[WARN] Validation upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:180]} — continuing without validation")
|
|
207
|
+
|
|
208
|
+
# 2) Build job payload
|
|
209
|
+
hp_block: Dict[str, Any] = {
|
|
210
|
+
"n_epochs": int(hp_cfg.get("n_epochs") or 1),
|
|
211
|
+
}
|
|
212
|
+
# Optional extras if present
|
|
213
|
+
for k in (
|
|
214
|
+
"batch_size",
|
|
215
|
+
"global_batch",
|
|
216
|
+
"per_device_batch",
|
|
217
|
+
"gradient_accumulation_steps",
|
|
218
|
+
"sequence_length",
|
|
219
|
+
"learning_rate",
|
|
220
|
+
"warmup_ratio",
|
|
221
|
+
"train_kind",
|
|
222
|
+
):
|
|
223
|
+
if k in hp_cfg:
|
|
224
|
+
hp_block[k] = hp_cfg[k]
|
|
225
|
+
|
|
226
|
+
parallel = hp_cfg.get("parallelism") if isinstance(hp_cfg.get("parallelism"), dict) else None
|
|
227
|
+
if parallel:
|
|
228
|
+
hp_block["parallelism"] = parallel
|
|
229
|
+
|
|
230
|
+
compute_block: Dict[str, Any] = {}
|
|
231
|
+
for k in ("gpu_type", "gpu_count", "nodes"):
|
|
232
|
+
if k in compute_cfg:
|
|
233
|
+
compute_block[k] = compute_cfg[k]
|
|
234
|
+
|
|
235
|
+
effective = {
|
|
236
|
+
"compute": compute_block,
|
|
237
|
+
"data": {"topology": topo_cfg or {}},
|
|
238
|
+
"training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
|
|
239
|
+
}
|
|
240
|
+
# If TOML includes a [training.validation] block, forward relevant knobs into hyperparameters
|
|
241
|
+
validation_cfg = train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
|
|
242
|
+
if isinstance(validation_cfg, dict):
|
|
243
|
+
# Enable evaluation and map keys as-is; backend trainer maps metric_for_best_model 'val.loss'→'eval_loss'
|
|
244
|
+
hp_block.update({
|
|
245
|
+
"evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
|
|
246
|
+
"eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
|
|
247
|
+
"save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
|
|
248
|
+
"metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
|
|
249
|
+
"greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
|
|
250
|
+
})
|
|
251
|
+
# Also surface validation enable flag into effective_config for visibility (optional)
|
|
252
|
+
effective.setdefault("training", {})["validation"] = {"enabled": bool(validation_cfg.get("enabled", True))}
|
|
253
|
+
|
|
254
|
+
body = {
|
|
255
|
+
"model": model,
|
|
256
|
+
"training_file_id": file_id,
|
|
257
|
+
"training_type": "sft_offline",
|
|
258
|
+
"hyperparameters": hp_block,
|
|
259
|
+
"metadata": {"effective_config": effective},
|
|
260
|
+
}
|
|
261
|
+
if val_file_id:
|
|
262
|
+
# Shared API expects top-level validation_file? Tests mention legacy; prefer placing into metadata.effective_config.data
|
|
263
|
+
# Put into effective_config.data so downstream loader can read it; keep top-level off unless required.
|
|
264
|
+
effective.setdefault("data", {})["validation_files"] = [val_file_id]
|
|
265
|
+
|
|
266
|
+
# 3) Create and start job
|
|
267
|
+
print("[INFO] Creating FFT job…")
|
|
268
|
+
cj = post_json(backend, synth_key, "/learning/jobs", body)
|
|
269
|
+
print(f"[INFO] Create response: {json.dumps(cj, indent=2)[:200]}")
|
|
270
|
+
job_id = str(cj.get("job_id") or cj.get("id") or "").strip()
|
|
271
|
+
if not job_id:
|
|
272
|
+
print("Create job failed", file=sys.stderr)
|
|
273
|
+
sys.exit(5)
|
|
274
|
+
|
|
275
|
+
print(f"[INFO] Starting job {job_id}…")
|
|
276
|
+
_ = post_json(backend, synth_key, f"/learning/jobs/{job_id}/start", {})
|
|
277
|
+
|
|
278
|
+
# 4) Poll until terminal
|
|
279
|
+
deadline = time.time() + max(30, int(job_cfg.get("poll_seconds") or args.poll_seconds))
|
|
280
|
+
status = "queued"
|
|
281
|
+
ft_model = None
|
|
282
|
+
queued_since = time.time()
|
|
283
|
+
while time.time() < deadline:
|
|
284
|
+
info = get_json(backend, synth_key, f"/learning/jobs/{job_id}")
|
|
285
|
+
status = (info.get("status") or "").lower()
|
|
286
|
+
ft_model = info.get("fine_tuned_model")
|
|
287
|
+
print(f"[INFO] poll status={status} ft_model={ft_model}")
|
|
288
|
+
if status in ("succeeded", "failed", "canceled", "cancelled"):
|
|
289
|
+
break
|
|
290
|
+
# Warn if stuck queued for >10 minutes
|
|
291
|
+
if status == "queued" and (time.time() - queued_since) > 600:
|
|
292
|
+
print("[WARN] Job has remained queued for >10 minutes. Backend may be capacity constrained.")
|
|
293
|
+
queued_since = time.time()
|
|
294
|
+
time.sleep(5)
|
|
295
|
+
|
|
296
|
+
# 5) Save model id
|
|
297
|
+
out_file = Path(__file__).parent / "ft_model_id.txt"
|
|
298
|
+
if ft_model:
|
|
299
|
+
with out_file.open("a") as fh:
|
|
300
|
+
fh.write(str(ft_model) + "\n")
|
|
301
|
+
print(f"[INFO] Saved model id to {out_file}: {ft_model}")
|
|
302
|
+
sys.exit(0 if status == "succeeded" else 1)
|
|
303
|
+
else:
|
|
304
|
+
print(f"[WARN] No fine_tuned_model found; final status={status}")
|
|
305
|
+
sys.exit(1)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
if __name__ == "__main__":
|
|
309
|
+
main()
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Hit a locally running Crafter task app and request a rollout."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
|
|
18
|
+
from synth_ai.task import (
|
|
19
|
+
RolloutEnvSpec,
|
|
20
|
+
RolloutPolicySpec,
|
|
21
|
+
RolloutRecordConfig,
|
|
22
|
+
RolloutRequest,
|
|
23
|
+
RolloutSafetyConfig,
|
|
24
|
+
TaskAppClient,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_rollout_request(
|
|
29
|
+
seed: int,
|
|
30
|
+
run_id: str,
|
|
31
|
+
*,
|
|
32
|
+
model: str,
|
|
33
|
+
inference_url: str,
|
|
34
|
+
ops: list[str],
|
|
35
|
+
extra_headers: dict[str, str] | None = None,
|
|
36
|
+
trace_format: str = "compact",
|
|
37
|
+
return_trace: bool = False,
|
|
38
|
+
) -> RolloutRequest:
|
|
39
|
+
policy_config = {"model": model, "inference_url": inference_url}
|
|
40
|
+
if extra_headers:
|
|
41
|
+
policy_config["extra_headers"] = extra_headers
|
|
42
|
+
record_cfg = RolloutRecordConfig(
|
|
43
|
+
trajectories=True,
|
|
44
|
+
trace_format=trace_format,
|
|
45
|
+
return_trace=return_trace,
|
|
46
|
+
)
|
|
47
|
+
return RolloutRequest(
|
|
48
|
+
run_id=run_id,
|
|
49
|
+
env=RolloutEnvSpec(env_name='crafter', seed=seed, config={}),
|
|
50
|
+
policy=RolloutPolicySpec(policy_name='crafter-react', config=policy_config),
|
|
51
|
+
ops=ops,
|
|
52
|
+
record=record_cfg,
|
|
53
|
+
on_done='reset',
|
|
54
|
+
safety=RolloutSafetyConfig(),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def summarise_response(data: Any) -> dict[str, Any]:
|
|
59
|
+
metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
|
|
60
|
+
error = None
|
|
61
|
+
rollout_status = None
|
|
62
|
+
try:
|
|
63
|
+
trajectories = getattr(data, "trajectories", None) or data.get("trajectories")
|
|
64
|
+
if isinstance(trajectories, list) and trajectories:
|
|
65
|
+
final = getattr(trajectories[0], "final", None)
|
|
66
|
+
if not final and isinstance(trajectories[0], dict):
|
|
67
|
+
final = trajectories[0].get("final")
|
|
68
|
+
if isinstance(final, dict):
|
|
69
|
+
error = final.get("error")
|
|
70
|
+
rollout_status = final.get("rollout_status")
|
|
71
|
+
except Exception:
|
|
72
|
+
pass
|
|
73
|
+
return {
|
|
74
|
+
"run_id": getattr(data, "run_id", None) or data.get("run_id"),
|
|
75
|
+
"num_episodes": metrics.get("num_episodes"),
|
|
76
|
+
"num_steps": metrics.get("num_steps"),
|
|
77
|
+
"episode_returns": metrics.get("episode_returns"),
|
|
78
|
+
"outcome_score": metrics.get("outcome_score"),
|
|
79
|
+
"events_score": metrics.get("events_score"),
|
|
80
|
+
"rollout_status": rollout_status,
|
|
81
|
+
"error": error,
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def main() -> None:
|
|
86
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
87
|
+
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
88
|
+
parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
|
|
89
|
+
parser.add_argument('--seed', type=int, default=42, help='Env seed to rollout')
|
|
90
|
+
parser.add_argument('--run-id', default='local-demo', help='Run identifier')
|
|
91
|
+
parser.add_argument('--model', default='gpt-4o-mini', help='Model identifier for the Crafter policy (OpenAI-compatible)')
|
|
92
|
+
parser.add_argument('--inference-url', default='https://api.openai.com', help='Inference base URL used by the policy (e.g., https://api.openai.com)')
|
|
93
|
+
parser.add_argument('--env-file', type=str, default=None, help='Path to .env file with API keys')
|
|
94
|
+
parser.add_argument('--ops', default=None, help='Comma-separated rollout ops (advanced override)')
|
|
95
|
+
parser.add_argument('--max-llm-calls', type=int, default=1, help='Number of policy inference calls when --ops not provided')
|
|
96
|
+
parser.add_argument('--max-policy-tokens', type=int, default=None, help='Optional per-call token limit forwarded to the policy config')
|
|
97
|
+
parser.add_argument('--timeout', type=float, default=600.0, help='HTTP timeout (seconds) for task app requests')
|
|
98
|
+
parser.add_argument('--verbose', action='store_true', help='Print resolved configuration and headers')
|
|
99
|
+
args = parser.parse_args()
|
|
100
|
+
|
|
101
|
+
if args.env_file:
|
|
102
|
+
env_path = Path(args.env_file).expanduser()
|
|
103
|
+
if not env_path.exists():
|
|
104
|
+
print(f"[WARN] Env file not found: {env_path}")
|
|
105
|
+
else:
|
|
106
|
+
load_dotenv(env_path, override=False)
|
|
107
|
+
|
|
108
|
+
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
109
|
+
if not api_key:
|
|
110
|
+
parser.error("Missing --api-key (or ENVIRONMENT_API_KEY not set)")
|
|
111
|
+
|
|
112
|
+
extra_headers: dict[str, str] | None = None
|
|
113
|
+
synth_key = os.getenv("SYNTH_API_KEY")
|
|
114
|
+
if synth_key:
|
|
115
|
+
extra_headers = {"Authorization": f"Bearer {synth_key}"}
|
|
116
|
+
if "openai.com" not in args.inference_url.lower():
|
|
117
|
+
os.environ["OPENAI_API_KEY"] = synth_key
|
|
118
|
+
|
|
119
|
+
if args.verbose:
|
|
120
|
+
def _mask(val: str | None) -> str:
|
|
121
|
+
if not val:
|
|
122
|
+
return '<unset>'
|
|
123
|
+
return f"{val[:6]}…{val[-4:]} (len={len(val)})"
|
|
124
|
+
|
|
125
|
+
print('Resolved configuration:')
|
|
126
|
+
print(f" Task app base URL : {args.base_url}")
|
|
127
|
+
print(f" Inference base URL : {args.inference_url}")
|
|
128
|
+
print(f" Task app API key : {_mask(api_key)}")
|
|
129
|
+
print(f" Synth API key : {_mask(synth_key)}")
|
|
130
|
+
print(f" HTTP timeout : {args.timeout:.1f}s")
|
|
131
|
+
|
|
132
|
+
if args.ops:
|
|
133
|
+
ops = [op.strip() for op in args.ops.split(',') if op.strip()]
|
|
134
|
+
if not ops:
|
|
135
|
+
raise ValueError('Ops must contain at least one entry')
|
|
136
|
+
else:
|
|
137
|
+
llm_calls = max(args.max_llm_calls, 1)
|
|
138
|
+
if llm_calls > 20:
|
|
139
|
+
print('[WARN] --max-llm-calls capped at 20 to avoid excessive episodes; use --ops for manual control.')
|
|
140
|
+
llm_calls = 20
|
|
141
|
+
ops = []
|
|
142
|
+
for _ in range(llm_calls):
|
|
143
|
+
ops.extend(['agent', 'env'])
|
|
144
|
+
|
|
145
|
+
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
146
|
+
try:
|
|
147
|
+
print(f'Fetching task_info for seed {args.seed}…')
|
|
148
|
+
task_info = await client.task_info(seeds=[args.seed])
|
|
149
|
+
info_payload = task_info[0] if isinstance(task_info, list) else task_info
|
|
150
|
+
print(json.dumps(info_payload.model_dump(), indent=2)[:600])
|
|
151
|
+
|
|
152
|
+
request = build_rollout_request(
|
|
153
|
+
args.seed,
|
|
154
|
+
args.run_id,
|
|
155
|
+
model=args.model,
|
|
156
|
+
inference_url=args.inference_url,
|
|
157
|
+
ops=ops,
|
|
158
|
+
extra_headers=extra_headers,
|
|
159
|
+
)
|
|
160
|
+
if args.max_policy_tokens is not None:
|
|
161
|
+
request.policy.config.update({
|
|
162
|
+
'max_completion_tokens': args.max_policy_tokens,
|
|
163
|
+
'max_tokens': args.max_policy_tokens,
|
|
164
|
+
})
|
|
165
|
+
if args.verbose:
|
|
166
|
+
print(f'Ops: {ops}')
|
|
167
|
+
print(f'Request headers: {request.policy.config.get("extra_headers", {})}')
|
|
168
|
+
print('Requesting rollout…')
|
|
169
|
+
response = await client.rollout(request)
|
|
170
|
+
summary = summarise_response(response)
|
|
171
|
+
print(json.dumps(summary, indent=2))
|
|
172
|
+
print(f'Ops executed: {ops}')
|
|
173
|
+
print('Tip: use --max-llm-calls N for agent/env pairs or --ops for manual control.')
|
|
174
|
+
except httpx.HTTPStatusError as exc:
|
|
175
|
+
detail = exc.response.json() if exc.response.headers.get('content-type', '').startswith('application/json') else exc.response.text
|
|
176
|
+
print(f'HTTP error {exc.response.status_code}: {detail}', file=sys.stderr)
|
|
177
|
+
if exc.response.status_code in (401, 503):
|
|
178
|
+
print('Hint: ensure the task app was started with ENVIRONMENT_API_KEY set and pass the same key via --api-key.', file=sys.stderr)
|
|
179
|
+
if exc.response.status_code == 500 and args.model in str(detail):
|
|
180
|
+
print('Hint: supply --model/--inference-url (and set OPENAI_API_KEY or GROQ_API_KEY) so the policy can route inference.', file=sys.stderr)
|
|
181
|
+
print('Hint: the inference URL should be the base (e.g., https://api.openai.com); the task app appends /v1/chat/completions.', file=sys.stderr)
|
|
182
|
+
if args.max_policy_tokens is not None:
|
|
183
|
+
print(f'Hint: --max-policy-tokens={args.max_policy_tokens} is forwarded to the policy config as max_completion_tokens.', file=sys.stderr)
|
|
184
|
+
raise
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if __name__ == "__main__":
|
|
188
|
+
asyncio.run(main())
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""Rollout a Crafter task app using the Modal backend proxy."""
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import sys
|
|
14
|
+
|
|
15
|
+
import httpx
|
|
16
|
+
from dotenv import load_dotenv
|
|
17
|
+
|
|
18
|
+
from synth_ai.task import (
|
|
19
|
+
RolloutEnvSpec,
|
|
20
|
+
RolloutPolicySpec,
|
|
21
|
+
RolloutRecordConfig,
|
|
22
|
+
RolloutRequest,
|
|
23
|
+
RolloutSafetyConfig,
|
|
24
|
+
TaskAppClient,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str) -> RolloutRequest:
|
|
29
|
+
policy_config = {
|
|
30
|
+
"model": model,
|
|
31
|
+
"inference_url": inference_url,
|
|
32
|
+
"extra_headers": {
|
|
33
|
+
"Authorization": f"Bearer {api_key}",
|
|
34
|
+
},
|
|
35
|
+
}
|
|
36
|
+
return RolloutRequest(
|
|
37
|
+
run_id=run_id,
|
|
38
|
+
env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
|
|
39
|
+
policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
|
|
40
|
+
ops=ops,
|
|
41
|
+
record=RolloutRecordConfig(trajectories=True),
|
|
42
|
+
on_done="reset",
|
|
43
|
+
safety=RolloutSafetyConfig(),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def summarise_response(data: Any) -> dict[str, Any]:
|
|
48
|
+
metrics = data.metrics.model_dump() if hasattr(data.metrics, "model_dump") else data.get("metrics", {})
|
|
49
|
+
return {
|
|
50
|
+
"run_id": getattr(data, "run_id", None) or data.get("run_id"),
|
|
51
|
+
"num_episodes": metrics.get("num_episodes"),
|
|
52
|
+
"num_steps": metrics.get("num_steps"),
|
|
53
|
+
"episode_returns": metrics.get("episode_returns"),
|
|
54
|
+
"outcome_score": metrics.get("outcome_score"),
|
|
55
|
+
"events_score": metrics.get("events_score"),
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def main() -> None:
|
|
60
|
+
parser = argparse.ArgumentParser(description=__doc__)
|
|
61
|
+
parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
|
|
62
|
+
parser.add_argument("--env-file", type=str, default=None, help="Path to .env file with keys")
|
|
63
|
+
parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
|
|
64
|
+
parser.add_argument("--run-id", default="modal-eval", help="Run identifier")
|
|
65
|
+
parser.add_argument("--model", required=True, help="Model identifier for the Crafter policy")
|
|
66
|
+
parser.add_argument("--inference-url", required=True, help="Modal backend inference base URL (e.g., http://localhost:8000/api)")
|
|
67
|
+
parser.add_argument("--task-app-key", default=None, help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)")
|
|
68
|
+
parser.add_argument("--modal-key", default=None, help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)")
|
|
69
|
+
parser.add_argument("--max-llm-calls", type=int, default=20, help="Number of policy inference calls")
|
|
70
|
+
parser.add_argument("--ops", default=None, help="Comma-separated rollout ops (advanced override)")
|
|
71
|
+
parser.add_argument("--max-policy-tokens", type=int, default=None, help="Optional per-call token limit forwarded to the policy config")
|
|
72
|
+
parser.add_argument("--verbose", action="store_true", help="Print resolved configuration and headers")
|
|
73
|
+
args = parser.parse_args()
|
|
74
|
+
|
|
75
|
+
if args.env_file:
|
|
76
|
+
env_path = Path(args.env_file).expanduser()
|
|
77
|
+
if not env_path.exists():
|
|
78
|
+
print(f"[WARN] Env file not found: {env_path}")
|
|
79
|
+
else:
|
|
80
|
+
load_dotenv(env_path, override=False)
|
|
81
|
+
|
|
82
|
+
task_app_key = args.task_app_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
83
|
+
if not task_app_key:
|
|
84
|
+
parser.error("Missing task app API key (set ENVIRONMENT_API_KEY or pass --task-app-key)")
|
|
85
|
+
|
|
86
|
+
modal_key = args.modal_key or os.getenv("SYNTH_API_KEY")
|
|
87
|
+
if not modal_key:
|
|
88
|
+
parser.error("Missing Synth/Modal API key (set SYNTH_API_KEY or pass --modal-key)")
|
|
89
|
+
|
|
90
|
+
if synth_key and "openai.com" not in args.inference_url.lower():
|
|
91
|
+
os.environ["OPENAI_API_KEY"] = synth_key
|
|
92
|
+
|
|
93
|
+
if args.ops:
|
|
94
|
+
ops = [op.strip() for op in args.ops.split(",") if op.strip()]
|
|
95
|
+
if not ops:
|
|
96
|
+
raise ValueError("Ops must contain at least one entry")
|
|
97
|
+
else:
|
|
98
|
+
llm_calls = max(args.max_llm_calls, 1)
|
|
99
|
+
if llm_calls > 20:
|
|
100
|
+
llm_calls = 20
|
|
101
|
+
ops = []
|
|
102
|
+
for _ in range(llm_calls):
|
|
103
|
+
ops.extend(["agent", "env"])
|
|
104
|
+
|
|
105
|
+
if args.verbose:
|
|
106
|
+
def _mask(val: str | None) -> str:
|
|
107
|
+
if not val:
|
|
108
|
+
return "<unset>"
|
|
109
|
+
return f"{val[:6]}…{val[-4:]} (len={len(val)})"
|
|
110
|
+
|
|
111
|
+
print("Resolved configuration:")
|
|
112
|
+
print(f" Task app base URL : {args.base_url}")
|
|
113
|
+
print(f" Inference base URL : {args.inference_url}")
|
|
114
|
+
print(f" Task app API key : {_mask(task_app_key)}")
|
|
115
|
+
print(f" Modal API key : {_mask(modal_key)}")
|
|
116
|
+
print(f" Ops (count={len(ops)}) : {ops}")
|
|
117
|
+
|
|
118
|
+
inf_url_norm = args.inference_url.rstrip('/')
|
|
119
|
+
if '/api' not in inf_url_norm:
|
|
120
|
+
print('[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions.')
|
|
121
|
+
elif not inf_url_norm.lower().endswith('/api'):
|
|
122
|
+
print('[INFO] Using inference base URL; policy will append /v1/chat/completions automatically.')
|
|
123
|
+
|
|
124
|
+
async with TaskAppClient(args.base_url, api_key=task_app_key) as client:
|
|
125
|
+
try:
|
|
126
|
+
print(f"Fetching task_info for seed {args.seed}…")
|
|
127
|
+
task_info = await client.task_info(seeds=[args.seed])
|
|
128
|
+
info_payload = task_info[0] if isinstance(task_info, list) else task_info
|
|
129
|
+
print(json.dumps(info_payload.model_dump(), indent=2)[:600])
|
|
130
|
+
|
|
131
|
+
request = build_rollout_request(
|
|
132
|
+
args.seed,
|
|
133
|
+
args.run_id,
|
|
134
|
+
model=args.model,
|
|
135
|
+
inference_url=args.inference_url,
|
|
136
|
+
ops=ops,
|
|
137
|
+
api_key=modal_key,
|
|
138
|
+
)
|
|
139
|
+
if args.verbose:
|
|
140
|
+
print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
|
|
141
|
+
if args.max_policy_tokens is not None:
|
|
142
|
+
request.policy.config.update({
|
|
143
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
144
|
+
"max_tokens": args.max_policy_tokens,
|
|
145
|
+
})
|
|
146
|
+
print("Requesting rollout…")
|
|
147
|
+
response = await client.rollout(request)
|
|
148
|
+
summary = summarise_response(response)
|
|
149
|
+
print(json.dumps(summary, indent=2))
|
|
150
|
+
print(f"Ops executed: {ops}")
|
|
151
|
+
except httpx.HTTPStatusError as exc:
|
|
152
|
+
detail = exc.response.json() if exc.response.headers.get("content-type", "").startswith("application/json") else exc.response.text
|
|
153
|
+
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
154
|
+
if exc.response.status_code in (401, 503):
|
|
155
|
+
print("Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.", file=sys.stderr)
|
|
156
|
+
raise
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
if __name__ == "__main__":
|
|
160
|
+
asyncio.run(main())
|