synth-ai 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/multi_step/task_app_config_notes.md +488 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +33 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +30 -0
- examples/warming_up_to_rl/run_eval.py +142 -25
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +146 -2
- synth_ai/api/train/builders.py +25 -14
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +143 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/jobs/client.py +15 -3
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/RECORD +38 -31
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/entry_points.txt +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.12.dist-info}/top_level.txt +0 -0
synth_ai/task/server.py
CHANGED
|
@@ -6,6 +6,7 @@ import asyncio
|
|
|
6
6
|
import inspect
|
|
7
7
|
import os
|
|
8
8
|
from collections.abc import Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
9
10
|
from dataclasses import dataclass, field
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import Any
|
|
@@ -34,6 +35,10 @@ InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iter
|
|
|
34
35
|
RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
def _default_app_state() -> dict[str, Any]:
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
|
|
37
42
|
@dataclass(slots=True)
|
|
38
43
|
class RubricBundle:
|
|
39
44
|
"""Optional rubrics advertised by the task app."""
|
|
@@ -69,7 +74,7 @@ class TaskAppConfig:
|
|
|
69
74
|
proxy: ProxyConfig | None = None
|
|
70
75
|
routers: Sequence[APIRouter] = field(default_factory=tuple)
|
|
71
76
|
middleware: Sequence[Middleware] = field(default_factory=tuple)
|
|
72
|
-
app_state:
|
|
77
|
+
app_state: MutableMapping[str, Any] = field(default_factory=_default_app_state)
|
|
73
78
|
require_api_key: bool = True
|
|
74
79
|
expose_debug_env: bool = True
|
|
75
80
|
cors_origins: Sequence[str] | None = None
|
|
@@ -260,17 +265,19 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
|
|
|
260
265
|
return _maybe_await(hook(app)) # type: ignore[misc]
|
|
261
266
|
return _maybe_await(hook())
|
|
262
267
|
|
|
263
|
-
@
|
|
264
|
-
async def
|
|
268
|
+
@asynccontextmanager
|
|
269
|
+
async def lifespan(_: FastAPI):
|
|
265
270
|
normalize_environment_api_key()
|
|
266
271
|
normalize_vendor_keys()
|
|
267
272
|
for hook in cfg.startup_hooks:
|
|
268
273
|
await _call_hook(hook)
|
|
274
|
+
try:
|
|
275
|
+
yield
|
|
276
|
+
finally:
|
|
277
|
+
for hook in cfg.shutdown_hooks:
|
|
278
|
+
await _call_hook(hook)
|
|
269
279
|
|
|
270
|
-
|
|
271
|
-
async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
|
|
272
|
-
for hook in cfg.shutdown_hooks:
|
|
273
|
-
await _call_hook(hook)
|
|
280
|
+
app.router.lifespan_context = lifespan
|
|
274
281
|
|
|
275
282
|
@app.get("/")
|
|
276
283
|
async def root() -> Mapping[str, Any]:
|
|
@@ -28,8 +28,8 @@ import asyncio
|
|
|
28
28
|
import contextvars
|
|
29
29
|
import functools
|
|
30
30
|
import time
|
|
31
|
-
from collections.abc import Callable
|
|
32
|
-
from typing import Any, TypeVar
|
|
31
|
+
from collections.abc import Awaitable, Callable, Mapping
|
|
32
|
+
from typing import Any, TypeVar, cast, overload
|
|
33
33
|
|
|
34
34
|
from .abstractions import LMCAISEvent, TimeRecord
|
|
35
35
|
from .utils import calculate_cost, detect_provider
|
|
@@ -88,6 +88,16 @@ def get_session_tracer() -> Any:
|
|
|
88
88
|
T = TypeVar("T")
|
|
89
89
|
|
|
90
90
|
|
|
91
|
+
@overload
|
|
92
|
+
def with_session(require: bool = True) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
|
|
93
|
+
...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def with_session(require: bool = True) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
|
|
91
101
|
def with_session(require: bool = True):
|
|
92
102
|
"""Decorator that ensures a session is active.
|
|
93
103
|
|
|
@@ -109,29 +119,31 @@ def with_session(require: bool = True):
|
|
|
109
119
|
```
|
|
110
120
|
"""
|
|
111
121
|
|
|
112
|
-
def decorator(fn: Callable[..., T]) -> Callable[..., T]:
|
|
122
|
+
def decorator(fn: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
113
123
|
if asyncio.iscoroutinefunction(fn):
|
|
114
124
|
|
|
115
125
|
@functools.wraps(fn)
|
|
116
|
-
async def async_wrapper(*args, **kwargs):
|
|
126
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
117
127
|
session_id = get_session_id()
|
|
118
128
|
if require and session_id is None:
|
|
119
129
|
raise RuntimeError(
|
|
120
130
|
f"No active session for {getattr(fn, '__name__', 'unknown')}"
|
|
121
131
|
)
|
|
122
|
-
|
|
132
|
+
async_fn = cast(Callable[..., Awaitable[T]], fn)
|
|
133
|
+
return await async_fn(*args, **kwargs)
|
|
123
134
|
|
|
124
135
|
return async_wrapper
|
|
125
136
|
else:
|
|
126
137
|
|
|
127
138
|
@functools.wraps(fn)
|
|
128
|
-
def sync_wrapper(*args, **kwargs):
|
|
139
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
129
140
|
session_id = get_session_id()
|
|
130
141
|
if require and session_id is None:
|
|
131
142
|
raise RuntimeError(
|
|
132
143
|
f"No active session for {getattr(fn, '__name__', 'unknown')}"
|
|
133
144
|
)
|
|
134
|
-
|
|
145
|
+
sync_fn = cast(Callable[..., T], fn)
|
|
146
|
+
return sync_fn(*args, **kwargs)
|
|
135
147
|
|
|
136
148
|
return sync_wrapper
|
|
137
149
|
|
|
@@ -172,31 +184,36 @@ def trace_llm_call(
|
|
|
172
184
|
```
|
|
173
185
|
"""
|
|
174
186
|
|
|
175
|
-
def decorator(fn: Callable[..., T]) -> Callable[..., T]:
|
|
187
|
+
def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
|
|
176
188
|
if asyncio.iscoroutinefunction(fn):
|
|
189
|
+
async_fn: Callable[..., Awaitable[T]] = fn
|
|
177
190
|
|
|
178
191
|
@functools.wraps(fn)
|
|
179
|
-
async def async_wrapper(*args, **kwargs):
|
|
192
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
180
193
|
tracer = get_session_tracer()
|
|
181
194
|
if not tracer:
|
|
182
|
-
return await
|
|
195
|
+
return await async_fn(*args, **kwargs)
|
|
183
196
|
|
|
184
197
|
start_time = time.time()
|
|
185
198
|
system_state_before = kwargs.get("state_before", {})
|
|
186
199
|
|
|
187
200
|
try:
|
|
188
|
-
result = await
|
|
201
|
+
result = await async_fn(*args, **kwargs)
|
|
189
202
|
|
|
190
203
|
# Extract metrics from result - this assumes the result follows
|
|
191
204
|
# common LLM API response formats (OpenAI, Anthropic, etc.)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
205
|
+
input_tokens = output_tokens = total_tokens = None
|
|
206
|
+
actual_model = model_name
|
|
207
|
+
if extract_tokens and isinstance(result, Mapping):
|
|
208
|
+
result_mapping = cast(Mapping[str, Any], result)
|
|
209
|
+
usage = result_mapping.get("usage")
|
|
210
|
+
if isinstance(usage, Mapping):
|
|
211
|
+
input_tokens = usage.get("prompt_tokens")
|
|
212
|
+
output_tokens = usage.get("completion_tokens")
|
|
213
|
+
total_tokens = usage.get("total_tokens")
|
|
214
|
+
value = result_mapping.get("model")
|
|
215
|
+
if isinstance(value, str):
|
|
216
|
+
actual_model = value
|
|
200
217
|
|
|
201
218
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
202
219
|
|
|
@@ -272,19 +289,26 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
|
|
|
272
289
|
```
|
|
273
290
|
"""
|
|
274
291
|
|
|
275
|
-
def decorator(
|
|
292
|
+
def decorator(
|
|
293
|
+
fn: Callable[..., Awaitable[T]] | Callable[..., T]
|
|
294
|
+
) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
276
295
|
if asyncio.iscoroutinefunction(fn):
|
|
296
|
+
async_fn = cast(Callable[..., Awaitable[T]], fn)
|
|
277
297
|
|
|
278
298
|
@functools.wraps(fn)
|
|
279
|
-
async def async_wrapper(
|
|
299
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
280
300
|
tracer = get_session_tracer()
|
|
281
301
|
if not tracer:
|
|
282
|
-
return await
|
|
302
|
+
return await async_fn(*args, **kwargs)
|
|
283
303
|
|
|
284
304
|
from .abstractions import RuntimeEvent
|
|
285
305
|
|
|
286
306
|
# Use class name as system_id if not provided
|
|
287
|
-
|
|
307
|
+
self_obj = args[0] if args else None
|
|
308
|
+
inferred_system_id = (
|
|
309
|
+
self_obj.__class__.__name__ if self_obj is not None else "unknown"
|
|
310
|
+
)
|
|
311
|
+
actual_system_id = system_id or inferred_system_id
|
|
288
312
|
|
|
289
313
|
event = RuntimeEvent(
|
|
290
314
|
system_instance_id=actual_system_id,
|
|
@@ -298,17 +322,18 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
|
|
|
298
322
|
)
|
|
299
323
|
|
|
300
324
|
await tracer.record_event(event)
|
|
301
|
-
return await
|
|
325
|
+
return await async_fn(*args, **kwargs)
|
|
302
326
|
|
|
303
327
|
return async_wrapper
|
|
304
328
|
else:
|
|
305
329
|
|
|
306
330
|
@functools.wraps(fn)
|
|
307
|
-
def sync_wrapper(
|
|
331
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
308
332
|
# For sync methods, we can't easily trace without blocking
|
|
309
333
|
# the event loop. This is a limitation of the async-first design.
|
|
310
334
|
# Consider converting to async or using a different approach
|
|
311
|
-
|
|
335
|
+
sync_fn = cast(Callable[..., T], fn)
|
|
336
|
+
return sync_fn(*args, **kwargs)
|
|
312
337
|
|
|
313
338
|
return sync_wrapper
|
|
314
339
|
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from
|
|
7
|
+
from .. import SessionTracer
|
|
8
|
+
from ..abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
|
|
9
|
+
from ..turso.daemon import SqldDaemon
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
async def simulate_llm_call(model: str, prompt: str) -> dict:
|
|
12
|
+
async def simulate_llm_call(model: str, prompt: str) -> dict[str, Any]:
|
|
12
13
|
"""Simulate an LLM API call."""
|
|
13
14
|
await asyncio.sleep(0.1) # Simulate network latency
|
|
14
15
|
|
|
@@ -133,6 +134,9 @@ async def main():
|
|
|
133
134
|
print("\n--- Example 3: Querying Data ---")
|
|
134
135
|
|
|
135
136
|
# Get model usage statistics
|
|
137
|
+
if tracer.db is None:
|
|
138
|
+
raise RuntimeError("Tracer database backend is not initialized")
|
|
139
|
+
|
|
136
140
|
model_usage = await tracer.db.get_model_usage()
|
|
137
141
|
print("\nModel Usage:")
|
|
138
142
|
print(model_usage)
|
|
@@ -150,9 +154,10 @@ async def main():
|
|
|
150
154
|
# Get specific session details
|
|
151
155
|
if recent_sessions:
|
|
152
156
|
session_detail = await tracer.db.get_session_trace(recent_sessions[0]["session_id"])
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
157
|
+
if session_detail:
|
|
158
|
+
print(f"\nSession Detail for {session_detail['session_id']}:")
|
|
159
|
+
print(f" Created: {session_detail['created_at']}")
|
|
160
|
+
print(f" Timesteps: {len(session_detail['timesteps'])}")
|
|
156
161
|
|
|
157
162
|
# Example 4: Using hooks
|
|
158
163
|
print("\n--- Example 4: Hooks ---")
|
|
@@ -4,11 +4,14 @@ This module provides utilities to convert vendor responses to LLMCallRecord
|
|
|
4
4
|
format and compute aggregates from call records.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
7
9
|
import uuid
|
|
10
|
+
from dataclasses import dataclass, field
|
|
8
11
|
from datetime import UTC, datetime
|
|
9
|
-
from typing import Any
|
|
12
|
+
from typing import Any, TypedDict, cast
|
|
10
13
|
|
|
11
|
-
from
|
|
14
|
+
from .lm_call_record_abstractions import (
|
|
12
15
|
LLMCallRecord,
|
|
13
16
|
LLMChunk,
|
|
14
17
|
LLMContentPart,
|
|
@@ -17,7 +20,21 @@ from synth_ai.tracing_v3.lm_call_record_abstractions import (
|
|
|
17
20
|
LLMUsage,
|
|
18
21
|
ToolCallSpec,
|
|
19
22
|
)
|
|
20
|
-
|
|
23
|
+
|
|
24
|
+
BaseLMResponse = Any
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _UsageDict(TypedDict, total=False):
|
|
28
|
+
prompt_tokens: int
|
|
29
|
+
completion_tokens: int
|
|
30
|
+
total_tokens: int
|
|
31
|
+
reasoning_tokens: int
|
|
32
|
+
cost_usd: float
|
|
33
|
+
duration_ms: int
|
|
34
|
+
reasoning_input_tokens: int
|
|
35
|
+
reasoning_output_tokens: int
|
|
36
|
+
cache_write_tokens: int
|
|
37
|
+
cache_read_tokens: int
|
|
21
38
|
|
|
22
39
|
|
|
23
40
|
def create_llm_call_record_from_response(
|
|
@@ -110,9 +127,10 @@ def create_llm_call_record_from_response(
|
|
|
110
127
|
)
|
|
111
128
|
|
|
112
129
|
# Extract tool calls if present
|
|
113
|
-
output_tool_calls = []
|
|
114
|
-
|
|
115
|
-
|
|
130
|
+
output_tool_calls: list[ToolCallSpec] = []
|
|
131
|
+
tool_calls_data = cast(list[dict[str, Any]] | None, getattr(response, "tool_calls", None))
|
|
132
|
+
if tool_calls_data:
|
|
133
|
+
for idx, tool_call in enumerate(tool_calls_data):
|
|
116
134
|
if isinstance(tool_call, dict):
|
|
117
135
|
output_tool_calls.append(
|
|
118
136
|
ToolCallSpec(
|
|
@@ -125,18 +143,19 @@ def create_llm_call_record_from_response(
|
|
|
125
143
|
|
|
126
144
|
# Extract usage information
|
|
127
145
|
usage = None
|
|
128
|
-
|
|
146
|
+
usage_data = cast(_UsageDict | None, getattr(response, "usage", None))
|
|
147
|
+
if usage_data:
|
|
129
148
|
usage = LLMUsage(
|
|
130
|
-
input_tokens=
|
|
131
|
-
output_tokens=
|
|
132
|
-
total_tokens=
|
|
133
|
-
cost_usd=
|
|
149
|
+
input_tokens=usage_data.get("input_tokens"),
|
|
150
|
+
output_tokens=usage_data.get("output_tokens"),
|
|
151
|
+
total_tokens=usage_data.get("total_tokens"),
|
|
152
|
+
cost_usd=usage_data.get("cost_usd"),
|
|
134
153
|
# Additional token accounting if available
|
|
135
|
-
reasoning_tokens=
|
|
136
|
-
reasoning_input_tokens=
|
|
137
|
-
reasoning_output_tokens=
|
|
138
|
-
cache_write_tokens=
|
|
139
|
-
cache_read_tokens=
|
|
154
|
+
reasoning_tokens=usage_data.get("reasoning_tokens"),
|
|
155
|
+
reasoning_input_tokens=usage_data.get("reasoning_input_tokens"),
|
|
156
|
+
reasoning_output_tokens=usage_data.get("reasoning_output_tokens"),
|
|
157
|
+
cache_write_tokens=usage_data.get("cache_write_tokens"),
|
|
158
|
+
cache_read_tokens=usage_data.get("cache_read_tokens"),
|
|
140
159
|
)
|
|
141
160
|
|
|
142
161
|
# Build request parameters
|
|
@@ -188,7 +207,45 @@ def create_llm_call_record_from_response(
|
|
|
188
207
|
return record
|
|
189
208
|
|
|
190
209
|
|
|
191
|
-
|
|
210
|
+
@dataclass
|
|
211
|
+
class _AggregateAccumulator:
|
|
212
|
+
"""Mutable accumulator for call record aggregates."""
|
|
213
|
+
|
|
214
|
+
call_count: int = 0
|
|
215
|
+
input_tokens: int = 0
|
|
216
|
+
output_tokens: int = 0
|
|
217
|
+
total_tokens: int = 0
|
|
218
|
+
reasoning_tokens: int = 0
|
|
219
|
+
cost_usd: float = 0.0
|
|
220
|
+
latency_ms: int = 0
|
|
221
|
+
models_used: set[str] = field(default_factory=set)
|
|
222
|
+
providers_used: set[str] = field(default_factory=set)
|
|
223
|
+
tool_calls_count: int = 0
|
|
224
|
+
error_count: int = 0
|
|
225
|
+
success_count: int = 0
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class AggregateSummary(TypedDict, total=False):
|
|
229
|
+
"""Aggregate metrics derived from call records."""
|
|
230
|
+
|
|
231
|
+
call_count: int
|
|
232
|
+
input_tokens: int
|
|
233
|
+
output_tokens: int
|
|
234
|
+
total_tokens: int
|
|
235
|
+
reasoning_tokens: int
|
|
236
|
+
cost_usd: float
|
|
237
|
+
latency_ms: int
|
|
238
|
+
models_used: list[str]
|
|
239
|
+
providers_used: list[str]
|
|
240
|
+
tool_calls_count: int
|
|
241
|
+
error_count: int
|
|
242
|
+
success_count: int
|
|
243
|
+
avg_latency_ms: float
|
|
244
|
+
avg_input_tokens: float
|
|
245
|
+
avg_output_tokens: float
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> AggregateSummary:
|
|
192
249
|
"""Compute aggregate statistics from a list of LLMCallRecord instances.
|
|
193
250
|
|
|
194
251
|
Args:
|
|
@@ -197,65 +254,62 @@ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> d
|
|
|
197
254
|
Returns:
|
|
198
255
|
Dictionary containing aggregated statistics
|
|
199
256
|
"""
|
|
200
|
-
aggregates =
|
|
201
|
-
"input_tokens": 0,
|
|
202
|
-
"output_tokens": 0,
|
|
203
|
-
"total_tokens": 0,
|
|
204
|
-
"reasoning_tokens": 0,
|
|
205
|
-
"cost_usd": 0.0,
|
|
206
|
-
"latency_ms": 0,
|
|
207
|
-
"models_used": set(),
|
|
208
|
-
"providers_used": set(),
|
|
209
|
-
"tool_calls_count": 0,
|
|
210
|
-
"error_count": 0,
|
|
211
|
-
"success_count": 0,
|
|
212
|
-
"call_count": len(call_records),
|
|
213
|
-
}
|
|
257
|
+
aggregates = _AggregateAccumulator(call_count=len(call_records))
|
|
214
258
|
|
|
215
259
|
for record in call_records:
|
|
216
260
|
# Token aggregation
|
|
217
261
|
if record.usage:
|
|
218
262
|
if record.usage.input_tokens:
|
|
219
|
-
aggregates
|
|
263
|
+
aggregates.input_tokens += record.usage.input_tokens
|
|
220
264
|
if record.usage.output_tokens:
|
|
221
|
-
aggregates
|
|
265
|
+
aggregates.output_tokens += record.usage.output_tokens
|
|
222
266
|
if record.usage.total_tokens:
|
|
223
|
-
aggregates
|
|
267
|
+
aggregates.total_tokens += record.usage.total_tokens
|
|
224
268
|
if record.usage.reasoning_tokens:
|
|
225
|
-
aggregates
|
|
269
|
+
aggregates.reasoning_tokens += record.usage.reasoning_tokens
|
|
226
270
|
if record.usage.cost_usd:
|
|
227
|
-
aggregates
|
|
271
|
+
aggregates.cost_usd += record.usage.cost_usd
|
|
228
272
|
|
|
229
273
|
# Latency aggregation
|
|
230
|
-
if record.latency_ms:
|
|
231
|
-
aggregates
|
|
274
|
+
if record.latency_ms is not None:
|
|
275
|
+
aggregates.latency_ms += record.latency_ms
|
|
232
276
|
|
|
233
277
|
# Model and provider tracking
|
|
234
278
|
if record.model_name:
|
|
235
|
-
aggregates
|
|
279
|
+
aggregates.models_used.add(record.model_name)
|
|
236
280
|
if record.provider:
|
|
237
|
-
aggregates
|
|
281
|
+
aggregates.providers_used.add(record.provider)
|
|
238
282
|
|
|
239
283
|
# Tool calls
|
|
240
|
-
aggregates
|
|
284
|
+
aggregates.tool_calls_count += len(record.output_tool_calls)
|
|
241
285
|
|
|
242
286
|
# Success/error tracking
|
|
243
287
|
if record.outcome == "error":
|
|
244
|
-
aggregates
|
|
288
|
+
aggregates.error_count += 1
|
|
245
289
|
elif record.outcome == "success":
|
|
246
|
-
aggregates
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
290
|
+
aggregates.success_count += 1
|
|
291
|
+
|
|
292
|
+
summary: AggregateSummary = {
|
|
293
|
+
"call_count": aggregates.call_count,
|
|
294
|
+
"input_tokens": aggregates.input_tokens,
|
|
295
|
+
"output_tokens": aggregates.output_tokens,
|
|
296
|
+
"total_tokens": aggregates.total_tokens,
|
|
297
|
+
"reasoning_tokens": aggregates.reasoning_tokens,
|
|
298
|
+
"cost_usd": aggregates.cost_usd,
|
|
299
|
+
"latency_ms": aggregates.latency_ms,
|
|
300
|
+
"models_used": list(aggregates.models_used),
|
|
301
|
+
"providers_used": list(aggregates.providers_used),
|
|
302
|
+
"tool_calls_count": aggregates.tool_calls_count,
|
|
303
|
+
"error_count": aggregates.error_count,
|
|
304
|
+
"success_count": aggregates.success_count,
|
|
305
|
+
}
|
|
251
306
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
aggregates["avg_output_tokens"] = aggregates["output_tokens"] / aggregates["call_count"]
|
|
307
|
+
if aggregates.call_count > 0:
|
|
308
|
+
summary["avg_latency_ms"] = aggregates.latency_ms / aggregates.call_count
|
|
309
|
+
summary["avg_input_tokens"] = aggregates.input_tokens / aggregates.call_count
|
|
310
|
+
summary["avg_output_tokens"] = aggregates.output_tokens / aggregates.call_count
|
|
257
311
|
|
|
258
|
-
return
|
|
312
|
+
return summary
|
|
259
313
|
|
|
260
314
|
|
|
261
315
|
def create_llm_call_record_from_streaming(
|
|
@@ -26,6 +26,7 @@ application to continue without blocking on sync operations.
|
|
|
26
26
|
|
|
27
27
|
import asyncio
|
|
28
28
|
import logging
|
|
29
|
+
from typing import Any
|
|
29
30
|
|
|
30
31
|
import libsql
|
|
31
32
|
|
|
@@ -66,8 +67,8 @@ class ReplicaSync:
|
|
|
66
67
|
self.sync_url = sync_url or CONFIG.sync_url
|
|
67
68
|
self.auth_token = auth_token or CONFIG.auth_token
|
|
68
69
|
self.sync_interval = sync_interval or CONFIG.sync_interval
|
|
69
|
-
self._sync_task: asyncio.Task | None = None
|
|
70
|
-
self._conn:
|
|
70
|
+
self._sync_task: asyncio.Task[Any] | None = None
|
|
71
|
+
self._conn: Any | None = None
|
|
71
72
|
|
|
72
73
|
def _ensure_connection(self):
|
|
73
74
|
"""Ensure libsql connection is established.
|
|
@@ -113,8 +114,11 @@ class ReplicaSync:
|
|
|
113
114
|
"""
|
|
114
115
|
try:
|
|
115
116
|
self._ensure_connection()
|
|
117
|
+
conn = self._conn
|
|
118
|
+
if conn is None:
|
|
119
|
+
raise RuntimeError("Replica sync connection is not available after initialization")
|
|
116
120
|
# Run sync in thread pool since libsql sync is blocking
|
|
117
|
-
await asyncio.to_thread(
|
|
121
|
+
await asyncio.to_thread(conn.sync)
|
|
118
122
|
logger.info("Successfully synced with remote Turso database")
|
|
119
123
|
return True
|
|
120
124
|
except Exception as e:
|
|
@@ -146,7 +150,7 @@ class ReplicaSync:
|
|
|
146
150
|
# Sleep until next sync interval
|
|
147
151
|
await asyncio.sleep(self.sync_interval)
|
|
148
152
|
|
|
149
|
-
def start_background_sync(self) -> asyncio.Task:
|
|
153
|
+
def start_background_sync(self) -> asyncio.Task[Any]:
|
|
150
154
|
"""Start the background sync task.
|
|
151
155
|
|
|
152
156
|
Creates an asyncio task that runs the sync loop. The task is stored
|
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import functools
|
|
5
5
|
import time
|
|
6
|
-
from collections.abc import Callable
|
|
7
|
-
from typing import Any, TypeVar
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from typing import Any, TypeVar, cast
|
|
8
8
|
|
|
9
9
|
T = TypeVar("T")
|
|
10
10
|
|
|
@@ -18,9 +18,9 @@ def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0)
|
|
|
18
18
|
backoff: Backoff multiplier for each retry
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
-
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
21
|
+
def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
|
|
22
22
|
@functools.wraps(func)
|
|
23
|
-
async def wrapper(*args, **kwargs):
|
|
23
|
+
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
24
24
|
last_exception: Exception | None = None
|
|
25
25
|
current_delay = delay
|
|
26
26
|
|
|
@@ -171,13 +171,14 @@ STORAGE_METRICS = StorageMetrics()
|
|
|
171
171
|
def track_metrics(operation: str):
|
|
172
172
|
"""Decorator to track storage operation metrics."""
|
|
173
173
|
|
|
174
|
-
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
174
|
+
def decorator(func: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
175
175
|
@functools.wraps(func)
|
|
176
|
-
async def async_wrapper(*args, **kwargs):
|
|
176
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
177
177
|
start_time = time.time()
|
|
178
178
|
success = False
|
|
179
179
|
try:
|
|
180
|
-
|
|
180
|
+
async_func = cast(Callable[..., Awaitable[T]], func)
|
|
181
|
+
result = await async_func(*args, **kwargs)
|
|
181
182
|
success = True
|
|
182
183
|
return result
|
|
183
184
|
finally:
|
|
@@ -185,11 +186,12 @@ def track_metrics(operation: str):
|
|
|
185
186
|
STORAGE_METRICS.record_operation(operation, duration, success)
|
|
186
187
|
|
|
187
188
|
@functools.wraps(func)
|
|
188
|
-
def sync_wrapper(*args, **kwargs):
|
|
189
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
189
190
|
start_time = time.time()
|
|
190
191
|
success = False
|
|
191
192
|
try:
|
|
192
|
-
|
|
193
|
+
sync_func = cast(Callable[..., T], func)
|
|
194
|
+
result = sync_func(*args, **kwargs)
|
|
193
195
|
success = True
|
|
194
196
|
return result
|
|
195
197
|
finally:
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Turso integration package for tracing v3."""
|
|
2
|
+
|
|
3
|
+
from .daemon import SqldDaemon, get_daemon, start_sqld, stop_sqld
|
|
4
|
+
from .native_manager import NativeLibsqlTraceManager
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"SqldDaemon",
|
|
8
|
+
"NativeLibsqlTraceManager",
|
|
9
|
+
"get_daemon",
|
|
10
|
+
"start_sqld",
|
|
11
|
+
"stop_sqld",
|
|
12
|
+
]
|
|
@@ -6,6 +6,7 @@ import subprocess
|
|
|
6
6
|
import time
|
|
7
7
|
|
|
8
8
|
import requests
|
|
9
|
+
from requests import RequestException
|
|
9
10
|
|
|
10
11
|
from ..config import CONFIG
|
|
11
12
|
|
|
@@ -79,7 +80,7 @@ class SqldDaemon:
|
|
|
79
80
|
response = requests.get(health_url, timeout=1)
|
|
80
81
|
if response.status_code == 200:
|
|
81
82
|
return
|
|
82
|
-
except
|
|
83
|
+
except RequestException:
|
|
83
84
|
pass
|
|
84
85
|
|
|
85
86
|
# Check if process crashed
|