synth-ai 0.2.9.dev17__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (56) hide show
  1. examples/dev/qwen3_32b_qlora_4xh100.toml +40 -0
  2. examples/multi_step/crafter_rl_lora.md +29 -0
  3. examples/multi_step/task_app_config_notes.md +488 -0
  4. examples/qwen_coder/infer_ft_smoke.py +1 -0
  5. examples/qwen_coder/scripts/infer_coder.sh +1 -0
  6. examples/qwen_coder/scripts/train_coder_30b.sh +1 -0
  7. examples/qwen_coder/subset_jsonl.py +1 -0
  8. examples/qwen_coder/todos.md +38 -0
  9. examples/qwen_coder/validate_jsonl.py +1 -0
  10. examples/vlm/PROPOSAL.md +53 -0
  11. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
  12. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  13. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  14. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
  15. examples/warming_up_to_rl/old/event_rewards.md +234 -0
  16. examples/warming_up_to_rl/old/notes.md +73 -0
  17. examples/warming_up_to_rl/run_eval.py +142 -25
  18. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
  19. synth_ai/__init__.py +5 -20
  20. synth_ai/api/train/builders.py +25 -14
  21. synth_ai/api/train/cli.py +29 -6
  22. synth_ai/api/train/env_resolver.py +18 -19
  23. synth_ai/api/train/supported_algos.py +8 -5
  24. synth_ai/api/train/utils.py +6 -1
  25. synth_ai/cli/__init__.py +4 -2
  26. synth_ai/cli/_storage.py +19 -0
  27. synth_ai/cli/balance.py +14 -2
  28. synth_ai/cli/calc.py +37 -22
  29. synth_ai/cli/legacy_root_backup.py +12 -14
  30. synth_ai/cli/recent.py +12 -7
  31. synth_ai/cli/root.py +1 -23
  32. synth_ai/cli/status.py +4 -3
  33. synth_ai/cli/task_apps.py +143 -137
  34. synth_ai/cli/traces.py +4 -3
  35. synth_ai/cli/watch.py +3 -2
  36. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_modal_ft/filter_traces_sft_turso.py +738 -0
  37. synth_ai/environments/examples/crafter_classic/agent_demos/crafter_openai_ft/filter_traces_sft_turso.py +580 -0
  38. synth_ai/jobs/client.py +15 -3
  39. synth_ai/task/server.py +14 -7
  40. synth_ai/tracing_v3/decorators.py +51 -26
  41. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  42. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  43. synth_ai/tracing_v3/replica_sync.py +8 -4
  44. synth_ai/tracing_v3/storage/utils.py +11 -9
  45. synth_ai/tracing_v3/turso/__init__.py +12 -0
  46. synth_ai/tracing_v3/turso/daemon.py +2 -1
  47. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  48. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/METADATA +33 -88
  49. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/RECORD +53 -41
  50. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -1
  51. synth/__init__.py +0 -14
  52. synth_ai/_docs_message.py +0 -10
  53. synth_ai/main.py +0 -5
  54. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
  55. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
  56. {synth_ai-0.2.9.dev17.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
@@ -28,8 +28,8 @@ import asyncio
28
28
  import contextvars
29
29
  import functools
30
30
  import time
31
- from collections.abc import Callable
32
- from typing import Any, TypeVar
31
+ from collections.abc import Awaitable, Callable, Mapping
32
+ from typing import Any, TypeVar, cast, overload
33
33
 
34
34
  from .abstractions import LMCAISEvent, TimeRecord
35
35
  from .utils import calculate_cost, detect_provider
@@ -88,6 +88,16 @@ def get_session_tracer() -> Any:
88
88
  T = TypeVar("T")
89
89
 
90
90
 
91
+ @overload
92
+ def with_session(require: bool = True) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
93
+ ...
94
+
95
+
96
+ @overload
97
+ def with_session(require: bool = True) -> Callable[[Callable[..., T]], Callable[..., T]]:
98
+ ...
99
+
100
+
91
101
  def with_session(require: bool = True):
92
102
  """Decorator that ensures a session is active.
93
103
 
@@ -109,29 +119,31 @@ def with_session(require: bool = True):
109
119
  ```
110
120
  """
111
121
 
112
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
122
+ def decorator(fn: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
113
123
  if asyncio.iscoroutinefunction(fn):
114
124
 
115
125
  @functools.wraps(fn)
116
- async def async_wrapper(*args, **kwargs):
126
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
117
127
  session_id = get_session_id()
118
128
  if require and session_id is None:
119
129
  raise RuntimeError(
120
130
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
121
131
  )
122
- return await fn(*args, **kwargs)
132
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
133
+ return await async_fn(*args, **kwargs)
123
134
 
124
135
  return async_wrapper
125
136
  else:
126
137
 
127
138
  @functools.wraps(fn)
128
- def sync_wrapper(*args, **kwargs):
139
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
129
140
  session_id = get_session_id()
130
141
  if require and session_id is None:
131
142
  raise RuntimeError(
132
143
  f"No active session for {getattr(fn, '__name__', 'unknown')}"
133
144
  )
134
- return fn(*args, **kwargs)
145
+ sync_fn = cast(Callable[..., T], fn)
146
+ return sync_fn(*args, **kwargs)
135
147
 
136
148
  return sync_wrapper
137
149
 
@@ -172,31 +184,36 @@ def trace_llm_call(
172
184
  ```
173
185
  """
174
186
 
175
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
187
+ def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
176
188
  if asyncio.iscoroutinefunction(fn):
189
+ async_fn: Callable[..., Awaitable[T]] = fn
177
190
 
178
191
  @functools.wraps(fn)
179
- async def async_wrapper(*args, **kwargs):
192
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
180
193
  tracer = get_session_tracer()
181
194
  if not tracer:
182
- return await fn(*args, **kwargs)
195
+ return await async_fn(*args, **kwargs)
183
196
 
184
197
  start_time = time.time()
185
198
  system_state_before = kwargs.get("state_before", {})
186
199
 
187
200
  try:
188
- result = await fn(*args, **kwargs)
201
+ result = await async_fn(*args, **kwargs)
189
202
 
190
203
  # Extract metrics from result - this assumes the result follows
191
204
  # common LLM API response formats (OpenAI, Anthropic, etc.)
192
- if extract_tokens and isinstance(result, dict):
193
- input_tokens = result.get("usage", {}).get("prompt_tokens")
194
- output_tokens = result.get("usage", {}).get("completion_tokens")
195
- total_tokens = result.get("usage", {}).get("total_tokens")
196
- actual_model = result.get("model", model_name)
197
- else:
198
- input_tokens = output_tokens = total_tokens = None
199
- actual_model = model_name
205
+ input_tokens = output_tokens = total_tokens = None
206
+ actual_model = model_name
207
+ if extract_tokens and isinstance(result, Mapping):
208
+ result_mapping = cast(Mapping[str, Any], result)
209
+ usage = result_mapping.get("usage")
210
+ if isinstance(usage, Mapping):
211
+ input_tokens = usage.get("prompt_tokens")
212
+ output_tokens = usage.get("completion_tokens")
213
+ total_tokens = usage.get("total_tokens")
214
+ value = result_mapping.get("model")
215
+ if isinstance(value, str):
216
+ actual_model = value
200
217
 
201
218
  latency_ms = int((time.time() - start_time) * 1000)
202
219
 
@@ -272,19 +289,26 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
272
289
  ```
273
290
  """
274
291
 
275
- def decorator(fn: Callable[..., T]) -> Callable[..., T]:
292
+ def decorator(
293
+ fn: Callable[..., Awaitable[T]] | Callable[..., T]
294
+ ) -> Callable[..., Awaitable[T]] | Callable[..., T]:
276
295
  if asyncio.iscoroutinefunction(fn):
296
+ async_fn = cast(Callable[..., Awaitable[T]], fn)
277
297
 
278
298
  @functools.wraps(fn)
279
- async def async_wrapper(self, *args, **kwargs):
299
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
280
300
  tracer = get_session_tracer()
281
301
  if not tracer:
282
- return await fn(self, *args, **kwargs)
302
+ return await async_fn(*args, **kwargs)
283
303
 
284
304
  from .abstractions import RuntimeEvent
285
305
 
286
306
  # Use class name as system_id if not provided
287
- actual_system_id = system_id or self.__class__.__name__
307
+ self_obj = args[0] if args else None
308
+ inferred_system_id = (
309
+ self_obj.__class__.__name__ if self_obj is not None else "unknown"
310
+ )
311
+ actual_system_id = system_id or inferred_system_id
288
312
 
289
313
  event = RuntimeEvent(
290
314
  system_instance_id=actual_system_id,
@@ -298,17 +322,18 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
298
322
  )
299
323
 
300
324
  await tracer.record_event(event)
301
- return await fn(self, *args, **kwargs)
325
+ return await async_fn(*args, **kwargs)
302
326
 
303
327
  return async_wrapper
304
328
  else:
305
329
 
306
330
  @functools.wraps(fn)
307
- def sync_wrapper(self, *args, **kwargs):
331
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
308
332
  # For sync methods, we can't easily trace without blocking
309
333
  # the event loop. This is a limitation of the async-first design.
310
334
  # Consider converting to async or using a different approach
311
- return fn(self, *args, **kwargs)
335
+ sync_fn = cast(Callable[..., T], fn)
336
+ return sync_fn(*args, **kwargs)
312
337
 
313
338
  return sync_wrapper
314
339
 
@@ -2,13 +2,14 @@
2
2
 
3
3
  import asyncio
4
4
  import time
5
+ from typing import Any
5
6
 
6
- from synth_ai.tracing_v3 import SessionTracer
7
- from synth_ai.tracing_v3.abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
8
- from synth_ai.tracing_v3.turso.daemon import SqldDaemon
7
+ from .. import SessionTracer
8
+ from ..abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
9
+ from ..turso.daemon import SqldDaemon
9
10
 
10
11
 
11
- async def simulate_llm_call(model: str, prompt: str) -> dict:
12
+ async def simulate_llm_call(model: str, prompt: str) -> dict[str, Any]:
12
13
  """Simulate an LLM API call."""
13
14
  await asyncio.sleep(0.1) # Simulate network latency
14
15
 
@@ -133,6 +134,9 @@ async def main():
133
134
  print("\n--- Example 3: Querying Data ---")
134
135
 
135
136
  # Get model usage statistics
137
+ if tracer.db is None:
138
+ raise RuntimeError("Tracer database backend is not initialized")
139
+
136
140
  model_usage = await tracer.db.get_model_usage()
137
141
  print("\nModel Usage:")
138
142
  print(model_usage)
@@ -150,9 +154,10 @@ async def main():
150
154
  # Get specific session details
151
155
  if recent_sessions:
152
156
  session_detail = await tracer.db.get_session_trace(recent_sessions[0]["session_id"])
153
- print(f"\nSession Detail for {session_detail['session_id']}:")
154
- print(f" Created: {session_detail['created_at']}")
155
- print(f" Timesteps: {len(session_detail['timesteps'])}")
157
+ if session_detail:
158
+ print(f"\nSession Detail for {session_detail['session_id']}:")
159
+ print(f" Created: {session_detail['created_at']}")
160
+ print(f" Timesteps: {len(session_detail['timesteps'])}")
156
161
 
157
162
  # Example 4: Using hooks
158
163
  print("\n--- Example 4: Hooks ---")
@@ -4,11 +4,14 @@ This module provides utilities to convert vendor responses to LLMCallRecord
4
4
  format and compute aggregates from call records.
5
5
  """
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  import uuid
10
+ from dataclasses import dataclass, field
8
11
  from datetime import UTC, datetime
9
- from typing import Any
12
+ from typing import Any, TypedDict, cast
10
13
 
11
- from synth_ai.tracing_v3.lm_call_record_abstractions import (
14
+ from .lm_call_record_abstractions import (
12
15
  LLMCallRecord,
13
16
  LLMChunk,
14
17
  LLMContentPart,
@@ -17,7 +20,21 @@ from synth_ai.tracing_v3.lm_call_record_abstractions import (
17
20
  LLMUsage,
18
21
  ToolCallSpec,
19
22
  )
20
- from synth_ai.v0.lm.vendors.base import BaseLMResponse
23
+
24
+ BaseLMResponse = Any
25
+
26
+
27
+ class _UsageDict(TypedDict, total=False):
28
+ prompt_tokens: int
29
+ completion_tokens: int
30
+ total_tokens: int
31
+ reasoning_tokens: int
32
+ cost_usd: float
33
+ duration_ms: int
34
+ reasoning_input_tokens: int
35
+ reasoning_output_tokens: int
36
+ cache_write_tokens: int
37
+ cache_read_tokens: int
21
38
 
22
39
 
23
40
  def create_llm_call_record_from_response(
@@ -110,9 +127,10 @@ def create_llm_call_record_from_response(
110
127
  )
111
128
 
112
129
  # Extract tool calls if present
113
- output_tool_calls = []
114
- if hasattr(response, "tool_calls") and response.tool_calls:
115
- for idx, tool_call in enumerate(response.tool_calls):
130
+ output_tool_calls: list[ToolCallSpec] = []
131
+ tool_calls_data = cast(list[dict[str, Any]] | None, getattr(response, "tool_calls", None))
132
+ if tool_calls_data:
133
+ for idx, tool_call in enumerate(tool_calls_data):
116
134
  if isinstance(tool_call, dict):
117
135
  output_tool_calls.append(
118
136
  ToolCallSpec(
@@ -125,18 +143,19 @@ def create_llm_call_record_from_response(
125
143
 
126
144
  # Extract usage information
127
145
  usage = None
128
- if hasattr(response, "usage") and response.usage:
146
+ usage_data = cast(_UsageDict | None, getattr(response, "usage", None))
147
+ if usage_data:
129
148
  usage = LLMUsage(
130
- input_tokens=response.usage.get("input_tokens"),
131
- output_tokens=response.usage.get("output_tokens"),
132
- total_tokens=response.usage.get("total_tokens"),
133
- cost_usd=response.usage.get("cost_usd"),
149
+ input_tokens=usage_data.get("input_tokens"),
150
+ output_tokens=usage_data.get("output_tokens"),
151
+ total_tokens=usage_data.get("total_tokens"),
152
+ cost_usd=usage_data.get("cost_usd"),
134
153
  # Additional token accounting if available
135
- reasoning_tokens=response.usage.get("reasoning_tokens"),
136
- reasoning_input_tokens=response.usage.get("reasoning_input_tokens"),
137
- reasoning_output_tokens=response.usage.get("reasoning_output_tokens"),
138
- cache_write_tokens=response.usage.get("cache_write_tokens"),
139
- cache_read_tokens=response.usage.get("cache_read_tokens"),
154
+ reasoning_tokens=usage_data.get("reasoning_tokens"),
155
+ reasoning_input_tokens=usage_data.get("reasoning_input_tokens"),
156
+ reasoning_output_tokens=usage_data.get("reasoning_output_tokens"),
157
+ cache_write_tokens=usage_data.get("cache_write_tokens"),
158
+ cache_read_tokens=usage_data.get("cache_read_tokens"),
140
159
  )
141
160
 
142
161
  # Build request parameters
@@ -188,7 +207,45 @@ def create_llm_call_record_from_response(
188
207
  return record
189
208
 
190
209
 
191
- def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> dict[str, Any]:
210
+ @dataclass
211
+ class _AggregateAccumulator:
212
+ """Mutable accumulator for call record aggregates."""
213
+
214
+ call_count: int = 0
215
+ input_tokens: int = 0
216
+ output_tokens: int = 0
217
+ total_tokens: int = 0
218
+ reasoning_tokens: int = 0
219
+ cost_usd: float = 0.0
220
+ latency_ms: int = 0
221
+ models_used: set[str] = field(default_factory=set)
222
+ providers_used: set[str] = field(default_factory=set)
223
+ tool_calls_count: int = 0
224
+ error_count: int = 0
225
+ success_count: int = 0
226
+
227
+
228
+ class AggregateSummary(TypedDict, total=False):
229
+ """Aggregate metrics derived from call records."""
230
+
231
+ call_count: int
232
+ input_tokens: int
233
+ output_tokens: int
234
+ total_tokens: int
235
+ reasoning_tokens: int
236
+ cost_usd: float
237
+ latency_ms: int
238
+ models_used: list[str]
239
+ providers_used: list[str]
240
+ tool_calls_count: int
241
+ error_count: int
242
+ success_count: int
243
+ avg_latency_ms: float
244
+ avg_input_tokens: float
245
+ avg_output_tokens: float
246
+
247
+
248
+ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> AggregateSummary:
192
249
  """Compute aggregate statistics from a list of LLMCallRecord instances.
193
250
 
194
251
  Args:
@@ -197,65 +254,62 @@ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> d
197
254
  Returns:
198
255
  Dictionary containing aggregated statistics
199
256
  """
200
- aggregates = {
201
- "input_tokens": 0,
202
- "output_tokens": 0,
203
- "total_tokens": 0,
204
- "reasoning_tokens": 0,
205
- "cost_usd": 0.0,
206
- "latency_ms": 0,
207
- "models_used": set(),
208
- "providers_used": set(),
209
- "tool_calls_count": 0,
210
- "error_count": 0,
211
- "success_count": 0,
212
- "call_count": len(call_records),
213
- }
257
+ aggregates = _AggregateAccumulator(call_count=len(call_records))
214
258
 
215
259
  for record in call_records:
216
260
  # Token aggregation
217
261
  if record.usage:
218
262
  if record.usage.input_tokens:
219
- aggregates["input_tokens"] += record.usage.input_tokens
263
+ aggregates.input_tokens += record.usage.input_tokens
220
264
  if record.usage.output_tokens:
221
- aggregates["output_tokens"] += record.usage.output_tokens
265
+ aggregates.output_tokens += record.usage.output_tokens
222
266
  if record.usage.total_tokens:
223
- aggregates["total_tokens"] += record.usage.total_tokens
267
+ aggregates.total_tokens += record.usage.total_tokens
224
268
  if record.usage.reasoning_tokens:
225
- aggregates["reasoning_tokens"] += record.usage.reasoning_tokens
269
+ aggregates.reasoning_tokens += record.usage.reasoning_tokens
226
270
  if record.usage.cost_usd:
227
- aggregates["cost_usd"] += record.usage.cost_usd
271
+ aggregates.cost_usd += record.usage.cost_usd
228
272
 
229
273
  # Latency aggregation
230
- if record.latency_ms:
231
- aggregates["latency_ms"] += record.latency_ms
274
+ if record.latency_ms is not None:
275
+ aggregates.latency_ms += record.latency_ms
232
276
 
233
277
  # Model and provider tracking
234
278
  if record.model_name:
235
- aggregates["models_used"].add(record.model_name)
279
+ aggregates.models_used.add(record.model_name)
236
280
  if record.provider:
237
- aggregates["providers_used"].add(record.provider)
281
+ aggregates.providers_used.add(record.provider)
238
282
 
239
283
  # Tool calls
240
- aggregates["tool_calls_count"] += len(record.output_tool_calls)
284
+ aggregates.tool_calls_count += len(record.output_tool_calls)
241
285
 
242
286
  # Success/error tracking
243
287
  if record.outcome == "error":
244
- aggregates["error_count"] += 1
288
+ aggregates.error_count += 1
245
289
  elif record.outcome == "success":
246
- aggregates["success_count"] += 1
247
-
248
- # Convert sets to lists for JSON serialization
249
- aggregates["models_used"] = list(aggregates["models_used"])
250
- aggregates["providers_used"] = list(aggregates["providers_used"])
290
+ aggregates.success_count += 1
291
+
292
+ summary: AggregateSummary = {
293
+ "call_count": aggregates.call_count,
294
+ "input_tokens": aggregates.input_tokens,
295
+ "output_tokens": aggregates.output_tokens,
296
+ "total_tokens": aggregates.total_tokens,
297
+ "reasoning_tokens": aggregates.reasoning_tokens,
298
+ "cost_usd": aggregates.cost_usd,
299
+ "latency_ms": aggregates.latency_ms,
300
+ "models_used": list(aggregates.models_used),
301
+ "providers_used": list(aggregates.providers_used),
302
+ "tool_calls_count": aggregates.tool_calls_count,
303
+ "error_count": aggregates.error_count,
304
+ "success_count": aggregates.success_count,
305
+ }
251
306
 
252
- # Compute averages
253
- if aggregates["call_count"] > 0:
254
- aggregates["avg_latency_ms"] = aggregates["latency_ms"] / aggregates["call_count"]
255
- aggregates["avg_input_tokens"] = aggregates["input_tokens"] / aggregates["call_count"]
256
- aggregates["avg_output_tokens"] = aggregates["output_tokens"] / aggregates["call_count"]
307
+ if aggregates.call_count > 0:
308
+ summary["avg_latency_ms"] = aggregates.latency_ms / aggregates.call_count
309
+ summary["avg_input_tokens"] = aggregates.input_tokens / aggregates.call_count
310
+ summary["avg_output_tokens"] = aggregates.output_tokens / aggregates.call_count
257
311
 
258
- return aggregates
312
+ return summary
259
313
 
260
314
 
261
315
  def create_llm_call_record_from_streaming(
@@ -26,6 +26,7 @@ application to continue without blocking on sync operations.
26
26
 
27
27
  import asyncio
28
28
  import logging
29
+ from typing import Any
29
30
 
30
31
  import libsql
31
32
 
@@ -66,8 +67,8 @@ class ReplicaSync:
66
67
  self.sync_url = sync_url or CONFIG.sync_url
67
68
  self.auth_token = auth_token or CONFIG.auth_token
68
69
  self.sync_interval = sync_interval or CONFIG.sync_interval
69
- self._sync_task: asyncio.Task | None = None
70
- self._conn: libsql.Connection | None = None
70
+ self._sync_task: asyncio.Task[Any] | None = None
71
+ self._conn: Any | None = None
71
72
 
72
73
  def _ensure_connection(self):
73
74
  """Ensure libsql connection is established.
@@ -113,8 +114,11 @@ class ReplicaSync:
113
114
  """
114
115
  try:
115
116
  self._ensure_connection()
117
+ conn = self._conn
118
+ if conn is None:
119
+ raise RuntimeError("Replica sync connection is not available after initialization")
116
120
  # Run sync in thread pool since libsql sync is blocking
117
- await asyncio.to_thread(self._conn.sync)
121
+ await asyncio.to_thread(conn.sync)
118
122
  logger.info("Successfully synced with remote Turso database")
119
123
  return True
120
124
  except Exception as e:
@@ -146,7 +150,7 @@ class ReplicaSync:
146
150
  # Sleep until next sync interval
147
151
  await asyncio.sleep(self.sync_interval)
148
152
 
149
- def start_background_sync(self) -> asyncio.Task:
153
+ def start_background_sync(self) -> asyncio.Task[Any]:
150
154
  """Start the background sync task.
151
155
 
152
156
  Creates an asyncio task that runs the sync loop. The task is stored
@@ -3,8 +3,8 @@
3
3
  import asyncio
4
4
  import functools
5
5
  import time
6
- from collections.abc import Callable
7
- from typing import Any, TypeVar
6
+ from collections.abc import Awaitable, Callable
7
+ from typing import Any, TypeVar, cast
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -18,9 +18,9 @@ def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0)
18
18
  backoff: Backoff multiplier for each retry
19
19
  """
20
20
 
21
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
21
+ def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
22
22
  @functools.wraps(func)
23
- async def wrapper(*args, **kwargs):
23
+ async def wrapper(*args: Any, **kwargs: Any) -> T:
24
24
  last_exception: Exception | None = None
25
25
  current_delay = delay
26
26
 
@@ -171,13 +171,14 @@ STORAGE_METRICS = StorageMetrics()
171
171
  def track_metrics(operation: str):
172
172
  """Decorator to track storage operation metrics."""
173
173
 
174
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
174
+ def decorator(func: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
175
175
  @functools.wraps(func)
176
- async def async_wrapper(*args, **kwargs):
176
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
177
177
  start_time = time.time()
178
178
  success = False
179
179
  try:
180
- result = await func(*args, **kwargs)
180
+ async_func = cast(Callable[..., Awaitable[T]], func)
181
+ result = await async_func(*args, **kwargs)
181
182
  success = True
182
183
  return result
183
184
  finally:
@@ -185,11 +186,12 @@ def track_metrics(operation: str):
185
186
  STORAGE_METRICS.record_operation(operation, duration, success)
186
187
 
187
188
  @functools.wraps(func)
188
- def sync_wrapper(*args, **kwargs):
189
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
189
190
  start_time = time.time()
190
191
  success = False
191
192
  try:
192
- result = func(*args, **kwargs)
193
+ sync_func = cast(Callable[..., T], func)
194
+ result = sync_func(*args, **kwargs)
193
195
  success = True
194
196
  return result
195
197
  finally:
@@ -0,0 +1,12 @@
1
+ """Turso integration package for tracing v3."""
2
+
3
+ from .daemon import SqldDaemon, get_daemon, start_sqld, stop_sqld
4
+ from .native_manager import NativeLibsqlTraceManager
5
+
6
+ __all__ = [
7
+ "SqldDaemon",
8
+ "NativeLibsqlTraceManager",
9
+ "get_daemon",
10
+ "start_sqld",
11
+ "stop_sqld",
12
+ ]
@@ -6,6 +6,7 @@ import subprocess
6
6
  import time
7
7
 
8
8
  import requests
9
+ from requests import RequestException
9
10
 
10
11
  from ..config import CONFIG
11
12
 
@@ -79,7 +80,7 @@ class SqldDaemon:
79
80
  response = requests.get(health_url, timeout=1)
80
81
  if response.status_code == 200:
81
82
  return
82
- except requests.exceptions.RequestException:
83
+ except RequestException:
83
84
  pass
84
85
 
85
86
  # Check if process crashed
@@ -11,18 +11,14 @@ import asyncio
11
11
  import json
12
12
  import logging
13
13
  import re
14
+ from collections.abc import Callable
14
15
  from dataclasses import asdict, dataclass
15
16
  from datetime import UTC, datetime
16
- from typing import Any
17
+ from typing import TYPE_CHECKING, Any, cast
17
18
 
18
19
  import libsql
19
20
  from sqlalchemy.engine import make_url
20
21
 
21
- try: # pragma: no cover - exercised only when pandas present
22
- import pandas as pd # type: ignore
23
- except Exception: # pragma: no cover
24
- pd = None # type: ignore[assignment]
25
-
26
22
  from ..abstractions import (
27
23
  EnvironmentEvent,
28
24
  LMCAISEvent,
@@ -34,6 +30,24 @@ from ..config import CONFIG
34
30
  from ..storage.base import TraceStorage
35
31
  from .models import analytics_views
36
32
 
33
+ if TYPE_CHECKING:
34
+ from sqlite3 import Connection as LibsqlConnection
35
+ else: # pragma: no cover - runtime fallback for typing only
36
+ LibsqlConnection = Any # type: ignore[assignment]
37
+
38
+ _LIBSQL_CONNECT_ATTR = getattr(libsql, "connect", None)
39
+ if _LIBSQL_CONNECT_ATTR is None: # pragma: no cover - defensive guard
40
+ raise RuntimeError("libsql.connect is required for NativeLibsqlTraceManager")
41
+ _libsql_connect: Callable[..., LibsqlConnection] = cast(
42
+ Callable[..., LibsqlConnection],
43
+ _LIBSQL_CONNECT_ATTR,
44
+ )
45
+
46
+ try: # pragma: no cover - exercised only when pandas present
47
+ import pandas as pd # type: ignore
48
+ except Exception: # pragma: no cover
49
+ pd = None # type: ignore[assignment]
50
+
37
51
  logger = logging.getLogger(__name__)
38
52
 
39
53
 
@@ -66,9 +80,8 @@ def _resolve_connection_target(db_url: str | None, auth_token: str | None) -> _C
66
80
  # Fallback to SQLAlchemy URL parsing for anything else we missed.
67
81
  try:
68
82
  parsed = make_url(url)
69
- if parsed.drivername.startswith("sqlite"):
70
- if parsed.database:
71
- return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
83
+ if parsed.drivername.startswith("sqlite") and parsed.database:
84
+ return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
72
85
  if parsed.drivername.startswith("libsql"):
73
86
  database = parsed.render_as_string(hide_password=False)
74
87
  return _ConnectionTarget(database=database, sync_url=database, auth_token=auth_token)
@@ -314,12 +327,12 @@ class NativeLibsqlTraceManager(TraceStorage):
314
327
  ):
315
328
  self._config_auth_token = auth_token
316
329
  self._target = _resolve_connection_target(db_url, auth_token)
317
- self._conn: libsql.Connection | None = None
330
+ self._conn: LibsqlConnection | None = None
318
331
  self._conn_lock = asyncio.Lock()
319
332
  self._op_lock = asyncio.Lock()
320
333
  self._initialized = False
321
334
 
322
- def _open_connection(self) -> libsql.Connection:
335
+ def _open_connection(self) -> LibsqlConnection:
323
336
  """Open a libsql connection for the resolved target."""
324
337
  kwargs: dict[str, Any] = {}
325
338
  if self._target.sync_url and self._target.sync_url.startswith("libsql://"):
@@ -329,7 +342,7 @@ class NativeLibsqlTraceManager(TraceStorage):
329
342
  # Disable automatic background sync; ReplicaSync drives this explicitly.
330
343
  kwargs.setdefault("sync_interval", 0)
331
344
  logger.debug("Opening libsql connection to %s", self._target.database)
332
- return libsql.connect(self._target.database, **kwargs)
345
+ return _libsql_connect(self._target.database, **kwargs)
333
346
 
334
347
  async def initialize(self):
335
348
  """Initialise the backend."""
@@ -493,7 +506,7 @@ class NativeLibsqlTraceManager(TraceStorage):
493
506
  return None
494
507
 
495
508
  session_columns = ["session_id", "created_at", "num_timesteps", "num_events", "num_messages", "metadata"]
496
- session_data = dict(zip(session_columns, session_row))
509
+ session_data = dict(zip(session_columns, session_row, strict=True))
497
510
 
498
511
  timestep_cursor = conn.execute(
499
512
  """
@@ -608,10 +621,10 @@ class NativeLibsqlTraceManager(TraceStorage):
608
621
 
609
622
  if not rows:
610
623
  if pd is not None:
611
- return pd.DataFrame(columns=[col for col in columns])
624
+ return pd.DataFrame(columns=list(columns))
612
625
  return []
613
626
 
614
- records = [dict(zip(columns, row)) for row in rows]
627
+ records = [dict(zip(columns, row, strict=True)) for row in rows]
615
628
  if pd is not None:
616
629
  return pd.DataFrame(records)
617
630
  return records