studiomeyer-aishield 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_shield/__init__.py +47 -0
- ai_shield/audit/__init__.py +13 -0
- ai_shield/audit/logger.py +155 -0
- ai_shield/audit/types.py +19 -0
- ai_shield/cache/__init__.py +7 -0
- ai_shield/cache/lru.py +89 -0
- ai_shield/cost/__init__.py +17 -0
- ai_shield/cost/anomaly.py +50 -0
- ai_shield/cost/pricing.py +62 -0
- ai_shield/cost/tracker.py +175 -0
- ai_shield/mcp_server.py +108 -0
- ai_shield/policy/__init__.py +14 -0
- ai_shield/policy/engine.py +117 -0
- ai_shield/policy/tools.py +102 -0
- ai_shield/scanner/__init__.py +18 -0
- ai_shield/scanner/canary.py +23 -0
- ai_shield/scanner/chain.py +65 -0
- ai_shield/scanner/heuristic.py +545 -0
- ai_shield/scanner/pii.py +254 -0
- ai_shield/shield.py +200 -0
- ai_shield/types.py +226 -0
- studiomeyer_aishield-0.1.0.dist-info/METADATA +291 -0
- studiomeyer_aishield-0.1.0.dist-info/RECORD +26 -0
- studiomeyer_aishield-0.1.0.dist-info/WHEEL +4 -0
- studiomeyer_aishield-0.1.0.dist-info/entry_points.txt +2 -0
- studiomeyer_aishield-0.1.0.dist-info/licenses/LICENSE +23 -0
ai_shield/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""ai-shield — LLM input shield for prompt-injection, PII, tool-policy, cost.
|
|
2
|
+
|
|
3
|
+
Python 1:1 port of ai-shield-core (TypeScript, MIT, 4 audit rounds).
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from ai_shield.shield import AIShield
|
|
9
|
+
from ai_shield.types import (
|
|
10
|
+
AuditRecord,
|
|
11
|
+
BudgetCheckResult,
|
|
12
|
+
BudgetConfig,
|
|
13
|
+
BudgetPeriod,
|
|
14
|
+
CostRecord,
|
|
15
|
+
Decision,
|
|
16
|
+
PIIAction,
|
|
17
|
+
PIIEntity,
|
|
18
|
+
PIIType,
|
|
19
|
+
ScannerResult,
|
|
20
|
+
ScanResult,
|
|
21
|
+
ToolCall,
|
|
22
|
+
ToolManifestPin,
|
|
23
|
+
Violation,
|
|
24
|
+
ViolationType,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
__version__ = "0.1.0"
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"AIShield",
|
|
31
|
+
"AuditRecord",
|
|
32
|
+
"BudgetCheckResult",
|
|
33
|
+
"BudgetConfig",
|
|
34
|
+
"BudgetPeriod",
|
|
35
|
+
"CostRecord",
|
|
36
|
+
"Decision",
|
|
37
|
+
"PIIAction",
|
|
38
|
+
"PIIEntity",
|
|
39
|
+
"PIIType",
|
|
40
|
+
"ScanResult",
|
|
41
|
+
"ScannerResult",
|
|
42
|
+
"ToolCall",
|
|
43
|
+
"ToolManifestPin",
|
|
44
|
+
"Violation",
|
|
45
|
+
"ViolationType",
|
|
46
|
+
"__version__",
|
|
47
|
+
]
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Audit subpackage — async batched logger + store interface."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ai_shield.audit.logger import AuditLogger, ConsoleAuditStore, MemoryAuditStore
|
|
6
|
+
from ai_shield.audit.types import AuditStore
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AuditLogger",
|
|
10
|
+
"AuditStore",
|
|
11
|
+
"ConsoleAuditStore",
|
|
12
|
+
"MemoryAuditStore",
|
|
13
|
+
]
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Async batched audit logger.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/audit/logger.ts`.
|
|
4
|
+
|
|
5
|
+
Inputs are NEVER stored in plain text — only `sha256(input)` and an
|
|
6
|
+
optional `sha256(user_id)[:32]` end up on disk / in the store.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import asyncio
|
|
12
|
+
import contextlib
|
|
13
|
+
import hashlib
|
|
14
|
+
import sys
|
|
15
|
+
import unicodedata
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from datetime import datetime, timezone
|
|
18
|
+
from typing import Any, TextIO
|
|
19
|
+
|
|
20
|
+
from ai_shield.audit.types import AuditStore
|
|
21
|
+
from ai_shield.types import AuditRecord, Decision, Violation
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _hash_input(text: str) -> str:
|
|
25
|
+
normalized = unicodedata.normalize("NFKD", text).encode("utf-8")
|
|
26
|
+
return hashlib.sha256(normalized).hexdigest()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _hash_user(user_id: str | None) -> str | None:
|
|
30
|
+
if user_id is None:
|
|
31
|
+
return None
|
|
32
|
+
return hashlib.sha256(user_id.encode("utf-8")).hexdigest()[:32]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class ConsoleAuditStore:
|
|
37
|
+
"""Write JSON lines to a TextIO sink (default stderr)."""
|
|
38
|
+
|
|
39
|
+
sink: TextIO = field(default_factory=lambda: sys.stderr)
|
|
40
|
+
|
|
41
|
+
async def write(self, record: AuditRecord) -> None:
|
|
42
|
+
self.sink.write(record.model_dump_json() + "\n")
|
|
43
|
+
self.sink.flush()
|
|
44
|
+
|
|
45
|
+
async def write_batch(self, records: list[AuditRecord]) -> None:
|
|
46
|
+
for r in records:
|
|
47
|
+
self.sink.write(r.model_dump_json() + "\n")
|
|
48
|
+
self.sink.flush()
|
|
49
|
+
|
|
50
|
+
async def flush(self) -> None:
|
|
51
|
+
self.sink.flush()
|
|
52
|
+
|
|
53
|
+
async def close(self) -> None:
|
|
54
|
+
self.sink.flush()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class MemoryAuditStore:
|
|
59
|
+
"""In-process AuditStore — useful for tests and short-lived processes."""
|
|
60
|
+
|
|
61
|
+
records: list[AuditRecord] = field(default_factory=list)
|
|
62
|
+
|
|
63
|
+
async def write(self, record: AuditRecord) -> None:
|
|
64
|
+
self.records.append(record)
|
|
65
|
+
|
|
66
|
+
async def write_batch(self, records: list[AuditRecord]) -> None:
|
|
67
|
+
self.records.extend(records)
|
|
68
|
+
|
|
69
|
+
async def flush(self) -> None:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
async def close(self) -> None:
|
|
73
|
+
return None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AuditLogger:
|
|
77
|
+
"""Buffer audit records and flush periodically to an AuditStore."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
store: AuditStore | None = None,
|
|
82
|
+
*,
|
|
83
|
+
flush_interval_seconds: float = 5.0,
|
|
84
|
+
max_buffer: int = 100,
|
|
85
|
+
) -> None:
|
|
86
|
+
self.store: AuditStore = store if store is not None else ConsoleAuditStore()
|
|
87
|
+
self._flush_interval = flush_interval_seconds
|
|
88
|
+
self._max_buffer = max_buffer
|
|
89
|
+
self._buffer: list[AuditRecord] = []
|
|
90
|
+
self._lock = asyncio.Lock()
|
|
91
|
+
self._closed = False
|
|
92
|
+
self._task: asyncio.Task[None] | None = None
|
|
93
|
+
|
|
94
|
+
async def log(
|
|
95
|
+
self,
|
|
96
|
+
*,
|
|
97
|
+
text: str,
|
|
98
|
+
decision: Decision,
|
|
99
|
+
violations: list[Violation],
|
|
100
|
+
score: float,
|
|
101
|
+
user_id: str | None = None,
|
|
102
|
+
metadata: dict[str, Any] | None = None,
|
|
103
|
+
) -> None:
|
|
104
|
+
record = AuditRecord(
|
|
105
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
106
|
+
user_id_hash=_hash_user(user_id),
|
|
107
|
+
input_sha256=_hash_input(text),
|
|
108
|
+
decision=decision,
|
|
109
|
+
violations=violations,
|
|
110
|
+
score=score,
|
|
111
|
+
metadata=metadata or {},
|
|
112
|
+
)
|
|
113
|
+
async with self._lock:
|
|
114
|
+
self._buffer.append(record)
|
|
115
|
+
buffer_full = len(self._buffer) >= self._max_buffer
|
|
116
|
+
|
|
117
|
+
if buffer_full:
|
|
118
|
+
await self.flush()
|
|
119
|
+
else:
|
|
120
|
+
self._ensure_task()
|
|
121
|
+
|
|
122
|
+
def _ensure_task(self) -> None:
|
|
123
|
+
if self._closed:
|
|
124
|
+
return
|
|
125
|
+
if self._task is not None and not self._task.done():
|
|
126
|
+
return
|
|
127
|
+
try:
|
|
128
|
+
loop = asyncio.get_running_loop()
|
|
129
|
+
except RuntimeError:
|
|
130
|
+
return
|
|
131
|
+
self._task = loop.create_task(self._auto_flush())
|
|
132
|
+
|
|
133
|
+
async def _auto_flush(self) -> None:
|
|
134
|
+
try:
|
|
135
|
+
await asyncio.sleep(self._flush_interval)
|
|
136
|
+
await self.flush()
|
|
137
|
+
except asyncio.CancelledError:
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
async def flush(self) -> None:
|
|
141
|
+
async with self._lock:
|
|
142
|
+
if not self._buffer:
|
|
143
|
+
return
|
|
144
|
+
batch = self._buffer
|
|
145
|
+
self._buffer = []
|
|
146
|
+
await self.store.write_batch(batch)
|
|
147
|
+
|
|
148
|
+
async def close(self) -> None:
|
|
149
|
+
self._closed = True
|
|
150
|
+
if self._task is not None and not self._task.done():
|
|
151
|
+
self._task.cancel()
|
|
152
|
+
with contextlib.suppress(asyncio.CancelledError, Exception):
|
|
153
|
+
await self._task
|
|
154
|
+
await self.flush()
|
|
155
|
+
await self.store.close()
|
ai_shield/audit/types.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""AuditStore protocol — pluggable backend for AuditLogger.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/audit/types.ts`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from typing import Protocol
|
|
9
|
+
|
|
10
|
+
from ai_shield.types import AuditRecord
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AuditStore(Protocol):
|
|
14
|
+
"""Pluggable backend for AuditLogger writes."""
|
|
15
|
+
|
|
16
|
+
async def write(self, record: AuditRecord) -> None: ...
|
|
17
|
+
async def write_batch(self, records: list[AuditRecord]) -> None: ...
|
|
18
|
+
async def flush(self) -> None: ...
|
|
19
|
+
async def close(self) -> None: ...
|
ai_shield/cache/lru.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""TTL + insertion-order LRU cache for scan results.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/cache/lru.ts` — relies on dict-insertion
|
|
4
|
+
order semantics (guaranteed in CPython 3.7+).
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import time
|
|
10
|
+
from collections import OrderedDict
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Generic, TypeVar
|
|
13
|
+
|
|
14
|
+
V = TypeVar("V")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class _Entry(Generic[V]):
|
|
19
|
+
value: V
|
|
20
|
+
expires_at: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ScanLRUCache(Generic[V]):
|
|
24
|
+
"""Least-recently-used cache with per-entry TTL."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, *, max_size: int = 1000, ttl_ms: int = 300_000) -> None:
|
|
27
|
+
if max_size < 1:
|
|
28
|
+
raise ValueError("max_size must be >= 1")
|
|
29
|
+
if ttl_ms < 0:
|
|
30
|
+
raise ValueError("ttl_ms must be >= 0")
|
|
31
|
+
self._max_size = max_size
|
|
32
|
+
self._ttl_ms = ttl_ms
|
|
33
|
+
self._data: OrderedDict[str, _Entry[V]] = OrderedDict()
|
|
34
|
+
|
|
35
|
+
def get(self, key: str) -> V | None:
|
|
36
|
+
entry = self._data.get(key)
|
|
37
|
+
if entry is None:
|
|
38
|
+
return None
|
|
39
|
+
if self._expired(entry):
|
|
40
|
+
del self._data[key]
|
|
41
|
+
return None
|
|
42
|
+
# promote to MRU
|
|
43
|
+
self._data.move_to_end(key)
|
|
44
|
+
return entry.value
|
|
45
|
+
|
|
46
|
+
def set(self, key: str, value: V) -> None:
|
|
47
|
+
if key in self._data:
|
|
48
|
+
del self._data[key]
|
|
49
|
+
elif len(self._data) >= self._max_size:
|
|
50
|
+
# evict oldest
|
|
51
|
+
self._data.popitem(last=False)
|
|
52
|
+
self._data[key] = _Entry(value=value, expires_at=self._now_ms() + self._ttl_ms)
|
|
53
|
+
|
|
54
|
+
def has(self, key: str) -> bool:
|
|
55
|
+
entry = self._data.get(key)
|
|
56
|
+
if entry is None:
|
|
57
|
+
return False
|
|
58
|
+
if self._expired(entry):
|
|
59
|
+
del self._data[key]
|
|
60
|
+
return False
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
def delete(self, key: str) -> bool:
|
|
64
|
+
if key in self._data:
|
|
65
|
+
del self._data[key]
|
|
66
|
+
return True
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
def clear(self) -> None:
|
|
70
|
+
self._data.clear()
|
|
71
|
+
|
|
72
|
+
def prune(self) -> int:
|
|
73
|
+
"""Remove all expired entries. Returns count removed."""
|
|
74
|
+
keys = [k for k, e in self._data.items() if self._expired(e)]
|
|
75
|
+
for k in keys:
|
|
76
|
+
del self._data[k]
|
|
77
|
+
return len(keys)
|
|
78
|
+
|
|
79
|
+
def __len__(self) -> int:
|
|
80
|
+
return len(self._data)
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def _now_ms() -> float:
|
|
84
|
+
return time.monotonic() * 1000.0
|
|
85
|
+
|
|
86
|
+
def _expired(self, entry: _Entry[V]) -> bool:
|
|
87
|
+
if self._ttl_ms == 0:
|
|
88
|
+
return False
|
|
89
|
+
return self._now_ms() >= entry.expires_at
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""Cost subpackage — tracker + pricing + anomaly."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from ai_shield.cost.anomaly import detect_anomaly
|
|
6
|
+
from ai_shield.cost.pricing import MODEL_PRICING, estimate_cost, get_model_pricing
|
|
7
|
+
from ai_shield.cost.tracker import CostTracker, MemoryStore, RedisLike
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"MODEL_PRICING",
|
|
11
|
+
"CostTracker",
|
|
12
|
+
"MemoryStore",
|
|
13
|
+
"RedisLike",
|
|
14
|
+
"detect_anomaly",
|
|
15
|
+
"estimate_cost",
|
|
16
|
+
"get_model_pricing",
|
|
17
|
+
]
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""Z-score anomaly detection for cost-spike alerts.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/cost/anomaly.ts`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import math
|
|
9
|
+
|
|
10
|
+
from ai_shield.types import AnomalyResult
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def detect_anomaly(
|
|
14
|
+
samples: list[float],
|
|
15
|
+
current: float,
|
|
16
|
+
*,
|
|
17
|
+
z_threshold: float = 2.5,
|
|
18
|
+
) -> AnomalyResult:
|
|
19
|
+
"""Return AnomalyResult flagging `current` if z-score exceeds threshold.
|
|
20
|
+
|
|
21
|
+
Returns is_anomaly=False when fewer than 3 samples or stdDev=0.
|
|
22
|
+
"""
|
|
23
|
+
n = len(samples)
|
|
24
|
+
if n < 3:
|
|
25
|
+
return AnomalyResult(
|
|
26
|
+
is_anomaly=False,
|
|
27
|
+
z_score=0.0,
|
|
28
|
+
current_value=current,
|
|
29
|
+
mean=0.0,
|
|
30
|
+
std_dev=0.0,
|
|
31
|
+
)
|
|
32
|
+
mean = sum(samples) / n
|
|
33
|
+
variance = sum((s - mean) ** 2 for s in samples) / n
|
|
34
|
+
std = math.sqrt(variance)
|
|
35
|
+
if std == 0.0:
|
|
36
|
+
return AnomalyResult(
|
|
37
|
+
is_anomaly=False,
|
|
38
|
+
z_score=0.0,
|
|
39
|
+
current_value=current,
|
|
40
|
+
mean=mean,
|
|
41
|
+
std_dev=0.0,
|
|
42
|
+
)
|
|
43
|
+
z = (current - mean) / std
|
|
44
|
+
return AnomalyResult(
|
|
45
|
+
is_anomaly=abs(z) >= z_threshold,
|
|
46
|
+
z_score=z,
|
|
47
|
+
current_value=current,
|
|
48
|
+
mean=mean,
|
|
49
|
+
std_dev=std,
|
|
50
|
+
)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Model pricing table + estimate_cost helper.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/cost/pricing.ts`.
|
|
4
|
+
USD per 1M tokens. Update as providers change rates.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from ai_shield.types import ModelPricing
|
|
10
|
+
|
|
11
|
+
MODEL_PRICING: dict[str, ModelPricing] = {
|
|
12
|
+
# OpenAI
|
|
13
|
+
"gpt-5.2": ModelPricing(input_per_1m=4.50, output_per_1m=18.00),
|
|
14
|
+
"gpt-5.0": ModelPricing(input_per_1m=3.00, output_per_1m=12.00),
|
|
15
|
+
"gpt-4.1": ModelPricing(input_per_1m=2.50, output_per_1m=10.00),
|
|
16
|
+
"gpt-4o": ModelPricing(input_per_1m=2.50, output_per_1m=10.00),
|
|
17
|
+
"gpt-4o-mini": ModelPricing(input_per_1m=0.15, output_per_1m=0.60),
|
|
18
|
+
"o3": ModelPricing(input_per_1m=15.00, output_per_1m=60.00),
|
|
19
|
+
"o4-mini": ModelPricing(input_per_1m=1.10, output_per_1m=4.40),
|
|
20
|
+
# Anthropic
|
|
21
|
+
"claude-opus-4-7": ModelPricing(input_per_1m=15.00, output_per_1m=75.00),
|
|
22
|
+
"claude-opus-4-6": ModelPricing(input_per_1m=15.00, output_per_1m=75.00),
|
|
23
|
+
"claude-sonnet-4-7": ModelPricing(input_per_1m=3.00, output_per_1m=15.00),
|
|
24
|
+
"claude-sonnet-4-5": ModelPricing(input_per_1m=3.00, output_per_1m=15.00),
|
|
25
|
+
"claude-haiku-4-7": ModelPricing(input_per_1m=0.80, output_per_1m=4.00),
|
|
26
|
+
# Google
|
|
27
|
+
"gemini-2.5-pro": ModelPricing(input_per_1m=1.25, output_per_1m=5.00),
|
|
28
|
+
"gemini-2.5-flash": ModelPricing(input_per_1m=0.075, output_per_1m=0.30),
|
|
29
|
+
"gemini-2.5-flash-lite": ModelPricing(input_per_1m=0.0375, output_per_1m=0.15),
|
|
30
|
+
# xAI
|
|
31
|
+
"grok-4": ModelPricing(input_per_1m=5.00, output_per_1m=15.00),
|
|
32
|
+
# Mistral
|
|
33
|
+
"mistral-large": ModelPricing(input_per_1m=2.00, output_per_1m=6.00),
|
|
34
|
+
"mistral-small": ModelPricing(input_per_1m=0.20, output_per_1m=0.60),
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def get_model_pricing(model: str) -> ModelPricing:
|
|
39
|
+
"""Return pricing for `model`. Exact match → longest-prefix match → fallback.
|
|
40
|
+
|
|
41
|
+
Longest-prefix-first ensures `gpt-4o-mini-2024-07-18` matches `gpt-4o-mini`
|
|
42
|
+
($0.15) NOT `gpt-4o` ($2.50) — critical for cost-accuracy with versioned
|
|
43
|
+
snapshots. Insertion-order would otherwise return the first registered
|
|
44
|
+
prefix (which is shorter for many providers).
|
|
45
|
+
"""
|
|
46
|
+
if model in MODEL_PRICING:
|
|
47
|
+
return MODEL_PRICING[model]
|
|
48
|
+
# Sort keys by length descending so most-specific prefix wins first.
|
|
49
|
+
for key in sorted(MODEL_PRICING.keys(), key=len, reverse=True):
|
|
50
|
+
if model.startswith(key):
|
|
51
|
+
return MODEL_PRICING[key]
|
|
52
|
+
return MODEL_PRICING["gpt-4o-mini"]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
|
56
|
+
"""Compute USD estimate for a single LLM call."""
|
|
57
|
+
if input_tokens < 0 or output_tokens < 0:
|
|
58
|
+
raise ValueError("Token counts must be non-negative")
|
|
59
|
+
p = get_model_pricing(model)
|
|
60
|
+
return (input_tokens / 1_000_000) * p.input_per_1m + (
|
|
61
|
+
output_tokens / 1_000_000
|
|
62
|
+
) * p.output_per_1m
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Async cost tracker — soft/hard budgets, in-memory or Redis backend.
|
|
2
|
+
|
|
3
|
+
1:1 port of `packages/core/src/cost/tracker.ts`.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import os
|
|
10
|
+
from collections import deque
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
from typing import Protocol
|
|
14
|
+
|
|
15
|
+
from ai_shield.cost.pricing import estimate_cost
|
|
16
|
+
from ai_shield.types import (
|
|
17
|
+
BudgetCheckResult,
|
|
18
|
+
BudgetConfig,
|
|
19
|
+
BudgetPeriod,
|
|
20
|
+
CostRecord,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
_DEFAULT_MAX_RECORDS = int(os.environ.get("AI_SHIELD_MAX_RECORDS", "10000"))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RedisLike(Protocol):
|
|
27
|
+
"""Minimal Redis interface — `incrbyfloat` + `expire` + `get`."""
|
|
28
|
+
|
|
29
|
+
async def incrbyfloat(self, key: str, value: float) -> float: ...
|
|
30
|
+
async def expire(self, key: str, seconds: int) -> bool: ...
|
|
31
|
+
async def get(self, key: str) -> str | None: ...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class MemoryStore:
|
|
36
|
+
"""In-process fallback when Redis is not provided."""
|
|
37
|
+
|
|
38
|
+
_data: dict[str, float] = field(default_factory=dict)
|
|
39
|
+
_expires_at: dict[str, float] = field(default_factory=dict)
|
|
40
|
+
_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
|
41
|
+
|
|
42
|
+
async def incrbyfloat(self, key: str, value: float) -> float:
|
|
43
|
+
async with self._lock:
|
|
44
|
+
self._sweep_expired()
|
|
45
|
+
current = self._data.get(key, 0.0)
|
|
46
|
+
new = current + value
|
|
47
|
+
self._data[key] = new
|
|
48
|
+
return new
|
|
49
|
+
|
|
50
|
+
async def expire(self, key: str, seconds: int) -> bool:
|
|
51
|
+
async with self._lock:
|
|
52
|
+
if key not in self._data:
|
|
53
|
+
return False
|
|
54
|
+
loop_time = asyncio.get_event_loop().time()
|
|
55
|
+
self._expires_at[key] = loop_time + seconds
|
|
56
|
+
return True
|
|
57
|
+
|
|
58
|
+
async def get(self, key: str) -> str | None:
|
|
59
|
+
async with self._lock:
|
|
60
|
+
self._sweep_expired()
|
|
61
|
+
v = self._data.get(key)
|
|
62
|
+
return None if v is None else str(v)
|
|
63
|
+
|
|
64
|
+
def _sweep_expired(self) -> None:
|
|
65
|
+
now = asyncio.get_event_loop().time()
|
|
66
|
+
expired = [k for k, t in self._expires_at.items() if t <= now]
|
|
67
|
+
for k in expired:
|
|
68
|
+
self._data.pop(k, None)
|
|
69
|
+
self._expires_at.pop(k, None)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _period_seconds(period: BudgetPeriod) -> int:
|
|
73
|
+
if period == "hourly":
|
|
74
|
+
return 3600
|
|
75
|
+
if period == "daily":
|
|
76
|
+
return 86400
|
|
77
|
+
if period == "monthly":
|
|
78
|
+
return 86400 * 31
|
|
79
|
+
raise ValueError(f"Unknown period: {period!r}")
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _period_key(period: BudgetPeriod, now: datetime | None = None) -> str:
|
|
83
|
+
moment = now or datetime.now(timezone.utc)
|
|
84
|
+
if period == "hourly":
|
|
85
|
+
return moment.strftime("%Y%m%d%H")
|
|
86
|
+
if period == "daily":
|
|
87
|
+
return moment.strftime("%Y%m%d")
|
|
88
|
+
if period == "monthly":
|
|
89
|
+
return moment.strftime("%Y%m")
|
|
90
|
+
raise ValueError(f"Unknown period: {period!r}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class CostTracker:
|
|
95
|
+
"""Track LLM spend with optional Redis backend.
|
|
96
|
+
|
|
97
|
+
All writes are atomic at the store level (`INCRBYFLOAT` on Redis,
|
|
98
|
+
`asyncio.Lock` on the in-memory store).
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
budget: BudgetConfig = field(default_factory=BudgetConfig)
|
|
102
|
+
store: RedisLike | None = None
|
|
103
|
+
max_records: int = _DEFAULT_MAX_RECORDS
|
|
104
|
+
|
|
105
|
+
_records: deque[CostRecord] = field(init=False)
|
|
106
|
+
_store: RedisLike = field(init=False)
|
|
107
|
+
|
|
108
|
+
def __post_init__(self) -> None:
|
|
109
|
+
self._records = deque(maxlen=self.max_records)
|
|
110
|
+
self._store = self.store if self.store is not None else MemoryStore()
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def _key(entity_id: str, period: BudgetPeriod) -> str:
|
|
114
|
+
return f"ai-shield:cost:{entity_id}:{period}:{_period_key(period)}"
|
|
115
|
+
|
|
116
|
+
async def record(
|
|
117
|
+
self,
|
|
118
|
+
*,
|
|
119
|
+
entity_id: str,
|
|
120
|
+
model: str,
|
|
121
|
+
input_tokens: int,
|
|
122
|
+
output_tokens: int,
|
|
123
|
+
actual_usd: float | None = None,
|
|
124
|
+
) -> CostRecord:
|
|
125
|
+
cost = (
|
|
126
|
+
actual_usd
|
|
127
|
+
if actual_usd is not None
|
|
128
|
+
else estimate_cost(
|
|
129
|
+
model,
|
|
130
|
+
input_tokens,
|
|
131
|
+
output_tokens,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
record = CostRecord(
|
|
135
|
+
entity_id=entity_id,
|
|
136
|
+
model=model,
|
|
137
|
+
input_tokens=input_tokens,
|
|
138
|
+
output_tokens=output_tokens,
|
|
139
|
+
actual_usd=cost,
|
|
140
|
+
timestamp=datetime.now(timezone.utc).isoformat(),
|
|
141
|
+
)
|
|
142
|
+
self._records.append(record)
|
|
143
|
+
|
|
144
|
+
key = self._key(entity_id, self.budget.period)
|
|
145
|
+
await self._store.incrbyfloat(key, cost)
|
|
146
|
+
await self._store.expire(key, _period_seconds(self.budget.period))
|
|
147
|
+
return record
|
|
148
|
+
|
|
149
|
+
async def check_budget(self, entity_id: str) -> BudgetCheckResult:
|
|
150
|
+
key = self._key(entity_id, self.budget.period)
|
|
151
|
+
raw = await self._store.get(key)
|
|
152
|
+
spend = float(raw) if raw is not None else 0.0
|
|
153
|
+
|
|
154
|
+
soft = self.budget.soft_limit_usd
|
|
155
|
+
hard = self.budget.hard_limit_usd
|
|
156
|
+
|
|
157
|
+
soft_exceeded = soft is not None and spend >= soft
|
|
158
|
+
hard_exceeded = hard is not None and spend >= hard
|
|
159
|
+
|
|
160
|
+
return BudgetCheckResult(
|
|
161
|
+
allowed=not hard_exceeded,
|
|
162
|
+
current_spend_usd=spend,
|
|
163
|
+
limit_usd=hard if hard is not None else soft,
|
|
164
|
+
period=self.budget.period,
|
|
165
|
+
soft_exceeded=soft_exceeded,
|
|
166
|
+
hard_exceeded=hard_exceeded,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
async def get_current_spend(self, entity_id: str) -> float:
|
|
170
|
+
key = self._key(entity_id, self.budget.period)
|
|
171
|
+
raw = await self._store.get(key)
|
|
172
|
+
return float(raw) if raw is not None else 0.0
|
|
173
|
+
|
|
174
|
+
def recent_records(self, n: int = 100) -> list[CostRecord]:
|
|
175
|
+
return list(self._records)[-n:]
|