synth-ai 0.2.10__py3-none-any.whl → 0.2.13.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of synth-ai might be problematic. Click here for more details.
- examples/agora_ex/README_MoE.md +224 -0
- examples/agora_ex/__init__.py +7 -0
- examples/agora_ex/agora_ex.py +65 -0
- examples/agora_ex/agora_ex_task_app.py +590 -0
- examples/agora_ex/configs/rl_lora_qwen3_moe_2xh200.toml +121 -0
- examples/agora_ex/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/system_prompt_CURRENT.md +63 -0
- examples/agora_ex/task_app/agora_ex_task_app.py +590 -0
- examples/agora_ex/task_app/reward_fn_grpo-human.py +129 -0
- examples/agora_ex/task_app/system_prompt_CURRENT.md +63 -0
- examples/multi_step/configs/crafter_rl_outcome.toml +74 -0
- examples/multi_step/configs/crafter_rl_stepwise_hosted_judge.toml +175 -0
- examples/multi_step/configs/crafter_rl_stepwise_shaped.toml +83 -0
- examples/multi_step/configs/crafter_rl_stepwise_simple.toml +78 -0
- examples/multi_step/crafter_rl_lora.md +51 -10
- examples/multi_step/sse_metrics_streaming_notes.md +357 -0
- examples/multi_step/task_app_config_notes.md +494 -0
- examples/warming_up_to_rl/configs/eval_stepwise_complex.toml +35 -0
- examples/warming_up_to_rl/configs/eval_stepwise_consistent.toml +26 -0
- examples/warming_up_to_rl/configs/eval_stepwise_per_achievement.toml +36 -0
- examples/warming_up_to_rl/configs/eval_stepwise_simple.toml +32 -0
- examples/warming_up_to_rl/run_eval.py +267 -41
- examples/warming_up_to_rl/task_app/grpo_crafter.py +3 -33
- examples/warming_up_to_rl/task_app/synth_envs_hosted/inference/openai_client.py +109 -45
- examples/warming_up_to_rl/task_app/synth_envs_hosted/policy_routes.py +42 -46
- examples/warming_up_to_rl/task_app/synth_envs_hosted/rollout.py +376 -193
- synth_ai/__init__.py +41 -1
- synth_ai/api/train/builders.py +74 -33
- synth_ai/api/train/cli.py +29 -6
- synth_ai/api/train/configs/__init__.py +44 -0
- synth_ai/api/train/configs/rl.py +133 -0
- synth_ai/api/train/configs/sft.py +94 -0
- synth_ai/api/train/configs/shared.py +24 -0
- synth_ai/api/train/env_resolver.py +18 -19
- synth_ai/api/train/supported_algos.py +8 -5
- synth_ai/api/train/utils.py +6 -1
- synth_ai/cli/__init__.py +4 -2
- synth_ai/cli/_storage.py +19 -0
- synth_ai/cli/balance.py +14 -2
- synth_ai/cli/calc.py +37 -22
- synth_ai/cli/demo.py +38 -39
- synth_ai/cli/legacy_root_backup.py +12 -14
- synth_ai/cli/recent.py +12 -7
- synth_ai/cli/rl_demo.py +81 -102
- synth_ai/cli/status.py +4 -3
- synth_ai/cli/task_apps.py +146 -137
- synth_ai/cli/traces.py +4 -3
- synth_ai/cli/watch.py +3 -2
- synth_ai/demos/core/cli.py +121 -159
- synth_ai/environments/examples/crafter_classic/environment.py +16 -0
- synth_ai/evals/__init__.py +15 -0
- synth_ai/evals/client.py +85 -0
- synth_ai/evals/types.py +42 -0
- synth_ai/jobs/client.py +15 -3
- synth_ai/judge_schemas.py +127 -0
- synth_ai/rubrics/__init__.py +22 -0
- synth_ai/rubrics/validators.py +126 -0
- synth_ai/task/server.py +14 -7
- synth_ai/tracing_v3/decorators.py +51 -26
- synth_ai/tracing_v3/examples/basic_usage.py +12 -7
- synth_ai/tracing_v3/llm_call_record_helpers.py +107 -53
- synth_ai/tracing_v3/replica_sync.py +8 -4
- synth_ai/tracing_v3/serialization.py +130 -0
- synth_ai/tracing_v3/storage/utils.py +11 -9
- synth_ai/tracing_v3/turso/__init__.py +12 -0
- synth_ai/tracing_v3/turso/daemon.py +2 -1
- synth_ai/tracing_v3/turso/native_manager.py +28 -15
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/METADATA +4 -2
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/RECORD +73 -40
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/entry_points.txt +0 -1
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/WHEEL +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/licenses/LICENSE +0 -0
- {synth_ai-0.2.10.dist-info → synth_ai-0.2.13.dev1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Judge API Contract Schemas
|
|
3
|
+
|
|
4
|
+
These schemas define the expected structure for requests and responses
|
|
5
|
+
to the judge scoring endpoint at POST /api/judge/v1/score.
|
|
6
|
+
|
|
7
|
+
This is the canonical contract that the backend MUST conform to.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, List, Literal, Optional
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CriterionScorePayload(BaseModel):
|
|
18
|
+
"""Per-criterion score returned by the judge."""
|
|
19
|
+
|
|
20
|
+
score: float = Field(..., description="Numeric score for this criterion")
|
|
21
|
+
reason: str = Field(default="", description="Explanation for the score")
|
|
22
|
+
weight: float = Field(default=1.0, description="Weight of this criterion")
|
|
23
|
+
description: str = Field(default="", description="Description of the criterion")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ReviewPayload(BaseModel):
|
|
27
|
+
"""Rubric review (event-level or outcome-level)."""
|
|
28
|
+
|
|
29
|
+
criteria: Dict[str, CriterionScorePayload] = Field(
|
|
30
|
+
default_factory=dict,
|
|
31
|
+
description="Map of criterion keys to their scores"
|
|
32
|
+
)
|
|
33
|
+
total: float = Field(default=0.0, description="Aggregated total score")
|
|
34
|
+
summary: Optional[str] = Field(None, description="Optional text summary")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class JudgeScoreResponse(BaseModel):
|
|
38
|
+
"""
|
|
39
|
+
Response body for POST /api/judge/v1/score.
|
|
40
|
+
|
|
41
|
+
This is the canonical contract that judge backends MUST return.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
status: Literal["ok", "failed"] = Field(default="ok", description="Request status")
|
|
45
|
+
event_reviews: List[ReviewPayload] = Field(
|
|
46
|
+
default_factory=list,
|
|
47
|
+
description="List of per-event rubric reviews (one per step)"
|
|
48
|
+
)
|
|
49
|
+
outcome_review: Optional[ReviewPayload] = Field(
|
|
50
|
+
None,
|
|
51
|
+
description="Optional outcome-level rubric review"
|
|
52
|
+
)
|
|
53
|
+
event_totals: List[float] = Field(
|
|
54
|
+
default_factory=list,
|
|
55
|
+
description="List of aggregated scores per event (matches event_reviews length)"
|
|
56
|
+
)
|
|
57
|
+
details: Dict[str, Any] = Field(
|
|
58
|
+
default_factory=dict,
|
|
59
|
+
description="Additional details (provider, latency, etc.)"
|
|
60
|
+
)
|
|
61
|
+
metadata: Dict[str, Any] = Field(
|
|
62
|
+
default_factory=dict,
|
|
63
|
+
description="Request metadata (provider, options, etc.)"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def aggregate_event_reward(self) -> float | None:
|
|
67
|
+
"""
|
|
68
|
+
Aggregate all event totals into a single reward.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Sum of all event_totals, or None if empty
|
|
72
|
+
"""
|
|
73
|
+
if not self.event_totals:
|
|
74
|
+
return None
|
|
75
|
+
return sum(self.event_totals)
|
|
76
|
+
|
|
77
|
+
def aggregate_outcome_reward(self) -> float | None:
|
|
78
|
+
"""
|
|
79
|
+
Extract outcome reward from outcome_review.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
outcome_review.total, or None if no outcome review
|
|
83
|
+
"""
|
|
84
|
+
if self.outcome_review is None:
|
|
85
|
+
return None
|
|
86
|
+
return self.outcome_review.total
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# Request schemas for completeness
|
|
90
|
+
|
|
91
|
+
class JudgeTaskApp(BaseModel):
|
|
92
|
+
"""Task application metadata."""
|
|
93
|
+
|
|
94
|
+
id: str = Field(..., description="Task app identifier")
|
|
95
|
+
base_url: Optional[str] = Field(None, description="Optional base URL for task app")
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class JudgeOptions(BaseModel):
|
|
99
|
+
"""Judge provider and configuration options."""
|
|
100
|
+
|
|
101
|
+
provider: Optional[str] = Field(None, description="Judge provider (e.g., 'openai', 'groq')")
|
|
102
|
+
model: Optional[str] = Field(None, description="Model identifier")
|
|
103
|
+
rubric_id: Optional[str] = Field(None, description="Rubric identifier")
|
|
104
|
+
event: bool = Field(True, description="Enable event-level judging")
|
|
105
|
+
outcome: bool = Field(True, description="Enable outcome-level judging")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class JudgeTracePayload(BaseModel):
|
|
109
|
+
"""Trace payload containing trajectory context."""
|
|
110
|
+
|
|
111
|
+
event_history: List[Dict[str, Any]] = Field(..., description="List of events/steps")
|
|
112
|
+
markov_blanket_message_history: List[Dict[str, Any]] = Field(
|
|
113
|
+
default_factory=list,
|
|
114
|
+
description="Optional message history for context"
|
|
115
|
+
)
|
|
116
|
+
metadata: Dict[str, Any] = Field(default_factory=dict, description="Trace metadata")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class JudgeScoreRequest(BaseModel):
|
|
120
|
+
"""Request body for POST /api/judge/v1/score."""
|
|
121
|
+
|
|
122
|
+
policy_name: str = Field(..., description="Name of the policy being evaluated")
|
|
123
|
+
task_app: JudgeTaskApp = Field(..., description="Task application metadata")
|
|
124
|
+
trace: JudgeTracePayload = Field(..., description="Trajectory trace to evaluate")
|
|
125
|
+
options: JudgeOptions = Field(default_factory=lambda: JudgeOptions(), description="Judge options")
|
|
126
|
+
rubric: Optional[Dict[str, Any]] = Field(None, description="Optional explicit rubric criteria")
|
|
127
|
+
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Rubric utilities.
|
|
3
|
+
|
|
4
|
+
Exposes helpers for validating rubric specifications that are used across
|
|
5
|
+
Crafter-style judge configurations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .validators import (
|
|
9
|
+
RubricCriterion,
|
|
10
|
+
RubricSpec,
|
|
11
|
+
ValidationError,
|
|
12
|
+
validate_rubric_dict,
|
|
13
|
+
validate_rubric_file,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"RubricCriterion",
|
|
18
|
+
"RubricSpec",
|
|
19
|
+
"ValidationError",
|
|
20
|
+
"validate_rubric_dict",
|
|
21
|
+
"validate_rubric_file",
|
|
22
|
+
]
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import math
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Iterable, Literal
|
|
7
|
+
|
|
8
|
+
import pydantic
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class RubricCriterion(pydantic.BaseModel):
|
|
12
|
+
"""Single scoring criterion within a rubric."""
|
|
13
|
+
|
|
14
|
+
id: str
|
|
15
|
+
description: str
|
|
16
|
+
weight: float
|
|
17
|
+
scale: str | None = None
|
|
18
|
+
|
|
19
|
+
@pydantic.field_validator("weight")
|
|
20
|
+
@classmethod
|
|
21
|
+
def _validate_weight(cls, value: float) -> float:
|
|
22
|
+
if not math.isfinite(value):
|
|
23
|
+
raise ValueError("weight must be a finite number")
|
|
24
|
+
if value <= 0.0:
|
|
25
|
+
raise ValueError("weight must be positive")
|
|
26
|
+
if value > 1.0:
|
|
27
|
+
raise ValueError("weight must be <= 1.0")
|
|
28
|
+
return value
|
|
29
|
+
|
|
30
|
+
@pydantic.field_validator("id", "description", mode="before")
|
|
31
|
+
@classmethod
|
|
32
|
+
def _strip_string(cls, value: Any) -> Any:
|
|
33
|
+
if isinstance(value, str):
|
|
34
|
+
return value.strip()
|
|
35
|
+
return value
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RubricSpec(pydantic.BaseModel):
|
|
39
|
+
"""High-level rubric definition used by step-wise judges."""
|
|
40
|
+
|
|
41
|
+
version: str
|
|
42
|
+
goal_text: str
|
|
43
|
+
aggregation: Literal["weighted_sum"]
|
|
44
|
+
criteria: list[RubricCriterion]
|
|
45
|
+
|
|
46
|
+
@pydantic.model_validator(mode="after")
|
|
47
|
+
def _validate_weights(self) -> "RubricSpec":
|
|
48
|
+
if not self.criteria:
|
|
49
|
+
raise ValueError("rubric must declare at least one criterion")
|
|
50
|
+
total_weight = sum(criterion.weight for criterion in self.criteria)
|
|
51
|
+
if not math.isclose(total_weight, 1.0, abs_tol=1e-6, rel_tol=1e-6):
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"criterion weights must sum to 1 (got {total_weight:.6f})"
|
|
54
|
+
)
|
|
55
|
+
return self
|
|
56
|
+
|
|
57
|
+
@pydantic.field_validator("version")
|
|
58
|
+
@classmethod
|
|
59
|
+
def _non_empty_version(cls, value: str) -> str:
|
|
60
|
+
value = value.strip()
|
|
61
|
+
if not value:
|
|
62
|
+
raise ValueError("version string must not be empty")
|
|
63
|
+
return value
|
|
64
|
+
|
|
65
|
+
@pydantic.field_validator("goal_text")
|
|
66
|
+
@classmethod
|
|
67
|
+
def _non_empty_goal_text(cls, value: str) -> str:
|
|
68
|
+
value = value.strip()
|
|
69
|
+
if not value:
|
|
70
|
+
raise ValueError("goal_text must not be empty")
|
|
71
|
+
return value
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
ValidationError = pydantic.ValidationError
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def validate_rubric_dict(payload: dict[str, Any]) -> RubricSpec:
|
|
78
|
+
"""
|
|
79
|
+
Validate an in-memory rubric payload and return the parsed model.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
payload: Dictionary representing the rubric JSON.
|
|
83
|
+
Returns:
|
|
84
|
+
Validated RubricSpec instance.
|
|
85
|
+
Raises:
|
|
86
|
+
ValidationError: If the payload is missing required fields or contains
|
|
87
|
+
invalid weights.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
if not isinstance(payload, dict):
|
|
91
|
+
raise TypeError("rubric payload must be a dictionary")
|
|
92
|
+
return RubricSpec.model_validate(payload)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _load_payload_from_file(path: Path) -> dict[str, Any]:
|
|
96
|
+
if path.suffix.lower() != ".json":
|
|
97
|
+
raise ValueError(f"Unsupported rubric file type: {path}")
|
|
98
|
+
text = path.read_text(encoding="utf-8")
|
|
99
|
+
return json.loads(text)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def validate_rubric_file(path: Path) -> RubricSpec:
|
|
103
|
+
"""
|
|
104
|
+
Load and validate a rubric file.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
path: Path to a JSON rubric document.
|
|
108
|
+
Returns:
|
|
109
|
+
Validated RubricSpec instance.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
payload = _load_payload_from_file(path)
|
|
113
|
+
return validate_rubric_dict(payload)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_rubric_files(paths: Iterable[Path]) -> list[RubricSpec]:
|
|
117
|
+
"""
|
|
118
|
+
Validate multiple rubric files and return their parsed models.
|
|
119
|
+
|
|
120
|
+
Useful for bulk validation inside tests or CI checks.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
validated: list[RubricSpec] = []
|
|
124
|
+
for path in paths:
|
|
125
|
+
validated.append(validate_rubric_file(path))
|
|
126
|
+
return validated
|
synth_ai/task/server.py
CHANGED
|
@@ -6,6 +6,7 @@ import asyncio
|
|
|
6
6
|
import inspect
|
|
7
7
|
import os
|
|
8
8
|
from collections.abc import Awaitable, Callable, Iterable, Mapping, MutableMapping, Sequence
|
|
9
|
+
from contextlib import asynccontextmanager
|
|
9
10
|
from dataclasses import dataclass, field
|
|
10
11
|
from pathlib import Path
|
|
11
12
|
from typing import Any
|
|
@@ -34,6 +35,10 @@ InstanceProvider = Callable[[Sequence[int]], Iterable[TaskInfo] | Awaitable[Iter
|
|
|
34
35
|
RolloutExecutor = Callable[[RolloutRequest, Request], Any | Awaitable[Any]]
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
def _default_app_state() -> dict[str, Any]:
|
|
39
|
+
return {}
|
|
40
|
+
|
|
41
|
+
|
|
37
42
|
@dataclass(slots=True)
|
|
38
43
|
class RubricBundle:
|
|
39
44
|
"""Optional rubrics advertised by the task app."""
|
|
@@ -69,7 +74,7 @@ class TaskAppConfig:
|
|
|
69
74
|
proxy: ProxyConfig | None = None
|
|
70
75
|
routers: Sequence[APIRouter] = field(default_factory=tuple)
|
|
71
76
|
middleware: Sequence[Middleware] = field(default_factory=tuple)
|
|
72
|
-
app_state:
|
|
77
|
+
app_state: MutableMapping[str, Any] = field(default_factory=_default_app_state)
|
|
73
78
|
require_api_key: bool = True
|
|
74
79
|
expose_debug_env: bool = True
|
|
75
80
|
cors_origins: Sequence[str] | None = None
|
|
@@ -260,17 +265,19 @@ def create_task_app(config: TaskAppConfig) -> FastAPI:
|
|
|
260
265
|
return _maybe_await(hook(app)) # type: ignore[misc]
|
|
261
266
|
return _maybe_await(hook())
|
|
262
267
|
|
|
263
|
-
@
|
|
264
|
-
async def
|
|
268
|
+
@asynccontextmanager
|
|
269
|
+
async def lifespan(_: FastAPI):
|
|
265
270
|
normalize_environment_api_key()
|
|
266
271
|
normalize_vendor_keys()
|
|
267
272
|
for hook in cfg.startup_hooks:
|
|
268
273
|
await _call_hook(hook)
|
|
274
|
+
try:
|
|
275
|
+
yield
|
|
276
|
+
finally:
|
|
277
|
+
for hook in cfg.shutdown_hooks:
|
|
278
|
+
await _call_hook(hook)
|
|
269
279
|
|
|
270
|
-
|
|
271
|
-
async def _shutdown() -> None: # pragma: no cover - FastAPI lifecycle
|
|
272
|
-
for hook in cfg.shutdown_hooks:
|
|
273
|
-
await _call_hook(hook)
|
|
280
|
+
app.router.lifespan_context = lifespan
|
|
274
281
|
|
|
275
282
|
@app.get("/")
|
|
276
283
|
async def root() -> Mapping[str, Any]:
|
|
@@ -28,8 +28,8 @@ import asyncio
|
|
|
28
28
|
import contextvars
|
|
29
29
|
import functools
|
|
30
30
|
import time
|
|
31
|
-
from collections.abc import Callable
|
|
32
|
-
from typing import Any, TypeVar
|
|
31
|
+
from collections.abc import Awaitable, Callable, Mapping
|
|
32
|
+
from typing import Any, TypeVar, cast, overload
|
|
33
33
|
|
|
34
34
|
from .abstractions import LMCAISEvent, TimeRecord
|
|
35
35
|
from .utils import calculate_cost, detect_provider
|
|
@@ -88,6 +88,16 @@ def get_session_tracer() -> Any:
|
|
|
88
88
|
T = TypeVar("T")
|
|
89
89
|
|
|
90
90
|
|
|
91
|
+
@overload
|
|
92
|
+
def with_session(require: bool = True) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]:
|
|
93
|
+
...
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def with_session(require: bool = True) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
|
|
91
101
|
def with_session(require: bool = True):
|
|
92
102
|
"""Decorator that ensures a session is active.
|
|
93
103
|
|
|
@@ -109,29 +119,31 @@ def with_session(require: bool = True):
|
|
|
109
119
|
```
|
|
110
120
|
"""
|
|
111
121
|
|
|
112
|
-
def decorator(fn: Callable[..., T]) -> Callable[..., T]:
|
|
122
|
+
def decorator(fn: Callable[..., Awaitable[T]] | Callable[..., T]) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
113
123
|
if asyncio.iscoroutinefunction(fn):
|
|
114
124
|
|
|
115
125
|
@functools.wraps(fn)
|
|
116
|
-
async def async_wrapper(*args, **kwargs):
|
|
126
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
117
127
|
session_id = get_session_id()
|
|
118
128
|
if require and session_id is None:
|
|
119
129
|
raise RuntimeError(
|
|
120
130
|
f"No active session for {getattr(fn, '__name__', 'unknown')}"
|
|
121
131
|
)
|
|
122
|
-
|
|
132
|
+
async_fn = cast(Callable[..., Awaitable[T]], fn)
|
|
133
|
+
return await async_fn(*args, **kwargs)
|
|
123
134
|
|
|
124
135
|
return async_wrapper
|
|
125
136
|
else:
|
|
126
137
|
|
|
127
138
|
@functools.wraps(fn)
|
|
128
|
-
def sync_wrapper(*args, **kwargs):
|
|
139
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
129
140
|
session_id = get_session_id()
|
|
130
141
|
if require and session_id is None:
|
|
131
142
|
raise RuntimeError(
|
|
132
143
|
f"No active session for {getattr(fn, '__name__', 'unknown')}"
|
|
133
144
|
)
|
|
134
|
-
|
|
145
|
+
sync_fn = cast(Callable[..., T], fn)
|
|
146
|
+
return sync_fn(*args, **kwargs)
|
|
135
147
|
|
|
136
148
|
return sync_wrapper
|
|
137
149
|
|
|
@@ -172,31 +184,36 @@ def trace_llm_call(
|
|
|
172
184
|
```
|
|
173
185
|
"""
|
|
174
186
|
|
|
175
|
-
def decorator(fn: Callable[..., T]) -> Callable[..., T]:
|
|
187
|
+
def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
|
|
176
188
|
if asyncio.iscoroutinefunction(fn):
|
|
189
|
+
async_fn: Callable[..., Awaitable[T]] = fn
|
|
177
190
|
|
|
178
191
|
@functools.wraps(fn)
|
|
179
|
-
async def async_wrapper(*args, **kwargs):
|
|
192
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
180
193
|
tracer = get_session_tracer()
|
|
181
194
|
if not tracer:
|
|
182
|
-
return await
|
|
195
|
+
return await async_fn(*args, **kwargs)
|
|
183
196
|
|
|
184
197
|
start_time = time.time()
|
|
185
198
|
system_state_before = kwargs.get("state_before", {})
|
|
186
199
|
|
|
187
200
|
try:
|
|
188
|
-
result = await
|
|
201
|
+
result = await async_fn(*args, **kwargs)
|
|
189
202
|
|
|
190
203
|
# Extract metrics from result - this assumes the result follows
|
|
191
204
|
# common LLM API response formats (OpenAI, Anthropic, etc.)
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
205
|
+
input_tokens = output_tokens = total_tokens = None
|
|
206
|
+
actual_model = model_name
|
|
207
|
+
if extract_tokens and isinstance(result, Mapping):
|
|
208
|
+
result_mapping = cast(Mapping[str, Any], result)
|
|
209
|
+
usage = result_mapping.get("usage")
|
|
210
|
+
if isinstance(usage, Mapping):
|
|
211
|
+
input_tokens = usage.get("prompt_tokens")
|
|
212
|
+
output_tokens = usage.get("completion_tokens")
|
|
213
|
+
total_tokens = usage.get("total_tokens")
|
|
214
|
+
value = result_mapping.get("model")
|
|
215
|
+
if isinstance(value, str):
|
|
216
|
+
actual_model = value
|
|
200
217
|
|
|
201
218
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
202
219
|
|
|
@@ -272,19 +289,26 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
|
|
|
272
289
|
```
|
|
273
290
|
"""
|
|
274
291
|
|
|
275
|
-
def decorator(
|
|
292
|
+
def decorator(
|
|
293
|
+
fn: Callable[..., Awaitable[T]] | Callable[..., T]
|
|
294
|
+
) -> Callable[..., Awaitable[T]] | Callable[..., T]:
|
|
276
295
|
if asyncio.iscoroutinefunction(fn):
|
|
296
|
+
async_fn = cast(Callable[..., Awaitable[T]], fn)
|
|
277
297
|
|
|
278
298
|
@functools.wraps(fn)
|
|
279
|
-
async def async_wrapper(
|
|
299
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
280
300
|
tracer = get_session_tracer()
|
|
281
301
|
if not tracer:
|
|
282
|
-
return await
|
|
302
|
+
return await async_fn(*args, **kwargs)
|
|
283
303
|
|
|
284
304
|
from .abstractions import RuntimeEvent
|
|
285
305
|
|
|
286
306
|
# Use class name as system_id if not provided
|
|
287
|
-
|
|
307
|
+
self_obj = args[0] if args else None
|
|
308
|
+
inferred_system_id = (
|
|
309
|
+
self_obj.__class__.__name__ if self_obj is not None else "unknown"
|
|
310
|
+
)
|
|
311
|
+
actual_system_id = system_id or inferred_system_id
|
|
288
312
|
|
|
289
313
|
event = RuntimeEvent(
|
|
290
314
|
system_instance_id=actual_system_id,
|
|
@@ -298,17 +322,18 @@ def trace_method(event_type: str = "runtime", system_id: str | None = None):
|
|
|
298
322
|
)
|
|
299
323
|
|
|
300
324
|
await tracer.record_event(event)
|
|
301
|
-
return await
|
|
325
|
+
return await async_fn(*args, **kwargs)
|
|
302
326
|
|
|
303
327
|
return async_wrapper
|
|
304
328
|
else:
|
|
305
329
|
|
|
306
330
|
@functools.wraps(fn)
|
|
307
|
-
def sync_wrapper(
|
|
331
|
+
def sync_wrapper(*args: Any, **kwargs: Any) -> T:
|
|
308
332
|
# For sync methods, we can't easily trace without blocking
|
|
309
333
|
# the event loop. This is a limitation of the async-first design.
|
|
310
334
|
# Consider converting to async or using a different approach
|
|
311
|
-
|
|
335
|
+
sync_fn = cast(Callable[..., T], fn)
|
|
336
|
+
return sync_fn(*args, **kwargs)
|
|
312
337
|
|
|
313
338
|
return sync_wrapper
|
|
314
339
|
|
|
@@ -2,13 +2,14 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
+
from typing import Any
|
|
5
6
|
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
from
|
|
7
|
+
from .. import SessionTracer
|
|
8
|
+
from ..abstractions import EnvironmentEvent, LMCAISEvent, RuntimeEvent, TimeRecord
|
|
9
|
+
from ..turso.daemon import SqldDaemon
|
|
9
10
|
|
|
10
11
|
|
|
11
|
-
async def simulate_llm_call(model: str, prompt: str) -> dict:
|
|
12
|
+
async def simulate_llm_call(model: str, prompt: str) -> dict[str, Any]:
|
|
12
13
|
"""Simulate an LLM API call."""
|
|
13
14
|
await asyncio.sleep(0.1) # Simulate network latency
|
|
14
15
|
|
|
@@ -133,6 +134,9 @@ async def main():
|
|
|
133
134
|
print("\n--- Example 3: Querying Data ---")
|
|
134
135
|
|
|
135
136
|
# Get model usage statistics
|
|
137
|
+
if tracer.db is None:
|
|
138
|
+
raise RuntimeError("Tracer database backend is not initialized")
|
|
139
|
+
|
|
136
140
|
model_usage = await tracer.db.get_model_usage()
|
|
137
141
|
print("\nModel Usage:")
|
|
138
142
|
print(model_usage)
|
|
@@ -150,9 +154,10 @@ async def main():
|
|
|
150
154
|
# Get specific session details
|
|
151
155
|
if recent_sessions:
|
|
152
156
|
session_detail = await tracer.db.get_session_trace(recent_sessions[0]["session_id"])
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
157
|
+
if session_detail:
|
|
158
|
+
print(f"\nSession Detail for {session_detail['session_id']}:")
|
|
159
|
+
print(f" Created: {session_detail['created_at']}")
|
|
160
|
+
print(f" Timesteps: {len(session_detail['timesteps'])}")
|
|
156
161
|
|
|
157
162
|
# Example 4: Using hooks
|
|
158
163
|
print("\n--- Example 4: Hooks ---")
|