synth-ai 0.2.9.dev4__py3-none-any.whl → 0.2.9.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (157) hide show
  1. examples/common_old/backend.py +0 -1
  2. examples/crafter_debug_render.py +15 -6
  3. examples/evals_old/compare_models.py +1 -0
  4. examples/finetuning_old/_backup_synth_qwen/filter_traces_achievements.py +6 -2
  5. examples/finetuning_old/_backup_synth_qwen/react_agent_lm.py +4 -4
  6. examples/finetuning_old/_backup_synth_qwen/sft_kickoff.py +4 -3
  7. examples/finetuning_old/synth_qwen_v1/filter_traces_achievements.py +6 -2
  8. examples/finetuning_old/synth_qwen_v1/finetune.py +1 -1
  9. examples/finetuning_old/synth_qwen_v1/hello_ft_model.py +4 -4
  10. examples/finetuning_old/synth_qwen_v1/infer.py +1 -2
  11. examples/finetuning_old/synth_qwen_v1/poll.py +4 -2
  12. examples/finetuning_old/synth_qwen_v1/prepare_data.py +8 -8
  13. examples/finetuning_old/synth_qwen_v1/react_agent_lm.py +5 -4
  14. examples/finetuning_old/synth_qwen_v1/run_crafter_sft_job.py +11 -8
  15. examples/finetuning_old/synth_qwen_v1/run_ft_job.py +17 -12
  16. examples/finetuning_old/synth_qwen_v1/upload_data.py +1 -1
  17. examples/finetuning_old/synth_qwen_v1/util.py +7 -2
  18. examples/rl/configs/eval_base_qwen.toml +1 -1
  19. examples/rl/configs/rl_from_base_qwen17.toml +1 -1
  20. examples/rl/download_dataset.py +26 -10
  21. examples/rl/run_eval.py +17 -15
  22. examples/rl/run_rl_and_save.py +24 -7
  23. examples/rl/task_app/math_single_step.py +128 -11
  24. examples/rl/task_app/math_task_app.py +11 -3
  25. examples/rl_old/task_app.py +222 -53
  26. examples/warming_up_to_rl/analyze_trace_db.py +7 -5
  27. examples/warming_up_to_rl/export_trace_sft.py +141 -16
  28. examples/warming_up_to_rl/groq_test.py +11 -4
  29. examples/warming_up_to_rl/manage_secrets.py +15 -6
  30. examples/warming_up_to_rl/readme.md +9 -2
  31. examples/warming_up_to_rl/run_eval.py +108 -30
  32. examples/warming_up_to_rl/run_fft_and_save.py +128 -52
  33. examples/warming_up_to_rl/run_local_rollout.py +87 -36
  34. examples/warming_up_to_rl/run_local_rollout_modal.py +113 -25
  35. examples/warming_up_to_rl/run_local_rollout_parallel.py +80 -16
  36. examples/warming_up_to_rl/run_local_rollout_traced.py +125 -20
  37. examples/warming_up_to_rl/run_rl_and_save.py +31 -7
  38. examples/warming_up_to_rl/run_rollout_remote.py +37 -10
  39. examples/warming_up_to_rl/task_app/grpo_crafter.py +90 -27
  40. examples/warming_up_to_rl/task_app/grpo_crafter_task_app.py +9 -27
  41. examples/warming_up_to_rl/task_app/synth_envs_hosted/environment_routes.py +46 -108
  42. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/__init__.py +1 -1
  43. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/__init__.py +1 -1
  44. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/app.py +1 -1
  45. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/environment.py +50 -17
  46. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/policy.py +35 -21
  47. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/react_agent.py +8 -4
  48. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/shared.py +29 -26
  49. examples/warming_up_to_rl/task_app/synth_envs_hosted/envs/crafter/tools.py +1 -1
  50. examples/warming_up_to_rl/task_app/synth_envs_hosted/hosted_app.py +17 -13
  51. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/__init__.py +1 -1
  52. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +106 -63
  53. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +82 -84
  54. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +76 -59
  55. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/__init__.py +1 -1
  56. examples/warming_up_to_rl/task_app/synth_envs_hosted/storage/volume.py +43 -49
  57. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_service.py +5 -15
  58. synth_ai/__init__.py +1 -0
  59. synth_ai/api/train/builders.py +34 -10
  60. synth_ai/api/train/cli.py +172 -32
  61. synth_ai/api/train/config_finder.py +59 -4
  62. synth_ai/api/train/env_resolver.py +32 -14
  63. synth_ai/api/train/pollers.py +11 -3
  64. synth_ai/api/train/task_app.py +4 -1
  65. synth_ai/api/train/utils.py +20 -4
  66. synth_ai/cli/__init__.py +11 -4
  67. synth_ai/cli/balance.py +1 -1
  68. synth_ai/cli/demo.py +19 -5
  69. synth_ai/cli/rl_demo.py +75 -16
  70. synth_ai/cli/root.py +116 -37
  71. synth_ai/cli/task_apps.py +1286 -170
  72. synth_ai/cli/traces.py +1 -0
  73. synth_ai/cli/turso.py +73 -0
  74. synth_ai/core/experiment.py +0 -2
  75. synth_ai/demo_registry.py +67 -30
  76. synth_ai/demos/core/cli.py +493 -164
  77. synth_ai/demos/demo_task_apps/core.py +50 -6
  78. synth_ai/demos/demo_task_apps/crafter/configs/crafter_fft_4b.toml +2 -3
  79. synth_ai/demos/demo_task_apps/crafter/grpo_crafter_task_app.py +36 -28
  80. synth_ai/demos/demo_task_apps/math/_common.py +1 -2
  81. synth_ai/demos/demo_task_apps/math/deploy_modal.py +0 -2
  82. synth_ai/demos/demo_task_apps/math/modal_task_app.py +168 -65
  83. synth_ai/demos/demo_task_apps/math/task_app_entry.py +0 -1
  84. synth_ai/environments/examples/bandit/engine.py +12 -4
  85. synth_ai/environments/examples/bandit/taskset.py +4 -4
  86. synth_ai/environments/reproducibility/tree.py +3 -1
  87. synth_ai/environments/service/core_routes.py +6 -2
  88. synth_ai/evals/base.py +0 -2
  89. synth_ai/experimental/synth_oss.py +11 -12
  90. synth_ai/handshake.py +3 -1
  91. synth_ai/http_client.py +31 -7
  92. synth_ai/inference/__init__.py +0 -2
  93. synth_ai/inference/client.py +8 -4
  94. synth_ai/jobs/client.py +40 -10
  95. synth_ai/learning/client.py +33 -8
  96. synth_ai/learning/config.py +0 -2
  97. synth_ai/learning/constants.py +0 -2
  98. synth_ai/learning/ft_client.py +6 -3
  99. synth_ai/learning/health.py +9 -2
  100. synth_ai/learning/jobs.py +17 -5
  101. synth_ai/learning/prompts/hello_world_in_context_injection_ex.py +1 -3
  102. synth_ai/learning/prompts/random_search.py +4 -1
  103. synth_ai/learning/prompts/run_random_search_banking77.py +6 -1
  104. synth_ai/learning/rl_client.py +42 -14
  105. synth_ai/learning/sse.py +0 -2
  106. synth_ai/learning/validators.py +6 -2
  107. synth_ai/lm/caching/ephemeral.py +1 -3
  108. synth_ai/lm/core/exceptions.py +0 -2
  109. synth_ai/lm/core/main.py +13 -1
  110. synth_ai/lm/core/synth_models.py +0 -1
  111. synth_ai/lm/core/vendor_clients.py +4 -2
  112. synth_ai/lm/overrides.py +2 -2
  113. synth_ai/lm/vendors/core/anthropic_api.py +7 -7
  114. synth_ai/lm/vendors/core/openai_api.py +2 -0
  115. synth_ai/lm/vendors/openai_standard.py +3 -1
  116. synth_ai/lm/vendors/openai_standard_responses.py +6 -3
  117. synth_ai/lm/vendors/supported/custom_endpoint.py +1 -3
  118. synth_ai/lm/vendors/synth_client.py +37 -10
  119. synth_ai/rl/__init__.py +0 -1
  120. synth_ai/rl/contracts.py +0 -2
  121. synth_ai/rl/env_keys.py +6 -1
  122. synth_ai/task/__init__.py +1 -0
  123. synth_ai/task/apps/__init__.py +11 -11
  124. synth_ai/task/auth.py +29 -17
  125. synth_ai/task/client.py +3 -1
  126. synth_ai/task/contracts.py +1 -0
  127. synth_ai/task/datasets.py +3 -1
  128. synth_ai/task/errors.py +3 -2
  129. synth_ai/task/health.py +0 -2
  130. synth_ai/task/json.py +0 -1
  131. synth_ai/task/proxy.py +2 -5
  132. synth_ai/task/rubrics.py +9 -3
  133. synth_ai/task/server.py +31 -5
  134. synth_ai/task/tracing_utils.py +8 -3
  135. synth_ai/task/validators.py +0 -1
  136. synth_ai/task/vendors.py +0 -1
  137. synth_ai/tracing_v3/db_config.py +26 -1
  138. synth_ai/tracing_v3/decorators.py +1 -0
  139. synth_ai/tracing_v3/examples/basic_usage.py +3 -2
  140. synth_ai/tracing_v3/hooks.py +2 -0
  141. synth_ai/tracing_v3/replica_sync.py +1 -0
  142. synth_ai/tracing_v3/session_tracer.py +24 -3
  143. synth_ai/tracing_v3/storage/base.py +4 -1
  144. synth_ai/tracing_v3/storage/factory.py +0 -1
  145. synth_ai/tracing_v3/turso/manager.py +102 -38
  146. synth_ai/tracing_v3/turso/models.py +4 -1
  147. synth_ai/tracing_v3/utils.py +1 -0
  148. synth_ai/v0/tracing/upload.py +32 -135
  149. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/METADATA +1 -1
  150. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/RECORD +154 -156
  151. examples/warming_up_to_rl/task_app/synth_envs_hosted/test_stepwise_rewards.py +0 -58
  152. synth_ai/environments/examples/sokoban/units/astar_common.py +0 -95
  153. synth_ai/install_sqld.sh +0 -40
  154. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/WHEEL +0 -0
  155. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/entry_points.txt +0 -0
  156. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/licenses/LICENSE +0 -0
  157. {synth_ai-0.2.9.dev4.dist-info → synth_ai-0.2.9.dev7.dist-info}/top_level.txt +0 -0
synth_ai/task/proxy.py CHANGED
@@ -179,7 +179,7 @@ def parse_tool_call_from_text(text: str) -> Tuple[list[str], str]:
179
179
  if m:
180
180
  items = [part.strip() for part in m.group(1).split(",") if part.strip()]
181
181
  if items:
182
- reasoning = text[:m.start()].strip()
182
+ reasoning = text[: m.start()].strip()
183
183
  return items, reasoning
184
184
 
185
185
  # Patterns like "Action 1: move_right"
@@ -242,9 +242,7 @@ def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str
242
242
  return openai_response
243
243
 
244
244
  new_message = copy.deepcopy(message)
245
- new_message["tool_calls"] = [
246
- _build_tool_call(actions, reasoning)
247
- ]
245
+ new_message["tool_calls"] = [_build_tool_call(actions, reasoning)]
248
246
  if "content" not in new_message:
249
247
  new_message["content"] = None
250
248
 
@@ -255,4 +253,3 @@ def synthesize_tool_call_if_missing(openai_response: dict[str, Any]) -> dict[str
255
253
  result = copy.deepcopy(openai_response)
256
254
  result["choices"] = new_choices
257
255
  return result
258
-
synth_ai/task/rubrics.py CHANGED
@@ -155,7 +155,9 @@ def _as_float(value: Any) -> Optional[float]:
155
155
  return None
156
156
 
157
157
 
158
- def _score(criteria: Iterable[Criterion], values: Dict[str, float], aggregation: str) -> Dict[str, Any]:
158
+ def _score(
159
+ criteria: Iterable[Criterion], values: Dict[str, float], aggregation: str
160
+ ) -> Dict[str, Any]:
159
161
  if aggregation == "inherit":
160
162
  aggregation = "weighted_sum"
161
163
  per_criterion: Dict[str, Dict[str, Any]] = {}
@@ -184,7 +186,9 @@ def _score(criteria: Iterable[Criterion], values: Dict[str, float], aggregation:
184
186
  }
185
187
 
186
188
 
187
- def score_events_against_rubric(events: list[dict[str, Any]], rubric: Rubric | None) -> Dict[str, Any]:
189
+ def score_events_against_rubric(
190
+ events: list[dict[str, Any]], rubric: Rubric | None
191
+ ) -> Dict[str, Any]:
188
192
  if rubric is None:
189
193
  return {"aggregation": "none", "score": None, "per_criterion": {}}
190
194
  values: Dict[str, float] = {}
@@ -203,7 +207,9 @@ def score_outcome_against_rubric(outcome: dict[str, Any], rubric: Rubric | None)
203
207
  return {"aggregation": "none", "score": None, "per_criterion": {}}
204
208
  values: Dict[str, float] = {}
205
209
  if isinstance(outcome, dict):
206
- candidates = outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
210
+ candidates = (
211
+ outcome.get("criteria") if isinstance(outcome.get("criteria"), dict) else outcome
212
+ )
207
213
  if isinstance(candidates, dict):
208
214
  for key, value in candidates.items():
209
215
  score = _as_float(value)
synth_ai/task/server.py CHANGED
@@ -120,7 +120,9 @@ def _ensure_task_info(obj: Any) -> TaskInfo:
120
120
  return obj
121
121
  if isinstance(obj, MutableMapping):
122
122
  return TaskInfo.model_validate(obj)
123
- raise TypeError(f"Task instance provider must yield TaskInfo-compatible objects (got {type(obj)!r})")
123
+ raise TypeError(
124
+ f"Task instance provider must yield TaskInfo-compatible objects (got {type(obj)!r})"
125
+ )
124
126
 
125
127
 
126
128
  def _normalise_seeds(values: Sequence[int]) -> list[int]:
@@ -140,7 +142,9 @@ def _build_proxy_routes(
140
142
  if not proxy:
141
143
  return
142
144
 
143
- async def _call_vendor(url: str, payload: dict[str, Any], headers: dict[str, str]) -> dict[str, Any]:
145
+ async def _call_vendor(
146
+ url: str, payload: dict[str, Any], headers: dict[str, str]
147
+ ) -> dict[str, Any]:
144
148
  async with httpx.AsyncClient(timeout=httpx.Timeout(600.0), follow_redirects=True) as client:
145
149
  response = await client.post(url, json=payload, headers=headers)
146
150
  data = (
@@ -168,13 +172,17 @@ def _build_proxy_routes(
168
172
  msg_count = len(messages) if isinstance(messages, list) else 0
169
173
  tool_count = len(payload.get("tools") or []) if isinstance(payload, dict) else 0
170
174
  model = payload.get("model") if isinstance(payload, dict) else None
171
- print(f"[task:proxy:{route}] model={model} messages={msg_count} tools={tool_count}", flush=True)
175
+ print(
176
+ f"[task:proxy:{route}] model={model} messages={msg_count} tools={tool_count}",
177
+ flush=True,
178
+ )
172
179
  except Exception: # pragma: no cover - best effort logging
173
180
  pass
174
181
 
175
182
  system_hint = proxy.system_hint
176
183
 
177
184
  if proxy.enable_openai:
185
+
178
186
  @app.post("/proxy/v1/chat/completions", dependencies=[Depends(auth_dependency)])
179
187
  async def proxy_openai(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
180
188
  key = get_openai_key_or_503()
@@ -187,6 +195,7 @@ def _build_proxy_routes(
187
195
  return to_jsonable(sanitized)
188
196
 
189
197
  if proxy.enable_groq:
198
+
190
199
  @app.post("/proxy/groq/v1/chat/completions", dependencies=[Depends(auth_dependency)])
191
200
  async def proxy_groq(body: dict[str, Any], request: Request) -> Any: # type: ignore[no-redef]
192
201
  key = get_groq_key_or_503()
@@ -194,7 +203,9 @@ def _build_proxy_routes(
194
203
  payload = prepare_for_groq(model, body)
195
204
  payload = inject_system_hint(payload, system_hint or "")
196
205
  _log_proxy("groq", payload)
197
- data = await _call_vendor(proxy.groq_url.rstrip("/"), payload, {"Authorization": f"Bearer {key}"})
206
+ data = await _call_vendor(
207
+ proxy.groq_url.rstrip("/"), payload, {"Authorization": f"Bearer {key}"}
208
+ )
198
209
  sanitized = synthesize_tool_call_if_missing(data)
199
210
  return to_jsonable(sanitized)
200
211
 
@@ -278,7 +289,15 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
278
289
  async def health(request: Request) -> Mapping[str, Any]:
279
290
  # If we got here, auth_dependency already verified the key exactly matches
280
291
  expected = normalize_environment_api_key()
281
- return to_jsonable({"healthy": True, "auth": {"required": True, "expected_prefix": (expected[:6] + '...') if expected else '<unset>'}})
292
+ return to_jsonable(
293
+ {
294
+ "healthy": True,
295
+ "auth": {
296
+ "required": True,
297
+ "expected_prefix": (expected[:6] + "...") if expected else "<unset>",
298
+ },
299
+ }
300
+ )
282
301
 
283
302
  @app.get("/info", dependencies=[Depends(auth_dependency)])
284
303
  async def info() -> Mapping[str, Any]:
@@ -335,6 +354,7 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
335
354
  raise TypeError("Rollout executor must return RolloutResponse or mapping")
336
355
 
337
356
  if cfg.expose_debug_env:
357
+
338
358
  @app.get("/debug/env", dependencies=[Depends(auth_dependency)])
339
359
  async def debug_env() -> Mapping[str, Any]:
340
360
  def _mask(value: str | None) -> str:
@@ -387,6 +407,12 @@ def run_task_app(
387
407
  print(f"[task:server] Loaded environment from: {', '.join(loaded_files)}", flush=True)
388
408
 
389
409
  config = config_factory()
410
+ # Defensive: ensure the factory produced a valid TaskAppConfig to avoid
411
+ # confusing attribute errors later in the boot sequence.
412
+ if not isinstance(config, TaskAppConfig): # type: ignore[arg-type]
413
+ raise TypeError(
414
+ f"Task app config_factory must return TaskAppConfig, got {type(config).__name__}"
415
+ )
390
416
  app = create_task_app(config)
391
417
 
392
418
  try:
@@ -45,7 +45,9 @@ def resolve_tracing_db_url() -> str | None:
45
45
  return f"sqlite+aiosqlite:///{fallback_path}"
46
46
 
47
47
 
48
- def build_tracer_factory(make_tracer: Callable[..., Any], *, enabled: bool, db_url: str | None) -> Callable[[], Any] | None:
48
+ def build_tracer_factory(
49
+ make_tracer: Callable[..., Any], *, enabled: bool, db_url: str | None
50
+ ) -> Callable[[], Any] | None:
49
51
  """Return a factory that instantiates a tracer when enabled, else None."""
50
52
 
51
53
  if not enabled:
@@ -74,6 +76,9 @@ def resolve_sft_output_dir() -> str | None:
74
76
  def unique_sft_path(base_dir: str, *, run_id: str) -> Path:
75
77
  """Return a unique JSONL path for an SFT record batch."""
76
78
 
77
- ts = int(time.time() * 1000)
78
- name = f"{run_id}_{ts}.jsonl"
79
+ from datetime import datetime
80
+
81
+ now = datetime.now()
82
+ timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")
83
+ name = f"{run_id}_{timestamp}.jsonl"
79
84
  return Path(base_dir) / name
@@ -9,4 +9,3 @@ def validate_task_app_url(url: str, *, name: str = "TASK_APP_BASE_URL") -> None:
9
9
  p = urlparse(url)
10
10
  if p.scheme not in ("http", "https") or not p.netloc:
11
11
  raise ValueError(f"Invalid {name}: malformed: {url}")
12
-
synth_ai/task/vendors.py CHANGED
@@ -58,4 +58,3 @@ def get_groq_key_or_503() -> str:
58
58
  if not key:
59
59
  raise http_exception(503, "missing_groq_api_key", "GROQ_API_KEY is not configured")
60
60
  return key
61
-
@@ -4,6 +4,7 @@ Centralized database configuration for v3 tracing.
4
4
 
5
5
  import logging
6
6
  import os
7
+ import shutil
7
8
  from typing import TYPE_CHECKING, Optional
8
9
 
9
10
  if TYPE_CHECKING:
@@ -30,7 +31,7 @@ class DatabaseConfig:
30
31
  http_port: HTTP port for sqld daemon. If None, uses DEFAULT_HTTP_PORT from serve.sh.
31
32
  use_sqld: Whether to use sqld daemon or direct SQLite.
32
33
  """
33
- self.use_sqld = use_sqld
34
+ self.use_sqld = use_sqld and self._sqld_binary_available()
34
35
  self.http_port = http_port or int(os.getenv("SQLD_HTTP_PORT", self.DEFAULT_HTTP_PORT))
35
36
  self._daemon: SqldDaemon | None = None
36
37
 
@@ -70,6 +71,30 @@ class DatabaseConfig:
70
71
  # SQLite URLs need 3 slashes for absolute paths
71
72
  return f"sqlite+aiosqlite:///{actual_db_path}"
72
73
 
74
+ def _sqld_binary_available(self) -> bool:
75
+ """Check if the sqld (Turso) binary is available on PATH."""
76
+ # Respect explicit SQLD_BINARY override when present
77
+ binary_override = os.getenv("SQLD_BINARY")
78
+ candidates = [binary_override, "sqld", "libsql-server"]
79
+
80
+ for candidate in candidates:
81
+ if candidate and shutil.which(candidate):
82
+ return True
83
+
84
+ if binary_override:
85
+ logger.warning(
86
+ "Configured SQLD_BINARY='%s' but the executable was not found on PATH. "
87
+ "Falling back to direct SQLite.",
88
+ binary_override,
89
+ )
90
+ else:
91
+ logger.warning(
92
+ "sqld binary not detected; falling back to SQLite-only mode. "
93
+ "Install Turso's sqld or set SQLD_BINARY to enable the Turso daemon."
94
+ )
95
+
96
+ return False
97
+
73
98
  def start_daemon(self, wait_time: float = 2.0):
74
99
  """
75
100
  Start the sqld daemon if configured.
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """Async-aware decorators for tracing v3.
3
4
 
4
5
  This module provides decorators and context management utilities for the tracing
@@ -166,8 +166,9 @@ async def main():
166
166
 
167
167
  tracer.hooks.register("event_recorded", count_events, name="event_counter")
168
168
 
169
- async with tracer.session(metadata={"example": "hooks"}) as session_id, tracer.timestep(
170
- "hook_test"
169
+ async with (
170
+ tracer.session(metadata={"example": "hooks"}) as session_id,
171
+ tracer.timestep("hook_test"),
171
172
  ):
172
173
  for i in range(3):
173
174
  event = RuntimeEvent(
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """Hook system for extending tracing functionality.
3
4
 
4
5
  The hook system provides a flexible way to extend the tracing system without
@@ -202,6 +203,7 @@ def create_default_hooks() -> HookManager:
202
203
  # Example: Log session starts - useful for debugging and monitoring
203
204
  async def log_session_start(session_id: str, metadata: dict[str, Any]):
204
205
  import os
206
+
205
207
  if os.getenv("SYNTH_TRACE_VERBOSE", "0") in ("1", "true", "True"):
206
208
  print(f"Session started: {session_id}")
207
209
 
@@ -180,6 +180,7 @@ class ReplicaSync:
180
180
  # Request cancellation
181
181
  self._sync_task.cancel()
182
182
  import contextlib
183
+
183
184
  with contextlib.suppress(asyncio.CancelledError):
184
185
  # Wait for the task to finish
185
186
  await self._sync_task
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """Main SessionTracer class for tracing v3."""
3
4
 
4
5
  import asyncio
@@ -110,7 +111,9 @@ class SessionTracer:
110
111
 
111
112
  # Ensure session row exists for incremental writes
112
113
  if self.db:
113
- await self.db.ensure_session(session_id, created_at=self._current_trace.created_at, metadata=metadata or {})
114
+ await self.db.ensure_session(
115
+ session_id, created_at=self._current_trace.created_at, metadata=metadata or {}
116
+ )
114
117
 
115
118
  # Trigger hooks
116
119
  await self.hooks.trigger(
@@ -435,7 +438,14 @@ class SessionTracer:
435
438
  # Reward recording helpers
436
439
  # -------------------------------
437
440
 
438
- async def record_outcome_reward(self, *, total_reward: int, achievements_count: int, total_steps: int, reward_metadata: dict[str, Any] | None = None) -> int | None:
441
+ async def record_outcome_reward(
442
+ self,
443
+ *,
444
+ total_reward: int,
445
+ achievements_count: int,
446
+ total_steps: int,
447
+ reward_metadata: dict[str, Any] | None = None,
448
+ ) -> int | None:
439
449
  """Record an episode-level outcome reward for the current session."""
440
450
  if self._current_trace is None:
441
451
  raise RuntimeError("No active session")
@@ -462,7 +472,18 @@ class SessionTracer:
462
472
 
463
473
  # StepMetrics removed in favor of event_rewards; use record_event_reward for per-turn shaped values
464
474
 
465
- async def record_event_reward(self, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int | None:
475
+ async def record_event_reward(
476
+ self,
477
+ *,
478
+ event_id: int,
479
+ message_id: int | None = None,
480
+ turn_number: int | None = None,
481
+ reward_value: float = 0.0,
482
+ reward_type: str | None = None,
483
+ key: str | None = None,
484
+ annotation: dict[str, Any] | None = None,
485
+ source: str | None = None,
486
+ ) -> int | None:
466
487
  """Record a first-class event-level reward with optional annotations."""
467
488
  if self._current_trace is None:
468
489
  raise RuntimeError("No active session")
@@ -54,7 +54,10 @@ class TraceStorage(ABC):
54
54
 
55
55
  @abstractmethod
56
56
  async def get_model_usage(
57
- self, start_date: datetime | None = None, end_date: datetime | None = None, model_name: str | None = None
57
+ self,
58
+ start_date: datetime | None = None,
59
+ end_date: datetime | None = None,
60
+ model_name: str | None = None,
58
61
  ) -> Any:
59
62
  """Get model usage statistics.
60
63
 
@@ -1,6 +1,5 @@
1
1
  """Factory for creating storage instances."""
2
2
 
3
-
4
3
  from ..turso.manager import AsyncSQLTraceManager
5
4
  from .base import TraceStorage
6
5
  from .config import StorageBackend, StorageConfig
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """Async SQLAlchemy-based trace manager for Turso/sqld.
3
4
 
4
5
  This module provides the database interface for the tracing system using
@@ -139,6 +140,7 @@ class AsyncSQLTraceManager:
139
140
  )
140
141
  # Ensure PRAGMA foreign_keys=ON for every connection
141
142
  try:
143
+
142
144
  @event.listens_for(self.engine.sync_engine, "connect")
143
145
  def _set_sqlite_pragma(dbapi_connection, connection_record): # type: ignore[no-redef]
144
146
  try:
@@ -408,9 +410,7 @@ class AsyncSQLTraceManager:
408
410
  ],
409
411
  }
410
412
 
411
- async def query_traces(
412
- self, query: str, params: dict[str, Any] | None = None
413
- ) -> Any:
413
+ async def query_traces(self, query: str, params: dict[str, Any] | None = None) -> Any:
414
414
  """Execute a query and return results.
415
415
 
416
416
  Returns a pandas DataFrame when pandas is available; otherwise a
@@ -577,10 +577,18 @@ class AsyncSQLTraceManager:
577
577
  # Incremental insert helpers
578
578
  # -------------------------------
579
579
 
580
- async def ensure_session(self, session_id: str, *, created_at: datetime | None = None, metadata: dict[str, Any] | None = None):
580
+ async def ensure_session(
581
+ self,
582
+ session_id: str,
583
+ *,
584
+ created_at: datetime | None = None,
585
+ metadata: dict[str, Any] | None = None,
586
+ ):
581
587
  """Ensure a DB session row exists for session_id."""
582
588
  async with self.session() as sess:
583
- result = await sess.execute(select(DBSessionTrace).where(DBSessionTrace.session_id == session_id))
589
+ result = await sess.execute(
590
+ select(DBSessionTrace).where(DBSessionTrace.session_id == session_id)
591
+ )
584
592
  existing = result.scalar_one_or_none()
585
593
  if existing:
586
594
  return
@@ -595,11 +603,23 @@ class AsyncSQLTraceManager:
595
603
  sess.add(row)
596
604
  await sess.commit()
597
605
 
598
- async def ensure_timestep(self, session_id: str, *, step_id: str, step_index: int, turn_number: int | None = None, started_at: datetime | None = None, completed_at: datetime | None = None, metadata: dict[str, Any] | None = None) -> int:
606
+ async def ensure_timestep(
607
+ self,
608
+ session_id: str,
609
+ *,
610
+ step_id: str,
611
+ step_index: int,
612
+ turn_number: int | None = None,
613
+ started_at: datetime | None = None,
614
+ completed_at: datetime | None = None,
615
+ metadata: dict[str, Any] | None = None,
616
+ ) -> int:
599
617
  """Ensure a timestep row exists; return its DB id."""
600
618
  async with self.session() as sess:
601
619
  result = await sess.execute(
602
- select(DBSessionTimestep).where(DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id)
620
+ select(DBSessionTimestep).where(
621
+ DBSessionTimestep.session_id == session_id, DBSessionTimestep.step_id == step_id
622
+ )
603
623
  )
604
624
  row = result.scalar_one_or_none()
605
625
  if row:
@@ -626,7 +646,17 @@ class AsyncSQLTraceManager:
626
646
  await sess.commit()
627
647
  return row.id
628
648
 
629
- async def insert_message_row(self, session_id: str, *, timestep_db_id: int | None, message_type: str, content: str, event_time: float | None = None, message_time: int | None = None, metadata: dict[str, Any] | None = None) -> int:
649
+ async def insert_message_row(
650
+ self,
651
+ session_id: str,
652
+ *,
653
+ timestep_db_id: int | None,
654
+ message_type: str,
655
+ content: str,
656
+ event_time: float | None = None,
657
+ message_time: int | None = None,
658
+ metadata: dict[str, Any] | None = None,
659
+ ) -> int:
630
660
  """Insert a message and return its id."""
631
661
  async with self.session() as sess:
632
662
  db_msg = DBMessage(
@@ -649,8 +679,16 @@ class AsyncSQLTraceManager:
649
679
  await sess.commit()
650
680
  return db_msg.id
651
681
 
652
- async def insert_event_row(self, session_id: str, *, timestep_db_id: int | None, event: EnvironmentEvent | LMCAISEvent | RuntimeEvent, metadata_override: dict[str, Any] | None = None) -> int:
682
+ async def insert_event_row(
683
+ self,
684
+ session_id: str,
685
+ *,
686
+ timestep_db_id: int | None,
687
+ event: EnvironmentEvent | LMCAISEvent | RuntimeEvent,
688
+ metadata_override: dict[str, Any] | None = None,
689
+ ) -> int:
653
690
  """Insert an event and return its id."""
691
+
654
692
  def to_cents(cost: float | None) -> int | None:
655
693
  return int(cost * 100) if cost is not None else None
656
694
 
@@ -669,35 +707,41 @@ class AsyncSQLTraceManager:
669
707
  from dataclasses import asdict
670
708
 
671
709
  call_records_data = [asdict(record) for record in event.call_records]
672
- event_data.update({
673
- "event_type": "cais",
674
- "model_name": event.model_name,
675
- "provider": event.provider,
676
- "input_tokens": event.input_tokens,
677
- "output_tokens": event.output_tokens,
678
- "total_tokens": event.total_tokens,
679
- "cost_usd": to_cents(event.cost_usd),
680
- "latency_ms": event.latency_ms,
681
- "span_id": event.span_id,
682
- "trace_id": event.trace_id,
683
- "system_state_before": event.system_state_before,
684
- "system_state_after": event.system_state_after,
685
- "call_records": call_records_data,
686
- })
710
+ event_data.update(
711
+ {
712
+ "event_type": "cais",
713
+ "model_name": event.model_name,
714
+ "provider": event.provider,
715
+ "input_tokens": event.input_tokens,
716
+ "output_tokens": event.output_tokens,
717
+ "total_tokens": event.total_tokens,
718
+ "cost_usd": to_cents(event.cost_usd),
719
+ "latency_ms": event.latency_ms,
720
+ "span_id": event.span_id,
721
+ "trace_id": event.trace_id,
722
+ "system_state_before": event.system_state_before,
723
+ "system_state_after": event.system_state_after,
724
+ "call_records": call_records_data,
725
+ }
726
+ )
687
727
  elif isinstance(event, EnvironmentEvent):
688
- event_data.update({
689
- "event_type": "environment",
690
- "reward": event.reward,
691
- "terminated": event.terminated,
692
- "truncated": event.truncated,
693
- "system_state_before": event.system_state_before,
694
- "system_state_after": event.system_state_after,
695
- })
728
+ event_data.update(
729
+ {
730
+ "event_type": "environment",
731
+ "reward": event.reward,
732
+ "terminated": event.terminated,
733
+ "truncated": event.truncated,
734
+ "system_state_before": event.system_state_before,
735
+ "system_state_after": event.system_state_after,
736
+ }
737
+ )
696
738
  elif isinstance(event, RuntimeEvent):
697
- event_data.update({
698
- "event_type": "runtime",
699
- "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
700
- })
739
+ event_data.update(
740
+ {
741
+ "event_type": "runtime",
742
+ "event_metadata_json": {**(event.metadata or {}), "actions": event.actions},
743
+ }
744
+ )
701
745
  else:
702
746
  event_data["event_type"] = event.__class__.__name__.lower()
703
747
 
@@ -718,7 +762,15 @@ class AsyncSQLTraceManager:
718
762
  # Reward helpers
719
763
  # -------------------------------
720
764
 
721
- async def insert_outcome_reward(self, session_id: str, *, total_reward: int, achievements_count: int, total_steps: int, reward_metadata: dict | None = None) -> int:
765
+ async def insert_outcome_reward(
766
+ self,
767
+ session_id: str,
768
+ *,
769
+ total_reward: int,
770
+ achievements_count: int,
771
+ total_steps: int,
772
+ reward_metadata: dict | None = None,
773
+ ) -> int:
722
774
  async with self.session() as sess:
723
775
  row = DBOutcomeReward(
724
776
  session_id=session_id,
@@ -732,7 +784,19 @@ class AsyncSQLTraceManager:
732
784
  await sess.commit()
733
785
  return row.id
734
786
 
735
- async def insert_event_reward(self, session_id: str, *, event_id: int, message_id: int | None = None, turn_number: int | None = None, reward_value: float = 0.0, reward_type: str | None = None, key: str | None = None, annotation: dict[str, Any] | None = None, source: str | None = None) -> int:
787
+ async def insert_event_reward(
788
+ self,
789
+ session_id: str,
790
+ *,
791
+ event_id: int,
792
+ message_id: int | None = None,
793
+ turn_number: int | None = None,
794
+ reward_value: float = 0.0,
795
+ reward_type: str | None = None,
796
+ key: str | None = None,
797
+ annotation: dict[str, Any] | None = None,
798
+ source: str | None = None,
799
+ ) -> int:
736
800
  async with self.session() as sess:
737
801
  row = DBEventReward(
738
802
  event_id=event_id,
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """SQLAlchemy declarative models for tracing v3."""
3
4
 
4
5
  import json
@@ -452,7 +453,9 @@ class EventReward(Base):
452
453
  message_id = Column(Integer, ForeignKey("messages.id"), nullable=True)
453
454
  turn_number = Column(Integer, nullable=True)
454
455
  reward_value = Column(Float, nullable=False, default=0.0)
455
- reward_type = Column(String, nullable=True) # shaped | sparse | achievement | penalty | evaluator | human
456
+ reward_type = Column(
457
+ String, nullable=True
458
+ ) # shaped | sparse | achievement | penalty | evaluator | human
456
459
  key = Column(String, nullable=True) # e.g., achievement name
457
460
  annotation = Column(JSONText) # free-form JSON
458
461
  source = Column(String, nullable=True) # environment | runner | evaluator | human
@@ -1,4 +1,5 @@
1
1
  from __future__ import annotations
2
+
2
3
  """Utility functions for tracing v3."""
3
4
 
4
5
  import hashlib