synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1286 -170
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -12,6 +12,9 @@ from typing import Any, Dict, Tuple, List
|
|
|
12
12
|
import tomllib
|
|
13
13
|
import re
|
|
14
14
|
import requests
|
|
15
|
+
from dotenv import load_dotenv
|
|
16
|
+
|
|
17
|
+
from synth_ai.config.base_url import PROD_BASE_URL_DEFAULT
|
|
15
18
|
|
|
16
19
|
|
|
17
20
|
def mask(val: str) -> str:
|
|
@@ -20,7 +23,9 @@ def mask(val: str) -> str:
|
|
|
20
23
|
return f"{val[:6]}…{val[-4:]}" if len(val) >= 10 else "****"
|
|
21
24
|
|
|
22
25
|
|
|
23
|
-
def post_multipart(
|
|
26
|
+
def post_multipart(
|
|
27
|
+
base: str, api_key: str, path: str, file_field: str, filepath: Path
|
|
28
|
+
) -> Dict[str, Any]:
|
|
24
29
|
"""Upload a file, trying backend-specific endpoints with fallbacks.
|
|
25
30
|
|
|
26
31
|
Priority:
|
|
@@ -33,7 +38,7 @@ def post_multipart(base: str, api_key: str, path: str, file_field: str, filepath
|
|
|
33
38
|
|
|
34
39
|
endpoints = [
|
|
35
40
|
f"{base.rstrip('/')}/{path.lstrip('/')}", # e.g., /learning/files
|
|
36
|
-
f"{base.rstrip('/')}/files",
|
|
41
|
+
f"{base.rstrip('/')}/files", # OpenAI-style
|
|
37
42
|
]
|
|
38
43
|
last_err: Dict[str, Any] | None = None
|
|
39
44
|
for ep in endpoints:
|
|
@@ -87,17 +92,94 @@ def get_json(base: str, api_key: str, path: str) -> Dict[str, Any]:
|
|
|
87
92
|
return {"status": r.status_code, "text": r.text[:400]}
|
|
88
93
|
|
|
89
94
|
|
|
95
|
+
def _find_fft_configs() -> List[Path]:
|
|
96
|
+
"""Find FFT TOML configs in standard locations."""
|
|
97
|
+
candidates: List[Path] = []
|
|
98
|
+
|
|
99
|
+
# Check current directory configs/
|
|
100
|
+
cwd = Path.cwd()
|
|
101
|
+
configs_dir = cwd / "configs"
|
|
102
|
+
if configs_dir.is_dir():
|
|
103
|
+
for f in configs_dir.glob("*.toml"):
|
|
104
|
+
# Look for FFT configs (check if they have [algorithm] method = "supervised_finetune")
|
|
105
|
+
try:
|
|
106
|
+
content = f.read_text()
|
|
107
|
+
if "supervised_finetune" in content or "fft" in content.lower():
|
|
108
|
+
candidates.append(f)
|
|
109
|
+
except Exception:
|
|
110
|
+
pass
|
|
111
|
+
|
|
112
|
+
# Also check for any .toml files in current directory
|
|
113
|
+
for f in cwd.glob("*.toml"):
|
|
114
|
+
if f not in candidates:
|
|
115
|
+
try:
|
|
116
|
+
content = f.read_text()
|
|
117
|
+
if "supervised_finetune" in content or "fft" in content.lower():
|
|
118
|
+
candidates.append(f)
|
|
119
|
+
except Exception:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
return sorted(candidates)
|
|
123
|
+
|
|
124
|
+
|
|
90
125
|
def main() -> None:
|
|
126
|
+
# Load .env file from current directory first if it exists
|
|
127
|
+
default_env = Path.cwd() / ".env"
|
|
128
|
+
if default_env.exists():
|
|
129
|
+
load_dotenv(default_env, override=False)
|
|
130
|
+
|
|
91
131
|
parser = argparse.ArgumentParser(description="Submit FFT job and save resulting model id")
|
|
92
|
-
parser.add_argument(
|
|
93
|
-
|
|
132
|
+
parser.add_argument(
|
|
133
|
+
"--backend", default=os.getenv("BACKEND_BASE_URL", f"{PROD_BASE_URL_DEFAULT}/api")
|
|
134
|
+
)
|
|
135
|
+
parser.add_argument("--toml", required=False, help="Path to FFT TOML config")
|
|
94
136
|
parser.add_argument("--data", default="", help="Override dataset JSONL path")
|
|
95
137
|
parser.add_argument("--poll-seconds", type=int, default=1800)
|
|
96
|
-
parser.add_argument(
|
|
138
|
+
parser.add_argument(
|
|
139
|
+
"--env-file", default="", help="Optional path to .env file with SYNTH_API_KEY"
|
|
140
|
+
)
|
|
97
141
|
args = parser.parse_args()
|
|
98
142
|
|
|
99
|
-
|
|
100
|
-
if
|
|
143
|
+
# Also load from explicit --env-file if provided
|
|
144
|
+
if args.env_file:
|
|
145
|
+
env_path = Path(args.env_file).expanduser()
|
|
146
|
+
if not env_path.exists():
|
|
147
|
+
print(f"[WARN] Env file not found: {env_path}")
|
|
148
|
+
else:
|
|
149
|
+
load_dotenv(env_path, override=False)
|
|
150
|
+
|
|
151
|
+
# Auto-discover TOML config if not specified
|
|
152
|
+
config_path: Path | None = None
|
|
153
|
+
if args.toml:
|
|
154
|
+
config_path = Path(args.toml).expanduser().resolve()
|
|
155
|
+
else:
|
|
156
|
+
configs = _find_fft_configs()
|
|
157
|
+
if not configs:
|
|
158
|
+
print(
|
|
159
|
+
"No FFT config files found. Please specify --toml or create a config in configs/",
|
|
160
|
+
file=sys.stderr,
|
|
161
|
+
)
|
|
162
|
+
sys.exit(2)
|
|
163
|
+
elif len(configs) == 1:
|
|
164
|
+
config_path = configs[0]
|
|
165
|
+
print(f"Using FFT config: {config_path}")
|
|
166
|
+
else:
|
|
167
|
+
print("\nFound multiple FFT configs:")
|
|
168
|
+
for idx, cfg in enumerate(configs, 1):
|
|
169
|
+
print(f" [{idx}] {cfg}")
|
|
170
|
+
choice = input(f"Select config [1-{len(configs)}]: ").strip()
|
|
171
|
+
try:
|
|
172
|
+
selected_idx = int(choice) - 1
|
|
173
|
+
if 0 <= selected_idx < len(configs):
|
|
174
|
+
config_path = configs[selected_idx]
|
|
175
|
+
else:
|
|
176
|
+
print("Invalid selection", file=sys.stderr)
|
|
177
|
+
sys.exit(2)
|
|
178
|
+
except ValueError:
|
|
179
|
+
print("Invalid input", file=sys.stderr)
|
|
180
|
+
sys.exit(2)
|
|
181
|
+
|
|
182
|
+
if not config_path or not config_path.exists():
|
|
101
183
|
print(f"Config not found: {config_path}", file=sys.stderr)
|
|
102
184
|
sys.exit(2)
|
|
103
185
|
with config_path.open("rb") as fh:
|
|
@@ -107,7 +189,9 @@ def main() -> None:
|
|
|
107
189
|
compute_cfg = cfg.get("compute", {}) if isinstance(cfg.get("compute"), dict) else {}
|
|
108
190
|
data_cfg_full = cfg.get("data", {}) if isinstance(cfg.get("data"), dict) else {}
|
|
109
191
|
topo_cfg = (data_cfg_full or {}).get("topology", {}) if isinstance(data_cfg_full, dict) else {}
|
|
110
|
-
validation_local_path = (
|
|
192
|
+
validation_local_path = (
|
|
193
|
+
(data_cfg_full or {}).get("validation_path") if isinstance(data_cfg_full, dict) else None
|
|
194
|
+
)
|
|
111
195
|
train_cfg = cfg.get("training", {}) if isinstance(cfg.get("training"), dict) else {}
|
|
112
196
|
hp_cfg = cfg.get("hyperparameters", {}) if isinstance(cfg.get("hyperparameters"), dict) else {}
|
|
113
197
|
|
|
@@ -119,7 +203,13 @@ def main() -> None:
|
|
|
119
203
|
if isinstance(data_path, str) and data_path.strip():
|
|
120
204
|
p = Path(data_path).expanduser()
|
|
121
205
|
if not p.is_absolute():
|
|
122
|
-
|
|
206
|
+
# Try relative to cwd first, then relative to config directory
|
|
207
|
+
cwd_relative = Path.cwd() / p
|
|
208
|
+
config_relative = config_path.parent / p
|
|
209
|
+
if cwd_relative.exists():
|
|
210
|
+
p = cwd_relative.resolve()
|
|
211
|
+
else:
|
|
212
|
+
p = config_relative.resolve()
|
|
123
213
|
data_file = p
|
|
124
214
|
if data_file is None:
|
|
125
215
|
print("Missing dataset path in --data or [job].data", file=sys.stderr)
|
|
@@ -129,38 +219,11 @@ def main() -> None:
|
|
|
129
219
|
sys.exit(2)
|
|
130
220
|
|
|
131
221
|
synth_key = (os.getenv("SYNTH_API_KEY") or "").strip()
|
|
132
|
-
# Fallback: try to load from .env if not present in environment
|
|
133
222
|
if not synth_key:
|
|
134
|
-
|
|
135
|
-
if
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
# Prefer .env next to the TOML config
|
|
139
|
-
candidate_env = (config_path.parent / ".env").resolve()
|
|
140
|
-
if candidate_env and candidate_env.exists():
|
|
141
|
-
try:
|
|
142
|
-
env_text = candidate_env.read_text(encoding="utf-8", errors="ignore")
|
|
143
|
-
# Match lines like: SYNTH_API_KEY=..., or export SYNTH_API_KEY=...
|
|
144
|
-
key_val: str | None = None
|
|
145
|
-
for line in env_text.splitlines():
|
|
146
|
-
m = re.match(r"^\s*(?:export\s+)?SYNTH_API_KEY\s*=\s*(.*)$", line)
|
|
147
|
-
if m:
|
|
148
|
-
raw = m.group(1).strip()
|
|
149
|
-
# Trim surrounding quotes if present
|
|
150
|
-
if (raw.startswith('"') and raw.endswith('"')) or (raw.startswith("'") and raw.endswith("'")):
|
|
151
|
-
raw = raw[1:-1]
|
|
152
|
-
key_val = raw.strip()
|
|
153
|
-
break
|
|
154
|
-
if key_val:
|
|
155
|
-
synth_key = key_val
|
|
156
|
-
os.environ["SYNTH_API_KEY"] = synth_key
|
|
157
|
-
print(f"[INFO] Loaded SYNTH_API_KEY from {candidate_env}")
|
|
158
|
-
except Exception as _e:
|
|
159
|
-
# Ignore and fall through to error below
|
|
160
|
-
pass
|
|
161
|
-
if not synth_key:
|
|
162
|
-
print("Missing SYNTH_API_KEY (set in env or provide --env-file pointing to .env)", file=sys.stderr)
|
|
163
|
-
sys.exit(2)
|
|
223
|
+
synth_key = input("Please enter your Synth API key:\n> ").strip()
|
|
224
|
+
if not synth_key:
|
|
225
|
+
print("Synth API key is required", file=sys.stderr)
|
|
226
|
+
sys.exit(2)
|
|
164
227
|
|
|
165
228
|
backend = args.backend.rstrip("/")
|
|
166
229
|
print(f"[INFO] Using backend={backend} key_fp={mask(synth_key)} data={data_file}")
|
|
@@ -180,7 +243,10 @@ def main() -> None:
|
|
|
180
243
|
err_status = (upf or {}).get("status")
|
|
181
244
|
err_body = (upf or {}).get("body") or (upf or {}).get("text")
|
|
182
245
|
err_ep = (upf or {}).get("endpoint")
|
|
183
|
-
print(
|
|
246
|
+
print(
|
|
247
|
+
f"Upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:200]}",
|
|
248
|
+
file=sys.stderr,
|
|
249
|
+
)
|
|
184
250
|
sys.exit(4)
|
|
185
251
|
|
|
186
252
|
# Optionally upload validation file
|
|
@@ -203,7 +269,9 @@ def main() -> None:
|
|
|
203
269
|
err_status = (upv or {}).get("status")
|
|
204
270
|
err_body = (upv or {}).get("body") or (upv or {}).get("text")
|
|
205
271
|
err_ep = (upv or {}).get("endpoint")
|
|
206
|
-
print(
|
|
272
|
+
print(
|
|
273
|
+
f"[WARN] Validation upload failed (status={err_status} endpoint={err_ep}) body={str(err_body)[:180]} — continuing without validation"
|
|
274
|
+
)
|
|
207
275
|
|
|
208
276
|
# 2) Build job payload
|
|
209
277
|
hp_block: Dict[str, Any] = {
|
|
@@ -238,18 +306,24 @@ def main() -> None:
|
|
|
238
306
|
"training": {k: v for k, v in train_cfg.items() if k in ("mode", "use_qlora")},
|
|
239
307
|
}
|
|
240
308
|
# If TOML includes a [training.validation] block, forward relevant knobs into hyperparameters
|
|
241
|
-
validation_cfg =
|
|
309
|
+
validation_cfg = (
|
|
310
|
+
train_cfg.get("validation") if isinstance(train_cfg.get("validation"), dict) else None
|
|
311
|
+
)
|
|
242
312
|
if isinstance(validation_cfg, dict):
|
|
243
313
|
# Enable evaluation and map keys as-is; backend trainer maps metric_for_best_model 'val.loss'→'eval_loss'
|
|
244
|
-
hp_block.update(
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
314
|
+
hp_block.update(
|
|
315
|
+
{
|
|
316
|
+
"evaluation_strategy": validation_cfg.get("evaluation_strategy", "steps"),
|
|
317
|
+
"eval_steps": int(validation_cfg.get("eval_steps", 0) or 0),
|
|
318
|
+
"save_best_model_at_end": bool(validation_cfg.get("save_best_model_at_end", True)),
|
|
319
|
+
"metric_for_best_model": validation_cfg.get("metric_for_best_model", "val.loss"),
|
|
320
|
+
"greater_is_better": bool(validation_cfg.get("greater_is_better", False)),
|
|
321
|
+
}
|
|
322
|
+
)
|
|
251
323
|
# Also surface validation enable flag into effective_config for visibility (optional)
|
|
252
|
-
effective.setdefault("training", {})["validation"] = {
|
|
324
|
+
effective.setdefault("training", {})["validation"] = {
|
|
325
|
+
"enabled": bool(validation_cfg.get("enabled", True))
|
|
326
|
+
}
|
|
253
327
|
|
|
254
328
|
body = {
|
|
255
329
|
"model": model,
|
|
@@ -289,7 +363,9 @@ def main() -> None:
|
|
|
289
363
|
break
|
|
290
364
|
# Warn if stuck queued for >10 minutes
|
|
291
365
|
if status == "queued" and (time.time() - queued_since) > 600:
|
|
292
|
-
print(
|
|
366
|
+
print(
|
|
367
|
+
"[WARN] Job has remained queued for >10 minutes. Backend may be capacity constrained."
|
|
368
|
+
)
|
|
293
369
|
queued_since = time.time()
|
|
294
370
|
time.sleep(5)
|
|
295
371
|
|
|
@@ -46,17 +46,21 @@ def build_rollout_request(
|
|
|
46
46
|
)
|
|
47
47
|
return RolloutRequest(
|
|
48
48
|
run_id=run_id,
|
|
49
|
-
env=RolloutEnvSpec(env_name=
|
|
50
|
-
policy=RolloutPolicySpec(policy_name=
|
|
49
|
+
env=RolloutEnvSpec(env_name="crafter", seed=seed, config={}),
|
|
50
|
+
policy=RolloutPolicySpec(policy_name="crafter-react", config=policy_config),
|
|
51
51
|
ops=ops,
|
|
52
52
|
record=record_cfg,
|
|
53
|
-
on_done=
|
|
53
|
+
on_done="reset",
|
|
54
54
|
safety=RolloutSafetyConfig(),
|
|
55
55
|
)
|
|
56
56
|
|
|
57
57
|
|
|
58
58
|
def summarise_response(data: Any) -> dict[str, Any]:
|
|
59
|
-
metrics =
|
|
59
|
+
metrics = (
|
|
60
|
+
data.metrics.model_dump()
|
|
61
|
+
if hasattr(data.metrics, "model_dump")
|
|
62
|
+
else data.get("metrics", {})
|
|
63
|
+
)
|
|
60
64
|
error = None
|
|
61
65
|
rollout_status = None
|
|
62
66
|
try:
|
|
@@ -86,16 +90,42 @@ async def main() -> None:
|
|
|
86
90
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
87
91
|
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
88
92
|
parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
|
|
89
|
-
parser.add_argument(
|
|
90
|
-
parser.add_argument(
|
|
91
|
-
parser.add_argument(
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
parser.add_argument(
|
|
97
|
-
|
|
98
|
-
|
|
93
|
+
parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
|
|
94
|
+
parser.add_argument("--run-id", default="local-demo", help="Run identifier")
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--model",
|
|
97
|
+
default="gpt-4o-mini",
|
|
98
|
+
help="Model identifier for the Crafter policy (OpenAI-compatible)",
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--inference-url",
|
|
102
|
+
default="https://api.openai.com",
|
|
103
|
+
help="Inference base URL used by the policy (e.g., https://api.openai.com)",
|
|
104
|
+
)
|
|
105
|
+
parser.add_argument(
|
|
106
|
+
"--env-file", type=str, default=None, help="Path to .env file with API keys"
|
|
107
|
+
)
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--ops", default=None, help="Comma-separated rollout ops (advanced override)"
|
|
110
|
+
)
|
|
111
|
+
parser.add_argument(
|
|
112
|
+
"--max-llm-calls",
|
|
113
|
+
type=int,
|
|
114
|
+
default=1,
|
|
115
|
+
help="Number of policy inference calls when --ops not provided",
|
|
116
|
+
)
|
|
117
|
+
parser.add_argument(
|
|
118
|
+
"--max-policy-tokens",
|
|
119
|
+
type=int,
|
|
120
|
+
default=None,
|
|
121
|
+
help="Optional per-call token limit forwarded to the policy config",
|
|
122
|
+
)
|
|
123
|
+
parser.add_argument(
|
|
124
|
+
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests"
|
|
125
|
+
)
|
|
126
|
+
parser.add_argument(
|
|
127
|
+
"--verbose", action="store_true", help="Print resolved configuration and headers"
|
|
128
|
+
)
|
|
99
129
|
args = parser.parse_args()
|
|
100
130
|
|
|
101
131
|
if args.env_file:
|
|
@@ -117,12 +147,13 @@ async def main() -> None:
|
|
|
117
147
|
os.environ["OPENAI_API_KEY"] = synth_key
|
|
118
148
|
|
|
119
149
|
if args.verbose:
|
|
150
|
+
|
|
120
151
|
def _mask(val: str | None) -> str:
|
|
121
152
|
if not val:
|
|
122
|
-
return
|
|
153
|
+
return "<unset>"
|
|
123
154
|
return f"{val[:6]}…{val[-4:]} (len={len(val)})"
|
|
124
155
|
|
|
125
|
-
print(
|
|
156
|
+
print("Resolved configuration:")
|
|
126
157
|
print(f" Task app base URL : {args.base_url}")
|
|
127
158
|
print(f" Inference base URL : {args.inference_url}")
|
|
128
159
|
print(f" Task app API key : {_mask(api_key)}")
|
|
@@ -130,21 +161,23 @@ async def main() -> None:
|
|
|
130
161
|
print(f" HTTP timeout : {args.timeout:.1f}s")
|
|
131
162
|
|
|
132
163
|
if args.ops:
|
|
133
|
-
ops = [op.strip() for op in args.ops.split(
|
|
164
|
+
ops = [op.strip() for op in args.ops.split(",") if op.strip()]
|
|
134
165
|
if not ops:
|
|
135
|
-
raise ValueError(
|
|
166
|
+
raise ValueError("Ops must contain at least one entry")
|
|
136
167
|
else:
|
|
137
168
|
llm_calls = max(args.max_llm_calls, 1)
|
|
138
169
|
if llm_calls > 20:
|
|
139
|
-
print(
|
|
170
|
+
print(
|
|
171
|
+
"[WARN] --max-llm-calls capped at 20 to avoid excessive episodes; use --ops for manual control."
|
|
172
|
+
)
|
|
140
173
|
llm_calls = 20
|
|
141
174
|
ops = []
|
|
142
175
|
for _ in range(llm_calls):
|
|
143
|
-
ops.extend([
|
|
176
|
+
ops.extend(["agent", "env"])
|
|
144
177
|
|
|
145
178
|
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
146
179
|
try:
|
|
147
|
-
print(f
|
|
180
|
+
print(f"Fetching task_info for seed {args.seed}…")
|
|
148
181
|
task_info = await client.task_info(seeds=[args.seed])
|
|
149
182
|
info_payload = task_info[0] if isinstance(task_info, list) else task_info
|
|
150
183
|
print(json.dumps(info_payload.model_dump(), indent=2)[:600])
|
|
@@ -158,29 +191,47 @@ async def main() -> None:
|
|
|
158
191
|
extra_headers=extra_headers,
|
|
159
192
|
)
|
|
160
193
|
if args.max_policy_tokens is not None:
|
|
161
|
-
request.policy.config.update(
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
194
|
+
request.policy.config.update(
|
|
195
|
+
{
|
|
196
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
197
|
+
"max_tokens": args.max_policy_tokens,
|
|
198
|
+
}
|
|
199
|
+
)
|
|
165
200
|
if args.verbose:
|
|
166
|
-
print(f
|
|
167
|
-
print(f
|
|
168
|
-
print(
|
|
201
|
+
print(f"Ops: {ops}")
|
|
202
|
+
print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
|
|
203
|
+
print("Requesting rollout…")
|
|
169
204
|
response = await client.rollout(request)
|
|
170
205
|
summary = summarise_response(response)
|
|
171
206
|
print(json.dumps(summary, indent=2))
|
|
172
|
-
print(f
|
|
173
|
-
print(
|
|
207
|
+
print(f"Ops executed: {ops}")
|
|
208
|
+
print("Tip: use --max-llm-calls N for agent/env pairs or --ops for manual control.")
|
|
174
209
|
except httpx.HTTPStatusError as exc:
|
|
175
|
-
detail =
|
|
176
|
-
|
|
210
|
+
detail = (
|
|
211
|
+
exc.response.json()
|
|
212
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
213
|
+
else exc.response.text
|
|
214
|
+
)
|
|
215
|
+
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
177
216
|
if exc.response.status_code in (401, 503):
|
|
178
|
-
print(
|
|
217
|
+
print(
|
|
218
|
+
"Hint: ensure the task app was started with ENVIRONMENT_API_KEY set and pass the same key via --api-key.",
|
|
219
|
+
file=sys.stderr,
|
|
220
|
+
)
|
|
179
221
|
if exc.response.status_code == 500 and args.model in str(detail):
|
|
180
|
-
print(
|
|
181
|
-
|
|
222
|
+
print(
|
|
223
|
+
"Hint: supply --model/--inference-url (and set OPENAI_API_KEY or GROQ_API_KEY) so the policy can route inference.",
|
|
224
|
+
file=sys.stderr,
|
|
225
|
+
)
|
|
226
|
+
print(
|
|
227
|
+
"Hint: the inference URL should be the base (e.g., https://api.openai.com); the task app appends /v1/chat/completions.",
|
|
228
|
+
file=sys.stderr,
|
|
229
|
+
)
|
|
182
230
|
if args.max_policy_tokens is not None:
|
|
183
|
-
print(
|
|
231
|
+
print(
|
|
232
|
+
f"Hint: --max-policy-tokens={args.max_policy_tokens} is forwarded to the policy config as max_completion_tokens.",
|
|
233
|
+
file=sys.stderr,
|
|
234
|
+
)
|
|
184
235
|
raise
|
|
185
236
|
|
|
186
237
|
|
|
@@ -25,7 +25,9 @@ from synth_ai.task import (
|
|
|
25
25
|
)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def build_rollout_request(
|
|
28
|
+
def build_rollout_request(
|
|
29
|
+
seed: int, run_id: str, *, model: str, inference_url: str, ops: list[str], api_key: str
|
|
30
|
+
) -> RolloutRequest:
|
|
29
31
|
policy_config = {
|
|
30
32
|
"model": model,
|
|
31
33
|
"inference_url": inference_url,
|
|
@@ -45,7 +47,11 @@ def build_rollout_request(seed: int, run_id: str, *, model: str, inference_url:
|
|
|
45
47
|
|
|
46
48
|
|
|
47
49
|
def summarise_response(data: Any) -> dict[str, Any]:
|
|
48
|
-
metrics =
|
|
50
|
+
metrics = (
|
|
51
|
+
data.metrics.model_dump()
|
|
52
|
+
if hasattr(data.metrics, "model_dump")
|
|
53
|
+
else data.get("metrics", {})
|
|
54
|
+
)
|
|
49
55
|
return {
|
|
50
56
|
"run_id": getattr(data, "run_id", None) or data.get("run_id"),
|
|
51
57
|
"num_episodes": metrics.get("num_episodes"),
|
|
@@ -57,21 +63,54 @@ def summarise_response(data: Any) -> dict[str, Any]:
|
|
|
57
63
|
|
|
58
64
|
|
|
59
65
|
async def main() -> None:
|
|
66
|
+
# Load .env file from current directory first if it exists
|
|
67
|
+
default_env = Path.cwd() / ".env"
|
|
68
|
+
if default_env.exists():
|
|
69
|
+
load_dotenv(default_env, override=False)
|
|
70
|
+
|
|
60
71
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
61
72
|
parser.add_argument("--base-url", default="http://localhost:8010", help="Task app base URL")
|
|
62
73
|
parser.add_argument("--env-file", type=str, default=None, help="Path to .env file with keys")
|
|
63
74
|
parser.add_argument("--seed", type=int, default=42, help="Env seed to rollout")
|
|
64
75
|
parser.add_argument("--run-id", default="modal-eval", help="Run identifier")
|
|
65
|
-
parser.add_argument(
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
parser.add_argument(
|
|
71
|
-
|
|
72
|
-
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--model",
|
|
78
|
+
required=False,
|
|
79
|
+
help="Model identifier for the Crafter policy (e.g., fft:Qwen/Qwen3-4B:job_xxx)",
|
|
80
|
+
)
|
|
81
|
+
parser.add_argument(
|
|
82
|
+
"--inference-url",
|
|
83
|
+
required=False,
|
|
84
|
+
help="Modal backend inference base URL (e.g., http://localhost:8000/api)",
|
|
85
|
+
)
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--task-app-key",
|
|
88
|
+
default=None,
|
|
89
|
+
help="Environment API key for the task app (fallback ENVIRONMENT_API_KEY)",
|
|
90
|
+
)
|
|
91
|
+
parser.add_argument(
|
|
92
|
+
"--modal-key",
|
|
93
|
+
default=None,
|
|
94
|
+
help="Synth/Modal API key for inference (fallback SYNTH_API_KEY)",
|
|
95
|
+
)
|
|
96
|
+
parser.add_argument(
|
|
97
|
+
"--max-llm-calls", type=int, default=20, help="Number of policy inference calls"
|
|
98
|
+
)
|
|
99
|
+
parser.add_argument(
|
|
100
|
+
"--ops", default=None, help="Comma-separated rollout ops (advanced override)"
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument(
|
|
103
|
+
"--max-policy-tokens",
|
|
104
|
+
type=int,
|
|
105
|
+
default=None,
|
|
106
|
+
help="Optional per-call token limit forwarded to the policy config",
|
|
107
|
+
)
|
|
108
|
+
parser.add_argument(
|
|
109
|
+
"--verbose", action="store_true", help="Print resolved configuration and headers"
|
|
110
|
+
)
|
|
73
111
|
args = parser.parse_args()
|
|
74
112
|
|
|
113
|
+
# Also load from explicit --env-file if provided
|
|
75
114
|
if args.env_file:
|
|
76
115
|
env_path = Path(args.env_file).expanduser()
|
|
77
116
|
if not env_path.exists():
|
|
@@ -79,16 +118,51 @@ async def main() -> None:
|
|
|
79
118
|
else:
|
|
80
119
|
load_dotenv(env_path, override=False)
|
|
81
120
|
|
|
121
|
+
# Prompt for required parameters if not provided
|
|
122
|
+
base_url = args.base_url
|
|
123
|
+
if args.base_url == "http://localhost:8010":
|
|
124
|
+
print("\nTask app configuration:")
|
|
125
|
+
base_url_input = input(f"Task app base URL [http://localhost:8001]: ").strip()
|
|
126
|
+
base_url = base_url_input if base_url_input else "http://localhost:8001"
|
|
127
|
+
|
|
128
|
+
model = args.model
|
|
129
|
+
if not model:
|
|
130
|
+
print("\nFine-tuned model configuration:")
|
|
131
|
+
print(
|
|
132
|
+
"Note: This should be the model ID returned from training (e.g., fft:Qwen/Qwen3-4B:job_abc123)"
|
|
133
|
+
)
|
|
134
|
+
model_input = input("Fine-tuned model ID: ").strip()
|
|
135
|
+
if not model_input:
|
|
136
|
+
parser.error("Model identifier is required")
|
|
137
|
+
model = model_input
|
|
138
|
+
|
|
139
|
+
inference_url = args.inference_url
|
|
140
|
+
if not inference_url:
|
|
141
|
+
inference_url_input = input("Inference URL [http://localhost:8000/api]: ").strip()
|
|
142
|
+
inference_url = inference_url_input if inference_url_input else "http://localhost:8000/api"
|
|
143
|
+
|
|
144
|
+
# Override args
|
|
145
|
+
args.base_url = base_url
|
|
146
|
+
args.model = model
|
|
147
|
+
args.inference_url = inference_url
|
|
148
|
+
|
|
149
|
+
# Check environment variables first (loaded from .env)
|
|
82
150
|
task_app_key = args.task_app_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
83
151
|
if not task_app_key:
|
|
84
|
-
|
|
152
|
+
print("\n[INFO] ENVIRONMENT_API_KEY not found in environment or .env file")
|
|
153
|
+
task_app_key = input("RL Environment API key: ").strip()
|
|
154
|
+
if not task_app_key:
|
|
155
|
+
parser.error("Missing task app API key")
|
|
85
156
|
|
|
86
157
|
modal_key = args.modal_key or os.getenv("SYNTH_API_KEY")
|
|
87
158
|
if not modal_key:
|
|
88
|
-
|
|
159
|
+
print("[INFO] SYNTH_API_KEY not found in environment or .env file")
|
|
160
|
+
modal_key = input("Synth API key: ").strip()
|
|
161
|
+
if not modal_key:
|
|
162
|
+
parser.error("Missing Synth/Modal API key")
|
|
89
163
|
|
|
90
|
-
if
|
|
91
|
-
os.environ["OPENAI_API_KEY"] =
|
|
164
|
+
if modal_key and "openai.com" not in args.inference_url.lower():
|
|
165
|
+
os.environ["OPENAI_API_KEY"] = modal_key
|
|
92
166
|
|
|
93
167
|
if args.ops:
|
|
94
168
|
ops = [op.strip() for op in args.ops.split(",") if op.strip()]
|
|
@@ -103,6 +177,7 @@ async def main() -> None:
|
|
|
103
177
|
ops.extend(["agent", "env"])
|
|
104
178
|
|
|
105
179
|
if args.verbose:
|
|
180
|
+
|
|
106
181
|
def _mask(val: str | None) -> str:
|
|
107
182
|
if not val:
|
|
108
183
|
return "<unset>"
|
|
@@ -115,11 +190,15 @@ async def main() -> None:
|
|
|
115
190
|
print(f" Modal API key : {_mask(modal_key)}")
|
|
116
191
|
print(f" Ops (count={len(ops)}) : {ops}")
|
|
117
192
|
|
|
118
|
-
inf_url_norm = args.inference_url.rstrip(
|
|
119
|
-
if
|
|
120
|
-
print(
|
|
121
|
-
|
|
122
|
-
|
|
193
|
+
inf_url_norm = args.inference_url.rstrip("/")
|
|
194
|
+
if "/api" not in inf_url_norm:
|
|
195
|
+
print(
|
|
196
|
+
"[WARN] Inference URL is missing /api prefix; proxy endpoints usually live at /api/inference/v1/chat/completions."
|
|
197
|
+
)
|
|
198
|
+
elif not inf_url_norm.lower().endswith("/api"):
|
|
199
|
+
print(
|
|
200
|
+
"[INFO] Using inference base URL; policy will append /v1/chat/completions automatically."
|
|
201
|
+
)
|
|
123
202
|
|
|
124
203
|
async with TaskAppClient(args.base_url, api_key=task_app_key) as client:
|
|
125
204
|
try:
|
|
@@ -139,20 +218,29 @@ async def main() -> None:
|
|
|
139
218
|
if args.verbose:
|
|
140
219
|
print(f"Request headers: {request.policy.config.get('extra_headers', {})}")
|
|
141
220
|
if args.max_policy_tokens is not None:
|
|
142
|
-
request.policy.config.update(
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
221
|
+
request.policy.config.update(
|
|
222
|
+
{
|
|
223
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
224
|
+
"max_tokens": args.max_policy_tokens,
|
|
225
|
+
}
|
|
226
|
+
)
|
|
146
227
|
print("Requesting rollout…")
|
|
147
228
|
response = await client.rollout(request)
|
|
148
229
|
summary = summarise_response(response)
|
|
149
230
|
print(json.dumps(summary, indent=2))
|
|
150
231
|
print(f"Ops executed: {ops}")
|
|
151
232
|
except httpx.HTTPStatusError as exc:
|
|
152
|
-
detail =
|
|
233
|
+
detail = (
|
|
234
|
+
exc.response.json()
|
|
235
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
236
|
+
else exc.response.text
|
|
237
|
+
)
|
|
153
238
|
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
154
239
|
if exc.response.status_code in (401, 503):
|
|
155
|
-
print(
|
|
240
|
+
print(
|
|
241
|
+
"Hint: ensure ENVIRONMENT_API_KEY and SYNTH_API_KEY are correctly set.",
|
|
242
|
+
file=sys.stderr,
|
|
243
|
+
)
|
|
156
244
|
raise
|
|
157
245
|
|
|
158
246
|
|