proxilion 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. proxilion/__init__.py +136 -0
  2. proxilion/audit/__init__.py +133 -0
  3. proxilion/audit/base_exporters.py +527 -0
  4. proxilion/audit/compliance/__init__.py +130 -0
  5. proxilion/audit/compliance/base.py +457 -0
  6. proxilion/audit/compliance/eu_ai_act.py +603 -0
  7. proxilion/audit/compliance/iso27001.py +544 -0
  8. proxilion/audit/compliance/soc2.py +491 -0
  9. proxilion/audit/events.py +493 -0
  10. proxilion/audit/explainability.py +1173 -0
  11. proxilion/audit/exporters/__init__.py +58 -0
  12. proxilion/audit/exporters/aws_s3.py +636 -0
  13. proxilion/audit/exporters/azure_storage.py +608 -0
  14. proxilion/audit/exporters/cloud_base.py +468 -0
  15. proxilion/audit/exporters/gcp_storage.py +570 -0
  16. proxilion/audit/exporters/multi_exporter.py +498 -0
  17. proxilion/audit/hash_chain.py +652 -0
  18. proxilion/audit/logger.py +543 -0
  19. proxilion/caching/__init__.py +49 -0
  20. proxilion/caching/tool_cache.py +633 -0
  21. proxilion/context/__init__.py +73 -0
  22. proxilion/context/context_window.py +556 -0
  23. proxilion/context/message_history.py +505 -0
  24. proxilion/context/session.py +735 -0
  25. proxilion/contrib/__init__.py +51 -0
  26. proxilion/contrib/anthropic.py +609 -0
  27. proxilion/contrib/google.py +1012 -0
  28. proxilion/contrib/langchain.py +641 -0
  29. proxilion/contrib/mcp.py +893 -0
  30. proxilion/contrib/openai.py +646 -0
  31. proxilion/core.py +3058 -0
  32. proxilion/decorators.py +966 -0
  33. proxilion/engines/__init__.py +287 -0
  34. proxilion/engines/base.py +266 -0
  35. proxilion/engines/casbin_engine.py +412 -0
  36. proxilion/engines/opa_engine.py +493 -0
  37. proxilion/engines/simple.py +437 -0
  38. proxilion/exceptions.py +887 -0
  39. proxilion/guards/__init__.py +54 -0
  40. proxilion/guards/input_guard.py +522 -0
  41. proxilion/guards/output_guard.py +634 -0
  42. proxilion/observability/__init__.py +198 -0
  43. proxilion/observability/cost_tracker.py +866 -0
  44. proxilion/observability/hooks.py +683 -0
  45. proxilion/observability/metrics.py +798 -0
  46. proxilion/observability/session_cost_tracker.py +1063 -0
  47. proxilion/policies/__init__.py +67 -0
  48. proxilion/policies/base.py +304 -0
  49. proxilion/policies/builtin.py +486 -0
  50. proxilion/policies/registry.py +376 -0
  51. proxilion/providers/__init__.py +201 -0
  52. proxilion/providers/adapter.py +468 -0
  53. proxilion/providers/anthropic_adapter.py +330 -0
  54. proxilion/providers/gemini_adapter.py +391 -0
  55. proxilion/providers/openai_adapter.py +294 -0
  56. proxilion/py.typed +0 -0
  57. proxilion/resilience/__init__.py +81 -0
  58. proxilion/resilience/degradation.py +615 -0
  59. proxilion/resilience/fallback.py +555 -0
  60. proxilion/resilience/retry.py +554 -0
  61. proxilion/scheduling/__init__.py +57 -0
  62. proxilion/scheduling/priority_queue.py +419 -0
  63. proxilion/scheduling/scheduler.py +459 -0
  64. proxilion/security/__init__.py +244 -0
  65. proxilion/security/agent_trust.py +968 -0
  66. proxilion/security/behavioral_drift.py +794 -0
  67. proxilion/security/cascade_protection.py +869 -0
  68. proxilion/security/circuit_breaker.py +428 -0
  69. proxilion/security/cost_limiter.py +690 -0
  70. proxilion/security/idor_protection.py +460 -0
  71. proxilion/security/intent_capsule.py +849 -0
  72. proxilion/security/intent_validator.py +495 -0
  73. proxilion/security/memory_integrity.py +767 -0
  74. proxilion/security/rate_limiter.py +509 -0
  75. proxilion/security/scope_enforcer.py +680 -0
  76. proxilion/security/sequence_validator.py +636 -0
  77. proxilion/security/trust_boundaries.py +784 -0
  78. proxilion/streaming/__init__.py +70 -0
  79. proxilion/streaming/detector.py +761 -0
  80. proxilion/streaming/transformer.py +674 -0
  81. proxilion/timeouts/__init__.py +55 -0
  82. proxilion/timeouts/decorators.py +477 -0
  83. proxilion/timeouts/manager.py +545 -0
  84. proxilion/tools/__init__.py +69 -0
  85. proxilion/tools/decorators.py +493 -0
  86. proxilion/tools/registry.py +732 -0
  87. proxilion/types.py +339 -0
  88. proxilion/validation/__init__.py +93 -0
  89. proxilion/validation/pydantic_schema.py +351 -0
  90. proxilion/validation/schema.py +651 -0
  91. proxilion-0.0.1.dist-info/METADATA +872 -0
  92. proxilion-0.0.1.dist-info/RECORD +94 -0
  93. proxilion-0.0.1.dist-info/WHEEL +4 -0
  94. proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1063 @@
1
+ """
2
+ Enhanced Session-Based Cost Tracking for Proxilion.
3
+
4
+ Extends the base CostTracker with per-user and per-agent session
5
+ tracking, real-time budget alerts, and detailed cost attribution.
6
+
7
+ Features:
8
+ - Per-session cost tracking with automatic session management
9
+ - Per-agent cost attribution in multi-agent systems
10
+ - Real-time budget alerts and callbacks
11
+ - Cost breakdown by tool, model, and time period
12
+ - Session cost limits and automatic termination
13
+ - Cost forecasting based on usage patterns
14
+ - Export formats for billing integration
15
+
16
+ Example:
17
+ >>> from proxilion.observability.session_cost_tracker import (
18
+ ... SessionCostTracker,
19
+ ... Session,
20
+ ... CostAlert,
21
+ ... )
22
+ >>>
23
+ >>> # Create tracker with session support
24
+ >>> tracker = SessionCostTracker()
25
+ >>>
26
+ >>> # Start a user session
27
+ >>> session = tracker.start_session(
28
+ ... user_id="user_123",
29
+ ... agent_id="assistant_main",
30
+ ... budget_limit=10.00,
31
+ ... )
32
+ >>>
33
+ >>> # Record usage
34
+ >>> tracker.record_session_usage(
35
+ ... session_id=session.session_id,
36
+ ... model="claude-sonnet-4-20250514",
37
+ ... input_tokens=1000,
38
+ ... output_tokens=500,
39
+ ... tool_name="search",
40
+ ... )
41
+ >>>
42
+ >>> # Check session costs
43
+ >>> print(f"Session cost: ${session.total_cost:.4f}")
44
+ >>> print(f"Budget remaining: ${session.budget_remaining:.4f}")
45
+ >>>
46
+ >>> # End session
47
+ >>> summary = tracker.end_session(session.session_id)
48
+ >>> print(f"Final cost: ${summary.total_cost:.4f}")
49
+ """
50
+
51
+ from __future__ import annotations
52
+
53
+ import hashlib
54
+ import json
55
+ import logging
56
+ import threading
57
+ import uuid
58
+ from collections import defaultdict
59
+ from dataclasses import asdict, dataclass, field
60
+ from datetime import datetime, timedelta, timezone
61
+ from enum import Enum
62
+ from typing import Any, Callable
63
+
64
+ from proxilion.observability.cost_tracker import (
65
+ BudgetPolicy,
66
+ CostSummary,
67
+ CostTracker,
68
+ ModelPricing,
69
+ UsageRecord,
70
+ DEFAULT_PRICING,
71
+ )
72
+
73
+ logger = logging.getLogger(__name__)
74
+
75
+
76
+ class SessionState(str, Enum):
77
+ """Session lifecycle states."""
78
+
79
+ ACTIVE = "active"
80
+ PAUSED = "paused"
81
+ BUDGET_EXCEEDED = "budget_exceeded"
82
+ TERMINATED = "terminated"
83
+ EXPIRED = "expired"
84
+
85
+
86
+ class AlertSeverity(str, Enum):
87
+ """Alert severity levels."""
88
+
89
+ INFO = "info"
90
+ WARNING = "warning"
91
+ CRITICAL = "critical"
92
+
93
+
94
+ class AlertType(str, Enum):
95
+ """Types of cost alerts."""
96
+
97
+ BUDGET_WARNING = "budget_warning" # Approaching budget limit
98
+ BUDGET_EXCEEDED = "budget_exceeded" # Budget limit hit
99
+ RATE_WARNING = "rate_warning" # High spend rate
100
+ ANOMALY = "anomaly" # Unusual spending pattern
101
+ SESSION_EXPIRED = "session_expired" # Session timeout
102
+ FORECAST_WARNING = "forecast_warning" # Projected to exceed budget
103
+
104
+
105
+ @dataclass
106
+ class CostAlert:
107
+ """
108
+ A cost-related alert.
109
+
110
+ Attributes:
111
+ alert_id: Unique identifier.
112
+ alert_type: Type of alert.
113
+ severity: Alert severity level.
114
+ session_id: Associated session (if any).
115
+ user_id: Associated user.
116
+ agent_id: Associated agent (if any).
117
+ message: Human-readable message.
118
+ current_cost: Current cost when alert was triggered.
119
+ threshold: Threshold that was crossed.
120
+ timestamp: When the alert was created.
121
+ metadata: Additional metadata.
122
+ """
123
+
124
+ alert_id: str
125
+ alert_type: AlertType
126
+ severity: AlertSeverity
127
+ message: str
128
+ current_cost: float
129
+ threshold: float | None = None
130
+ session_id: str | None = None
131
+ user_id: str | None = None
132
+ agent_id: str | None = None
133
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
134
+ metadata: dict[str, Any] = field(default_factory=dict)
135
+
136
+ def to_dict(self) -> dict[str, Any]:
137
+ """Convert to dictionary."""
138
+ return {
139
+ "alert_id": self.alert_id,
140
+ "alert_type": self.alert_type.value,
141
+ "severity": self.severity.value,
142
+ "message": self.message,
143
+ "current_cost": self.current_cost,
144
+ "threshold": self.threshold,
145
+ "session_id": self.session_id,
146
+ "user_id": self.user_id,
147
+ "agent_id": self.agent_id,
148
+ "timestamp": self.timestamp.isoformat(),
149
+ "metadata": self.metadata,
150
+ }
151
+
152
+
153
+ @dataclass
154
+ class AgentCostProfile:
155
+ """
156
+ Cost profile for an agent within a session.
157
+
158
+ Tracks costs attributed to a specific agent in a multi-agent
159
+ system, including delegation costs.
160
+
161
+ Attributes:
162
+ agent_id: Unique agent identifier.
163
+ parent_agent_id: Parent agent (if delegated).
164
+ total_cost: Total cost attributed to this agent.
165
+ input_tokens: Total input tokens.
166
+ output_tokens: Total output tokens.
167
+ tool_calls: Number of tool calls.
168
+ by_tool: Cost breakdown by tool.
169
+ by_model: Cost breakdown by model.
170
+ first_activity: First activity timestamp.
171
+ last_activity: Last activity timestamp.
172
+ """
173
+
174
+ agent_id: str
175
+ parent_agent_id: str | None = None
176
+ total_cost: float = 0.0
177
+ input_tokens: int = 0
178
+ output_tokens: int = 0
179
+ tool_calls: int = 0
180
+ by_tool: dict[str, float] = field(default_factory=dict)
181
+ by_model: dict[str, float] = field(default_factory=dict)
182
+ first_activity: datetime | None = None
183
+ last_activity: datetime | None = None
184
+
185
+ def to_dict(self) -> dict[str, Any]:
186
+ """Convert to dictionary."""
187
+ return {
188
+ "agent_id": self.agent_id,
189
+ "parent_agent_id": self.parent_agent_id,
190
+ "total_cost": self.total_cost,
191
+ "input_tokens": self.input_tokens,
192
+ "output_tokens": self.output_tokens,
193
+ "tool_calls": self.tool_calls,
194
+ "by_tool": self.by_tool,
195
+ "by_model": self.by_model,
196
+ "first_activity": self.first_activity.isoformat() if self.first_activity else None,
197
+ "last_activity": self.last_activity.isoformat() if self.last_activity else None,
198
+ }
199
+
200
+
201
+ @dataclass
202
+ class Session:
203
+ """
204
+ A user/agent session with cost tracking.
205
+
206
+ Attributes:
207
+ session_id: Unique session identifier.
208
+ user_id: User who owns the session.
209
+ agent_id: Primary agent for the session.
210
+ state: Current session state.
211
+ budget_limit: Maximum cost allowed for this session.
212
+ total_cost: Total cost incurred so far.
213
+ input_tokens: Total input tokens used.
214
+ output_tokens: Total output tokens used.
215
+ record_count: Number of usage records.
216
+ agents: Per-agent cost profiles.
217
+ by_tool: Cost breakdown by tool.
218
+ by_model: Cost breakdown by model.
219
+ start_time: When the session started.
220
+ end_time: When the session ended (if ended).
221
+ last_activity: Last activity timestamp.
222
+ timeout_minutes: Session timeout in minutes.
223
+ metadata: Additional session metadata.
224
+ alerts: Alerts triggered during session.
225
+ """
226
+
227
+ session_id: str
228
+ user_id: str
229
+ state: SessionState = SessionState.ACTIVE
230
+ agent_id: str | None = None
231
+ budget_limit: float | None = None
232
+ total_cost: float = 0.0
233
+ input_tokens: int = 0
234
+ output_tokens: int = 0
235
+ record_count: int = 0
236
+ agents: dict[str, AgentCostProfile] = field(default_factory=dict)
237
+ by_tool: dict[str, float] = field(default_factory=dict)
238
+ by_model: dict[str, float] = field(default_factory=dict)
239
+ start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
240
+ end_time: datetime | None = None
241
+ last_activity: datetime | None = None
242
+ timeout_minutes: int = 60
243
+ metadata: dict[str, Any] = field(default_factory=dict)
244
+ alerts: list[CostAlert] = field(default_factory=list)
245
+
246
+ @property
247
+ def is_active(self) -> bool:
248
+ """Whether the session is active."""
249
+ return self.state == SessionState.ACTIVE
250
+
251
+ @property
252
+ def budget_remaining(self) -> float | None:
253
+ """Remaining budget, or None if no limit."""
254
+ if self.budget_limit is None:
255
+ return None
256
+ return max(0.0, self.budget_limit - self.total_cost)
257
+
258
+ @property
259
+ def budget_percentage(self) -> float | None:
260
+ """Percentage of budget used, or None if no limit."""
261
+ if self.budget_limit is None or self.budget_limit == 0:
262
+ return None
263
+ return min(1.0, self.total_cost / self.budget_limit)
264
+
265
+ @property
266
+ def duration_seconds(self) -> float:
267
+ """Session duration in seconds."""
268
+ end = self.end_time or datetime.now(timezone.utc)
269
+ return (end - self.start_time).total_seconds()
270
+
271
+ @property
272
+ def is_expired(self) -> bool:
273
+ """Whether the session has expired due to timeout."""
274
+ if self.last_activity is None:
275
+ return False
276
+ inactive = datetime.now(timezone.utc) - self.last_activity
277
+ return inactive > timedelta(minutes=self.timeout_minutes)
278
+
279
+ def to_dict(self) -> dict[str, Any]:
280
+ """Convert to dictionary."""
281
+ return {
282
+ "session_id": self.session_id,
283
+ "user_id": self.user_id,
284
+ "agent_id": self.agent_id,
285
+ "state": self.state.value,
286
+ "budget_limit": self.budget_limit,
287
+ "budget_remaining": self.budget_remaining,
288
+ "budget_percentage": self.budget_percentage,
289
+ "total_cost": self.total_cost,
290
+ "input_tokens": self.input_tokens,
291
+ "output_tokens": self.output_tokens,
292
+ "record_count": self.record_count,
293
+ "agents": {k: v.to_dict() for k, v in self.agents.items()},
294
+ "by_tool": self.by_tool,
295
+ "by_model": self.by_model,
296
+ "start_time": self.start_time.isoformat(),
297
+ "end_time": self.end_time.isoformat() if self.end_time else None,
298
+ "last_activity": self.last_activity.isoformat() if self.last_activity else None,
299
+ "duration_seconds": self.duration_seconds,
300
+ "timeout_minutes": self.timeout_minutes,
301
+ "metadata": self.metadata,
302
+ "alerts": [a.to_dict() for a in self.alerts],
303
+ }
304
+
305
+ def to_json(self) -> str:
306
+ """Convert to JSON string."""
307
+ return json.dumps(self.to_dict(), indent=2)
308
+
309
+
310
+ @dataclass
311
+ class SessionSummary:
312
+ """
313
+ Summary of a completed session.
314
+
315
+ Attributes:
316
+ session_id: Session identifier.
317
+ user_id: User who owned the session.
318
+ total_cost: Total cost incurred.
319
+ total_tokens: Total tokens used.
320
+ duration_seconds: Session duration.
321
+ tool_breakdown: Cost by tool.
322
+ model_breakdown: Cost by model.
323
+ agent_breakdown: Cost by agent.
324
+ peak_spend_rate: Highest spend rate ($/minute).
325
+ alerts_triggered: Number of alerts triggered.
326
+ end_reason: Why the session ended.
327
+ """
328
+
329
+ session_id: str
330
+ user_id: str
331
+ total_cost: float
332
+ total_tokens: int
333
+ duration_seconds: float
334
+ tool_breakdown: dict[str, float]
335
+ model_breakdown: dict[str, float]
336
+ agent_breakdown: dict[str, float]
337
+ peak_spend_rate: float = 0.0
338
+ alerts_triggered: int = 0
339
+ end_reason: str = "user_ended"
340
+
341
+ def to_dict(self) -> dict[str, Any]:
342
+ """Convert to dictionary."""
343
+ return asdict(self)
344
+
345
+
346
+ # Type alias for alert callbacks
347
+ AlertCallback = Callable[[CostAlert], None]
348
+
349
+
350
+ class SessionCostTracker:
351
+ """
352
+ Enhanced cost tracker with session management.
353
+
354
+ Provides per-user and per-agent cost tracking with automatic
355
+ session management, budget enforcement, and alerting.
356
+
357
+ Example:
358
+ >>> tracker = SessionCostTracker(
359
+ ... budget_policy=BudgetPolicy(max_cost_per_user_per_day=100.0)
360
+ ... )
361
+ >>>
362
+ >>> # Start session with budget
363
+ >>> session = tracker.start_session(
364
+ ... user_id="alice",
365
+ ... budget_limit=10.0,
366
+ ... )
367
+ >>>
368
+ >>> # Record usage
369
+ >>> record = tracker.record_session_usage(
370
+ ... session_id=session.session_id,
371
+ ... model="claude-sonnet-4-20250514",
372
+ ... input_tokens=1000,
373
+ ... output_tokens=500,
374
+ ... )
375
+ >>>
376
+ >>> # Check status
377
+ >>> print(f"Session cost: ${session.total_cost:.4f}")
378
+ >>> print(f"Remaining: ${session.budget_remaining:.4f}")
379
+ """
380
+
381
+ def __init__(
382
+ self,
383
+ base_tracker: CostTracker | None = None,
384
+ budget_policy: BudgetPolicy | None = None,
385
+ default_session_timeout: int = 60,
386
+ budget_warning_threshold: float = 0.8,
387
+ rate_warning_threshold: float = 1.0, # $/minute
388
+ max_sessions: int = 10000,
389
+ enable_forecasting: bool = True,
390
+ ) -> None:
391
+ """
392
+ Initialize the session cost tracker.
393
+
394
+ Args:
395
+ base_tracker: Optional base CostTracker to wrap.
396
+ budget_policy: Budget policy for global limits.
397
+ default_session_timeout: Default session timeout in minutes.
398
+ budget_warning_threshold: Percentage at which to warn (0.0 to 1.0).
399
+ rate_warning_threshold: Spend rate ($/min) that triggers warning.
400
+ max_sessions: Maximum active sessions to track.
401
+ enable_forecasting: Whether to enable cost forecasting.
402
+ """
403
+ self._lock = threading.RLock()
404
+
405
+ # Use provided tracker or create new one
406
+ self._base_tracker = base_tracker or CostTracker(budget_policy=budget_policy)
407
+
408
+ self._default_timeout = default_session_timeout
409
+ self._budget_warning_threshold = budget_warning_threshold
410
+ self._rate_warning_threshold = rate_warning_threshold
411
+ self._max_sessions = max_sessions
412
+ self._enable_forecasting = enable_forecasting
413
+
414
+ # Session storage
415
+ self._sessions: dict[str, Session] = {}
416
+ self._user_sessions: dict[str, list[str]] = defaultdict(list)
417
+ self._session_records: dict[str, list[UsageRecord]] = defaultdict(list)
418
+
419
+ # Alert callbacks
420
+ self._alert_callbacks: list[AlertCallback] = []
421
+
422
+ # Metrics
423
+ self._total_alerts = 0
424
+ self._sessions_created = 0
425
+ self._sessions_terminated = 0
426
+
427
+ def add_alert_callback(self, callback: AlertCallback) -> None:
428
+ """
429
+ Register a callback for cost alerts.
430
+
431
+ Args:
432
+ callback: Function to call when an alert is triggered.
433
+ """
434
+ self._alert_callbacks.append(callback)
435
+
436
+ def remove_alert_callback(self, callback: AlertCallback) -> None:
437
+ """Remove an alert callback."""
438
+ if callback in self._alert_callbacks:
439
+ self._alert_callbacks.remove(callback)
440
+
441
+ def start_session(
442
+ self,
443
+ user_id: str,
444
+ agent_id: str | None = None,
445
+ budget_limit: float | None = None,
446
+ timeout_minutes: int | None = None,
447
+ metadata: dict[str, Any] | None = None,
448
+ ) -> Session:
449
+ """
450
+ Start a new cost tracking session.
451
+
452
+ Args:
453
+ user_id: User who owns the session.
454
+ agent_id: Primary agent for the session.
455
+ budget_limit: Maximum cost for this session.
456
+ timeout_minutes: Session timeout (uses default if not provided).
457
+ metadata: Additional metadata to store.
458
+
459
+ Returns:
460
+ The created Session.
461
+ """
462
+ session_id = f"sess_{uuid.uuid4().hex[:12]}"
463
+
464
+ session = Session(
465
+ session_id=session_id,
466
+ user_id=user_id,
467
+ agent_id=agent_id,
468
+ budget_limit=budget_limit,
469
+ timeout_minutes=timeout_minutes or self._default_timeout,
470
+ metadata=metadata or {},
471
+ last_activity=datetime.now(timezone.utc),
472
+ )
473
+
474
+ # Register primary agent if provided
475
+ if agent_id:
476
+ session.agents[agent_id] = AgentCostProfile(agent_id=agent_id)
477
+
478
+ with self._lock:
479
+ # Clean up expired sessions if at capacity
480
+ if len(self._sessions) >= self._max_sessions:
481
+ self._cleanup_expired_sessions()
482
+
483
+ self._sessions[session_id] = session
484
+ self._user_sessions[user_id].append(session_id)
485
+ self._sessions_created += 1
486
+
487
+ budget_str = f"${budget_limit:.2f}" if budget_limit else "unlimited"
488
+ logger.info(
489
+ f"Started session {session_id} for user {user_id} "
490
+ f"(budget: {budget_str})"
491
+ )
492
+
493
+ return session
494
+
495
+ def get_session(self, session_id: str) -> Session | None:
496
+ """Get a session by ID."""
497
+ with self._lock:
498
+ session = self._sessions.get(session_id)
499
+
500
+ # Check for expiration
501
+ if session and session.is_expired and session.is_active:
502
+ self._expire_session(session)
503
+
504
+ return session
505
+
506
+ def get_user_sessions(
507
+ self,
508
+ user_id: str,
509
+ active_only: bool = True,
510
+ ) -> list[Session]:
511
+ """
512
+ Get all sessions for a user.
513
+
514
+ Args:
515
+ user_id: User to get sessions for.
516
+ active_only: Whether to return only active sessions.
517
+
518
+ Returns:
519
+ List of sessions.
520
+ """
521
+ with self._lock:
522
+ session_ids = self._user_sessions.get(user_id, [])
523
+ sessions = []
524
+
525
+ for sid in session_ids:
526
+ session = self._sessions.get(sid)
527
+ if session:
528
+ if active_only and not session.is_active:
529
+ continue
530
+ sessions.append(session)
531
+
532
+ return sessions
533
+
534
+ def record_session_usage(
535
+ self,
536
+ session_id: str,
537
+ model: str,
538
+ input_tokens: int,
539
+ output_tokens: int,
540
+ cache_read_tokens: int = 0,
541
+ cache_write_tokens: int = 0,
542
+ tool_name: str | None = None,
543
+ agent_id: str | None = None,
544
+ request_id: str | None = None,
545
+ metadata: dict[str, Any] | None = None,
546
+ ) -> UsageRecord | None:
547
+ """
548
+ Record usage for a session.
549
+
550
+ Args:
551
+ session_id: Session to record usage for.
552
+ model: Model used.
553
+ input_tokens: Number of input tokens.
554
+ output_tokens: Number of output tokens.
555
+ cache_read_tokens: Cached tokens read.
556
+ cache_write_tokens: Tokens written to cache.
557
+ tool_name: Tool that triggered the usage.
558
+ agent_id: Agent that incurred the usage.
559
+ request_id: Request identifier.
560
+ metadata: Additional metadata.
561
+
562
+ Returns:
563
+ UsageRecord if successful, None if session not found or inactive.
564
+ """
565
+ with self._lock:
566
+ session = self._sessions.get(session_id)
567
+
568
+ if session is None:
569
+ logger.warning(f"Session {session_id} not found")
570
+ return None
571
+
572
+ if not session.is_active:
573
+ logger.warning(f"Session {session_id} is not active (state: {session.state})")
574
+ return None
575
+
576
+ # Check for expiration
577
+ if session.is_expired:
578
+ self._expire_session(session)
579
+ return None
580
+
581
+ # Record in base tracker
582
+ record = self._base_tracker.record_usage(
583
+ model=model,
584
+ input_tokens=input_tokens,
585
+ output_tokens=output_tokens,
586
+ cache_read_tokens=cache_read_tokens,
587
+ cache_write_tokens=cache_write_tokens,
588
+ tool_name=tool_name,
589
+ user_id=session.user_id,
590
+ request_id=request_id,
591
+ metadata={
592
+ **(metadata or {}),
593
+ "session_id": session_id,
594
+ "agent_id": agent_id,
595
+ },
596
+ )
597
+
598
+ # Update session
599
+ session.total_cost += record.cost_usd
600
+ session.input_tokens += input_tokens
601
+ session.output_tokens += output_tokens
602
+ session.record_count += 1
603
+ session.last_activity = record.timestamp
604
+
605
+ # Update tool breakdown
606
+ if tool_name:
607
+ session.by_tool[tool_name] = session.by_tool.get(tool_name, 0.0) + record.cost_usd
608
+
609
+ # Update model breakdown
610
+ session.by_model[model] = session.by_model.get(model, 0.0) + record.cost_usd
611
+
612
+ # Update agent breakdown
613
+ effective_agent = agent_id or session.agent_id
614
+ if effective_agent:
615
+ if effective_agent not in session.agents:
616
+ session.agents[effective_agent] = AgentCostProfile(agent_id=effective_agent)
617
+
618
+ agent_profile = session.agents[effective_agent]
619
+ agent_profile.total_cost += record.cost_usd
620
+ agent_profile.input_tokens += input_tokens
621
+ agent_profile.output_tokens += output_tokens
622
+ agent_profile.tool_calls += 1
623
+
624
+ if tool_name:
625
+ agent_profile.by_tool[tool_name] = agent_profile.by_tool.get(tool_name, 0.0) + record.cost_usd
626
+ agent_profile.by_model[model] = agent_profile.by_model.get(model, 0.0) + record.cost_usd
627
+
628
+ if agent_profile.first_activity is None:
629
+ agent_profile.first_activity = record.timestamp
630
+ agent_profile.last_activity = record.timestamp
631
+
632
+ # Store record
633
+ self._session_records[session_id].append(record)
634
+
635
+ # Check budget and alerts
636
+ self._check_budget_alerts(session, record)
637
+
638
+ return record
639
+
640
+ def _check_budget_alerts(self, session: Session, record: UsageRecord) -> None:
641
+ """Check for budget-related alerts."""
642
+ if session.budget_limit is None:
643
+ return
644
+
645
+ percentage = session.budget_percentage or 0.0
646
+
647
+ # Budget warning
648
+ if (
649
+ percentage >= self._budget_warning_threshold
650
+ and percentage < 1.0
651
+ and not any(a.alert_type == AlertType.BUDGET_WARNING for a in session.alerts)
652
+ ):
653
+ alert = CostAlert(
654
+ alert_id=f"alert_{uuid.uuid4().hex[:8]}",
655
+ alert_type=AlertType.BUDGET_WARNING,
656
+ severity=AlertSeverity.WARNING,
657
+ session_id=session.session_id,
658
+ user_id=session.user_id,
659
+ agent_id=session.agent_id,
660
+ message=(
661
+ f"Session approaching budget limit: "
662
+ f"${session.total_cost:.2f}/${session.budget_limit:.2f} "
663
+ f"({percentage:.0%})"
664
+ ),
665
+ current_cost=session.total_cost,
666
+ threshold=session.budget_limit * self._budget_warning_threshold,
667
+ )
668
+ self._trigger_alert(session, alert)
669
+
670
+ # Budget exceeded
671
+ if percentage >= 1.0:
672
+ alert = CostAlert(
673
+ alert_id=f"alert_{uuid.uuid4().hex[:8]}",
674
+ alert_type=AlertType.BUDGET_EXCEEDED,
675
+ severity=AlertSeverity.CRITICAL,
676
+ session_id=session.session_id,
677
+ user_id=session.user_id,
678
+ agent_id=session.agent_id,
679
+ message=(
680
+ f"Session budget exceeded: "
681
+ f"${session.total_cost:.2f}/${session.budget_limit:.2f}"
682
+ ),
683
+ current_cost=session.total_cost,
684
+ threshold=session.budget_limit,
685
+ )
686
+ self._trigger_alert(session, alert)
687
+
688
+ # Terminate session
689
+ session.state = SessionState.BUDGET_EXCEEDED
690
+ session.end_time = datetime.now(timezone.utc)
691
+ logger.warning(f"Session {session.session_id} terminated: budget exceeded")
692
+
693
+ # Check spend rate
694
+ if session.duration_seconds > 60: # At least 1 minute
695
+ rate_per_minute = session.total_cost / (session.duration_seconds / 60)
696
+
697
+ if (
698
+ rate_per_minute > self._rate_warning_threshold
699
+ and not any(a.alert_type == AlertType.RATE_WARNING for a in session.alerts)
700
+ ):
701
+ alert = CostAlert(
702
+ alert_id=f"alert_{uuid.uuid4().hex[:8]}",
703
+ alert_type=AlertType.RATE_WARNING,
704
+ severity=AlertSeverity.WARNING,
705
+ session_id=session.session_id,
706
+ user_id=session.user_id,
707
+ agent_id=session.agent_id,
708
+ message=(
709
+ f"High spend rate detected: ${rate_per_minute:.2f}/minute "
710
+ f"(threshold: ${self._rate_warning_threshold:.2f}/minute)"
711
+ ),
712
+ current_cost=session.total_cost,
713
+ threshold=self._rate_warning_threshold,
714
+ metadata={"rate_per_minute": rate_per_minute},
715
+ )
716
+ self._trigger_alert(session, alert)
717
+
718
+ def _trigger_alert(self, session: Session, alert: CostAlert) -> None:
719
+ """Trigger an alert and notify callbacks."""
720
+ session.alerts.append(alert)
721
+ self._total_alerts += 1
722
+
723
+ logger.warning(f"Cost alert: {alert.message}")
724
+
725
+ # Notify callbacks
726
+ for callback in self._alert_callbacks:
727
+ try:
728
+ callback(alert)
729
+ except Exception as e:
730
+ logger.error(f"Alert callback error: {e}")
731
+
732
+ def _expire_session(self, session: Session) -> None:
733
+ """Mark a session as expired."""
734
+ session.state = SessionState.EXPIRED
735
+ session.end_time = datetime.now(timezone.utc)
736
+
737
+ alert = CostAlert(
738
+ alert_id=f"alert_{uuid.uuid4().hex[:8]}",
739
+ alert_type=AlertType.SESSION_EXPIRED,
740
+ severity=AlertSeverity.INFO,
741
+ session_id=session.session_id,
742
+ user_id=session.user_id,
743
+ agent_id=session.agent_id,
744
+ message=f"Session expired after {session.timeout_minutes} minutes of inactivity",
745
+ current_cost=session.total_cost,
746
+ )
747
+ self._trigger_alert(session, alert)
748
+
749
+ logger.info(f"Session {session.session_id} expired")
750
+
751
+ def pause_session(self, session_id: str) -> bool:
752
+ """
753
+ Pause a session.
754
+
755
+ Args:
756
+ session_id: Session to pause.
757
+
758
+ Returns:
759
+ True if paused successfully.
760
+ """
761
+ with self._lock:
762
+ session = self._sessions.get(session_id)
763
+ if session and session.is_active:
764
+ session.state = SessionState.PAUSED
765
+ logger.info(f"Session {session_id} paused")
766
+ return True
767
+ return False
768
+
769
+ def resume_session(self, session_id: str) -> bool:
770
+ """
771
+ Resume a paused session.
772
+
773
+ Args:
774
+ session_id: Session to resume.
775
+
776
+ Returns:
777
+ True if resumed successfully.
778
+ """
779
+ with self._lock:
780
+ session = self._sessions.get(session_id)
781
+ if session and session.state == SessionState.PAUSED:
782
+ session.state = SessionState.ACTIVE
783
+ session.last_activity = datetime.now(timezone.utc)
784
+ logger.info(f"Session {session_id} resumed")
785
+ return True
786
+ return False
787
+
788
+ def end_session(self, session_id: str, reason: str = "user_ended") -> SessionSummary | None:
789
+ """
790
+ End a session and get summary.
791
+
792
+ Args:
793
+ session_id: Session to end.
794
+ reason: Reason for ending.
795
+
796
+ Returns:
797
+ SessionSummary or None if not found.
798
+ """
799
+ with self._lock:
800
+ session = self._sessions.get(session_id)
801
+
802
+ if session is None:
803
+ return None
804
+
805
+ # Mark as terminated
806
+ session.state = SessionState.TERMINATED
807
+ session.end_time = datetime.now(timezone.utc)
808
+ self._sessions_terminated += 1
809
+
810
+ # Calculate peak spend rate
811
+ peak_rate = 0.0
812
+ records = self._session_records.get(session_id, [])
813
+ if len(records) >= 2:
814
+ # Calculate rolling 1-minute windows
815
+ for i in range(len(records) - 1):
816
+ window_cost = 0.0
817
+ window_start = records[i].timestamp
818
+
819
+ for j in range(i, len(records)):
820
+ if (records[j].timestamp - window_start).total_seconds() <= 60:
821
+ window_cost += records[j].cost_usd
822
+ else:
823
+ break
824
+
825
+ if window_cost > peak_rate:
826
+ peak_rate = window_cost
827
+
828
+ summary = SessionSummary(
829
+ session_id=session.session_id,
830
+ user_id=session.user_id,
831
+ total_cost=session.total_cost,
832
+ total_tokens=session.input_tokens + session.output_tokens,
833
+ duration_seconds=session.duration_seconds,
834
+ tool_breakdown=dict(session.by_tool),
835
+ model_breakdown=dict(session.by_model),
836
+ agent_breakdown={
837
+ agent_id: profile.total_cost
838
+ for agent_id, profile in session.agents.items()
839
+ },
840
+ peak_spend_rate=peak_rate,
841
+ alerts_triggered=len(session.alerts),
842
+ end_reason=reason,
843
+ )
844
+
845
+ logger.info(
846
+ f"Session {session_id} ended: ${session.total_cost:.4f}, "
847
+ f"{session.record_count} records, {len(session.alerts)} alerts"
848
+ )
849
+
850
+ return summary
851
+
852
+ def get_session_records(
853
+ self,
854
+ session_id: str,
855
+ limit: int | None = None,
856
+ ) -> list[UsageRecord]:
857
+ """
858
+ Get usage records for a session.
859
+
860
+ Args:
861
+ session_id: Session to get records for.
862
+ limit: Maximum records to return.
863
+
864
+ Returns:
865
+ List of usage records.
866
+ """
867
+ with self._lock:
868
+ records = self._session_records.get(session_id, [])
869
+ if limit:
870
+ return list(records[-limit:])
871
+ return list(records)
872
+
873
+ def forecast_session_cost(
874
+ self,
875
+ session_id: str,
876
+ duration_minutes: int = 60,
877
+ ) -> float | None:
878
+ """
879
+ Forecast session cost based on current usage pattern.
880
+
881
+ Args:
882
+ session_id: Session to forecast.
883
+ duration_minutes: Minutes to forecast.
884
+
885
+ Returns:
886
+ Forecasted additional cost, or None if insufficient data.
887
+ """
888
+ if not self._enable_forecasting:
889
+ return None
890
+
891
+ with self._lock:
892
+ session = self._sessions.get(session_id)
893
+ if session is None:
894
+ return None
895
+
896
+ # Need at least 2 minutes of data
897
+ if session.duration_seconds < 120:
898
+ return None
899
+
900
+ # Calculate rate
901
+ rate_per_minute = session.total_cost / (session.duration_seconds / 60)
902
+
903
+ return rate_per_minute * duration_minutes
904
+
905
+ def _cleanup_expired_sessions(self) -> int:
906
+ """Clean up expired sessions. Returns number cleaned."""
907
+ cleaned = 0
908
+ expired_ids = []
909
+
910
+ for session_id, session in self._sessions.items():
911
+ if session.is_expired or session.state in (
912
+ SessionState.TERMINATED,
913
+ SessionState.EXPIRED,
914
+ ):
915
+ # Only clean if ended more than 1 hour ago
916
+ if session.end_time:
917
+ age = datetime.now(timezone.utc) - session.end_time
918
+ if age > timedelta(hours=1):
919
+ expired_ids.append(session_id)
920
+
921
+ for session_id in expired_ids:
922
+ del self._sessions[session_id]
923
+ if session_id in self._session_records:
924
+ del self._session_records[session_id]
925
+ cleaned += 1
926
+
927
+ if cleaned:
928
+ logger.info(f"Cleaned up {cleaned} expired sessions")
929
+
930
+ return cleaned
931
+
932
+ def get_user_total_cost(
933
+ self,
934
+ user_id: str,
935
+ period: timedelta | None = None,
936
+ ) -> float:
937
+ """
938
+ Get total cost for a user across all sessions.
939
+
940
+ Args:
941
+ user_id: User to check.
942
+ period: Time period (None for all time).
943
+
944
+ Returns:
945
+ Total cost in USD.
946
+ """
947
+ with self._lock:
948
+ if period:
949
+ return self._base_tracker.get_user_spend(user_id, period)
950
+
951
+ # Sum across all sessions
952
+ total = 0.0
953
+ for session_id in self._user_sessions.get(user_id, []):
954
+ session = self._sessions.get(session_id)
955
+ if session:
956
+ total += session.total_cost
957
+
958
+ return total
959
+
960
+ def get_agent_total_cost(self, agent_id: str) -> float:
961
+ """
962
+ Get total cost attributed to an agent.
963
+
964
+ Args:
965
+ agent_id: Agent to check.
966
+
967
+ Returns:
968
+ Total cost in USD.
969
+ """
970
+ with self._lock:
971
+ total = 0.0
972
+
973
+ for session in self._sessions.values():
974
+ if agent_id in session.agents:
975
+ total += session.agents[agent_id].total_cost
976
+
977
+ return total
978
+
979
+ def get_stats(self) -> dict[str, Any]:
980
+ """Get tracker statistics."""
981
+ with self._lock:
982
+ active_sessions = sum(1 for s in self._sessions.values() if s.is_active)
983
+
984
+ return {
985
+ "total_sessions": len(self._sessions),
986
+ "active_sessions": active_sessions,
987
+ "sessions_created": self._sessions_created,
988
+ "sessions_terminated": self._sessions_terminated,
989
+ "total_alerts": self._total_alerts,
990
+ "users_tracked": len(self._user_sessions),
991
+ }
992
+
993
+ def export_session(
994
+ self,
995
+ session_id: str,
996
+ format: str = "json",
997
+ ) -> str | None:
998
+ """
999
+ Export session data for billing/audit.
1000
+
1001
+ Args:
1002
+ session_id: Session to export.
1003
+ format: Output format ("json" or "csv").
1004
+
1005
+ Returns:
1006
+ Exported data as string, or None if not found.
1007
+ """
1008
+ with self._lock:
1009
+ session = self._sessions.get(session_id)
1010
+ if session is None:
1011
+ return None
1012
+
1013
+ records = self._session_records.get(session_id, [])
1014
+
1015
+ if format == "csv":
1016
+ lines = ["timestamp,model,input_tokens,output_tokens,cost_usd,tool_name,agent_id"]
1017
+ for record in records:
1018
+ agent_id = record.metadata.get("agent_id", "")
1019
+ lines.append(
1020
+ f"{record.timestamp.isoformat()},"
1021
+ f"{record.model},"
1022
+ f"{record.input_tokens},"
1023
+ f"{record.output_tokens},"
1024
+ f"{record.cost_usd:.6f},"
1025
+ f"{record.tool_name or ''},"
1026
+ f"{agent_id}"
1027
+ )
1028
+ return "\n".join(lines)
1029
+
1030
+ else:
1031
+ return json.dumps({
1032
+ "session": session.to_dict(),
1033
+ "records": [r.to_dict() for r in records],
1034
+ }, indent=2)
1035
+
1036
+ @property
1037
+ def base_tracker(self) -> CostTracker:
1038
+ """Get the underlying CostTracker."""
1039
+ return self._base_tracker
1040
+
1041
+
1042
+ def create_session_cost_tracker(
1043
+ budget_policy: BudgetPolicy | None = None,
1044
+ default_session_budget: float | None = None,
1045
+ alert_callback: AlertCallback | None = None,
1046
+ ) -> SessionCostTracker:
1047
+ """
1048
+ Factory function to create a SessionCostTracker.
1049
+
1050
+ Args:
1051
+ budget_policy: Global budget policy.
1052
+ default_session_budget: Default budget for new sessions.
1053
+ alert_callback: Optional alert callback.
1054
+
1055
+ Returns:
1056
+ Configured SessionCostTracker.
1057
+ """
1058
+ tracker = SessionCostTracker(budget_policy=budget_policy)
1059
+
1060
+ if alert_callback:
1061
+ tracker.add_alert_callback(alert_callback)
1062
+
1063
+ return tracker