synth-ai 0.2.9.dev5__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/common_old/backend.py +0 -1
- examples/crafter_debug_render.py +15 -6
- examples/evals_old/compare_models.py +1 -0
- examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
- examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
- examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
- examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
- examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
- examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
- examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
- examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
- examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
- examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
- examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
- examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
- examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
- examples/finetuning_old/synth_qwen_v1/util.py +7 -2
- examples/rl/configs/eval_base_qwen.toml +1 -1
- examples/rl/configs/rl_from_base_qwen17.toml +1 -1
- examples/rl/download_dataset.py +26 -10
- examples/rl/run_eval.py +17 -15
- examples/rl/run_rl_and_save.py +24 -7
- examples/rl/task_app/math_single_step.py +128 -11
- examples/rl/task_app/math_task_app.py +11 -3
- examples/rl_old/task_app.py +222 -53
- examples/warming_up_to_rl/analyze_trace_db.py +7 -5
- examples/warming_up_to_rl/export_trace_sft.py +141 -16
- examples/warming_up_to_rl/groq_test.py +11 -4
- examples/warming_up_to_rl/manage_secrets.py +15 -6
- examples/warming_up_to_rl/readme.md +9 -2
- examples/warming_up_to_rl/run_eval.py +108 -30
- examples/warming_up_to_rl/run_fft_and_save.py +128 -52
- examples/warming_up_to_rl/run_local_rollout.py +87 -36
- examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
- examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
- examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
- examples/warming_up_to_rl/run_rl_and_save.py +31 -7
- examples/warming_up_to_rl/run_rollout_remote.py +37 -10
- examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
- examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
- examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
- examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
- examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
- examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
- synth_ai/__init__.py +1 -0
- synth_ai/api/train/builders.py +34 -10
- synth_ai/api/train/cli.py +172 -32
- synth_ai/api/train/config_finder.py +59 -4
- synth_ai/api/train/env_resolver.py +32 -14
- synth_ai/api/train/pollers.py +11 -3
- synth_ai/api/train/task_app.py +4 -1
- synth_ai/api/train/utils.py +20 -4
- synth_ai/cli/__init__.py +11 -4
- synth_ai/cli/balance.py +1 -1
- synth_ai/cli/demo.py +19 -5
- synth_ai/cli/rl_demo.py +75 -16
- synth_ai/cli/root.py +116 -37
- synth_ai/cli/task_apps.py +1276 -186
- synth_ai/cli/traces.py +1 -0
- synth_ai/cli/turso.py +73 -0
- synth_ai/core/experiment.py +0 -2
- synth_ai/demo_registry.py +67 -30
- synth_ai/demos/core/cli.py +493 -164
- synth_ai/demos/demo_task_apps/core.py +50 -6
- synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
- synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
- synth_ai/demos/demo_task_apps/math/_common.py +1 -2
- synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
- synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
- synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
- synth_ai/environments/examples/bandit/engine.py +12 -4
- synth_ai/environments/examples/bandit/taskset.py +4 -4
- synth_ai/environments/reproducibility/tree.py +3 -1
- synth_ai/environments/service/core_routes.py +6 -2
- synth_ai/evals/base.py +0 -2
- synth_ai/experimental/synth_oss.py +11 -12
- synth_ai/handshake.py +3 -1
- synth_ai/http_client.py +31 -7
- synth_ai/inference/__init__.py +0 -2
- synth_ai/inference/client.py +8 -4
- synth_ai/jobs/client.py +40 -10
- synth_ai/learning/client.py +33 -8
- synth_ai/learning/config.py +0 -2
- synth_ai/learning/constants.py +0 -2
- synth_ai/learning/ft_client.py +6 -3
- synth_ai/learning/health.py +9 -2
- synth_ai/learning/jobs.py +17 -5
- synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
- synth_ai/learning/prompts/random_search.py +4 -1
- synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
- synth_ai/learning/rl_client.py +42 -14
- synth_ai/learning/sse.py +0 -2
- synth_ai/learning/validators.py +6 -2
- synth_ai/lm/caching/ephemeral.py +1 -3
- synth_ai/lm/core/exceptions.py +0 -2
- synth_ai/lm/core/main.py +13 -1
- synth_ai/lm/core/synth_models.py +0 -1
- synth_ai/lm/core/vendor_clients.py +4 -2
- synth_ai/lm/overrides.py +2 -2
- synth_ai/lm/vendors/core/anthropic_api.py +7 -7
- synth_ai/lm/vendors/core/openai_api.py +2 -0
- synth_ai/lm/vendors/openai_standard.py +3 -1
- synth_ai/lm/vendors/openai_standard_responses.py +6 -3
- synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
- synth_ai/lm/vendors/synth_client.py +37 -10
- synth_ai/rl/__init__.py +0 -1
- synth_ai/rl/contracts.py +0 -2
- synth_ai/rl/env_keys.py +6 -1
- synth_ai/task/__init__.py +1 -0
- synth_ai/task/apps/__init__.py +11 -11
- synth_ai/task/auth.py +29 -17
- synth_ai/task/client.py +3 -1
- synth_ai/task/contracts.py +1 -0
- synth_ai/task/datasets.py +3 -1
- synth_ai/task/errors.py +3 -2
- synth_ai/task/health.py +0 -2
- synth_ai/task/json.py +0 -1
- synth_ai/task/proxy.py +2 -5
- synth_ai/task/rubrics.py +9 -3
- synth_ai/task/server.py +31 -5
- synth_ai/task/tracing_utils.py +8 -3
- synth_ai/task/validators.py +0 -1
- synth_ai/task/vendors.py +0 -1
- synth_ai/tracing_v3/db_config.py +26 -1
- synth_ai/tracing_v3/decorators.py +1 -0
- synth_ai/tracing_v3/examples/basic_usage.py +3 -2
- synth_ai/tracing_v3/hooks.py +2 -0
- synth_ai/tracing_v3/replica_sync.py +1 -0
- synth_ai/tracing_v3/session_tracer.py +24 -3
- synth_ai/tracing_v3/storage/base.py +4 -1
- synth_ai/tracing_v3/storage/factory.py +0 -1
- synth_ai/tracing_v3/turso/manager.py +102 -38
- synth_ai/tracing_v3/turso/models.py +4 -1
- synth_ai/tracing_v3/utils.py +1 -0
- synth_ai/v0/tracing/upload.py +32 -135
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -154
- synth_ai/install_sqld.sh +0 -40
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.9.dev5.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
synth_ai/task/proxy.py
CHANGED
|
@@ -179,7 +179,7 @@ def parse_tool_call_from_text(text: str) -> Tuple[list[str], str]:
|
|
|
179
179
|
if m:
|
|
180
180
|
items = [part.strip() for part in m.group(1).split(",") if part.strip()]
|
|
181
181
|
if items:
|
|
182
|
-
reasoning = text[:m.start()].strip()
|
|
182
|
+
reasoning = text[: m.start()].strip()
|
|
183
183
|
return items, reasoning
|
|
184
184
|
|
|
185
185
|
# Patterns like "Action 1: move_right"
|
|
@@ -242,9 +242,7 @@ def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str
|
|
|
242
242
|
return openai_response
|
|
243
243
|
|
|
244
244
|
new_message = copy.deepcopy(message)
|
|
245
|
-
new_message["tool_calls"] = [
|
|
246
|
-
_build_tool_call(actions, reasoning)
|
|
247
|
-
]
|
|
245
|
+
new_message["tool_calls"] = [_build_tool_call(actions, reasoning)]
|
|
248
246
|
if "content" not in new_message:
|
|
249
247
|
new_message["content"] = None
|
|
250
248
|
|
|
@@ -255,4 +253,3 @@ def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str
|
|
|
255
253
|
result = copy.deepcopy(openai_response)
|
|
256
254
|
result["choices"] = new_choices
|
|
257
255
|
return result
|
|
258
|
-
|
synth_ai/task/rubrics.py
CHANGED
|
@@ -155,7 +155,9 @@ def _as_float(value: Any) -> Optional[float]:
|
|
|
155
155
|
return None
|
|
156
156
|
|
|
157
157
|
|
|
158
|
-
def _score(
|
|
158
|
+
def _score(
|
|
159
|
+
criteria: Iterable[Criterion], values: Dict[str, float], aggregation: str
|
|
160
|
+
) -> Dict[str, Any]:
|
|
159
161
|
if aggregation == "inherit":
|
|
160
162
|
aggregation = "weighted_sum"
|
|
161
163
|
per_criterion: Dict[str, Dict[str, Any]] = {}
|
|
@@ -184,7 +186,9 @@ def _score(criteria: Iterable[Criterion], values: Dict[str, float], aggregation:
|
|
|
184
186
|
}
|
|
185
187
|
|
|
186
188
|
|
|
187
|
-
def score_events_against_rubric(
|
|
189
|
+
def score_events_against_rubric(
|
|
190
|
+
events: list[dict[str, Any]], rubric: Rubric | None
|
|
191
|
+
) -> Dict[str, Any]:
|
|
188
192
|
if rubric is None:
|
|
189
193
|
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
190
194
|
values: Dict[str, float] = {}
|
|
@@ -203,7 +207,9 @@ def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None)
|
|
|
203
207
|
return {"aggregation": "none", "score": None, "per_criterion": {}}
|
|
204
208
|
values: Dict[str, float] = {}
|
|
205
209
|
if isinstance(outcome, dict):
|
|
206
|
-
candidates =
|
|
210
|
+
candidates = (
|
|
211
|
+
outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
|
|
212
|
+
)
|
|
207
213
|
if isinstance(candidates, dict):
|
|
208
214
|
for key, value in candidates.items():
|
|
209
215
|
score = _as_float(value)
|
synth_ai/task/server.py
CHANGED
|
@@ -120,7 +120,9 @@ def _ensure_task_info(obj: Any) -> TaskInfo:
|
|
|
120
120
|
return obj
|
|
121
121
|
if isinstance(obj, MutableMapping):
|
|
122
122
|
return TaskInfo.model_validate(obj)
|
|
123
|
-
raise TypeError(
|
|
123
|
+
raise TypeError(
|
|
124
|
+
f"Task instance provider must yield TaskInfo-compatible objects (got {type(obj)!r})"
|
|
125
|
+
)
|
|
124
126
|
|
|
125
127
|
|
|
126
128
|
def _normalise_seeds(values: Sequence[int]) -> list[int]:
|
|
@@ -140,7 +142,9 @@ def _build_proxy_routes(
|
|
|
140
142
|
if not proxy:
|
|
141
143
|
return
|
|
142
144
|
|
|
143
|
-
async def _call_vendor(
|
|
145
|
+
async def _call_vendor(
|
|
146
|
+
url: str, payload: dict[str, Any], headers: dict[str, str]
|
|
147
|
+
) -> dict[str, Any]:
|
|
144
148
|
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0), follow_redirects=True) as client:
|
|
145
149
|
response = await client.post(url, json=payload, headers=headers)
|
|
146
150
|
data = (
|
|
@@ -168,13 +172,17 @@ def _build_proxy_routes(
|
|
|
168
172
|
msg_count = len(messages) if isinstance(messages, list) else 0
|
|
169
173
|
tool_count = len(payload.get("tools") or []) if isinstance(payload, dict) else 0
|
|
170
174
|
model = payload.get("model") if isinstance(payload, dict) else None
|
|
171
|
-
print(
|
|
175
|
+
print(
|
|
176
|
+
f"[task:proxy:{route}] model={model} messages={msg_count} tools={tool_count}",
|
|
177
|
+
flush=True,
|
|
178
|
+
)
|
|
172
179
|
except Exception: # pragma: no cover - best effort logging
|
|
173
180
|
pass
|
|
174
181
|
|
|
175
182
|
system_hint = proxy.system_hint
|
|
176
183
|
|
|
177
184
|
if proxy.enable_openai:
|
|
185
|
+
|
|
178
186
|
@app.post("/proxy/v1/chat/completions", dependencies=[Depends(auth_dependency)])
|
|
179
187
|
async def proxy_openai(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
|
|
180
188
|
key = get_openai_key_or_503()
|
|
@@ -187,6 +195,7 @@ def _build_proxy_routes(
|
|
|
187
195
|
return to_jsonable(sanitized)
|
|
188
196
|
|
|
189
197
|
if proxy.enable_groq:
|
|
198
|
+
|
|
190
199
|
@app.post("/proxy/groq/v1/chat/completions", dependencies=[Depends(auth_dependency)])
|
|
191
200
|
async def proxy_groq(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
|
|
192
201
|
key = get_groq_key_or_503()
|
|
@@ -194,7 +203,9 @@ def _build_proxy_routes(
|
|
|
194
203
|
payload = prepare_for_groq(model, body)
|
|
195
204
|
payload = inject_system_hint(payload, system_hint or "")
|
|
196
205
|
_log_proxy("groq", payload)
|
|
197
|
-
data = await _call_vendor(
|
|
206
|
+
data = await _call_vendor(
|
|
207
|
+
proxy.groq_url.rstrip("/"), payload, {"Authorization": f"Bearer {key}"}
|
|
208
|
+
)
|
|
198
209
|
sanitized = synthesize_tool_call_if_missing(data)
|
|
199
210
|
return to_jsonable(sanitized)
|
|
200
211
|
|
|
@@ -278,7 +289,15 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
|
|
|
278
289
|
async def health(request: Request) -> Mapping[str, Any]:
|
|
279
290
|
# If we got here, auth_dependency already verified the key exactly matches
|
|
280
291
|
expected = normalize_environment_api_key()
|
|
281
|
-
return to_jsonable(
|
|
292
|
+
return to_jsonable(
|
|
293
|
+
{
|
|
294
|
+
"healthy": True,
|
|
295
|
+
"auth": {
|
|
296
|
+
"required": True,
|
|
297
|
+
"expected_prefix": (expected[:6] + "...") if expected else "<unset>",
|
|
298
|
+
},
|
|
299
|
+
}
|
|
300
|
+
)
|
|
282
301
|
|
|
283
302
|
@app.get("/info", dependencies=[Depends(auth_dependency)])
|
|
284
303
|
async def info() -> Mapping[str, Any]:
|
|
@@ -335,6 +354,7 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
|
|
|
335
354
|
raise TypeError("Rollout executor must return RolloutResponse or mapping")
|
|
336
355
|
|
|
337
356
|
if cfg.expose_debug_env:
|
|
357
|
+
|
|
338
358
|
@app.get("/debug/env", dependencies=[Depends(auth_dependency)])
|
|
339
359
|
async def debug_env() -> Mapping[str, Any]:
|
|
340
360
|
def _mask(value: str | None) -> str:
|
|
@@ -387,6 +407,12 @@ def run_task_app(
|
|
|
387
407
|
print(f"[task:server] Loaded environment from: {', '.join(loaded_files)}", flush=True)
|
|
388
408
|
|
|
389
409
|
config = config_factory()
|
|
410
|
+
# Defensive: ensure the factory produced a valid TaskAppConfig to avoid
|
|
411
|
+
# confusing attribute errors later in the boot sequence.
|
|
412
|
+
if not isinstance(config, TaskAppConfig): # type: ignore[arg-type]
|
|
413
|
+
raise TypeError(
|
|
414
|
+
f"Task app config_factory must return TaskAppConfig, got {type(config).__name__}"
|
|
415
|
+
)
|
|
390
416
|
app = create_task_app(config)
|
|
391
417
|
|
|
392
418
|
try:
|
synth_ai/task/tracing_utils.py
CHANGED
|
@@ -45,7 +45,9 @@ def resolve_tracing_db_url() -> str | None:
|
|
|
45
45
|
return f"sqlite+aiosqlite:///{fallback_path}"
|
|
46
46
|
|
|
47
47
|
|
|
48
|
-
def build_tracer_factory(
|
|
48
|
+
def build_tracer_factory(
|
|
49
|
+
make_tracer: Callable[..., Any], *, enabled: bool, db_url: str | None
|
|
50
|
+
) -> Callable[[], Any] | None:
|
|
49
51
|
"""Return a factory that instantiates a tracer when enabled, else None."""
|
|
50
52
|
|
|
51
53
|
if not enabled:
|
|
@@ -74,6 +76,9 @@ def resolve_sft_output_dir() -> str | None:
|
|
|
74
76
|
def unique_sft_path(base_dir: str, *, run_id: str) -> Path:
|
|
75
77
|
"""Return a unique JSONL path for an SFT record batch."""
|
|
76
78
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
+
from datetime import datetime
|
|
80
|
+
|
|
81
|
+
now = datetime.now()
|
|
82
|
+
timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
|
|
83
|
+
name = f"{run_id}_{timestamp}.jsonl"
|
|
79
84
|
return Path(base_dir) / name
|
synth_ai/task/validators.py
CHANGED
synth_ai/task/vendors.py
CHANGED
synth_ai/tracing_v3/db_config.py
CHANGED
|
@@ -4,6 +4,7 @@ Centralized database configuration for v3 tracing.
|
|
|
4
4
|
|
|
5
5
|
import logging
|
|
6
6
|
import os
|
|
7
|
+
import shutil
|
|
7
8
|
from typing import TYPE_CHECKING, Optional
|
|
8
9
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
@@ -30,7 +31,7 @@ class DatabaseConfig:
|
|
|
30
31
|
http_port: HTTP port for sqld daemon. If None, uses DEFAULT_HTTP_PORT from serve.sh.
|
|
31
32
|
use_sqld: Whether to use sqld daemon or direct SQLite.
|
|
32
33
|
"""
|
|
33
|
-
self.use_sqld = use_sqld
|
|
34
|
+
self.use_sqld = use_sqld and self._sqld_binary_available()
|
|
34
35
|
self.http_port = http_port or int(os.getenv("SQLD_HTTP_PORT", self.DEFAULT_HTTP_PORT))
|
|
35
36
|
self._daemon: SqldDaemon | None = None
|
|
36
37
|
|
|
@@ -70,6 +71,30 @@ class DatabaseConfig:
|
|
|
70
71
|
# SQLite URLs need 3 slashes for absolute paths
|
|
71
72
|
return f"sqlite+aiosqlite:///{actual_db_path}"
|
|
72
73
|
|
|
74
|
+
def _sqld_binary_available(self) -> bool:
|
|
75
|
+
"""Check if the sqld (Turso) binary is available on PATH."""
|
|
76
|
+
# Respect explicit SQLD_BINARY override when present
|
|
77
|
+
binary_override = os.getenv("SQLD_BINARY")
|
|
78
|
+
candidates = [binary_override, "sqld", "libsql-server"]
|
|
79
|
+
|
|
80
|
+
for candidate in candidates:
|
|
81
|
+
if candidate and shutil.which(candidate):
|
|
82
|
+
return True
|
|
83
|
+
|
|
84
|
+
if binary_override:
|
|
85
|
+
logger.warning(
|
|
86
|
+
"Configured SQLD_BINARY='%s' but the executable was not found on PATH. "
|
|
87
|
+
"Falling back to direct SQLite.",
|
|
88
|
+
binary_override,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
logger.warning(
|
|
92
|
+
"sqld binary not detected; falling back to SQLite-only mode. "
|
|
93
|
+
"Install Turso's sqld or set SQLD_BINARY to enable the Turso daemon."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return False
|
|
97
|
+
|
|
73
98
|
def start_daemon(self, wait_time: float = 2.0):
|
|
74
99
|
"""
|
|
75
100
|
Start the sqld daemon if configured.
|
|
@@ -166,8 +166,9 @@ async def main():
|
|
|
166
166
|
|
|
167
167
|
tracer.hooks.register("event_recorded", count_events, name="event_counter")
|
|
168
168
|
|
|
169
|
-
async with
|
|
170
|
-
"
|
|
169
|
+
async with (
|
|
170
|
+
tracer.session(metadata={"example": "hooks"}) as session_id,
|
|
171
|
+
tracer.timestep("hook_test"),
|
|
171
172
|
):
|
|
172
173
|
for i in range(3):
|
|
173
174
|
event = RuntimeEvent(
|
synth_ai/tracing_v3/hooks.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
"""Hook system for extending tracing functionality.
|
|
3
4
|
|
|
4
5
|
The hook system provides a flexible way to extend the tracing system without
|
|
@@ -202,6 +203,7 @@ def create_default_hooks() -> HookManager:
|
|
|
202
203
|
# Example: Log session starts - useful for debugging and monitoring
|
|
203
204
|
async def log_session_start(session_id: str, metadata: dict[str, Any]):
|
|
204
205
|
import os
|
|
206
|
+
|
|
205
207
|
if os.getenv("SYNTH_TRACE_VERBOSE", "0") in ("1", "true", "True"):
|
|
206
208
|
print(f"Session started: {session_id}")
|
|
207
209
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
"""Main SessionTracer class for tracing v3."""
|
|
3
4
|
|
|
4
5
|
import asyncio
|
|
@@ -110,7 +111,9 @@ class SessionTracer:
|
|
|
110
111
|
|
|
111
112
|
# Ensure session row exists for incremental writes
|
|
112
113
|
if self.db:
|
|
113
|
-
await self.db.ensure_session(
|
|
114
|
+
await self.db.ensure_session(
|
|
115
|
+
session_id, created_at=self._current_trace.created_at, metadata=metadata or {}
|
|
116
|
+
)
|
|
114
117
|
|
|
115
118
|
# Trigger hooks
|
|
116
119
|
await self.hooks.trigger(
|
|
@@ -435,7 +438,14 @@ class SessionTracer:
|
|
|
435
438
|
# Reward recording helpers
|
|
436
439
|
# -------------------------------
|
|
437
440
|
|
|
438
|
-
async def record_outcome_reward(
|
|
441
|
+
async def record_outcome_reward(
|
|
442
|
+
self,
|
|
443
|
+
*,
|
|
444
|
+
total_reward: int,
|
|
445
|
+
achievements_count: int,
|
|
446
|
+
total_steps: int,
|
|
447
|
+
reward_metadata: dict[str, Any] | None = None,
|
|
448
|
+
) -> int | None:
|
|
439
449
|
"""Record an episode-level outcome reward for the current session."""
|
|
440
450
|
if self._current_trace is None:
|
|
441
451
|
raise RuntimeError("No active session")
|
|
@@ -462,7 +472,18 @@ class SessionTracer:
|
|
|
462
472
|
|
|
463
473
|
# StepMetrics removed in favor of event_rewards; use record_event_reward for per-turn shaped values
|
|
464
474
|
|
|
465
|
-
async def record_event_reward(
|
|
475
|
+
async def record_event_reward(
|
|
476
|
+
self,
|
|
477
|
+
*,
|
|
478
|
+
event_id: int,
|
|
479
|
+
message_id: int | None = None,
|
|
480
|
+
turn_number: int | None = None,
|
|
481
|
+
reward_value: float = 0.0,
|
|
482
|
+
reward_type: str | None = None,
|
|
483
|
+
key: str | None = None,
|
|
484
|
+
annotation: dict[str, Any] | None = None,
|
|
485
|
+
source: str | None = None,
|
|
486
|
+
) -> int | None:
|
|
466
487
|
"""Record a first-class event-level reward with optional annotations."""
|
|
467
488
|
if self._current_trace is None:
|
|
468
489
|
raise RuntimeError("No active session")
|
|
@@ -54,7 +54,10 @@ class TraceStorage(ABC):
|
|
|
54
54
|
|
|
55
55
|
@abstractmethod
|
|
56
56
|
async def get_model_usage(
|
|
57
|
-
self,
|
|
57
|
+
self,
|
|
58
|
+
start_date: datetime | None = None,
|
|
59
|
+
end_date: datetime | None = None,
|
|
60
|
+
model_name: str | None = None,
|
|
58
61
|
) -> Any:
|
|
59
62
|
"""Get model usage statistics.
|
|
60
63
|
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
"""Async SQLAlchemy-based trace manager for Turso/sqld.
|
|
3
4
|
|
|
4
5
|
This module provides the database interface for the tracing system using
|
|
@@ -139,6 +140,7 @@ class AsyncSQLTraceManager:
|
|
|
139
140
|
)
|
|
140
141
|
# Ensure PRAGMA foreign_keys=ON for every connection
|
|
141
142
|
try:
|
|
143
|
+
|
|
142
144
|
@event.listens_for(self.engine.sync_engine, "connect")
|
|
143
145
|
def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
|
|
144
146
|
try:
|
|
@@ -408,9 +410,7 @@ class AsyncSQLTraceManager:
|
|
|
408
410
|
],
|
|
409
411
|
}
|
|
410
412
|
|
|
411
|
-
async def query_traces(
|
|
412
|
-
self, query: str, params: dict[str, Any] | None = None
|
|
413
|
-
) -> Any:
|
|
413
|
+
async def query_traces(self, query: str, params: dict[str, Any] | None = None) -> Any:
|
|
414
414
|
"""Execute a query and return results.
|
|
415
415
|
|
|
416
416
|
Returns a pandas DataFrame when pandas is available; otherwise a
|
|
@@ -577,10 +577,18 @@ class AsyncSQLTraceManager:
|
|
|
577
577
|
# Incremental insert helpers
|
|
578
578
|
# -------------------------------
|
|
579
579
|
|
|
580
|
-
async def ensure_session(
|
|
580
|
+
async def ensure_session(
|
|
581
|
+
self,
|
|
582
|
+
session_id: str,
|
|
583
|
+
*,
|
|
584
|
+
created_at: datetime | None = None,
|
|
585
|
+
metadata: dict[str, Any] | None = None,
|
|
586
|
+
):
|
|
581
587
|
"""Ensure a DB session row exists for session_id."""
|
|
582
588
|
async with self.session() as sess:
|
|
583
|
-
result = await sess.execute(
|
|
589
|
+
result = await sess.execute(
|
|
590
|
+
select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
|
|
591
|
+
)
|
|
584
592
|
existing = result.scalar_one_or_none()
|
|
585
593
|
if existing:
|
|
586
594
|
return
|
|
@@ -595,11 +603,23 @@ class AsyncSQLTraceManager:
|
|
|
595
603
|
sess.add(row)
|
|
596
604
|
await sess.commit()
|
|
597
605
|
|
|
598
|
-
async def ensure_timestep(
|
|
606
|
+
async def ensure_timestep(
|
|
607
|
+
self,
|
|
608
|
+
session_id: str,
|
|
609
|
+
*,
|
|
610
|
+
step_id: str,
|
|
611
|
+
step_index: int,
|
|
612
|
+
turn_number: int | None = None,
|
|
613
|
+
started_at: datetime | None = None,
|
|
614
|
+
completed_at: datetime | None = None,
|
|
615
|
+
metadata: dict[str, Any] | None = None,
|
|
616
|
+
) -> int:
|
|
599
617
|
"""Ensure a timestep row exists; return its DB id."""
|
|
600
618
|
async with self.session() as sess:
|
|
601
619
|
result = await sess.execute(
|
|
602
|
-
select(DBSessionTimestep).where(
|
|
620
|
+
select(DBSessionTimestep).where(
|
|
621
|
+
DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id
|
|
622
|
+
)
|
|
603
623
|
)
|
|
604
624
|
row = result.scalar_one_or_none()
|
|
605
625
|
if row:
|
|
@@ -626,7 +646,17 @@ class AsyncSQLTraceManager:
|
|
|
626
646
|
await sess.commit()
|
|
627
647
|
return row.id
|
|
628
648
|
|
|
629
|
-
async def insert_message_row(
|
|
649
|
+
async def insert_message_row(
|
|
650
|
+
self,
|
|
651
|
+
session_id: str,
|
|
652
|
+
*,
|
|
653
|
+
timestep_db_id: int | None,
|
|
654
|
+
message_type: str,
|
|
655
|
+
content: str,
|
|
656
|
+
event_time: float | None = None,
|
|
657
|
+
message_time: int | None = None,
|
|
658
|
+
metadata: dict[str, Any] | None = None,
|
|
659
|
+
) -> int:
|
|
630
660
|
"""Insert a message and return its id."""
|
|
631
661
|
async with self.session() as sess:
|
|
632
662
|
db_msg = DBMessage(
|
|
@@ -649,8 +679,16 @@ class AsyncSQLTraceManager:
|
|
|
649
679
|
await sess.commit()
|
|
650
680
|
return db_msg.id
|
|
651
681
|
|
|
652
|
-
async def insert_event_row(
|
|
682
|
+
async def insert_event_row(
|
|
683
|
+
self,
|
|
684
|
+
session_id: str,
|
|
685
|
+
*,
|
|
686
|
+
timestep_db_id: int | None,
|
|
687
|
+
event: EnvironmentEvent | LMCAISEvent | RuntimeEvent,
|
|
688
|
+
metadata_override: dict[str, Any] | None = None,
|
|
689
|
+
) -> int:
|
|
653
690
|
"""Insert an event and return its id."""
|
|
691
|
+
|
|
654
692
|
def to_cents(cost: float | None) -> int | None:
|
|
655
693
|
return int(cost * 100) if cost is not None else None
|
|
656
694
|
|
|
@@ -669,35 +707,41 @@ class AsyncSQLTraceManager:
|
|
|
669
707
|
from dataclasses import asdict
|
|
670
708
|
|
|
671
709
|
call_records_data = [asdict(record) for record in event.call_records]
|
|
672
|
-
event_data.update(
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
710
|
+
event_data.update(
|
|
711
|
+
{
|
|
712
|
+
"event_type": "cais",
|
|
713
|
+
"model_name": event.model_name,
|
|
714
|
+
"provider": event.provider,
|
|
715
|
+
"input_tokens": event.input_tokens,
|
|
716
|
+
"output_tokens": event.output_tokens,
|
|
717
|
+
"total_tokens": event.total_tokens,
|
|
718
|
+
"cost_usd": to_cents(event.cost_usd),
|
|
719
|
+
"latency_ms": event.latency_ms,
|
|
720
|
+
"span_id": event.span_id,
|
|
721
|
+
"trace_id": event.trace_id,
|
|
722
|
+
"system_state_before": event.system_state_before,
|
|
723
|
+
"system_state_after": event.system_state_after,
|
|
724
|
+
"call_records": call_records_data,
|
|
725
|
+
}
|
|
726
|
+
)
|
|
687
727
|
elif isinstance(event, EnvironmentEvent):
|
|
688
|
-
event_data.update(
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
728
|
+
event_data.update(
|
|
729
|
+
{
|
|
730
|
+
"event_type": "environment",
|
|
731
|
+
"reward": event.reward,
|
|
732
|
+
"terminated": event.terminated,
|
|
733
|
+
"truncated": event.truncated,
|
|
734
|
+
"system_state_before": event.system_state_before,
|
|
735
|
+
"system_state_after": event.system_state_after,
|
|
736
|
+
}
|
|
737
|
+
)
|
|
696
738
|
elif isinstance(event, RuntimeEvent):
|
|
697
|
-
event_data.update(
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
739
|
+
event_data.update(
|
|
740
|
+
{
|
|
741
|
+
"event_type": "runtime",
|
|
742
|
+
"event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
|
|
743
|
+
}
|
|
744
|
+
)
|
|
701
745
|
else:
|
|
702
746
|
event_data["event_type"] = event.__class__.__name__.lower()
|
|
703
747
|
|
|
@@ -718,7 +762,15 @@ class AsyncSQLTraceManager:
|
|
|
718
762
|
# Reward helpers
|
|
719
763
|
# -------------------------------
|
|
720
764
|
|
|
721
|
-
async def insert_outcome_reward(
|
|
765
|
+
async def insert_outcome_reward(
|
|
766
|
+
self,
|
|
767
|
+
session_id: str,
|
|
768
|
+
*,
|
|
769
|
+
total_reward: int,
|
|
770
|
+
achievements_count: int,
|
|
771
|
+
total_steps: int,
|
|
772
|
+
reward_metadata: dict | None = None,
|
|
773
|
+
) -> int:
|
|
722
774
|
async with self.session() as sess:
|
|
723
775
|
row = DBOutcomeReward(
|
|
724
776
|
session_id=session_id,
|
|
@@ -732,7 +784,19 @@ class AsyncSQLTraceManager:
|
|
|
732
784
|
await sess.commit()
|
|
733
785
|
return row.id
|
|
734
786
|
|
|
735
|
-
async def insert_event_reward(
|
|
787
|
+
async def insert_event_reward(
|
|
788
|
+
self,
|
|
789
|
+
session_id: str,
|
|
790
|
+
*,
|
|
791
|
+
event_id: int,
|
|
792
|
+
message_id: int | None = None,
|
|
793
|
+
turn_number: int | None = None,
|
|
794
|
+
reward_value: float = 0.0,
|
|
795
|
+
reward_type: str | None = None,
|
|
796
|
+
key: str | None = None,
|
|
797
|
+
annotation: dict[str, Any] | None = None,
|
|
798
|
+
source: str | None = None,
|
|
799
|
+
) -> int:
|
|
736
800
|
async with self.session() as sess:
|
|
737
801
|
row = DBEventReward(
|
|
738
802
|
event_id=event_id,
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
"""SQLAlchemy declarative models for tracing v3."""
|
|
3
4
|
|
|
4
5
|
import json
|
|
@@ -452,7 +453,9 @@ class EventReward(Base):
|
|
|
452
453
|
message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
|
|
453
454
|
turn_number = Column(Integer, nullable=True)
|
|
454
455
|
reward_value = Column(Float, nullable=False, default=0.0)
|
|
455
|
-
reward_type = Column(
|
|
456
|
+
reward_type = Column(
|
|
457
|
+
String, nullable=True
|
|
458
|
+
) # shaped | sparse | achievement | penalty | evaluator | human
|
|
456
459
|
key = Column(String, nullable=True) # e.g., achievement name
|
|
457
460
|
annotation = Column(JSONText) # free-form JSON
|
|
458
461
|
source = Column(String, nullable=True) # environment | runner | evaluator | human
|