synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synth-ai might be problematic. Click here for more details.

Files changed (73) hide show
  1. examples/agora_ex/README_MoE.md +224 -0
  2. examples/agora_ex/__init__.py +7 -0
  3. examples/agora_ex/agora_ex.py +65 -0
  4. examples/agora_ex/agora_ex_task_app.py +590 -0
  5. examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
  6. examples/agora_ex/reward_fn_grpo-human.py +129 -0
  7. examples/agora_ex/system_prompt_CURRENT.md +63 -0
  8. examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
  9. examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
  10. examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
  11. examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
  12. examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
  13. examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
  14. examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
  15. examples/multi_step/crafter_rl_lora.md +51 -10
  16. examples/multi_step/sse_metrics_streaming_notes.md +357 -0
  17. examples/multi_step/task_app_config_notes.md +494 -0
  18. examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
  19. examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
  20. examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
  21. examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
  22. examples/warming_up_to_rl/run_eval.py +267 -41
  23. examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
  24. examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
  25. examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
  26. examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
  27. synth_ai/__init__.py +41 -1
  28. synth_ai/api/train/builders.py +74 -33
  29. synth_ai/api/train/cli.py +29 -6
  30. synth_ai/api/train/configs/__init__.py +44 -0
  31. synth_ai/api/train/configs/rl.py +133 -0
  32. synth_ai/api/train/configs/sft.py +94 -0
  33. synth_ai/api/train/configs/shared.py +24 -0
  34. synth_ai/api/train/env_resolver.py +18 -19
  35. synth_ai/api/train/supported_algos.py +8 -5
  36. synth_ai/api/train/utils.py +6 -1
  37. synth_ai/cli/__init__.py +4 -2
  38. synth_ai/cli/_storage.py +19 -0
  39. synth_ai/cli/balance.py +14 -2
  40. synth_ai/cli/calc.py +37 -22
  41. synth_ai/cli/demo.py +38 -39
  42. synth_ai/cli/legacy_root_backup.py +12 -14
  43. synth_ai/cli/recent.py +12 -7
  44. synth_ai/cli/rl_demo.py +81 -102
  45. synth_ai/cli/status.py +4 -3
  46. synth_ai/cli/task_apps.py +146 -137
  47. synth_ai/cli/traces.py +4 -3
  48. synth_ai/cli/watch.py +3 -2
  49. synth_ai/demos/core/cli.py +121 -159
  50. synth_ai/environments/examples/crafter_classic/environment.py +16 -0
  51. synth_ai/evals/__init__.py +15 -0
  52. synth_ai/evals/client.py +85 -0
  53. synth_ai/evals/types.py +42 -0
  54. synth_ai/jobs/client.py +15 -3
  55. synth_ai/judge_schemas.py +127 -0
  56. synth_ai/rubrics/__init__.py +22 -0
  57. synth_ai/rubrics/validators.py +126 -0
  58. synth_ai/task/server.py +14 -7
  59. synth_ai/tracing_v3/decorators.py +51 -26
  60. synth_ai/tracing_v3/examples/basic_usage.py +12 -7
  61. synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
  62. synth_ai/tracing_v3/replica_sync.py +8 -4
  63. synth_ai/tracing_v3/serialization.py +130 -0
  64. synth_ai/tracing_v3/storage/utils.py +11 -9
  65. synth_ai/tracing_v3/turso/__init__.py +12 -0
  66. synth_ai/tracing_v3/turso/daemon.py +2 -1
  67. synth_ai/tracing_v3/turso/native_manager.py +28 -15
  68. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
  69. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
  70. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
  71. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
  72. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
  73. {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
@@ -4,11 +4,14 @@ This module provides utilities to convert vendor responses to LLMCallRecord
4
4
  format and compute aggregates from call records.
5
5
  """
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  import uuid
10
+ from dataclasses import dataclass, field
8
11
  from datetime import UTC, datetime
9
- from typing import Any
12
+ from typing import Any, TypedDict, cast
10
13
 
11
- from synth_ai.tracing_v3.lm_call_record_abstractions import (
14
+ from .lm_call_record_abstractions import (
12
15
  LLMCallRecord,
13
16
  LLMChunk,
14
17
  LLMContentPart,
@@ -17,7 +20,21 @@ from synth_ai.tracing_v3.lm_call_record_abstractions import (
17
20
  LLMUsage,
18
21
  ToolCallSpec,
19
22
  )
20
- from synth_ai.v0.lm.vendors.base import BaseLMResponse
23
+
24
+ BaseLMResponse = Any
25
+
26
+
27
+ class _UsageDict(TypedDict, total=False):
28
+ prompt_tokens: int
29
+ completion_tokens: int
30
+ total_tokens: int
31
+ reasoning_tokens: int
32
+ cost_usd: float
33
+ duration_ms: int
34
+ reasoning_input_tokens: int
35
+ reasoning_output_tokens: int
36
+ cache_write_tokens: int
37
+ cache_read_tokens: int
21
38
 
22
39
 
23
40
  def create_llm_call_record_from_response(
@@ -110,9 +127,10 @@ def create_llm_call_record_from_response(
110
127
  )
111
128
 
112
129
  # Extract tool calls if present
113
- output_tool_calls = []
114
- if hasattr(response, "tool_calls") and response.tool_calls:
115
- for idx, tool_call in enumerate(response.tool_calls):
130
+ output_tool_calls: list[ToolCallSpec] = []
131
+ tool_calls_data = cast(list[dict[str, Any]] | None, getattr(response, "tool_calls", None))
132
+ if tool_calls_data:
133
+ for idx, tool_call in enumerate(tool_calls_data):
116
134
  if isinstance(tool_call, dict):
117
135
  output_tool_calls.append(
118
136
  ToolCallSpec(
@@ -125,18 +143,19 @@ def create_llm_call_record_from_response(
125
143
 
126
144
  # Extract usage information
127
145
  usage = None
128
- if hasattr(response, "usage") and response.usage:
146
+ usage_data = cast(_UsageDict | None, getattr(response, "usage", None))
147
+ if usage_data:
129
148
  usage = LLMUsage(
130
- input_tokens=response.usage.get("input_tokens"),
131
- output_tokens=response.usage.get("output_tokens"),
132
- total_tokens=response.usage.get("total_tokens"),
133
- cost_usd=response.usage.get("cost_usd"),
149
+ input_tokens=usage_data.get("input_tokens"),
150
+ output_tokens=usage_data.get("output_tokens"),
151
+ total_tokens=usage_data.get("total_tokens"),
152
+ cost_usd=usage_data.get("cost_usd"),
134
153
  # Additional token accounting if available
135
- reasoning_tokens=response.usage.get("reasoning_tokens"),
136
- reasoning_input_tokens=response.usage.get("reasoning_input_tokens"),
137
- reasoning_output_tokens=response.usage.get("reasoning_output_tokens"),
138
- cache_write_tokens=response.usage.get("cache_write_tokens"),
139
- cache_read_tokens=response.usage.get("cache_read_tokens"),
154
+ reasoning_tokens=usage_data.get("reasoning_tokens"),
155
+ reasoning_input_tokens=usage_data.get("reasoning_input_tokens"),
156
+ reasoning_output_tokens=usage_data.get("reasoning_output_tokens"),
157
+ cache_write_tokens=usage_data.get("cache_write_tokens"),
158
+ cache_read_tokens=usage_data.get("cache_read_tokens"),
140
159
  )
141
160
 
142
161
  # Build request parameters
@@ -188,7 +207,45 @@ def create_llm_call_record_from_response(
188
207
  return record
189
208
 
190
209
 
191
- def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> dict[str, Any]:
210
+ @dataclass
211
+ class _AggregateAccumulator:
212
+ """Mutable accumulator for call record aggregates."""
213
+
214
+ call_count: int = 0
215
+ input_tokens: int = 0
216
+ output_tokens: int = 0
217
+ total_tokens: int = 0
218
+ reasoning_tokens: int = 0
219
+ cost_usd: float = 0.0
220
+ latency_ms: int = 0
221
+ models_used: set[str] = field(default_factory=set)
222
+ providers_used: set[str] = field(default_factory=set)
223
+ tool_calls_count: int = 0
224
+ error_count: int = 0
225
+ success_count: int = 0
226
+
227
+
228
+ class AggregateSummary(TypedDict, total=False):
229
+ """Aggregate metrics derived from call records."""
230
+
231
+ call_count: int
232
+ input_tokens: int
233
+ output_tokens: int
234
+ total_tokens: int
235
+ reasoning_tokens: int
236
+ cost_usd: float
237
+ latency_ms: int
238
+ models_used: list[str]
239
+ providers_used: list[str]
240
+ tool_calls_count: int
241
+ error_count: int
242
+ success_count: int
243
+ avg_latency_ms: float
244
+ avg_input_tokens: float
245
+ avg_output_tokens: float
246
+
247
+
248
+ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> AggregateSummary:
192
249
  """Compute aggregate statistics from a list of LLMCallRecord instances.
193
250
 
194
251
  Args:
@@ -197,65 +254,62 @@ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> d
197
254
  Returns:
198
255
  Dictionary containing aggregated statistics
199
256
  """
200
- aggregates = {
201
- "input_tokens": 0,
202
- "output_tokens": 0,
203
- "total_tokens": 0,
204
- "reasoning_tokens": 0,
205
- "cost_usd": 0.0,
206
- "latency_ms": 0,
207
- "models_used": set(),
208
- "providers_used": set(),
209
- "tool_calls_count": 0,
210
- "error_count": 0,
211
- "success_count": 0,
212
- "call_count": len(call_records),
213
- }
257
+ aggregates = _AggregateAccumulator(call_count=len(call_records))
214
258
 
215
259
  for record in call_records:
216
260
  # Token aggregation
217
261
  if record.usage:
218
262
  if record.usage.input_tokens:
219
- aggregates["input_tokens"] += record.usage.input_tokens
263
+ aggregates.input_tokens += record.usage.input_tokens
220
264
  if record.usage.output_tokens:
221
- aggregates["output_tokens"] += record.usage.output_tokens
265
+ aggregates.output_tokens += record.usage.output_tokens
222
266
  if record.usage.total_tokens:
223
- aggregates["total_tokens"] += record.usage.total_tokens
267
+ aggregates.total_tokens += record.usage.total_tokens
224
268
  if record.usage.reasoning_tokens:
225
- aggregates["reasoning_tokens"] += record.usage.reasoning_tokens
269
+ aggregates.reasoning_tokens += record.usage.reasoning_tokens
226
270
  if record.usage.cost_usd:
227
- aggregates["cost_usd"] += record.usage.cost_usd
271
+ aggregates.cost_usd += record.usage.cost_usd
228
272
 
229
273
  # Latency aggregation
230
- if record.latency_ms:
231
- aggregates["latency_ms"] += record.latency_ms
274
+ if record.latency_ms is not None:
275
+ aggregates.latency_ms += record.latency_ms
232
276
 
233
277
  # Model and provider tracking
234
278
  if record.model_name:
235
- aggregates["models_used"].add(record.model_name)
279
+ aggregates.models_used.add(record.model_name)
236
280
  if record.provider:
237
- aggregates["providers_used"].add(record.provider)
281
+ aggregates.providers_used.add(record.provider)
238
282
 
239
283
  # Tool calls
240
- aggregates["tool_calls_count"] += len(record.output_tool_calls)
284
+ aggregates.tool_calls_count += len(record.output_tool_calls)
241
285
 
242
286
  # Success/error tracking
243
287
  if record.outcome == "error":
244
- aggregates["error_count"] += 1
288
+ aggregates.error_count += 1
245
289
  elif record.outcome == "success":
246
- aggregates["success_count"] += 1
247
-
248
- # Convert sets to lists for JSON serialization
249
- aggregates["models_used"] = list(aggregates["models_used"])
250
- aggregates["providers_used"] = list(aggregates["providers_used"])
290
+ aggregates.success_count += 1
291
+
292
+ summary: AggregateSummary = {
293
+ "call_count": aggregates.call_count,
294
+ "input_tokens": aggregates.input_tokens,
295
+ "output_tokens": aggregates.output_tokens,
296
+ "total_tokens": aggregates.total_tokens,
297
+ "reasoning_tokens": aggregates.reasoning_tokens,
298
+ "cost_usd": aggregates.cost_usd,
299
+ "latency_ms": aggregates.latency_ms,
300
+ "models_used": list(aggregates.models_used),
301
+ "providers_used": list(aggregates.providers_used),
302
+ "tool_calls_count": aggregates.tool_calls_count,
303
+ "error_count": aggregates.error_count,
304
+ "success_count": aggregates.success_count,
305
+ }
251
306
 
252
- # Compute averages
253
- if aggregates["call_count"] > 0:
254
- aggregates["avg_latency_ms"] = aggregates["latency_ms"] / aggregates["call_count"]
255
- aggregates["avg_input_tokens"] = aggregates["input_tokens"] / aggregates["call_count"]
256
- aggregates["avg_output_tokens"] = aggregates["output_tokens"] / aggregates["call_count"]
307
+ if aggregates.call_count > 0:
308
+ summary["avg_latency_ms"] = aggregates.latency_ms / aggregates.call_count
309
+ summary["avg_input_tokens"] = aggregates.input_tokens / aggregates.call_count
310
+ summary["avg_output_tokens"] = aggregates.output_tokens / aggregates.call_count
257
311
 
258
- return aggregates
312
+ return summary
259
313
 
260
314
 
261
315
  def create_llm_call_record_from_streaming(
@@ -26,6 +26,7 @@ application to continue without blocking on sync operations.
26
26
 
27
27
  import asyncio
28
28
  import logging
29
+ from typing import Any
29
30
 
30
31
  import libsql
31
32
 
@@ -66,8 +67,8 @@ class ReplicaSync:
66
67
  self.sync_url = sync_url or CONFIG.sync_url
67
68
  self.auth_token = auth_token or CONFIG.auth_token
68
69
  self.sync_interval = sync_interval or CONFIG.sync_interval
69
- self._sync_task: asyncio.Task | None = None
70
- self._conn: libsql.Connection | None = None
70
+ self._sync_task: asyncio.Task[Any] | None = None
71
+ self._conn: Any | None = None
71
72
 
72
73
  def _ensure_connection(self):
73
74
  """Ensure libsql connection is established.
@@ -113,8 +114,11 @@ class ReplicaSync:
113
114
  """
114
115
  try:
115
116
  self._ensure_connection()
117
+ conn = self._conn
118
+ if conn is None:
119
+ raise RuntimeError("Replica sync connection is not available after initialization")
116
120
  # Run sync in thread pool since libsql sync is blocking
117
- await asyncio.to_thread(self._conn.sync)
121
+ await asyncio.to_thread(conn.sync)
118
122
  logger.info("Successfully synced with remote Turso database")
119
123
  return True
120
124
  except Exception as e:
@@ -146,7 +150,7 @@ class ReplicaSync:
146
150
  # Sleep until next sync interval
147
151
  await asyncio.sleep(self.sync_interval)
148
152
 
149
- def start_background_sync(self) -> asyncio.Task:
153
+ def start_background_sync(self) -> asyncio.Task[Any]:
150
154
  """Start the background sync task.
151
155
 
152
156
  Creates an asyncio task that runs the sync loop. The task is stored
@@ -0,0 +1,130 @@
1
+ """HTTP-safe serialization helpers for tracing v3.
2
+
3
+ These utilities normalize tracing structures (including dataclasses) into
4
+ JSON-serializable forms and provide a compact JSON encoder suitable for
5
+ HTTP transmission to backend services.
6
+
7
+ Design goals:
8
+ - Preserve structure while ensuring standard-compliant JSON (no NaN/Infinity)
9
+ - Handle common non-JSON types: datetime, Decimal, bytes, set/tuple, numpy scalars
10
+ - Keep output compact (no unnecessary whitespace) while readable if needed
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import base64
16
+ import json
17
+ from dataclasses import asdict, is_dataclass
18
+ from datetime import date, datetime
19
+ from decimal import Decimal
20
+ from enum import Enum
21
+ from typing import Any
22
+
23
+ try:
24
+ import numpy as _np # type: ignore
25
+ except Exception: # pragma: no cover - numpy optional at runtime
26
+ _np = None # type: ignore
27
+
28
+
29
+ def normalize_for_json(value: Any) -> Any:
30
+ """Return a JSON-serializable version of ``value``.
31
+
32
+ Rules:
33
+ - dataclass → dict (recursively normalized)
34
+ - datetime/date → ISO-8601 string (UTC-aware datetimes preserve tzinfo)
35
+ - Decimal → float (fallback to string if not finite)
36
+ - bytes/bytearray → base64 string (RFC 4648)
37
+ - set/tuple → list
38
+ - Enum → enum.value (normalized)
39
+ - numpy scalar → corresponding Python scalar
40
+ - float NaN/Inf/−Inf → None (to keep JSON standard compliant)
41
+ - dict / list → recursively normalized
42
+ - other primitives (str, int, bool, None, float) passed through
43
+ """
44
+
45
+ # Dataclasses
46
+ if is_dataclass(value) and not isinstance(value, type):
47
+ try:
48
+ return normalize_for_json(asdict(value))
49
+ except Exception:
50
+ # Fallback: best-effort conversion via __dict__
51
+ return normalize_for_json(getattr(value, "__dict__", {}))
52
+
53
+ # Mapping
54
+ if isinstance(value, dict):
55
+ return {str(k): normalize_for_json(v) for k, v in value.items()}
56
+
57
+ # Sequences
58
+ if isinstance(value, (list, tuple, set)):
59
+ return [normalize_for_json(v) for v in value]
60
+
61
+ # Datetime / Date
62
+ if isinstance(value, (datetime, date)):
63
+ return value.isoformat()
64
+
65
+ # Decimal
66
+ if isinstance(value, Decimal):
67
+ try:
68
+ f = float(value)
69
+ if f != f or f in (float("inf"), float("-inf")):
70
+ return str(value)
71
+ return f
72
+ except Exception:
73
+ return str(value)
74
+
75
+ # Bytes-like
76
+ if isinstance(value, (bytes, bytearray)):
77
+ return base64.b64encode(bytes(value)).decode("ascii")
78
+
79
+ # Enum
80
+ if isinstance(value, Enum):
81
+ return normalize_for_json(value.value)
82
+
83
+ # Numpy scalars / arrays
84
+ if _np is not None:
85
+ if isinstance(value, (_np.generic,)): # type: ignore[attr-defined]
86
+ return normalize_for_json(value.item())
87
+ if isinstance(value, (_np.ndarray,)):
88
+ return normalize_for_json(value.tolist())
89
+
90
+ # Floats: sanitize NaN / Infinity to None
91
+ if isinstance(value, float):
92
+ if value != value or value in (float("inf"), float("-inf")):
93
+ return None
94
+ return value
95
+
96
+ return value
97
+
98
+
99
+ def dumps_http_json(payload: Any) -> str:
100
+ """Dump ``payload`` into a compact, HTTP-safe JSON string.
101
+
102
+ - Recursively normalizes non-JSON types (see ``normalize_for_json``)
103
+ - Disallows NaN/Infinity per RFC 8259 (allow_nan=False)
104
+ - Uses compact separators and preserves Unicode (ensure_ascii=False)
105
+ """
106
+
107
+ normalized = normalize_for_json(payload)
108
+ return json.dumps(
109
+ normalized,
110
+ ensure_ascii=False,
111
+ allow_nan=False,
112
+ separators=(",", ":"),
113
+ )
114
+
115
+
116
+ def serialize_trace_for_http(trace: Any) -> str:
117
+ """Serialize a tracing v3 session (or dict-like) to HTTP-safe JSON.
118
+
119
+ Accepts either a dataclass (e.g., SessionTrace) or a dict/list and
120
+ applies normalization and compact JSON encoding.
121
+ """
122
+
123
+ if is_dataclass(trace) and not isinstance(trace, type):
124
+ try:
125
+ return dumps_http_json(asdict(trace))
126
+ except Exception:
127
+ return dumps_http_json(getattr(trace, "__dict__", {}))
128
+ return dumps_http_json(trace)
129
+
130
+
@@ -3,8 +3,8 @@
3
3
  import asyncio
4
4
  import functools
5
5
  import time
6
- from collections.abc import Callable
7
- from typing import Any, TypeVar
6
+ from collections.abc import Awaitable, Callable
7
+ from typing import Any, TypeVar, cast
8
8
 
9
9
  T = TypeVar("T")
10
10
 
@@ -18,9 +18,9 @@ def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0)
18
18
  backoff: Backoff multiplier for each retry
19
19
  """
20
20
 
21
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
21
+ def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
22
22
  @functools.wraps(func)
23
- async def wrapper(*args, **kwargs):
23
+ async def wrapper(*args: Any, **kwargs: Any) -> T:
24
24
  last_exception: Exception | None = None
25
25
  current_delay = delay
26
26
 
@@ -171,13 +171,14 @@ STORAGE_METRICS = StorageMetrics()
171
171
  def track_metrics(operation: str):
172
172
  """Decorator to track storage operation metrics."""
173
173
 
174
- def decorator(func: Callable[..., T]) -> Callable[..., T]:
174
+ def decorator(func: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
175
175
  @functools.wraps(func)
176
- async def async_wrapper(*args, **kwargs):
176
+ async def async_wrapper(*args: Any, **kwargs: Any) -> T:
177
177
  start_time = time.time()
178
178
  success = False
179
179
  try:
180
- result = await func(*args, **kwargs)
180
+ async_func = cast(Callable[..., Awaitable[T]], func)
181
+ result = await async_func(*args, **kwargs)
181
182
  success = True
182
183
  return result
183
184
  finally:
@@ -185,11 +186,12 @@ def track_metrics(operation: str):
185
186
  STORAGE_METRICS.record_operation(operation, duration, success)
186
187
 
187
188
  @functools.wraps(func)
188
- def sync_wrapper(*args, **kwargs):
189
+ def sync_wrapper(*args: Any, **kwargs: Any) -> T:
189
190
  start_time = time.time()
190
191
  success = False
191
192
  try:
192
- result = func(*args, **kwargs)
193
+ sync_func = cast(Callable[..., T], func)
194
+ result = sync_func(*args, **kwargs)
193
195
  success = True
194
196
  return result
195
197
  finally:
@@ -0,0 +1,12 @@
1
+ """Turso integration package for tracing v3."""
2
+
3
+ from .daemon import SqldDaemon, get_daemon, start_sqld, stop_sqld
4
+ from .native_manager import NativeLibsqlTraceManager
5
+
6
+ __all__ = [
7
+ "SqldDaemon",
8
+ "NativeLibsqlTraceManager",
9
+ "get_daemon",
10
+ "start_sqld",
11
+ "stop_sqld",
12
+ ]
@@ -6,6 +6,7 @@ import subprocess
6
6
  import time
7
7
 
8
8
  import requests
9
+ from requests import RequestException
9
10
 
10
11
  from ..config import CONFIG
11
12
 
@@ -79,7 +80,7 @@ class SqldDaemon:
79
80
  response = requests.get(health_url, timeout=1)
80
81
  if response.status_code == 200:
81
82
  return
82
- except requests.exceptions.RequestException:
83
+ except RequestException:
83
84
  pass
84
85
 
85
86
  # Check if process crashed
@@ -11,18 +11,14 @@ import asyncio
11
11
  import json
12
12
  import logging
13
13
  import re
14
+ from collections.abc import Callable
14
15
  from dataclasses import asdict, dataclass
15
16
  from datetime import UTC, datetime
16
- from typing import Any
17
+ from typing import TYPE_CHECKING, Any, cast
17
18
 
18
19
  import libsql
19
20
  from sqlalchemy.engine import make_url
20
21
 
21
- try: # pragma: no cover - exercised only when pandas present
22
- import pandas as pd # type: ignore
23
- except Exception: # pragma: no cover
24
- pd = None # type: ignore[assignment]
25
-
26
22
  from ..abstractions import (
27
23
  EnvironmentEvent,
28
24
  LMCAISEvent,
@@ -34,6 +30,24 @@ from ..config import CONFIG
34
30
  from ..storage.base import TraceStorage
35
31
  from .models import analytics_views
36
32
 
33
+ if TYPE_CHECKING:
34
+ from sqlite3 import Connection as LibsqlConnection
35
+ else: # pragma: no cover - runtime fallback for typing only
36
+ LibsqlConnection = Any # type: ignore[assignment]
37
+
38
+ _LIBSQL_CONNECT_ATTR = getattr(libsql, "connect", None)
39
+ if _LIBSQL_CONNECT_ATTR is None: # pragma: no cover - defensive guard
40
+ raise RuntimeError("libsql.connect is required for NativeLibsqlTraceManager")
41
+ _libsql_connect: Callable[..., LibsqlConnection] = cast(
42
+ Callable[..., LibsqlConnection],
43
+ _LIBSQL_CONNECT_ATTR,
44
+ )
45
+
46
+ try: # pragma: no cover - exercised only when pandas present
47
+ import pandas as pd # type: ignore
48
+ except Exception: # pragma: no cover
49
+ pd = None # type: ignore[assignment]
50
+
37
51
  logger = logging.getLogger(__name__)
38
52
 
39
53
 
@@ -66,9 +80,8 @@ def _resolve_connection_target(db_url: str | None, auth_token: str | None) -> _C
66
80
  # Fallback to SQLAlchemy URL parsing for anything else we missed.
67
81
  try:
68
82
  parsed = make_url(url)
69
- if parsed.drivername.startswith("sqlite"):
70
- if parsed.database:
71
- return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
83
+ if parsed.drivername.startswith("sqlite") and parsed.database:
84
+ return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
72
85
  if parsed.drivername.startswith("libsql"):
73
86
  database = parsed.render_as_string(hide_password=False)
74
87
  return _ConnectionTarget(database=database, sync_url=database, auth_token=auth_token)
@@ -314,12 +327,12 @@ class NativeLibsqlTraceManager(TraceStorage):
314
327
  ):
315
328
  self._config_auth_token = auth_token
316
329
  self._target = _resolve_connection_target(db_url, auth_token)
317
- self._conn: libsql.Connection | None = None
330
+ self._conn: LibsqlConnection | None = None
318
331
  self._conn_lock = asyncio.Lock()
319
332
  self._op_lock = asyncio.Lock()
320
333
  self._initialized = False
321
334
 
322
- def _open_connection(self) -> libsql.Connection:
335
+ def _open_connection(self) -> LibsqlConnection:
323
336
  """Open a libsql connection for the resolved target."""
324
337
  kwargs: dict[str, Any] = {}
325
338
  if self._target.sync_url and self._target.sync_url.startswith("libsql://"):
@@ -329,7 +342,7 @@ class NativeLibsqlTraceManager(TraceStorage):
329
342
  # Disable automatic background sync; ReplicaSync drives this explicitly.
330
343
  kwargs.setdefault("sync_interval", 0)
331
344
  logger.debug("Opening libsql connection to %s", self._target.database)
332
- return libsql.connect(self._target.database, **kwargs)
345
+ return _libsql_connect(self._target.database, **kwargs)
333
346
 
334
347
  async def initialize(self):
335
348
  """Initialise the backend."""
@@ -493,7 +506,7 @@ class NativeLibsqlTraceManager(TraceStorage):
493
506
  return None
494
507
 
495
508
  session_columns = ["session_id", "created_at", "num_timesteps", "num_events", "num_messages", "metadata"]
496
- session_data = dict(zip(session_columns, session_row))
509
+ session_data = dict(zip(session_columns, session_row, strict=True))
497
510
 
498
511
  timestep_cursor = conn.execute(
499
512
  """
@@ -608,10 +621,10 @@ class NativeLibsqlTraceManager(TraceStorage):
608
621
 
609
622
  if not rows:
610
623
  if pd is not None:
611
- return pd.DataFrame(columns=[col for col in columns])
624
+ return pd.DataFrame(columns=list(columns))
612
625
  return []
613
626
 
614
- records = [dict(zip(columns, row)) for row in rows]
627
+ records = [dict(zip(columns, row, strict=True)) for row in rows]
615
628
  if pd is not None:
616
629
  return pd.DataFrame(records)
617
630
  return records
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: synth-ai
3
- Version: 0.2.10
3
+ Version: 0.2.13.dev1
4
4
  Summary: RL as a service SDK - Core AI functionality and tracing
5
5
  Author-email: Synth AI <josh@usesynth.ai>
6
6
  License-Expression: MIT
@@ -79,7 +79,7 @@ Dynamic: license-file
79
79
 
80
80
  [![Python](https://img.shields.io/badge/python-3.11+-blue)](https://www.python.org/)
81
81
  [![License](https://img.shields.io/badge/license-MIT-green)](LICENSE)
82
- [![PyPI](https://img.shields.io/badge/PyPI-0.2.4.dev9-orange)](https://pypi.org/project/synth-ai/)
82
+ [![PyPI](https://img.shields.io/badge/PyPI-0.2.10-orange)](https://pypi.org/project/synth-ai/)
83
83
  ![Coverage](https://img.shields.io/badge/coverage-9.09%25-red)
84
84
  ![Tests](https://img.shields.io/badge/tests-37%2F38%20passing-brightgreen)
85
85
  ![Blacksmith CI](https://img.shields.io/badge/CI-Blacksmith%20Worker-blue)
@@ -88,6 +88,8 @@ Docs: [Synth‑AI Documentation](https://docs.usesynth.ai/welcome/introduction)
88
88
 
89
89
  Fast and effective reinforcement learning for agents, via an API
90
90
 
91
+ > Latest: 0.2.10 published to PyPI (uv publish)
92
+
91
93
  ## Highlights
92
94
 
93
95
  - Easily scale gpu topologies - train on 3 a10gs or 8 H100s (multi-node available upon request)