synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (157) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1286 -170
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
  151. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
  152. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  153. synth_ai/install_sqld.sh +0 -40
  154. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  155. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  156. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  157. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
@@ -233,7 +233,9 @@ def _normalise_tool_calls(tool_calls: list[dict[str, Any]] | None) -> list[dict[
233
233
  continue
234
234
  entry = dict(call)
235
235
 
236
- func_payload: dict[str, Any] | None = entry.get("function") if isinstance(entry.get("function"), dict) else None
236
+ func_payload: dict[str, Any] | None = (
237
+ entry.get("function") if isinstance(entry.get("function"), dict) else None
238
+ )
237
239
  name = entry.get("name") or (func_payload.get("name") if func_payload else None) or "tool"
238
240
 
239
241
  args = None
@@ -355,7 +357,10 @@ def build_sft_dataset(
355
357
  if not assistant_tool_calls:
356
358
  assistant_tool_calls = _normalise_tool_calls(record.get("output_tool_calls"))
357
359
 
358
- assistant_message: dict[str, Any] = {"role": "assistant", "content": assistant_content or ""}
360
+ assistant_message: dict[str, Any] = {
361
+ "role": "assistant",
362
+ "content": assistant_content or "",
363
+ }
359
364
  if assistant_tool_calls:
360
365
  assistant_message["tool_calls"] = assistant_tool_calls
361
366
 
@@ -426,27 +431,141 @@ def _validate_dataset(records: list[dict[str, Any]]) -> None:
426
431
  raise SystemExit(f"Validation error while exporting dataset:\n - {summary}")
427
432
 
428
433
 
434
+ def _find_trace_database() -> Path | None:
435
+ """Automatically discover the trace database in common locations."""
436
+
437
+ # Check for demo directory from state
438
+ try:
439
+ state_path = Path.home() / ".synth-ai" / "demo.json"
440
+ if state_path.exists():
441
+ import json
442
+
443
+ with state_path.open() as f:
444
+ data = json.load(f)
445
+ demo_dir = data.get("DEMO_DIR")
446
+ if demo_dir:
447
+ candidate = Path(demo_dir) / "traces" / "v3" / "synth_ai.db"
448
+ if candidate.exists():
449
+ return candidate
450
+ except Exception:
451
+ pass
452
+
453
+ # Search upward from current directory
454
+ cwd = Path.cwd()
455
+ for parent in [cwd] + list(cwd.parents):
456
+ candidate = parent / "traces" / "v3" / "synth_ai.db"
457
+ if candidate.exists():
458
+ return candidate
459
+
460
+ # Check standard locations
461
+ standard_locations = [
462
+ Path("traces/v3/synth_ai.db"),
463
+ Path("../traces/v3/synth_ai.db"),
464
+ Path.home() / "synth-ai" / "traces" / "v3" / "synth_ai.db",
465
+ ]
466
+
467
+ for location in standard_locations:
468
+ try:
469
+ if location.exists():
470
+ return location.resolve()
471
+ except Exception:
472
+ continue
473
+
474
+ return None
475
+
476
+
429
477
  def main() -> None:
430
478
  parser = argparse.ArgumentParser(description=__doc__)
431
- parser.add_argument("--db", type=Path, default=Path("traces/v3/synth_ai.db"), help="Path to tracing_v3 SQLite DB")
432
- parser.add_argument("--output", type=Path, required=True, help="Destination JSONL path for the exported dataset")
433
- parser.add_argument("--model", action="append", dest="models", help="Restrict to sessions whose dominant model matches (repeatable)")
434
- parser.add_argument("--provider", action="append", dest="providers", help="Restrict to sessions whose dominant provider matches (repeatable)")
435
- parser.add_argument("--min-unique", type=int, default=None, help="Minimum unique achievements per session")
436
- parser.add_argument("--max-unique", type=int, default=None, help="Maximum unique achievements per session")
479
+ parser.add_argument("--db", type=Path, default=None, help="Path to tracing_v3 SQLite DB")
480
+ parser.add_argument(
481
+ "--output",
482
+ type=Path,
483
+ required=False,
484
+ help="Destination JSONL path for the exported dataset",
485
+ )
486
+ parser.add_argument(
487
+ "--model",
488
+ action="append",
489
+ dest="models",
490
+ help="Restrict to sessions whose dominant model matches (repeatable)",
491
+ )
492
+ parser.add_argument(
493
+ "--provider",
494
+ action="append",
495
+ dest="providers",
496
+ help="Restrict to sessions whose dominant provider matches (repeatable)",
497
+ )
498
+ parser.add_argument(
499
+ "--min-unique", type=int, default=None, help="Minimum unique achievements per session"
500
+ )
501
+ parser.add_argument(
502
+ "--max-unique", type=int, default=None, help="Maximum unique achievements per session"
503
+ )
437
504
  parser.add_argument(
438
505
  "--exclude-achievement",
439
506
  action="append",
440
507
  dest="exclude_achievements",
441
508
  help="Achievements to ignore when evaluating --min-unique/--max-unique (repeatable)",
442
509
  )
443
- parser.add_argument("--require-achievement", action="append", dest="required_achievements", help="Require these outcome achievements (repeatable)")
444
- parser.add_argument("--min-outcome-reward", type=float, default=None, help="Minimum total outcome reward per session")
445
- parser.add_argument("--max-outcome-reward", type=float, default=None, help="Maximum total outcome reward per session")
446
- parser.add_argument("--event-reward", action="append", dest="event_reward_filters", help="Require reward_type[:min_total] in event_rewards (repeatable)")
447
- parser.add_argument("--limit", type=int, default=None, help="Maximum number of examples to emit")
510
+ parser.add_argument(
511
+ "--require-achievement",
512
+ action="append",
513
+ dest="required_achievements",
514
+ help="Require these outcome achievements (repeatable)",
515
+ )
516
+ parser.add_argument(
517
+ "--min-outcome-reward",
518
+ type=float,
519
+ default=None,
520
+ help="Minimum total outcome reward per session",
521
+ )
522
+ parser.add_argument(
523
+ "--max-outcome-reward",
524
+ type=float,
525
+ default=None,
526
+ help="Maximum total outcome reward per session",
527
+ )
528
+ parser.add_argument(
529
+ "--event-reward",
530
+ action="append",
531
+ dest="event_reward_filters",
532
+ help="Require reward_type[:min_total] in event_rewards (repeatable)",
533
+ )
534
+ parser.add_argument(
535
+ "--limit", type=int, default=None, help="Maximum number of examples to emit"
536
+ )
448
537
  args = parser.parse_args()
449
538
 
539
+ # Auto-discover database if not specified
540
+ db_path = args.db
541
+ if db_path is None:
542
+ db_path = _find_trace_database()
543
+ if db_path:
544
+ print(f"Found trace database: {db_path}")
545
+ else:
546
+ print("\nTrace database configuration:")
547
+ db_input = input("Trace database path [traces/v3/synth_ai.db]: ").strip()
548
+ db_path = Path(db_input) if db_input else Path("traces/v3/synth_ai.db")
549
+
550
+ if not db_path.exists():
551
+ print(f"Database not found: {db_path}", file=sys.stderr)
552
+ raise SystemExit(1)
553
+
554
+ output_path = args.output
555
+ if not output_path:
556
+ output_path = Path("ft_data/crafter_traces.jsonl")
557
+ print(f"Output will be written to: {output_path.resolve()}")
558
+
559
+ min_unique = args.min_unique
560
+ if min_unique is None:
561
+ min_unique = 0 # Default to including all traces
562
+ print(f"Minimum unique achievements filter: {min_unique} (all traces)")
563
+
564
+ # Override args with prompted values
565
+ args.db = db_path
566
+ args.output = output_path
567
+ args.min_unique = min_unique
568
+
450
569
  if not args.db.exists():
451
570
  print(f"Database not found: {args.db}", file=sys.stderr)
452
571
  raise SystemExit(1)
@@ -488,7 +607,11 @@ def main() -> None:
488
607
 
489
608
  outcome = outcome_data.get(session_id)
490
609
  total_reward = outcome["total_reward"] if outcome else 0.0
491
- final_achievements = outcome["achievements"] if outcome else session_final_achievements.get(session_id, set())
610
+ final_achievements = (
611
+ outcome["achievements"]
612
+ if outcome
613
+ else session_final_achievements.get(session_id, set())
614
+ )
492
615
 
493
616
  if args.min_outcome_reward is not None and total_reward < args.min_outcome_reward:
494
617
  continue
@@ -522,7 +645,9 @@ def main() -> None:
522
645
  )
523
646
 
524
647
  if not dataset:
525
- print("No rollout steps matched the filters (after session selection).", file=sys.stderr)
648
+ print(
649
+ "No rollout steps matched the filters (after session selection).", file=sys.stderr
650
+ )
526
651
  raise SystemExit(1)
527
652
 
528
653
  _validate_dataset(dataset)
@@ -530,7 +655,7 @@ def main() -> None:
530
655
  session_ids = {item.get("metadata", {}).get("session_id") for item in dataset}
531
656
  session_ids.discard(None)
532
657
  print(
533
- f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output}",
658
+ f"Wrote {len(dataset)} examples from {len(session_ids)} session(s) -> {args.output.resolve()}",
534
659
  file=sys.stderr,
535
660
  )
536
661
  finally:
@@ -63,13 +63,21 @@ async def run(args: argparse.Namespace) -> None:
63
63
  response = await client.rollout(request)
64
64
  print("rollout.metrics →", to_jsonable(response.metrics.model_dump()))
65
65
  for idx, step in enumerate(response.trajectories[0].steps, start=1):
66
- print(f"step[{idx}] tool_calls={step.tool_calls} reward={step.reward} info={to_jsonable(step.info)}")
66
+ print(
67
+ f"step[{idx}] tool_calls={step.tool_calls} reward={step.reward} info={to_jsonable(step.info)}"
68
+ )
67
69
 
68
70
 
69
71
  def _parse_args() -> argparse.Namespace:
70
72
  parser = argparse.ArgumentParser(description=__doc__)
71
- parser.add_argument("--base-url", default=os.getenv("TASK_APP_BASE_URL", "http://localhost:8000"))
72
- parser.add_argument("--api-key", default=os.getenv("TASK_APP_API_KEY"), required=os.getenv("TASK_APP_API_KEY") is None)
73
+ parser.add_argument(
74
+ "--base-url", default=os.getenv("TASK_APP_BASE_URL", "http://localhost:8000")
75
+ )
76
+ parser.add_argument(
77
+ "--api-key",
78
+ default=os.getenv("TASK_APP_API_KEY"),
79
+ required=os.getenv("TASK_APP_API_KEY") is None,
80
+ )
73
81
  parser.add_argument("--model", default=os.getenv("GROQ_MODEL", "groq/mixtral-8x7b"))
74
82
  parser.add_argument("--inference-url", default=os.getenv("TASK_APP_INFERENCE_URL"))
75
83
  parser.add_argument("--seed", type=int, default=int(os.getenv("CRAFTER_TEST_SEED", "42")))
@@ -85,4 +93,3 @@ def main() -> None:
85
93
 
86
94
  if __name__ == "__main__":
87
95
  main()
88
-
@@ -34,7 +34,9 @@ def write_temp_env(kv: Dict[str, str]) -> Path:
34
34
 
35
35
 
36
36
  def run(cmd: str) -> Tuple[int, str]:
37
- proc = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
37
+ proc = subprocess.run(
38
+ cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
39
+ )
38
40
  return proc.returncode, proc.stdout
39
41
 
40
42
 
@@ -44,11 +46,14 @@ def ensure_secret(secret_name: str, kv: Dict[str, str]) -> None:
44
46
  return
45
47
  # Prefer passing KEY=VALUE pairs to avoid Typer --env-file bug under some shells
46
48
  kv_args = " ".join([f"{shlex.quote(k)}={shlex.quote(v)}" for k, v in kv.items()])
49
+
47
50
  # Try plain modal first; fallback to uv run modal
48
51
  def _create() -> Tuple[int, str]:
49
52
  return run(f"modal secret create {shlex.quote(secret_name)} {kv_args}")
53
+
50
54
  def _delete() -> Tuple[int, str]:
51
55
  return run(f"printf 'y\n' | modal secret delete {shlex.quote(secret_name)}")
56
+
52
57
  rc, out = _create()
53
58
  if rc != 0:
54
59
  # Fallback: use uv run modal
@@ -70,8 +75,12 @@ def ensure_secret(secret_name: str, kv: Dict[str, str]) -> None:
70
75
 
71
76
 
72
77
  def main() -> None:
73
- ap = argparse.ArgumentParser(description="Sync .env keys into Modal secret bundles for the task app")
74
- ap.add_argument("--env-path", default=str(Path(__file__).parent / ".env"), help="Path to .env with keys")
78
+ ap = argparse.ArgumentParser(
79
+ description="Sync .env keys into Modal secret bundles for the task app"
80
+ )
81
+ ap.add_argument(
82
+ "--env-path", default=str(Path(__file__).parent / ".env"), help="Path to .env with keys"
83
+ )
75
84
  args = ap.parse_args()
76
85
 
77
86
  env = load_env_file(Path(args.env_path))
@@ -105,7 +114,9 @@ def main() -> None:
105
114
  }
106
115
 
107
116
  # Optional: backend key (not mounted by task app today, but useful to keep consistent)
108
- synth_secret = {"SYNTH_API_KEY": env.get("SYNTH_API_KEY", "")} if env.get("SYNTH_API_KEY") else {}
117
+ synth_secret = (
118
+ {"SYNTH_API_KEY": env.get("SYNTH_API_KEY", "")} if env.get("SYNTH_API_KEY") else {}
119
+ )
109
120
 
110
121
  ensure_secret("crafter-environment-sdk", env_secret)
111
122
  ensure_secret("groq-api-key", groq_secret)
@@ -123,5 +134,3 @@ if __name__ == "__main__":
123
134
  except Exception as e:
124
135
  print(f"[error] {type(e).__name__}: {e}")
125
136
  sys.exit(1)
126
-
127
-
@@ -87,9 +87,16 @@ Evaluation scripts auto-load `.env` values. Update TOMLs under `configs/` with t
87
87
 
88
88
  ## 4. Tracing and SFT Dataset Export
89
89
 
90
- 1. Serve the task app with tracing enabled (see Section 2) or run the traced rollout helper:
90
+ 1. Serve the task app with tracing enabled (see Section 2). Optionally, run the traced rollout helper against the running server:
91
91
  ```bash
92
- uv run python examples/warming_up_to_rl/run_local_rollout_traced.py --episodes 10 --difficulty easy
92
+ uv run python examples/warming_up_to_rl/run_local_rollout_traced.py \
93
+ --base-url http://localhost:8001 \
94
+ --api-key "$ENVIRONMENT_API_KEY" \
95
+ --inference-api-key "$GROQ_API_KEY" \
96
+ --model qwen/qwen3-32b \
97
+ --inference-url https://api.groq.com/openai \
98
+ --max-llm-calls 3 \
99
+ --run-id local-trace
93
100
  ```
94
101
  2. Inspect local trace databases:
95
102
  ```bash
@@ -5,6 +5,7 @@ Baseline evaluation script (public-friendly skeleton)
5
5
  - Uses a TaskAppClient interface (to be implemented in synth-ai SDK)
6
6
  - Keeps structure aligned with research/testing/crafter eval harness
7
7
  """
8
+
8
9
  from __future__ import annotations
9
10
  import os
10
11
  import json
@@ -17,6 +18,7 @@ import argparse
17
18
  import tomllib
18
19
  from pathlib import Path
19
20
 
21
+
20
22
  class TaskAppClient:
21
23
  """Minimal async client for the task app initialize/step/terminate routes.
22
24
 
@@ -68,7 +70,9 @@ class TaskAppClient:
68
70
  resp.raise_for_status()
69
71
  return resp.json()
70
72
 
71
- async def step(self, env_name: str, env_id: str, tool_calls: List[Dict[str, Any]]) -> Dict[str, Any]:
73
+ async def step(
74
+ self, env_name: str, env_id: str, tool_calls: List[Dict[str, Any]]
75
+ ) -> Dict[str, Any]:
72
76
  """POST /env/{env_name}/step with wrapped tool_calls in action."""
73
77
  payload = {"env_id": env_id, "action": {"tool_calls": tool_calls}}
74
78
  resp = await self.client.post(f"/env/{env_name}/step", json=payload)
@@ -102,7 +106,17 @@ class TaskAppClient:
102
106
  return {"error": data}
103
107
  return data
104
108
 
105
- async def rollout(self, *, run_id: str, env_name: str, seed: int, difficulty: str, policy_name: str, policy_config: Dict[str, Any], max_turns: int) -> Dict[str, Any]:
109
+ async def rollout(
110
+ self,
111
+ *,
112
+ run_id: str,
113
+ env_name: str,
114
+ seed: int,
115
+ difficulty: str,
116
+ policy_name: str,
117
+ policy_config: Dict[str, Any],
118
+ max_turns: int,
119
+ ) -> Dict[str, Any]:
106
120
  ops: List[str] = []
107
121
  for _ in range(max_turns):
108
122
  ops.extend(["agent", "env"])
@@ -128,30 +142,37 @@ class TaskAppClient:
128
142
  resp.raise_for_status()
129
143
  return resp.json()
130
144
 
145
+
131
146
  TASK_APP_URL = os.getenv("TASK_APP_URL", "https://YOUR-TASK-APP.modal.run").rstrip("/")
132
147
  MODEL = os.getenv("EVAL_MODEL", "qwen/qwen3-32b")
133
148
  NUM_EPISODES = int(os.getenv("NUM_EPISODES", "3"))
134
149
  MAX_TURNS = int(os.getenv("MAX_TURNS", "10"))
135
150
  CONCURRENCY = int(os.getenv("CONCURRENCY", "1"))
136
151
 
152
+
137
153
  def _interact_tool_schema() -> List[Dict[str, Any]]:
138
- return [{
139
- "type": "function",
140
- "function": {
141
- "name": "interact",
142
- "description": "Perform actions in the Crafter environment.",
143
- "parameters": {
144
- "type": "object",
145
- "properties": {
146
- "actions": {"type": "array", "items": {"type": "string"}},
147
- "reasoning": {"type": "string"},
154
+ return [
155
+ {
156
+ "type": "function",
157
+ "function": {
158
+ "name": "interact",
159
+ "description": "Perform actions in the Crafter environment.",
160
+ "parameters": {
161
+ "type": "object",
162
+ "properties": {
163
+ "actions": {"type": "array", "items": {"type": "string"}},
164
+ "reasoning": {"type": "string"},
165
+ },
166
+ "required": ["actions", "reasoning"],
148
167
  },
149
- "required": ["actions", "reasoning"],
150
168
  },
151
- },
152
- }]
169
+ }
170
+ ]
171
+
153
172
 
154
- def _build_messages_from_observation(observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
173
+ def _build_messages_from_observation(
174
+ observation: Dict[str, Any], history: List[Dict[str, Any]]
175
+ ) -> List[Dict[str, Any]]:
155
176
  inv = observation.get("inventory") or {}
156
177
  pos = observation.get("player_position") or []
157
178
  ach = observation.get("achievements_status") or {}
@@ -171,6 +192,7 @@ def _build_messages_from_observation(observation: Dict[str, Any], history: List[
171
192
  content = "\n".join(user_lines)
172
193
  return [{"role": "user", "content": content}]
173
194
 
195
+
174
196
  def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
175
197
  try:
176
198
  choices = data.get("choices")
@@ -203,7 +225,11 @@ def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
203
225
  if isinstance(content, str):
204
226
  text = content
205
227
  elif isinstance(content, list):
206
- text = "\n".join(str(part.get("text")) for part in content if isinstance(part, dict) and part.get("text"))
228
+ text = "\n".join(
229
+ str(part.get("text"))
230
+ for part in content
231
+ if isinstance(part, dict) and part.get("text")
232
+ )
207
233
  for raw in re.findall(r"\{[\s\S]*\}", text or ""):
208
234
  try:
209
235
  obj = json.loads(raw)
@@ -217,7 +243,14 @@ def _parse_tool_calls_from_openai_response(data: Dict[str, Any]) -> List[str]:
217
243
  pass
218
244
  return []
219
245
 
220
- async def _choose_actions_via_llm(client: TaskAppClient, provider: str, model: str, observation: Dict[str, Any], history: List[Dict[str, Any]]) -> List[str]:
246
+
247
+ async def _choose_actions_via_llm(
248
+ client: TaskAppClient,
249
+ provider: str,
250
+ model: str,
251
+ observation: Dict[str, Any],
252
+ history: List[Dict[str, Any]],
253
+ ) -> List[str]:
221
254
  messages = _build_messages_from_observation(observation, history)
222
255
  payload: Dict[str, Any] = {
223
256
  "model": model,
@@ -245,25 +278,31 @@ async def _choose_actions_via_llm(client: TaskAppClient, provider: str, model: s
245
278
  actions = _parse_tool_calls_from_openai_response(data)
246
279
  return actions or []
247
280
 
281
+
248
282
  def _expand_actions_to_tool_calls(actions: List[str]) -> List[Dict[str, Any]]:
249
283
  out: List[Dict[str, Any]] = []
250
284
  for a in actions[:5]:
251
285
  out.append({"tool": "interact", "args": {"action": a}})
252
286
  return out
253
287
 
288
+
254
289
  def _detect_provider(model: str) -> str:
255
290
  m = (model or "").lower()
256
291
  if "qwen/qwen3-32b" in m or "qwen-2.5-" in m or m.startswith("groq:"):
257
292
  return "groq"
258
293
  return "vllm"
259
294
 
260
- def _rollout_inference_url_from_cfg(cfg: Dict[str, Any], default_vllm: Optional[str]) -> Optional[str]:
295
+
296
+ def _rollout_inference_url_from_cfg(
297
+ cfg: Dict[str, Any], default_vllm: Optional[str]
298
+ ) -> Optional[str]:
261
299
  # Prefer explicit inference_url in TOML; else fall back to discovered vLLM base
262
300
  url = cfg.get("inference_url")
263
301
  if isinstance(url, str) and url:
264
302
  return url
265
303
  return default_vllm
266
304
 
305
+
267
306
  async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
268
307
  env_name = "CrafterClassic"
269
308
  history: List[Dict[str, Any]] = []
@@ -271,7 +310,10 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
271
310
  turns = 0
272
311
 
273
312
  # Initialize environment
274
- init_cfg: Dict[str, Any] = {"seed": seed, "world_config": {"difficulty": os.getenv("DIFFICULTY", "easy")}}
313
+ init_cfg: Dict[str, Any] = {
314
+ "seed": seed,
315
+ "world_config": {"difficulty": os.getenv("DIFFICULTY", "easy")},
316
+ }
275
317
  created = await client.initialize(env_name, init_cfg)
276
318
  env_id = created.get("env_id")
277
319
  if not isinstance(env_id, str) or not env_id:
@@ -285,7 +327,9 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
285
327
  try:
286
328
  while turns < MAX_TURNS and not done:
287
329
  # Ask LLM for actions; fallback to a simple exploratory pair
288
- chosen_actions = await _choose_actions_via_llm(client, provider, MODEL, observation, history)
330
+ chosen_actions = await _choose_actions_via_llm(
331
+ client, provider, MODEL, observation, history
332
+ )
289
333
  if not chosen_actions:
290
334
  chosen_actions = ["move_up", "do"]
291
335
  tool_calls = _expand_actions_to_tool_calls(chosen_actions)
@@ -306,6 +350,7 @@ async def eval_episode(client: TaskAppClient, seed: int) -> Dict[str, Any]:
306
350
 
307
351
  return {"seed": seed, "turns": turns, "achievements": sorted(achievements)}
308
352
 
353
+
309
354
  async def main() -> None:
310
355
  # Best-effort load local .env if present (ensures ENVIRONMENT_API_KEY for rollout)
311
356
  try:
@@ -322,9 +367,13 @@ async def main() -> None:
322
367
  except Exception:
323
368
  pass
324
369
 
325
- parser = argparse.ArgumentParser(description="Baseline eval against task app with optional TOML config")
370
+ parser = argparse.ArgumentParser(
371
+ description="Baseline eval against task app with optional TOML config"
372
+ )
326
373
  parser.add_argument("--toml", help="Path to TOML config file", default=None)
327
- parser.add_argument("--use-rollout", action="store_true", help="Use server-side rollout endpoint for eval")
374
+ parser.add_argument(
375
+ "--use-rollout", action="store_true", help="Use server-side rollout endpoint for eval"
376
+ )
328
377
  args = parser.parse_args()
329
378
 
330
379
  global TASK_APP_URL, MODEL, NUM_EPISODES, MAX_TURNS, CONCURRENCY
@@ -346,10 +395,14 @@ async def main() -> None:
346
395
  if env_url:
347
396
  TASK_APP_URL = env_url.rstrip("/")
348
397
  else:
349
- raise RuntimeError("TASK_APP_URL is a placeholder. Set task_app_url in TOML or export TASK_APP_URL.")
398
+ raise RuntimeError(
399
+ "TASK_APP_URL is a placeholder. Set task_app_url in TOML or export TASK_APP_URL."
400
+ )
350
401
 
351
402
  print(f"Task App: {TASK_APP_URL}")
352
- print(f"Model: {MODEL} Episodes: {NUM_EPISODES} Max turns: {MAX_TURNS} Concurrency: {CONCURRENCY}")
403
+ print(
404
+ f"Model: {MODEL} Episodes: {NUM_EPISODES} Max turns: {MAX_TURNS} Concurrency: {CONCURRENCY}"
405
+ )
353
406
  sem = asyncio.Semaphore(max(CONCURRENCY, 1))
354
407
  async with TaskAppClient(TASK_APP_URL, api_key=os.getenv("ENVIRONMENT_API_KEY")) as client:
355
408
  if args.use_rollout:
@@ -359,6 +412,7 @@ async def main() -> None:
359
412
  inf_url = _rollout_inference_url_from_cfg(cfg, default_vllm)
360
413
  if not inf_url:
361
414
  raise RuntimeError("Could not resolve inference URL for rollout")
415
+
362
416
  async def _run(seed: int):
363
417
  async with sem:
364
418
  try:
@@ -368,7 +422,14 @@ async def main() -> None:
368
422
  "model": cfg.get("model", MODEL),
369
423
  "inference_url": inf_url,
370
424
  }
371
- for k in ("max_tokens", "temperature", "top_p", "thinking_mode", "thinking_budget", "use_tools"):
425
+ for k in (
426
+ "max_tokens",
427
+ "temperature",
428
+ "top_p",
429
+ "thinking_mode",
430
+ "thinking_budget",
431
+ "use_tools",
432
+ ):
372
433
  if k in cfg and cfg.get(k) is not None:
373
434
  policy_cfg[k] = cfg.get(k)
374
435
 
@@ -385,8 +446,16 @@ async def main() -> None:
385
446
  ach = []
386
447
  try:
387
448
  trajs = r.get("trajectories") or []
388
- final_obs = (trajs[0].get("final") or {}).get("observation") if trajs and isinstance(trajs[0], dict) else None
389
- ach_map = (final_obs or {}).get("achievements_status") if isinstance(final_obs, dict) else None
449
+ final_obs = (
450
+ (trajs[0].get("final") or {}).get("observation")
451
+ if trajs and isinstance(trajs[0], dict)
452
+ else None
453
+ )
454
+ ach_map = (
455
+ (final_obs or {}).get("achievements_status")
456
+ if isinstance(final_obs, dict)
457
+ else None
458
+ )
390
459
  if isinstance(ach_map, dict):
391
460
  ach = sorted([k for k, v in ach_map.items() if v])
392
461
  except Exception:
@@ -401,7 +470,11 @@ async def main() -> None:
401
470
  return {"seed": seed, "turns": length, "achievements": ach}
402
471
  except Exception as e:
403
472
  return {"seed": seed, "turns": 0, "achievements": [], "error": str(e)}
404
- results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)], return_exceptions=False)
473
+
474
+ results = await asyncio.gather(
475
+ *[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)],
476
+ return_exceptions=False,
477
+ )
405
478
  # Aggregate summary
406
479
  counts = [len(r.get("achievements") or []) for r in results if isinstance(r, dict)]
407
480
  turns = [int(r.get("turns") or 0) for r in results if isinstance(r, dict)]
@@ -424,11 +497,16 @@ async def main() -> None:
424
497
  }
425
498
  print(json.dumps(summary, indent=2))
426
499
  else:
500
+
427
501
  async def _run(seed: int):
428
502
  async with sem:
429
503
  return await eval_episode(client, seed)
430
- results = await asyncio.gather(*[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)])
504
+
505
+ results = await asyncio.gather(
506
+ *[asyncio.create_task(_run(i)) for i in range(1, NUM_EPISODES + 1)]
507
+ )
431
508
  print(json.dumps({"episodes": results}, indent=2))
432
509
 
510
+
433
511
  if __name__ == "__main__":
434
512
  asyncio.run(main())