synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1276 -186
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
|
@@ -20,6 +20,7 @@ from .registry import registry
|
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
|
+
|
|
23
24
|
# --- Seeding utilities (robust, optional deps) ---
|
|
24
25
|
def _set_global_seed(seed_value: int) -> Dict[str, Any]:
|
|
25
26
|
"""Set global RNG seeds across common libraries; return details for logging/restoration.
|
|
@@ -29,18 +30,21 @@ def _set_global_seed(seed_value: int) -> Dict[str, Any]:
|
|
|
29
30
|
seeded: Dict[str, Any] = {"seed": int(seed_value), "libs": []}
|
|
30
31
|
try:
|
|
31
32
|
import random as _random # type: ignore
|
|
33
|
+
|
|
32
34
|
_random.seed(seed_value)
|
|
33
35
|
seeded["libs"].append("random")
|
|
34
36
|
except Exception:
|
|
35
37
|
pass
|
|
36
38
|
try:
|
|
37
39
|
import numpy as _np # type: ignore
|
|
40
|
+
|
|
38
41
|
_np.random.seed(seed_value)
|
|
39
42
|
seeded["libs"].append("numpy")
|
|
40
43
|
except Exception:
|
|
41
44
|
pass
|
|
42
45
|
try:
|
|
43
46
|
import torch as _torch # type: ignore
|
|
47
|
+
|
|
44
48
|
if hasattr(_torch, "manual_seed"):
|
|
45
49
|
_torch.manual_seed(seed_value)
|
|
46
50
|
seeded["libs"].append("torch")
|
|
@@ -62,12 +66,14 @@ def _set_global_seed(seed_value: int) -> Dict[str, Any]:
|
|
|
62
66
|
pass
|
|
63
67
|
return seeded
|
|
64
68
|
|
|
69
|
+
|
|
65
70
|
def _clear_seed_side_effects() -> None:
|
|
66
71
|
"""Best-effort cleanup to avoid global deterministic side-effects between requests."""
|
|
67
72
|
# We cannot truly restore prior RNG states without capturing them; we just avoid
|
|
68
73
|
# leaving aggressive deterministic flags enabled where it matters.
|
|
69
74
|
try:
|
|
70
75
|
import torch as _torch # type: ignore
|
|
76
|
+
|
|
71
77
|
try:
|
|
72
78
|
if getattr(_torch, "backends", None) and getattr(_torch.backends, "cudnn", None):
|
|
73
79
|
# Re-enable cudnn.benchmark default True only if it was True; safest is False -> leave as is.
|
|
@@ -78,6 +84,7 @@ def _clear_seed_side_effects() -> None:
|
|
|
78
84
|
except Exception:
|
|
79
85
|
pass
|
|
80
86
|
|
|
87
|
+
|
|
81
88
|
router = APIRouter()
|
|
82
89
|
|
|
83
90
|
|
|
@@ -161,11 +168,7 @@ def compute_stepwise_reward(
|
|
|
161
168
|
prev_map = prev_achievements or {}
|
|
162
169
|
next_map = new_achievements or {}
|
|
163
170
|
|
|
164
|
-
unlocked = [
|
|
165
|
-
name
|
|
166
|
-
for name, value in next_map.items()
|
|
167
|
-
if value and not prev_map.get(name, False)
|
|
168
|
-
]
|
|
171
|
+
unlocked = [name for name, value in next_map.items() if value and not prev_map.get(name, False)]
|
|
169
172
|
indicator = 1 if unlocked else 0
|
|
170
173
|
reward_value = float(indicator_lambda) * indicator
|
|
171
174
|
|
|
@@ -227,7 +230,9 @@ class RolloutTracingContext:
|
|
|
227
230
|
self.sft_records: list[dict[str, Any]] = []
|
|
228
231
|
self.latest_system_messages: list[str] = []
|
|
229
232
|
self.latest_user_messages: list[str] = []
|
|
230
|
-
self.trace_format = (
|
|
233
|
+
self.trace_format = (
|
|
234
|
+
getattr(request.record, "trace_format", "compact") or "compact"
|
|
235
|
+
).lower()
|
|
231
236
|
self.return_trace = bool(getattr(request.record, "return_trace", False))
|
|
232
237
|
self.sft_output_dir = getattr(fastapi_request.app.state, "sft_output_dir", None)
|
|
233
238
|
self.session_trace = None
|
|
@@ -257,7 +262,9 @@ class RolloutTracingContext:
|
|
|
257
262
|
except Exception as exc:
|
|
258
263
|
logger.debug("TRACING_INIT_FAIL: %s", exc)
|
|
259
264
|
try:
|
|
260
|
-
await self.tracer.start_session(
|
|
265
|
+
await self.tracer.start_session(
|
|
266
|
+
session_id=self.run_id, metadata=dict(self.metadata_base)
|
|
267
|
+
)
|
|
261
268
|
except Exception as exc:
|
|
262
269
|
logger.warning("TRACING_START_FAIL: %s", exc)
|
|
263
270
|
self.enabled = False
|
|
@@ -379,17 +386,15 @@ class RolloutTracingContext:
|
|
|
379
386
|
input_tokens = usage.get("input_tokens") or usage.get("prompt_tokens")
|
|
380
387
|
output_tokens = usage.get("output_tokens") or usage.get("completion_tokens")
|
|
381
388
|
total_tokens = usage.get("total_tokens")
|
|
382
|
-
cost_usd = (
|
|
383
|
-
usage.get("cost_usd")
|
|
384
|
-
or usage.get("cost")
|
|
385
|
-
or usage.get("total_cost")
|
|
386
|
-
)
|
|
389
|
+
cost_usd = usage.get("cost_usd") or usage.get("cost") or usage.get("total_cost")
|
|
387
390
|
|
|
388
391
|
assistant_message = None
|
|
389
392
|
choices = inference_response.get("choices") or []
|
|
390
393
|
if choices:
|
|
391
394
|
assistant_message = choices[0].get("message") or {}
|
|
392
|
-
assistant_content =
|
|
395
|
+
assistant_content = (
|
|
396
|
+
assistant_message.get("content") if isinstance(assistant_message, dict) else None
|
|
397
|
+
)
|
|
393
398
|
|
|
394
399
|
raw_response = self._content_to_text(assistant_content)
|
|
395
400
|
if not raw_response:
|
|
@@ -397,7 +402,9 @@ class RolloutTracingContext:
|
|
|
397
402
|
|
|
398
403
|
base_response = BaseLMResponse(
|
|
399
404
|
raw_response=raw_response,
|
|
400
|
-
tool_calls=assistant_message.get("tool_calls")
|
|
405
|
+
tool_calls=assistant_message.get("tool_calls")
|
|
406
|
+
if isinstance(assistant_message, dict)
|
|
407
|
+
else None,
|
|
401
408
|
usage=usage or None,
|
|
402
409
|
api_type="chat_completions",
|
|
403
410
|
)
|
|
@@ -469,7 +476,9 @@ class RolloutTracingContext:
|
|
|
469
476
|
),
|
|
470
477
|
"assistant": {
|
|
471
478
|
"content": assistant_text,
|
|
472
|
-
"tool_calls": assistant_message.get("tool_calls")
|
|
479
|
+
"tool_calls": assistant_message.get("tool_calls")
|
|
480
|
+
if isinstance(assistant_message, dict)
|
|
481
|
+
else [],
|
|
473
482
|
},
|
|
474
483
|
"timestamp": datetime.utcnow().isoformat(),
|
|
475
484
|
}
|
|
@@ -488,11 +497,19 @@ class RolloutTracingContext:
|
|
|
488
497
|
return None
|
|
489
498
|
|
|
490
499
|
try:
|
|
491
|
-
prev_summary =
|
|
500
|
+
prev_summary = (
|
|
501
|
+
_summarize_observation_for_storage(env_handle, prev_obs or {})
|
|
502
|
+
if prev_obs is not None
|
|
503
|
+
else None
|
|
504
|
+
)
|
|
492
505
|
except Exception:
|
|
493
506
|
prev_summary = None
|
|
494
507
|
try:
|
|
495
|
-
next_summary =
|
|
508
|
+
next_summary = (
|
|
509
|
+
_summarize_observation_for_storage(env_handle, next_obs or {})
|
|
510
|
+
if next_obs is not None
|
|
511
|
+
else None
|
|
512
|
+
)
|
|
496
513
|
except Exception:
|
|
497
514
|
next_summary = None
|
|
498
515
|
|
|
@@ -640,7 +657,11 @@ class RolloutTracingContext:
|
|
|
640
657
|
"lm_calls": self.lm_calls_summary,
|
|
641
658
|
"decision_rewards": self.decision_rewards,
|
|
642
659
|
}
|
|
643
|
-
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def _summarize_observation_for_storage(
|
|
663
|
+
env_handle: Any, observation: Dict[str, Any]
|
|
664
|
+
) -> Dict[str, Any]:
|
|
644
665
|
"""Return a compact dict for trajectory storage instead of the raw observation.
|
|
645
666
|
|
|
646
667
|
- For Crafter, use the same summary used for the policy user prompt
|
|
@@ -652,9 +673,12 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
|
|
|
652
673
|
except Exception:
|
|
653
674
|
_CrafterWrapper = None # type: ignore
|
|
654
675
|
|
|
655
|
-
if _CrafterWrapper is not None and isinstance(
|
|
676
|
+
if _CrafterWrapper is not None and isinstance(
|
|
677
|
+
getattr(env_handle, "env", None), _CrafterWrapper
|
|
678
|
+
):
|
|
656
679
|
try:
|
|
657
680
|
from .envs.crafter.shared import format_observation as _fmt # type: ignore
|
|
681
|
+
|
|
658
682
|
text = _fmt(observation or {})
|
|
659
683
|
return {"text": text}
|
|
660
684
|
except Exception:
|
|
@@ -671,8 +695,12 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
|
|
|
671
695
|
summary = {
|
|
672
696
|
"position": pos,
|
|
673
697
|
"health": health,
|
|
674
|
-
"inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10]
|
|
675
|
-
|
|
698
|
+
"inventory_keys": sorted([k for k, v in (inv or {}).items() if v])[:10]
|
|
699
|
+
if isinstance(inv, dict)
|
|
700
|
+
else None,
|
|
701
|
+
"achievements_unlocked": sorted([k for k, v in (ach or {}).items() if v])[:10]
|
|
702
|
+
if isinstance(ach, dict)
|
|
703
|
+
else None,
|
|
676
704
|
}
|
|
677
705
|
return {"text": json.dumps(summary, ensure_ascii=False)}
|
|
678
706
|
except Exception:
|
|
@@ -685,7 +713,6 @@ def _summarize_observation_for_storage(env_handle: Any, observation: Dict[str, A
|
|
|
685
713
|
return {"text": ""}
|
|
686
714
|
|
|
687
715
|
|
|
688
|
-
|
|
689
716
|
class RunAbortRequest(BaseModel):
|
|
690
717
|
run_id: str
|
|
691
718
|
|
|
@@ -857,9 +884,7 @@ async def execute_rollout(
|
|
|
857
884
|
# Propagate training_session_id via env config for downstream usage
|
|
858
885
|
_env_config = dict(request.env.config or {})
|
|
859
886
|
if request.training_session_id is not None:
|
|
860
|
-
_env_config.setdefault(
|
|
861
|
-
"training_session_id", request.training_session_id
|
|
862
|
-
)
|
|
887
|
+
_env_config.setdefault("training_session_id", request.training_session_id)
|
|
863
888
|
env_response = await create_environment(
|
|
864
889
|
EnvCreateRequest(
|
|
865
890
|
env_name=request.env.env_name,
|
|
@@ -893,9 +918,7 @@ async def execute_rollout(
|
|
|
893
918
|
# Propagate training_session_id and synth_base_url via policy config
|
|
894
919
|
_policy_config = dict(request.policy.config or {})
|
|
895
920
|
if request.training_session_id is not None:
|
|
896
|
-
_policy_config.setdefault(
|
|
897
|
-
"training_session_id", request.training_session_id
|
|
898
|
-
)
|
|
921
|
+
_policy_config.setdefault("training_session_id", request.training_session_id)
|
|
899
922
|
if request.synth_base_url is not None:
|
|
900
923
|
_policy_config.setdefault("synth_base_url", request.synth_base_url)
|
|
901
924
|
policy_response = await create_policy(
|
|
@@ -1065,7 +1088,10 @@ async def execute_rollout(
|
|
|
1065
1088
|
_timing["decision_ms"] = decision_ms
|
|
1066
1089
|
if last_env_step_ms is not None:
|
|
1067
1090
|
_timing.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1068
|
-
_timing.setdefault(
|
|
1091
|
+
_timing.setdefault(
|
|
1092
|
+
"overhead_ms",
|
|
1093
|
+
max(0.0, decision_ms - float(last_env_step_ms)),
|
|
1094
|
+
)
|
|
1069
1095
|
else:
|
|
1070
1096
|
_timing.setdefault("overhead_ms", 0.0)
|
|
1071
1097
|
_meta["timing"] = _timing
|
|
@@ -1107,9 +1133,7 @@ async def execute_rollout(
|
|
|
1107
1133
|
_first_guess = None
|
|
1108
1134
|
if _count > 0 and isinstance(_prev_calls[0], dict):
|
|
1109
1135
|
_args = (
|
|
1110
|
-
_prev_calls[0]["arguments"]
|
|
1111
|
-
if "arguments" in _prev_calls[0]
|
|
1112
|
-
else None
|
|
1136
|
+
_prev_calls[0]["arguments"] if "arguments" in _prev_calls[0] else None
|
|
1113
1137
|
)
|
|
1114
1138
|
if isinstance(_args, str):
|
|
1115
1139
|
import json as _json
|
|
@@ -1119,9 +1143,9 @@ async def execute_rollout(
|
|
|
1119
1143
|
except Exception:
|
|
1120
1144
|
_args = {}
|
|
1121
1145
|
if isinstance(_args, dict):
|
|
1122
|
-
_first_guess = (
|
|
1123
|
-
_args["
|
|
1124
|
-
)
|
|
1146
|
+
_first_guess = (_args["guess"] if "guess" in _args else None) or (
|
|
1147
|
+
_args["word"] if "word" in _args else None
|
|
1148
|
+
)
|
|
1125
1149
|
logger.info(
|
|
1126
1150
|
"POLICY_METADATA: prev_tool_calls=%d first_guess=%r has_prev_env_result=%s",
|
|
1127
1151
|
_count,
|
|
@@ -1377,7 +1401,9 @@ async def execute_rollout(
|
|
|
1377
1401
|
(env_step_end - float(last_agent_response_ts)) * 1000.0,
|
|
1378
1402
|
)
|
|
1379
1403
|
timing_last["decision_ms"] = decision_ms
|
|
1380
|
-
timing_last.setdefault(
|
|
1404
|
+
timing_last.setdefault(
|
|
1405
|
+
"overhead_ms", max(0.0, decision_ms - env_step_duration_ms)
|
|
1406
|
+
)
|
|
1381
1407
|
except Exception:
|
|
1382
1408
|
pass
|
|
1383
1409
|
if decision_open:
|
|
@@ -1409,9 +1435,7 @@ async def execute_rollout(
|
|
|
1409
1435
|
# Attach policy meta from the immediately preceding agent step
|
|
1410
1436
|
try:
|
|
1411
1437
|
prev_meta = {}
|
|
1412
|
-
if "policy_response" in locals() and isinstance(
|
|
1413
|
-
policy_response.meta, dict
|
|
1414
|
-
): # type: ignore[name-defined]
|
|
1438
|
+
if "policy_response" in locals() and isinstance(policy_response.meta, dict): # type: ignore[name-defined]
|
|
1415
1439
|
prev_meta = policy_response.meta
|
|
1416
1440
|
if prev_meta:
|
|
1417
1441
|
_info = dict(_info)
|
|
@@ -1452,9 +1476,7 @@ async def execute_rollout(
|
|
|
1452
1476
|
reward_stepwise = float(stats.get("reward", 0.0))
|
|
1453
1477
|
stepwise_indicator_sum += float(stats.get("indicator", 0.0))
|
|
1454
1478
|
stepwise_reward_sum += reward_stepwise
|
|
1455
|
-
stepwise_new_achievements_total += int(
|
|
1456
|
-
stats.get("new_achievements_count", 0.0)
|
|
1457
|
-
)
|
|
1479
|
+
stepwise_new_achievements_total += int(stats.get("new_achievements_count", 0.0))
|
|
1458
1480
|
if not isinstance(_info, dict):
|
|
1459
1481
|
_info = {}
|
|
1460
1482
|
else:
|
|
@@ -1470,7 +1492,9 @@ async def execute_rollout(
|
|
|
1470
1492
|
# Prepare stable lists for logging/metadata
|
|
1471
1493
|
all_list = sorted(list(turned_true))
|
|
1472
1494
|
# Ensure nested meta exists
|
|
1473
|
-
meta_block =
|
|
1495
|
+
meta_block = (
|
|
1496
|
+
_info.get("meta") if isinstance(_info.get("meta"), dict) else {}
|
|
1497
|
+
)
|
|
1474
1498
|
decision_rewards = {
|
|
1475
1499
|
"turn": int(decision_index),
|
|
1476
1500
|
"ach_delta": ach_delta,
|
|
@@ -1521,9 +1545,7 @@ async def execute_rollout(
|
|
|
1521
1545
|
EnvResetRequest,
|
|
1522
1546
|
)
|
|
1523
1547
|
|
|
1524
|
-
reset_response = await reset_environment(
|
|
1525
|
-
EnvResetRequest(env_id=env_id)
|
|
1526
|
-
)
|
|
1548
|
+
reset_response = await reset_environment(EnvResetRequest(env_id=env_id))
|
|
1527
1549
|
current_obs = reset_response.observation
|
|
1528
1550
|
elif request.on_done == "terminate":
|
|
1529
1551
|
break
|
|
@@ -1544,15 +1566,11 @@ async def execute_rollout(
|
|
|
1544
1566
|
):
|
|
1545
1567
|
try:
|
|
1546
1568
|
final_now = last_env_step_completed_ts or _time.perf_counter()
|
|
1547
|
-
final_decision_ms = max(
|
|
1548
|
-
0.0, (final_now - float(last_agent_response_ts)) * 1000.0
|
|
1549
|
-
)
|
|
1569
|
+
final_decision_ms = max(0.0, (final_now - float(last_agent_response_ts)) * 1000.0)
|
|
1550
1570
|
timing_final = last_policy_meta.setdefault("timing", {})
|
|
1551
1571
|
timing_final["decision_ms"] = final_decision_ms
|
|
1552
1572
|
if last_env_step_ms is not None:
|
|
1553
|
-
timing_final.setdefault(
|
|
1554
|
-
"env_step_ms", float(last_env_step_ms)
|
|
1555
|
-
)
|
|
1573
|
+
timing_final.setdefault("env_step_ms", float(last_env_step_ms))
|
|
1556
1574
|
timing_final.setdefault(
|
|
1557
1575
|
"overhead_ms",
|
|
1558
1576
|
max(0.0, final_decision_ms - float(last_env_step_ms)),
|
|
@@ -1601,10 +1619,11 @@ async def execute_rollout(
|
|
|
1601
1619
|
for step in trajectory_steps:
|
|
1602
1620
|
formatted_steps.append({"tool_calls": step.tool_calls or []})
|
|
1603
1621
|
|
|
1604
|
-
if
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1622
|
+
if (
|
|
1623
|
+
get_wordle_rollout_summary is not None
|
|
1624
|
+
and log_wordle_rollout_summary is not None
|
|
1625
|
+
):
|
|
1626
|
+
summary = get_wordle_rollout_summary(formatted_steps, current_obs, env_handle)
|
|
1608
1627
|
log_wordle_rollout_summary(request.run_id, summary)
|
|
1609
1628
|
except ImportError:
|
|
1610
1629
|
# Wordle helpers not available, skip Wordle-specific logging
|
|
@@ -1681,9 +1700,7 @@ async def execute_rollout(
|
|
|
1681
1700
|
except Exception:
|
|
1682
1701
|
pass
|
|
1683
1702
|
except Exception as _te:
|
|
1684
|
-
logger.warning(
|
|
1685
|
-
f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}"
|
|
1686
|
-
)
|
|
1703
|
+
logger.warning(f"ROLL_OUT: failed to terminate environment {created_env_id}: {_te}")
|
|
1687
1704
|
|
|
1688
1705
|
# Best-effort policy cleanup if we created one (avoid reuse across rollouts)
|
|
1689
1706
|
try:
|
|
@@ -13,10 +13,10 @@ from typing import Any, Dict, Optional
|
|
|
13
13
|
|
|
14
14
|
class VolumeStorage:
|
|
15
15
|
"""Helpers for Modal Volume storage operations."""
|
|
16
|
-
|
|
16
|
+
|
|
17
17
|
def __init__(self, base_path: str = "/data/state") -> None:
|
|
18
18
|
self.base_path = Path(base_path)
|
|
19
|
-
|
|
19
|
+
|
|
20
20
|
def get_snapshot_path(
|
|
21
21
|
self,
|
|
22
22
|
rl_run_id: str,
|
|
@@ -27,21 +27,15 @@ class VolumeStorage:
|
|
|
27
27
|
# Use first 2 chars of snapshot_id for sharding
|
|
28
28
|
shard1 = snapshot_id[:2] if len(snapshot_id) >= 2 else "00"
|
|
29
29
|
shard2 = snapshot_id[2:4] if len(snapshot_id) >= 4 else "00"
|
|
30
|
-
|
|
30
|
+
|
|
31
31
|
return (
|
|
32
|
-
self.base_path
|
|
33
|
-
/ "runs"
|
|
34
|
-
/ rl_run_id
|
|
35
|
-
/ kind
|
|
36
|
-
/ shard1
|
|
37
|
-
/ shard2
|
|
38
|
-
/ f"{snapshot_id}.tar.gz"
|
|
32
|
+
self.base_path / "runs" / rl_run_id / kind / shard1 / shard2 / f"{snapshot_id}.tar.gz"
|
|
39
33
|
)
|
|
40
|
-
|
|
34
|
+
|
|
41
35
|
def get_index_path(self, rl_run_id: str) -> Path:
|
|
42
36
|
"""Get the index file path for a run."""
|
|
43
37
|
return self.base_path / "runs" / rl_run_id / "index" / "meta.jsonl"
|
|
44
|
-
|
|
38
|
+
|
|
45
39
|
def write_snapshot_atomic(
|
|
46
40
|
self,
|
|
47
41
|
path: Path,
|
|
@@ -50,17 +44,17 @@ class VolumeStorage:
|
|
|
50
44
|
"""Atomically write a snapshot archive to disk."""
|
|
51
45
|
# Ensure parent directory exists
|
|
52
46
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
53
|
-
|
|
47
|
+
|
|
54
48
|
# Write to temp file first
|
|
55
49
|
tmp_path = path.with_suffix(".tmp")
|
|
56
50
|
with open(tmp_path, "wb") as f:
|
|
57
51
|
f.write(archive_bytes)
|
|
58
52
|
f.flush()
|
|
59
53
|
os.fsync(f.fileno())
|
|
60
|
-
|
|
54
|
+
|
|
61
55
|
# Atomic rename
|
|
62
56
|
os.replace(tmp_path, path)
|
|
63
|
-
|
|
57
|
+
|
|
64
58
|
def create_archive(
|
|
65
59
|
self,
|
|
66
60
|
state_dict: Dict[str, Any],
|
|
@@ -69,61 +63,61 @@ class VolumeStorage:
|
|
|
69
63
|
"""Create a tar.gz archive with state and metadata."""
|
|
70
64
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
71
65
|
tmppath = Path(tmpdir)
|
|
72
|
-
|
|
66
|
+
|
|
73
67
|
# Write state.json
|
|
74
68
|
state_path = tmppath / "state.json"
|
|
75
69
|
with open(state_path, "w") as f:
|
|
76
70
|
json.dump(state_dict, f, sort_keys=True, indent=2)
|
|
77
|
-
|
|
71
|
+
|
|
78
72
|
# Write meta.json
|
|
79
73
|
meta_path = tmppath / "meta.json"
|
|
80
74
|
with open(meta_path, "w") as f:
|
|
81
75
|
json.dump(meta, f, sort_keys=True, indent=2)
|
|
82
|
-
|
|
76
|
+
|
|
83
77
|
# Create tar archive
|
|
84
78
|
tar_path = tmppath / "archive.tar"
|
|
85
79
|
with tarfile.open(tar_path, "w") as tar:
|
|
86
80
|
tar.add(state_path, arcname="state.json")
|
|
87
81
|
tar.add(meta_path, arcname="meta.json")
|
|
88
|
-
|
|
82
|
+
|
|
89
83
|
# Compress with gzip
|
|
90
84
|
with open(tar_path, "rb") as f:
|
|
91
85
|
tar_bytes = f.read()
|
|
92
|
-
|
|
86
|
+
|
|
93
87
|
compressed = gzip.compress(tar_bytes, compresslevel=6)
|
|
94
|
-
|
|
88
|
+
|
|
95
89
|
return compressed
|
|
96
|
-
|
|
90
|
+
|
|
97
91
|
def extract_archive(self, archive_bytes: bytes) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
98
92
|
"""Extract state and metadata from a tar.gz archive."""
|
|
99
93
|
# Decompress
|
|
100
94
|
tar_bytes = gzip.decompress(archive_bytes)
|
|
101
|
-
|
|
95
|
+
|
|
102
96
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
103
97
|
tmppath = Path(tmpdir)
|
|
104
|
-
|
|
98
|
+
|
|
105
99
|
# Write tar bytes to temp file
|
|
106
100
|
tar_path = tmppath / "archive.tar"
|
|
107
101
|
with open(tar_path, "wb") as f:
|
|
108
102
|
f.write(tar_bytes)
|
|
109
|
-
|
|
103
|
+
|
|
110
104
|
# Extract tar
|
|
111
105
|
with tarfile.open(tar_path, "r") as tar:
|
|
112
106
|
tar.extractall(tmppath)
|
|
113
|
-
|
|
107
|
+
|
|
114
108
|
# Read state and meta
|
|
115
109
|
with open(tmppath / "state.json", "r") as f:
|
|
116
110
|
state = json.load(f)
|
|
117
|
-
|
|
111
|
+
|
|
118
112
|
with open(tmppath / "meta.json", "r") as f:
|
|
119
113
|
meta = json.load(f)
|
|
120
|
-
|
|
114
|
+
|
|
121
115
|
return state, meta
|
|
122
|
-
|
|
116
|
+
|
|
123
117
|
def compute_snapshot_id(self, archive_bytes: bytes) -> str:
|
|
124
118
|
"""Compute content-addressed snapshot ID."""
|
|
125
119
|
return hashlib.sha256(archive_bytes).hexdigest()
|
|
126
|
-
|
|
120
|
+
|
|
127
121
|
def save_snapshot(
|
|
128
122
|
self,
|
|
129
123
|
rl_run_id: str,
|
|
@@ -140,33 +134,33 @@ class VolumeStorage:
|
|
|
140
134
|
"schema_version": "1.0",
|
|
141
135
|
"created_at": datetime.utcnow().isoformat(),
|
|
142
136
|
}
|
|
143
|
-
|
|
137
|
+
|
|
144
138
|
if parent_snapshot_id:
|
|
145
139
|
meta["parent_snapshot_id"] = parent_snapshot_id
|
|
146
|
-
|
|
140
|
+
|
|
147
141
|
if config:
|
|
148
142
|
config_str = json.dumps(config, sort_keys=True)
|
|
149
143
|
meta["config_hash"] = hashlib.sha256(config_str.encode()).hexdigest()
|
|
150
|
-
|
|
144
|
+
|
|
151
145
|
# Create archive
|
|
152
146
|
archive_bytes = self.create_archive(state_dict, meta)
|
|
153
|
-
|
|
147
|
+
|
|
154
148
|
# Compute snapshot ID
|
|
155
149
|
snapshot_id = self.compute_snapshot_id(archive_bytes)
|
|
156
150
|
meta["snapshot_id"] = snapshot_id
|
|
157
|
-
|
|
151
|
+
|
|
158
152
|
# Recreate archive with snapshot_id in metadata
|
|
159
153
|
archive_bytes = self.create_archive(state_dict, meta)
|
|
160
|
-
|
|
154
|
+
|
|
161
155
|
# Get path and write
|
|
162
156
|
path = self.get_snapshot_path(rl_run_id, kind, snapshot_id)
|
|
163
157
|
self.write_snapshot_atomic(path, archive_bytes)
|
|
164
|
-
|
|
158
|
+
|
|
165
159
|
# Append to index
|
|
166
160
|
self.append_to_index(rl_run_id, meta)
|
|
167
|
-
|
|
161
|
+
|
|
168
162
|
return snapshot_id, str(path), len(archive_bytes)
|
|
169
|
-
|
|
163
|
+
|
|
170
164
|
def load_snapshot(
|
|
171
165
|
self,
|
|
172
166
|
rl_run_id: str,
|
|
@@ -175,16 +169,16 @@ class VolumeStorage:
|
|
|
175
169
|
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
|
176
170
|
"""Load a snapshot and return (state_dict, meta)."""
|
|
177
171
|
path = self.get_snapshot_path(rl_run_id, kind, snapshot_id)
|
|
178
|
-
|
|
172
|
+
|
|
179
173
|
if not path.exists():
|
|
180
174
|
raise FileNotFoundError(f"Snapshot not found: {path}")
|
|
181
|
-
|
|
175
|
+
|
|
182
176
|
with open(path, "rb") as f:
|
|
183
177
|
archive_bytes = f.read()
|
|
184
|
-
|
|
178
|
+
|
|
185
179
|
state, meta = self.extract_archive(archive_bytes)
|
|
186
180
|
return state, meta
|
|
187
|
-
|
|
181
|
+
|
|
188
182
|
def append_to_index(
|
|
189
183
|
self,
|
|
190
184
|
rl_run_id: str,
|
|
@@ -193,25 +187,25 @@ class VolumeStorage:
|
|
|
193
187
|
"""Append metadata to the run's index file."""
|
|
194
188
|
index_path = self.get_index_path(rl_run_id)
|
|
195
189
|
index_path.parent.mkdir(parents=True, exist_ok=True)
|
|
196
|
-
|
|
190
|
+
|
|
197
191
|
with open(index_path, "a") as f:
|
|
198
192
|
f.write(json.dumps(meta) + "\n")
|
|
199
|
-
|
|
193
|
+
|
|
200
194
|
def read_index(self, rl_run_id: str) -> list[Dict[str, Any]]:
|
|
201
195
|
"""Read all entries from a run's index file."""
|
|
202
196
|
index_path = self.get_index_path(rl_run_id)
|
|
203
|
-
|
|
197
|
+
|
|
204
198
|
if not index_path.exists():
|
|
205
199
|
return []
|
|
206
|
-
|
|
200
|
+
|
|
207
201
|
entries = []
|
|
208
202
|
with open(index_path, "r") as f:
|
|
209
203
|
for line in f:
|
|
210
204
|
if line.strip():
|
|
211
205
|
entries.append(json.loads(line))
|
|
212
|
-
|
|
206
|
+
|
|
213
207
|
return entries
|
|
214
208
|
|
|
215
209
|
|
|
216
210
|
# Global storage instance
|
|
217
|
-
storage = VolumeStorage()
|
|
211
|
+
storage = VolumeStorage()
|
|
@@ -82,15 +82,11 @@ async def test_service():
|
|
|
82
82
|
print(f" Error: {response.status_code} - {response.text}")
|
|
83
83
|
else:
|
|
84
84
|
step_data = response.json()
|
|
85
|
-
print(
|
|
86
|
-
f" Step result - done: {step_data['done']}, reward: {step_data.get('reward')}"
|
|
87
|
-
)
|
|
85
|
+
print(f" Step result - done: {step_data['done']}, reward: {step_data.get('reward')}")
|
|
88
86
|
|
|
89
87
|
# Test 6: Environment snapshot
|
|
90
88
|
print("\n6. Creating environment snapshot...")
|
|
91
|
-
response = await client.post(
|
|
92
|
-
f"{base_url}/env/snapshot", json={"env_id": env_id}
|
|
93
|
-
)
|
|
89
|
+
response = await client.post(f"{base_url}/env/snapshot", json={"env_id": env_id})
|
|
94
90
|
if response.status_code != 200:
|
|
95
91
|
print(f" Error: {response.status_code} - {response.text}")
|
|
96
92
|
else:
|
|
@@ -100,9 +96,7 @@ async def test_service():
|
|
|
100
96
|
|
|
101
97
|
# Test 7: Policy snapshot
|
|
102
98
|
print("\n7. Creating policy snapshot...")
|
|
103
|
-
response = await client.post(
|
|
104
|
-
f"{base_url}/policy/snapshot", json={"policy_id": policy_id}
|
|
105
|
-
)
|
|
99
|
+
response = await client.post(f"{base_url}/policy/snapshot", json={"policy_id": policy_id})
|
|
106
100
|
if response.status_code != 200:
|
|
107
101
|
print(f" Error: {response.status_code} - {response.text}")
|
|
108
102
|
else:
|
|
@@ -121,9 +115,7 @@ async def test_service():
|
|
|
121
115
|
|
|
122
116
|
# Test 9: Terminate environment
|
|
123
117
|
print("\n9. Terminating environment...")
|
|
124
|
-
response = await client.post(
|
|
125
|
-
f"{base_url}/env/terminate", json={"env_id": env_id}
|
|
126
|
-
)
|
|
118
|
+
response = await client.post(f"{base_url}/env/terminate", json={"env_id": env_id})
|
|
127
119
|
if response.status_code != 200:
|
|
128
120
|
print(f" Error: {response.status_code} - {response.text}")
|
|
129
121
|
else:
|
|
@@ -131,9 +123,7 @@ async def test_service():
|
|
|
131
123
|
|
|
132
124
|
# Test 10: Terminate policy
|
|
133
125
|
print("\n10. Terminating policy...")
|
|
134
|
-
response = await client.post(
|
|
135
|
-
f"{base_url}/policy/terminate", json={"policy_id": policy_id}
|
|
136
|
-
)
|
|
126
|
+
response = await client.post(f"{base_url}/policy/terminate", json={"policy_id": policy_id})
|
|
137
127
|
if response.status_code != 200:
|
|
138
128
|
print(f" Error: {response.status_code} - {response.text}")
|
|
139
129
|
else:
|
synth_ai/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ Synth AI - Software for aiding the best and multiplying the will.
|
|
|
5
5
|
# Environment exports - moved from synth-env
|
|
6
6
|
from synth_ai.environments import * # noqa
|
|
7
7
|
import synth_ai.environments as environments # expose module name for __all__
|
|
8
|
+
|
|
8
9
|
try:
|
|
9
10
|
from synth_ai.lm.core.main import LM # Moved from zyk to lm for better organization
|
|
10
11
|
except Exception: # allow minimal imports (e.g., tracing) without LM stack
|