synth-ai 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (38) hide show
  1. examples/multi_step/task_app_config_notes.md +488 -0
  2. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
  3. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  4. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  5. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
  6. examples/warming_up_to_rl/run_eval.py +142 -25
  7. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
  8. synth_ai/api/train/builders.py +25 -14
  9. synth_ai/api/train/cli.py +29 -6
  10. synth_ai/api/train/env_resolver.py +18 -19
  11. synth_ai/api/train/supported_algos.py +8 -5
  12. synth_ai/api/train/utils.py +6 -1
  13. synth_ai/cli/__init__.py +4 -2
  14. synth_ai/cli/_storage.py +19 -0
  15. synth_ai/cli/balance.py +14 -2
  16. synth_ai/cli/calc.py +37 -22
  17. synth_ai/cli/legacy_root_backup.py +12 -14
  18. synth_ai/cli/recent.py +12 -7
  19. synth_ai/cli/status.py +4 -3
  20. synth_ai/cli/task_apps.py +143 -137
  21. synth_ai/cli/traces.py +4 -3
  22. synth_ai/cli/watch.py +3 -2
  23. synth_ai/jobs/client.py +15 -3
  24. synth_ai/task/server.py +14 -7
  25. synth_ai/tracing_v3/decorators.py +51 -26
  26. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  27. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  28. synth_ai/tracing_v3/replica_sync.py +8 -4
  29. synth_ai/tracing_v3/storage/utils.py +11 -9
  30. synth_ai/tracing_v3/turso/__init__.py +12 -0
  31. synth_ai/tracing_v3/turso/daemon.py +2 -1
  32. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  33. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/METADATA +4 -2
  34. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/RECORD +38 -31
  35. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
  36. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
  37. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
  38. {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -0
synth_ai/task/server.py CHANGED
@@ -6,6 +6,7 @@ import asyncio
6
6
  import inspect
7
7
  import os
8
8
  from collections.abc import Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
9
+ from contextlib import asynccontextmanager
9
10
  from dataclasses import dataclass, field
10
11
  from pathlib import Path
11
12
  from typing import Any
@@ -34,6 +35,10 @@ InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iter
34
35
  RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
35
36
 
36
37
 
38
+ def _default_app_state() -> dict[str, Any]:
39
+ return {}
40
+
41
+
37
42
  @dataclass(slots=True)
38
43
  class RubricBundle:
39
44
  """Optional rubrics advertised by the task app."""
@@ -69,7 +74,7 @@ class TaskAppConfig:
69
74
  proxy: ProxyConfig | None = None
70
75
  routers: Sequence[APIRouter] = field(default_factory=tuple)
71
76
  middleware: Sequence[Middleware] = field(default_factory=tuple)
72
- app_state: Mapping[str, Any] = field(default_factory=dict)
77
+ app_state: MutableMapping[str, Any] = field(default_factory=_default_app_state)
73
78
  require_api_key: bool = True
74
79
  expose_debug_env: bool = True
75
80
  cors_origins: Sequence[str] | None = None
@@ -260,17 +265,19 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
260
265
  return _maybe_await(hook(app)) # type: ignore[misc]
261
266
  return _maybe_await(hook())
262
267
 
263
- @app.on_event("startup")
264
- async def _startup() -> None: # pragma: no cover - FastAPI lifecycle
268
+ @asynccontextmanager
269
+ async def lifespan(_: FastAPI):
265
270
  normalize_environment_api_key()
266
271
  normalize_vendor_keys()
267
272
  for hook in cfg.startup_hooks:
268
273
  await _call_hook(hook)
274
+ try:
275
+ yield
276
+ finally:
277
+ for hook in cfg.shutdown_hooks:
278
+ await _call_hook(hook)
269
279
 
270
- @app.on_event("shutdown")
271
- async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
272
- for hook in cfg.shutdown_hooks:
273
- await _call_hook(hook)
280
+ app.router.lifespan_context = lifespan
274
281
 
275
282
  @app.get("/")
276
283
  async def root() -> Mapping[str, Any]:
@@ -28,8 +28,8 @@ import asyncio
28
28
  import contextvars
29
29
  import functools
30
30
  import time
31
- from collections.abc import Callable
32
- from typing import Any, TypeVar
31
+ from collections.abc import Awaitable, Callable, Mapping
32
+ from typing import Any, TypeVar, cast, overload
33
33
 
34
34
  from .abstractions import LMCAISEvent, TimeRecord
35
35
  from .utils import calculate_cost, detect_provider
@@ -88,6 +88,16 @@ def get_session_tracer() -> Any:
88
88
  T = TypeVar("T")
89
89
 
90
90
 
91
+ @overload
92
+ def with_session(require: bool = True) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
93
+ ...
94
+
95
+
96
+ @overload
97
+ def with_session(require: bool = True) -> Callable[[Callable[..., T]], Callable[..., T]]:
98
+ ...
99
+
100
+
91
101
  def with_session(require: bool = True):
92
102
  """Decorator that ensures a session is active.
93
103
 
@@ -109,29 +119,31 @@ def with_session(require: bool = True):
109
119
  ```
110
120
  """
111
121
 
112
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
122
+ def decorator(fn: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
113
123
  if asyncio.iscoroutinefunction(fn):
114
124
 
115
125
  @functools.wraps(fn)
116
- async def async_wrapper(*args, **kwargs):
126
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
117
127
  session_id = get_session_id()
118
128
  if require and session_id is None:
119
129
  raise RuntimeError(
120
130
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
121
131
  )
122
- return await fn(*args, **kwargs)
132
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
133
+ return await async_fn(*args, **kwargs)
123
134
 
124
135
  return async_wrapper
125
136
  else:
126
137
 
127
138
  @functools.wraps(fn)
128
- def sync_wrapper(*args, **kwargs):
139
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
129
140
  session_id = get_session_id()
130
141
  if require and session_id is None:
131
142
  raise RuntimeError(
132
143
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
133
144
  )
134
- return fn(*args, **kwargs)
145
+ sync_fn = cast(Callable[..., T], fn)
146
+ return sync_fn(*args, **kwargs)
135
147
 
136
148
  return sync_wrapper
137
149
 
@@ -172,31 +184,36 @@ def trace_llm_call(
172
184
  ```
173
185
  """
174
186
 
175
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
187
+ def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
176
188
  if asyncio.iscoroutinefunction(fn):
189
+ async_fn: Callable[..., Awaitable[T]] = fn
177
190
 
178
191
  @functools.wraps(fn)
179
- async def async_wrapper(*args, **kwargs):
192
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
180
193
  tracer = get_session_tracer()
181
194
  if not tracer:
182
- return await fn(*args, **kwargs)
195
+ return await async_fn(*args, **kwargs)
183
196
 
184
197
  start_time = time.time()
185
198
  system_state_before = kwargs.get("state_before", {})
186
199
 
187
200
  try:
188
- result = await fn(*args, **kwargs)
201
+ result = await async_fn(*args, **kwargs)
189
202
 
190
203
  # Extract metrics from result - this assumes the result follows
191
204
  # common LLM API response formats (OpenAI, Anthropic, etc.)
192
- if extract_tokens and isinstance(result, dict):
193
- input_tokens = result.get("usage", {}).get("prompt_tokens")
194
- output_tokens = result.get("usage", {}).get("completion_tokens")
195
- total_tokens = result.get("usage", {}).get("total_tokens")
196
- actual_model = result.get("model", model_name)
197
- else:
198
- input_tokens = output_tokens = total_tokens = None
199
- actual_model = model_name
205
+ input_tokens = output_tokens = total_tokens = None
206
+ actual_model = model_name
207
+ if extract_tokens and isinstance(result, Mapping):
208
+ result_mapping = cast(Mapping[str, Any], result)
209
+ usage = result_mapping.get("usage")
210
+ if isinstance(usage, Mapping):
211
+ input_tokens = usage.get("prompt_tokens")
212
+ output_tokens = usage.get("completion_tokens")
213
+ total_tokens = usage.get("total_tokens")
214
+ value = result_mapping.get("model")
215
+ if isinstance(value, str):
216
+ actual_model = value
200
217
 
201
218
  latency_ms = int((time.time() - start_time) * 1000)
202
219
 
@@ -272,19 +289,26 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
272
289
  ```
273
290
  """
274
291
 
275
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
292
+ def decorator(
293
+ fn: Callable[..., Awaitable[T]] | Callable[..., T]
294
+ ) -> Callable[..., Awaitable[T]] | Callable[..., T]:
276
295
  if asyncio.iscoroutinefunction(fn):
296
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
277
297
 
278
298
  @functools.wraps(fn)
279
- async def async_wrapper(self, *args, **kwargs):
299
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
280
300
  tracer = get_session_tracer()
281
301
  if not tracer:
282
- return await fn(self, *args, **kwargs)
302
+ return await async_fn(*args, **kwargs)
283
303
 
284
304
  from .abstractions import RuntimeEvent
285
305
 
286
306
  # Use class name as system_id if not provided
287
- actual_system_id = system_id or self.__class__.__name__
307
+ self_obj = args[0] if args else None
308
+ inferred_system_id = (
309
+ self_obj.__class__.__name__ if self_obj is not None else "unknown"
310
+ )
311
+ actual_system_id = system_id or inferred_system_id
288
312
 
289
313
  event = RuntimeEvent(
290
314
  system_instance_id=actual_system_id,
@@ -298,17 +322,18 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
298
322
  )
299
323
 
300
324
  await tracer.record_event(event)
301
- return await fn(self, *args, **kwargs)
325
+ return await async_fn(*args, **kwargs)
302
326
 
303
327
  return async_wrapper
304
328
  else:
305
329
 
306
330
  @functools.wraps(fn)
307
- def sync_wrapper(self, *args, **kwargs):
331
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
308
332
  # For sync methods, we can't easily trace without blocking
309
333
  # the event loop. This is a limitation of the async-first design.
310
334
  # Consider converting to async or using a different approach
311
- return fn(self, *args, **kwargs)
335
+ sync_fn = cast(Callable[..., T], fn)
336
+ return sync_fn(*args, **kwargs)
312
337
 
313
338
  return sync_wrapper
314
339
 
@@ -2,13 +2,14 @@
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ from typing import Any
5
6
 
6
- from synth_ai.tracing_v3 import SessionTracer
7
- from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
8
- from synth_ai.tracing_v3.turso.daemon import SqldDaemon
7
+ from .. import SessionTracer
8
+ from ..abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
9
+ from ..turso.daemon import SqldDaemon
9
10
 
10
11
 
11
- async def simulate_llm_call(model: str, prompt: str) -> dict:
12
+ async def simulate_llm_call(model: str, prompt: str) -> dict[str, Any]:
12
13
  """Simulate an LLM API call."""
13
14
  await asyncio.sleep(0.1) # Simulate network latency
14
15
 
@@ -133,6 +134,9 @@ async def main():
133
134
  print("\n--- Example 3: Querying Data ---")
134
135
 
135
136
  # Get model usage statistics
137
+ if tracer.db is None:
138
+ raise RuntimeError("Tracer database backend is not initialized")
139
+
136
140
  model_usage = await tracer.db.get_model_usage()
137
141
  print("\nModel Usage:")
138
142
  print(model_usage)
@@ -150,9 +154,10 @@ async def main():
150
154
  # Get specific session details
151
155
  if recent_sessions:
152
156
  session_detail = await tracer.db.get_session_trace(recent_sessions[0]["session_id"])
153
- print(f"\nSession Detail for {session_detail['session_id']}:")
154
- print(f" Created: {session_detail['created_at']}")
155
- print(f" Timesteps: {len(session_detail['timesteps'])}")
157
+ if session_detail:
158
+ print(f"\nSession Detail for {session_detail['session_id']}:")
159
+ print(f" Created: {session_detail['created_at']}")
160
+ print(f" Timesteps: {len(session_detail['timesteps'])}")
156
161
 
157
162
  # Example 4: Using hooks
158
163
  print("\n--- Example 4: Hooks ---")
@@ -4,11 +4,14 @@ This module provides utilities to convert vendor responses to LLMCallRecord
4
4
  format and compute aggregates from call records.
5
5
  """
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  import uuid
10
+ from dataclasses import dataclass, field
8
11
  from datetime import UTC, datetime
9
- from typing import Any
12
+ from typing import Any, TypedDict, cast
10
13
 
11
- from synth_ai.tracing_v3.lm_call_record_abstractions import (
14
+ from .lm_call_record_abstractions import (
12
15
  LLMCallRecord,
13
16
  LLMChunk,
14
17
  LLMContentPart,
@@ -17,7 +20,21 @@ from synth_ai.tracing_v3.lm_call_record_abstractions import (
17
20
  LLMUsage,
18
21
  ToolCallSpec,
19
22
  )
20
- from synth_ai.v0.lm.vendors.base import BaseLMResponse
23
+
24
+ BaseLMResponse = Any
25
+
26
+
27
+ class _UsageDict(TypedDict, total=False):
28
+ prompt_tokens: int
29
+ completion_tokens: int
30
+ total_tokens: int
31
+ reasoning_tokens: int
32
+ cost_usd: float
33
+ duration_ms: int
34
+ reasoning_input_tokens: int
35
+ reasoning_output_tokens: int
36
+ cache_write_tokens: int
37
+ cache_read_tokens: int
21
38
 
22
39
 
23
40
  def create_llm_call_record_from_response(
@@ -110,9 +127,10 @@ def create_llm_call_record_from_response(
110
127
  )
111
128
 
112
129
  # Extract tool calls if present
113
- output_tool_calls = []
114
- if hasattr(response, "tool_calls") and response.tool_calls:
115
- for idx, tool_call in enumerate(response.tool_calls):
130
+ output_tool_calls: list[ToolCallSpec] = []
131
+ tool_calls_data = cast(list[dict[str, Any]] | None, getattr(response, "tool_calls", None))
132
+ if tool_calls_data:
133
+ for idx, tool_call in enumerate(tool_calls_data):
116
134
  if isinstance(tool_call, dict):
117
135
  output_tool_calls.append(
118
136
  ToolCallSpec(
@@ -125,18 +143,19 @@ def create_llm_call_record_from_response(
125
143
 
126
144
  # Extract usage information
127
145
  usage = None
128
- if hasattr(response, "usage") and response.usage:
146
+ usage_data = cast(_UsageDict | None, getattr(response, "usage", None))
147
+ if usage_data:
129
148
  usage = LLMUsage(
130
- input_tokens=response.usage.get("input_tokens"),
131
- output_tokens=response.usage.get("output_tokens"),
132
- total_tokens=response.usage.get("total_tokens"),
133
- cost_usd=response.usage.get("cost_usd"),
149
+ input_tokens=usage_data.get("input_tokens"),
150
+ output_tokens=usage_data.get("output_tokens"),
151
+ total_tokens=usage_data.get("total_tokens"),
152
+ cost_usd=usage_data.get("cost_usd"),
134
153
  # Additional token accounting if available
135
- reasoning_tokens=response.usage.get("reasoning_tokens"),
136
- reasoning_input_tokens=response.usage.get("reasoning_input_tokens"),
137
- reasoning_output_tokens=response.usage.get("reasoning_output_tokens"),
138
- cache_write_tokens=response.usage.get("cache_write_tokens"),
139
- cache_read_tokens=response.usage.get("cache_read_tokens"),
154
+ reasoning_tokens=usage_data.get("reasoning_tokens"),
155
+ reasoning_input_tokens=usage_data.get("reasoning_input_tokens"),
156
+ reasoning_output_tokens=usage_data.get("reasoning_output_tokens"),
157
+ cache_write_tokens=usage_data.get("cache_write_tokens"),
158
+ cache_read_tokens=usage_data.get("cache_read_tokens"),
140
159
  )
141
160
 
142
161
  # Build request parameters
@@ -188,7 +207,45 @@ def create_llm_call_record_from_response(
188
207
  return record
189
208
 
190
209
 
191
- def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> dict[str, Any]:
210
+ @dataclass
211
+ class _AggregateAccumulator:
212
+ """Mutable accumulator for call record aggregates."""
213
+
214
+ call_count: int = 0
215
+ input_tokens: int = 0
216
+ output_tokens: int = 0
217
+ total_tokens: int = 0
218
+ reasoning_tokens: int = 0
219
+ cost_usd: float = 0.0
220
+ latency_ms: int = 0
221
+ models_used: set[str] = field(default_factory=set)
222
+ providers_used: set[str] = field(default_factory=set)
223
+ tool_calls_count: int = 0
224
+ error_count: int = 0
225
+ success_count: int = 0
226
+
227
+
228
+ class AggregateSummary(TypedDict, total=False):
229
+ """Aggregate metrics derived from call records."""
230
+
231
+ call_count: int
232
+ input_tokens: int
233
+ output_tokens: int
234
+ total_tokens: int
235
+ reasoning_tokens: int
236
+ cost_usd: float
237
+ latency_ms: int
238
+ models_used: list[str]
239
+ providers_used: list[str]
240
+ tool_calls_count: int
241
+ error_count: int
242
+ success_count: int
243
+ avg_latency_ms: float
244
+ avg_input_tokens: float
245
+ avg_output_tokens: float
246
+
247
+
248
+ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> AggregateSummary:
192
249
  """Compute aggregate statistics from a list of LLMCallRecord instances.
193
250
 
194
251
  Args:
@@ -197,65 +254,62 @@ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> d
197
254
  Returns:
198
255
  Dictionary containing aggregated statistics
199
256
  """
200
- aggregates = {
201
- "input_tokens": 0,
202
- "output_tokens": 0,
203
- "total_tokens": 0,
204
- "reasoning_tokens": 0,
205
- "cost_usd": 0.0,
206
- "latency_ms": 0,
207
- "models_used": set(),
208
- "providers_used": set(),
209
- "tool_calls_count": 0,
210
- "error_count": 0,
211
- "success_count": 0,
212
- "call_count": len(call_records),
213
- }
257
+ aggregates = _AggregateAccumulator(call_count=len(call_records))
214
258
 
215
259
  for record in call_records:
216
260
  # Token aggregation
217
261
  if record.usage:
218
262
  if record.usage.input_tokens:
219
- aggregates["input_tokens"] += record.usage.input_tokens
263
+ aggregates.input_tokens += record.usage.input_tokens
220
264
  if record.usage.output_tokens:
221
- aggregates["output_tokens"] += record.usage.output_tokens
265
+ aggregates.output_tokens += record.usage.output_tokens
222
266
  if record.usage.total_tokens:
223
- aggregates["total_tokens"] += record.usage.total_tokens
267
+ aggregates.total_tokens += record.usage.total_tokens
224
268
  if record.usage.reasoning_tokens:
225
- aggregates["reasoning_tokens"] += record.usage.reasoning_tokens
269
+ aggregates.reasoning_tokens += record.usage.reasoning_tokens
226
270
  if record.usage.cost_usd:
227
- aggregates["cost_usd"] += record.usage.cost_usd
271
+ aggregates.cost_usd += record.usage.cost_usd
228
272
 
229
273
  # Latency aggregation
230
- if record.latency_ms:
231
- aggregates["latency_ms"] += record.latency_ms
274
+ if record.latency_ms is not None:
275
+ aggregates.latency_ms += record.latency_ms
232
276
 
233
277
  # Model and provider tracking
234
278
  if record.model_name:
235
- aggregates["models_used"].add(record.model_name)
279
+ aggregates.models_used.add(record.model_name)
236
280
  if record.provider:
237
- aggregates["providers_used"].add(record.provider)
281
+ aggregates.providers_used.add(record.provider)
238
282
 
239
283
  # Tool calls
240
- aggregates["tool_calls_count"] += len(record.output_tool_calls)
284
+ aggregates.tool_calls_count += len(record.output_tool_calls)
241
285
 
242
286
  # Success/error tracking
243
287
  if record.outcome == "error":
244
- aggregates["error_count"] += 1
288
+ aggregates.error_count += 1
245
289
  elif record.outcome == "success":
246
- aggregates["success_count"] += 1
247
-
248
- # Convert sets to lists for JSON serialization
249
- aggregates["models_used"] = list(aggregates["models_used"])
250
- aggregates["providers_used"] = list(aggregates["providers_used"])
290
+ aggregates.success_count += 1
291
+
292
+ summary: AggregateSummary = {
293
+ "call_count": aggregates.call_count,
294
+ "input_tokens": aggregates.input_tokens,
295
+ "output_tokens": aggregates.output_tokens,
296
+ "total_tokens": aggregates.total_tokens,
297
+ "reasoning_tokens": aggregates.reasoning_tokens,
298
+ "cost_usd": aggregates.cost_usd,
299
+ "latency_ms": aggregates.latency_ms,
300
+ "models_used": list(aggregates.models_used),
301
+ "providers_used": list(aggregates.providers_used),
302
+ "tool_calls_count": aggregates.tool_calls_count,
303
+ "error_count": aggregates.error_count,
304
+ "success_count": aggregates.success_count,
305
+ }
251
306
 
252
- # Compute averages
253
- if aggregates["call_count"] > 0:
254
- aggregates["avg_latency_ms"] = aggregates["latency_ms"] / aggregates["call_count"]
255
- aggregates["avg_input_tokens"] = aggregates["input_tokens"] / aggregates["call_count"]
256
- aggregates["avg_output_tokens"] = aggregates["output_tokens"] / aggregates["call_count"]
307
+ if aggregates.call_count > 0:
308
+ summary["avg_latency_ms"] = aggregates.latency_ms / aggregates.call_count
309
+ summary["avg_input_tokens"] = aggregates.input_tokens / aggregates.call_count
310
+ summary["avg_output_tokens"] = aggregates.output_tokens / aggregates.call_count
257
311
 
258
- return aggregates
312
+ return summary
259
313
 
260
314
 
261
315
  def create_llm_call_record_from_streaming(
@@ -26,6 +26,7 @@ application to continue without blocking on sync operations.
26
26
 
27
27
  import asyncio
28
28
  import logging
29
+ from typing import Any
29
30
 
30
31
  import libsql
31
32
 
@@ -66,8 +67,8 @@ class ReplicaSync:
66
67
  self.sync_url = sync_url or CONFIG.sync_url
67
68
  self.auth_token = auth_token or CONFIG.auth_token
68
69
  self.sync_interval = sync_interval or CONFIG.sync_interval
69
- self._sync_task: asyncio.Task | None = None
70
- self._conn: libsql.Connection | None = None
70
+ self._sync_task: asyncio.Task[Any] | None = None
71
+ self._conn: Any | None = None
71
72
 
72
73
  def _ensure_connection(self):
73
74
  """Ensure libsql connection is established.
@@ -113,8 +114,11 @@ class ReplicaSync:
113
114
  """
114
115
  try:
115
116
  self._ensure_connection()
117
+ conn = self._conn
118
+ if conn is None:
119
+ raise RuntimeError("Replica sync connection is not available after initialization")
116
120
  # Run sync in thread pool since libsql sync is blocking
117
- await asyncio.to_thread(self._conn.sync)
121
+ await asyncio.to_thread(conn.sync)
118
122
  logger.info("Successfully synced with remote Turso database")
119
123
  return True
120
124
  except Exception as e:
@@ -146,7 +150,7 @@ class ReplicaSync:
146
150
  # Sleep until next sync interval
147
151
  await asyncio.sleep(self.sync_interval)
148
152
 
149
- def start_background_sync(self) -> asyncio.Task:
153
+ def start_background_sync(self) -> asyncio.Task[Any]:
150
154
  """Start the background sync task.
151
155
 
152
156
  Creates an asyncio task that runs the sync loop. The task is stored
@@ -3,8 +3,8 @@
3
3
  import asyncio
4
4
  import functools
5
5
  import time
6
- from collections.abc import Callable
7
- from typing import Any, TypeVar
6
+ from collections.abc import Awaitable, Callable
7
+ from typing import Any, TypeVar, cast
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -18,9 +18,9 @@ def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0)
18
18
  backoff: Backoff multiplier for each retry
19
19
  """
20
20
 
21
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
21
+ def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
22
22
  @functools.wraps(func)
23
- async def wrapper(*args, **kwargs):
23
+ async def wrapper(*args: Any, **kwargs: Any) -> T:
24
24
  last_exception: Exception | None = None
25
25
  current_delay = delay
26
26
 
@@ -171,13 +171,14 @@ STORAGE_METRICS = StorageMetrics()
171
171
  def track_metrics(operation: str):
172
172
  """Decorator to track storage operation metrics."""
173
173
 
174
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
174
+ def decorator(func: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
175
175
  @functools.wraps(func)
176
- async def async_wrapper(*args, **kwargs):
176
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
177
177
  start_time = time.time()
178
178
  success = False
179
179
  try:
180
- result = await func(*args, **kwargs)
180
+ async_func = cast(Callable[..., Awaitable[T]], func)
181
+ result = await async_func(*args, **kwargs)
181
182
  success = True
182
183
  return result
183
184
  finally:
@@ -185,11 +186,12 @@ def track_metrics(operation: str):
185
186
  STORAGE_METRICS.record_operation(operation, duration, success)
186
187
 
187
188
  @functools.wraps(func)
188
- def sync_wrapper(*args, **kwargs):
189
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
189
190
  start_time = time.time()
190
191
  success = False
191
192
  try:
192
- result = func(*args, **kwargs)
193
+ sync_func = cast(Callable[..., T], func)
194
+ result = sync_func(*args, **kwargs)
193
195
  success = True
194
196
  return result
195
197
  finally:
@@ -0,0 +1,12 @@
1
+ """Turso integration package for tracing v3."""
2
+
3
+ from .daemon import SqldDaemon, get_daemon, start_sqld, stop_sqld
4
+ from .native_manager import NativeLibsqlTraceManager
5
+
6
+ __all__ = [
7
+ "SqldDaemon",
8
+ "NativeLibsqlTraceManager",
9
+ "get_daemon",
10
+ "start_sqld",
11
+ "stop_sqld",
12
+ ]
@@ -6,6 +6,7 @@ import subprocess
6
6
  import time
7
7
 
8
8
  import requests
9
+ from requests import RequestException
9
10
 
10
11
  from ..config import CONFIG
11
12
 
@@ -79,7 +80,7 @@ class SqldDaemon:
79
80
  response = requests.get(health_url, timeout=1)
80
81
  if response.status_code == 200:
81
82
  return
82
- except requests.exceptions.RequestException:
83
+ except RequestException:
83
84
  pass
84
85
 
85
86
  # Check if process crashed