agenthacker 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,462 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright 2026 AgentHacker
3
+
4
+ """Anomaly detection and risk scoring for firewall-sdk agents.
5
+
6
+ Score formula (0–100):
7
+ session_component = session_rate × damping × 70 (max 70 pts)
8
+ weighted_component = weighted_rate × 30 (max 30 pts, only if total >= 8)
9
+ risk_score = min(100, session + weighted)
10
+
11
+ Session = blocked/total in the last 2 hours.
12
+ Damping = min(1.0, session_count / 3) — ramps up over the first 3 messages so
13
+ a single rejection never instantly blocks anyone.
14
+ Weighted = log-decay all-time rate: Σ(w×blocked)/Σ(w) where w=1/ln(rank+1).
15
+ Starts at 0 and stays 0 until the user has 8+ total invocations.
16
+
17
+ Usage (inside an agent endpoint):
18
+ from firewall_sdk import check_user_risk, RiskLevel
19
+
20
+ risk = check_user_risk(user_hash, store._conn)
21
+ if risk.level == RiskLevel.CRITICAL:
22
+ raise HTTPException(423, "Locked", headers={"Retry-After": "3600"})
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import logging
28
+ import os
29
+ from dataclasses import dataclass, field
30
+ from datetime import datetime, timezone
31
+ from enum import Enum
32
+ from typing import Any
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ _SESSION_HOURS = 2 # window that counts as "current session"
37
+ _WEIGHTED_MIN_MSGS = (
38
+ 8 # weighted component is 0 until user has this many total messages
39
+ )
40
+ _SESSION_DAMPING_MSGS = (
41
+ 3 # session rate ramps to full weight after this many session messages
42
+ )
43
+
44
+
45
+ # ── Feature flag ──────────────────────────────────────────────────────
46
+
47
+
48
+ def is_anomaly_enabled() -> bool:
49
+ """Return True if anomaly detection is active (default: True).
50
+
51
+ Set env var FIREWALL_ANOMALY_DETECTION=0 to disable. When disabled,
52
+ check_user_risk() always returns LOW so no enforcement happens and
53
+ the dashboard can hide the risk panel entirely.
54
+ """
55
+ return os.environ.get("FIREWALL_ANOMALY_DETECTION", "1").lower() not in (
56
+ "0",
57
+ "false",
58
+ "no",
59
+ "off",
60
+ )
61
+
62
+
63
+ # ── Public types ──────────────────────────────────────────────────────
64
+
65
+
66
+ class RiskLevel(str, Enum):
67
+ """Risk tier derived from a numerical score."""
68
+
69
+ LOW = "LOW" # 0–30: proceed normally
70
+ MEDIUM = "MEDIUM" # 31–60: read-only tools
71
+ HIGH = "HIGH" # 61–80: read-only + shorter context + system-prompt flag
72
+ CRITICAL = "CRITICAL" # 81–100: block (HTTP 423)
73
+
74
+ @classmethod
75
+ def from_score(cls, score: float) -> "RiskLevel":
76
+ if score >= 81:
77
+ return cls.CRITICAL
78
+ if score >= 61:
79
+ return cls.HIGH
80
+ if score >= 31:
81
+ return cls.MEDIUM
82
+ return cls.LOW
83
+
84
+
85
+ @dataclass
86
+ class RiskFactor:
87
+ """A single contributing component with its score contribution."""
88
+
89
+ name: str
90
+ description: str
91
+ contribution: float # points added to the final score
92
+ signal_value: Any # raw value for display in the dashboard
93
+
94
+
95
+ @dataclass
96
+ class RiskScore:
97
+ """Complete risk assessment for a user."""
98
+
99
+ user_hash: str
100
+ score: float # 0.0–100.0 (clamped)
101
+ level: RiskLevel
102
+ factors: list[RiskFactor] = field(default_factory=list)
103
+ computed_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
104
+
105
+
106
+ # ── Scorer ────────────────────────────────────────────────────────────
107
+
108
+
109
+ class RiskScorer:
110
+ """Stateless risk scorer — accepts a psycopg2 connection and queries Postgres.
111
+
112
+ Degrades gracefully when conn is None (NullStore paths) or on DB errors:
113
+ always returns RiskScore(score=0.0, level=LOW, factors=[]).
114
+ """
115
+
116
+ def compute(
117
+ self,
118
+ user_hash: str,
119
+ conn: Any,
120
+ window_hours: int = _SESSION_HOURS,
121
+ ) -> RiskScore:
122
+ """Compute a RiskScore for user_hash."""
123
+ if conn is None:
124
+ return self._zero(user_hash)
125
+ try:
126
+ return self._compute_internal(user_hash, conn)
127
+ except Exception:
128
+ logger.warning(
129
+ "RiskScorer.compute failed for %s — returning LOW",
130
+ user_hash,
131
+ exc_info=True,
132
+ )
133
+ return self._zero(user_hash)
134
+
135
+ # ── Internal ─────────────────────────────────────────────────────
136
+
137
+ @staticmethod
138
+ def _zero(user_hash: str) -> RiskScore:
139
+ return RiskScore(user_hash=user_hash, score=0.0, level=RiskLevel.LOW)
140
+
141
+ def _compute_internal(self, user_hash: str, conn: Any) -> RiskScore:
142
+ factors: list[RiskFactor] = []
143
+
144
+ for signal_fn in (
145
+ self._signal_session,
146
+ self._signal_weighted,
147
+ ):
148
+ try:
149
+ factor = signal_fn(user_hash, conn)
150
+ factors.append(factor)
151
+ except Exception:
152
+ logger.debug(
153
+ "Signal %s failed — skipping", signal_fn.__name__, exc_info=True
154
+ )
155
+
156
+ score = round(min(100.0, sum(f.contribution for f in factors)), 2)
157
+ return RiskScore(
158
+ user_hash=user_hash,
159
+ score=score,
160
+ level=RiskLevel.from_score(score),
161
+ factors=factors, # always include both factors so the dashboard can display them
162
+ )
163
+
164
+ # ── Signal 1: session threat rate (max 70 pts) ────────────────────
165
+ #
166
+ # Uses log-decay recency weighting within the 2h window so the most
167
+ # recent blocked messages count more than older ones in the same session.
168
+ # weight for rank r = 1 / ln(r + 1), r=1 is most recent message.
169
+ #
170
+ # weighted_rate = Σ(w × blocked) / Σ(w) within last 2h
171
+ # damping = min(1.0, session_count / 3) — ramps over first 3 msgs
172
+ # contribution = weighted_rate × damping × 70
173
+
174
+ def _signal_session(self, user_hash: str, conn: Any) -> RiskFactor:
175
+ sql = """
176
+ SELECT
177
+ COUNT(*) AS session_total,
178
+ COUNT(*) FILTER (WHERE blocked) AS session_blocked,
179
+ SUM(CASE WHEN blocked THEN 1.0/LN(rn::float + 1) ELSE 0 END) AS w_flagged,
180
+ SUM(1.0/LN(rn::float + 1)) AS w_total
181
+ FROM (
182
+ SELECT blocked,
183
+ ROW_NUMBER() OVER (ORDER BY timestamp DESC) AS rn
184
+ FROM invocations
185
+ WHERE user_hash = %s
186
+ AND timestamp > NOW() - INTERVAL %s
187
+ ) t
188
+ """
189
+ row = _row4(conn, sql, (user_hash, f"{_SESSION_HOURS} hours"))
190
+ session_total = int(row[0] or 0)
191
+ session_blocked = int(row[1] or 0)
192
+ w_flagged = float(row[2] or 0)
193
+ w_total = float(row[3] or 0)
194
+
195
+ if session_total == 0:
196
+ return RiskFactor(
197
+ name="session_threat_rate",
198
+ description="No activity in current session",
199
+ contribution=0.0,
200
+ signal_value={"blocked": 0, "total": 0, "rate": 0.0},
201
+ )
202
+
203
+ weighted_rate = w_flagged / w_total if w_total > 0 else 0.0
204
+ damping = min(1.0, session_total / _SESSION_DAMPING_MSGS)
205
+ contribution = round(weighted_rate * damping * 70.0, 2)
206
+
207
+ return RiskFactor(
208
+ name="session_threat_rate",
209
+ description=(
210
+ f"{session_blocked}/{session_total} requests blocked in current session "
211
+ f"(log-decay weighted rate {weighted_rate * 100:.1f}%, damping {damping:.2f})"
212
+ ),
213
+ contribution=contribution,
214
+ signal_value={
215
+ "blocked": session_blocked,
216
+ "total": session_total,
217
+ "rate": round(weighted_rate, 3),
218
+ "damping": round(damping, 3),
219
+ },
220
+ )
221
+
222
+ # ── Signal 2: weighted all-time threat rate (max 30 pts) ──────────
223
+ #
224
+ # weight for rank r = 1 / ln(r + 1) (r=1 is most recent)
225
+ # weighted_rate = Σ(w × blocked) / Σ(w)
226
+ # Only activates once the user has >= 8 total invocations.
227
+ # Before that, contribution is always 0 — one bad message in a fresh
228
+ # account cannot accumulate history-based risk.
229
+
230
+ def _signal_weighted(self, user_hash: str, conn: Any) -> RiskFactor:
231
+ sql = """
232
+ SELECT
233
+ COUNT(*) AS total_count,
234
+ SUM(CASE WHEN blocked THEN 1.0/LN(rn::float + 1) ELSE 0 END) AS w_flagged,
235
+ SUM(1.0/LN(rn::float + 1)) AS w_total
236
+ FROM (
237
+ SELECT blocked,
238
+ ROW_NUMBER() OVER (ORDER BY timestamp DESC) AS rn
239
+ FROM invocations
240
+ WHERE user_hash = %s
241
+ ) t
242
+ """
243
+ row = _row3(conn, sql, (user_hash,))
244
+ total_count = int(row[0] or 0)
245
+ w_flagged = float(row[1] or 0)
246
+ w_total = float(row[2] or 0)
247
+
248
+ if total_count < _WEIGHTED_MIN_MSGS:
249
+ return RiskFactor(
250
+ name="weighted_all_time_rate",
251
+ description=(
252
+ f"Weighted history not yet active "
253
+ f"({total_count}/{_WEIGHTED_MIN_MSGS} messages minimum)"
254
+ ),
255
+ contribution=0.0,
256
+ signal_value={
257
+ "total_count": total_count,
258
+ "min_required": _WEIGHTED_MIN_MSGS,
259
+ },
260
+ )
261
+
262
+ weighted_rate = w_flagged / w_total if w_total > 0 else 0.0
263
+ contribution = round(weighted_rate * 30.0, 2)
264
+
265
+ return RiskFactor(
266
+ name="weighted_all_time_rate",
267
+ description=(
268
+ f"Weighted all-time block rate: {weighted_rate * 100:.1f}% "
269
+ f"across {total_count} total requests (log-decay)"
270
+ ),
271
+ contribution=contribution,
272
+ signal_value={
273
+ "total_count": total_count,
274
+ "weighted_rate": round(weighted_rate, 3),
275
+ },
276
+ )
277
+
278
+
279
+ # ── Content signal: script anomaly (soft) ────────────────────────────
280
+
281
+
282
+ def script_risk_factor(text: str) -> RiskFactor | None:
283
+ """Build a soft RiskFactor for non-Latin / mixed-script / homoglyph input.
284
+
285
+ Capped at 15 points by ``lang.script_risk`` — a soft signal that nudges
286
+ behavioural risk for a burst of homoglyph/mixed-script probing, never a
287
+ hard block. Returns None when there is no script anomaly (no factor to add)
288
+ so legitimate non-English traffic with no anomaly stays unaffected.
289
+ """
290
+ try:
291
+ from firewall_sdk import lang
292
+
293
+ contribution = lang.script_risk(text)
294
+ except Exception:
295
+ return None
296
+ if contribution <= 0.0:
297
+ return None
298
+ return RiskFactor(
299
+ name="script_anomaly",
300
+ description="Non-Latin / mixed-script / homoglyph-confusable input",
301
+ contribution=contribution,
302
+ signal_value={"text_preview": text[:80]},
303
+ )
304
+
305
+
306
+ def _apply_extra_factors(
307
+ risk: RiskScore, extra_factors: list[RiskFactor] | None
308
+ ) -> RiskScore:
309
+ """Fold opt-in extra factors into a RiskScore, re-clamping score and level.
310
+
311
+ Skips pardoned scores (a manual pardon means the user is trusted).
312
+ """
313
+ if not extra_factors:
314
+ return risk
315
+ if any(f.name == "manual_pardon" for f in risk.factors):
316
+ return risk
317
+ factors = risk.factors + list(extra_factors)
318
+ score = round(min(100.0, sum(f.contribution for f in factors)), 2)
319
+ return RiskScore(
320
+ user_hash=risk.user_hash,
321
+ score=score,
322
+ level=RiskLevel.from_score(score),
323
+ factors=factors,
324
+ computed_at=risk.computed_at,
325
+ )
326
+
327
+
328
+ # ── Query helpers ─────────────────────────────────────────────────────
329
+
330
+
331
+ def _row(conn: Any, sql: str, params: tuple) -> tuple:
332
+ with conn.cursor() as cur:
333
+ cur.execute(sql, params)
334
+ return cur.fetchone() or (None, None)
335
+
336
+
337
+ def _row3(conn: Any, sql: str, params: tuple) -> tuple:
338
+ with conn.cursor() as cur:
339
+ cur.execute(sql, params)
340
+ return cur.fetchone() or (None, None, None)
341
+
342
+
343
+ def _row4(conn: Any, sql: str, params: tuple) -> tuple:
344
+ with conn.cursor() as cur:
345
+ cur.execute(sql, params)
346
+ return cur.fetchone() or (None, None, None, None)
347
+
348
+
349
+ # ── Public convenience function ───────────────────────────────────────
350
+
351
+
352
+ def check_user_risk(
353
+ user_hash: str,
354
+ conn: Any,
355
+ window_hours: int = _SESSION_HOURS,
356
+ *,
357
+ extra_factors: list[RiskFactor] | None = None,
358
+ ) -> RiskScore:
359
+ """Compute and return a RiskScore for user_hash.
360
+
361
+ When a cloud client is configured, delegates to the centralized Aurora
362
+ backend for cross-session, cross-instance risk scoring. Falls back to
363
+ local psycopg2 computation if the cloud client is unavailable or returns
364
+ an error.
365
+
366
+ Without a cloud client, checks for an active manual pardon first, then
367
+ runs the local RiskScorer against the local psycopg2 connection.
368
+
369
+ extra_factors: optional opt-in content factors (e.g. ``script_risk_factor``)
370
+ folded into the final score and level. Ignored for pardoned users.
371
+ """
372
+ if not is_anomaly_enabled():
373
+ return RiskScore(user_hash=user_hash, score=0.0, level=RiskLevel.LOW)
374
+
375
+ # ── Cloud path: delegate to centralized Aurora ─────────────────────
376
+ try:
377
+ from firewall_sdk.cloud_client import get_client
378
+
379
+ client = get_client()
380
+ if client is not None:
381
+ data = client.get_risk_score(user_hash)
382
+ if data is not None:
383
+ now = datetime.now(timezone.utc)
384
+ pardoned = False
385
+ if data.get("pardoned_until"):
386
+ try:
387
+ pardoned_dt = datetime.fromisoformat(data["pardoned_until"])
388
+ pardoned = pardoned_dt > now
389
+ except ValueError:
390
+ pass
391
+ if pardoned:
392
+ return RiskScore(
393
+ user_hash=user_hash,
394
+ score=0.0,
395
+ level=RiskLevel.LOW,
396
+ factors=[
397
+ RiskFactor(
398
+ name="manual_pardon",
399
+ description="Pardoned via admin dashboard",
400
+ contribution=0.0,
401
+ signal_value={
402
+ "pardoned_until": data.get("pardoned_until")
403
+ },
404
+ )
405
+ ],
406
+ )
407
+ factors = [
408
+ RiskFactor(
409
+ name=f["name"],
410
+ description=f["description"],
411
+ contribution=f["contribution"],
412
+ signal_value=f.get("signal_value"),
413
+ )
414
+ for f in data.get("factors", [])
415
+ ]
416
+ return _apply_extra_factors(
417
+ RiskScore(
418
+ user_hash=user_hash,
419
+ score=float(data.get("score", 0.0)),
420
+ level=RiskLevel(data.get("level", "LOW")),
421
+ factors=factors,
422
+ ),
423
+ extra_factors,
424
+ )
425
+ except Exception:
426
+ logger.debug(
427
+ "check_user_risk cloud lookup failed — falling back to local", exc_info=True
428
+ )
429
+
430
+ # ── Local path: psycopg2 + pardon check ───────────────────────────
431
+ if conn is not None:
432
+ try:
433
+ now = datetime.now(timezone.utc)
434
+ with conn.cursor() as cur:
435
+ cur.execute(
436
+ "SELECT pardoned_until FROM user_risk_scores WHERE user_hash = %s",
437
+ (user_hash,),
438
+ )
439
+ row = cur.fetchone()
440
+ if row and row[0] and row[0] > now:
441
+ pardoned_until = row[0]
442
+ return RiskScore(
443
+ user_hash=user_hash,
444
+ score=0.0,
445
+ level=RiskLevel.LOW,
446
+ factors=[
447
+ RiskFactor(
448
+ name="manual_pardon",
449
+ description=f"Manually reset by admin — active until {pardoned_until.isoformat()}",
450
+ contribution=0.0,
451
+ signal_value={"pardoned_until": pardoned_until.isoformat()},
452
+ )
453
+ ],
454
+ )
455
+ except Exception:
456
+ logger.debug(
457
+ "check_user_risk pardon lookup failed — proceeding with compute",
458
+ exc_info=True,
459
+ )
460
+ return _apply_extra_factors(
461
+ RiskScorer().compute(user_hash, conn, window_hours), extra_factors
462
+ )