synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1276 -186
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -233,7 +233,9 @@ def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[
|
|
|
233
233
|
continue
|
|
234
234
|
entry = dict(call)
|
|
235
235
|
|
|
236
|
-
func_payload: dict[str, Any] | None =
|
|
236
|
+
func_payload: dict[str, Any] | None = (
|
|
237
|
+
entry.get("function") if isinstance(entry.get("function"), dict) else None
|
|
238
|
+
)
|
|
237
239
|
name = entry.get("name") or (func_payload.get("name") if func_payload else None) or "tool"
|
|
238
240
|
|
|
239
241
|
args = None
|
|
@@ -355,7 +357,10 @@ def build_sft_dataset(
|
|
|
355
357
|
if not assistant_tool_calls:
|
|
356
358
|
assistant_tool_calls = _normalise_tool_calls(record.get("output_tool_calls"))
|
|
357
359
|
|
|
358
|
-
assistant_message: dict[str, Any] = {
|
|
360
|
+
assistant_message: dict[str, Any] = {
|
|
361
|
+
"role": "assistant",
|
|
362
|
+
"content": assistant_content or "",
|
|
363
|
+
}
|
|
359
364
|
if assistant_tool_calls:
|
|
360
365
|
assistant_message["tool_calls"] = assistant_tool_calls
|
|
361
366
|
|
|
@@ -426,27 +431,141 @@ def _validate_dataset(records: list[dict[str, Any]]) -> None:
|
|
|
426
431
|
raise SystemExit(f"Validation error while exporting dataset:\n - {summary}")
|
|
427
432
|
|
|
428
433
|
|
|
434
|
+
def _find_trace_database() -> Path | None:
|
|
435
|
+
"""Automatically discover the trace database in common locations."""
|
|
436
|
+
|
|
437
|
+
# Check for demo directory from state
|
|
438
|
+
try:
|
|
439
|
+
state_path = Path.home() / ".synth-ai" / "demo.json"
|
|
440
|
+
if state_path.exists():
|
|
441
|
+
import json
|
|
442
|
+
|
|
443
|
+
with state_path.open() as f:
|
|
444
|
+
data = json.load(f)
|
|
445
|
+
demo_dir = data.get("DEMO_DIR")
|
|
446
|
+
if demo_dir:
|
|
447
|
+
candidate = Path(demo_dir) / "traces" / "v3" / "synth_ai.db"
|
|
448
|
+
if candidate.exists():
|
|
449
|
+
return candidate
|
|
450
|
+
except Exception:
|
|
451
|
+
pass
|
|
452
|
+
|
|
453
|
+
# Search upward from current directory
|
|
454
|
+
cwd = Path.cwd()
|
|
455
|
+
for parent in [cwd] + list(cwd.parents):
|
|
456
|
+
candidate = parent / "traces" / "v3" / "synth_ai.db"
|
|
457
|
+
if candidate.exists():
|
|
458
|
+
return candidate
|
|
459
|
+
|
|
460
|
+
# Check standard locations
|
|
461
|
+
standard_locations = [
|
|
462
|
+
Path("traces/v3/synth_ai.db"),
|
|
463
|
+
Path("../traces/v3/synth_ai.db"),
|
|
464
|
+
Path.home() / "synth-ai" / "traces" / "v3" / "synth_ai.db",
|
|
465
|
+
]
|
|
466
|
+
|
|
467
|
+
for location in standard_locations:
|
|
468
|
+
try:
|
|
469
|
+
if location.exists():
|
|
470
|
+
return location.resolve()
|
|
471
|
+
except Exception:
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
return None
|
|
475
|
+
|
|
476
|
+
|
|
429
477
|
def main() -> None:
|
|
430
478
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
431
|
-
parser.add_argument("--db", type=Path, default=
|
|
432
|
-
parser.add_argument(
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
479
|
+
parser.add_argument("--db", type=Path, default=None, help="Path to tracing_v3 SQLite DB")
|
|
480
|
+
parser.add_argument(
|
|
481
|
+
"--output",
|
|
482
|
+
type=Path,
|
|
483
|
+
required=False,
|
|
484
|
+
help="Destination JSONL path for the exported dataset",
|
|
485
|
+
)
|
|
486
|
+
parser.add_argument(
|
|
487
|
+
"--model",
|
|
488
|
+
action="append",
|
|
489
|
+
dest="models",
|
|
490
|
+
help="Restrict to sessions whose dominant model matches (repeatable)",
|
|
491
|
+
)
|
|
492
|
+
parser.add_argument(
|
|
493
|
+
"--provider",
|
|
494
|
+
action="append",
|
|
495
|
+
dest="providers",
|
|
496
|
+
help="Restrict to sessions whose dominant provider matches (repeatable)",
|
|
497
|
+
)
|
|
498
|
+
parser.add_argument(
|
|
499
|
+
"--min-unique", type=int, default=None, help="Minimum unique achievements per session"
|
|
500
|
+
)
|
|
501
|
+
parser.add_argument(
|
|
502
|
+
"--max-unique", type=int, default=None, help="Maximum unique achievements per session"
|
|
503
|
+
)
|
|
437
504
|
parser.add_argument(
|
|
438
505
|
"--exclude-achievement",
|
|
439
506
|
action="append",
|
|
440
507
|
dest="exclude_achievements",
|
|
441
508
|
help="Achievements to ignore when evaluating --min-unique/--max-unique (repeatable)",
|
|
442
509
|
)
|
|
443
|
-
parser.add_argument(
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
510
|
+
parser.add_argument(
|
|
511
|
+
"--require-achievement",
|
|
512
|
+
action="append",
|
|
513
|
+
dest="required_achievements",
|
|
514
|
+
help="Require these outcome achievements (repeatable)",
|
|
515
|
+
)
|
|
516
|
+
parser.add_argument(
|
|
517
|
+
"--min-outcome-reward",
|
|
518
|
+
type=float,
|
|
519
|
+
default=None,
|
|
520
|
+
help="Minimum total outcome reward per session",
|
|
521
|
+
)
|
|
522
|
+
parser.add_argument(
|
|
523
|
+
"--max-outcome-reward",
|
|
524
|
+
type=float,
|
|
525
|
+
default=None,
|
|
526
|
+
help="Maximum total outcome reward per session",
|
|
527
|
+
)
|
|
528
|
+
parser.add_argument(
|
|
529
|
+
"--event-reward",
|
|
530
|
+
action="append",
|
|
531
|
+
dest="event_reward_filters",
|
|
532
|
+
help="Require reward_type[:min_total] in event_rewards (repeatable)",
|
|
533
|
+
)
|
|
534
|
+
parser.add_argument(
|
|
535
|
+
"--limit", type=int, default=None, help="Maximum number of examples to emit"
|
|
536
|
+
)
|
|
448
537
|
args = parser.parse_args()
|
|
449
538
|
|
|
539
|
+
# Auto-discover database if not specified
|
|
540
|
+
db_path = args.db
|
|
541
|
+
if db_path is None:
|
|
542
|
+
db_path = _find_trace_database()
|
|
543
|
+
if db_path:
|
|
544
|
+
print(f"Found trace database: {db_path}")
|
|
545
|
+
else:
|
|
546
|
+
print("\nTrace database configuration:")
|
|
547
|
+
db_input = input("Trace database path [traces/v3/synth_ai.db]: ").strip()
|
|
548
|
+
db_path = Path(db_input) if db_input else Path("traces/v3/synth_ai.db")
|
|
549
|
+
|
|
550
|
+
if not db_path.exists():
|
|
551
|
+
print(f"Database not found: {db_path}", file=sys.stderr)
|
|
552
|
+
raise SystemExit(1)
|
|
553
|
+
|
|
554
|
+
output_path = args.output
|
|
555
|
+
if not output_path:
|
|
556
|
+
output_path = Path("ft_data/crafter_traces.jsonl")
|
|
557
|
+
print(f"Output will be written to: {output_path.resolve()}")
|
|
558
|
+
|
|
559
|
+
min_unique = args.min_unique
|
|
560
|
+
if min_unique is None:
|
|
561
|
+
min_unique = 0 # Default to including all traces
|
|
562
|
+
print(f"Minimum unique achievements filter: {min_unique} (all traces)")
|
|
563
|
+
|
|
564
|
+
# Override args with prompted values
|
|
565
|
+
args.db = db_path
|
|
566
|
+
args.output = output_path
|
|
567
|
+
args.min_unique = min_unique
|
|
568
|
+
|
|
450
569
|
if not args.db.exists():
|
|
451
570
|
print(f"Database not found: {args.db}", file=sys.stderr)
|
|
452
571
|
raise SystemExit(1)
|
|
@@ -488,7 +607,11 @@ def main() -> None:
|
|
|
488
607
|
|
|
489
608
|
outcome = outcome_data.get(session_id)
|
|
490
609
|
total_reward = outcome["total_reward"] if outcome else 0.0
|
|
491
|
-
final_achievements =
|
|
610
|
+
final_achievements = (
|
|
611
|
+
outcome["achievements"]
|
|
612
|
+
if outcome
|
|
613
|
+
else session_final_achievements.get(session_id, set())
|
|
614
|
+
)
|
|
492
615
|
|
|
493
616
|
if args.min_outcome_reward is not None and total_reward < args.min_outcome_reward:
|
|
494
617
|
continue
|
|
@@ -522,7 +645,9 @@ def main() -> None:
|
|
|
522
645
|
)
|
|
523
646
|
|
|
524
647
|
if not dataset:
|
|
525
|
-
print(
|
|
648
|
+
print(
|
|
649
|
+
"No rollout steps matched the filters (after session selection).", file=sys.stderr
|
|
650
|
+
)
|
|
526
651
|
raise SystemExit(1)
|
|
527
652
|
|
|
528
653
|
_validate_dataset(dataset)
|
|
@@ -530,7 +655,7 @@ def main() -> None:
|
|
|
530
655
|
session_ids = {item.get("metadata", {}).get("session_id") for item in dataset}
|
|
531
656
|
session_ids.discard(None)
|
|
532
657
|
print(
|
|
533
|
-
f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output}",
|
|
658
|
+
f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output.resolve()}",
|
|
534
659
|
file=sys.stderr,
|
|
535
660
|
)
|
|
536
661
|
finally:
|
|
@@ -63,13 +63,21 @@ async def run(args: argparse.Namespace) -> None:
|
|
|
63
63
|
response = await client.rollout(request)
|
|
64
64
|
print("rollout.metrics →", to_jsonable(response.metrics.model_dump()))
|
|
65
65
|
for idx, step in enumerate(response.trajectories[0].steps, start=1):
|
|
66
|
-
print(
|
|
66
|
+
print(
|
|
67
|
+
f"step[{idx}] tool_calls={step.tool_calls} reward={step.reward} info={to_jsonable(step.info)}"
|
|
68
|
+
)
|
|
67
69
|
|
|
68
70
|
|
|
69
71
|
def _parse_args() -> argparse.Namespace:
|
|
70
72
|
parser = argparse.ArgumentParser(description=__doc__)
|
|
71
|
-
parser.add_argument(
|
|
72
|
-
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--base-url", default=os.getenv("TASK_APP_BASE_URL", "http://localhost:8000")
|
|
75
|
+
)
|
|
76
|
+
parser.add_argument(
|
|
77
|
+
"--api-key",
|
|
78
|
+
default=os.getenv("TASK_APP_API_KEY"),
|
|
79
|
+
required=os.getenv("TASK_APP_API_KEY") is None,
|
|
80
|
+
)
|
|
73
81
|
parser.add_argument("--model", default=os.getenv("GROQ_MODEL", "groq/mixtral-8x7b"))
|
|
74
82
|
parser.add_argument("--inference-url", default=os.getenv("TASK_APP_INFERENCE_URL"))
|
|
75
83
|
parser.add_argument("--seed", type=int, default=int(os.getenv("CRAFTER_TEST_SEED", "42")))
|
|
@@ -85,4 +93,3 @@ def main() -> None:
|
|
|
85
93
|
|
|
86
94
|
if __name__ == "__main__":
|
|
87
95
|
main()
|
|
88
|
-
|
|
@@ -34,7 +34,9 @@ def write_temp_env(kv: Dict[str, str]) -> Path:
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def run(cmd: str) -> Tuple[int, str]:
|
|
37
|
-
proc = subprocess.run(
|
|
37
|
+
proc = subprocess.run(
|
|
38
|
+
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
|
39
|
+
)
|
|
38
40
|
return proc.returncode, proc.stdout
|
|
39
41
|
|
|
40
42
|
|
|
@@ -44,11 +46,14 @@ def ensure_secret(secret_name: str, kv: Dict[str, str]) -> None:
|
|
|
44
46
|
return
|
|
45
47
|
# Prefer passing KEY=VALUE pairs to avoid Typer --env-file bug under some shells
|
|
46
48
|
kv_args = " ".join([f"{shlex.quote(k)}={shlex.quote(v)}" for k, v in kv.items()])
|
|
49
|
+
|
|
47
50
|
# Try plain modal first; fallback to uv run modal
|
|
48
51
|
def _create() -> Tuple[int, str]:
|
|
49
52
|
return run(f"modal secret create {shlex.quote(secret_name)} {kv_args}")
|
|
53
|
+
|
|
50
54
|
def _delete() -> Tuple[int, str]:
|
|
51
55
|
return run(f"printf 'y\n' | modal secret delete {shlex.quote(secret_name)}")
|
|
56
|
+
|
|
52
57
|
rc, out = _create()
|
|
53
58
|
if rc != 0:
|
|
54
59
|
# Fallback: use uv run modal
|
|
@@ -70,8 +75,12 @@ def ensure_secret(secret_name: str, kv: Dict[str, str]) -> None:
|
|
|
70
75
|
|
|
71
76
|
|
|
72
77
|
def main() -> None:
|
|
73
|
-
ap = argparse.ArgumentParser(
|
|
74
|
-
|
|
78
|
+
ap = argparse.ArgumentParser(
|
|
79
|
+
description="Sync .env keys into Modal secret bundles for the task app"
|
|
80
|
+
)
|
|
81
|
+
ap.add_argument(
|
|
82
|
+
"--env-path", default=str(Path(__file__).parent / ".env"), help="Path to .env with keys"
|
|
83
|
+
)
|
|
75
84
|
args = ap.parse_args()
|
|
76
85
|
|
|
77
86
|
env = load_env_file(Path(args.env_path))
|
|
@@ -105,7 +114,9 @@ def main() -> None:
|
|
|
105
114
|
}
|
|
106
115
|
|
|
107
116
|
# Optional: backend key (not mounted by task app today, but useful to keep consistent)
|
|
108
|
-
synth_secret =
|
|
117
|
+
synth_secret = (
|
|
118
|
+
{"SYNTH_API_KEY": env.get("SYNTH_API_KEY", "")} if env.get("SYNTH_API_KEY") else {}
|
|
119
|
+
)
|
|
109
120
|
|
|
110
121
|
ensure_secret("crafter-environment-sdk", env_secret)
|
|
111
122
|
ensure_secret("groq-api-key", groq_secret)
|
|
@@ -123,5 +134,3 @@ if __name__ == "__main__":
|
|
|
123
134
|
except Exception as e:
|
|
124
135
|
print(f"[error] {type(e).__name__}: {e}")
|
|
125
136
|
sys.exit(1)
|
|
126
|
-
|
|
127
|
-
|
|
@@ -87,9 +87,16 @@ Evaluation scripts auto-load `.env` values. Update TOMLs under `configs/` with t
|
|
|
87
87
|
|
|
88
88
|
## 4. Tracing and SFT Dataset Export
|
|
89
89
|
|
|
90
|
-
1. Serve the task app with tracing enabled (see Section 2)
|
|
90
|
+
1. Serve the task app with tracing enabled (see Section 2). Optionally, run the traced rollout helper against the running server:
|
|
91
91
|
```bash
|
|
92
|
-
uv run python examples/warming_up_to_rl/run_local_rollout_traced.py
|
|
92
|
+
uv run python examples/warming_up_to_rl/run_local_rollout_traced.py \
|
|
93
|
+
--base-url http://localhost:8001 \
|
|
94
|
+
--api-key "$ENVIRONMENT_API_KEY" \
|
|
95
|
+
--inference-api-key "$GROQ_API_KEY" \
|
|
96
|
+
--model qwen/qwen3-32b \
|
|
97
|
+
--inference-url https://api.groq.com/openai \
|
|
98
|
+
--max-llm-calls 3 \
|
|
99
|
+
--run-id local-trace
|
|
93
100
|
```
|
|
94
101
|
2. Inspect local trace databases:
|
|
95
102
|
```bash
|
|
@@ -5,6 +5,7 @@ Baseline evaluation script (public-friendly skeleton)
|
|
|
5
5
|
- Uses a TaskAppClient interface (to be implemented in synth-ai SDK)
|
|
6
6
|
- Keeps structure aligned with research/testing/crafter eval harness
|
|
7
7
|
"""
|
|
8
|
+
|
|
8
9
|
from __future__ import annotations
|
|
9
10
|
import os
|
|
10
11
|
import json
|
|
@@ -17,6 +18,7 @@ import argparse
|
|
|
17
18
|
import tomllib
|
|
18
19
|
from pathlib import Path
|
|
19
20
|
|
|
21
|
+
|
|
20
22
|
class TaskAppClient:
|
|
21
23
|
"""Minimal async client for the task app initialize/step/terminate routes.
|
|
22
24
|
|
|
@@ -68,7 +70,9 @@ class TaskAppClient:
|
|
|
68
70
|
resp.raise_for_status()
|
|
69
71
|
return resp.json()
|
|
70
72
|
|
|
71
|
-
async def step(
|
|
73
|
+
async def step(
|
|
74
|
+
self, env_name: str, env_id: str, tool_calls: List[Dict[str, Any]]
|
|
75
|
+
) -> Dict[str, Any]:
|
|
72
76
|
"""POST /env/{env_name}/step with wrapped tool_calls in action."""
|
|
73
77
|
payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
|
|
74
78
|
resp = await self.client.post(f"/env/{env_name}/step", json=payload)
|
|
@@ -102,7 +106,17 @@ class TaskAppClient:
|
|
|
102
106
|
return {"error": data}
|
|
103
107
|
return data
|
|
104
108
|
|
|
105
|
-
async def rollout(
|
|
109
|
+
async def rollout(
|
|
110
|
+
self,
|
|
111
|
+
*,
|
|
112
|
+
run_id: str,
|
|
113
|
+
env_name: str,
|
|
114
|
+
seed: int,
|
|
115
|
+
difficulty: str,
|
|
116
|
+
policy_name: str,
|
|
117
|
+
policy_config: Dict[str, Any],
|
|
118
|
+
max_turns: int,
|
|
119
|
+
) -> Dict[str, Any]:
|
|
106
120
|
ops: List[str] = []
|
|
107
121
|
for _ in range(max_turns):
|
|
108
122
|
ops.extend(["agent", "env"])
|
|
@@ -128,30 +142,37 @@ class TaskAppClient:
|
|
|
128
142
|
resp.raise_for_status()
|
|
129
143
|
return resp.json()
|
|
130
144
|
|
|
145
|
+
|
|
131
146
|
TASK_APP_URL = os.getenv("TASK_APP_URL", "https://YOUR-TASK-APP.modal.run").rstrip("/")
|
|
132
147
|
MODEL = os.getenv("EVAL_MODEL", "qwen/qwen3-32b")
|
|
133
148
|
NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
|
|
134
149
|
MAX_TURNS = int(os.getenv("MAX_TURNS", "10"))
|
|
135
150
|
CONCURRENCY = int(os.getenv("CONCURRENCY", "1"))
|
|
136
151
|
|
|
152
|
+
|
|
137
153
|
def _interact_tool_schema() -> List[Dict[str, Any]]:
|
|
138
|
-
return [
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
"
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
"
|
|
145
|
-
|
|
146
|
-
"
|
|
147
|
-
|
|
154
|
+
return [
|
|
155
|
+
{
|
|
156
|
+
"type": "function",
|
|
157
|
+
"function": {
|
|
158
|
+
"name": "interact",
|
|
159
|
+
"description": "Perform actions in the Crafter environment.",
|
|
160
|
+
"parameters": {
|
|
161
|
+
"type": "object",
|
|
162
|
+
"properties": {
|
|
163
|
+
"actions": {"type": "array", "items": {"type": "string"}},
|
|
164
|
+
"reasoning": {"type": "string"},
|
|
165
|
+
},
|
|
166
|
+
"required": ["actions", "reasoning"],
|
|
148
167
|
},
|
|
149
|
-
"required": ["actions", "reasoning"],
|
|
150
168
|
},
|
|
151
|
-
}
|
|
152
|
-
|
|
169
|
+
}
|
|
170
|
+
]
|
|
171
|
+
|
|
153
172
|
|
|
154
|
-
def _build_messages_from_observation(
|
|
173
|
+
def _build_messages_from_observation(
|
|
174
|
+
observation: Dict[str, Any], history: List[Dict[str, Any]]
|
|
175
|
+
) -> List[Dict[str, Any]]:
|
|
155
176
|
inv = observation.get("inventory") or {}
|
|
156
177
|
pos = observation.get("player_position") or []
|
|
157
178
|
ach = observation.get("achievements_status") or {}
|
|
@@ -171,6 +192,7 @@ def _build_messages_from_observation(observation: Dict[str, Any], history: List[
|
|
|
171
192
|
content = "\n".join(user_lines)
|
|
172
193
|
return [{"role": "user", "content": content}]
|
|
173
194
|
|
|
195
|
+
|
|
174
196
|
def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
|
|
175
197
|
try:
|
|
176
198
|
choices = data.get("choices")
|
|
@@ -203,7 +225,11 @@ def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
|
|
|
203
225
|
if isinstance(content, str):
|
|
204
226
|
text = content
|
|
205
227
|
elif isinstance(content, list):
|
|
206
|
-
text = "\n".join(
|
|
228
|
+
text = "\n".join(
|
|
229
|
+
str(part.get("text"))
|
|
230
|
+
for part in content
|
|
231
|
+
if isinstance(part, dict) and part.get("text")
|
|
232
|
+
)
|
|
207
233
|
for raw in re.findall(r"\{[\s\S]*\}", text or ""):
|
|
208
234
|
try:
|
|
209
235
|
obj = json.loads(raw)
|
|
@@ -217,7 +243,14 @@ def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
|
|
|
217
243
|
pass
|
|
218
244
|
return []
|
|
219
245
|
|
|
220
|
-
|
|
246
|
+
|
|
247
|
+
async def _choose_actions_via_llm(
|
|
248
|
+
client: TaskAppClient,
|
|
249
|
+
provider: str,
|
|
250
|
+
model: str,
|
|
251
|
+
observation: Dict[str, Any],
|
|
252
|
+
history: List[Dict[str, Any]],
|
|
253
|
+
) -> List[str]:
|
|
221
254
|
messages = _build_messages_from_observation(observation, history)
|
|
222
255
|
payload: Dict[str, Any] = {
|
|
223
256
|
"model": model,
|
|
@@ -245,25 +278,31 @@ async def _choose_actions_via_llm(client: TaskAppClient, provider: str, model: s
|
|
|
245
278
|
actions = _parse_tool_calls_from_openai_response(data)
|
|
246
279
|
return actions or []
|
|
247
280
|
|
|
281
|
+
|
|
248
282
|
def _expand_actions_to_tool_calls(actions: List[str]) -> List[Dict[str, Any]]:
|
|
249
283
|
out: List[Dict[str, Any]] = []
|
|
250
284
|
for a in actions[:5]:
|
|
251
285
|
out.append({"tool": "interact", "args": {"action": a}})
|
|
252
286
|
return out
|
|
253
287
|
|
|
288
|
+
|
|
254
289
|
def _detect_provider(model: str) -> str:
|
|
255
290
|
m = (model or "").lower()
|
|
256
291
|
if "qwen/qwen3-32b" in m or "qwen-2.5-" in m or m.startswith("groq:"):
|
|
257
292
|
return "groq"
|
|
258
293
|
return "vllm"
|
|
259
294
|
|
|
260
|
-
|
|
295
|
+
|
|
296
|
+
def _rollout_inference_url_from_cfg(
|
|
297
|
+
cfg: Dict[str, Any], default_vllm: Optional[str]
|
|
298
|
+
) -> Optional[str]:
|
|
261
299
|
# Prefer explicit inference_url in TOML; else fall back to discovered vLLM base
|
|
262
300
|
url = cfg.get("inference_url")
|
|
263
301
|
if isinstance(url, str) and url:
|
|
264
302
|
return url
|
|
265
303
|
return default_vllm
|
|
266
304
|
|
|
305
|
+
|
|
267
306
|
async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
|
|
268
307
|
env_name = "CrafterClassic"
|
|
269
308
|
history: List[Dict[str, Any]] = []
|
|
@@ -271,7 +310,10 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
|
|
|
271
310
|
turns = 0
|
|
272
311
|
|
|
273
312
|
# Initialize environment
|
|
274
|
-
init_cfg: Dict[str, Any] = {
|
|
313
|
+
init_cfg: Dict[str, Any] = {
|
|
314
|
+
"seed": seed,
|
|
315
|
+
"world_config": {"difficulty": os.getenv("DIFFICULTY", "easy")},
|
|
316
|
+
}
|
|
275
317
|
created = await client.initialize(env_name, init_cfg)
|
|
276
318
|
env_id = created.get("env_id")
|
|
277
319
|
if not isinstance(env_id, str) or not env_id:
|
|
@@ -285,7 +327,9 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
|
|
|
285
327
|
try:
|
|
286
328
|
while turns < MAX_TURNS and not done:
|
|
287
329
|
# Ask LLM for actions; fallback to a simple exploratory pair
|
|
288
|
-
chosen_actions = await _choose_actions_via_llm(
|
|
330
|
+
chosen_actions = await _choose_actions_via_llm(
|
|
331
|
+
client, provider, MODEL, observation, history
|
|
332
|
+
)
|
|
289
333
|
if not chosen_actions:
|
|
290
334
|
chosen_actions = ["move_up", "do"]
|
|
291
335
|
tool_calls = _expand_actions_to_tool_calls(chosen_actions)
|
|
@@ -306,6 +350,7 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
|
|
|
306
350
|
|
|
307
351
|
return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
|
|
308
352
|
|
|
353
|
+
|
|
309
354
|
async def main() -> None:
|
|
310
355
|
# Best-effort load local .env if present (ensures ENVIRONMENT_API_KEY for rollout)
|
|
311
356
|
try:
|
|
@@ -322,9 +367,13 @@ async def main() -> None:
|
|
|
322
367
|
except Exception:
|
|
323
368
|
pass
|
|
324
369
|
|
|
325
|
-
parser = argparse.ArgumentParser(
|
|
370
|
+
parser = argparse.ArgumentParser(
|
|
371
|
+
description="Baseline eval against task app with optional TOML config"
|
|
372
|
+
)
|
|
326
373
|
parser.add_argument("--toml", help="Path to TOML config file", default=None)
|
|
327
|
-
parser.add_argument(
|
|
374
|
+
parser.add_argument(
|
|
375
|
+
"--use-rollout", action="store_true", help="Use server-side rollout endpoint for eval"
|
|
376
|
+
)
|
|
328
377
|
args = parser.parse_args()
|
|
329
378
|
|
|
330
379
|
global TASK_APP_URL, MODEL, NUM_EPISODES, MAX_TURNS, CONCURRENCY
|
|
@@ -346,10 +395,14 @@ async def main() -> None:
|
|
|
346
395
|
if env_url:
|
|
347
396
|
TASK_APP_URL = env_url.rstrip("/")
|
|
348
397
|
else:
|
|
349
|
-
raise RuntimeError(
|
|
398
|
+
raise RuntimeError(
|
|
399
|
+
"TASK_APP_URL is a placeholder. Set task_app_url in TOML or export TASK_APP_URL."
|
|
400
|
+
)
|
|
350
401
|
|
|
351
402
|
print(f"Task App: {TASK_APP_URL}")
|
|
352
|
-
print(
|
|
403
|
+
print(
|
|
404
|
+
f"Model: {MODEL} Episodes: {NUM_EPISODES} Max turns: {MAX_TURNS} Concurrency: {CONCURRENCY}"
|
|
405
|
+
)
|
|
353
406
|
sem = asyncio.Semaphore(max(CONCURRENCY, 1))
|
|
354
407
|
async with TaskAppClient(TASK_APP_URL, api_key=os.getenv("ENVIRONMENT_API_KEY")) as client:
|
|
355
408
|
if args.use_rollout:
|
|
@@ -359,6 +412,7 @@ async def main() -> None:
|
|
|
359
412
|
inf_url = _rollout_inference_url_from_cfg(cfg, default_vllm)
|
|
360
413
|
if not inf_url:
|
|
361
414
|
raise RuntimeError("Could not resolve inference URL for rollout")
|
|
415
|
+
|
|
362
416
|
async def _run(seed: int):
|
|
363
417
|
async with sem:
|
|
364
418
|
try:
|
|
@@ -368,7 +422,14 @@ async def main() -> None:
|
|
|
368
422
|
"model": cfg.get("model", MODEL),
|
|
369
423
|
"inference_url": inf_url,
|
|
370
424
|
}
|
|
371
|
-
for k in (
|
|
425
|
+
for k in (
|
|
426
|
+
"max_tokens",
|
|
427
|
+
"temperature",
|
|
428
|
+
"top_p",
|
|
429
|
+
"thinking_mode",
|
|
430
|
+
"thinking_budget",
|
|
431
|
+
"use_tools",
|
|
432
|
+
):
|
|
372
433
|
if k in cfg and cfg.get(k) is not None:
|
|
373
434
|
policy_cfg[k] = cfg.get(k)
|
|
374
435
|
|
|
@@ -385,8 +446,16 @@ async def main() -> None:
|
|
|
385
446
|
ach = []
|
|
386
447
|
try:
|
|
387
448
|
trajs = r.get("trajectories") or []
|
|
388
|
-
final_obs = (
|
|
389
|
-
|
|
449
|
+
final_obs = (
|
|
450
|
+
(trajs[0].get("final") or {}).get("observation")
|
|
451
|
+
if trajs and isinstance(trajs[0], dict)
|
|
452
|
+
else None
|
|
453
|
+
)
|
|
454
|
+
ach_map = (
|
|
455
|
+
(final_obs or {}).get("achievements_status")
|
|
456
|
+
if isinstance(final_obs, dict)
|
|
457
|
+
else None
|
|
458
|
+
)
|
|
390
459
|
if isinstance(ach_map, dict):
|
|
391
460
|
ach = sorted([k for k, v in ach_map.items() if v])
|
|
392
461
|
except Exception:
|
|
@@ -401,7 +470,11 @@ async def main() -> None:
|
|
|
401
470
|
return {"seed": seed, "turns": length, "achievements": ach}
|
|
402
471
|
except Exception as e:
|
|
403
472
|
return {"seed": seed, "turns": 0, "achievements": [], "error": str(e)}
|
|
404
|
-
|
|
473
|
+
|
|
474
|
+
results = await asyncio.gather(
|
|
475
|
+
*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)],
|
|
476
|
+
return_exceptions=False,
|
|
477
|
+
)
|
|
405
478
|
# Aggregate summary
|
|
406
479
|
counts = [len(r.get("achievements") or []) for r in results if isinstance(r, dict)]
|
|
407
480
|
turns = [int(r.get("turns") or 0) for r in results if isinstance(r, dict)]
|
|
@@ -424,11 +497,16 @@ async def main() -> None:
|
|
|
424
497
|
}
|
|
425
498
|
print(json.dumps(summary, indent=2))
|
|
426
499
|
else:
|
|
500
|
+
|
|
427
501
|
async def _run(seed: int):
|
|
428
502
|
async with sem:
|
|
429
503
|
return await eval_episode(client, seed)
|
|
430
|
-
|
|
504
|
+
|
|
505
|
+
results = await asyncio.gather(
|
|
506
|
+
*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)]
|
|
507
|
+
)
|
|
431
508
|
print(json.dumps({"episodes": results}, indent=2))
|
|
432
509
|
|
|
510
|
+
|
|
433
511
|
if __name__ == "__main__":
|
|
434
512
|
asyncio.run(main())
|