synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/run_eval.py +267 -41
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +74 -33
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +146 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +15 -3
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
|
@@ -4,11 +4,14 @@ This module provides utilities to convert vendor responses to LLMCallRecord
|
|
|
4
4
|
format and compute aggregates from call records.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
7
9
|
import uuid
|
|
10
|
+
from dataclasses import dataclass, field
|
|
8
11
|
from datetime import UTC, datetime
|
|
9
|
-
from typing import Any
|
|
12
|
+
from typing import Any, TypedDict, cast
|
|
10
13
|
|
|
11
|
-
from
|
|
14
|
+
from .lm_call_record_abstractions import (
|
|
12
15
|
LLMCallRecord,
|
|
13
16
|
LLMChunk,
|
|
14
17
|
LLMContentPart,
|
|
@@ -17,7 +20,21 @@ from synth_ai.tracing_v3.lm_call_record_abstractions import (
|
|
|
17
20
|
LLMUsage,
|
|
18
21
|
ToolCallSpec,
|
|
19
22
|
)
|
|
20
|
-
|
|
23
|
+
|
|
24
|
+
BaseLMResponse = Any
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _UsageDict(TypedDict, total=False):
|
|
28
|
+
prompt_tokens: int
|
|
29
|
+
completion_tokens: int
|
|
30
|
+
total_tokens: int
|
|
31
|
+
reasoning_tokens: int
|
|
32
|
+
cost_usd: float
|
|
33
|
+
duration_ms: int
|
|
34
|
+
reasoning_input_tokens: int
|
|
35
|
+
reasoning_output_tokens: int
|
|
36
|
+
cache_write_tokens: int
|
|
37
|
+
cache_read_tokens: int
|
|
21
38
|
|
|
22
39
|
|
|
23
40
|
def create_llm_call_record_from_response(
|
|
@@ -110,9 +127,10 @@ def create_llm_call_record_from_response(
|
|
|
110
127
|
)
|
|
111
128
|
|
|
112
129
|
# Extract tool calls if present
|
|
113
|
-
output_tool_calls = []
|
|
114
|
-
|
|
115
|
-
|
|
130
|
+
output_tool_calls: list[ToolCallSpec] = []
|
|
131
|
+
tool_calls_data = cast(list[dict[str, Any]] | None, getattr(response, "tool_calls", None))
|
|
132
|
+
if tool_calls_data:
|
|
133
|
+
for idx, tool_call in enumerate(tool_calls_data):
|
|
116
134
|
if isinstance(tool_call, dict):
|
|
117
135
|
output_tool_calls.append(
|
|
118
136
|
ToolCallSpec(
|
|
@@ -125,18 +143,19 @@ def create_llm_call_record_from_response(
|
|
|
125
143
|
|
|
126
144
|
# Extract usage information
|
|
127
145
|
usage = None
|
|
128
|
-
|
|
146
|
+
usage_data = cast(_UsageDict | None, getattr(response, "usage", None))
|
|
147
|
+
if usage_data:
|
|
129
148
|
usage = LLMUsage(
|
|
130
|
-
input_tokens=
|
|
131
|
-
output_tokens=
|
|
132
|
-
total_tokens=
|
|
133
|
-
cost_usd=
|
|
149
|
+
input_tokens=usage_data.get("input_tokens"),
|
|
150
|
+
output_tokens=usage_data.get("output_tokens"),
|
|
151
|
+
total_tokens=usage_data.get("total_tokens"),
|
|
152
|
+
cost_usd=usage_data.get("cost_usd"),
|
|
134
153
|
# Additional token accounting if available
|
|
135
|
-
reasoning_tokens=
|
|
136
|
-
reasoning_input_tokens=
|
|
137
|
-
reasoning_output_tokens=
|
|
138
|
-
cache_write_tokens=
|
|
139
|
-
cache_read_tokens=
|
|
154
|
+
reasoning_tokens=usage_data.get("reasoning_tokens"),
|
|
155
|
+
reasoning_input_tokens=usage_data.get("reasoning_input_tokens"),
|
|
156
|
+
reasoning_output_tokens=usage_data.get("reasoning_output_tokens"),
|
|
157
|
+
cache_write_tokens=usage_data.get("cache_write_tokens"),
|
|
158
|
+
cache_read_tokens=usage_data.get("cache_read_tokens"),
|
|
140
159
|
)
|
|
141
160
|
|
|
142
161
|
# Build request parameters
|
|
@@ -188,7 +207,45 @@ def create_llm_call_record_from_response(
|
|
|
188
207
|
return record
|
|
189
208
|
|
|
190
209
|
|
|
191
|
-
|
|
210
|
+
@dataclass
|
|
211
|
+
class _AggregateAccumulator:
|
|
212
|
+
"""Mutable accumulator for call record aggregates."""
|
|
213
|
+
|
|
214
|
+
call_count: int = 0
|
|
215
|
+
input_tokens: int = 0
|
|
216
|
+
output_tokens: int = 0
|
|
217
|
+
total_tokens: int = 0
|
|
218
|
+
reasoning_tokens: int = 0
|
|
219
|
+
cost_usd: float = 0.0
|
|
220
|
+
latency_ms: int = 0
|
|
221
|
+
models_used: set[str] = field(default_factory=set)
|
|
222
|
+
providers_used: set[str] = field(default_factory=set)
|
|
223
|
+
tool_calls_count: int = 0
|
|
224
|
+
error_count: int = 0
|
|
225
|
+
success_count: int = 0
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class AggregateSummary(TypedDict, total=False):
|
|
229
|
+
"""Aggregate metrics derived from call records."""
|
|
230
|
+
|
|
231
|
+
call_count: int
|
|
232
|
+
input_tokens: int
|
|
233
|
+
output_tokens: int
|
|
234
|
+
total_tokens: int
|
|
235
|
+
reasoning_tokens: int
|
|
236
|
+
cost_usd: float
|
|
237
|
+
latency_ms: int
|
|
238
|
+
models_used: list[str]
|
|
239
|
+
providers_used: list[str]
|
|
240
|
+
tool_calls_count: int
|
|
241
|
+
error_count: int
|
|
242
|
+
success_count: int
|
|
243
|
+
avg_latency_ms: float
|
|
244
|
+
avg_input_tokens: float
|
|
245
|
+
avg_output_tokens: float
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> AggregateSummary:
|
|
192
249
|
"""Compute aggregate statistics from a list of LLMCallRecord instances.
|
|
193
250
|
|
|
194
251
|
Args:
|
|
@@ -197,65 +254,62 @@ def compute_aggregates_from_call_records(call_records: list[LLMCallRecord]) -> d
|
|
|
197
254
|
Returns:
|
|
198
255
|
Dictionary containing aggregated statistics
|
|
199
256
|
"""
|
|
200
|
-
aggregates =
|
|
201
|
-
"input_tokens": 0,
|
|
202
|
-
"output_tokens": 0,
|
|
203
|
-
"total_tokens": 0,
|
|
204
|
-
"reasoning_tokens": 0,
|
|
205
|
-
"cost_usd": 0.0,
|
|
206
|
-
"latency_ms": 0,
|
|
207
|
-
"models_used": set(),
|
|
208
|
-
"providers_used": set(),
|
|
209
|
-
"tool_calls_count": 0,
|
|
210
|
-
"error_count": 0,
|
|
211
|
-
"success_count": 0,
|
|
212
|
-
"call_count": len(call_records),
|
|
213
|
-
}
|
|
257
|
+
aggregates = _AggregateAccumulator(call_count=len(call_records))
|
|
214
258
|
|
|
215
259
|
for record in call_records:
|
|
216
260
|
# Token aggregation
|
|
217
261
|
if record.usage:
|
|
218
262
|
if record.usage.input_tokens:
|
|
219
|
-
aggregates
|
|
263
|
+
aggregates.input_tokens += record.usage.input_tokens
|
|
220
264
|
if record.usage.output_tokens:
|
|
221
|
-
aggregates
|
|
265
|
+
aggregates.output_tokens += record.usage.output_tokens
|
|
222
266
|
if record.usage.total_tokens:
|
|
223
|
-
aggregates
|
|
267
|
+
aggregates.total_tokens += record.usage.total_tokens
|
|
224
268
|
if record.usage.reasoning_tokens:
|
|
225
|
-
aggregates
|
|
269
|
+
aggregates.reasoning_tokens += record.usage.reasoning_tokens
|
|
226
270
|
if record.usage.cost_usd:
|
|
227
|
-
aggregates
|
|
271
|
+
aggregates.cost_usd += record.usage.cost_usd
|
|
228
272
|
|
|
229
273
|
# Latency aggregation
|
|
230
|
-
if record.latency_ms:
|
|
231
|
-
aggregates
|
|
274
|
+
if record.latency_ms is not None:
|
|
275
|
+
aggregates.latency_ms += record.latency_ms
|
|
232
276
|
|
|
233
277
|
# Model and provider tracking
|
|
234
278
|
if record.model_name:
|
|
235
|
-
aggregates
|
|
279
|
+
aggregates.models_used.add(record.model_name)
|
|
236
280
|
if record.provider:
|
|
237
|
-
aggregates
|
|
281
|
+
aggregates.providers_used.add(record.provider)
|
|
238
282
|
|
|
239
283
|
# Tool calls
|
|
240
|
-
aggregates
|
|
284
|
+
aggregates.tool_calls_count += len(record.output_tool_calls)
|
|
241
285
|
|
|
242
286
|
# Success/error tracking
|
|
243
287
|
if record.outcome == "error":
|
|
244
|
-
aggregates
|
|
288
|
+
aggregates.error_count += 1
|
|
245
289
|
elif record.outcome == "success":
|
|
246
|
-
aggregates
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
290
|
+
aggregates.success_count += 1
|
|
291
|
+
|
|
292
|
+
summary: AggregateSummary = {
|
|
293
|
+
"call_count": aggregates.call_count,
|
|
294
|
+
"input_tokens": aggregates.input_tokens,
|
|
295
|
+
"output_tokens": aggregates.output_tokens,
|
|
296
|
+
"total_tokens": aggregates.total_tokens,
|
|
297
|
+
"reasoning_tokens": aggregates.reasoning_tokens,
|
|
298
|
+
"cost_usd": aggregates.cost_usd,
|
|
299
|
+
"latency_ms": aggregates.latency_ms,
|
|
300
|
+
"models_used": list(aggregates.models_used),
|
|
301
|
+
"providers_used": list(aggregates.providers_used),
|
|
302
|
+
"tool_calls_count": aggregates.tool_calls_count,
|
|
303
|
+
"error_count": aggregates.error_count,
|
|
304
|
+
"success_count": aggregates.success_count,
|
|
305
|
+
}
|
|
251
306
|
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
aggregates["avg_output_tokens"] = aggregates["output_tokens"] / aggregates["call_count"]
|
|
307
|
+
if aggregates.call_count > 0:
|
|
308
|
+
summary["avg_latency_ms"] = aggregates.latency_ms / aggregates.call_count
|
|
309
|
+
summary["avg_input_tokens"] = aggregates.input_tokens / aggregates.call_count
|
|
310
|
+
summary["avg_output_tokens"] = aggregates.output_tokens / aggregates.call_count
|
|
257
311
|
|
|
258
|
-
return
|
|
312
|
+
return summary
|
|
259
313
|
|
|
260
314
|
|
|
261
315
|
def create_llm_call_record_from_streaming(
|
|
@@ -26,6 +26,7 @@ application to continue without blocking on sync operations.
|
|
|
26
26
|
|
|
27
27
|
import asyncio
|
|
28
28
|
import logging
|
|
29
|
+
from typing import Any
|
|
29
30
|
|
|
30
31
|
import libsql
|
|
31
32
|
|
|
@@ -66,8 +67,8 @@ class ReplicaSync:
|
|
|
66
67
|
self.sync_url = sync_url or CONFIG.sync_url
|
|
67
68
|
self.auth_token = auth_token or CONFIG.auth_token
|
|
68
69
|
self.sync_interval = sync_interval or CONFIG.sync_interval
|
|
69
|
-
self._sync_task: asyncio.Task | None = None
|
|
70
|
-
self._conn:
|
|
70
|
+
self._sync_task: asyncio.Task[Any] | None = None
|
|
71
|
+
self._conn: Any | None = None
|
|
71
72
|
|
|
72
73
|
def _ensure_connection(self):
|
|
73
74
|
"""Ensure libsql connection is established.
|
|
@@ -113,8 +114,11 @@ class ReplicaSync:
|
|
|
113
114
|
"""
|
|
114
115
|
try:
|
|
115
116
|
self._ensure_connection()
|
|
117
|
+
conn = self._conn
|
|
118
|
+
if conn is None:
|
|
119
|
+
raise RuntimeError("Replica sync connection is not available after initialization")
|
|
116
120
|
# Run sync in thread pool since libsql sync is blocking
|
|
117
|
-
await asyncio.to_thread(
|
|
121
|
+
await asyncio.to_thread(conn.sync)
|
|
118
122
|
logger.info("Successfully synced with remote Turso database")
|
|
119
123
|
return True
|
|
120
124
|
except Exception as e:
|
|
@@ -146,7 +150,7 @@ class ReplicaSync:
|
|
|
146
150
|
# Sleep until next sync interval
|
|
147
151
|
await asyncio.sleep(self.sync_interval)
|
|
148
152
|
|
|
149
|
-
def start_background_sync(self) -> asyncio.Task:
|
|
153
|
+
def start_background_sync(self) -> asyncio.Task[Any]:
|
|
150
154
|
"""Start the background sync task.
|
|
151
155
|
|
|
152
156
|
Creates an asyncio task that runs the sync loop. The task is stored
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""HTTP-safe serialization helpers for tracing v3.
|
|
2
|
+
|
|
3
|
+
These utilities normalize tracing structures (including dataclasses) into
|
|
4
|
+
JSON-serializable forms and provide a compact JSON encoder suitable for
|
|
5
|
+
HTTP transmission to backend services.
|
|
6
|
+
|
|
7
|
+
Design goals:
|
|
8
|
+
- Preserve structure while ensuring standard-compliant JSON (no NaN/Infinity)
|
|
9
|
+
- Handle common non-JSON types: datetime, Decimal, bytes, set/tuple, numpy scalars
|
|
10
|
+
- Keep output compact (no unnecessary whitespace) while readable if needed
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import base64
|
|
16
|
+
import json
|
|
17
|
+
from dataclasses import asdict, is_dataclass
|
|
18
|
+
from datetime import date, datetime
|
|
19
|
+
from decimal import Decimal
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import numpy as _np # type: ignore
|
|
25
|
+
except Exception: # pragma: no cover - numpy optional at runtime
|
|
26
|
+
_np = None # type: ignore
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def normalize_for_json(value: Any) -> Any:
|
|
30
|
+
"""Return a JSON-serializable version of ``value``.
|
|
31
|
+
|
|
32
|
+
Rules:
|
|
33
|
+
- dataclass → dict (recursively normalized)
|
|
34
|
+
- datetime/date → ISO-8601 string (UTC-aware datetimes preserve tzinfo)
|
|
35
|
+
- Decimal → float (fallback to string if not finite)
|
|
36
|
+
- bytes/bytearray → base64 string (RFC 4648)
|
|
37
|
+
- set/tuple → list
|
|
38
|
+
- Enum → enum.value (normalized)
|
|
39
|
+
- numpy scalar → corresponding Python scalar
|
|
40
|
+
- float NaN/Inf/−Inf → None (to keep JSON standard compliant)
|
|
41
|
+
- dict / list → recursively normalized
|
|
42
|
+
- other primitives (str, int, bool, None, float) passed through
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
# Dataclasses
|
|
46
|
+
if is_dataclass(value) and not isinstance(value, type):
|
|
47
|
+
try:
|
|
48
|
+
return normalize_for_json(asdict(value))
|
|
49
|
+
except Exception:
|
|
50
|
+
# Fallback: best-effort conversion via __dict__
|
|
51
|
+
return normalize_for_json(getattr(value, "__dict__", {}))
|
|
52
|
+
|
|
53
|
+
# Mapping
|
|
54
|
+
if isinstance(value, dict):
|
|
55
|
+
return {str(k): normalize_for_json(v) for k, v in value.items()}
|
|
56
|
+
|
|
57
|
+
# Sequences
|
|
58
|
+
if isinstance(value, (list, tuple, set)):
|
|
59
|
+
return [normalize_for_json(v) for v in value]
|
|
60
|
+
|
|
61
|
+
# Datetime / Date
|
|
62
|
+
if isinstance(value, (datetime, date)):
|
|
63
|
+
return value.isoformat()
|
|
64
|
+
|
|
65
|
+
# Decimal
|
|
66
|
+
if isinstance(value, Decimal):
|
|
67
|
+
try:
|
|
68
|
+
f = float(value)
|
|
69
|
+
if f != f or f in (float("inf"), float("-inf")):
|
|
70
|
+
return str(value)
|
|
71
|
+
return f
|
|
72
|
+
except Exception:
|
|
73
|
+
return str(value)
|
|
74
|
+
|
|
75
|
+
# Bytes-like
|
|
76
|
+
if isinstance(value, (bytes, bytearray)):
|
|
77
|
+
return base64.b64encode(bytes(value)).decode("ascii")
|
|
78
|
+
|
|
79
|
+
# Enum
|
|
80
|
+
if isinstance(value, Enum):
|
|
81
|
+
return normalize_for_json(value.value)
|
|
82
|
+
|
|
83
|
+
# Numpy scalars / arrays
|
|
84
|
+
if _np is not None:
|
|
85
|
+
if isinstance(value, (_np.generic,)): # type: ignore[attr-defined]
|
|
86
|
+
return normalize_for_json(value.item())
|
|
87
|
+
if isinstance(value, (_np.ndarray,)):
|
|
88
|
+
return normalize_for_json(value.tolist())
|
|
89
|
+
|
|
90
|
+
# Floats: sanitize NaN / Infinity to None
|
|
91
|
+
if isinstance(value, float):
|
|
92
|
+
if value != value or value in (float("inf"), float("-inf")):
|
|
93
|
+
return None
|
|
94
|
+
return value
|
|
95
|
+
|
|
96
|
+
return value
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def dumps_http_json(payload: Any) -> str:
|
|
100
|
+
"""Dump ``payload`` into a compact, HTTP-safe JSON string.
|
|
101
|
+
|
|
102
|
+
- Recursively normalizes non-JSON types (see ``normalize_for_json``)
|
|
103
|
+
- Disallows NaN/Infinity per RFC 8259 (allow_nan=False)
|
|
104
|
+
- Uses compact separators and preserves Unicode (ensure_ascii=False)
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
normalized = normalize_for_json(payload)
|
|
108
|
+
return json.dumps(
|
|
109
|
+
normalized,
|
|
110
|
+
ensure_ascii=False,
|
|
111
|
+
allow_nan=False,
|
|
112
|
+
separators=(",", ":"),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def serialize_trace_for_http(trace: Any) -> str:
|
|
117
|
+
"""Serialize a tracing v3 session (or dict-like) to HTTP-safe JSON.
|
|
118
|
+
|
|
119
|
+
Accepts either a dataclass (e.g., SessionTrace) or a dict/list and
|
|
120
|
+
applies normalization and compact JSON encoding.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
if is_dataclass(trace) and not isinstance(trace, type):
|
|
124
|
+
try:
|
|
125
|
+
return dumps_http_json(asdict(trace))
|
|
126
|
+
except Exception:
|
|
127
|
+
return dumps_http_json(getattr(trace, "__dict__", {}))
|
|
128
|
+
return dumps_http_json(trace)
|
|
129
|
+
|
|
130
|
+
|
|
@@ -3,8 +3,8 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import functools
|
|
5
5
|
import time
|
|
6
|
-
from collections.abc import Callable
|
|
7
|
-
from typing import Any, TypeVar
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from typing import Any, TypeVar, cast
|
|
8
8
|
|
|
9
9
|
T = TypeVar("T")
|
|
10
10
|
|
|
@@ -18,9 +18,9 @@ def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0)
|
|
|
18
18
|
backoff: Backoff multiplier for each retry
|
|
19
19
|
"""
|
|
20
20
|
|
|
21
|
-
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
21
|
+
def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
|
|
22
22
|
@functools.wraps(func)
|
|
23
|
-
async def wrapper(*args, **kwargs):
|
|
23
|
+
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
|
24
24
|
last_exception: Exception | None = None
|
|
25
25
|
current_delay = delay
|
|
26
26
|
|
|
@@ -171,13 +171,14 @@ STORAGE_METRICS = StorageMetrics()
|
|
|
171
171
|
def track_metrics(operation: str):
|
|
172
172
|
"""Decorator to track storage operation metrics."""
|
|
173
173
|
|
|
174
|
-
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
|
174
|
+
def decorator(func: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
175
175
|
@functools.wraps(func)
|
|
176
|
-
async def async_wrapper(*args, **kwargs):
|
|
176
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
177
177
|
start_time = time.time()
|
|
178
178
|
success = False
|
|
179
179
|
try:
|
|
180
|
-
|
|
180
|
+
async_func = cast(Callable[..., Awaitable[T]], func)
|
|
181
|
+
result = await async_func(*args, **kwargs)
|
|
181
182
|
success = True
|
|
182
183
|
return result
|
|
183
184
|
finally:
|
|
@@ -185,11 +186,12 @@ def track_metrics(operation: str):
|
|
|
185
186
|
STORAGE_METRICS.record_operation(operation, duration, success)
|
|
186
187
|
|
|
187
188
|
@functools.wraps(func)
|
|
188
|
-
def sync_wrapper(*args, **kwargs):
|
|
189
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
189
190
|
start_time = time.time()
|
|
190
191
|
success = False
|
|
191
192
|
try:
|
|
192
|
-
|
|
193
|
+
sync_func = cast(Callable[..., T], func)
|
|
194
|
+
result = sync_func(*args, **kwargs)
|
|
193
195
|
success = True
|
|
194
196
|
return result
|
|
195
197
|
finally:
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Turso integration package for tracing v3."""
|
|
2
|
+
|
|
3
|
+
from .daemon import SqldDaemon, get_daemon, start_sqld, stop_sqld
|
|
4
|
+
from .native_manager import NativeLibsqlTraceManager
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"SqldDaemon",
|
|
8
|
+
"NativeLibsqlTraceManager",
|
|
9
|
+
"get_daemon",
|
|
10
|
+
"start_sqld",
|
|
11
|
+
"stop_sqld",
|
|
12
|
+
]
|
|
@@ -6,6 +6,7 @@ import subprocess
|
|
|
6
6
|
import time
|
|
7
7
|
|
|
8
8
|
import requests
|
|
9
|
+
from requests import RequestException
|
|
9
10
|
|
|
10
11
|
from ..config import CONFIG
|
|
11
12
|
|
|
@@ -79,7 +80,7 @@ class SqldDaemon:
|
|
|
79
80
|
response = requests.get(health_url, timeout=1)
|
|
80
81
|
if response.status_code == 200:
|
|
81
82
|
return
|
|
82
|
-
except
|
|
83
|
+
except RequestException:
|
|
83
84
|
pass
|
|
84
85
|
|
|
85
86
|
# Check if process crashed
|
|
@@ -11,18 +11,14 @@ import asyncio
|
|
|
11
11
|
import json
|
|
12
12
|
import logging
|
|
13
13
|
import re
|
|
14
|
+
from collections.abc import Callable
|
|
14
15
|
from dataclasses import asdict, dataclass
|
|
15
16
|
from datetime import UTC, datetime
|
|
16
|
-
from typing import Any
|
|
17
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
17
18
|
|
|
18
19
|
import libsql
|
|
19
20
|
from sqlalchemy.engine import make_url
|
|
20
21
|
|
|
21
|
-
try: # pragma: no cover - exercised only when pandas present
|
|
22
|
-
import pandas as pd # type: ignore
|
|
23
|
-
except Exception: # pragma: no cover
|
|
24
|
-
pd = None # type: ignore[assignment]
|
|
25
|
-
|
|
26
22
|
from ..abstractions import (
|
|
27
23
|
EnvironmentEvent,
|
|
28
24
|
LMCAISEvent,
|
|
@@ -34,6 +30,24 @@ from ..config import CONFIG
|
|
|
34
30
|
from ..storage.base import TraceStorage
|
|
35
31
|
from .models import analytics_views
|
|
36
32
|
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from sqlite3 import Connection as LibsqlConnection
|
|
35
|
+
else: # pragma: no cover - runtime fallback for typing only
|
|
36
|
+
LibsqlConnection = Any # type: ignore[assignment]
|
|
37
|
+
|
|
38
|
+
_LIBSQL_CONNECT_ATTR = getattr(libsql, "connect", None)
|
|
39
|
+
if _LIBSQL_CONNECT_ATTR is None: # pragma: no cover - defensive guard
|
|
40
|
+
raise RuntimeError("libsql.connect is required for NativeLibsqlTraceManager")
|
|
41
|
+
_libsql_connect: Callable[..., LibsqlConnection] = cast(
|
|
42
|
+
Callable[..., LibsqlConnection],
|
|
43
|
+
_LIBSQL_CONNECT_ATTR,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
try: # pragma: no cover - exercised only when pandas present
|
|
47
|
+
import pandas as pd # type: ignore
|
|
48
|
+
except Exception: # pragma: no cover
|
|
49
|
+
pd = None # type: ignore[assignment]
|
|
50
|
+
|
|
37
51
|
logger = logging.getLogger(__name__)
|
|
38
52
|
|
|
39
53
|
|
|
@@ -66,9 +80,8 @@ def _resolve_connection_target(db_url: str | None, auth_token: str | None) -> _C
|
|
|
66
80
|
# Fallback to SQLAlchemy URL parsing for anything else we missed.
|
|
67
81
|
try:
|
|
68
82
|
parsed = make_url(url)
|
|
69
|
-
if parsed.drivername.startswith("sqlite"):
|
|
70
|
-
|
|
71
|
-
return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
|
|
83
|
+
if parsed.drivername.startswith("sqlite") and parsed.database:
|
|
84
|
+
return _ConnectionTarget(database=parsed.database, auth_token=auth_token)
|
|
72
85
|
if parsed.drivername.startswith("libsql"):
|
|
73
86
|
database = parsed.render_as_string(hide_password=False)
|
|
74
87
|
return _ConnectionTarget(database=database, sync_url=database, auth_token=auth_token)
|
|
@@ -314,12 +327,12 @@ class NativeLibsqlTraceManager(TraceStorage):
|
|
|
314
327
|
):
|
|
315
328
|
self._config_auth_token = auth_token
|
|
316
329
|
self._target = _resolve_connection_target(db_url, auth_token)
|
|
317
|
-
self._conn:
|
|
330
|
+
self._conn: LibsqlConnection | None = None
|
|
318
331
|
self._conn_lock = asyncio.Lock()
|
|
319
332
|
self._op_lock = asyncio.Lock()
|
|
320
333
|
self._initialized = False
|
|
321
334
|
|
|
322
|
-
def _open_connection(self) ->
|
|
335
|
+
def _open_connection(self) -> LibsqlConnection:
|
|
323
336
|
"""Open a libsql connection for the resolved target."""
|
|
324
337
|
kwargs: dict[str, Any] = {}
|
|
325
338
|
if self._target.sync_url and self._target.sync_url.startswith("libsql://"):
|
|
@@ -329,7 +342,7 @@ class NativeLibsqlTraceManager(TraceStorage):
|
|
|
329
342
|
# Disable automatic background sync; ReplicaSync drives this explicitly.
|
|
330
343
|
kwargs.setdefault("sync_interval", 0)
|
|
331
344
|
logger.debug("Opening libsql connection to %s", self._target.database)
|
|
332
|
-
return
|
|
345
|
+
return _libsql_connect(self._target.database, **kwargs)
|
|
333
346
|
|
|
334
347
|
async def initialize(self):
|
|
335
348
|
"""Initialise the backend."""
|
|
@@ -493,7 +506,7 @@ class NativeLibsqlTraceManager(TraceStorage):
|
|
|
493
506
|
return None
|
|
494
507
|
|
|
495
508
|
session_columns = ["session_id", "created_at", "num_timesteps", "num_events", "num_messages", "metadata"]
|
|
496
|
-
session_data = dict(zip(session_columns, session_row))
|
|
509
|
+
session_data = dict(zip(session_columns, session_row, strict=True))
|
|
497
510
|
|
|
498
511
|
timestep_cursor = conn.execute(
|
|
499
512
|
"""
|
|
@@ -608,10 +621,10 @@ class NativeLibsqlTraceManager(TraceStorage):
|
|
|
608
621
|
|
|
609
622
|
if not rows:
|
|
610
623
|
if pd is not None:
|
|
611
|
-
return pd.DataFrame(columns=
|
|
624
|
+
return pd.DataFrame(columns=list(columns))
|
|
612
625
|
return []
|
|
613
626
|
|
|
614
|
-
records = [dict(zip(columns, row)) for row in rows]
|
|
627
|
+
records = [dict(zip(columns, row, strict=True)) for row in rows]
|
|
615
628
|
if pd is not None:
|
|
616
629
|
return pd.DataFrame(records)
|
|
617
630
|
return records
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: synth-ai
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.13.dev1
|
|
4
4
|
Summary: RL as a service SDK - Core AI functionality and tracing
|
|
5
5
|
Author-email: Synth AI <josh@usesynth.ai>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -79,7 +79,7 @@ Dynamic: license-file
|
|
|
79
79
|
|
|
80
80
|
[](https://www.python.org/)
|
|
81
81
|
[](LICENSE)
|
|
82
|
-
[](https://pypi.org/project/synth-ai/)
|
|
83
83
|

|
|
84
84
|

|
|
85
85
|

|
|
@@ -88,6 +88,8 @@ Docs: [Synth‑AI Documentation](https://docs.usesynth.ai/welcome/introduction)
|
|
|
88
88
|
|
|
89
89
|
Fast and effective reinforcement learning for agents, via an API
|
|
90
90
|
|
|
91
|
+
> Latest: 0.2.10 published to PyPI (uv publish)
|
|
92
|
+
|
|
91
93
|
## Highlights
|
|
92
94
|
|
|
93
95
|
- Easily scale gpu topologies - train on 3 a10gs or 8 H100s (multi-node available upon request)
|