auditi 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- auditi/__init__.py +47 -0
- auditi/client.py +76 -0
- auditi/context.py +71 -0
- auditi/decorators.py +1441 -0
- auditi/evaluator.py +38 -0
- auditi/events.py +194 -0
- auditi/providers/__init__.py +41 -0
- auditi/providers/anthropic.py +141 -0
- auditi/providers/base.py +156 -0
- auditi/providers/google.py +182 -0
- auditi/providers/openai.py +147 -0
- auditi/providers/registry.py +166 -0
- auditi/transport.py +78 -0
- auditi/types/__init__.py +12 -0
- auditi/types/api_types.py +107 -0
- auditi-0.1.0.dist-info/METADATA +703 -0
- auditi-0.1.0.dist-info/RECORD +19 -0
- auditi-0.1.0.dist-info/WHEEL +4 -0
- auditi-0.1.0.dist-info/licenses/LICENSE +21 -0
auditi/evaluator.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base evaluator class for implementing custom evaluation logic.
|
|
3
|
+
"""
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from .types import TraceInput, EvaluationResult
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseEvaluator(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Abstract base class for trace evaluators.
|
|
11
|
+
|
|
12
|
+
Subclass this to implement custom evaluation logic for your AI agents.
|
|
13
|
+
The evaluator is called after each trace completes and before sending
|
|
14
|
+
to the Auditi platform.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> class QualityEvaluator(BaseEvaluator):
|
|
18
|
+
... def evaluate(self, trace: TraceInput) -> EvaluationResult:
|
|
19
|
+
... score = calculate_quality(trace.assistant_output)
|
|
20
|
+
... return EvaluationResult(
|
|
21
|
+
... status="pass" if score > 0.7 else "fail",
|
|
22
|
+
... score=score,
|
|
23
|
+
... reason="Quality check"
|
|
24
|
+
... )
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def evaluate(self, trace: TraceInput) -> EvaluationResult:
|
|
29
|
+
"""
|
|
30
|
+
Evaluate a trace and return a pass/fail result with a score.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
trace: The complete trace data including input, output, and spans
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
EvaluationResult with status, score, and optional details
|
|
37
|
+
"""
|
|
38
|
+
pass
|
auditi/events.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Standardized event types for streaming agent responses.
|
|
3
|
+
|
|
4
|
+
This module provides a clear contract for all event types used in the tracing system.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any, Dict, Optional, List
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EventType(str, Enum):
|
|
13
|
+
"""Standardized event types for agent streaming."""
|
|
14
|
+
|
|
15
|
+
# Content events
|
|
16
|
+
TOKEN = "token" # Streaming text token
|
|
17
|
+
COMPLETE = "complete" # Final complete response
|
|
18
|
+
|
|
19
|
+
# Phase events (agent lifecycle)
|
|
20
|
+
PHASE_START = "phase_start" # Agent phase starting
|
|
21
|
+
PHASE_END = "phase_end" # Agent phase ending
|
|
22
|
+
|
|
23
|
+
# Tool execution events
|
|
24
|
+
TOOL_EXEC_START = "tool_exec_start" # Tool execution starting
|
|
25
|
+
TOOL_EXEC_END = "tool_exec_end" # Tool execution ending
|
|
26
|
+
|
|
27
|
+
# Metadata events
|
|
28
|
+
TURN_METADATA = "turn_metadata" # Turn-level metadata (usage, tool_calls, etc.)
|
|
29
|
+
USAGE = "usage" # Usage statistics (standalone)
|
|
30
|
+
|
|
31
|
+
# Error events
|
|
32
|
+
ERROR = "error" # Error occurred
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Events that should NOT be accumulated in span outputs
|
|
36
|
+
INTERNAL_EVENTS = frozenset(
|
|
37
|
+
{
|
|
38
|
+
EventType.PHASE_START,
|
|
39
|
+
EventType.PHASE_END,
|
|
40
|
+
EventType.TOOL_EXEC_START,
|
|
41
|
+
EventType.TOOL_EXEC_END,
|
|
42
|
+
EventType.TURN_METADATA,
|
|
43
|
+
EventType.USAGE,
|
|
44
|
+
}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Events that contain content to accumulate
|
|
48
|
+
CONTENT_EVENTS = frozenset(
|
|
49
|
+
{
|
|
50
|
+
EventType.TOKEN,
|
|
51
|
+
EventType.COMPLETE,
|
|
52
|
+
}
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class StreamEvent:
|
|
58
|
+
"""
|
|
59
|
+
Standardized streaming event structure.
|
|
60
|
+
|
|
61
|
+
Provides a consistent interface for all events yielded during agent streaming.
|
|
62
|
+
Supports backward compatibility with raw dictionaries.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
type: EventType
|
|
66
|
+
content: Optional[str] = None
|
|
67
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
68
|
+
usage: Optional[Dict[str, Any]] = None
|
|
69
|
+
error: Optional[str] = None
|
|
70
|
+
tool_calls: Optional[List[Any]] = field(default=None)
|
|
71
|
+
# Additional fields for specific events
|
|
72
|
+
phase: Optional[str] = None # For PHASE_START/PHASE_END
|
|
73
|
+
tool: Optional[str] = None # For TOOL_EXEC_START/TOOL_EXEC_END
|
|
74
|
+
history: Optional[List[Any]] = field(default=None) # For COMPLETE
|
|
75
|
+
messages: Optional[List[Any]] = field(default=None) # For COMPLETE
|
|
76
|
+
|
|
77
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
78
|
+
"""Convert to dictionary for backward compatibility."""
|
|
79
|
+
result = {"type": self.type.value}
|
|
80
|
+
if self.content is not None:
|
|
81
|
+
result["content"] = self.content
|
|
82
|
+
if self.metadata is not None:
|
|
83
|
+
result["metadata"] = self.metadata
|
|
84
|
+
if self.usage is not None:
|
|
85
|
+
result["usage"] = self.usage
|
|
86
|
+
if self.error is not None:
|
|
87
|
+
result["error"] = self.error
|
|
88
|
+
if self.tool_calls is not None:
|
|
89
|
+
result["tool_calls"] = self.tool_calls
|
|
90
|
+
if self.phase is not None:
|
|
91
|
+
result["phase"] = self.phase
|
|
92
|
+
if self.tool is not None:
|
|
93
|
+
result["tool"] = self.tool
|
|
94
|
+
if self.history is not None:
|
|
95
|
+
result["history"] = self.history
|
|
96
|
+
if self.messages is not None:
|
|
97
|
+
result["messages"] = self.messages
|
|
98
|
+
return result
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def from_dict(cls, data: Dict[str, Any]) -> "StreamEvent":
|
|
102
|
+
"""
|
|
103
|
+
Create StreamEvent from dictionary (for backward compatibility).
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
data: Dictionary with at least a 'type' key
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
StreamEvent instance
|
|
110
|
+
"""
|
|
111
|
+
event_type_raw = data.get("type", "token")
|
|
112
|
+
|
|
113
|
+
# Handle both string and EventType enum
|
|
114
|
+
if isinstance(event_type_raw, EventType):
|
|
115
|
+
event_type = event_type_raw
|
|
116
|
+
elif isinstance(event_type_raw, str):
|
|
117
|
+
try:
|
|
118
|
+
event_type = EventType(event_type_raw)
|
|
119
|
+
except ValueError:
|
|
120
|
+
# Unknown event type, default to token
|
|
121
|
+
event_type = EventType.TOKEN
|
|
122
|
+
else:
|
|
123
|
+
event_type = EventType.TOKEN
|
|
124
|
+
|
|
125
|
+
return cls(
|
|
126
|
+
type=event_type,
|
|
127
|
+
content=data.get("content"),
|
|
128
|
+
metadata=data.get("metadata"),
|
|
129
|
+
usage=data.get("usage"),
|
|
130
|
+
error=data.get("error"),
|
|
131
|
+
tool_calls=data.get("tool_calls"),
|
|
132
|
+
phase=data.get("phase"),
|
|
133
|
+
tool=data.get("tool"),
|
|
134
|
+
history=data.get("history"),
|
|
135
|
+
messages=data.get("messages"),
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def is_internal(self) -> bool:
|
|
139
|
+
"""Check if this event should be filtered from outputs."""
|
|
140
|
+
return self.type in INTERNAL_EVENTS
|
|
141
|
+
|
|
142
|
+
def is_content(self) -> bool:
|
|
143
|
+
"""Check if this event contains content to accumulate."""
|
|
144
|
+
return self.type in CONTENT_EVENTS
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# Helper factory functions for common events
|
|
148
|
+
def token_event(content: str) -> Dict[str, Any]:
|
|
149
|
+
"""Create a token event dictionary."""
|
|
150
|
+
return {"type": EventType.TOKEN.value, "content": content}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def complete_event(
|
|
154
|
+
content: str,
|
|
155
|
+
history: Optional[List[Any]] = None,
|
|
156
|
+
messages: Optional[List[Any]] = None,
|
|
157
|
+
) -> Dict[str, Any]:
|
|
158
|
+
"""Create a complete event dictionary."""
|
|
159
|
+
result = {"type": EventType.COMPLETE.value, "content": content}
|
|
160
|
+
if history is not None:
|
|
161
|
+
result["history"] = history
|
|
162
|
+
if messages is not None:
|
|
163
|
+
result["messages"] = messages
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def turn_metadata_event(
|
|
168
|
+
tool_calls: Optional[List[Any]] = None,
|
|
169
|
+
usage: Optional[Dict[str, Any]] = None,
|
|
170
|
+
perplexity: Optional[float] = None,
|
|
171
|
+
confidence_level: Optional[str] = None,
|
|
172
|
+
total_tokens: Optional[int] = None,
|
|
173
|
+
) -> Dict[str, Any]:
|
|
174
|
+
"""Create a turn_metadata event dictionary."""
|
|
175
|
+
return {
|
|
176
|
+
"type": EventType.TURN_METADATA.value,
|
|
177
|
+
"tool_calls": tool_calls or [],
|
|
178
|
+
"usage": usage,
|
|
179
|
+
"perplexity": perplexity,
|
|
180
|
+
"confidence_level": confidence_level,
|
|
181
|
+
"total_tokens": total_tokens,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def phase_event(phase: str, start: bool = True) -> Dict[str, Any]:
|
|
186
|
+
"""Create a phase start/end event dictionary."""
|
|
187
|
+
event_type = EventType.PHASE_START if start else EventType.PHASE_END
|
|
188
|
+
return {"type": event_type.value, "phase": phase}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def tool_exec_event(tool: str, start: bool = True) -> Dict[str, Any]:
|
|
192
|
+
"""Create a tool execution start/end event dictionary."""
|
|
193
|
+
event_type = EventType.TOOL_EXEC_START if start else EventType.TOOL_EXEC_END
|
|
194
|
+
return {"type": event_type.value, "tool": tool}
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Provider abstraction layer for LLM usage extraction and cost calculation.
|
|
3
|
+
|
|
4
|
+
This module provides a clean, extensible way to handle different LLM providers
|
|
5
|
+
(OpenAI, Anthropic, Google, etc.) with automatic detection and provider-specific
|
|
6
|
+
pricing and usage extraction.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
>>> from auditi.providers import detect_provider
|
|
10
|
+
>>>
|
|
11
|
+
>>> # Auto-detect from model name
|
|
12
|
+
>>> provider = detect_provider(model="gpt-4o")
|
|
13
|
+
>>> input_tokens, output_tokens, total = provider.extract_usage(response.usage)
|
|
14
|
+
>>> cost = provider.calculate_cost("gpt-4o", input_tokens, output_tokens)
|
|
15
|
+
>>>
|
|
16
|
+
>>> # Or detect from response structure
|
|
17
|
+
>>> provider = detect_provider(response=api_response)
|
|
18
|
+
>>> model = provider.extract_model(api_response)
|
|
19
|
+
|
|
20
|
+
Adding a new provider:
|
|
21
|
+
1. Create a new file in auditi/providers/ (e.g., cohere.py)
|
|
22
|
+
2. Subclass BaseProvider and implement all abstract methods
|
|
23
|
+
3. Register it in registry.py's __init__ method
|
|
24
|
+
4. That's it! It will automatically be used for detection
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from .base import BaseProvider
|
|
28
|
+
from .registry import get_registry, detect_provider, ProviderRegistry
|
|
29
|
+
from .openai import OpenAIProvider
|
|
30
|
+
from .anthropic import AnthropicProvider
|
|
31
|
+
from .google import GoogleProvider
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"BaseProvider",
|
|
35
|
+
"ProviderRegistry",
|
|
36
|
+
"get_registry",
|
|
37
|
+
"detect_provider",
|
|
38
|
+
"OpenAIProvider",
|
|
39
|
+
"AnthropicProvider",
|
|
40
|
+
"GoogleProvider",
|
|
41
|
+
]
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Anthropic provider implementation for usage extraction and cost calculation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import Optional, Any, Dict, Tuple
|
|
6
|
+
from .base import BaseProvider
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _coerce_int(value: Any) -> Optional[int]:
|
|
10
|
+
"""Helper to safely convert values to int."""
|
|
11
|
+
if value is None:
|
|
12
|
+
return None
|
|
13
|
+
try:
|
|
14
|
+
return int(value)
|
|
15
|
+
except (TypeError, ValueError):
|
|
16
|
+
return None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AnthropicProvider(BaseProvider):
|
|
20
|
+
"""Provider implementation for Anthropic Claude models."""
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def name(self) -> str:
|
|
24
|
+
return "anthropic"
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def model_pricing(self) -> Dict[str, Tuple[float, float]]:
|
|
28
|
+
"""
|
|
29
|
+
Anthropic model pricing per 1M tokens (input, output) in USD.
|
|
30
|
+
Updated as of January 2025.
|
|
31
|
+
"""
|
|
32
|
+
return {
|
|
33
|
+
# Claude 4.5 family (newest)
|
|
34
|
+
"claude-opus-4-5-20251101": (15.00, 75.00),
|
|
35
|
+
"claude-sonnet-4-5-20250929": (3.00, 15.00),
|
|
36
|
+
"claude-haiku-4-5-20251001": (0.80, 4.00),
|
|
37
|
+
# Claude 3.5 family
|
|
38
|
+
"claude-3-5-sonnet-20241022": (3.00, 15.00),
|
|
39
|
+
"claude-3-5-sonnet-20240620": (3.00, 15.00),
|
|
40
|
+
"claude-3-5-sonnet-latest": (3.00, 15.00),
|
|
41
|
+
"claude-3-5-haiku-20241022": (0.80, 4.00),
|
|
42
|
+
"claude-3-5-haiku-latest": (0.80, 4.00),
|
|
43
|
+
# Claude 3 family
|
|
44
|
+
"claude-3-opus-20240229": (15.00, 75.00),
|
|
45
|
+
"claude-3-opus-latest": (15.00, 75.00),
|
|
46
|
+
"claude-3-sonnet-20240229": (3.00, 15.00),
|
|
47
|
+
"claude-3-haiku-20240307": (0.25, 1.25),
|
|
48
|
+
# Legacy models
|
|
49
|
+
"claude-2.1": (8.00, 24.00),
|
|
50
|
+
"claude-2.0": (8.00, 24.00),
|
|
51
|
+
"claude-instant-1.2": (0.80, 2.40),
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
def get_default_pricing(self) -> Tuple[float, float]:
|
|
55
|
+
"""Conservative default for unknown Anthropic models."""
|
|
56
|
+
return (3.00, 15.00) # Similar to Sonnet pricing
|
|
57
|
+
|
|
58
|
+
def get_model_prefixes(self) -> list[str]:
|
|
59
|
+
return ["claude-"]
|
|
60
|
+
|
|
61
|
+
def extract_usage(self, usage: Any) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
|
62
|
+
"""
|
|
63
|
+
Extract usage from Anthropic response.
|
|
64
|
+
|
|
65
|
+
Anthropic structure:
|
|
66
|
+
{
|
|
67
|
+
"usage": {
|
|
68
|
+
"input_tokens": 100,
|
|
69
|
+
"output_tokens": 50
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
Note: Anthropic does NOT return total_tokens, only input and output.
|
|
74
|
+
"""
|
|
75
|
+
if usage is None:
|
|
76
|
+
return None, None, None
|
|
77
|
+
|
|
78
|
+
input_tokens = None
|
|
79
|
+
output_tokens = None
|
|
80
|
+
total_tokens = None
|
|
81
|
+
|
|
82
|
+
if isinstance(usage, dict):
|
|
83
|
+
input_tokens = _coerce_int(usage.get("input_tokens"))
|
|
84
|
+
output_tokens = _coerce_int(usage.get("output_tokens"))
|
|
85
|
+
# Anthropic doesn't provide total_tokens, we calculate it
|
|
86
|
+
else:
|
|
87
|
+
# Handle object attributes
|
|
88
|
+
input_tokens = _coerce_int(getattr(usage, "input_tokens", None))
|
|
89
|
+
output_tokens = _coerce_int(getattr(usage, "output_tokens", None))
|
|
90
|
+
|
|
91
|
+
# Always calculate total for Anthropic
|
|
92
|
+
if input_tokens is not None or output_tokens is not None:
|
|
93
|
+
total_tokens = (input_tokens or 0) + (output_tokens or 0)
|
|
94
|
+
|
|
95
|
+
return input_tokens, output_tokens, total_tokens
|
|
96
|
+
|
|
97
|
+
def extract_model(self, response: Any) -> Optional[str]:
|
|
98
|
+
"""Extract model name from Anthropic response."""
|
|
99
|
+
if response is None:
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
# Try dict access
|
|
103
|
+
if isinstance(response, dict):
|
|
104
|
+
return response.get("model")
|
|
105
|
+
|
|
106
|
+
# Try object attribute
|
|
107
|
+
if hasattr(response, "model"):
|
|
108
|
+
return str(response.model)
|
|
109
|
+
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def matches_response(self, response: Any) -> bool:
|
|
113
|
+
"""
|
|
114
|
+
Detect Anthropic responses by structure.
|
|
115
|
+
|
|
116
|
+
Anthropic responses typically have:
|
|
117
|
+
- 'content' array with text blocks
|
|
118
|
+
- 'usage' with 'input_tokens' and 'output_tokens' (NOT prompt_tokens)
|
|
119
|
+
- 'stop_reason' field
|
|
120
|
+
"""
|
|
121
|
+
if response is None:
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
# Check for Anthropic-specific structure
|
|
125
|
+
if isinstance(response, dict):
|
|
126
|
+
has_anthropic_usage = (
|
|
127
|
+
"usage" in response
|
|
128
|
+
and isinstance(response.get("usage"), dict)
|
|
129
|
+
and "input_tokens" in response.get("usage", {})
|
|
130
|
+
)
|
|
131
|
+
has_stop_reason = "stop_reason" in response
|
|
132
|
+
|
|
133
|
+
if has_anthropic_usage or has_stop_reason:
|
|
134
|
+
return True
|
|
135
|
+
elif hasattr(response, "usage"):
|
|
136
|
+
usage = response.usage
|
|
137
|
+
if hasattr(usage, "input_tokens"): # Anthropic-specific field
|
|
138
|
+
return True
|
|
139
|
+
|
|
140
|
+
# Fallback to model prefix matching
|
|
141
|
+
return super().matches_response(response)
|
auditi/providers/base.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base provider interface for LLM usage extraction and cost calculation.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Optional, Any, Dict, Tuple
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseProvider(ABC):
|
|
10
|
+
"""
|
|
11
|
+
Abstract base class for LLM provider-specific logic.
|
|
12
|
+
|
|
13
|
+
Each provider implements:
|
|
14
|
+
1. Usage extraction from API responses
|
|
15
|
+
2. Model pricing lookup
|
|
16
|
+
3. Cost calculation
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def name(self) -> str:
|
|
22
|
+
"""Provider name (e.g., 'openai', 'anthropic', 'google')."""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def model_pricing(self) -> Dict[str, Tuple[float, float]]:
|
|
28
|
+
"""
|
|
29
|
+
Model pricing dictionary.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Dict mapping model names to (input_price, output_price) per 1M tokens in USD
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
{
|
|
36
|
+
"gpt-4o": (2.50, 10.00),
|
|
37
|
+
"claude-3-5-sonnet-20241022": (3.00, 15.00)
|
|
38
|
+
}
|
|
39
|
+
"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def extract_usage(self, usage: Any) -> Tuple[Optional[int], Optional[int], Optional[int]]:
|
|
44
|
+
"""
|
|
45
|
+
Extract token counts from provider-specific usage object.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
usage: Raw usage object/dict from API response
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (input_tokens, output_tokens, total_tokens)
|
|
52
|
+
Returns (None, None, None) if extraction fails
|
|
53
|
+
"""
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
@abstractmethod
|
|
57
|
+
def extract_model(self, response: Any) -> Optional[str]:
|
|
58
|
+
"""
|
|
59
|
+
Extract model name from API response.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
response: Raw API response object
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Model name string, or None if not found
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def calculate_cost(
|
|
70
|
+
self, model: Optional[str], input_tokens: Optional[int], output_tokens: Optional[int]
|
|
71
|
+
) -> float:
|
|
72
|
+
"""
|
|
73
|
+
Calculate cost based on model-specific pricing.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model: Model name
|
|
77
|
+
input_tokens: Number of input tokens
|
|
78
|
+
output_tokens: Number of output tokens
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Total cost in USD
|
|
82
|
+
"""
|
|
83
|
+
if input_tokens is None and output_tokens is None:
|
|
84
|
+
return 0.0
|
|
85
|
+
|
|
86
|
+
input_tokens = input_tokens or 0
|
|
87
|
+
output_tokens = output_tokens or 0
|
|
88
|
+
|
|
89
|
+
# Look up pricing for this model
|
|
90
|
+
pricing = self.model_pricing.get(model)
|
|
91
|
+
if pricing is None:
|
|
92
|
+
# Use provider-specific default pricing
|
|
93
|
+
pricing = self.get_default_pricing()
|
|
94
|
+
|
|
95
|
+
input_price, output_price = pricing
|
|
96
|
+
|
|
97
|
+
# Convert from price per 1M tokens to per token
|
|
98
|
+
input_cost = (input_tokens / 1_000_000) * input_price
|
|
99
|
+
output_cost = (output_tokens / 1_000_000) * output_price
|
|
100
|
+
|
|
101
|
+
return input_cost + output_cost
|
|
102
|
+
|
|
103
|
+
@abstractmethod
|
|
104
|
+
def get_default_pricing(self) -> Tuple[float, float]:
|
|
105
|
+
"""
|
|
106
|
+
Get default fallback pricing for unknown models.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Tuple of (input_price, output_price) per 1M tokens in USD
|
|
110
|
+
"""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
def matches_model(self, model: Optional[str]) -> bool:
|
|
114
|
+
"""
|
|
115
|
+
Check if a model name belongs to this provider.
|
|
116
|
+
|
|
117
|
+
Default implementation checks if model starts with common prefixes.
|
|
118
|
+
Override for custom logic.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
model: Model name string
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
True if this provider handles the model
|
|
125
|
+
"""
|
|
126
|
+
if not model:
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
model_lower = model.lower()
|
|
130
|
+
return any(model_lower.startswith(prefix) for prefix in self.get_model_prefixes())
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def get_model_prefixes(self) -> list[str]:
|
|
134
|
+
"""
|
|
135
|
+
Get list of model name prefixes for this provider.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
List of lowercase prefixes (e.g., ['gpt-', 'o1-'])
|
|
139
|
+
"""
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def matches_response(self, response: Any) -> bool:
|
|
143
|
+
"""
|
|
144
|
+
Check if a response object comes from this provider.
|
|
145
|
+
|
|
146
|
+
Default implementation tries to extract model and check prefixes.
|
|
147
|
+
Override for custom detection logic (e.g., checking response structure).
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
response: Raw API response object
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
True if this provider handles the response
|
|
154
|
+
"""
|
|
155
|
+
model = self.extract_model(response)
|
|
156
|
+
return self.matches_model(model)
|