synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1286 -170
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
- synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -31,12 +31,17 @@ def build_rollout_request(
|
|
|
31
31
|
run_id: str,
|
|
32
32
|
model: str,
|
|
33
33
|
inference_url: str,
|
|
34
|
+
inference_api_key: str,
|
|
34
35
|
ops: list[str],
|
|
35
36
|
extra_headers: dict[str, str] | None = None,
|
|
36
37
|
trace_format: str = "compact",
|
|
37
38
|
return_trace: bool = False,
|
|
38
39
|
) -> RolloutRequest:
|
|
39
|
-
policy_config = {
|
|
40
|
+
policy_config = {
|
|
41
|
+
"model": model,
|
|
42
|
+
"inference_url": inference_url,
|
|
43
|
+
"api_key": inference_api_key,
|
|
44
|
+
}
|
|
40
45
|
if extra_headers:
|
|
41
46
|
policy_config["extra_headers"] = extra_headers
|
|
42
47
|
record_cfg = RolloutRecordConfig(
|
|
@@ -123,7 +128,9 @@ def analyse_rollout_response(response: Any) -> dict[str, Any]:
|
|
|
123
128
|
if isinstance(final_list, list):
|
|
124
129
|
final_achievements = [str(item) for item in final_list]
|
|
125
130
|
|
|
126
|
-
decision_rewards =
|
|
131
|
+
decision_rewards = (
|
|
132
|
+
trace_payload.get("decision_rewards") if isinstance(trace_payload, dict) else []
|
|
133
|
+
)
|
|
127
134
|
trace_all: list[str] = []
|
|
128
135
|
if isinstance(decision_rewards, list):
|
|
129
136
|
for item in decision_rewards:
|
|
@@ -180,7 +187,9 @@ def summarise_runs(run_summaries: list[dict[str, Any]]) -> dict[str, Any]:
|
|
|
180
187
|
return stats
|
|
181
188
|
|
|
182
189
|
|
|
183
|
-
def print_summary(
|
|
190
|
+
def print_summary(
|
|
191
|
+
stats: dict[str, Any], *, run_details: list[dict[str, Any]], total_runs: int
|
|
192
|
+
) -> None:
|
|
184
193
|
if not stats:
|
|
185
194
|
print("No successful rollouts to summarise.")
|
|
186
195
|
return
|
|
@@ -234,7 +243,22 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
234
243
|
|
|
235
244
|
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
236
245
|
if not api_key:
|
|
237
|
-
|
|
246
|
+
import sys
|
|
247
|
+
|
|
248
|
+
print("Please enter your RL Environment API key:", file=sys.stderr, flush=True)
|
|
249
|
+
api_key = input("> ").strip()
|
|
250
|
+
if not api_key:
|
|
251
|
+
raise RuntimeError("RL Environment API key is required")
|
|
252
|
+
|
|
253
|
+
# Prompt for Groq API key if not set
|
|
254
|
+
groq_api_key = os.getenv("GROQ_API_KEY")
|
|
255
|
+
if not groq_api_key:
|
|
256
|
+
import sys
|
|
257
|
+
|
|
258
|
+
print("Please enter your Groq API key:", file=sys.stderr, flush=True)
|
|
259
|
+
groq_api_key = input("> ").strip()
|
|
260
|
+
if not groq_api_key:
|
|
261
|
+
raise RuntimeError("Groq API key is required")
|
|
238
262
|
|
|
239
263
|
synth_key = os.getenv("SYNTH_API_KEY")
|
|
240
264
|
extra_headers: dict[str, str] | None = None
|
|
@@ -252,29 +276,41 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
252
276
|
|
|
253
277
|
ops = build_ops(args.max_llm_calls, args.ops)
|
|
254
278
|
|
|
279
|
+
print(f"\n🚀 Starting {args.count} rollouts with {args.parallel} parallel workers...")
|
|
280
|
+
print(f"📊 Each rollout: {len(ops)} ops ({args.max_llm_calls} LLM calls)\n")
|
|
281
|
+
|
|
255
282
|
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
283
|
+
|
|
256
284
|
async def run_single(index: int) -> dict[str, Any]:
|
|
257
285
|
run_id = f"{args.run_id}-{index:03d}"
|
|
258
286
|
seed = args.seed + index * args.seed_stride
|
|
287
|
+
print(f"\n▶️ [{index + 1}/{args.count}] Starting rollout {run_id} (seed={seed})...")
|
|
288
|
+
|
|
259
289
|
request = build_rollout_request(
|
|
260
290
|
seed=seed,
|
|
261
291
|
run_id=run_id,
|
|
262
292
|
model=args.model,
|
|
263
293
|
inference_url=args.inference_url,
|
|
294
|
+
inference_api_key=groq_api_key,
|
|
264
295
|
ops=ops,
|
|
265
296
|
extra_headers=extra_headers,
|
|
266
297
|
trace_format=args.trace_format,
|
|
267
298
|
return_trace=True,
|
|
268
299
|
)
|
|
269
300
|
if args.max_policy_tokens is not None:
|
|
270
|
-
request.policy.config.update(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
301
|
+
request.policy.config.update(
|
|
302
|
+
{
|
|
303
|
+
"max_completion_tokens": args.max_policy_tokens,
|
|
304
|
+
"max_tokens": args.max_policy_tokens,
|
|
305
|
+
}
|
|
306
|
+
)
|
|
274
307
|
|
|
275
308
|
try:
|
|
276
309
|
response = await client.rollout(request)
|
|
277
310
|
summary = analyse_rollout_response(response)
|
|
311
|
+
print(
|
|
312
|
+
f"\n✅ [{index + 1}/{args.count}] Completed {run_id} (outcome={summary.get('outcome_score', 'N/A')})"
|
|
313
|
+
)
|
|
278
314
|
return {
|
|
279
315
|
"ok": True,
|
|
280
316
|
"run_id": run_id,
|
|
@@ -283,6 +319,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
283
319
|
"summary": summary,
|
|
284
320
|
}
|
|
285
321
|
except Exception as exc: # pragma: no cover - surface errors
|
|
322
|
+
print(f"\n❌ [{index + 1}/{args.count}] Failed {run_id}: {exc}")
|
|
286
323
|
return {
|
|
287
324
|
"ok": False,
|
|
288
325
|
"run_id": run_id,
|
|
@@ -302,6 +339,7 @@ async def execute_rollouts(args: argparse.Namespace) -> None:
|
|
|
302
339
|
successes = [item for item in results if item.get("ok")]
|
|
303
340
|
failures = [item for item in results if not item.get("ok")]
|
|
304
341
|
|
|
342
|
+
print(f"\n{'=' * 100}\n")
|
|
305
343
|
stats = summarise_runs([item["summary"] for item in successes])
|
|
306
344
|
print_summary(stats, run_details=successes, total_runs=args.count)
|
|
307
345
|
|
|
@@ -317,17 +355,43 @@ def parse_args() -> argparse.Namespace:
|
|
|
317
355
|
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
318
356
|
parser.add_argument("--api-key", help="Environment API key (or set via --env-file)")
|
|
319
357
|
parser.add_argument("--env-file", help="Path to .env file providing API keys")
|
|
320
|
-
parser.add_argument(
|
|
321
|
-
|
|
358
|
+
parser.add_argument(
|
|
359
|
+
"--model", default="gpt-4o-mini", help="Model identifier for the Crafter policy"
|
|
360
|
+
)
|
|
361
|
+
parser.add_argument(
|
|
362
|
+
"--inference-url",
|
|
363
|
+
default="https://api.openai.com",
|
|
364
|
+
help="Inference base URL for the policy",
|
|
365
|
+
)
|
|
322
366
|
parser.add_argument("--seed", type=int, default=42, help="Base seed for the first rollout")
|
|
323
|
-
parser.add_argument(
|
|
324
|
-
|
|
367
|
+
parser.add_argument(
|
|
368
|
+
"--seed-stride", type=int, default=1, help="Increment applied to the seed for each rollout"
|
|
369
|
+
)
|
|
370
|
+
parser.add_argument(
|
|
371
|
+
"--count", type=int, default=20, help="Number of rollout trajectories to execute"
|
|
372
|
+
)
|
|
325
373
|
parser.add_argument("--parallel", type=int, default=4, help="Maximum concurrent rollouts")
|
|
326
374
|
parser.add_argument("--ops", help="Comma-separated rollout ops (advanced override)")
|
|
327
|
-
parser.add_argument(
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
375
|
+
parser.add_argument(
|
|
376
|
+
"--max-llm-calls",
|
|
377
|
+
type=int,
|
|
378
|
+
default=20,
|
|
379
|
+
help="Number of agent/env pairs per rollout when --ops not provided",
|
|
380
|
+
)
|
|
381
|
+
parser.add_argument(
|
|
382
|
+
"--max-policy-tokens",
|
|
383
|
+
type=int,
|
|
384
|
+
help="Optional per-call token limit forwarded to the policy config",
|
|
385
|
+
)
|
|
386
|
+
parser.add_argument(
|
|
387
|
+
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds) for task app requests"
|
|
388
|
+
)
|
|
389
|
+
parser.add_argument(
|
|
390
|
+
"--trace-format",
|
|
391
|
+
default="compact",
|
|
392
|
+
choices=["compact", "full"],
|
|
393
|
+
help="Trace format requested from the task app",
|
|
394
|
+
)
|
|
331
395
|
parser.add_argument("--run-id", default="batch-demo", help="Run ID prefix for rollouts")
|
|
332
396
|
parser.add_argument("--verbose", action="store_true", help="Print resolved configuration")
|
|
333
397
|
return parser.parse_args()
|
|
@@ -6,6 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
import argparse
|
|
7
7
|
import asyncio
|
|
8
8
|
import json
|
|
9
|
+
import os
|
|
9
10
|
from pathlib import Path
|
|
10
11
|
from typing import Any
|
|
11
12
|
|
|
@@ -29,6 +30,7 @@ def build_rollout_request(
|
|
|
29
30
|
run_id: str,
|
|
30
31
|
model: str,
|
|
31
32
|
inference_url: str,
|
|
33
|
+
inference_api_key: str,
|
|
32
34
|
ops: list[str],
|
|
33
35
|
return_trace: bool,
|
|
34
36
|
trace_format: str,
|
|
@@ -37,6 +39,7 @@ def build_rollout_request(
|
|
|
37
39
|
policy_config = {
|
|
38
40
|
"model": model,
|
|
39
41
|
"inference_url": inference_url,
|
|
42
|
+
"api_key": inference_api_key,
|
|
40
43
|
}
|
|
41
44
|
if max_policy_tokens is not None:
|
|
42
45
|
policy_config.update(
|
|
@@ -64,7 +67,11 @@ def build_rollout_request(
|
|
|
64
67
|
|
|
65
68
|
|
|
66
69
|
def summarise_rollout(response: Any) -> dict[str, Any]:
|
|
67
|
-
metrics =
|
|
70
|
+
metrics = (
|
|
71
|
+
response.metrics.model_dump()
|
|
72
|
+
if hasattr(response, "metrics")
|
|
73
|
+
else response.get("metrics", {})
|
|
74
|
+
)
|
|
68
75
|
return {
|
|
69
76
|
"run_id": getattr(response, "run_id", None) or response.get("run_id"),
|
|
70
77
|
"num_episodes": metrics.get("num_episodes"),
|
|
@@ -83,17 +90,25 @@ def summarise_trace(trace: Any) -> dict[str, Any]:
|
|
|
83
90
|
|
|
84
91
|
format_hint = "compact" if "events_count" in trace or "lm_calls" in trace else "full"
|
|
85
92
|
events_count = trace.get("events_count")
|
|
86
|
-
if
|
|
93
|
+
if (
|
|
94
|
+
events_count is None
|
|
95
|
+
and "event_history" in trace
|
|
96
|
+
and isinstance(trace["event_history"], list)
|
|
97
|
+
):
|
|
87
98
|
events_count = len(trace["event_history"])
|
|
88
99
|
messages_count = trace.get("messages_count")
|
|
89
|
-
if
|
|
90
|
-
|
|
100
|
+
if (
|
|
101
|
+
messages_count is None
|
|
102
|
+
and "markov_blanket_message_history" in trace
|
|
103
|
+
and isinstance(trace["markov_blanket_message_history"], list)
|
|
91
104
|
):
|
|
92
105
|
messages_count = len(trace["markov_blanket_message_history"])
|
|
93
106
|
|
|
94
107
|
metadata = trace.get("metadata") if isinstance(trace.get("metadata"), dict) else {}
|
|
95
108
|
lm_calls = trace.get("lm_calls") if isinstance(trace.get("lm_calls"), list) else []
|
|
96
|
-
decision_rewards =
|
|
109
|
+
decision_rewards = (
|
|
110
|
+
trace.get("decision_rewards") if isinstance(trace.get("decision_rewards"), list) else []
|
|
111
|
+
)
|
|
97
112
|
|
|
98
113
|
return {
|
|
99
114
|
"session_id": trace.get("session_id"),
|
|
@@ -215,11 +230,13 @@ def print_reward_summary(
|
|
|
215
230
|
if decision_rewards:
|
|
216
231
|
print(" Decision rewards:")
|
|
217
232
|
for entry in decision_rewards:
|
|
218
|
-
turn = entry.get(
|
|
219
|
-
ach_delta = entry.get(
|
|
220
|
-
unique_delta = entry.get(
|
|
221
|
-
achievements = entry.get(
|
|
222
|
-
print(
|
|
233
|
+
turn = entry.get("turn")
|
|
234
|
+
ach_delta = entry.get("ach_delta")
|
|
235
|
+
unique_delta = entry.get("unique_delta")
|
|
236
|
+
achievements = entry.get("achievements") or []
|
|
237
|
+
print(
|
|
238
|
+
f" turn={turn}, ach_delta={ach_delta}, unique_delta={unique_delta}, achievements={achievements}"
|
|
239
|
+
)
|
|
223
240
|
else:
|
|
224
241
|
print(" Decision rewards: none recorded")
|
|
225
242
|
|
|
@@ -242,16 +259,40 @@ def print_reward_summary(
|
|
|
242
259
|
|
|
243
260
|
|
|
244
261
|
async def main() -> None:
|
|
262
|
+
# Load .env file from current directory if it exists
|
|
263
|
+
env_file = Path.cwd() / ".env"
|
|
264
|
+
if env_file.exists():
|
|
265
|
+
from dotenv import load_dotenv
|
|
266
|
+
|
|
267
|
+
load_dotenv(env_file)
|
|
268
|
+
|
|
245
269
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
246
|
-
parser.add_argument("--base-url", default="http://localhost:
|
|
247
|
-
parser.add_argument("--api-key",
|
|
270
|
+
parser.add_argument("--base-url", default="http://localhost:8001", help="Task app base URL")
|
|
271
|
+
parser.add_argument("--api-key", help="RL Environment API key (will prompt if not provided)")
|
|
272
|
+
parser.add_argument(
|
|
273
|
+
"--inference-api-key", help="Inference provider API key (will prompt if not provided)"
|
|
274
|
+
)
|
|
248
275
|
parser.add_argument("--seed", type=int, default=42, help="Environment seed")
|
|
249
276
|
parser.add_argument("--run-id", default="local-trace", help="Run identifier")
|
|
250
277
|
parser.add_argument("--model", default="gpt-4o-mini", help="OpenAI-compatible model id")
|
|
251
|
-
parser.add_argument(
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
parser.add_argument(
|
|
278
|
+
parser.add_argument(
|
|
279
|
+
"--inference-url", default="https://api.openai.com", help="Inference base URL (OpenAI/Groq)"
|
|
280
|
+
)
|
|
281
|
+
parser.add_argument(
|
|
282
|
+
"--ops", help="Comma-separated rollout ops (fallback: alternating agent/env)"
|
|
283
|
+
)
|
|
284
|
+
parser.add_argument(
|
|
285
|
+
"--max-llm-calls",
|
|
286
|
+
type=int,
|
|
287
|
+
default=1,
|
|
288
|
+
help="Number of agent/env pairs when --ops not supplied",
|
|
289
|
+
)
|
|
290
|
+
parser.add_argument(
|
|
291
|
+
"--max-policy-tokens",
|
|
292
|
+
type=int,
|
|
293
|
+
default=None,
|
|
294
|
+
help="Optional max token budget forwarded to policy",
|
|
295
|
+
)
|
|
255
296
|
parser.add_argument(
|
|
256
297
|
"--trace-format",
|
|
257
298
|
choices=["compact", "full"],
|
|
@@ -286,10 +327,69 @@ async def main() -> None:
|
|
|
286
327
|
)
|
|
287
328
|
args = parser.parse_args()
|
|
288
329
|
|
|
330
|
+
# Prompt for required parameters if not provided
|
|
331
|
+
base_url = args.base_url
|
|
332
|
+
if args.base_url == "http://localhost:8001":
|
|
333
|
+
print("\nTask app configuration:")
|
|
334
|
+
base_url_input = input(f"Task app base URL [http://localhost:8001]: ").strip()
|
|
335
|
+
base_url = base_url_input if base_url_input else "http://localhost:8001"
|
|
336
|
+
|
|
337
|
+
api_key = args.api_key or os.getenv("ENVIRONMENT_API_KEY")
|
|
338
|
+
if not api_key:
|
|
339
|
+
api_key = input("RL Environment API key (from ENVIRONMENT_API_KEY): ").strip()
|
|
340
|
+
if not api_key:
|
|
341
|
+
parser.error("RL Environment API key is required")
|
|
342
|
+
|
|
343
|
+
# Use Groq by default
|
|
344
|
+
model = "llama-3.3-70b-versatile"
|
|
345
|
+
inference_url = "https://api.groq.com/openai"
|
|
346
|
+
|
|
347
|
+
print("\nInference configuration (Groq):")
|
|
348
|
+
inference_api_key = args.inference_api_key or os.getenv("GROQ_API_KEY")
|
|
349
|
+
if not inference_api_key:
|
|
350
|
+
inference_api_key = input("Groq API key: ").strip()
|
|
351
|
+
if not inference_api_key:
|
|
352
|
+
parser.error("Groq API key is required")
|
|
353
|
+
|
|
354
|
+
# Save to .env for future use
|
|
355
|
+
env_path = Path.cwd() / ".env"
|
|
356
|
+
try:
|
|
357
|
+
# Read existing .env
|
|
358
|
+
existing_lines = []
|
|
359
|
+
if env_path.exists():
|
|
360
|
+
existing_lines = env_path.read_text().splitlines()
|
|
361
|
+
|
|
362
|
+
# Check if GROQ_API_KEY already exists
|
|
363
|
+
key_exists = any(line.strip().startswith("GROQ_API_KEY=") for line in existing_lines)
|
|
364
|
+
|
|
365
|
+
if not key_exists:
|
|
366
|
+
# Append to .env
|
|
367
|
+
with open(env_path, "a") as f:
|
|
368
|
+
if existing_lines and not existing_lines[-1].strip():
|
|
369
|
+
# File exists and last line is not empty
|
|
370
|
+
pass
|
|
371
|
+
elif existing_lines:
|
|
372
|
+
# Add newline before appending
|
|
373
|
+
f.write("\n")
|
|
374
|
+
f.write(f"GROQ_API_KEY={inference_api_key}\n")
|
|
375
|
+
print(f"[INFO] Saved GROQ_API_KEY to {env_path}")
|
|
376
|
+
except Exception as e:
|
|
377
|
+
print(f"[WARN] Could not save GROQ_API_KEY to .env: {e}")
|
|
378
|
+
|
|
379
|
+
print("\nRollout configuration:")
|
|
380
|
+
max_llm_calls = args.max_llm_calls
|
|
381
|
+
if args.max_llm_calls == 1:
|
|
382
|
+
max_llm_calls_input = input(f"Max LLM calls [10]: ").strip()
|
|
383
|
+
max_llm_calls = int(max_llm_calls_input) if max_llm_calls_input else 10
|
|
384
|
+
|
|
385
|
+
# Override args with prompted values
|
|
386
|
+
args.base_url = base_url
|
|
387
|
+
args.max_llm_calls = max_llm_calls
|
|
388
|
+
|
|
289
389
|
ops = ensure_ops(args.ops, args.max_llm_calls)
|
|
290
390
|
return_trace = not args.no_trace
|
|
291
391
|
|
|
292
|
-
async with TaskAppClient(args.base_url, api_key=
|
|
392
|
+
async with TaskAppClient(args.base_url, api_key=api_key, timeout=args.timeout) as client:
|
|
293
393
|
try:
|
|
294
394
|
print(f"Fetching task_info for seed {args.seed}…")
|
|
295
395
|
task_info = await client.task_info(seeds=[args.seed])
|
|
@@ -302,8 +402,9 @@ async def main() -> None:
|
|
|
302
402
|
request = build_rollout_request(
|
|
303
403
|
seed=args.seed,
|
|
304
404
|
run_id=args.run_id,
|
|
305
|
-
model=
|
|
306
|
-
inference_url=
|
|
405
|
+
model=model,
|
|
406
|
+
inference_url=inference_url,
|
|
407
|
+
inference_api_key=inference_api_key,
|
|
307
408
|
ops=ops,
|
|
308
409
|
return_trace=return_trace,
|
|
309
410
|
trace_format=args.trace_format,
|
|
@@ -350,7 +451,11 @@ async def main() -> None:
|
|
|
350
451
|
"Tip: export TASKAPP_TRACING_ENABLED=1 and optionally TASKAPP_SFT_OUTPUT_DIR before running `uvx synth-ai serve …` to persist traces/SFT."
|
|
351
452
|
)
|
|
352
453
|
except httpx.HTTPStatusError as exc:
|
|
353
|
-
detail =
|
|
454
|
+
detail = (
|
|
455
|
+
exc.response.json()
|
|
456
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
457
|
+
else exc.response.text
|
|
458
|
+
)
|
|
354
459
|
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
355
460
|
if exc.response.status_code in (401, 503):
|
|
356
461
|
print(
|
|
@@ -11,6 +11,8 @@ from typing import Any, Dict
|
|
|
11
11
|
import tomllib
|
|
12
12
|
import requests
|
|
13
13
|
|
|
14
|
+
from synth_ai.config.base_url import PROD_BASE_URL_DEFAULT
|
|
15
|
+
|
|
14
16
|
|
|
15
17
|
def _load_toml(path: Path) -> Dict[str, Any]:
|
|
16
18
|
if not path.exists():
|
|
@@ -21,11 +23,23 @@ def _load_toml(path: Path) -> Dict[str, Any]:
|
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
def main() -> None:
|
|
24
|
-
p = argparse.ArgumentParser(
|
|
25
|
-
|
|
26
|
+
p = argparse.ArgumentParser(
|
|
27
|
+
description="Create clustered RL training job via backend RL endpoint"
|
|
28
|
+
)
|
|
29
|
+
p.add_argument(
|
|
30
|
+
"--backend", default=os.getenv("BACKEND_BASE_URL", f"{PROD_BASE_URL_DEFAULT}/api")
|
|
31
|
+
)
|
|
26
32
|
p.add_argument("--config", required=True, help="Path to RL TOML config")
|
|
27
|
-
p.add_argument(
|
|
28
|
-
|
|
33
|
+
p.add_argument(
|
|
34
|
+
"--task-url",
|
|
35
|
+
default=os.getenv("TASK_APP_URL", ""),
|
|
36
|
+
help="Override task service URL (or set TASK_APP_URL)",
|
|
37
|
+
)
|
|
38
|
+
p.add_argument(
|
|
39
|
+
"--idempotency",
|
|
40
|
+
default=os.getenv("RL_IDEMPOTENCY_KEY", ""),
|
|
41
|
+
help="Optional Idempotency-Key header value",
|
|
42
|
+
)
|
|
29
43
|
args = p.parse_args()
|
|
30
44
|
|
|
31
45
|
cfg_path = Path(args.config).expanduser()
|
|
@@ -36,9 +50,16 @@ def main() -> None:
|
|
|
36
50
|
# Resolve task app base URL for the job
|
|
37
51
|
cli_task_url = (args.task_url or "").strip()
|
|
38
52
|
env_task_url = (os.getenv("TASK_APP_URL") or "").strip()
|
|
39
|
-
task_url =
|
|
53
|
+
task_url = (
|
|
54
|
+
cli_task_url
|
|
55
|
+
or env_task_url
|
|
56
|
+
or ((services.get("task_url") or "").strip() if isinstance(services, dict) else "")
|
|
57
|
+
)
|
|
40
58
|
if not task_url:
|
|
41
|
-
print(
|
|
59
|
+
print(
|
|
60
|
+
"Missing task service URL. Provide --task-url or set TASK_APP_URL or services.task_url in TOML",
|
|
61
|
+
file=sys.stderr,
|
|
62
|
+
)
|
|
42
63
|
sys.exit(2)
|
|
43
64
|
|
|
44
65
|
# TOML-only model selection validation
|
|
@@ -46,7 +67,10 @@ def main() -> None:
|
|
|
46
67
|
has_source = bool((model_cfg.get("source") or "").strip())
|
|
47
68
|
has_base = bool((model_cfg.get("base") or "").strip())
|
|
48
69
|
if has_source == has_base:
|
|
49
|
-
print(
|
|
70
|
+
print(
|
|
71
|
+
"Model selection must specify exactly one of [model].source or [model].base in TOML",
|
|
72
|
+
file=sys.stderr,
|
|
73
|
+
)
|
|
50
74
|
sys.exit(2)
|
|
51
75
|
|
|
52
76
|
# Build create-job payload. Send full TOML under data.config, plus endpoint_base_url.
|
|
@@ -11,10 +11,17 @@ import sys
|
|
|
11
11
|
|
|
12
12
|
import httpx
|
|
13
13
|
|
|
14
|
+
|
|
14
15
|
def check_health(base_url: str, api_key: str) -> None:
|
|
15
16
|
try:
|
|
16
|
-
resp = httpx.get(
|
|
17
|
-
|
|
17
|
+
resp = httpx.get(
|
|
18
|
+
f"{base_url.rstrip('/')}/health", headers={"X-API-Key": api_key}, timeout=10.0
|
|
19
|
+
)
|
|
20
|
+
data = (
|
|
21
|
+
resp.json()
|
|
22
|
+
if resp.headers.get("content-type", "").startswith("application/json")
|
|
23
|
+
else resp.text
|
|
24
|
+
)
|
|
18
25
|
if resp.status_code != 200:
|
|
19
26
|
print(f"warning: /health returned {resp.status_code}: {data}")
|
|
20
27
|
else:
|
|
@@ -22,6 +29,7 @@ def check_health(base_url: str, api_key: str) -> None:
|
|
|
22
29
|
except Exception as exc:
|
|
23
30
|
print(f"warning: failed to call /health: {exc}")
|
|
24
31
|
|
|
32
|
+
|
|
25
33
|
from synth_ai.task import (
|
|
26
34
|
RolloutEnvSpec,
|
|
27
35
|
RolloutPolicySpec,
|
|
@@ -79,8 +87,14 @@ def summarise(response) -> dict[str, any]:
|
|
|
79
87
|
|
|
80
88
|
async def main() -> None:
|
|
81
89
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
82
|
-
parser.add_argument(
|
|
83
|
-
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--base-url",
|
|
92
|
+
default=None,
|
|
93
|
+
help="Remote task app base URL (e.g., https://xyz.modal.run); defaults to TASK_APP_BASE_URL env",
|
|
94
|
+
)
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--api-key", required=True, help="Environment API key for the remote task app"
|
|
97
|
+
)
|
|
84
98
|
parser.add_argument("--seed", type=int, default=42)
|
|
85
99
|
parser.add_argument("--run-id", default="remote-demo")
|
|
86
100
|
parser.add_argument("--model", default="gpt-4o-mini")
|
|
@@ -89,9 +103,9 @@ async def main() -> None:
|
|
|
89
103
|
parser.add_argument("--max-policy-tokens", type=int, default=None)
|
|
90
104
|
args = parser.parse_args()
|
|
91
105
|
|
|
92
|
-
base_url = args.base_url or os.getenv(
|
|
106
|
+
base_url = args.base_url or os.getenv("TASK_APP_BASE_URL")
|
|
93
107
|
if not base_url:
|
|
94
|
-
parser.error(
|
|
108
|
+
parser.error("Missing --base-url (and TASK_APP_BASE_URL not set).")
|
|
95
109
|
|
|
96
110
|
request = build_request(
|
|
97
111
|
run_id=args.run_id,
|
|
@@ -114,14 +128,27 @@ async def main() -> None:
|
|
|
114
128
|
print(json.dumps(summarise(response), indent=2))
|
|
115
129
|
print(f"Ops executed: {request.ops}")
|
|
116
130
|
except httpx.HTTPStatusError as exc:
|
|
117
|
-
detail =
|
|
131
|
+
detail = (
|
|
132
|
+
exc.response.json()
|
|
133
|
+
if exc.response.headers.get("content-type", "").startswith("application/json")
|
|
134
|
+
else exc.response.text
|
|
135
|
+
)
|
|
118
136
|
print(f"HTTP error {exc.response.status_code}: {detail}", file=sys.stderr)
|
|
119
137
|
if exc.response.status_code in (401, 403):
|
|
120
|
-
print(
|
|
138
|
+
print(
|
|
139
|
+
"Hint: check --api-key and ensure the remote deployment expects that value.",
|
|
140
|
+
file=sys.stderr,
|
|
141
|
+
)
|
|
121
142
|
if exc.response.status_code == 404:
|
|
122
|
-
print(
|
|
143
|
+
print(
|
|
144
|
+
"Hint: verify the --base-url includes the correct path (should be the root of the task app).",
|
|
145
|
+
file=sys.stderr,
|
|
146
|
+
)
|
|
123
147
|
if exc.response.status_code == 500:
|
|
124
|
-
print(
|
|
148
|
+
print(
|
|
149
|
+
"Hint: remote rollout failed server-side; inspect the deployment logs (Modal dashboard/logs).",
|
|
150
|
+
file=sys.stderr,
|
|
151
|
+
)
|
|
125
152
|
raise
|
|
126
153
|
|
|
127
154
|
|