synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (155) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1276 -186
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
  151. synth_ai/install_sqld.sh +0 -40
  152. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  153. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  154. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  155. {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ from .registry import registry
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
+
23
24
  # --- Seeding utilities (robust, optional deps) ---
24
25
  def _set_global_seed(seed_value: int) -> Dict[str, Any]:
25
26
  """Set global RNG seeds across common libraries; return details for logging/restoration.
@@ -29,18 +30,21 @@ def _set_global_seed(seed_value: int) -> Dict[str, Any]:
29
30
  seeded: Dict[str, Any] = {"seed": int(seed_value), "libs": []}
30
31
  try:
31
32
  import random as _random # type: ignore
33
+
32
34
  _random.seed(seed_value)
33
35
  seeded["libs"].append("random")
34
36
  except Exception:
35
37
  pass
36
38
  try:
37
39
  import numpy as _np # type: ignore
40
+
38
41
  _np.random.seed(seed_value)
39
42
  seeded["libs"].append("numpy")
40
43
  except Exception:
41
44
  pass
42
45
  try:
43
46
  import torch as _torch # type: ignore
47
+
44
48
  if hasattr(_torch, "manual_seed"):
45
49
  _torch.manual_seed(seed_value)
46
50
  seeded["libs"].append("torch")
@@ -62,12 +66,14 @@ def _set_global_seed(seed_value: int) -> Dict[str, Any]:
62
66
  pass
63
67
  return seeded
64
68
 
69
+
65
70
  def _clear_seed_side_effects() -> None:
66
71
  """Best-effort cleanup to avoid global deterministic side-effects between requests."""
67
72
  # We cannot truly restore prior RNG states without capturing them; we just avoid
68
73
  # leaving aggressive deterministic flags enabled where it matters.
69
74
  try:
70
75
  import torch as _torch # type: ignore
76
+
71
77
  try:
72
78
  if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
73
79
  # Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
@@ -78,6 +84,7 @@ def _clear_seed_side_effects() -> None:
78
84
  except Exception:
79
85
  pass
80
86
 
87
+
81
88
  router = APIRouter()
82
89
 
83
90
 
@@ -161,11 +168,7 @@ def compute_stepwise_reward(
161
168
  prev_map = prev_achievements or {}
162
169
  next_map = new_achievements or {}
163
170
 
164
- unlocked = [
165
- name
166
- for name, value in next_map.items()
167
- if value and not prev_map.get(name, False)
168
- ]
171
+ unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
169
172
  indicator = 1 if unlocked else 0
170
173
  reward_value = float(indicator_lambda) * indicator
171
174
 
@@ -227,7 +230,9 @@ class RolloutTracingContext:
227
230
  self.sft_records: list[dict[str, Any]] = []
228
231
  self.latest_system_messages: list[str] = []
229
232
  self.latest_user_messages: list[str] = []
230
- self.trace_format = (getattr(request.record, "trace_format", "compact") or "compact").lower()
233
+ self.trace_format = (
234
+ getattr(request.record, "trace_format", "compact") or "compact"
235
+ ).lower()
231
236
  self.return_trace = bool(getattr(request.record, "return_trace", False))
232
237
  self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
233
238
  self.session_trace = None
@@ -257,7 +262,9 @@ class RolloutTracingContext:
257
262
  except Exception as exc:
258
263
  logger.debug("TRACING_INIT_FAIL: %s", exc)
259
264
  try:
260
- await self.tracer.start_session(session_id=self.run_id, metadata=dict(self.metadata_base))
265
+ await self.tracer.start_session(
266
+ session_id=self.run_id, metadata=dict(self.metadata_base)
267
+ )
261
268
  except Exception as exc:
262
269
  logger.warning("TRACING_START_FAIL: %s", exc)
263
270
  self.enabled = False
@@ -379,17 +386,15 @@ class RolloutTracingContext:
379
386
  input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
380
387
  output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
381
388
  total_tokens = usage.get("total_tokens")
382
- cost_usd = (
383
- usage.get("cost_usd")
384
- or usage.get("cost")
385
- or usage.get("total_cost")
386
- )
389
+ cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
387
390
 
388
391
  assistant_message = None
389
392
  choices = inference_response.get("choices") or []
390
393
  if choices:
391
394
  assistant_message = choices[0].get("message") or {}
392
- assistant_content = assistant_message.get("content") if isinstance(assistant_message, dict) else None
395
+ assistant_content = (
396
+ assistant_message.get("content") if isinstance(assistant_message, dict) else None
397
+ )
393
398
 
394
399
  raw_response = self._content_to_text(assistant_content)
395
400
  if not raw_response:
@@ -397,7 +402,9 @@ class RolloutTracingContext:
397
402
 
398
403
  base_response = BaseLMResponse(
399
404
  raw_response=raw_response,
400
- tool_calls=assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else None,
405
+ tool_calls=assistant_message.get("tool_calls")
406
+ if isinstance(assistant_message, dict)
407
+ else None,
401
408
  usage=usage or None,
402
409
  api_type="chat_completions",
403
410
  )
@@ -469,7 +476,9 @@ class RolloutTracingContext:
469
476
  ),
470
477
  "assistant": {
471
478
  "content": assistant_text,
472
- "tool_calls": assistant_message.get("tool_calls") if isinstance(assistant_message, dict) else [],
479
+ "tool_calls": assistant_message.get("tool_calls")
480
+ if isinstance(assistant_message, dict)
481
+ else [],
473
482
  },
474
483
  "timestamp": datetime.utcnow().isoformat(),
475
484
  }
@@ -488,11 +497,19 @@ class RolloutTracingContext:
488
497
  return None
489
498
 
490
499
  try:
491
- prev_summary = _summarize_observation_for_storage(env_handle, prev_obs or {}) if prev_obs is not None else None
500
+ prev_summary = (
501
+ _summarize_observation_for_storage(env_handle, prev_obs or {})
502
+ if prev_obs is not None
503
+ else None
504
+ )
492
505
  except Exception:
493
506
  prev_summary = None
494
507
  try:
495
- next_summary = _summarize_observation_for_storage(env_handle, next_obs or {}) if next_obs is not None else None
508
+ next_summary = (
509
+ _summarize_observation_for_storage(env_handle, next_obs or {})
510
+ if next_obs is not None
511
+ else None
512
+ )
496
513
  except Exception:
497
514
  next_summary = None
498
515
 
@@ -640,7 +657,11 @@ class RolloutTracingContext:
640
657
  "lm_calls": self.lm_calls_summary,
641
658
  "decision_rewards": self.decision_rewards,
642
659
  }
643
- def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, Any]) -> Dict[str, Any]:
660
+
661
+
662
+ def _summarize_observation_for_storage(
663
+ env_handle: Any, observation: Dict[str, Any]
664
+ ) -> Dict[str, Any]:
644
665
  """Return a compact dict for trajectory storage instead of the raw observation.
645
666
 
646
667
  - For Crafter, use the same summary used for the policy user prompt
@@ -652,9 +673,12 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
652
673
  except Exception:
653
674
  _CrafterWrapper = None # type: ignore
654
675
 
655
- if _CrafterWrapper is not None and isinstance(getattr(env_handle, "env", None), _CrafterWrapper):
676
+ if _CrafterWrapper is not None and isinstance(
677
+ getattr(env_handle, "env", None), _CrafterWrapper
678
+ ):
656
679
  try:
657
680
  from .envs.crafter.shared import format_observation as _fmt # type: ignore
681
+
658
682
  text = _fmt(observation or {})
659
683
  return {"text": text}
660
684
  except Exception:
@@ -671,8 +695,12 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
671
695
  summary = {
672
696
  "position": pos,
673
697
  "health": health,
674
- "inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10] if isinstance(inv, dict) else None,
675
- "achievements_unlocked": sorted([k for k, v in (ach or {}).items() if v])[:10] if isinstance(ach, dict) else None,
698
+ "inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10]
699
+ if isinstance(inv, dict)
700
+ else None,
701
+ "achievements_unlocked": sorted([k for k, v in (ach or {}).items() if v])[:10]
702
+ if isinstance(ach, dict)
703
+ else None,
676
704
  }
677
705
  return {"text": json.dumps(summary, ensure_ascii=False)}
678
706
  except Exception:
@@ -685,7 +713,6 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
685
713
  return {"text": ""}
686
714
 
687
715
 
688
-
689
716
  class RunAbortRequest(BaseModel):
690
717
  run_id: str
691
718
 
@@ -857,9 +884,7 @@ async def execute_rollout(
857
884
  # Propagate training_session_id via env config for downstream usage
858
885
  _env_config = dict(request.env.config or {})
859
886
  if request.training_session_id is not None:
860
- _env_config.setdefault(
861
- "training_session_id", request.training_session_id
862
- )
887
+ _env_config.setdefault("training_session_id", request.training_session_id)
863
888
  env_response = await create_environment(
864
889
  EnvCreateRequest(
865
890
  env_name=request.env.env_name,
@@ -893,9 +918,7 @@ async def execute_rollout(
893
918
  # Propagate training_session_id and synth_base_url via policy config
894
919
  _policy_config = dict(request.policy.config or {})
895
920
  if request.training_session_id is not None:
896
- _policy_config.setdefault(
897
- "training_session_id", request.training_session_id
898
- )
921
+ _policy_config.setdefault("training_session_id", request.training_session_id)
899
922
  if request.synth_base_url is not None:
900
923
  _policy_config.setdefault("synth_base_url", request.synth_base_url)
901
924
  policy_response = await create_policy(
@@ -1065,7 +1088,10 @@ async def execute_rollout(
1065
1088
  _timing["decision_ms"] = decision_ms
1066
1089
  if last_env_step_ms is not None:
1067
1090
  _timing.setdefault("env_step_ms", float(last_env_step_ms))
1068
- _timing.setdefault("overhead_ms", max(0.0, decision_ms - float(last_env_step_ms)))
1091
+ _timing.setdefault(
1092
+ "overhead_ms",
1093
+ max(0.0, decision_ms - float(last_env_step_ms)),
1094
+ )
1069
1095
  else:
1070
1096
  _timing.setdefault("overhead_ms", 0.0)
1071
1097
  _meta["timing"] = _timing
@@ -1107,9 +1133,7 @@ async def execute_rollout(
1107
1133
  _first_guess = None
1108
1134
  if _count > 0 and isinstance(_prev_calls[0], dict):
1109
1135
  _args = (
1110
- _prev_calls[0]["arguments"]
1111
- if "arguments" in _prev_calls[0]
1112
- else None
1136
+ _prev_calls[0]["arguments"] if "arguments" in _prev_calls[0] else None
1113
1137
  )
1114
1138
  if isinstance(_args, str):
1115
1139
  import json as _json
@@ -1119,9 +1143,9 @@ async def execute_rollout(
1119
1143
  except Exception:
1120
1144
  _args = {}
1121
1145
  if isinstance(_args, dict):
1122
- _first_guess = (
1123
- _args["guess"] if "guess" in _args else None
1124
- ) or (_args["word"] if "word" in _args else None)
1146
+ _first_guess = (_args["guess"] if "guess" in _args else None) or (
1147
+ _args["word"] if "word" in _args else None
1148
+ )
1125
1149
  logger.info(
1126
1150
  "POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
1127
1151
  _count,
@@ -1377,7 +1401,9 @@ async def execute_rollout(
1377
1401
  (env_step_end - float(last_agent_response_ts)) * 1000.0,
1378
1402
  )
1379
1403
  timing_last["decision_ms"] = decision_ms
1380
- timing_last.setdefault("overhead_ms", max(0.0, decision_ms - env_step_duration_ms))
1404
+ timing_last.setdefault(
1405
+ "overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
1406
+ )
1381
1407
  except Exception:
1382
1408
  pass
1383
1409
  if decision_open:
@@ -1409,9 +1435,7 @@ async def execute_rollout(
1409
1435
  # Attach policy meta from the immediately preceding agent step
1410
1436
  try:
1411
1437
  prev_meta = {}
1412
- if "policy_response" in locals() and isinstance(
1413
- policy_response.meta, dict
1414
- ): # type: ignore[name-defined]
1438
+ if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
1415
1439
  prev_meta = policy_response.meta
1416
1440
  if prev_meta:
1417
1441
  _info = dict(_info)
@@ -1452,9 +1476,7 @@ async def execute_rollout(
1452
1476
  reward_stepwise = float(stats.get("reward", 0.0))
1453
1477
  stepwise_indicator_sum += float(stats.get("indicator", 0.0))
1454
1478
  stepwise_reward_sum += reward_stepwise
1455
- stepwise_new_achievements_total += int(
1456
- stats.get("new_achievements_count", 0.0)
1457
- )
1479
+ stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
1458
1480
  if not isinstance(_info, dict):
1459
1481
  _info = {}
1460
1482
  else:
@@ -1470,7 +1492,9 @@ async def execute_rollout(
1470
1492
  # Prepare stable lists for logging/metadata
1471
1493
  all_list = sorted(list(turned_true))
1472
1494
  # Ensure nested meta exists
1473
- meta_block = _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1495
+ meta_block = (
1496
+ _info.get("meta") if isinstance(_info.get("meta"), dict) else {}
1497
+ )
1474
1498
  decision_rewards = {
1475
1499
  "turn": int(decision_index),
1476
1500
  "ach_delta": ach_delta,
@@ -1521,9 +1545,7 @@ async def execute_rollout(
1521
1545
  EnvResetRequest,
1522
1546
  )
1523
1547
 
1524
- reset_response = await reset_environment(
1525
- EnvResetRequest(env_id=env_id)
1526
- )
1548
+ reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
1527
1549
  current_obs = reset_response.observation
1528
1550
  elif request.on_done == "terminate":
1529
1551
  break
@@ -1544,15 +1566,11 @@ async def execute_rollout(
1544
1566
  ):
1545
1567
  try:
1546
1568
  final_now = last_env_step_completed_ts or _time.perf_counter()
1547
- final_decision_ms = max(
1548
- 0.0, (final_now - float(last_agent_response_ts)) * 1000.0
1549
- )
1569
+ final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
1550
1570
  timing_final = last_policy_meta.setdefault("timing", {})
1551
1571
  timing_final["decision_ms"] = final_decision_ms
1552
1572
  if last_env_step_ms is not None:
1553
- timing_final.setdefault(
1554
- "env_step_ms", float(last_env_step_ms)
1555
- )
1573
+ timing_final.setdefault("env_step_ms", float(last_env_step_ms))
1556
1574
  timing_final.setdefault(
1557
1575
  "overhead_ms",
1558
1576
  max(0.0, final_decision_ms - float(last_env_step_ms)),
@@ -1601,10 +1619,11 @@ async def execute_rollout(
1601
1619
  for step in trajectory_steps:
1602
1620
  formatted_steps.append({"tool_calls": step.tool_calls or []})
1603
1621
 
1604
- if get_wordle_rollout_summary is not None and log_wordle_rollout_summary is not None:
1605
- summary = get_wordle_rollout_summary(
1606
- formatted_steps, current_obs, env_handle
1607
- )
1622
+ if (
1623
+ get_wordle_rollout_summary is not None
1624
+ and log_wordle_rollout_summary is not None
1625
+ ):
1626
+ summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
1608
1627
  log_wordle_rollout_summary(request.run_id, summary)
1609
1628
  except ImportError:
1610
1629
  # Wordle helpers not available, skip Wordle-specific logging
@@ -1681,9 +1700,7 @@ async def execute_rollout(
1681
1700
  except Exception:
1682
1701
  pass
1683
1702
  except Exception as _te:
1684
- logger.warning(
1685
- f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}"
1686
- )
1703
+ logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
1687
1704
 
1688
1705
  # Best-effort policy cleanup if we created one (avoid reuse across rollouts)
1689
1706
  try:
@@ -2,4 +2,4 @@
2
2
 
3
3
  from .volume import VolumeStorage, storage
4
4
 
5
- __all__ = ["VolumeStorage", "storage"]
5
+ __all__ = ["VolumeStorage", "storage"]
@@ -13,10 +13,10 @@ from typing import Any, Dict, Optional
13
13
 
14
14
  class VolumeStorage:
15
15
  """Helpers for Modal Volume storage operations."""
16
-
16
+
17
17
  def __init__(self, base_path: str = "/data/state") -> None:
18
18
  self.base_path = Path(base_path)
19
-
19
+
20
20
  def get_snapshot_path(
21
21
  self,
22
22
  rl_run_id: str,
@@ -27,21 +27,15 @@ class VolumeStorage:
27
27
  # Use first 2 chars of snapshot_id for sharding
28
28
  shard1 = snapshot_id[:2] if len(snapshot_id) >= 2 else "00"
29
29
  shard2 = snapshot_id[2:4] if len(snapshot_id) >= 4 else "00"
30
-
30
+
31
31
  return (
32
- self.base_path
33
- / "runs"
34
- / rl_run_id
35
- / kind
36
- / shard1
37
- / shard2
38
- / f"{snapshot_id}.tar.gz"
32
+ self.base_path / "runs" / rl_run_id / kind / shard1 / shard2 / f"{snapshot_id}.tar.gz"
39
33
  )
40
-
34
+
41
35
  def get_index_path(self, rl_run_id: str) -> Path:
42
36
  """Get the index file path for a run."""
43
37
  return self.base_path / "runs" / rl_run_id / "index" / "meta.jsonl"
44
-
38
+
45
39
  def write_snapshot_atomic(
46
40
  self,
47
41
  path: Path,
@@ -50,17 +44,17 @@ class VolumeStorage:
50
44
  """Atomically write a snapshot archive to disk."""
51
45
  # Ensure parent directory exists
52
46
  path.parent.mkdir(parents=True, exist_ok=True)
53
-
47
+
54
48
  # Write to temp file first
55
49
  tmp_path = path.with_suffix(".tmp")
56
50
  with open(tmp_path, "wb") as f:
57
51
  f.write(archive_bytes)
58
52
  f.flush()
59
53
  os.fsync(f.fileno())
60
-
54
+
61
55
  # Atomic rename
62
56
  os.replace(tmp_path, path)
63
-
57
+
64
58
  def create_archive(
65
59
  self,
66
60
  state_dict: Dict[str, Any],
@@ -69,61 +63,61 @@ class VolumeStorage:
69
63
  """Create a tar.gz archive with state and metadata."""
70
64
  with tempfile.TemporaryDirectory() as tmpdir:
71
65
  tmppath = Path(tmpdir)
72
-
66
+
73
67
  # Write state.json
74
68
  state_path = tmppath / "state.json"
75
69
  with open(state_path, "w") as f:
76
70
  json.dump(state_dict, f, sort_keys=True, indent=2)
77
-
71
+
78
72
  # Write meta.json
79
73
  meta_path = tmppath / "meta.json"
80
74
  with open(meta_path, "w") as f:
81
75
  json.dump(meta, f, sort_keys=True, indent=2)
82
-
76
+
83
77
  # Create tar archive
84
78
  tar_path = tmppath / "archive.tar"
85
79
  with tarfile.open(tar_path, "w") as tar:
86
80
  tar.add(state_path, arcname="state.json")
87
81
  tar.add(meta_path, arcname="meta.json")
88
-
82
+
89
83
  # Compress with gzip
90
84
  with open(tar_path, "rb") as f:
91
85
  tar_bytes = f.read()
92
-
86
+
93
87
  compressed = gzip.compress(tar_bytes, compresslevel=6)
94
-
88
+
95
89
  return compressed
96
-
90
+
97
91
  def extract_archive(self, archive_bytes: bytes) -> tuple[Dict[str, Any], Dict[str, Any]]:
98
92
  """Extract state and metadata from a tar.gz archive."""
99
93
  # Decompress
100
94
  tar_bytes = gzip.decompress(archive_bytes)
101
-
95
+
102
96
  with tempfile.TemporaryDirectory() as tmpdir:
103
97
  tmppath = Path(tmpdir)
104
-
98
+
105
99
  # Write tar bytes to temp file
106
100
  tar_path = tmppath / "archive.tar"
107
101
  with open(tar_path, "wb") as f:
108
102
  f.write(tar_bytes)
109
-
103
+
110
104
  # Extract tar
111
105
  with tarfile.open(tar_path, "r") as tar:
112
106
  tar.extractall(tmppath)
113
-
107
+
114
108
  # Read state and meta
115
109
  with open(tmppath / "state.json", "r") as f:
116
110
  state = json.load(f)
117
-
111
+
118
112
  with open(tmppath / "meta.json", "r") as f:
119
113
  meta = json.load(f)
120
-
114
+
121
115
  return state, meta
122
-
116
+
123
117
  def compute_snapshot_id(self, archive_bytes: bytes) -> str:
124
118
  """Compute content-addressed snapshot ID."""
125
119
  return hashlib.sha256(archive_bytes).hexdigest()
126
-
120
+
127
121
  def save_snapshot(
128
122
  self,
129
123
  rl_run_id: str,
@@ -140,33 +134,33 @@ class VolumeStorage:
140
134
  "schema_version": "1.0",
141
135
  "created_at": datetime.utcnow().isoformat(),
142
136
  }
143
-
137
+
144
138
  if parent_snapshot_id:
145
139
  meta["parent_snapshot_id"] = parent_snapshot_id
146
-
140
+
147
141
  if config:
148
142
  config_str = json.dumps(config, sort_keys=True)
149
143
  meta["config_hash"] = hashlib.sha256(config_str.encode()).hexdigest()
150
-
144
+
151
145
  # Create archive
152
146
  archive_bytes = self.create_archive(state_dict, meta)
153
-
147
+
154
148
  # Compute snapshot ID
155
149
  snapshot_id = self.compute_snapshot_id(archive_bytes)
156
150
  meta["snapshot_id"] = snapshot_id
157
-
151
+
158
152
  # Recreate archive with snapshot_id in metadata
159
153
  archive_bytes = self.create_archive(state_dict, meta)
160
-
154
+
161
155
  # Get path and write
162
156
  path = self.get_snapshot_path(rl_run_id, kind, snapshot_id)
163
157
  self.write_snapshot_atomic(path, archive_bytes)
164
-
158
+
165
159
  # Append to index
166
160
  self.append_to_index(rl_run_id, meta)
167
-
161
+
168
162
  return snapshot_id, str(path), len(archive_bytes)
169
-
163
+
170
164
  def load_snapshot(
171
165
  self,
172
166
  rl_run_id: str,
@@ -175,16 +169,16 @@ class VolumeStorage:
175
169
  ) -> tuple[Dict[str, Any], Dict[str, Any]]:
176
170
  """Load a snapshot and return (state_dict, meta)."""
177
171
  path = self.get_snapshot_path(rl_run_id, kind, snapshot_id)
178
-
172
+
179
173
  if not path.exists():
180
174
  raise FileNotFoundError(f"Snapshot not found: {path}")
181
-
175
+
182
176
  with open(path, "rb") as f:
183
177
  archive_bytes = f.read()
184
-
178
+
185
179
  state, meta = self.extract_archive(archive_bytes)
186
180
  return state, meta
187
-
181
+
188
182
  def append_to_index(
189
183
  self,
190
184
  rl_run_id: str,
@@ -193,25 +187,25 @@ class VolumeStorage:
193
187
  """Append metadata to the run's index file."""
194
188
  index_path = self.get_index_path(rl_run_id)
195
189
  index_path.parent.mkdir(parents=True, exist_ok=True)
196
-
190
+
197
191
  with open(index_path, "a") as f:
198
192
  f.write(json.dumps(meta) + "\n")
199
-
193
+
200
194
  def read_index(self, rl_run_id: str) -> list[Dict[str, Any]]:
201
195
  """Read all entries from a run's index file."""
202
196
  index_path = self.get_index_path(rl_run_id)
203
-
197
+
204
198
  if not index_path.exists():
205
199
  return []
206
-
200
+
207
201
  entries = []
208
202
  with open(index_path, "r") as f:
209
203
  for line in f:
210
204
  if line.strip():
211
205
  entries.append(json.loads(line))
212
-
206
+
213
207
  return entries
214
208
 
215
209
 
216
210
  # Global storage instance
217
- storage = VolumeStorage()
211
+ storage = VolumeStorage()
@@ -82,15 +82,11 @@ async def test_service():
82
82
  print(f" Error: {response.status_code} - {response.text}")
83
83
  else:
84
84
  step_data = response.json()
85
- print(
86
- f" Step result - done: {step_data['done']}, reward: {step_data.get('reward')}"
87
- )
85
+ print(f" Step result - done: {step_data['done']}, reward: {step_data.get('reward')}")
88
86
 
89
87
  # Test 6: Environment snapshot
90
88
  print("\n6. Creating environment snapshot...")
91
- response = await client.post(
92
- f"{base_url}/env/snapshot", json={"env_id": env_id}
93
- )
89
+ response = await client.post(f"{base_url}/env/snapshot", json={"env_id": env_id})
94
90
  if response.status_code != 200:
95
91
  print(f" Error: {response.status_code} - {response.text}")
96
92
  else:
@@ -100,9 +96,7 @@ async def test_service():
100
96
 
101
97
  # Test 7: Policy snapshot
102
98
  print("\n7. Creating policy snapshot...")
103
- response = await client.post(
104
- f"{base_url}/policy/snapshot", json={"policy_id": policy_id}
105
- )
99
+ response = await client.post(f"{base_url}/policy/snapshot", json={"policy_id": policy_id})
106
100
  if response.status_code != 200:
107
101
  print(f" Error: {response.status_code} - {response.text}")
108
102
  else:
@@ -121,9 +115,7 @@ async def test_service():
121
115
 
122
116
  # Test 9: Terminate environment
123
117
  print("\n9. Terminating environment...")
124
- response = await client.post(
125
- f"{base_url}/env/terminate", json={"env_id": env_id}
126
- )
118
+ response = await client.post(f"{base_url}/env/terminate", json={"env_id": env_id})
127
119
  if response.status_code != 200:
128
120
  print(f" Error: {response.status_code} - {response.text}")
129
121
  else:
@@ -131,9 +123,7 @@ async def test_service():
131
123
 
132
124
  # Test 10: Terminate policy
133
125
  print("\n10. Terminating policy...")
134
- response = await client.post(
135
- f"{base_url}/policy/terminate", json={"policy_id": policy_id}
136
- )
126
+ response = await client.post(f"{base_url}/policy/terminate", json={"policy_id": policy_id})
137
127
  if response.status_code != 200:
138
128
  print(f" Error: {response.status_code} - {response.text}")
139
129
  else:
synth_ai/__init__.py CHANGED
@@ -5,6 +5,7 @@ Synth AI - Software for aiding the best and multiplying the will.
5
5
  # Environment exports - moved from synth-env
6
6
  from synth_ai.environments import * # noqa
7
7
  import synth_ai.environments as environments # expose module name for __all__
8
+
8
9
  try:
9
10
  from synth_ai.lm.core.main import LM # Moved from zyk to lm for better organization
10
11
  except Exception: # allow minimal imports (e.g., tracing) without LM stack