proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1063 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Enhanced Session-Based Cost Tracking for Proxilion.
|
|
3
|
+
|
|
4
|
+
Extends the base CostTracker with per-user and per-agent session
|
|
5
|
+
tracking, real-time budget alerts, and detailed cost attribution.
|
|
6
|
+
|
|
7
|
+
Features:
|
|
8
|
+
- Per-session cost tracking with automatic session management
|
|
9
|
+
- Per-agent cost attribution in multi-agent systems
|
|
10
|
+
- Real-time budget alerts and callbacks
|
|
11
|
+
- Cost breakdown by tool, model, and time period
|
|
12
|
+
- Session cost limits and automatic termination
|
|
13
|
+
- Cost forecasting based on usage patterns
|
|
14
|
+
- Export formats for billing integration
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> from proxilion.observability.session_cost_tracker import (
|
|
18
|
+
... SessionCostTracker,
|
|
19
|
+
... Session,
|
|
20
|
+
... CostAlert,
|
|
21
|
+
... )
|
|
22
|
+
>>>
|
|
23
|
+
>>> # Create tracker with session support
|
|
24
|
+
>>> tracker = SessionCostTracker()
|
|
25
|
+
>>>
|
|
26
|
+
>>> # Start a user session
|
|
27
|
+
>>> session = tracker.start_session(
|
|
28
|
+
... user_id="user_123",
|
|
29
|
+
... agent_id="assistant_main",
|
|
30
|
+
... budget_limit=10.00,
|
|
31
|
+
... )
|
|
32
|
+
>>>
|
|
33
|
+
>>> # Record usage
|
|
34
|
+
>>> tracker.record_session_usage(
|
|
35
|
+
... session_id=session.session_id,
|
|
36
|
+
... model="claude-sonnet-4-20250514",
|
|
37
|
+
... input_tokens=1000,
|
|
38
|
+
... output_tokens=500,
|
|
39
|
+
... tool_name="search",
|
|
40
|
+
... )
|
|
41
|
+
>>>
|
|
42
|
+
>>> # Check session costs
|
|
43
|
+
>>> print(f"Session cost: ${session.total_cost:.4f}")
|
|
44
|
+
>>> print(f"Budget remaining: ${session.budget_remaining:.4f}")
|
|
45
|
+
>>>
|
|
46
|
+
>>> # End session
|
|
47
|
+
>>> summary = tracker.end_session(session.session_id)
|
|
48
|
+
>>> print(f"Final cost: ${summary.total_cost:.4f}")
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
from __future__ import annotations
|
|
52
|
+
|
|
53
|
+
import hashlib
|
|
54
|
+
import json
|
|
55
|
+
import logging
|
|
56
|
+
import threading
|
|
57
|
+
import uuid
|
|
58
|
+
from collections import defaultdict
|
|
59
|
+
from dataclasses import asdict, dataclass, field
|
|
60
|
+
from datetime import datetime, timedelta, timezone
|
|
61
|
+
from enum import Enum
|
|
62
|
+
from typing import Any, Callable
|
|
63
|
+
|
|
64
|
+
from proxilion.observability.cost_tracker import (
|
|
65
|
+
BudgetPolicy,
|
|
66
|
+
CostSummary,
|
|
67
|
+
CostTracker,
|
|
68
|
+
ModelPricing,
|
|
69
|
+
UsageRecord,
|
|
70
|
+
DEFAULT_PRICING,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger = logging.getLogger(__name__)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SessionState(str, Enum):
|
|
77
|
+
"""Session lifecycle states."""
|
|
78
|
+
|
|
79
|
+
ACTIVE = "active"
|
|
80
|
+
PAUSED = "paused"
|
|
81
|
+
BUDGET_EXCEEDED = "budget_exceeded"
|
|
82
|
+
TERMINATED = "terminated"
|
|
83
|
+
EXPIRED = "expired"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AlertSeverity(str, Enum):
|
|
87
|
+
"""Alert severity levels."""
|
|
88
|
+
|
|
89
|
+
INFO = "info"
|
|
90
|
+
WARNING = "warning"
|
|
91
|
+
CRITICAL = "critical"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class AlertType(str, Enum):
|
|
95
|
+
"""Types of cost alerts."""
|
|
96
|
+
|
|
97
|
+
BUDGET_WARNING = "budget_warning" # Approaching budget limit
|
|
98
|
+
BUDGET_EXCEEDED = "budget_exceeded" # Budget limit hit
|
|
99
|
+
RATE_WARNING = "rate_warning" # High spend rate
|
|
100
|
+
ANOMALY = "anomaly" # Unusual spending pattern
|
|
101
|
+
SESSION_EXPIRED = "session_expired" # Session timeout
|
|
102
|
+
FORECAST_WARNING = "forecast_warning" # Projected to exceed budget
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class CostAlert:
|
|
107
|
+
"""
|
|
108
|
+
A cost-related alert.
|
|
109
|
+
|
|
110
|
+
Attributes:
|
|
111
|
+
alert_id: Unique identifier.
|
|
112
|
+
alert_type: Type of alert.
|
|
113
|
+
severity: Alert severity level.
|
|
114
|
+
session_id: Associated session (if any).
|
|
115
|
+
user_id: Associated user.
|
|
116
|
+
agent_id: Associated agent (if any).
|
|
117
|
+
message: Human-readable message.
|
|
118
|
+
current_cost: Current cost when alert was triggered.
|
|
119
|
+
threshold: Threshold that was crossed.
|
|
120
|
+
timestamp: When the alert was created.
|
|
121
|
+
metadata: Additional metadata.
|
|
122
|
+
"""
|
|
123
|
+
|
|
124
|
+
alert_id: str
|
|
125
|
+
alert_type: AlertType
|
|
126
|
+
severity: AlertSeverity
|
|
127
|
+
message: str
|
|
128
|
+
current_cost: float
|
|
129
|
+
threshold: float | None = None
|
|
130
|
+
session_id: str | None = None
|
|
131
|
+
user_id: str | None = None
|
|
132
|
+
agent_id: str | None = None
|
|
133
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
134
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
135
|
+
|
|
136
|
+
def to_dict(self) -> dict[str, Any]:
|
|
137
|
+
"""Convert to dictionary."""
|
|
138
|
+
return {
|
|
139
|
+
"alert_id": self.alert_id,
|
|
140
|
+
"alert_type": self.alert_type.value,
|
|
141
|
+
"severity": self.severity.value,
|
|
142
|
+
"message": self.message,
|
|
143
|
+
"current_cost": self.current_cost,
|
|
144
|
+
"threshold": self.threshold,
|
|
145
|
+
"session_id": self.session_id,
|
|
146
|
+
"user_id": self.user_id,
|
|
147
|
+
"agent_id": self.agent_id,
|
|
148
|
+
"timestamp": self.timestamp.isoformat(),
|
|
149
|
+
"metadata": self.metadata,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@dataclass
|
|
154
|
+
class AgentCostProfile:
|
|
155
|
+
"""
|
|
156
|
+
Cost profile for an agent within a session.
|
|
157
|
+
|
|
158
|
+
Tracks costs attributed to a specific agent in a multi-agent
|
|
159
|
+
system, including delegation costs.
|
|
160
|
+
|
|
161
|
+
Attributes:
|
|
162
|
+
agent_id: Unique agent identifier.
|
|
163
|
+
parent_agent_id: Parent agent (if delegated).
|
|
164
|
+
total_cost: Total cost attributed to this agent.
|
|
165
|
+
input_tokens: Total input tokens.
|
|
166
|
+
output_tokens: Total output tokens.
|
|
167
|
+
tool_calls: Number of tool calls.
|
|
168
|
+
by_tool: Cost breakdown by tool.
|
|
169
|
+
by_model: Cost breakdown by model.
|
|
170
|
+
first_activity: First activity timestamp.
|
|
171
|
+
last_activity: Last activity timestamp.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
agent_id: str
|
|
175
|
+
parent_agent_id: str | None = None
|
|
176
|
+
total_cost: float = 0.0
|
|
177
|
+
input_tokens: int = 0
|
|
178
|
+
output_tokens: int = 0
|
|
179
|
+
tool_calls: int = 0
|
|
180
|
+
by_tool: dict[str, float] = field(default_factory=dict)
|
|
181
|
+
by_model: dict[str, float] = field(default_factory=dict)
|
|
182
|
+
first_activity: datetime | None = None
|
|
183
|
+
last_activity: datetime | None = None
|
|
184
|
+
|
|
185
|
+
def to_dict(self) -> dict[str, Any]:
|
|
186
|
+
"""Convert to dictionary."""
|
|
187
|
+
return {
|
|
188
|
+
"agent_id": self.agent_id,
|
|
189
|
+
"parent_agent_id": self.parent_agent_id,
|
|
190
|
+
"total_cost": self.total_cost,
|
|
191
|
+
"input_tokens": self.input_tokens,
|
|
192
|
+
"output_tokens": self.output_tokens,
|
|
193
|
+
"tool_calls": self.tool_calls,
|
|
194
|
+
"by_tool": self.by_tool,
|
|
195
|
+
"by_model": self.by_model,
|
|
196
|
+
"first_activity": self.first_activity.isoformat() if self.first_activity else None,
|
|
197
|
+
"last_activity": self.last_activity.isoformat() if self.last_activity else None,
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
@dataclass
|
|
202
|
+
class Session:
|
|
203
|
+
"""
|
|
204
|
+
A user/agent session with cost tracking.
|
|
205
|
+
|
|
206
|
+
Attributes:
|
|
207
|
+
session_id: Unique session identifier.
|
|
208
|
+
user_id: User who owns the session.
|
|
209
|
+
agent_id: Primary agent for the session.
|
|
210
|
+
state: Current session state.
|
|
211
|
+
budget_limit: Maximum cost allowed for this session.
|
|
212
|
+
total_cost: Total cost incurred so far.
|
|
213
|
+
input_tokens: Total input tokens used.
|
|
214
|
+
output_tokens: Total output tokens used.
|
|
215
|
+
record_count: Number of usage records.
|
|
216
|
+
agents: Per-agent cost profiles.
|
|
217
|
+
by_tool: Cost breakdown by tool.
|
|
218
|
+
by_model: Cost breakdown by model.
|
|
219
|
+
start_time: When the session started.
|
|
220
|
+
end_time: When the session ended (if ended).
|
|
221
|
+
last_activity: Last activity timestamp.
|
|
222
|
+
timeout_minutes: Session timeout in minutes.
|
|
223
|
+
metadata: Additional session metadata.
|
|
224
|
+
alerts: Alerts triggered during session.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
session_id: str
|
|
228
|
+
user_id: str
|
|
229
|
+
state: SessionState = SessionState.ACTIVE
|
|
230
|
+
agent_id: str | None = None
|
|
231
|
+
budget_limit: float | None = None
|
|
232
|
+
total_cost: float = 0.0
|
|
233
|
+
input_tokens: int = 0
|
|
234
|
+
output_tokens: int = 0
|
|
235
|
+
record_count: int = 0
|
|
236
|
+
agents: dict[str, AgentCostProfile] = field(default_factory=dict)
|
|
237
|
+
by_tool: dict[str, float] = field(default_factory=dict)
|
|
238
|
+
by_model: dict[str, float] = field(default_factory=dict)
|
|
239
|
+
start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
240
|
+
end_time: datetime | None = None
|
|
241
|
+
last_activity: datetime | None = None
|
|
242
|
+
timeout_minutes: int = 60
|
|
243
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
244
|
+
alerts: list[CostAlert] = field(default_factory=list)
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def is_active(self) -> bool:
|
|
248
|
+
"""Whether the session is active."""
|
|
249
|
+
return self.state == SessionState.ACTIVE
|
|
250
|
+
|
|
251
|
+
@property
|
|
252
|
+
def budget_remaining(self) -> float | None:
|
|
253
|
+
"""Remaining budget, or None if no limit."""
|
|
254
|
+
if self.budget_limit is None:
|
|
255
|
+
return None
|
|
256
|
+
return max(0.0, self.budget_limit - self.total_cost)
|
|
257
|
+
|
|
258
|
+
@property
|
|
259
|
+
def budget_percentage(self) -> float | None:
|
|
260
|
+
"""Percentage of budget used, or None if no limit."""
|
|
261
|
+
if self.budget_limit is None or self.budget_limit == 0:
|
|
262
|
+
return None
|
|
263
|
+
return min(1.0, self.total_cost / self.budget_limit)
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def duration_seconds(self) -> float:
|
|
267
|
+
"""Session duration in seconds."""
|
|
268
|
+
end = self.end_time or datetime.now(timezone.utc)
|
|
269
|
+
return (end - self.start_time).total_seconds()
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def is_expired(self) -> bool:
|
|
273
|
+
"""Whether the session has expired due to timeout."""
|
|
274
|
+
if self.last_activity is None:
|
|
275
|
+
return False
|
|
276
|
+
inactive = datetime.now(timezone.utc) - self.last_activity
|
|
277
|
+
return inactive > timedelta(minutes=self.timeout_minutes)
|
|
278
|
+
|
|
279
|
+
def to_dict(self) -> dict[str, Any]:
|
|
280
|
+
"""Convert to dictionary."""
|
|
281
|
+
return {
|
|
282
|
+
"session_id": self.session_id,
|
|
283
|
+
"user_id": self.user_id,
|
|
284
|
+
"agent_id": self.agent_id,
|
|
285
|
+
"state": self.state.value,
|
|
286
|
+
"budget_limit": self.budget_limit,
|
|
287
|
+
"budget_remaining": self.budget_remaining,
|
|
288
|
+
"budget_percentage": self.budget_percentage,
|
|
289
|
+
"total_cost": self.total_cost,
|
|
290
|
+
"input_tokens": self.input_tokens,
|
|
291
|
+
"output_tokens": self.output_tokens,
|
|
292
|
+
"record_count": self.record_count,
|
|
293
|
+
"agents": {k: v.to_dict() for k, v in self.agents.items()},
|
|
294
|
+
"by_tool": self.by_tool,
|
|
295
|
+
"by_model": self.by_model,
|
|
296
|
+
"start_time": self.start_time.isoformat(),
|
|
297
|
+
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
298
|
+
"last_activity": self.last_activity.isoformat() if self.last_activity else None,
|
|
299
|
+
"duration_seconds": self.duration_seconds,
|
|
300
|
+
"timeout_minutes": self.timeout_minutes,
|
|
301
|
+
"metadata": self.metadata,
|
|
302
|
+
"alerts": [a.to_dict() for a in self.alerts],
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
def to_json(self) -> str:
|
|
306
|
+
"""Convert to JSON string."""
|
|
307
|
+
return json.dumps(self.to_dict(), indent=2)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@dataclass
|
|
311
|
+
class SessionSummary:
|
|
312
|
+
"""
|
|
313
|
+
Summary of a completed session.
|
|
314
|
+
|
|
315
|
+
Attributes:
|
|
316
|
+
session_id: Session identifier.
|
|
317
|
+
user_id: User who owned the session.
|
|
318
|
+
total_cost: Total cost incurred.
|
|
319
|
+
total_tokens: Total tokens used.
|
|
320
|
+
duration_seconds: Session duration.
|
|
321
|
+
tool_breakdown: Cost by tool.
|
|
322
|
+
model_breakdown: Cost by model.
|
|
323
|
+
agent_breakdown: Cost by agent.
|
|
324
|
+
peak_spend_rate: Highest spend rate ($/minute).
|
|
325
|
+
alerts_triggered: Number of alerts triggered.
|
|
326
|
+
end_reason: Why the session ended.
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
session_id: str
|
|
330
|
+
user_id: str
|
|
331
|
+
total_cost: float
|
|
332
|
+
total_tokens: int
|
|
333
|
+
duration_seconds: float
|
|
334
|
+
tool_breakdown: dict[str, float]
|
|
335
|
+
model_breakdown: dict[str, float]
|
|
336
|
+
agent_breakdown: dict[str, float]
|
|
337
|
+
peak_spend_rate: float = 0.0
|
|
338
|
+
alerts_triggered: int = 0
|
|
339
|
+
end_reason: str = "user_ended"
|
|
340
|
+
|
|
341
|
+
def to_dict(self) -> dict[str, Any]:
|
|
342
|
+
"""Convert to dictionary."""
|
|
343
|
+
return asdict(self)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# Type alias for alert callbacks
|
|
347
|
+
AlertCallback = Callable[[CostAlert], None]
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
class SessionCostTracker:
|
|
351
|
+
"""
|
|
352
|
+
Enhanced cost tracker with session management.
|
|
353
|
+
|
|
354
|
+
Provides per-user and per-agent cost tracking with automatic
|
|
355
|
+
session management, budget enforcement, and alerting.
|
|
356
|
+
|
|
357
|
+
Example:
|
|
358
|
+
>>> tracker = SessionCostTracker(
|
|
359
|
+
... budget_policy=BudgetPolicy(max_cost_per_user_per_day=100.0)
|
|
360
|
+
... )
|
|
361
|
+
>>>
|
|
362
|
+
>>> # Start session with budget
|
|
363
|
+
>>> session = tracker.start_session(
|
|
364
|
+
... user_id="alice",
|
|
365
|
+
... budget_limit=10.0,
|
|
366
|
+
... )
|
|
367
|
+
>>>
|
|
368
|
+
>>> # Record usage
|
|
369
|
+
>>> record = tracker.record_session_usage(
|
|
370
|
+
... session_id=session.session_id,
|
|
371
|
+
... model="claude-sonnet-4-20250514",
|
|
372
|
+
... input_tokens=1000,
|
|
373
|
+
... output_tokens=500,
|
|
374
|
+
... )
|
|
375
|
+
>>>
|
|
376
|
+
>>> # Check status
|
|
377
|
+
>>> print(f"Session cost: ${session.total_cost:.4f}")
|
|
378
|
+
>>> print(f"Remaining: ${session.budget_remaining:.4f}")
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
def __init__(
|
|
382
|
+
self,
|
|
383
|
+
base_tracker: CostTracker | None = None,
|
|
384
|
+
budget_policy: BudgetPolicy | None = None,
|
|
385
|
+
default_session_timeout: int = 60,
|
|
386
|
+
budget_warning_threshold: float = 0.8,
|
|
387
|
+
rate_warning_threshold: float = 1.0, # $/minute
|
|
388
|
+
max_sessions: int = 10000,
|
|
389
|
+
enable_forecasting: bool = True,
|
|
390
|
+
) -> None:
|
|
391
|
+
"""
|
|
392
|
+
Initialize the session cost tracker.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
base_tracker: Optional base CostTracker to wrap.
|
|
396
|
+
budget_policy: Budget policy for global limits.
|
|
397
|
+
default_session_timeout: Default session timeout in minutes.
|
|
398
|
+
budget_warning_threshold: Percentage at which to warn (0.0 to 1.0).
|
|
399
|
+
rate_warning_threshold: Spend rate ($/min) that triggers warning.
|
|
400
|
+
max_sessions: Maximum active sessions to track.
|
|
401
|
+
enable_forecasting: Whether to enable cost forecasting.
|
|
402
|
+
"""
|
|
403
|
+
self._lock = threading.RLock()
|
|
404
|
+
|
|
405
|
+
# Use provided tracker or create new one
|
|
406
|
+
self._base_tracker = base_tracker or CostTracker(budget_policy=budget_policy)
|
|
407
|
+
|
|
408
|
+
self._default_timeout = default_session_timeout
|
|
409
|
+
self._budget_warning_threshold = budget_warning_threshold
|
|
410
|
+
self._rate_warning_threshold = rate_warning_threshold
|
|
411
|
+
self._max_sessions = max_sessions
|
|
412
|
+
self._enable_forecasting = enable_forecasting
|
|
413
|
+
|
|
414
|
+
# Session storage
|
|
415
|
+
self._sessions: dict[str, Session] = {}
|
|
416
|
+
self._user_sessions: dict[str, list[str]] = defaultdict(list)
|
|
417
|
+
self._session_records: dict[str, list[UsageRecord]] = defaultdict(list)
|
|
418
|
+
|
|
419
|
+
# Alert callbacks
|
|
420
|
+
self._alert_callbacks: list[AlertCallback] = []
|
|
421
|
+
|
|
422
|
+
# Metrics
|
|
423
|
+
self._total_alerts = 0
|
|
424
|
+
self._sessions_created = 0
|
|
425
|
+
self._sessions_terminated = 0
|
|
426
|
+
|
|
427
|
+
def add_alert_callback(self, callback: AlertCallback) -> None:
|
|
428
|
+
"""
|
|
429
|
+
Register a callback for cost alerts.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
callback: Function to call when an alert is triggered.
|
|
433
|
+
"""
|
|
434
|
+
self._alert_callbacks.append(callback)
|
|
435
|
+
|
|
436
|
+
def remove_alert_callback(self, callback: AlertCallback) -> None:
|
|
437
|
+
"""Remove an alert callback."""
|
|
438
|
+
if callback in self._alert_callbacks:
|
|
439
|
+
self._alert_callbacks.remove(callback)
|
|
440
|
+
|
|
441
|
+
def start_session(
|
|
442
|
+
self,
|
|
443
|
+
user_id: str,
|
|
444
|
+
agent_id: str | None = None,
|
|
445
|
+
budget_limit: float | None = None,
|
|
446
|
+
timeout_minutes: int | None = None,
|
|
447
|
+
metadata: dict[str, Any] | None = None,
|
|
448
|
+
) -> Session:
|
|
449
|
+
"""
|
|
450
|
+
Start a new cost tracking session.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
user_id: User who owns the session.
|
|
454
|
+
agent_id: Primary agent for the session.
|
|
455
|
+
budget_limit: Maximum cost for this session.
|
|
456
|
+
timeout_minutes: Session timeout (uses default if not provided).
|
|
457
|
+
metadata: Additional metadata to store.
|
|
458
|
+
|
|
459
|
+
Returns:
|
|
460
|
+
The created Session.
|
|
461
|
+
"""
|
|
462
|
+
session_id = f"sess_{uuid.uuid4().hex[:12]}"
|
|
463
|
+
|
|
464
|
+
session = Session(
|
|
465
|
+
session_id=session_id,
|
|
466
|
+
user_id=user_id,
|
|
467
|
+
agent_id=agent_id,
|
|
468
|
+
budget_limit=budget_limit,
|
|
469
|
+
timeout_minutes=timeout_minutes or self._default_timeout,
|
|
470
|
+
metadata=metadata or {},
|
|
471
|
+
last_activity=datetime.now(timezone.utc),
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
# Register primary agent if provided
|
|
475
|
+
if agent_id:
|
|
476
|
+
session.agents[agent_id] = AgentCostProfile(agent_id=agent_id)
|
|
477
|
+
|
|
478
|
+
with self._lock:
|
|
479
|
+
# Clean up expired sessions if at capacity
|
|
480
|
+
if len(self._sessions) >= self._max_sessions:
|
|
481
|
+
self._cleanup_expired_sessions()
|
|
482
|
+
|
|
483
|
+
self._sessions[session_id] = session
|
|
484
|
+
self._user_sessions[user_id].append(session_id)
|
|
485
|
+
self._sessions_created += 1
|
|
486
|
+
|
|
487
|
+
budget_str = f"${budget_limit:.2f}" if budget_limit else "unlimited"
|
|
488
|
+
logger.info(
|
|
489
|
+
f"Started session {session_id} for user {user_id} "
|
|
490
|
+
f"(budget: {budget_str})"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
return session
|
|
494
|
+
|
|
495
|
+
def get_session(self, session_id: str) -> Session | None:
|
|
496
|
+
"""Get a session by ID."""
|
|
497
|
+
with self._lock:
|
|
498
|
+
session = self._sessions.get(session_id)
|
|
499
|
+
|
|
500
|
+
# Check for expiration
|
|
501
|
+
if session and session.is_expired and session.is_active:
|
|
502
|
+
self._expire_session(session)
|
|
503
|
+
|
|
504
|
+
return session
|
|
505
|
+
|
|
506
|
+
def get_user_sessions(
|
|
507
|
+
self,
|
|
508
|
+
user_id: str,
|
|
509
|
+
active_only: bool = True,
|
|
510
|
+
) -> list[Session]:
|
|
511
|
+
"""
|
|
512
|
+
Get all sessions for a user.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
user_id: User to get sessions for.
|
|
516
|
+
active_only: Whether to return only active sessions.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
List of sessions.
|
|
520
|
+
"""
|
|
521
|
+
with self._lock:
|
|
522
|
+
session_ids = self._user_sessions.get(user_id, [])
|
|
523
|
+
sessions = []
|
|
524
|
+
|
|
525
|
+
for sid in session_ids:
|
|
526
|
+
session = self._sessions.get(sid)
|
|
527
|
+
if session:
|
|
528
|
+
if active_only and not session.is_active:
|
|
529
|
+
continue
|
|
530
|
+
sessions.append(session)
|
|
531
|
+
|
|
532
|
+
return sessions
|
|
533
|
+
|
|
534
|
+
def record_session_usage(
|
|
535
|
+
self,
|
|
536
|
+
session_id: str,
|
|
537
|
+
model: str,
|
|
538
|
+
input_tokens: int,
|
|
539
|
+
output_tokens: int,
|
|
540
|
+
cache_read_tokens: int = 0,
|
|
541
|
+
cache_write_tokens: int = 0,
|
|
542
|
+
tool_name: str | None = None,
|
|
543
|
+
agent_id: str | None = None,
|
|
544
|
+
request_id: str | None = None,
|
|
545
|
+
metadata: dict[str, Any] | None = None,
|
|
546
|
+
) -> UsageRecord | None:
|
|
547
|
+
"""
|
|
548
|
+
Record usage for a session.
|
|
549
|
+
|
|
550
|
+
Args:
|
|
551
|
+
session_id: Session to record usage for.
|
|
552
|
+
model: Model used.
|
|
553
|
+
input_tokens: Number of input tokens.
|
|
554
|
+
output_tokens: Number of output tokens.
|
|
555
|
+
cache_read_tokens: Cached tokens read.
|
|
556
|
+
cache_write_tokens: Tokens written to cache.
|
|
557
|
+
tool_name: Tool that triggered the usage.
|
|
558
|
+
agent_id: Agent that incurred the usage.
|
|
559
|
+
request_id: Request identifier.
|
|
560
|
+
metadata: Additional metadata.
|
|
561
|
+
|
|
562
|
+
Returns:
|
|
563
|
+
UsageRecord if successful, None if session not found or inactive.
|
|
564
|
+
"""
|
|
565
|
+
with self._lock:
|
|
566
|
+
session = self._sessions.get(session_id)
|
|
567
|
+
|
|
568
|
+
if session is None:
|
|
569
|
+
logger.warning(f"Session {session_id} not found")
|
|
570
|
+
return None
|
|
571
|
+
|
|
572
|
+
if not session.is_active:
|
|
573
|
+
logger.warning(f"Session {session_id} is not active (state: {session.state})")
|
|
574
|
+
return None
|
|
575
|
+
|
|
576
|
+
# Check for expiration
|
|
577
|
+
if session.is_expired:
|
|
578
|
+
self._expire_session(session)
|
|
579
|
+
return None
|
|
580
|
+
|
|
581
|
+
# Record in base tracker
|
|
582
|
+
record = self._base_tracker.record_usage(
|
|
583
|
+
model=model,
|
|
584
|
+
input_tokens=input_tokens,
|
|
585
|
+
output_tokens=output_tokens,
|
|
586
|
+
cache_read_tokens=cache_read_tokens,
|
|
587
|
+
cache_write_tokens=cache_write_tokens,
|
|
588
|
+
tool_name=tool_name,
|
|
589
|
+
user_id=session.user_id,
|
|
590
|
+
request_id=request_id,
|
|
591
|
+
metadata={
|
|
592
|
+
**(metadata or {}),
|
|
593
|
+
"session_id": session_id,
|
|
594
|
+
"agent_id": agent_id,
|
|
595
|
+
},
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Update session
|
|
599
|
+
session.total_cost += record.cost_usd
|
|
600
|
+
session.input_tokens += input_tokens
|
|
601
|
+
session.output_tokens += output_tokens
|
|
602
|
+
session.record_count += 1
|
|
603
|
+
session.last_activity = record.timestamp
|
|
604
|
+
|
|
605
|
+
# Update tool breakdown
|
|
606
|
+
if tool_name:
|
|
607
|
+
session.by_tool[tool_name] = session.by_tool.get(tool_name, 0.0) + record.cost_usd
|
|
608
|
+
|
|
609
|
+
# Update model breakdown
|
|
610
|
+
session.by_model[model] = session.by_model.get(model, 0.0) + record.cost_usd
|
|
611
|
+
|
|
612
|
+
# Update agent breakdown
|
|
613
|
+
effective_agent = agent_id or session.agent_id
|
|
614
|
+
if effective_agent:
|
|
615
|
+
if effective_agent not in session.agents:
|
|
616
|
+
session.agents[effective_agent] = AgentCostProfile(agent_id=effective_agent)
|
|
617
|
+
|
|
618
|
+
agent_profile = session.agents[effective_agent]
|
|
619
|
+
agent_profile.total_cost += record.cost_usd
|
|
620
|
+
agent_profile.input_tokens += input_tokens
|
|
621
|
+
agent_profile.output_tokens += output_tokens
|
|
622
|
+
agent_profile.tool_calls += 1
|
|
623
|
+
|
|
624
|
+
if tool_name:
|
|
625
|
+
agent_profile.by_tool[tool_name] = agent_profile.by_tool.get(tool_name, 0.0) + record.cost_usd
|
|
626
|
+
agent_profile.by_model[model] = agent_profile.by_model.get(model, 0.0) + record.cost_usd
|
|
627
|
+
|
|
628
|
+
if agent_profile.first_activity is None:
|
|
629
|
+
agent_profile.first_activity = record.timestamp
|
|
630
|
+
agent_profile.last_activity = record.timestamp
|
|
631
|
+
|
|
632
|
+
# Store record
|
|
633
|
+
self._session_records[session_id].append(record)
|
|
634
|
+
|
|
635
|
+
# Check budget and alerts
|
|
636
|
+
self._check_budget_alerts(session, record)
|
|
637
|
+
|
|
638
|
+
return record
|
|
639
|
+
|
|
640
|
+
def _check_budget_alerts(self, session: Session, record: UsageRecord) -> None:
|
|
641
|
+
"""Check for budget-related alerts."""
|
|
642
|
+
if session.budget_limit is None:
|
|
643
|
+
return
|
|
644
|
+
|
|
645
|
+
percentage = session.budget_percentage or 0.0
|
|
646
|
+
|
|
647
|
+
# Budget warning
|
|
648
|
+
if (
|
|
649
|
+
percentage >= self._budget_warning_threshold
|
|
650
|
+
and percentage < 1.0
|
|
651
|
+
and not any(a.alert_type == AlertType.BUDGET_WARNING for a in session.alerts)
|
|
652
|
+
):
|
|
653
|
+
alert = CostAlert(
|
|
654
|
+
alert_id=f"alert_{uuid.uuid4().hex[:8]}",
|
|
655
|
+
alert_type=AlertType.BUDGET_WARNING,
|
|
656
|
+
severity=AlertSeverity.WARNING,
|
|
657
|
+
session_id=session.session_id,
|
|
658
|
+
user_id=session.user_id,
|
|
659
|
+
agent_id=session.agent_id,
|
|
660
|
+
message=(
|
|
661
|
+
f"Session approaching budget limit: "
|
|
662
|
+
f"${session.total_cost:.2f}/${session.budget_limit:.2f} "
|
|
663
|
+
f"({percentage:.0%})"
|
|
664
|
+
),
|
|
665
|
+
current_cost=session.total_cost,
|
|
666
|
+
threshold=session.budget_limit * self._budget_warning_threshold,
|
|
667
|
+
)
|
|
668
|
+
self._trigger_alert(session, alert)
|
|
669
|
+
|
|
670
|
+
# Budget exceeded
|
|
671
|
+
if percentage >= 1.0:
|
|
672
|
+
alert = CostAlert(
|
|
673
|
+
alert_id=f"alert_{uuid.uuid4().hex[:8]}",
|
|
674
|
+
alert_type=AlertType.BUDGET_EXCEEDED,
|
|
675
|
+
severity=AlertSeverity.CRITICAL,
|
|
676
|
+
session_id=session.session_id,
|
|
677
|
+
user_id=session.user_id,
|
|
678
|
+
agent_id=session.agent_id,
|
|
679
|
+
message=(
|
|
680
|
+
f"Session budget exceeded: "
|
|
681
|
+
f"${session.total_cost:.2f}/${session.budget_limit:.2f}"
|
|
682
|
+
),
|
|
683
|
+
current_cost=session.total_cost,
|
|
684
|
+
threshold=session.budget_limit,
|
|
685
|
+
)
|
|
686
|
+
self._trigger_alert(session, alert)
|
|
687
|
+
|
|
688
|
+
# Terminate session
|
|
689
|
+
session.state = SessionState.BUDGET_EXCEEDED
|
|
690
|
+
session.end_time = datetime.now(timezone.utc)
|
|
691
|
+
logger.warning(f"Session {session.session_id} terminated: budget exceeded")
|
|
692
|
+
|
|
693
|
+
# Check spend rate
|
|
694
|
+
if session.duration_seconds > 60: # At least 1 minute
|
|
695
|
+
rate_per_minute = session.total_cost / (session.duration_seconds / 60)
|
|
696
|
+
|
|
697
|
+
if (
|
|
698
|
+
rate_per_minute > self._rate_warning_threshold
|
|
699
|
+
and not any(a.alert_type == AlertType.RATE_WARNING for a in session.alerts)
|
|
700
|
+
):
|
|
701
|
+
alert = CostAlert(
|
|
702
|
+
alert_id=f"alert_{uuid.uuid4().hex[:8]}",
|
|
703
|
+
alert_type=AlertType.RATE_WARNING,
|
|
704
|
+
severity=AlertSeverity.WARNING,
|
|
705
|
+
session_id=session.session_id,
|
|
706
|
+
user_id=session.user_id,
|
|
707
|
+
agent_id=session.agent_id,
|
|
708
|
+
message=(
|
|
709
|
+
f"High spend rate detected: ${rate_per_minute:.2f}/minute "
|
|
710
|
+
f"(threshold: ${self._rate_warning_threshold:.2f}/minute)"
|
|
711
|
+
),
|
|
712
|
+
current_cost=session.total_cost,
|
|
713
|
+
threshold=self._rate_warning_threshold,
|
|
714
|
+
metadata={"rate_per_minute": rate_per_minute},
|
|
715
|
+
)
|
|
716
|
+
self._trigger_alert(session, alert)
|
|
717
|
+
|
|
718
|
+
def _trigger_alert(self, session: Session, alert: CostAlert) -> None:
|
|
719
|
+
"""Trigger an alert and notify callbacks."""
|
|
720
|
+
session.alerts.append(alert)
|
|
721
|
+
self._total_alerts += 1
|
|
722
|
+
|
|
723
|
+
logger.warning(f"Cost alert: {alert.message}")
|
|
724
|
+
|
|
725
|
+
# Notify callbacks
|
|
726
|
+
for callback in self._alert_callbacks:
|
|
727
|
+
try:
|
|
728
|
+
callback(alert)
|
|
729
|
+
except Exception as e:
|
|
730
|
+
logger.error(f"Alert callback error: {e}")
|
|
731
|
+
|
|
732
|
+
def _expire_session(self, session: Session) -> None:
|
|
733
|
+
"""Mark a session as expired."""
|
|
734
|
+
session.state = SessionState.EXPIRED
|
|
735
|
+
session.end_time = datetime.now(timezone.utc)
|
|
736
|
+
|
|
737
|
+
alert = CostAlert(
|
|
738
|
+
alert_id=f"alert_{uuid.uuid4().hex[:8]}",
|
|
739
|
+
alert_type=AlertType.SESSION_EXPIRED,
|
|
740
|
+
severity=AlertSeverity.INFO,
|
|
741
|
+
session_id=session.session_id,
|
|
742
|
+
user_id=session.user_id,
|
|
743
|
+
agent_id=session.agent_id,
|
|
744
|
+
message=f"Session expired after {session.timeout_minutes} minutes of inactivity",
|
|
745
|
+
current_cost=session.total_cost,
|
|
746
|
+
)
|
|
747
|
+
self._trigger_alert(session, alert)
|
|
748
|
+
|
|
749
|
+
logger.info(f"Session {session.session_id} expired")
|
|
750
|
+
|
|
751
|
+
def pause_session(self, session_id: str) -> bool:
|
|
752
|
+
"""
|
|
753
|
+
Pause a session.
|
|
754
|
+
|
|
755
|
+
Args:
|
|
756
|
+
session_id: Session to pause.
|
|
757
|
+
|
|
758
|
+
Returns:
|
|
759
|
+
True if paused successfully.
|
|
760
|
+
"""
|
|
761
|
+
with self._lock:
|
|
762
|
+
session = self._sessions.get(session_id)
|
|
763
|
+
if session and session.is_active:
|
|
764
|
+
session.state = SessionState.PAUSED
|
|
765
|
+
logger.info(f"Session {session_id} paused")
|
|
766
|
+
return True
|
|
767
|
+
return False
|
|
768
|
+
|
|
769
|
+
def resume_session(self, session_id: str) -> bool:
|
|
770
|
+
"""
|
|
771
|
+
Resume a paused session.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
session_id: Session to resume.
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
True if resumed successfully.
|
|
778
|
+
"""
|
|
779
|
+
with self._lock:
|
|
780
|
+
session = self._sessions.get(session_id)
|
|
781
|
+
if session and session.state == SessionState.PAUSED:
|
|
782
|
+
session.state = SessionState.ACTIVE
|
|
783
|
+
session.last_activity = datetime.now(timezone.utc)
|
|
784
|
+
logger.info(f"Session {session_id} resumed")
|
|
785
|
+
return True
|
|
786
|
+
return False
|
|
787
|
+
|
|
788
|
+
def end_session(self, session_id: str, reason: str = "user_ended") -> SessionSummary | None:
|
|
789
|
+
"""
|
|
790
|
+
End a session and get summary.
|
|
791
|
+
|
|
792
|
+
Args:
|
|
793
|
+
session_id: Session to end.
|
|
794
|
+
reason: Reason for ending.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
SessionSummary or None if not found.
|
|
798
|
+
"""
|
|
799
|
+
with self._lock:
|
|
800
|
+
session = self._sessions.get(session_id)
|
|
801
|
+
|
|
802
|
+
if session is None:
|
|
803
|
+
return None
|
|
804
|
+
|
|
805
|
+
# Mark as terminated
|
|
806
|
+
session.state = SessionState.TERMINATED
|
|
807
|
+
session.end_time = datetime.now(timezone.utc)
|
|
808
|
+
self._sessions_terminated += 1
|
|
809
|
+
|
|
810
|
+
# Calculate peak spend rate
|
|
811
|
+
peak_rate = 0.0
|
|
812
|
+
records = self._session_records.get(session_id, [])
|
|
813
|
+
if len(records) >= 2:
|
|
814
|
+
# Calculate rolling 1-minute windows
|
|
815
|
+
for i in range(len(records) - 1):
|
|
816
|
+
window_cost = 0.0
|
|
817
|
+
window_start = records[i].timestamp
|
|
818
|
+
|
|
819
|
+
for j in range(i, len(records)):
|
|
820
|
+
if (records[j].timestamp - window_start).total_seconds() <= 60:
|
|
821
|
+
window_cost += records[j].cost_usd
|
|
822
|
+
else:
|
|
823
|
+
break
|
|
824
|
+
|
|
825
|
+
if window_cost > peak_rate:
|
|
826
|
+
peak_rate = window_cost
|
|
827
|
+
|
|
828
|
+
summary = SessionSummary(
|
|
829
|
+
session_id=session.session_id,
|
|
830
|
+
user_id=session.user_id,
|
|
831
|
+
total_cost=session.total_cost,
|
|
832
|
+
total_tokens=session.input_tokens + session.output_tokens,
|
|
833
|
+
duration_seconds=session.duration_seconds,
|
|
834
|
+
tool_breakdown=dict(session.by_tool),
|
|
835
|
+
model_breakdown=dict(session.by_model),
|
|
836
|
+
agent_breakdown={
|
|
837
|
+
agent_id: profile.total_cost
|
|
838
|
+
for agent_id, profile in session.agents.items()
|
|
839
|
+
},
|
|
840
|
+
peak_spend_rate=peak_rate,
|
|
841
|
+
alerts_triggered=len(session.alerts),
|
|
842
|
+
end_reason=reason,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
logger.info(
|
|
846
|
+
f"Session {session_id} ended: ${session.total_cost:.4f}, "
|
|
847
|
+
f"{session.record_count} records, {len(session.alerts)} alerts"
|
|
848
|
+
)
|
|
849
|
+
|
|
850
|
+
return summary
|
|
851
|
+
|
|
852
|
+
def get_session_records(
|
|
853
|
+
self,
|
|
854
|
+
session_id: str,
|
|
855
|
+
limit: int | None = None,
|
|
856
|
+
) -> list[UsageRecord]:
|
|
857
|
+
"""
|
|
858
|
+
Get usage records for a session.
|
|
859
|
+
|
|
860
|
+
Args:
|
|
861
|
+
session_id: Session to get records for.
|
|
862
|
+
limit: Maximum records to return.
|
|
863
|
+
|
|
864
|
+
Returns:
|
|
865
|
+
List of usage records.
|
|
866
|
+
"""
|
|
867
|
+
with self._lock:
|
|
868
|
+
records = self._session_records.get(session_id, [])
|
|
869
|
+
if limit:
|
|
870
|
+
return list(records[-limit:])
|
|
871
|
+
return list(records)
|
|
872
|
+
|
|
873
|
+
def forecast_session_cost(
|
|
874
|
+
self,
|
|
875
|
+
session_id: str,
|
|
876
|
+
duration_minutes: int = 60,
|
|
877
|
+
) -> float | None:
|
|
878
|
+
"""
|
|
879
|
+
Forecast session cost based on current usage pattern.
|
|
880
|
+
|
|
881
|
+
Args:
|
|
882
|
+
session_id: Session to forecast.
|
|
883
|
+
duration_minutes: Minutes to forecast.
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
Forecasted additional cost, or None if insufficient data.
|
|
887
|
+
"""
|
|
888
|
+
if not self._enable_forecasting:
|
|
889
|
+
return None
|
|
890
|
+
|
|
891
|
+
with self._lock:
|
|
892
|
+
session = self._sessions.get(session_id)
|
|
893
|
+
if session is None:
|
|
894
|
+
return None
|
|
895
|
+
|
|
896
|
+
# Need at least 2 minutes of data
|
|
897
|
+
if session.duration_seconds < 120:
|
|
898
|
+
return None
|
|
899
|
+
|
|
900
|
+
# Calculate rate
|
|
901
|
+
rate_per_minute = session.total_cost / (session.duration_seconds / 60)
|
|
902
|
+
|
|
903
|
+
return rate_per_minute * duration_minutes
|
|
904
|
+
|
|
905
|
+
def _cleanup_expired_sessions(self) -> int:
|
|
906
|
+
"""Clean up expired sessions. Returns number cleaned."""
|
|
907
|
+
cleaned = 0
|
|
908
|
+
expired_ids = []
|
|
909
|
+
|
|
910
|
+
for session_id, session in self._sessions.items():
|
|
911
|
+
if session.is_expired or session.state in (
|
|
912
|
+
SessionState.TERMINATED,
|
|
913
|
+
SessionState.EXPIRED,
|
|
914
|
+
):
|
|
915
|
+
# Only clean if ended more than 1 hour ago
|
|
916
|
+
if session.end_time:
|
|
917
|
+
age = datetime.now(timezone.utc) - session.end_time
|
|
918
|
+
if age > timedelta(hours=1):
|
|
919
|
+
expired_ids.append(session_id)
|
|
920
|
+
|
|
921
|
+
for session_id in expired_ids:
|
|
922
|
+
del self._sessions[session_id]
|
|
923
|
+
if session_id in self._session_records:
|
|
924
|
+
del self._session_records[session_id]
|
|
925
|
+
cleaned += 1
|
|
926
|
+
|
|
927
|
+
if cleaned:
|
|
928
|
+
logger.info(f"Cleaned up {cleaned} expired sessions")
|
|
929
|
+
|
|
930
|
+
return cleaned
|
|
931
|
+
|
|
932
|
+
def get_user_total_cost(
|
|
933
|
+
self,
|
|
934
|
+
user_id: str,
|
|
935
|
+
period: timedelta | None = None,
|
|
936
|
+
) -> float:
|
|
937
|
+
"""
|
|
938
|
+
Get total cost for a user across all sessions.
|
|
939
|
+
|
|
940
|
+
Args:
|
|
941
|
+
user_id: User to check.
|
|
942
|
+
period: Time period (None for all time).
|
|
943
|
+
|
|
944
|
+
Returns:
|
|
945
|
+
Total cost in USD.
|
|
946
|
+
"""
|
|
947
|
+
with self._lock:
|
|
948
|
+
if period:
|
|
949
|
+
return self._base_tracker.get_user_spend(user_id, period)
|
|
950
|
+
|
|
951
|
+
# Sum across all sessions
|
|
952
|
+
total = 0.0
|
|
953
|
+
for session_id in self._user_sessions.get(user_id, []):
|
|
954
|
+
session = self._sessions.get(session_id)
|
|
955
|
+
if session:
|
|
956
|
+
total += session.total_cost
|
|
957
|
+
|
|
958
|
+
return total
|
|
959
|
+
|
|
960
|
+
def get_agent_total_cost(self, agent_id: str) -> float:
|
|
961
|
+
"""
|
|
962
|
+
Get total cost attributed to an agent.
|
|
963
|
+
|
|
964
|
+
Args:
|
|
965
|
+
agent_id: Agent to check.
|
|
966
|
+
|
|
967
|
+
Returns:
|
|
968
|
+
Total cost in USD.
|
|
969
|
+
"""
|
|
970
|
+
with self._lock:
|
|
971
|
+
total = 0.0
|
|
972
|
+
|
|
973
|
+
for session in self._sessions.values():
|
|
974
|
+
if agent_id in session.agents:
|
|
975
|
+
total += session.agents[agent_id].total_cost
|
|
976
|
+
|
|
977
|
+
return total
|
|
978
|
+
|
|
979
|
+
def get_stats(self) -> dict[str, Any]:
|
|
980
|
+
"""Get tracker statistics."""
|
|
981
|
+
with self._lock:
|
|
982
|
+
active_sessions = sum(1 for s in self._sessions.values() if s.is_active)
|
|
983
|
+
|
|
984
|
+
return {
|
|
985
|
+
"total_sessions": len(self._sessions),
|
|
986
|
+
"active_sessions": active_sessions,
|
|
987
|
+
"sessions_created": self._sessions_created,
|
|
988
|
+
"sessions_terminated": self._sessions_terminated,
|
|
989
|
+
"total_alerts": self._total_alerts,
|
|
990
|
+
"users_tracked": len(self._user_sessions),
|
|
991
|
+
}
|
|
992
|
+
|
|
993
|
+
def export_session(
|
|
994
|
+
self,
|
|
995
|
+
session_id: str,
|
|
996
|
+
format: str = "json",
|
|
997
|
+
) -> str | None:
|
|
998
|
+
"""
|
|
999
|
+
Export session data for billing/audit.
|
|
1000
|
+
|
|
1001
|
+
Args:
|
|
1002
|
+
session_id: Session to export.
|
|
1003
|
+
format: Output format ("json" or "csv").
|
|
1004
|
+
|
|
1005
|
+
Returns:
|
|
1006
|
+
Exported data as string, or None if not found.
|
|
1007
|
+
"""
|
|
1008
|
+
with self._lock:
|
|
1009
|
+
session = self._sessions.get(session_id)
|
|
1010
|
+
if session is None:
|
|
1011
|
+
return None
|
|
1012
|
+
|
|
1013
|
+
records = self._session_records.get(session_id, [])
|
|
1014
|
+
|
|
1015
|
+
if format == "csv":
|
|
1016
|
+
lines = ["timestamp,model,input_tokens,output_tokens,cost_usd,tool_name,agent_id"]
|
|
1017
|
+
for record in records:
|
|
1018
|
+
agent_id = record.metadata.get("agent_id", "")
|
|
1019
|
+
lines.append(
|
|
1020
|
+
f"{record.timestamp.isoformat()},"
|
|
1021
|
+
f"{record.model},"
|
|
1022
|
+
f"{record.input_tokens},"
|
|
1023
|
+
f"{record.output_tokens},"
|
|
1024
|
+
f"{record.cost_usd:.6f},"
|
|
1025
|
+
f"{record.tool_name or ''},"
|
|
1026
|
+
f"{agent_id}"
|
|
1027
|
+
)
|
|
1028
|
+
return "\n".join(lines)
|
|
1029
|
+
|
|
1030
|
+
else:
|
|
1031
|
+
return json.dumps({
|
|
1032
|
+
"session": session.to_dict(),
|
|
1033
|
+
"records": [r.to_dict() for r in records],
|
|
1034
|
+
}, indent=2)
|
|
1035
|
+
|
|
1036
|
+
@property
|
|
1037
|
+
def base_tracker(self) -> CostTracker:
|
|
1038
|
+
"""Get the underlying CostTracker."""
|
|
1039
|
+
return self._base_tracker
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
def create_session_cost_tracker(
|
|
1043
|
+
budget_policy: BudgetPolicy | None = None,
|
|
1044
|
+
default_session_budget: float | None = None,
|
|
1045
|
+
alert_callback: AlertCallback | None = None,
|
|
1046
|
+
) -> SessionCostTracker:
|
|
1047
|
+
"""
|
|
1048
|
+
Factory function to create a SessionCostTracker.
|
|
1049
|
+
|
|
1050
|
+
Args:
|
|
1051
|
+
budget_policy: Global budget policy.
|
|
1052
|
+
default_session_budget: Default budget for new sessions.
|
|
1053
|
+
alert_callback: Optional alert callback.
|
|
1054
|
+
|
|
1055
|
+
Returns:
|
|
1056
|
+
Configured SessionCostTracker.
|
|
1057
|
+
"""
|
|
1058
|
+
tracker = SessionCostTracker(budget_policy=budget_policy)
|
|
1059
|
+
|
|
1060
|
+
if alert_callback:
|
|
1061
|
+
tracker.add_alert_callback(alert_callback)
|
|
1062
|
+
|
|
1063
|
+
return tracker
|