synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (73) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +494 -0
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
  19. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  20. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  21. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
  22. examples/warming_up_to_rl/run_eval.py +267 -41
  23. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  25. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  26. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
  27. synth_ai/__init__.py +41 -1
  28. synth_ai/api/train/builders.py +74 -33
  29. synth_ai/api/train/cli.py +29 -6
  30. synth_ai/api/train/configs/__init__.py +44 -0
  31. synth_ai/api/train/configs/rl.py +133 -0
  32. synth_ai/api/train/configs/sft.py +94 -0
  33. synth_ai/api/train/configs/shared.py +24 -0
  34. synth_ai/api/train/env_resolver.py +18 -19
  35. synth_ai/api/train/supported_algos.py +8 -5
  36. synth_ai/api/train/utils.py +6 -1
  37. synth_ai/cli/__init__.py +4 -2
  38. synth_ai/cli/_storage.py +19 -0
  39. synth_ai/cli/balance.py +14 -2
  40. synth_ai/cli/calc.py +37 -22
  41. synth_ai/cli/demo.py +38 -39
  42. synth_ai/cli/legacy_root_backup.py +12 -14
  43. synth_ai/cli/recent.py +12 -7
  44. synth_ai/cli/rl_demo.py +81 -102
  45. synth_ai/cli/status.py +4 -3
  46. synth_ai/cli/task_apps.py +146 -137
  47. synth_ai/cli/traces.py +4 -3
  48. synth_ai/cli/watch.py +3 -2
  49. synth_ai/demos/core/cli.py +121 -159
  50. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  51. synth_ai/evals/__init__.py +15 -0
  52. synth_ai/evals/client.py +85 -0
  53. synth_ai/evals/types.py +42 -0
  54. synth_ai/jobs/client.py +15 -3
  55. synth_ai/judge_schemas.py +127 -0
  56. synth_ai/rubrics/__init__.py +22 -0
  57. synth_ai/rubrics/validators.py +126 -0
  58. synth_ai/task/server.py +14 -7
  59. synth_ai/tracing_v3/decorators.py +51 -26
  60. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  61. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  62. synth_ai/tracing_v3/replica_sync.py +8 -4
  63. synth_ai/tracing_v3/serialization.py +130 -0
  64. synth_ai/tracing_v3/storage/utils.py +11 -9
  65. synth_ai/tracing_v3/turso/__init__.py +12 -0
  66. synth_ai/tracing_v3/turso/daemon.py +2 -1
  67. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  68. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
  69. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
  70. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  71. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  72. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  73. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,127 @@
1
+ """
2
+ Judge API Contract Schemas
3
+
4
+ These schemas define the expected structure for requests and responses
5
+ to the judge scoring endpoint at POST /api/judge/v1/score.
6
+
7
+ This is the canonical contract that the backend MUST conform to.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any, Dict, List, Literal, Optional
13
+
14
+ from pydantic import BaseModel, Field
15
+
16
+
17
+ class CriterionScorePayload(BaseModel):
18
+ """Per-criterion score returned by the judge."""
19
+
20
+ score: float = Field(..., description="Numeric score for this criterion")
21
+ reason: str = Field(default="", description="Explanation for the score")
22
+ weight: float = Field(default=1.0, description="Weight of this criterion")
23
+ description: str = Field(default="", description="Description of the criterion")
24
+
25
+
26
+ class ReviewPayload(BaseModel):
27
+ """Rubric review (event-level or outcome-level)."""
28
+
29
+ criteria: Dict[str, CriterionScorePayload] = Field(
30
+ default_factory=dict,
31
+ description="Map of criterion keys to their scores"
32
+ )
33
+ total: float = Field(default=0.0, description="Aggregated total score")
34
+ summary: Optional[str] = Field(None, description="Optional text summary")
35
+
36
+
37
+ class JudgeScoreResponse(BaseModel):
38
+ """
39
+ Response body for POST /api/judge/v1/score.
40
+
41
+ This is the canonical contract that judge backends MUST return.
42
+ """
43
+
44
+ status: Literal["ok", "failed"] = Field(default="ok", description="Request status")
45
+ event_reviews: List[ReviewPayload] = Field(
46
+ default_factory=list,
47
+ description="List of per-event rubric reviews (one per step)"
48
+ )
49
+ outcome_review: Optional[ReviewPayload] = Field(
50
+ None,
51
+ description="Optional outcome-level rubric review"
52
+ )
53
+ event_totals: List[float] = Field(
54
+ default_factory=list,
55
+ description="List of aggregated scores per event (matches event_reviews length)"
56
+ )
57
+ details: Dict[str, Any] = Field(
58
+ default_factory=dict,
59
+ description="Additional details (provider, latency, etc.)"
60
+ )
61
+ metadata: Dict[str, Any] = Field(
62
+ default_factory=dict,
63
+ description="Request metadata (provider, options, etc.)"
64
+ )
65
+
66
+ def aggregate_event_reward(self) -> float | None:
67
+ """
68
+ Aggregate all event totals into a single reward.
69
+
70
+ Returns:
71
+ Sum of all event_totals, or None if empty
72
+ """
73
+ if not self.event_totals:
74
+ return None
75
+ return sum(self.event_totals)
76
+
77
+ def aggregate_outcome_reward(self) -> float | None:
78
+ """
79
+ Extract outcome reward from outcome_review.
80
+
81
+ Returns:
82
+ outcome_review.total, or None if no outcome review
83
+ """
84
+ if self.outcome_review is None:
85
+ return None
86
+ return self.outcome_review.total
87
+
88
+
89
+ # Request schemas for completeness
90
+
91
+ class JudgeTaskApp(BaseModel):
92
+ """Task application metadata."""
93
+
94
+ id: str = Field(..., description="Task app identifier")
95
+ base_url: Optional[str] = Field(None, description="Optional base URL for task app")
96
+
97
+
98
+ class JudgeOptions(BaseModel):
99
+ """Judge provider and configuration options."""
100
+
101
+ provider: Optional[str] = Field(None, description="Judge provider (e.g., 'openai', 'groq')")
102
+ model: Optional[str] = Field(None, description="Model identifier")
103
+ rubric_id: Optional[str] = Field(None, description="Rubric identifier")
104
+ event: bool = Field(True, description="Enable event-level judging")
105
+ outcome: bool = Field(True, description="Enable outcome-level judging")
106
+
107
+
108
+ class JudgeTracePayload(BaseModel):
109
+ """Trace payload containing trajectory context."""
110
+
111
+ event_history: List[Dict[str, Any]] = Field(..., description="List of events/steps")
112
+ markov_blanket_message_history: List[Dict[str, Any]] = Field(
113
+ default_factory=list,
114
+ description="Optional message history for context"
115
+ )
116
+ metadata: Dict[str, Any] = Field(default_factory=dict, description="Trace metadata")
117
+
118
+
119
+ class JudgeScoreRequest(BaseModel):
120
+ """Request body for POST /api/judge/v1/score."""
121
+
122
+ policy_name: str = Field(..., description="Name of the policy being evaluated")
123
+ task_app: JudgeTaskApp = Field(..., description="Task application metadata")
124
+ trace: JudgeTracePayload = Field(..., description="Trajectory trace to evaluate")
125
+ options: JudgeOptions = Field(default_factory=lambda: JudgeOptions(), description="Judge options")
126
+ rubric: Optional[Dict[str, Any]] = Field(None, description="Optional explicit rubric criteria")
127
+
@@ -0,0 +1,22 @@
1
+ """
2
+ Rubric utilities.
3
+
4
+ Exposes helpers for validating rubric specifications that are used across
5
+ Crafter-style judge configurations.
6
+ """
7
+
8
+ from .validators import (
9
+ RubricCriterion,
10
+ RubricSpec,
11
+ ValidationError,
12
+ validate_rubric_dict,
13
+ validate_rubric_file,
14
+ )
15
+
16
+ __all__ = [
17
+ "RubricCriterion",
18
+ "RubricSpec",
19
+ "ValidationError",
20
+ "validate_rubric_dict",
21
+ "validate_rubric_file",
22
+ ]
@@ -0,0 +1,126 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import math
5
+ from pathlib import Path
6
+ from typing import Any, Iterable, Literal
7
+
8
+ import pydantic
9
+
10
+
11
+ class RubricCriterion(pydantic.BaseModel):
12
+ """Single scoring criterion within a rubric."""
13
+
14
+ id: str
15
+ description: str
16
+ weight: float
17
+ scale: str | None = None
18
+
19
+ @pydantic.field_validator("weight")
20
+ @classmethod
21
+ def _validate_weight(cls, value: float) -> float:
22
+ if not math.isfinite(value):
23
+ raise ValueError("weight must be a finite number")
24
+ if value <= 0.0:
25
+ raise ValueError("weight must be positive")
26
+ if value > 1.0:
27
+ raise ValueError("weight must be <= 1.0")
28
+ return value
29
+
30
+ @pydantic.field_validator("id", "description", mode="before")
31
+ @classmethod
32
+ def _strip_string(cls, value: Any) -> Any:
33
+ if isinstance(value, str):
34
+ return value.strip()
35
+ return value
36
+
37
+
38
+ class RubricSpec(pydantic.BaseModel):
39
+ """High-level rubric definition used by step-wise judges."""
40
+
41
+ version: str
42
+ goal_text: str
43
+ aggregation: Literal["weighted_sum"]
44
+ criteria: list[RubricCriterion]
45
+
46
+ @pydantic.model_validator(mode="after")
47
+ def _validate_weights(self) -> "RubricSpec":
48
+ if not self.criteria:
49
+ raise ValueError("rubric must declare at least one criterion")
50
+ total_weight = sum(criterion.weight for criterion in self.criteria)
51
+ if not math.isclose(total_weight, 1.0, abs_tol=1e-6, rel_tol=1e-6):
52
+ raise ValueError(
53
+ f"criterion weights must sum to 1 (got {total_weight:.6f})"
54
+ )
55
+ return self
56
+
57
+ @pydantic.field_validator("version")
58
+ @classmethod
59
+ def _non_empty_version(cls, value: str) -> str:
60
+ value = value.strip()
61
+ if not value:
62
+ raise ValueError("version string must not be empty")
63
+ return value
64
+
65
+ @pydantic.field_validator("goal_text")
66
+ @classmethod
67
+ def _non_empty_goal_text(cls, value: str) -> str:
68
+ value = value.strip()
69
+ if not value:
70
+ raise ValueError("goal_text must not be empty")
71
+ return value
72
+
73
+
74
+ ValidationError = pydantic.ValidationError
75
+
76
+
77
+ def validate_rubric_dict(payload: dict[str, Any]) -> RubricSpec:
78
+ """
79
+ Validate an in-memory rubric payload and return the parsed model.
80
+
81
+ Args:
82
+ payload: Dictionary representing the rubric JSON.
83
+ Returns:
84
+ Validated RubricSpec instance.
85
+ Raises:
86
+ ValidationError: If the payload is missing required fields or contains
87
+ invalid weights.
88
+ """
89
+
90
+ if not isinstance(payload, dict):
91
+ raise TypeError("rubric payload must be a dictionary")
92
+ return RubricSpec.model_validate(payload)
93
+
94
+
95
+ def _load_payload_from_file(path: Path) -> dict[str, Any]:
96
+ if path.suffix.lower() != ".json":
97
+ raise ValueError(f"Unsupported rubric file type: {path}")
98
+ text = path.read_text(encoding="utf-8")
99
+ return json.loads(text)
100
+
101
+
102
+ def validate_rubric_file(path: Path) -> RubricSpec:
103
+ """
104
+ Load and validate a rubric file.
105
+
106
+ Args:
107
+ path: Path to a JSON rubric document.
108
+ Returns:
109
+ Validated RubricSpec instance.
110
+ """
111
+
112
+ payload = _load_payload_from_file(path)
113
+ return validate_rubric_dict(payload)
114
+
115
+
116
+ def validate_rubric_files(paths: Iterable[Path]) -> list[RubricSpec]:
117
+ """
118
+ Validate multiple rubric files and return their parsed models.
119
+
120
+ Useful for bulk validation inside tests or CI checks.
121
+ """
122
+
123
+ validated: list[RubricSpec] = []
124
+ for path in paths:
125
+ validated.append(validate_rubric_file(path))
126
+ return validated
synth_ai/task/server.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
6
  import inspect
7
7
  import os
8
8
  from collections.abc import Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
9
+ from contextlib import asynccontextmanager
9
10
  from dataclasses import dataclass, field
10
11
  from pathlib import Path
11
12
  from typing import Any
@@ -34,6 +35,10 @@ InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iter
34
35
  RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
35
36
 
36
37
 
38
+ def _default_app_state() -> dict[str, Any]:
39
+ return {}
40
+
41
+
37
42
  @dataclass(slots=True)
38
43
  class RubricBundle:
39
44
  """Optional rubrics advertised by the task app."""
@@ -69,7 +74,7 @@ class TaskAppConfig:
69
74
  proxy: ProxyConfig | None = None
70
75
  routers: Sequence[APIRouter] = field(default_factory=tuple)
71
76
  middleware: Sequence[Middleware] = field(default_factory=tuple)
72
- app_state: Mapping[str, Any] = field(default_factory=dict)
77
+ app_state: MutableMapping[str, Any] = field(default_factory=_default_app_state)
73
78
  require_api_key: bool = True
74
79
  expose_debug_env: bool = True
75
80
  cors_origins: Sequence[str] | None = None
@@ -260,17 +265,19 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
260
265
  return _maybe_await(hook(app)) # type: ignore[misc]
261
266
  return _maybe_await(hook())
262
267
 
263
- @app.on_event("startup")
264
- async def _startup() -> None: # pragma: no cover - FastAPI lifecycle
268
+ @asynccontextmanager
269
+ async def lifespan(_: FastAPI):
265
270
  normalize_environment_api_key()
266
271
  normalize_vendor_keys()
267
272
  for hook in cfg.startup_hooks:
268
273
  await _call_hook(hook)
274
+ try:
275
+ yield
276
+ finally:
277
+ for hook in cfg.shutdown_hooks:
278
+ await _call_hook(hook)
269
279
 
270
- @app.on_event("shutdown")
271
- async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
272
- for hook in cfg.shutdown_hooks:
273
- await _call_hook(hook)
280
+ app.router.lifespan_context = lifespan
274
281
 
275
282
  @app.get("/")
276
283
  async def root() -> Mapping[str, Any]:
@@ -28,8 +28,8 @@ import asyncio
28
28
  import contextvars
29
29
  import functools
30
30
  import time
31
- from collections.abc import Callable
32
- from typing import Any, TypeVar
31
+ from collections.abc import Awaitable, Callable, Mapping
32
+ from typing import Any, TypeVar, cast, overload
33
33
 
34
34
  from .abstractions import LMCAISEvent, TimeRecord
35
35
  from .utils import calculate_cost, detect_provider
@@ -88,6 +88,16 @@ def get_session_tracer() -> Any:
88
88
  T = TypeVar("T")
89
89
 
90
90
 
91
+ @overload
92
+ def with_session(require: bool = True) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
93
+ ...
94
+
95
+
96
+ @overload
97
+ def with_session(require: bool = True) -> Callable[[Callable[..., T]], Callable[..., T]]:
98
+ ...
99
+
100
+
91
101
  def with_session(require: bool = True):
92
102
  """Decorator that ensures a session is active.
93
103
 
@@ -109,29 +119,31 @@ def with_session(require: bool = True):
109
119
  ```
110
120
  """
111
121
 
112
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
122
+ def decorator(fn: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
113
123
  if asyncio.iscoroutinefunction(fn):
114
124
 
115
125
  @functools.wraps(fn)
116
- async def async_wrapper(*args, **kwargs):
126
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
117
127
  session_id = get_session_id()
118
128
  if require and session_id is None:
119
129
  raise RuntimeError(
120
130
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
121
131
  )
122
- return await fn(*args, **kwargs)
132
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
133
+ return await async_fn(*args, **kwargs)
123
134
 
124
135
  return async_wrapper
125
136
  else:
126
137
 
127
138
  @functools.wraps(fn)
128
- def sync_wrapper(*args, **kwargs):
139
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
129
140
  session_id = get_session_id()
130
141
  if require and session_id is None:
131
142
  raise RuntimeError(
132
143
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
133
144
  )
134
- return fn(*args, **kwargs)
145
+ sync_fn = cast(Callable[..., T], fn)
146
+ return sync_fn(*args, **kwargs)
135
147
 
136
148
  return sync_wrapper
137
149
 
@@ -172,31 +184,36 @@ def trace_llm_call(
172
184
  ```
173
185
  """
174
186
 
175
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
187
+ def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
176
188
  if asyncio.iscoroutinefunction(fn):
189
+ async_fn: Callable[..., Awaitable[T]] = fn
177
190
 
178
191
  @functools.wraps(fn)
179
- async def async_wrapper(*args, **kwargs):
192
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
180
193
  tracer = get_session_tracer()
181
194
  if not tracer:
182
- return await fn(*args, **kwargs)
195
+ return await async_fn(*args, **kwargs)
183
196
 
184
197
  start_time = time.time()
185
198
  system_state_before = kwargs.get("state_before", {})
186
199
 
187
200
  try:
188
- result = await fn(*args, **kwargs)
201
+ result = await async_fn(*args, **kwargs)
189
202
 
190
203
  # Extract metrics from result - this assumes the result follows
191
204
  # common LLM API response formats (OpenAI, Anthropic, etc.)
192
- if extract_tokens and isinstance(result, dict):
193
- input_tokens = result.get("usage", {}).get("prompt_tokens")
194
- output_tokens = result.get("usage", {}).get("completion_tokens")
195
- total_tokens = result.get("usage", {}).get("total_tokens")
196
- actual_model = result.get("model", model_name)
197
- else:
198
- input_tokens = output_tokens = total_tokens = None
199
- actual_model = model_name
205
+ input_tokens = output_tokens = total_tokens = None
206
+ actual_model = model_name
207
+ if extract_tokens and isinstance(result, Mapping):
208
+ result_mapping = cast(Mapping[str, Any], result)
209
+ usage = result_mapping.get("usage")
210
+ if isinstance(usage, Mapping):
211
+ input_tokens = usage.get("prompt_tokens")
212
+ output_tokens = usage.get("completion_tokens")
213
+ total_tokens = usage.get("total_tokens")
214
+ value = result_mapping.get("model")
215
+ if isinstance(value, str):
216
+ actual_model = value
200
217
 
201
218
  latency_ms = int((time.time() - start_time) * 1000)
202
219
 
@@ -272,19 +289,26 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
272
289
  ```
273
290
  """
274
291
 
275
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
292
+ def decorator(
293
+ fn: Callable[..., Awaitable[T]] | Callable[..., T]
294
+ ) -> Callable[..., Awaitable[T]] | Callable[..., T]:
276
295
  if asyncio.iscoroutinefunction(fn):
296
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
277
297
 
278
298
  @functools.wraps(fn)
279
- async def async_wrapper(self, *args, **kwargs):
299
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
280
300
  tracer = get_session_tracer()
281
301
  if not tracer:
282
- return await fn(self, *args, **kwargs)
302
+ return await async_fn(*args, **kwargs)
283
303
 
284
304
  from .abstractions import RuntimeEvent
285
305
 
286
306
  # Use class name as system_id if not provided
287
- actual_system_id = system_id or self.__class__.__name__
307
+ self_obj = args[0] if args else None
308
+ inferred_system_id = (
309
+ self_obj.__class__.__name__ if self_obj is not None else "unknown"
310
+ )
311
+ actual_system_id = system_id or inferred_system_id
288
312
 
289
313
  event = RuntimeEvent(
290
314
  system_instance_id=actual_system_id,
@@ -298,17 +322,18 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
298
322
  )
299
323
 
300
324
  await tracer.record_event(event)
301
- return await fn(self, *args, **kwargs)
325
+ return await async_fn(*args, **kwargs)
302
326
 
303
327
  return async_wrapper
304
328
  else:
305
329
 
306
330
  @functools.wraps(fn)
307
- def sync_wrapper(self, *args, **kwargs):
331
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
308
332
  # For sync methods, we can't easily trace without blocking
309
333
  # the event loop. This is a limitation of the async-first design.
310
334
  # Consider converting to async or using a different approach
311
- return fn(self, *args, **kwargs)
335
+ sync_fn = cast(Callable[..., T], fn)
336
+ return sync_fn(*args, **kwargs)
312
337
 
313
338
  return sync_wrapper
314
339
 
@@ -2,13 +2,14 @@
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ from typing import Any
5
6
 
6
- from synth_ai.tracing_v3 import SessionTracer
7
- from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
8
- from synth_ai.tracing_v3.turso.daemon import SqldDaemon
7
+ from .. import SessionTracer
8
+ from ..abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
9
+ from ..turso.daemon import SqldDaemon
9
10
 
10
11
 
11
- async def simulate_llm_call(model: str, prompt: str) -> dict:
12
+ async def simulate_llm_call(model: str, prompt: str) -> dict[str, Any]:
12
13
  """Simulate an LLM API call."""
13
14
  await asyncio.sleep(0.1) # Simulate network latency
14
15
 
@@ -133,6 +134,9 @@ async def main():
133
134
  print("\n--- Example 3: Querying Data ---")
134
135
 
135
136
  # Get model usage statistics
137
+ if tracer.db is None:
138
+ raise RuntimeError("Tracer database backend is not initialized")
139
+
136
140
  model_usage = await tracer.db.get_model_usage()
137
141
  print("\nModel Usage:")
138
142
  print(model_usage)
@@ -150,9 +154,10 @@ async def main():
150
154
  # Get specific session details
151
155
  if recent_sessions:
152
156
  session_detail = await tracer.db.get_session_trace(recent_sessions[0]["session_id"])
153
- print(f"\nSession Detail for {session_detail['session_id']}:")
154
- print(f" Created: {session_detail['created_at']}")
155
- print(f" Timesteps: {len(session_detail['timesteps'])}")
157
+ if session_detail:
158
+ print(f"\nSession Detail for {session_detail['session_id']}:")
159
+ print(f" Created: {session_detail['created_at']}")
160
+ print(f" Timesteps: {len(session_detail['timesteps'])}")
156
161
 
157
162
  # Example 4: Using hooks
158
163
  print("\n--- Example 4: Hooks ---")