studiomeyer-aishield 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_shield/__init__.py ADDED
@@ -0,0 +1,47 @@
1
+ """ai-shield — LLM input shield for prompt-injection, PII, tool-policy, cost.
2
+
3
+ Python 1:1 port of ai-shield-core (TypeScript, MIT, 4 audit rounds).
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from ai_shield.shield import AIShield
9
+ from ai_shield.types import (
10
+ AuditRecord,
11
+ BudgetCheckResult,
12
+ BudgetConfig,
13
+ BudgetPeriod,
14
+ CostRecord,
15
+ Decision,
16
+ PIIAction,
17
+ PIIEntity,
18
+ PIIType,
19
+ ScannerResult,
20
+ ScanResult,
21
+ ToolCall,
22
+ ToolManifestPin,
23
+ Violation,
24
+ ViolationType,
25
+ )
26
+
27
+ __version__ = "0.1.0"
28
+
29
+ __all__ = [
30
+ "AIShield",
31
+ "AuditRecord",
32
+ "BudgetCheckResult",
33
+ "BudgetConfig",
34
+ "BudgetPeriod",
35
+ "CostRecord",
36
+ "Decision",
37
+ "PIIAction",
38
+ "PIIEntity",
39
+ "PIIType",
40
+ "ScanResult",
41
+ "ScannerResult",
42
+ "ToolCall",
43
+ "ToolManifestPin",
44
+ "Violation",
45
+ "ViolationType",
46
+ "__version__",
47
+ ]
@@ -0,0 +1,13 @@
1
+ """Audit subpackage — async batched logger + store interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ai_shield.audit.logger import AuditLogger, ConsoleAuditStore, MemoryAuditStore
6
+ from ai_shield.audit.types import AuditStore
7
+
8
+ __all__ = [
9
+ "AuditLogger",
10
+ "AuditStore",
11
+ "ConsoleAuditStore",
12
+ "MemoryAuditStore",
13
+ ]
@@ -0,0 +1,155 @@
1
+ """Async batched audit logger.
2
+
3
+ 1:1 port of `packages/core/src/audit/logger.ts`.
4
+
5
+ Inputs are NEVER stored in plain text — only `sha256(input)` and an
6
+ optional `sha256(user_id)[:32]` end up on disk / in the store.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import contextlib
13
+ import hashlib
14
+ import sys
15
+ import unicodedata
16
+ from dataclasses import dataclass, field
17
+ from datetime import datetime, timezone
18
+ from typing import Any, TextIO
19
+
20
+ from ai_shield.audit.types import AuditStore
21
+ from ai_shield.types import AuditRecord, Decision, Violation
22
+
23
+
24
+ def _hash_input(text: str) -> str:
25
+ normalized = unicodedata.normalize("NFKD", text).encode("utf-8")
26
+ return hashlib.sha256(normalized).hexdigest()
27
+
28
+
29
+ def _hash_user(user_id: str | None) -> str | None:
30
+ if user_id is None:
31
+ return None
32
+ return hashlib.sha256(user_id.encode("utf-8")).hexdigest()[:32]
33
+
34
+
35
+ @dataclass
36
+ class ConsoleAuditStore:
37
+ """Write JSON lines to a TextIO sink (default stderr)."""
38
+
39
+ sink: TextIO = field(default_factory=lambda: sys.stderr)
40
+
41
+ async def write(self, record: AuditRecord) -> None:
42
+ self.sink.write(record.model_dump_json() + "\n")
43
+ self.sink.flush()
44
+
45
+ async def write_batch(self, records: list[AuditRecord]) -> None:
46
+ for r in records:
47
+ self.sink.write(r.model_dump_json() + "\n")
48
+ self.sink.flush()
49
+
50
+ async def flush(self) -> None:
51
+ self.sink.flush()
52
+
53
+ async def close(self) -> None:
54
+ self.sink.flush()
55
+
56
+
57
+ @dataclass
58
+ class MemoryAuditStore:
59
+ """In-process AuditStore — useful for tests and short-lived processes."""
60
+
61
+ records: list[AuditRecord] = field(default_factory=list)
62
+
63
+ async def write(self, record: AuditRecord) -> None:
64
+ self.records.append(record)
65
+
66
+ async def write_batch(self, records: list[AuditRecord]) -> None:
67
+ self.records.extend(records)
68
+
69
+ async def flush(self) -> None:
70
+ return None
71
+
72
+ async def close(self) -> None:
73
+ return None
74
+
75
+
76
+ class AuditLogger:
77
+ """Buffer audit records and flush periodically to an AuditStore."""
78
+
79
+ def __init__(
80
+ self,
81
+ store: AuditStore | None = None,
82
+ *,
83
+ flush_interval_seconds: float = 5.0,
84
+ max_buffer: int = 100,
85
+ ) -> None:
86
+ self.store: AuditStore = store if store is not None else ConsoleAuditStore()
87
+ self._flush_interval = flush_interval_seconds
88
+ self._max_buffer = max_buffer
89
+ self._buffer: list[AuditRecord] = []
90
+ self._lock = asyncio.Lock()
91
+ self._closed = False
92
+ self._task: asyncio.Task[None] | None = None
93
+
94
+ async def log(
95
+ self,
96
+ *,
97
+ text: str,
98
+ decision: Decision,
99
+ violations: list[Violation],
100
+ score: float,
101
+ user_id: str | None = None,
102
+ metadata: dict[str, Any] | None = None,
103
+ ) -> None:
104
+ record = AuditRecord(
105
+ timestamp=datetime.now(timezone.utc).isoformat(),
106
+ user_id_hash=_hash_user(user_id),
107
+ input_sha256=_hash_input(text),
108
+ decision=decision,
109
+ violations=violations,
110
+ score=score,
111
+ metadata=metadata or {},
112
+ )
113
+ async with self._lock:
114
+ self._buffer.append(record)
115
+ buffer_full = len(self._buffer) >= self._max_buffer
116
+
117
+ if buffer_full:
118
+ await self.flush()
119
+ else:
120
+ self._ensure_task()
121
+
122
+ def _ensure_task(self) -> None:
123
+ if self._closed:
124
+ return
125
+ if self._task is not None and not self._task.done():
126
+ return
127
+ try:
128
+ loop = asyncio.get_running_loop()
129
+ except RuntimeError:
130
+ return
131
+ self._task = loop.create_task(self._auto_flush())
132
+
133
+ async def _auto_flush(self) -> None:
134
+ try:
135
+ await asyncio.sleep(self._flush_interval)
136
+ await self.flush()
137
+ except asyncio.CancelledError:
138
+ pass
139
+
140
+ async def flush(self) -> None:
141
+ async with self._lock:
142
+ if not self._buffer:
143
+ return
144
+ batch = self._buffer
145
+ self._buffer = []
146
+ await self.store.write_batch(batch)
147
+
148
+ async def close(self) -> None:
149
+ self._closed = True
150
+ if self._task is not None and not self._task.done():
151
+ self._task.cancel()
152
+ with contextlib.suppress(asyncio.CancelledError, Exception):
153
+ await self._task
154
+ await self.flush()
155
+ await self.store.close()
@@ -0,0 +1,19 @@
1
+ """AuditStore protocol — pluggable backend for AuditLogger.
2
+
3
+ 1:1 port of `packages/core/src/audit/types.ts`.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from typing import Protocol
9
+
10
+ from ai_shield.types import AuditRecord
11
+
12
+
13
+ class AuditStore(Protocol):
14
+ """Pluggable backend for AuditLogger writes."""
15
+
16
+ async def write(self, record: AuditRecord) -> None: ...
17
+ async def write_batch(self, records: list[AuditRecord]) -> None: ...
18
+ async def flush(self) -> None: ...
19
+ async def close(self) -> None: ...
@@ -0,0 +1,7 @@
1
+ """Cache subpackage — TTL + insertion-order LRU."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ai_shield.cache.lru import ScanLRUCache
6
+
7
+ __all__ = ["ScanLRUCache"]
ai_shield/cache/lru.py ADDED
@@ -0,0 +1,89 @@
1
+ """TTL + insertion-order LRU cache for scan results.
2
+
3
+ 1:1 port of `packages/core/src/cache/lru.ts` — relies on dict-insertion
4
+ order semantics (guaranteed in CPython 3.7+).
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import time
10
+ from collections import OrderedDict
11
+ from dataclasses import dataclass
12
+ from typing import Generic, TypeVar
13
+
14
+ V = TypeVar("V")
15
+
16
+
17
+ @dataclass
18
+ class _Entry(Generic[V]):
19
+ value: V
20
+ expires_at: float
21
+
22
+
23
+ class ScanLRUCache(Generic[V]):
24
+ """Least-recently-used cache with per-entry TTL."""
25
+
26
+ def __init__(self, *, max_size: int = 1000, ttl_ms: int = 300_000) -> None:
27
+ if max_size < 1:
28
+ raise ValueError("max_size must be >= 1")
29
+ if ttl_ms < 0:
30
+ raise ValueError("ttl_ms must be >= 0")
31
+ self._max_size = max_size
32
+ self._ttl_ms = ttl_ms
33
+ self._data: OrderedDict[str, _Entry[V]] = OrderedDict()
34
+
35
+ def get(self, key: str) -> V | None:
36
+ entry = self._data.get(key)
37
+ if entry is None:
38
+ return None
39
+ if self._expired(entry):
40
+ del self._data[key]
41
+ return None
42
+ # promote to MRU
43
+ self._data.move_to_end(key)
44
+ return entry.value
45
+
46
+ def set(self, key: str, value: V) -> None:
47
+ if key in self._data:
48
+ del self._data[key]
49
+ elif len(self._data) >= self._max_size:
50
+ # evict oldest
51
+ self._data.popitem(last=False)
52
+ self._data[key] = _Entry(value=value, expires_at=self._now_ms() + self._ttl_ms)
53
+
54
+ def has(self, key: str) -> bool:
55
+ entry = self._data.get(key)
56
+ if entry is None:
57
+ return False
58
+ if self._expired(entry):
59
+ del self._data[key]
60
+ return False
61
+ return True
62
+
63
+ def delete(self, key: str) -> bool:
64
+ if key in self._data:
65
+ del self._data[key]
66
+ return True
67
+ return False
68
+
69
+ def clear(self) -> None:
70
+ self._data.clear()
71
+
72
+ def prune(self) -> int:
73
+ """Remove all expired entries. Returns count removed."""
74
+ keys = [k for k, e in self._data.items() if self._expired(e)]
75
+ for k in keys:
76
+ del self._data[k]
77
+ return len(keys)
78
+
79
+ def __len__(self) -> int:
80
+ return len(self._data)
81
+
82
+ @staticmethod
83
+ def _now_ms() -> float:
84
+ return time.monotonic() * 1000.0
85
+
86
+ def _expired(self, entry: _Entry[V]) -> bool:
87
+ if self._ttl_ms == 0:
88
+ return False
89
+ return self._now_ms() >= entry.expires_at
@@ -0,0 +1,17 @@
1
+ """Cost subpackage — tracker + pricing + anomaly."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ai_shield.cost.anomaly import detect_anomaly
6
+ from ai_shield.cost.pricing import MODEL_PRICING, estimate_cost, get_model_pricing
7
+ from ai_shield.cost.tracker import CostTracker, MemoryStore, RedisLike
8
+
9
+ __all__ = [
10
+ "MODEL_PRICING",
11
+ "CostTracker",
12
+ "MemoryStore",
13
+ "RedisLike",
14
+ "detect_anomaly",
15
+ "estimate_cost",
16
+ "get_model_pricing",
17
+ ]
@@ -0,0 +1,50 @@
1
+ """Z-score anomaly detection for cost-spike alerts.
2
+
3
+ 1:1 port of `packages/core/src/cost/anomaly.ts`.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import math
9
+
10
+ from ai_shield.types import AnomalyResult
11
+
12
+
13
+ def detect_anomaly(
14
+ samples: list[float],
15
+ current: float,
16
+ *,
17
+ z_threshold: float = 2.5,
18
+ ) -> AnomalyResult:
19
+ """Return AnomalyResult flagging `current` if z-score exceeds threshold.
20
+
21
+ Returns is_anomaly=False when fewer than 3 samples or stdDev=0.
22
+ """
23
+ n = len(samples)
24
+ if n < 3:
25
+ return AnomalyResult(
26
+ is_anomaly=False,
27
+ z_score=0.0,
28
+ current_value=current,
29
+ mean=0.0,
30
+ std_dev=0.0,
31
+ )
32
+ mean = sum(samples) / n
33
+ variance = sum((s - mean) ** 2 for s in samples) / n
34
+ std = math.sqrt(variance)
35
+ if std == 0.0:
36
+ return AnomalyResult(
37
+ is_anomaly=False,
38
+ z_score=0.0,
39
+ current_value=current,
40
+ mean=mean,
41
+ std_dev=0.0,
42
+ )
43
+ z = (current - mean) / std
44
+ return AnomalyResult(
45
+ is_anomaly=abs(z) >= z_threshold,
46
+ z_score=z,
47
+ current_value=current,
48
+ mean=mean,
49
+ std_dev=std,
50
+ )
@@ -0,0 +1,62 @@
1
+ """Model pricing table + estimate_cost helper.
2
+
3
+ 1:1 port of `packages/core/src/cost/pricing.ts`.
4
+ USD per 1M tokens. Update as providers change rates.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from ai_shield.types import ModelPricing
10
+
11
+ MODEL_PRICING: dict[str, ModelPricing] = {
12
+ # OpenAI
13
+ "gpt-5.2": ModelPricing(input_per_1m=4.50, output_per_1m=18.00),
14
+ "gpt-5.0": ModelPricing(input_per_1m=3.00, output_per_1m=12.00),
15
+ "gpt-4.1": ModelPricing(input_per_1m=2.50, output_per_1m=10.00),
16
+ "gpt-4o": ModelPricing(input_per_1m=2.50, output_per_1m=10.00),
17
+ "gpt-4o-mini": ModelPricing(input_per_1m=0.15, output_per_1m=0.60),
18
+ "o3": ModelPricing(input_per_1m=15.00, output_per_1m=60.00),
19
+ "o4-mini": ModelPricing(input_per_1m=1.10, output_per_1m=4.40),
20
+ # Anthropic
21
+ "claude-opus-4-7": ModelPricing(input_per_1m=15.00, output_per_1m=75.00),
22
+ "claude-opus-4-6": ModelPricing(input_per_1m=15.00, output_per_1m=75.00),
23
+ "claude-sonnet-4-7": ModelPricing(input_per_1m=3.00, output_per_1m=15.00),
24
+ "claude-sonnet-4-5": ModelPricing(input_per_1m=3.00, output_per_1m=15.00),
25
+ "claude-haiku-4-7": ModelPricing(input_per_1m=0.80, output_per_1m=4.00),
26
+ # Google
27
+ "gemini-2.5-pro": ModelPricing(input_per_1m=1.25, output_per_1m=5.00),
28
+ "gemini-2.5-flash": ModelPricing(input_per_1m=0.075, output_per_1m=0.30),
29
+ "gemini-2.5-flash-lite": ModelPricing(input_per_1m=0.0375, output_per_1m=0.15),
30
+ # xAI
31
+ "grok-4": ModelPricing(input_per_1m=5.00, output_per_1m=15.00),
32
+ # Mistral
33
+ "mistral-large": ModelPricing(input_per_1m=2.00, output_per_1m=6.00),
34
+ "mistral-small": ModelPricing(input_per_1m=0.20, output_per_1m=0.60),
35
+ }
36
+
37
+
38
+ def get_model_pricing(model: str) -> ModelPricing:
39
+ """Return pricing for `model`. Exact match → longest-prefix match → fallback.
40
+
41
+ Longest-prefix-first ensures `gpt-4o-mini-2024-07-18` matches `gpt-4o-mini`
42
+ ($0.15) NOT `gpt-4o` ($2.50) — critical for cost-accuracy with versioned
43
+ snapshots. Insertion-order would otherwise return the first registered
44
+ prefix (which is shorter for many providers).
45
+ """
46
+ if model in MODEL_PRICING:
47
+ return MODEL_PRICING[model]
48
+ # Sort keys by length descending so most-specific prefix wins first.
49
+ for key in sorted(MODEL_PRICING.keys(), key=len, reverse=True):
50
+ if model.startswith(key):
51
+ return MODEL_PRICING[key]
52
+ return MODEL_PRICING["gpt-4o-mini"]
53
+
54
+
55
+ def estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
56
+ """Compute USD estimate for a single LLM call."""
57
+ if input_tokens < 0 or output_tokens < 0:
58
+ raise ValueError("Token counts must be non-negative")
59
+ p = get_model_pricing(model)
60
+ return (input_tokens / 1_000_000) * p.input_per_1m + (
61
+ output_tokens / 1_000_000
62
+ ) * p.output_per_1m
@@ -0,0 +1,175 @@
1
+ """Async cost tracker — soft/hard budgets, in-memory or Redis backend.
2
+
3
+ 1:1 port of `packages/core/src/cost/tracker.ts`.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import os
10
+ from collections import deque
11
+ from dataclasses import dataclass, field
12
+ from datetime import datetime, timezone
13
+ from typing import Protocol
14
+
15
+ from ai_shield.cost.pricing import estimate_cost
16
+ from ai_shield.types import (
17
+ BudgetCheckResult,
18
+ BudgetConfig,
19
+ BudgetPeriod,
20
+ CostRecord,
21
+ )
22
+
23
+ _DEFAULT_MAX_RECORDS = int(os.environ.get("AI_SHIELD_MAX_RECORDS", "10000"))
24
+
25
+
26
+ class RedisLike(Protocol):
27
+ """Minimal Redis interface — `incrbyfloat` + `expire` + `get`."""
28
+
29
+ async def incrbyfloat(self, key: str, value: float) -> float: ...
30
+ async def expire(self, key: str, seconds: int) -> bool: ...
31
+ async def get(self, key: str) -> str | None: ...
32
+
33
+
34
+ @dataclass
35
+ class MemoryStore:
36
+ """In-process fallback when Redis is not provided."""
37
+
38
+ _data: dict[str, float] = field(default_factory=dict)
39
+ _expires_at: dict[str, float] = field(default_factory=dict)
40
+ _lock: asyncio.Lock = field(default_factory=asyncio.Lock)
41
+
42
+ async def incrbyfloat(self, key: str, value: float) -> float:
43
+ async with self._lock:
44
+ self._sweep_expired()
45
+ current = self._data.get(key, 0.0)
46
+ new = current + value
47
+ self._data[key] = new
48
+ return new
49
+
50
+ async def expire(self, key: str, seconds: int) -> bool:
51
+ async with self._lock:
52
+ if key not in self._data:
53
+ return False
54
+ loop_time = asyncio.get_event_loop().time()
55
+ self._expires_at[key] = loop_time + seconds
56
+ return True
57
+
58
+ async def get(self, key: str) -> str | None:
59
+ async with self._lock:
60
+ self._sweep_expired()
61
+ v = self._data.get(key)
62
+ return None if v is None else str(v)
63
+
64
+ def _sweep_expired(self) -> None:
65
+ now = asyncio.get_event_loop().time()
66
+ expired = [k for k, t in self._expires_at.items() if t <= now]
67
+ for k in expired:
68
+ self._data.pop(k, None)
69
+ self._expires_at.pop(k, None)
70
+
71
+
72
+ def _period_seconds(period: BudgetPeriod) -> int:
73
+ if period == "hourly":
74
+ return 3600
75
+ if period == "daily":
76
+ return 86400
77
+ if period == "monthly":
78
+ return 86400 * 31
79
+ raise ValueError(f"Unknown period: {period!r}")
80
+
81
+
82
+ def _period_key(period: BudgetPeriod, now: datetime | None = None) -> str:
83
+ moment = now or datetime.now(timezone.utc)
84
+ if period == "hourly":
85
+ return moment.strftime("%Y%m%d%H")
86
+ if period == "daily":
87
+ return moment.strftime("%Y%m%d")
88
+ if period == "monthly":
89
+ return moment.strftime("%Y%m")
90
+ raise ValueError(f"Unknown period: {period!r}")
91
+
92
+
93
+ @dataclass
94
+ class CostTracker:
95
+ """Track LLM spend with optional Redis backend.
96
+
97
+ All writes are atomic at the store level (`INCRBYFLOAT` on Redis,
98
+ `asyncio.Lock` on the in-memory store).
99
+ """
100
+
101
+ budget: BudgetConfig = field(default_factory=BudgetConfig)
102
+ store: RedisLike | None = None
103
+ max_records: int = _DEFAULT_MAX_RECORDS
104
+
105
+ _records: deque[CostRecord] = field(init=False)
106
+ _store: RedisLike = field(init=False)
107
+
108
+ def __post_init__(self) -> None:
109
+ self._records = deque(maxlen=self.max_records)
110
+ self._store = self.store if self.store is not None else MemoryStore()
111
+
112
+ @staticmethod
113
+ def _key(entity_id: str, period: BudgetPeriod) -> str:
114
+ return f"ai-shield:cost:{entity_id}:{period}:{_period_key(period)}"
115
+
116
+ async def record(
117
+ self,
118
+ *,
119
+ entity_id: str,
120
+ model: str,
121
+ input_tokens: int,
122
+ output_tokens: int,
123
+ actual_usd: float | None = None,
124
+ ) -> CostRecord:
125
+ cost = (
126
+ actual_usd
127
+ if actual_usd is not None
128
+ else estimate_cost(
129
+ model,
130
+ input_tokens,
131
+ output_tokens,
132
+ )
133
+ )
134
+ record = CostRecord(
135
+ entity_id=entity_id,
136
+ model=model,
137
+ input_tokens=input_tokens,
138
+ output_tokens=output_tokens,
139
+ actual_usd=cost,
140
+ timestamp=datetime.now(timezone.utc).isoformat(),
141
+ )
142
+ self._records.append(record)
143
+
144
+ key = self._key(entity_id, self.budget.period)
145
+ await self._store.incrbyfloat(key, cost)
146
+ await self._store.expire(key, _period_seconds(self.budget.period))
147
+ return record
148
+
149
+ async def check_budget(self, entity_id: str) -> BudgetCheckResult:
150
+ key = self._key(entity_id, self.budget.period)
151
+ raw = await self._store.get(key)
152
+ spend = float(raw) if raw is not None else 0.0
153
+
154
+ soft = self.budget.soft_limit_usd
155
+ hard = self.budget.hard_limit_usd
156
+
157
+ soft_exceeded = soft is not None and spend >= soft
158
+ hard_exceeded = hard is not None and spend >= hard
159
+
160
+ return BudgetCheckResult(
161
+ allowed=not hard_exceeded,
162
+ current_spend_usd=spend,
163
+ limit_usd=hard if hard is not None else soft,
164
+ period=self.budget.period,
165
+ soft_exceeded=soft_exceeded,
166
+ hard_exceeded=hard_exceeded,
167
+ )
168
+
169
+ async def get_current_spend(self, entity_id: str) -> float:
170
+ key = self._key(entity_id, self.budget.period)
171
+ raw = await self._store.get(key)
172
+ return float(raw) if raw is not None else 0.0
173
+
174
+ def recent_records(self, n: int = 100) -> list[CostRecord]:
175
+ return list(self._records)[-n:]