tetherai-python 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tetherai/__init__.py +42 -0
- tetherai/_version.py +1 -0
- tetherai/budget.py +118 -0
- tetherai/circuit_breaker.py +132 -0
- tetherai/config.py +91 -0
- tetherai/crewai/__init__.py +3 -0
- tetherai/crewai/integration.py +68 -0
- tetherai/exceptions.py +92 -0
- tetherai/exporter.py +60 -0
- tetherai/interceptor.py +258 -0
- tetherai/pricing.py +99 -0
- tetherai/token_counter.py +150 -0
- tetherai/trace.py +117 -0
- tetherai_python-0.1.0a0.dist-info/METADATA +35 -0
- tetherai_python-0.1.0a0.dist-info/RECORD +17 -0
- tetherai_python-0.1.0a0.dist-info/WHEEL +5 -0
- tetherai_python-0.1.0a0.dist-info/top_level.txt +1 -0
tetherai/__init__.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from tetherai._version import __version__
|
|
4
|
+
from tetherai.circuit_breaker import enforce_budget
|
|
5
|
+
from tetherai.config import TetherConfig, load_config
|
|
6
|
+
from tetherai.exceptions import (
|
|
7
|
+
BudgetExceededError,
|
|
8
|
+
TetherError,
|
|
9
|
+
TokenCountError,
|
|
10
|
+
TurnLimitError,
|
|
11
|
+
UnknownModelError,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Tether:
|
|
16
|
+
"""TetherAI namespace class."""
|
|
17
|
+
|
|
18
|
+
enforce_budget = staticmethod(enforce_budget)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
tether = Tether
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def protect_crew(*args: Any, **kwargs: Any) -> Any:
|
|
25
|
+
from tetherai.crewai.integration import protect_crew as _protect_crew
|
|
26
|
+
|
|
27
|
+
return _protect_crew(*args, **kwargs)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"BudgetExceededError",
|
|
32
|
+
"TetherConfig",
|
|
33
|
+
"TetherError",
|
|
34
|
+
"TokenCountError",
|
|
35
|
+
"TurnLimitError",
|
|
36
|
+
"UnknownModelError",
|
|
37
|
+
"__version__",
|
|
38
|
+
"enforce_budget",
|
|
39
|
+
"load_config",
|
|
40
|
+
"protect_crew",
|
|
41
|
+
"tether",
|
|
42
|
+
]
|
tetherai/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0-alpha"
|
tetherai/budget.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import threading
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from tetherai.exceptions import BudgetExceededError, TurnLimitError
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class CallRecord:
|
|
10
|
+
input_tokens: int
|
|
11
|
+
output_tokens: int
|
|
12
|
+
model: str
|
|
13
|
+
cost_usd: float
|
|
14
|
+
duration_ms: float
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BudgetTracker:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
run_id: str,
|
|
21
|
+
max_usd: float,
|
|
22
|
+
max_turns: int | None = None,
|
|
23
|
+
):
|
|
24
|
+
self.run_id = run_id
|
|
25
|
+
self.max_usd = max_usd
|
|
26
|
+
self.max_turns = max_turns
|
|
27
|
+
self._spent_usd = 0.0
|
|
28
|
+
self._turn_count = 0
|
|
29
|
+
self._calls: list[CallRecord] = []
|
|
30
|
+
self._lock = threading.Lock()
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def spent_usd(self) -> float:
|
|
34
|
+
with self._lock:
|
|
35
|
+
return self._spent_usd
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def remaining_usd(self) -> float:
|
|
39
|
+
with self._lock:
|
|
40
|
+
return max(0.0, self.max_usd - self._spent_usd)
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def turn_count(self) -> int:
|
|
44
|
+
with self._lock:
|
|
45
|
+
return self._turn_count
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def is_exceeded(self) -> bool:
|
|
49
|
+
with self._lock:
|
|
50
|
+
return self._spent_usd >= self.max_usd
|
|
51
|
+
|
|
52
|
+
def pre_check(self, estimated_input_cost: float) -> None:
|
|
53
|
+
with self._lock:
|
|
54
|
+
projected = self._spent_usd + estimated_input_cost
|
|
55
|
+
if projected > self.max_usd:
|
|
56
|
+
raise BudgetExceededError(
|
|
57
|
+
message=f"Budget exceeded: ${projected:.2f} > ${self.max_usd:.2f}",
|
|
58
|
+
run_id=self.run_id,
|
|
59
|
+
budget_usd=self.max_usd,
|
|
60
|
+
spent_usd=projected,
|
|
61
|
+
last_model="unknown",
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def record_call(
|
|
65
|
+
self,
|
|
66
|
+
input_tokens: int,
|
|
67
|
+
output_tokens: int,
|
|
68
|
+
model: str,
|
|
69
|
+
cost_usd: float,
|
|
70
|
+
duration_ms: float,
|
|
71
|
+
) -> None:
|
|
72
|
+
if cost_usd < 0:
|
|
73
|
+
raise ValueError("cost_usd must be non-negative")
|
|
74
|
+
|
|
75
|
+
with self._lock:
|
|
76
|
+
if self.max_turns is not None and self._turn_count >= self.max_turns:
|
|
77
|
+
raise TurnLimitError(
|
|
78
|
+
message=f"Turn limit exceeded: {self._turn_count} >= {self.max_turns}",
|
|
79
|
+
run_id=self.run_id,
|
|
80
|
+
max_turns=self.max_turns,
|
|
81
|
+
current_turn=self._turn_count + 1,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
self._spent_usd += cost_usd
|
|
85
|
+
self._turn_count += 1
|
|
86
|
+
|
|
87
|
+
self._calls.append(
|
|
88
|
+
CallRecord(
|
|
89
|
+
input_tokens=input_tokens,
|
|
90
|
+
output_tokens=output_tokens,
|
|
91
|
+
model=model,
|
|
92
|
+
cost_usd=cost_usd,
|
|
93
|
+
duration_ms=duration_ms,
|
|
94
|
+
)
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if self._spent_usd > self.max_usd:
|
|
98
|
+
self._spent_usd = self.max_usd
|
|
99
|
+
|
|
100
|
+
def get_summary(self) -> dict[str, Any]:
|
|
101
|
+
with self._lock:
|
|
102
|
+
return {
|
|
103
|
+
"run_id": self.run_id,
|
|
104
|
+
"budget_usd": self.max_usd,
|
|
105
|
+
"spent_usd": self._spent_usd,
|
|
106
|
+
"remaining_usd": max(0.0, self.max_usd - self._spent_usd),
|
|
107
|
+
"turn_count": self._turn_count,
|
|
108
|
+
"calls": [
|
|
109
|
+
{
|
|
110
|
+
"input_tokens": call.input_tokens,
|
|
111
|
+
"output_tokens": call.output_tokens,
|
|
112
|
+
"model": call.model,
|
|
113
|
+
"cost_usd": call.cost_usd,
|
|
114
|
+
"duration_ms": call.duration_ms,
|
|
115
|
+
}
|
|
116
|
+
for call in self._calls
|
|
117
|
+
],
|
|
118
|
+
}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import uuid
|
|
4
|
+
from collections.abc import Callable
|
|
5
|
+
from typing import Any, TypeVar
|
|
6
|
+
|
|
7
|
+
from tetherai.budget import BudgetTracker
|
|
8
|
+
from tetherai.config import TetherConfig
|
|
9
|
+
from tetherai.exceptions import BudgetExceededError
|
|
10
|
+
from tetherai.exporter import get_exporter
|
|
11
|
+
from tetherai.interceptor import LLMInterceptor
|
|
12
|
+
from tetherai.pricing import PricingRegistry
|
|
13
|
+
from tetherai.token_counter import TokenCounter
|
|
14
|
+
from tetherai.trace import TraceCollector
|
|
15
|
+
|
|
16
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def enforce_budget(
|
|
20
|
+
max_usd: float,
|
|
21
|
+
max_turns: int | None = None,
|
|
22
|
+
on_exceed: str = "raise",
|
|
23
|
+
trace_export: str | None = None,
|
|
24
|
+
) -> Callable[[F], F]:
|
|
25
|
+
def decorator(func: F) -> F:
|
|
26
|
+
@functools.wraps(func)
|
|
27
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
28
|
+
return _run_with_budget(
|
|
29
|
+
func, max_usd, max_turns, on_exceed, trace_export, *args, **kwargs
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
@functools.wraps(func)
|
|
33
|
+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
34
|
+
return await _run_with_budget_async(
|
|
35
|
+
func, max_usd, max_turns, on_exceed, trace_export, *args, **kwargs
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
if asyncio.iscoroutinefunction(func):
|
|
39
|
+
return async_wrapper # type: ignore[return-value]
|
|
40
|
+
return wrapper # type: ignore[return-value]
|
|
41
|
+
|
|
42
|
+
return decorator
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _run_with_budget(
|
|
46
|
+
func: Callable[..., Any],
|
|
47
|
+
max_usd: float,
|
|
48
|
+
max_turns: int | None,
|
|
49
|
+
on_exceed: str,
|
|
50
|
+
trace_export: str | None,
|
|
51
|
+
*args: Any,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
) -> Any:
|
|
54
|
+
run_id = f"run-{uuid.uuid4().hex[:8]}"
|
|
55
|
+
config = TetherConfig()
|
|
56
|
+
|
|
57
|
+
if trace_export is None:
|
|
58
|
+
trace_export = config.trace_export
|
|
59
|
+
|
|
60
|
+
budget_tracker = BudgetTracker(run_id=run_id, max_usd=max_usd, max_turns=max_turns)
|
|
61
|
+
token_counter = TokenCounter()
|
|
62
|
+
pricing = PricingRegistry()
|
|
63
|
+
trace_collector = TraceCollector()
|
|
64
|
+
|
|
65
|
+
trace_collector.start_trace(run_id, budget_tracker.get_summary())
|
|
66
|
+
|
|
67
|
+
interceptor = LLMInterceptor(
|
|
68
|
+
budget_tracker=budget_tracker,
|
|
69
|
+
token_counter=token_counter,
|
|
70
|
+
pricing=pricing,
|
|
71
|
+
trace_collector=trace_collector,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
interceptor.activate()
|
|
76
|
+
result = func(*args, **kwargs)
|
|
77
|
+
return result
|
|
78
|
+
except BudgetExceededError:
|
|
79
|
+
if on_exceed == "return_none":
|
|
80
|
+
return None
|
|
81
|
+
raise
|
|
82
|
+
finally:
|
|
83
|
+
interceptor.deactivate()
|
|
84
|
+
trace = trace_collector.end_trace()
|
|
85
|
+
if trace and trace_export != "none":
|
|
86
|
+
exporter = get_exporter(trace_export, config.trace_export_path)
|
|
87
|
+
exporter.export(trace)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
async def _run_with_budget_async(
|
|
91
|
+
func: Callable[..., Any],
|
|
92
|
+
max_usd: float,
|
|
93
|
+
max_turns: int | None,
|
|
94
|
+
on_exceed: str,
|
|
95
|
+
trace_export: str | None,
|
|
96
|
+
*args: Any,
|
|
97
|
+
**kwargs: Any,
|
|
98
|
+
) -> Any:
|
|
99
|
+
run_id = f"run-{uuid.uuid4().hex[:8]}"
|
|
100
|
+
config = TetherConfig()
|
|
101
|
+
|
|
102
|
+
if trace_export is None:
|
|
103
|
+
trace_export = config.trace_export
|
|
104
|
+
|
|
105
|
+
budget_tracker = BudgetTracker(run_id=run_id, max_usd=max_usd, max_turns=max_turns)
|
|
106
|
+
token_counter = TokenCounter()
|
|
107
|
+
pricing = PricingRegistry()
|
|
108
|
+
trace_collector = TraceCollector()
|
|
109
|
+
|
|
110
|
+
trace_collector.start_trace(run_id, budget_tracker.get_summary())
|
|
111
|
+
|
|
112
|
+
interceptor = LLMInterceptor(
|
|
113
|
+
budget_tracker=budget_tracker,
|
|
114
|
+
token_counter=token_counter,
|
|
115
|
+
pricing=pricing,
|
|
116
|
+
trace_collector=trace_collector,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
interceptor.activate()
|
|
121
|
+
result = await func(*args, **kwargs)
|
|
122
|
+
return result
|
|
123
|
+
except BudgetExceededError:
|
|
124
|
+
if on_exceed == "return_none":
|
|
125
|
+
return None
|
|
126
|
+
raise
|
|
127
|
+
finally:
|
|
128
|
+
interceptor.deactivate()
|
|
129
|
+
trace = trace_collector.end_trace()
|
|
130
|
+
if trace and trace_export != "none":
|
|
131
|
+
exporter = get_exporter(trace_export, config.trace_export_path)
|
|
132
|
+
exporter.export(trace)
|
tetherai/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Literal, cast
|
|
4
|
+
|
|
5
|
+
TokenCounterBackend = Literal["tiktoken", "litellm", "auto"]
|
|
6
|
+
PricingSource = Literal["bundled", "litellm"]
|
|
7
|
+
TraceExport = Literal["console", "json", "none", "otlp"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class TetherConfig:
|
|
12
|
+
collector_url: str | None = None
|
|
13
|
+
default_budget_usd: float = 10.0
|
|
14
|
+
default_max_turns: int = 50
|
|
15
|
+
token_counter_backend: TokenCounterBackend = "auto"
|
|
16
|
+
pricing_source: PricingSource = "bundled"
|
|
17
|
+
log_level: str = "WARNING"
|
|
18
|
+
trace_export: TraceExport = "console"
|
|
19
|
+
trace_export_path: str = "./tetherai_traces/"
|
|
20
|
+
|
|
21
|
+
def __post_init__(self) -> None:
|
|
22
|
+
if self.default_budget_usd < 0:
|
|
23
|
+
raise ValueError("default_budget_usd must be non-negative")
|
|
24
|
+
|
|
25
|
+
if self.default_max_turns is not None and self.default_max_turns < 0:
|
|
26
|
+
raise ValueError("default_max_turns must be non-negative")
|
|
27
|
+
|
|
28
|
+
valid_backends = ("tiktoken", "litellm", "auto")
|
|
29
|
+
if self.token_counter_backend not in valid_backends:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Invalid token_counter_backend: {self.token_counter_backend}. "
|
|
32
|
+
f"Must be one of {valid_backends}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
valid_pricing = ("bundled", "litellm")
|
|
36
|
+
if self.pricing_source not in valid_pricing:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
f"Invalid pricing_source: {self.pricing_source}. Must be one of {valid_pricing}"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
valid_export = ("console", "json", "none", "otlp")
|
|
42
|
+
if self.trace_export not in valid_export:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Invalid trace_export: {self.trace_export}. Must be one of {valid_export}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_env(cls) -> "TetherConfig":
|
|
49
|
+
return cls(
|
|
50
|
+
collector_url=os.getenv("TETHERAI_COLLECTOR_URL"),
|
|
51
|
+
default_budget_usd=float(os.getenv("TETHERAI_DEFAULT_BUDGET_USD", "10.0")),
|
|
52
|
+
default_max_turns=int(os.getenv("TETHERAI_DEFAULT_MAX_TURNS", "50")),
|
|
53
|
+
token_counter_backend=cls._resolve_backend(
|
|
54
|
+
os.getenv("TETHERAI_TOKEN_COUNTER_BACKEND", "auto")
|
|
55
|
+
),
|
|
56
|
+
pricing_source=cast(
|
|
57
|
+
PricingSource, os.getenv("TETHERAI_PRICING_SOURCE", "bundled") or "bundled"
|
|
58
|
+
),
|
|
59
|
+
log_level=os.getenv("TETHERAI_LOG_LEVEL", "WARNING"),
|
|
60
|
+
trace_export=cast(
|
|
61
|
+
TraceExport, os.getenv("TETHERAI_TRACE_EXPORT", "console") or "console"
|
|
62
|
+
),
|
|
63
|
+
trace_export_path=os.getenv("TETHERAI_TRACE_EXPORT_PATH", "./tetherai_traces/"),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
@staticmethod
|
|
67
|
+
def _resolve_backend(backend: str) -> TokenCounterBackend:
|
|
68
|
+
if backend == "auto":
|
|
69
|
+
try:
|
|
70
|
+
import litellm # noqa: F401
|
|
71
|
+
|
|
72
|
+
return "litellm"
|
|
73
|
+
except ImportError:
|
|
74
|
+
return "tiktoken"
|
|
75
|
+
return backend # type: ignore[return-value]
|
|
76
|
+
|
|
77
|
+
def resolve_backend(self) -> TokenCounterBackend:
|
|
78
|
+
return self._resolve_backend(self.token_counter_backend)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def load_config(**kwargs: Any) -> TetherConfig:
|
|
82
|
+
env_config = TetherConfig.from_env()
|
|
83
|
+
|
|
84
|
+
config_dict = {}
|
|
85
|
+
for field_name in TetherConfig.__dataclass_fields__:
|
|
86
|
+
if field_name in kwargs:
|
|
87
|
+
config_dict[field_name] = kwargs[field_name]
|
|
88
|
+
else:
|
|
89
|
+
config_dict[field_name] = getattr(env_config, field_name)
|
|
90
|
+
|
|
91
|
+
return TetherConfig(**config_dict)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
if TYPE_CHECKING:
|
|
5
|
+
from crewai import Crew
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _check_crewai_installed() -> None:
|
|
9
|
+
try:
|
|
10
|
+
import crewai # noqa: F401
|
|
11
|
+
except ImportError:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"crewai is not installed. Install it with: pip install tetherai[crewai]"
|
|
14
|
+
) from None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def protect_crew(
|
|
18
|
+
crew: "Crew",
|
|
19
|
+
max_usd: float,
|
|
20
|
+
max_turns: int | None = None,
|
|
21
|
+
) -> "Crew":
|
|
22
|
+
_check_crewai_installed()
|
|
23
|
+
|
|
24
|
+
from tetherai.circuit_breaker import enforce_budget
|
|
25
|
+
|
|
26
|
+
original_kickoff = crew.kickoff
|
|
27
|
+
|
|
28
|
+
@enforce_budget(max_usd=max_usd, max_turns=max_turns)
|
|
29
|
+
def wrapped_kickoff(*args: Any, **kwargs: Any) -> Any:
|
|
30
|
+
return original_kickoff(*args, **kwargs)
|
|
31
|
+
|
|
32
|
+
crew.kickoff = wrapped_kickoff # type: ignore[method-assign]
|
|
33
|
+
|
|
34
|
+
for agent in crew.agents:
|
|
35
|
+
original_step_callback = getattr(agent, "step_callback", None)
|
|
36
|
+
|
|
37
|
+
def make_callback(original: Callable[..., Any] | None) -> Callable[..., Any]:
|
|
38
|
+
def callback(step_output: Any) -> None:
|
|
39
|
+
if original:
|
|
40
|
+
original(step_output)
|
|
41
|
+
|
|
42
|
+
return callback
|
|
43
|
+
|
|
44
|
+
if original_step_callback is not None:
|
|
45
|
+
agent.step_callback = make_callback(original_step_callback) # type: ignore[attr-defined]
|
|
46
|
+
|
|
47
|
+
for task in crew.tasks:
|
|
48
|
+
original_task_callback = getattr(task, "callback", None)
|
|
49
|
+
|
|
50
|
+
def make_task_callback(original: Callable[..., Any] | None) -> Callable[..., Any]:
|
|
51
|
+
def callback(task_output: Any) -> None:
|
|
52
|
+
if original:
|
|
53
|
+
original(task_output)
|
|
54
|
+
|
|
55
|
+
return callback
|
|
56
|
+
|
|
57
|
+
if original_task_callback is not None:
|
|
58
|
+
task.callback = make_task_callback(original_task_callback)
|
|
59
|
+
|
|
60
|
+
return crew
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def tether_step_callback(step_output: Any) -> None:
|
|
64
|
+
pass
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def tether_task_callback(task_output: Any) -> None:
|
|
68
|
+
pass
|
tetherai/exceptions.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
class TetherError(Exception):
|
|
2
|
+
"""Base exception for all TetherAI errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BudgetExceededError(TetherError):
|
|
6
|
+
"""Raised when a run's accumulated cost exceeds its budget."""
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
message: str,
|
|
11
|
+
run_id: str,
|
|
12
|
+
budget_usd: float,
|
|
13
|
+
spent_usd: float,
|
|
14
|
+
last_model: str,
|
|
15
|
+
trace_url: str | None = None,
|
|
16
|
+
) -> None:
|
|
17
|
+
super().__init__(message)
|
|
18
|
+
self.run_id = run_id
|
|
19
|
+
self.budget_usd = budget_usd
|
|
20
|
+
self.spent_usd = spent_usd
|
|
21
|
+
self.last_model = last_model
|
|
22
|
+
self.trace_url = trace_url
|
|
23
|
+
|
|
24
|
+
def __str__(self) -> str:
|
|
25
|
+
return (
|
|
26
|
+
f"Budget exceeded: ${self.spent_usd:.2f} / ${self.budget_usd:.2f} on run {self.run_id}"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
def __reduce__(self) -> tuple[type, tuple]: # type: ignore[type-arg]
|
|
30
|
+
return (
|
|
31
|
+
self.__class__,
|
|
32
|
+
(
|
|
33
|
+
self.args[0],
|
|
34
|
+
self.run_id,
|
|
35
|
+
self.budget_usd,
|
|
36
|
+
self.spent_usd,
|
|
37
|
+
self.last_model,
|
|
38
|
+
self.trace_url,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TurnLimitError(TetherError):
|
|
44
|
+
"""Raised when an agent exceeds max allowed LLM calls."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
message: str,
|
|
49
|
+
run_id: str,
|
|
50
|
+
max_turns: int,
|
|
51
|
+
current_turn: int,
|
|
52
|
+
) -> None:
|
|
53
|
+
super().__init__(message)
|
|
54
|
+
self.run_id = run_id
|
|
55
|
+
self.max_turns = max_turns
|
|
56
|
+
self.current_turn = current_turn
|
|
57
|
+
|
|
58
|
+
def __str__(self) -> str:
|
|
59
|
+
return f"Turn limit exceeded: {self.current_turn} / {self.max_turns} on run {self.run_id}"
|
|
60
|
+
|
|
61
|
+
def __reduce__(self) -> tuple[type, tuple]: # type: ignore[type-arg]
|
|
62
|
+
return (
|
|
63
|
+
self.__class__,
|
|
64
|
+
(
|
|
65
|
+
self.args[0],
|
|
66
|
+
self.run_id,
|
|
67
|
+
self.max_turns,
|
|
68
|
+
self.current_turn,
|
|
69
|
+
),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TokenCountError(TetherError):
|
|
74
|
+
"""Raised when token counting fails (e.g., unknown encoding)."""
|
|
75
|
+
|
|
76
|
+
def __init__(self, message: str, model: str | None = None) -> None:
|
|
77
|
+
super().__init__(message)
|
|
78
|
+
self.model = model
|
|
79
|
+
|
|
80
|
+
def __reduce__(self) -> tuple[type, tuple]: # type: ignore[type-arg]
|
|
81
|
+
return (self.__class__, (self.args[0], self.model))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class UnknownModelError(TetherError):
|
|
85
|
+
"""Raised when a model is not found in the pricing registry."""
|
|
86
|
+
|
|
87
|
+
def __init__(self, message: str, model: str) -> None:
|
|
88
|
+
super().__init__(message)
|
|
89
|
+
self.model = model
|
|
90
|
+
|
|
91
|
+
def __reduce__(self) -> tuple[type, tuple]: # type: ignore[type-arg]
|
|
92
|
+
return (self.__class__, (self.args[0], self.model))
|
tetherai/exporter.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Protocol, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from tetherai.trace import Trace
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@runtime_checkable
|
|
10
|
+
class TraceExporter(Protocol):
|
|
11
|
+
def export(self, trace: Trace) -> None: ...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ConsoleExporter:
|
|
15
|
+
def export(self, trace: Trace) -> None:
|
|
16
|
+
print(f"=== TetherAI Trace: {trace.run_id} ===", file=sys.stderr)
|
|
17
|
+
print(f"Total Cost: ${trace.total_cost:.4f}", file=sys.stderr)
|
|
18
|
+
print(f"Input Tokens: {trace.total_input_tokens}", file=sys.stderr)
|
|
19
|
+
print(f"Output Tokens: {trace.total_output_tokens}", file=sys.stderr)
|
|
20
|
+
print(f"Spans: {len(trace.spans)}", file=sys.stderr)
|
|
21
|
+
print(file=sys.stderr)
|
|
22
|
+
|
|
23
|
+
for i, span in enumerate(trace.spans):
|
|
24
|
+
print(f" [{i + 1}] {span.span_type}: {span.model or 'N/A'}", file=sys.stderr)
|
|
25
|
+
if span.cost_usd is not None:
|
|
26
|
+
print(f" Cost: ${span.cost_usd:.6f}", file=sys.stderr)
|
|
27
|
+
if span.input_tokens:
|
|
28
|
+
print(f" Input: {span.input_tokens} tokens", file=sys.stderr)
|
|
29
|
+
if span.output_tokens:
|
|
30
|
+
print(f" Output: {span.output_tokens} tokens", file=sys.stderr)
|
|
31
|
+
print(file=sys.stderr)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class JSONFileExporter:
|
|
35
|
+
def __init__(self, output_dir: str = "./tetherai_traces/"):
|
|
36
|
+
self.output_dir = Path(output_dir)
|
|
37
|
+
|
|
38
|
+
def export(self, trace: Trace) -> None:
|
|
39
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
filename = f"{trace.run_id}.json"
|
|
41
|
+
filepath = self.output_dir / filename
|
|
42
|
+
|
|
43
|
+
with open(filepath, "w") as f:
|
|
44
|
+
json.dump(trace.to_dict(), f, indent=2)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class NoopExporter:
|
|
48
|
+
def export(self, trace: Trace) -> None:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def get_exporter(exporter_type: str, output_dir: str = "./tetherai_traces/") -> TraceExporter:
|
|
53
|
+
if exporter_type == "console":
|
|
54
|
+
return ConsoleExporter()
|
|
55
|
+
elif exporter_type == "json":
|
|
56
|
+
return JSONFileExporter(output_dir=output_dir)
|
|
57
|
+
elif exporter_type == "none" or exporter_type == "noop":
|
|
58
|
+
return NoopExporter()
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Unknown exporter type: {exporter_type}")
|
tetherai/interceptor.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from tetherai.budget import BudgetTracker
|
|
6
|
+
from tetherai.exceptions import BudgetExceededError, TetherError
|
|
7
|
+
from tetherai.pricing import PricingRegistry
|
|
8
|
+
from tetherai.token_counter import TokenCounter
|
|
9
|
+
from tetherai.trace import Span, TraceCollector
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LLMInterceptor:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
budget_tracker: BudgetTracker,
|
|
16
|
+
token_counter: TokenCounter,
|
|
17
|
+
pricing: PricingRegistry,
|
|
18
|
+
trace_collector: TraceCollector,
|
|
19
|
+
):
|
|
20
|
+
self.budget_tracker = budget_tracker
|
|
21
|
+
self.token_counter = token_counter
|
|
22
|
+
self.pricing = pricing
|
|
23
|
+
self.trace_collector = trace_collector
|
|
24
|
+
|
|
25
|
+
self._original_completion: Callable[..., Any] | None = None
|
|
26
|
+
self._original_acompletion: Callable[..., Any] | None = None
|
|
27
|
+
self._active = False
|
|
28
|
+
|
|
29
|
+
def activate(self) -> None:
|
|
30
|
+
if self._active:
|
|
31
|
+
raise TetherError("Interceptor is already active")
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
import litellm
|
|
35
|
+
except ImportError:
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
self._original_completion = litellm.completion
|
|
39
|
+
self._original_acompletion = litellm.acompletion
|
|
40
|
+
|
|
41
|
+
litellm.completion = self._patched_completion
|
|
42
|
+
litellm.acompletion = self._patched_acompletion
|
|
43
|
+
self._active = True
|
|
44
|
+
|
|
45
|
+
def deactivate(self) -> None:
|
|
46
|
+
if not self._active:
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
import litellm
|
|
51
|
+
except ImportError:
|
|
52
|
+
self._active = False
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
if self._original_completion:
|
|
56
|
+
litellm.completion = self._original_completion
|
|
57
|
+
if self._original_acompletion:
|
|
58
|
+
litellm.acompletion = self._original_acompletion
|
|
59
|
+
|
|
60
|
+
self._active = False
|
|
61
|
+
|
|
62
|
+
def __enter__(self) -> "LLMInterceptor":
|
|
63
|
+
self.activate()
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
67
|
+
self.deactivate()
|
|
68
|
+
|
|
69
|
+
def _patched_completion(self, *args: Any, **kwargs: Any) -> Any:
|
|
70
|
+
return self._intercept_call(self._original_completion, *args, **kwargs)
|
|
71
|
+
|
|
72
|
+
async def _patched_acompletion(self, *args: Any, **kwargs: Any) -> Any:
|
|
73
|
+
return await self._intercept_call_async(self._original_acompletion, *args, **kwargs)
|
|
74
|
+
|
|
75
|
+
def _intercept_call(
|
|
76
|
+
self, original_fn: Callable[..., Any] | None, *args: Any, **kwargs: Any
|
|
77
|
+
) -> Any:
|
|
78
|
+
if original_fn is None:
|
|
79
|
+
raise TetherError("Interceptor not properly activated")
|
|
80
|
+
model = kwargs.get("model", args[0] if args else "unknown")
|
|
81
|
+
messages = kwargs.get("messages", [])
|
|
82
|
+
|
|
83
|
+
start_time = time.time()
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
input_tokens = self.token_counter.count_messages(messages, model)
|
|
87
|
+
except Exception:
|
|
88
|
+
input_tokens = 0
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
estimated_input_cost = self.pricing.get_input_cost(model) * input_tokens
|
|
92
|
+
except Exception:
|
|
93
|
+
estimated_input_cost = 0
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
self.budget_tracker.pre_check(estimated_input_cost)
|
|
97
|
+
except BudgetExceededError:
|
|
98
|
+
raise
|
|
99
|
+
|
|
100
|
+
span = Span(
|
|
101
|
+
run_id=self.budget_tracker.run_id,
|
|
102
|
+
span_type="llm_call",
|
|
103
|
+
model=model,
|
|
104
|
+
input_tokens=input_tokens,
|
|
105
|
+
input_preview=messages[0].get("content", "")[:200] if messages else None,
|
|
106
|
+
)
|
|
107
|
+
self.trace_collector.add_span(span)
|
|
108
|
+
|
|
109
|
+
try:
|
|
110
|
+
response = original_fn(*args, **kwargs)
|
|
111
|
+
except Exception:
|
|
112
|
+
span.status = "error"
|
|
113
|
+
span.duration_ms = (time.time() - start_time) * 1000
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
usage = getattr(response, "usage", None)
|
|
120
|
+
if usage:
|
|
121
|
+
output_tokens = getattr(usage, "completion_tokens", 0)
|
|
122
|
+
actual_input_tokens = getattr(usage, "prompt_tokens", input_tokens)
|
|
123
|
+
else:
|
|
124
|
+
output_tokens = 0
|
|
125
|
+
actual_input_tokens = input_tokens
|
|
126
|
+
|
|
127
|
+
cost_usd = self.pricing.estimate_call_cost(model, actual_input_tokens, output_tokens)
|
|
128
|
+
|
|
129
|
+
span.output_tokens = output_tokens
|
|
130
|
+
span.input_tokens = actual_input_tokens
|
|
131
|
+
span.cost_usd = cost_usd
|
|
132
|
+
span.duration_ms = duration_ms
|
|
133
|
+
span.status = "ok"
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
content = response.choices[0].message.content if response.choices else ""
|
|
137
|
+
span.output_preview = content[:200] if content else None
|
|
138
|
+
except Exception:
|
|
139
|
+
pass
|
|
140
|
+
|
|
141
|
+
self.budget_tracker.record_call(
|
|
142
|
+
actual_input_tokens,
|
|
143
|
+
output_tokens,
|
|
144
|
+
model,
|
|
145
|
+
cost_usd,
|
|
146
|
+
duration_ms,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
except Exception:
|
|
150
|
+
pass
|
|
151
|
+
|
|
152
|
+
return response
|
|
153
|
+
|
|
154
|
+
async def _intercept_call_async(
|
|
155
|
+
self, original_fn: Callable[..., Any] | None, *args: Any, **kwargs: Any
|
|
156
|
+
) -> Any:
|
|
157
|
+
if original_fn is None:
|
|
158
|
+
raise TetherError("Interceptor not properly activated")
|
|
159
|
+
model = kwargs.get("model", args[0] if args else "unknown")
|
|
160
|
+
messages = kwargs.get("messages", [])
|
|
161
|
+
|
|
162
|
+
start_time = time.time()
|
|
163
|
+
|
|
164
|
+
try:
|
|
165
|
+
input_tokens = self.token_counter.count_messages(messages, model)
|
|
166
|
+
except Exception:
|
|
167
|
+
input_tokens = 0
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
estimated_input_cost = self.pricing.get_input_cost(model) * input_tokens
|
|
171
|
+
except Exception:
|
|
172
|
+
estimated_input_cost = 0
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
self.budget_tracker.pre_check(estimated_input_cost)
|
|
176
|
+
except BudgetExceededError:
|
|
177
|
+
raise
|
|
178
|
+
|
|
179
|
+
span = Span(
|
|
180
|
+
run_id=self.budget_tracker.run_id,
|
|
181
|
+
span_type="llm_call",
|
|
182
|
+
model=model,
|
|
183
|
+
input_tokens=input_tokens,
|
|
184
|
+
input_preview=messages[0].get("content", "")[:200] if messages else None,
|
|
185
|
+
)
|
|
186
|
+
self.trace_collector.add_span(span)
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
response = await original_fn(*args, **kwargs)
|
|
190
|
+
except Exception:
|
|
191
|
+
span.status = "error"
|
|
192
|
+
span.duration_ms = (time.time() - start_time) * 1000
|
|
193
|
+
raise
|
|
194
|
+
|
|
195
|
+
duration_ms = (time.time() - start_time) * 1000
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
usage = getattr(response, "usage", None)
|
|
199
|
+
if usage:
|
|
200
|
+
output_tokens = getattr(usage, "completion_tokens", 0)
|
|
201
|
+
actual_input_tokens = getattr(usage, "prompt_tokens", input_tokens)
|
|
202
|
+
else:
|
|
203
|
+
output_tokens = 0
|
|
204
|
+
actual_input_tokens = input_tokens
|
|
205
|
+
|
|
206
|
+
cost_usd = self.pricing.estimate_call_cost(model, actual_input_tokens, output_tokens)
|
|
207
|
+
|
|
208
|
+
span.output_tokens = output_tokens
|
|
209
|
+
span.input_tokens = actual_input_tokens
|
|
210
|
+
span.cost_usd = cost_usd
|
|
211
|
+
span.duration_ms = duration_ms
|
|
212
|
+
span.status = "ok"
|
|
213
|
+
|
|
214
|
+
try:
|
|
215
|
+
content = response.choices[0].message.content if response.choices else ""
|
|
216
|
+
span.output_preview = content[:200] if content else None
|
|
217
|
+
except Exception:
|
|
218
|
+
pass
|
|
219
|
+
|
|
220
|
+
self.budget_tracker.record_call(
|
|
221
|
+
actual_input_tokens,
|
|
222
|
+
output_tokens,
|
|
223
|
+
model,
|
|
224
|
+
cost_usd,
|
|
225
|
+
duration_ms,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
except Exception:
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
return response
|
|
232
|
+
|
|
233
|
+
def track_call(
|
|
234
|
+
self,
|
|
235
|
+
model: str,
|
|
236
|
+
input_tokens: int,
|
|
237
|
+
output_tokens: int,
|
|
238
|
+
) -> None:
|
|
239
|
+
cost_usd = self.pricing.estimate_call_cost(model, input_tokens, output_tokens)
|
|
240
|
+
self.budget_tracker.pre_check(self.pricing.get_input_cost(model) * input_tokens)
|
|
241
|
+
self.budget_tracker.record_call(
|
|
242
|
+
input_tokens,
|
|
243
|
+
output_tokens,
|
|
244
|
+
model,
|
|
245
|
+
cost_usd,
|
|
246
|
+
0.0,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
span = Span(
|
|
250
|
+
run_id=self.budget_tracker.run_id,
|
|
251
|
+
span_type="llm_call",
|
|
252
|
+
model=model,
|
|
253
|
+
input_tokens=input_tokens,
|
|
254
|
+
output_tokens=output_tokens,
|
|
255
|
+
cost_usd=cost_usd,
|
|
256
|
+
status="ok",
|
|
257
|
+
)
|
|
258
|
+
self.trace_collector.add_span(span)
|
tetherai/pricing.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from tetherai.exceptions import UnknownModelError
|
|
2
|
+
|
|
3
|
+
BUNDLED_PRICING: dict[str, tuple[float, float]] = {
|
|
4
|
+
"gpt-4o": (0.0025, 0.01),
|
|
5
|
+
"gpt-4o-mini": (0.00015, 0.0006),
|
|
6
|
+
"gpt-4-turbo": (0.01, 0.03),
|
|
7
|
+
"gpt-4": (0.03, 0.06),
|
|
8
|
+
"gpt-3.5-turbo": (0.0005, 0.002),
|
|
9
|
+
"claude-3-5-sonnet-20241022": (0.003, 0.015),
|
|
10
|
+
"claude-3-5-sonnet": (0.003, 0.015),
|
|
11
|
+
"claude-3-opus-20240229": (0.015, 0.075),
|
|
12
|
+
"claude-3-opus": (0.015, 0.075),
|
|
13
|
+
"claude-3-sonnet-20240229": (0.003, 0.015),
|
|
14
|
+
"claude-3-sonnet": (0.003, 0.015),
|
|
15
|
+
"claude-3-haiku-20240307": (0.00025, 0.00125),
|
|
16
|
+
"claude-3-haiku": (0.00025, 0.00125),
|
|
17
|
+
"gemini-1.5-pro": (0.00125, 0.005),
|
|
18
|
+
"gemini-1.5-flash": (0.000075, 0.0003),
|
|
19
|
+
"gemini-1.5-flash-8b": (0.0000375, 0.00015),
|
|
20
|
+
"llama-3-70b": (0.0008, 0.0008),
|
|
21
|
+
"llama-3-8b": (0.0002, 0.0002),
|
|
22
|
+
"mixtral-8x7b": (0.00024, 0.00024),
|
|
23
|
+
"mistral-small": (0.001, 0.003),
|
|
24
|
+
"mistral-medium": (0.0024, 0.0072),
|
|
25
|
+
"mistral-large": (0.004, 0.012),
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
MODEL_ALIASES: dict[str, str] = {
|
|
29
|
+
"gpt4o": "gpt-4o",
|
|
30
|
+
"gpt-4o": "gpt-4o",
|
|
31
|
+
"gpt4o-mini": "gpt-4o-mini",
|
|
32
|
+
"gpt-4-turbo": "gpt-4-turbo",
|
|
33
|
+
"gpt4": "gpt-4",
|
|
34
|
+
"gpt-4": "gpt-4",
|
|
35
|
+
"gpt-3.5-turbo": "gpt-3.5-turbo",
|
|
36
|
+
"claude-sonnet": "claude-3-5-sonnet-20241022",
|
|
37
|
+
"claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
|
|
38
|
+
"claude-opus": "claude-3-opus-20240229",
|
|
39
|
+
"claude-3-opus": "claude-3-opus-20240229",
|
|
40
|
+
"claude-sonnet-20240229": "claude-3-sonnet-20240229",
|
|
41
|
+
"claude-3-sonnet-20240229": "claude-3-sonnet-20240229",
|
|
42
|
+
"claude-haiku": "claude-3-haiku-20240307",
|
|
43
|
+
"claude-3-haiku": "claude-3-haiku-20240307",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class PricingRegistry:
|
|
48
|
+
def __init__(self, source: str = "bundled"):
|
|
49
|
+
self._source = source
|
|
50
|
+
self._custom_models: dict[str, tuple[float, float]] = {}
|
|
51
|
+
self._bundled = BUNDLED_PRICING.copy()
|
|
52
|
+
|
|
53
|
+
def get_input_cost(self, model: str) -> float:
|
|
54
|
+
resolved = self.resolve_model_alias(model)
|
|
55
|
+
if resolved in self._custom_models:
|
|
56
|
+
return self._custom_models[resolved][0]
|
|
57
|
+
if resolved in self._bundled:
|
|
58
|
+
return self._bundled[resolved][0]
|
|
59
|
+
if self._source == "litellm":
|
|
60
|
+
return self._get_litellm_cost(model, "input")
|
|
61
|
+
raise UnknownModelError(f"Unknown model: {model}", model)
|
|
62
|
+
|
|
63
|
+
def get_output_cost(self, model: str) -> float:
|
|
64
|
+
resolved = self.resolve_model_alias(model)
|
|
65
|
+
if resolved in self._custom_models:
|
|
66
|
+
return self._custom_models[resolved][1]
|
|
67
|
+
if resolved in self._bundled:
|
|
68
|
+
return self._bundled[resolved][1]
|
|
69
|
+
if self._source == "litellm":
|
|
70
|
+
return self._get_litellm_cost(model, "output")
|
|
71
|
+
raise UnknownModelError(f"Unknown model: {model}", model)
|
|
72
|
+
|
|
73
|
+
def estimate_call_cost(self, model: str, input_tokens: int, output_tokens: int) -> float:
|
|
74
|
+
input_cost = self.get_input_cost(model) * input_tokens
|
|
75
|
+
output_cost = self.get_output_cost(model) * output_tokens
|
|
76
|
+
return input_cost + output_cost
|
|
77
|
+
|
|
78
|
+
def resolve_model_alias(self, model: str) -> str:
|
|
79
|
+
normalized = model.lower().strip()
|
|
80
|
+
return MODEL_ALIASES.get(normalized, model)
|
|
81
|
+
|
|
82
|
+
def register_custom_model(self, model: str, input_cost: float, output_cost: float) -> None:
|
|
83
|
+
self._custom_models[model] = (input_cost, output_cost)
|
|
84
|
+
|
|
85
|
+
def _get_litellm_cost(self, model: str, direction: str) -> float:
|
|
86
|
+
try:
|
|
87
|
+
import litellm
|
|
88
|
+
except ImportError:
|
|
89
|
+
raise UnknownModelError(
|
|
90
|
+
f"Unknown model: {model} (litellm not installed)", model
|
|
91
|
+
) from None
|
|
92
|
+
cost = litellm.cost_per_token(model, direction) # type: ignore[arg-type,attr-defined]
|
|
93
|
+
if isinstance(cost, tuple):
|
|
94
|
+
return cost[0] if direction == "input" else cost[1]
|
|
95
|
+
return cost
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_pricing_registry(source: str = "bundled") -> PricingRegistry:
|
|
99
|
+
return PricingRegistry(source=source)
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from tetherai.exceptions import TokenCountError
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
TOKENIZER_CACHE: dict[str, Any] = {}
|
|
9
|
+
|
|
10
|
+
CHATML_FORMATTING = {
|
|
11
|
+
"system": {"prefix": "<|im_start|>system\n", "suffix": "<|im_end|>\n"},
|
|
12
|
+
"user": {"prefix": "<|im_start|>user\n", "suffix": "<|im_end|>\n"},
|
|
13
|
+
"assistant": {"prefix": "<|im_start|>assistant\n", "suffix": "<|im_end|>\n"},
|
|
14
|
+
"tool": {"prefix": "<|im_start|>tool\n", "suffix": "<|im_end|>\n"},
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _get_tiktoken_encoder(encoding_name: str = "cl100k_base") -> Any:
|
|
19
|
+
if encoding_name in TOKENIZER_CACHE:
|
|
20
|
+
return TOKENIZER_CACHE[encoding_name]
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import tiktoken
|
|
24
|
+
|
|
25
|
+
encoder = tiktoken.get_encoding(encoding_name)
|
|
26
|
+
TOKENIZER_CACHE[encoding_name] = encoder
|
|
27
|
+
return encoder
|
|
28
|
+
except Exception as e:
|
|
29
|
+
raise TokenCountError(f"Failed to load tiktoken: {e}") from e
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _get_litellm_tokenizer(model: str) -> Any:
|
|
33
|
+
try:
|
|
34
|
+
import litellm
|
|
35
|
+
|
|
36
|
+
return litellm.token_counter # type: ignore[attr-defined]
|
|
37
|
+
except ImportError as e:
|
|
38
|
+
raise TokenCountError("litellm not installed", model) from e
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class TokenCounter:
|
|
42
|
+
def __init__(self, backend: str = "auto"):
|
|
43
|
+
self._backend = backend
|
|
44
|
+
self._tiktoken_encoder = None
|
|
45
|
+
self._litellm_tokenizer = None
|
|
46
|
+
|
|
47
|
+
if backend == "auto":
|
|
48
|
+
try:
|
|
49
|
+
import litellm # noqa: F401
|
|
50
|
+
|
|
51
|
+
self._backend = "litellm"
|
|
52
|
+
except ImportError:
|
|
53
|
+
self._backend = "tiktoken"
|
|
54
|
+
|
|
55
|
+
if self._backend == "tiktoken":
|
|
56
|
+
self._tiktoken_encoder = _get_tiktoken_encoder()
|
|
57
|
+
|
|
58
|
+
def count_tokens(self, text: str, model: str = "gpt-4o") -> int:
|
|
59
|
+
if not text:
|
|
60
|
+
return 0
|
|
61
|
+
|
|
62
|
+
if self._backend == "tiktoken":
|
|
63
|
+
return self._count_with_tiktoken(text, model)
|
|
64
|
+
elif self._backend == "litellm":
|
|
65
|
+
return self._count_with_litellm(text, model)
|
|
66
|
+
else:
|
|
67
|
+
raise TokenCountError(f"Unknown backend: {self._backend}")
|
|
68
|
+
|
|
69
|
+
def count_messages(self, messages: list[dict[str, str]], model: str = "gpt-4o") -> int:
|
|
70
|
+
if not messages:
|
|
71
|
+
return 0
|
|
72
|
+
|
|
73
|
+
if self._backend == "tiktoken":
|
|
74
|
+
return self._count_messages_with_tiktoken(messages, model)
|
|
75
|
+
elif self._backend == "litellm":
|
|
76
|
+
return self._count_messages_with_litellm(messages, model)
|
|
77
|
+
else:
|
|
78
|
+
raise TokenCountError(f"Unknown backend: {self._backend}")
|
|
79
|
+
|
|
80
|
+
def _count_with_tiktoken(self, text: str, model: str) -> int:
|
|
81
|
+
if model.startswith("claude-"):
|
|
82
|
+
logger.warning(
|
|
83
|
+
f"Using tiktoken for Claude model {model}. "
|
|
84
|
+
f"Token counts may be inaccurate (up to 12% error)."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
encoder = self._tiktoken_encoder
|
|
88
|
+
if encoder is None:
|
|
89
|
+
encoder = _get_tiktoken_encoder()
|
|
90
|
+
|
|
91
|
+
return len(encoder.encode(text))
|
|
92
|
+
|
|
93
|
+
def _count_with_litellm(self, text: str, model: str) -> int:
|
|
94
|
+
if self._litellm_tokenizer is None:
|
|
95
|
+
self._litellm_tokenizer = _get_litellm_tokenizer(model)
|
|
96
|
+
|
|
97
|
+
if model.startswith("claude-"):
|
|
98
|
+
try:
|
|
99
|
+
return self._litellm_tokenizer(model=model, text=text) # type: ignore[no-any-return,misc]
|
|
100
|
+
except Exception:
|
|
101
|
+
logger.warning(
|
|
102
|
+
f"litellm token_counter failed for {model}, falling back to tiktoken"
|
|
103
|
+
)
|
|
104
|
+
return self._count_with_tiktoken(text, model)
|
|
105
|
+
|
|
106
|
+
return self._litellm_tokenizer(model=model, text=text) # type: ignore[no-any-return,misc]
|
|
107
|
+
|
|
108
|
+
def _count_messages_with_tiktoken(self, messages: list[dict[str, str]], model: str) -> int:
|
|
109
|
+
encoder = self._tiktoken_encoder
|
|
110
|
+
if encoder is None:
|
|
111
|
+
encoder = _get_tiktoken_encoder()
|
|
112
|
+
|
|
113
|
+
if model.startswith("claude-"):
|
|
114
|
+
logger.warning(
|
|
115
|
+
f"Using tiktoken for Claude model {model}. Token counts may be inaccurate."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
total_tokens = 0
|
|
119
|
+
for message in messages:
|
|
120
|
+
role = message.get("role", "user")
|
|
121
|
+
content = message.get("content", "")
|
|
122
|
+
|
|
123
|
+
formatting = CHATML_FORMATTING.get(role, CHATML_FORMATTING["user"])
|
|
124
|
+
formatted = f"{formatting['prefix']}{content}{formatting['suffix']}"
|
|
125
|
+
total_tokens += len(encoder.encode(formatted))
|
|
126
|
+
|
|
127
|
+
total_tokens += 3
|
|
128
|
+
return total_tokens
|
|
129
|
+
|
|
130
|
+
def _count_messages_with_litellm(self, messages: list[dict[str, str]], model: str) -> int:
|
|
131
|
+
if self._litellm_tokenizer is None:
|
|
132
|
+
self._litellm_tokenizer = _get_litellm_tokenizer(model)
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
return self._litellm_tokenizer(model=model, messages=messages) # type: ignore[no-any-return,misc]
|
|
136
|
+
except Exception:
|
|
137
|
+
logger.warning(f"litellm token_counter failed for {model}, falling back to tiktoken")
|
|
138
|
+
return self._count_messages_with_tiktoken(messages, model)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def count_tokens(text: str, model: str = "gpt-4o", backend: str = "auto") -> int:
|
|
142
|
+
counter = TokenCounter(backend=backend)
|
|
143
|
+
return counter.count_tokens(text, model)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def count_messages(
|
|
147
|
+
messages: list[dict[str, str]], model: str = "gpt-4o", backend: str = "auto"
|
|
148
|
+
) -> int:
|
|
149
|
+
counter = TokenCounter(backend=backend)
|
|
150
|
+
return counter.count_messages(messages, model)
|
tetherai/trace.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
MAX_PREVIEW_LENGTH = 200
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def generate_id() -> str:
|
|
9
|
+
import uuid
|
|
10
|
+
|
|
11
|
+
return uuid.uuid4().hex[:16]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Span:
|
|
16
|
+
span_id: str = field(default_factory=generate_id)
|
|
17
|
+
parent_span_id: str | None = None
|
|
18
|
+
run_id: str = ""
|
|
19
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
20
|
+
duration_ms: float = 0.0
|
|
21
|
+
span_type: str = "llm_call"
|
|
22
|
+
model: str | None = None
|
|
23
|
+
input_tokens: int | None = None
|
|
24
|
+
output_tokens: int | None = None
|
|
25
|
+
cost_usd: float | None = None
|
|
26
|
+
status: str = "ok"
|
|
27
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
28
|
+
input_preview: str | None = None
|
|
29
|
+
output_preview: str | None = None
|
|
30
|
+
|
|
31
|
+
def __post_init__(self) -> None:
|
|
32
|
+
if self.input_preview and len(self.input_preview) > MAX_PREVIEW_LENGTH:
|
|
33
|
+
self.input_preview = self.input_preview[:MAX_PREVIEW_LENGTH] + "..."
|
|
34
|
+
|
|
35
|
+
if self.output_preview and len(self.output_preview) > MAX_PREVIEW_LENGTH:
|
|
36
|
+
self.output_preview = self.output_preview[:MAX_PREVIEW_LENGTH] + "..."
|
|
37
|
+
|
|
38
|
+
def to_dict(self) -> dict[str, Any]:
|
|
39
|
+
return {
|
|
40
|
+
"span_id": self.span_id,
|
|
41
|
+
"parent_span_id": self.parent_span_id,
|
|
42
|
+
"run_id": self.run_id,
|
|
43
|
+
"timestamp": self.timestamp.isoformat(),
|
|
44
|
+
"duration_ms": self.duration_ms,
|
|
45
|
+
"span_type": self.span_type,
|
|
46
|
+
"model": self.model,
|
|
47
|
+
"input_tokens": self.input_tokens,
|
|
48
|
+
"output_tokens": self.output_tokens,
|
|
49
|
+
"cost_usd": self.cost_usd,
|
|
50
|
+
"status": self.status,
|
|
51
|
+
"metadata": self.metadata,
|
|
52
|
+
"input_preview": self.input_preview,
|
|
53
|
+
"output_preview": self.output_preview,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class Trace:
|
|
59
|
+
run_id: str
|
|
60
|
+
spans: list[Span] = field(default_factory=list)
|
|
61
|
+
budget_summary: dict[str, Any] = field(default_factory=dict)
|
|
62
|
+
start_time: datetime = field(default_factory=datetime.now)
|
|
63
|
+
end_time: datetime | None = None
|
|
64
|
+
|
|
65
|
+
def add_span(self, span: Span) -> None:
|
|
66
|
+
self.spans.append(span)
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def total_cost(self) -> float:
|
|
70
|
+
return sum(span.cost_usd or 0 for span in self.spans)
|
|
71
|
+
|
|
72
|
+
@property
|
|
73
|
+
def total_input_tokens(self) -> int:
|
|
74
|
+
return sum(span.input_tokens or 0 for span in self.spans)
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def total_output_tokens(self) -> int:
|
|
78
|
+
return sum(span.output_tokens or 0 for span in self.spans)
|
|
79
|
+
|
|
80
|
+
def to_dict(self) -> dict[str, Any]:
|
|
81
|
+
return {
|
|
82
|
+
"run_id": self.run_id,
|
|
83
|
+
"spans": [span.to_dict() for span in self.spans],
|
|
84
|
+
"budget_summary": self.budget_summary,
|
|
85
|
+
"start_time": self.start_time.isoformat(),
|
|
86
|
+
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
87
|
+
"total_cost": self.total_cost,
|
|
88
|
+
"total_input_tokens": self.total_input_tokens,
|
|
89
|
+
"total_output_tokens": self.total_output_tokens,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TraceCollector:
|
|
94
|
+
def __init__(self) -> None:
|
|
95
|
+
self._current_trace: Trace | None = None
|
|
96
|
+
|
|
97
|
+
def start_trace(self, run_id: str, budget_summary: dict[str, Any] | None = None) -> Trace:
|
|
98
|
+
self._current_trace = Trace(
|
|
99
|
+
run_id=run_id,
|
|
100
|
+
budget_summary=budget_summary or {},
|
|
101
|
+
)
|
|
102
|
+
return self._current_trace
|
|
103
|
+
|
|
104
|
+
def end_trace(self) -> Trace | None:
|
|
105
|
+
if self._current_trace:
|
|
106
|
+
self._current_trace.end_time = datetime.now()
|
|
107
|
+
trace = self._current_trace
|
|
108
|
+
self._current_trace = None
|
|
109
|
+
return trace
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def add_span(self, span: Span) -> None:
|
|
113
|
+
if self._current_trace:
|
|
114
|
+
self._current_trace.add_span(span)
|
|
115
|
+
|
|
116
|
+
def get_current_trace(self) -> Trace | None:
|
|
117
|
+
return self._current_trace
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: tetherai-python
|
|
3
|
+
Version: 0.1.0a0
|
|
4
|
+
Summary: AI budget guardrails for LLM applications
|
|
5
|
+
Author-email: TetherAI Engineering <engineering@tetherai.com>
|
|
6
|
+
License: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/tetherai/tetherai-python
|
|
8
|
+
Project-URL: Repository, https://github.com/tetherai/tetherai-python
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
Requires-Dist: tiktoken>=0.7.0
|
|
20
|
+
Provides-Extra: crewai
|
|
21
|
+
Requires-Dist: crewai>=1.0.0; extra == "crewai"
|
|
22
|
+
Provides-Extra: litellm
|
|
23
|
+
Requires-Dist: litellm>=1.40.0; extra == "litellm"
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
26
|
+
Requires-Dist: pytest-cov>=5.0; extra == "dev"
|
|
27
|
+
Requires-Dist: pytest-asyncio>=0.23; extra == "dev"
|
|
28
|
+
Requires-Dist: pytest-mock>=3.12; extra == "dev"
|
|
29
|
+
Requires-Dist: ruff>=0.4; extra == "dev"
|
|
30
|
+
Requires-Dist: mypy>=1.10; extra == "dev"
|
|
31
|
+
Requires-Dist: crewai>=1.0.0; extra == "dev"
|
|
32
|
+
Requires-Dist: litellm>=1.40.0; extra == "dev"
|
|
33
|
+
|
|
34
|
+
[](https://github.com/tetherai/tetherai-python/actions/workflows/ci.yml)
|
|
35
|
+
[](https://pypi.org/project/tetherai/)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
tetherai/__init__.py,sha256=OTKdW-yPFkSNB_1U5FbKJ7EUuH3NT-IlvSOpJrgSY1U,852
|
|
2
|
+
tetherai/_version.py,sha256=04Pi7kDrn0FVD-1c5mk_o8ryOaXTkkiLyp6TW-qxcaE,28
|
|
3
|
+
tetherai/budget.py,sha256=UUY481RFsiTPAwmaq2cYEx5Kdb7ifsTvYbTblmUvUfo,3552
|
|
4
|
+
tetherai/circuit_breaker.py,sha256=FLIfVhJ2WnGRRXaofqH5w6JUlZdR2Kl7mvi_a9JiuBo,3878
|
|
5
|
+
tetherai/config.py,sha256=lAo9hPj2fpU9V1AJw0PAOFPwJhloxxxD0UVpl814Rm4,3432
|
|
6
|
+
tetherai/exceptions.py,sha256=kx21E9QV7oQJ1-bDfxVfmrokrHg7LaHgu7dI-W5s-Do,2643
|
|
7
|
+
tetherai/exporter.py,sha256=tsyk1ol-n_w_hrUjRT3zdfwmcWFHu6DvmpR8f1CuewU,2137
|
|
8
|
+
tetherai/interceptor.py,sha256=HOYSLoduxPM-6kt-cg06RFCJi4uBllXHg5g2BFTIemE,8163
|
|
9
|
+
tetherai/pricing.py,sha256=Cdj6QDRvZOoYlO6aQORrVsQFHcXYto1lOxd7C6EMifg,3913
|
|
10
|
+
tetherai/token_counter.py,sha256=DkvL2qpHpKJzp0ODNs9xZ-bzzwEJhExmE8g64CeaHME,5326
|
|
11
|
+
tetherai/trace.py,sha256=blJ4liFo_mdzzDX8VeAu5_OlmgUxZQnZ-1VvFHoqUBE,3821
|
|
12
|
+
tetherai/crewai/__init__.py,sha256=LDjmhTX2DfeN4z6-YY2wODCT60NlK551xA4pB6VI9Rw,173
|
|
13
|
+
tetherai/crewai/integration.py,sha256=nDWk4DQhbGZmzxnEK7Ra_x8C_VbmvEjdn0Zagg032WY,1918
|
|
14
|
+
tetherai_python-0.1.0a0.dist-info/METADATA,sha256=i6Wr5UmNn89mMoJvq8KB174yev1EcccODgF9IHlAyQ4,1605
|
|
15
|
+
tetherai_python-0.1.0a0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
|
|
16
|
+
tetherai_python-0.1.0a0.dist-info/top_level.txt,sha256=kKKtbtmezQqlgMwH9lmXDCY_GTSz1Z36J_VdCqUcxJM,9
|
|
17
|
+
tetherai_python-0.1.0a0.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
tetherai
|