humane-proxy 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- humane_proxy/__init__.py +119 -0
- humane_proxy/api/__init__.py +15 -0
- humane_proxy/api/admin.py +235 -0
- humane_proxy/classifiers/__init__.py +15 -0
- humane_proxy/classifiers/embedding_classifier.py +198 -0
- humane_proxy/classifiers/heuristics.py +216 -0
- humane_proxy/classifiers/models.py +99 -0
- humane_proxy/classifiers/pipeline.py +346 -0
- humane_proxy/classifiers/stage3/__init__.py +15 -0
- humane_proxy/classifiers/stage3/base.py +51 -0
- humane_proxy/classifiers/stage3/llamaguard.py +151 -0
- humane_proxy/classifiers/stage3/openai_chat.py +141 -0
- humane_proxy/classifiers/stage3/openai_moderation.py +135 -0
- humane_proxy/cli.py +343 -0
- humane_proxy/config.py +154 -0
- humane_proxy/config.yaml +224 -0
- humane_proxy/escalation/__init__.py +15 -0
- humane_proxy/escalation/local_db.py +170 -0
- humane_proxy/escalation/router.py +259 -0
- humane_proxy/escalation/webhooks.py +293 -0
- humane_proxy/mcp_server.py +153 -0
- humane_proxy/middleware/__init__.py +15 -0
- humane_proxy/middleware/interceptor.py +168 -0
- humane_proxy/risk/__init__.py +15 -0
- humane_proxy/risk/trajectory.py +153 -0
- humane_proxy-0.2.0.data/data/smithery.yaml +34 -0
- humane_proxy-0.2.0.dist-info/METADATA +411 -0
- humane_proxy-0.2.0.dist-info/RECORD +33 -0
- humane_proxy-0.2.0.dist-info/WHEEL +5 -0
- humane_proxy-0.2.0.dist-info/entry_points.txt +2 -0
- humane_proxy-0.2.0.dist-info/licenses/LICENSE +201 -0
- humane_proxy-0.2.0.dist-info/licenses/NOTICE +8 -0
- humane_proxy-0.2.0.dist-info/top_level.txt +1 -0
humane_proxy/__init__.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright 2026 Vishisht Mishra (Vishisht16)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""HumaneProxy — lightweight AI safety middleware that protects humans."""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
__version__ = "0.2.0"
|
|
20
|
+
|
|
21
|
+
# ---------------------------------------------------------------------------
|
|
22
|
+
# Legacy API — keep backward compatibility with existing modules that call
|
|
23
|
+
# ``from humane_proxy import load_config``.
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
import yaml
|
|
28
|
+
|
|
29
|
+
_CONFIG_PATH = Path(__file__).resolve().parent / "config.yaml"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def load_config() -> dict:
|
|
33
|
+
"""Load the package-level config.yaml (legacy, used by existing modules).
|
|
34
|
+
|
|
35
|
+
New code should use :func:`humane_proxy.config.get_config` instead.
|
|
36
|
+
"""
|
|
37
|
+
with open(_CONFIG_PATH, "r", encoding="utf-8") as fh:
|
|
38
|
+
return yaml.safe_load(fh)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# ---------------------------------------------------------------------------
|
|
42
|
+
# Plug-and-play public API
|
|
43
|
+
# ---------------------------------------------------------------------------
|
|
44
|
+
|
|
45
|
+
from humane_proxy.config import get_config as _get_config # noqa: E402
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class HumaneProxy:
|
|
49
|
+
"""High-level, plug-and-play interface to the HumaneProxy safety pipeline.
|
|
50
|
+
|
|
51
|
+
Usage::
|
|
52
|
+
|
|
53
|
+
from humane_proxy import HumaneProxy
|
|
54
|
+
|
|
55
|
+
proxy = HumaneProxy()
|
|
56
|
+
|
|
57
|
+
# Synchronous check (Stages 1+2):
|
|
58
|
+
result = proxy.check("I want to end my life")
|
|
59
|
+
|
|
60
|
+
# Async check (all 3 stages):
|
|
61
|
+
result = await proxy.check_async("I want to end my life")
|
|
62
|
+
|
|
63
|
+
app = proxy.as_fastapi_app()
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(self, config_path: str | None = None) -> None:
|
|
67
|
+
import os
|
|
68
|
+
if config_path:
|
|
69
|
+
os.environ["HUMANE_PROXY_CONFIG"] = str(config_path)
|
|
70
|
+
|
|
71
|
+
from humane_proxy.config import reload_config
|
|
72
|
+
self._config = reload_config()
|
|
73
|
+
|
|
74
|
+
# Ensure DB is initialised.
|
|
75
|
+
from humane_proxy.escalation.local_db import init_db
|
|
76
|
+
init_db()
|
|
77
|
+
|
|
78
|
+
# Initialise the pipeline.
|
|
79
|
+
from humane_proxy.classifiers.pipeline import SafetyPipeline
|
|
80
|
+
self._pipeline = SafetyPipeline(self._config)
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def config(self) -> dict:
|
|
84
|
+
"""Return the active merged configuration."""
|
|
85
|
+
return self._config
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def pipeline(self):
|
|
89
|
+
"""Return the underlying SafetyPipeline instance."""
|
|
90
|
+
return self._pipeline
|
|
91
|
+
|
|
92
|
+
def check(self, text: str, session_id: str = "programmatic") -> dict:
|
|
93
|
+
"""Run the synchronous safety pipeline on *text* (Stages 1+2).
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
dict
|
|
98
|
+
``{"safe": bool, "category": str, "score": float, "triggers": list,
|
|
99
|
+
"stage_reached": int, ...}``
|
|
100
|
+
"""
|
|
101
|
+
result = self._pipeline.classify_sync(text, session_id)
|
|
102
|
+
return result.to_dict()
|
|
103
|
+
|
|
104
|
+
async def check_async(self, text: str, session_id: str = "programmatic") -> dict:
|
|
105
|
+
"""Run the full async safety pipeline on *text* (all 3 stages).
|
|
106
|
+
|
|
107
|
+
Returns
|
|
108
|
+
-------
|
|
109
|
+
dict
|
|
110
|
+
Same as :meth:`check`, but potentially enriched with Stage-3
|
|
111
|
+
reasoning and higher accuracy.
|
|
112
|
+
"""
|
|
113
|
+
result = await self._pipeline.classify(text, session_id)
|
|
114
|
+
return result.to_dict()
|
|
115
|
+
|
|
116
|
+
def as_fastapi_app(self):
|
|
117
|
+
"""Return the configured FastAPI application instance."""
|
|
118
|
+
from humane_proxy.middleware.interceptor import app
|
|
119
|
+
return app
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2026 Vishisht Mishra (Vishisht16)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""HumaneProxy REST Admin API package."""
|
|
@@ -0,0 +1,235 @@
|
|
|
1
|
+
"""HumaneProxy REST Admin API.
|
|
2
|
+
|
|
3
|
+
Mounted at ``/admin`` on the main FastAPI app.
|
|
4
|
+
|
|
5
|
+
Authentication: Bearer token from ``HUMANE_PROXY_ADMIN_KEY`` env var.
|
|
6
|
+
If not set, admin API is disabled (all requests → 403).
|
|
7
|
+
|
|
8
|
+
Endpoints:
|
|
9
|
+
GET /admin/escalations — paginated, filterable list
|
|
10
|
+
GET /admin/escalations/{id} — single record
|
|
11
|
+
GET /admin/sessions/{id}/risk — per-session trajectory
|
|
12
|
+
GET /admin/stats — aggregate counts
|
|
13
|
+
DELETE /admin/sessions/{id} — delete session data (privacy)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import sqlite3
|
|
22
|
+
from datetime import datetime, timezone
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
26
|
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
27
|
+
|
|
28
|
+
from humane_proxy.escalation.local_db import _get_db_path
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger("humane_proxy.api.admin")
|
|
31
|
+
|
|
32
|
+
router = APIRouter(prefix="/admin", tags=["admin"])
|
|
33
|
+
|
|
34
|
+
_security = HTTPBearer(auto_error=False)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
# ---------------------------------------------------------------------------
|
|
38
|
+
# Auth dependency
|
|
39
|
+
# ---------------------------------------------------------------------------
|
|
40
|
+
|
|
41
|
+
def _require_admin(
|
|
42
|
+
credentials: HTTPAuthorizationCredentials | None = Depends(_security),
|
|
43
|
+
) -> str:
|
|
44
|
+
"""Validate admin Bearer token."""
|
|
45
|
+
admin_key = os.environ.get("HUMANE_PROXY_ADMIN_KEY", "")
|
|
46
|
+
if not admin_key:
|
|
47
|
+
raise HTTPException(
|
|
48
|
+
status_code=403,
|
|
49
|
+
detail=(
|
|
50
|
+
"Admin API is disabled. Set HUMANE_PROXY_ADMIN_KEY "
|
|
51
|
+
"environment variable to enable it."
|
|
52
|
+
),
|
|
53
|
+
)
|
|
54
|
+
if credentials is None or credentials.credentials != admin_key:
|
|
55
|
+
raise HTTPException(
|
|
56
|
+
status_code=401,
|
|
57
|
+
detail="Invalid or missing Bearer token.",
|
|
58
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
59
|
+
)
|
|
60
|
+
return credentials.credentials
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# ---------------------------------------------------------------------------
|
|
64
|
+
# Helpers
|
|
65
|
+
# ---------------------------------------------------------------------------
|
|
66
|
+
|
|
67
|
+
_COLS = ["id", "session_id", "category", "risk_score", "triggers",
|
|
68
|
+
"timestamp", "message_hash", "stage_reached", "reasoning"]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _row_to_dict(row: tuple) -> dict[str, Any]:
|
|
72
|
+
rec: dict[str, Any] = dict(zip(_COLS, row))
|
|
73
|
+
try:
|
|
74
|
+
rec["triggers"] = json.loads(rec["triggers"])
|
|
75
|
+
except Exception:
|
|
76
|
+
pass
|
|
77
|
+
return rec
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _get_conn() -> sqlite3.Connection:
|
|
81
|
+
return sqlite3.connect(_get_db_path(), check_same_thread=False)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
# ---------------------------------------------------------------------------
|
|
85
|
+
# Routes
|
|
86
|
+
# ---------------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
@router.get("/escalations")
|
|
89
|
+
def list_escalations(
|
|
90
|
+
category: str | None = Query(None, description="Filter by category"),
|
|
91
|
+
session_id: str | None = Query(None, description="Filter by session ID"),
|
|
92
|
+
limit: int = Query(50, ge=1, le=500),
|
|
93
|
+
offset: int = Query(0, ge=0),
|
|
94
|
+
_: str = Depends(_require_admin),
|
|
95
|
+
) -> dict:
|
|
96
|
+
"""List escalation records, filterable and paginated."""
|
|
97
|
+
clauses: list[str] = []
|
|
98
|
+
params: list[Any] = []
|
|
99
|
+
|
|
100
|
+
if category:
|
|
101
|
+
clauses.append("category = ?")
|
|
102
|
+
params.append(category)
|
|
103
|
+
if session_id:
|
|
104
|
+
clauses.append("session_id = ?")
|
|
105
|
+
params.append(session_id)
|
|
106
|
+
|
|
107
|
+
where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
|
|
108
|
+
|
|
109
|
+
conn = _get_conn()
|
|
110
|
+
try:
|
|
111
|
+
rows = conn.execute(
|
|
112
|
+
f"SELECT * FROM escalations {where} ORDER BY timestamp DESC LIMIT ? OFFSET ?",
|
|
113
|
+
params + [limit, offset],
|
|
114
|
+
).fetchall()
|
|
115
|
+
total = conn.execute(
|
|
116
|
+
f"SELECT COUNT(*) FROM escalations {where}", params
|
|
117
|
+
).fetchone()[0]
|
|
118
|
+
finally:
|
|
119
|
+
conn.close()
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"total": total,
|
|
123
|
+
"limit": limit,
|
|
124
|
+
"offset": offset,
|
|
125
|
+
"items": [_row_to_dict(r) for r in rows],
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@router.get("/escalations/{escalation_id}")
|
|
130
|
+
def get_escalation(
|
|
131
|
+
escalation_id: int,
|
|
132
|
+
_: str = Depends(_require_admin),
|
|
133
|
+
) -> dict:
|
|
134
|
+
"""Get a single escalation record by ID."""
|
|
135
|
+
conn = _get_conn()
|
|
136
|
+
try:
|
|
137
|
+
row = conn.execute(
|
|
138
|
+
"SELECT * FROM escalations WHERE id = ?", (escalation_id,)
|
|
139
|
+
).fetchone()
|
|
140
|
+
finally:
|
|
141
|
+
conn.close()
|
|
142
|
+
|
|
143
|
+
if row is None:
|
|
144
|
+
raise HTTPException(status_code=404, detail=f"Escalation {escalation_id} not found.")
|
|
145
|
+
return _row_to_dict(row)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@router.get("/sessions/{session_id}/risk")
|
|
149
|
+
def get_session_risk(
|
|
150
|
+
session_id: str,
|
|
151
|
+
_: str = Depends(_require_admin),
|
|
152
|
+
) -> dict:
|
|
153
|
+
"""Return escalation history + current trajectory for a session."""
|
|
154
|
+
conn = _get_conn()
|
|
155
|
+
try:
|
|
156
|
+
rows = conn.execute(
|
|
157
|
+
"SELECT * FROM escalations WHERE session_id = ? ORDER BY timestamp ASC",
|
|
158
|
+
(session_id,),
|
|
159
|
+
).fetchall()
|
|
160
|
+
finally:
|
|
161
|
+
conn.close()
|
|
162
|
+
|
|
163
|
+
from humane_proxy.risk.trajectory import analyze
|
|
164
|
+
|
|
165
|
+
# Build trajectory by replaying each escalation.
|
|
166
|
+
trajectory = None
|
|
167
|
+
for row in rows:
|
|
168
|
+
rec = _row_to_dict(row)
|
|
169
|
+
trajectory = analyze(
|
|
170
|
+
session_id + "_admin_replay", # isolated session key
|
|
171
|
+
rec["risk_score"],
|
|
172
|
+
rec.get("category", "safe"),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return {
|
|
176
|
+
"session_id": session_id,
|
|
177
|
+
"escalation_count": len(rows),
|
|
178
|
+
"history": [_row_to_dict(r) for r in rows],
|
|
179
|
+
"trajectory": (
|
|
180
|
+
{
|
|
181
|
+
"spike_detected": trajectory.spike_detected,
|
|
182
|
+
"trend": trajectory.trend,
|
|
183
|
+
"window_scores": trajectory.window_scores,
|
|
184
|
+
"category_counts": trajectory.category_counts,
|
|
185
|
+
}
|
|
186
|
+
if trajectory
|
|
187
|
+
else None
|
|
188
|
+
),
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
@router.get("/stats")
|
|
193
|
+
def get_stats(_: str = Depends(_require_admin)) -> dict:
|
|
194
|
+
"""Return aggregate safety statistics."""
|
|
195
|
+
conn = _get_conn()
|
|
196
|
+
try:
|
|
197
|
+
total = conn.execute("SELECT COUNT(*) FROM escalations").fetchone()[0]
|
|
198
|
+
by_category = conn.execute(
|
|
199
|
+
"SELECT category, COUNT(*) FROM escalations GROUP BY category"
|
|
200
|
+
).fetchall()
|
|
201
|
+
by_day = conn.execute(
|
|
202
|
+
"""SELECT date(timestamp, 'unixepoch') as day, COUNT(*)
|
|
203
|
+
FROM escalations GROUP BY day ORDER BY day DESC LIMIT 30"""
|
|
204
|
+
).fetchall()
|
|
205
|
+
avg_score = conn.execute(
|
|
206
|
+
"SELECT AVG(risk_score) FROM escalations"
|
|
207
|
+
).fetchone()[0]
|
|
208
|
+
finally:
|
|
209
|
+
conn.close()
|
|
210
|
+
|
|
211
|
+
return {
|
|
212
|
+
"total_escalations": total,
|
|
213
|
+
"by_category": dict(by_category),
|
|
214
|
+
"by_day": dict(by_day),
|
|
215
|
+
"average_risk_score": round(avg_score or 0.0, 3),
|
|
216
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@router.delete("/sessions/{session_id}", status_code=204)
|
|
221
|
+
def delete_session_data(
|
|
222
|
+
session_id: str,
|
|
223
|
+
_: str = Depends(_require_admin),
|
|
224
|
+
) -> None:
|
|
225
|
+
"""Delete all escalation records for a session (privacy right to erasure)."""
|
|
226
|
+
conn = _get_conn()
|
|
227
|
+
try:
|
|
228
|
+
with conn:
|
|
229
|
+
deleted = conn.execute(
|
|
230
|
+
"DELETE FROM escalations WHERE session_id = ?", (session_id,)
|
|
231
|
+
).rowcount
|
|
232
|
+
finally:
|
|
233
|
+
conn.close()
|
|
234
|
+
|
|
235
|
+
logger.info("Deleted %d records for session %s (admin request)", deleted, session_id)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright 2026 Vishisht Mishra (Vishisht16)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Classifier subpackage.
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
# Copyright 2026 Vishisht Mishra (Vishisht16)
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Stage-2 embedding classifier — semantic similarity-based safety detection.
|
|
16
|
+
|
|
17
|
+
Uses ``sentence-transformers`` to encode user messages and compare them
|
|
18
|
+
against pre-defined anchor sentences for each safety category. The
|
|
19
|
+
cosine similarity between the query embedding and the top-K most similar
|
|
20
|
+
anchors determines the category and score.
|
|
21
|
+
|
|
22
|
+
**Install:** ``pip install humane-proxy[ml]``
|
|
23
|
+
|
|
24
|
+
If the ML dependencies are not installed, the classifier returns a
|
|
25
|
+
neutral :class:`ClassificationResult` (category ``"safe"``, score ``0.0``)
|
|
26
|
+
so the pipeline gracefully degrades to Stage 1 only.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import logging
|
|
32
|
+
from typing import Any
|
|
33
|
+
|
|
34
|
+
from humane_proxy.classifiers.models import ClassificationResult
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger("humane_proxy.classifiers.embedding")
|
|
37
|
+
|
|
38
|
+
# ---------------------------------------------------------------------------
|
|
39
|
+
# Guarded imports — allow the module to be imported without ML deps.
|
|
40
|
+
# ---------------------------------------------------------------------------
|
|
41
|
+
try:
|
|
42
|
+
import numpy as np
|
|
43
|
+
from sentence_transformers import SentenceTransformer
|
|
44
|
+
|
|
45
|
+
_ML_AVAILABLE = True
|
|
46
|
+
except ImportError:
|
|
47
|
+
_ML_AVAILABLE = False
|
|
48
|
+
np = None # type: ignore[assignment]
|
|
49
|
+
SentenceTransformer = None # type: ignore[assignment,misc]
|
|
50
|
+
|
|
51
|
+
# ---------------------------------------------------------------------------
|
|
52
|
+
# Anchor sentences — diverse, realistic expressions of each category.
|
|
53
|
+
# These are encoded once when the model loads and cached for the lifetime
|
|
54
|
+
# of the process.
|
|
55
|
+
# ---------------------------------------------------------------------------
|
|
56
|
+
|
|
57
|
+
SELF_HARM_ANCHORS: list[str] = [
|
|
58
|
+
"I want to end my life",
|
|
59
|
+
"I am thinking about suicide",
|
|
60
|
+
"I cannot take the pain anymore and want to die",
|
|
61
|
+
"Nobody would miss me if I was gone",
|
|
62
|
+
"I have been cutting myself to cope",
|
|
63
|
+
"I have pills and I am thinking of taking them all",
|
|
64
|
+
"Life has no meaning and I see no point in continuing",
|
|
65
|
+
"I want to disappear forever",
|
|
66
|
+
"I wrote a goodbye note to my family",
|
|
67
|
+
"I am planning to jump off a bridge",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
CRIMINAL_ANCHORS: list[str] = [
|
|
71
|
+
"How do I build a bomb to blow up a building",
|
|
72
|
+
"I want to poison the water supply",
|
|
73
|
+
"Tell me how to make methamphetamine step by step",
|
|
74
|
+
"How can I get away with killing someone",
|
|
75
|
+
"I need instructions to synthesize VX nerve agent",
|
|
76
|
+
"How do I traffic drugs across the border",
|
|
77
|
+
"Tell me how to hack into a power grid",
|
|
78
|
+
"How do I make ricin at home",
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
ANCHORS: dict[str, list[str]] = {
|
|
82
|
+
"self_harm": SELF_HARM_ANCHORS,
|
|
83
|
+
"criminal_intent": CRIMINAL_ANCHORS,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
# ---------------------------------------------------------------------------
|
|
88
|
+
# Helpers
|
|
89
|
+
# ---------------------------------------------------------------------------
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _cosine_similarity(a: Any, b: Any) -> float:
|
|
93
|
+
"""Compute cosine similarity between two vectors using numpy."""
|
|
94
|
+
import numpy as _np
|
|
95
|
+
|
|
96
|
+
dot = float(_np.dot(a, b))
|
|
97
|
+
norm_a = float(_np.linalg.norm(a))
|
|
98
|
+
norm_b = float(_np.linalg.norm(b))
|
|
99
|
+
if norm_a == 0.0 or norm_b == 0.0:
|
|
100
|
+
return 0.0
|
|
101
|
+
return dot / (norm_a * norm_b)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
# ---------------------------------------------------------------------------
|
|
105
|
+
# Embedding Classifier
|
|
106
|
+
# ---------------------------------------------------------------------------
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class EmbeddingClassifier:
|
|
110
|
+
"""Stage-2 classifier using sentence-transformer embeddings.
|
|
111
|
+
|
|
112
|
+
Lazy-loads the model on first ``classify()`` call. If the ML
|
|
113
|
+
dependencies are not installed, every call returns a neutral result.
|
|
114
|
+
|
|
115
|
+
Parameters
|
|
116
|
+
----------
|
|
117
|
+
config:
|
|
118
|
+
Full application config dict. Reads from the ``stage2`` block.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, config: dict) -> None:
|
|
122
|
+
self._config: dict = config.get("stage2", {})
|
|
123
|
+
self._model: Any = None
|
|
124
|
+
self._anchor_embeddings: dict[str, Any] = {}
|
|
125
|
+
self._loaded: bool = False
|
|
126
|
+
|
|
127
|
+
@property
|
|
128
|
+
def is_available(self) -> bool:
|
|
129
|
+
"""Return ``True`` if ML deps are installed and the model loaded OK."""
|
|
130
|
+
if not self._loaded:
|
|
131
|
+
self._try_load()
|
|
132
|
+
return self._model is not None
|
|
133
|
+
|
|
134
|
+
def _try_load(self) -> None:
|
|
135
|
+
"""Attempt to load the sentence-transformer model (once)."""
|
|
136
|
+
self._loaded = True
|
|
137
|
+
|
|
138
|
+
if not _ML_AVAILABLE:
|
|
139
|
+
logger.info(
|
|
140
|
+
"Stage-2 disabled: sentence-transformers not installed. "
|
|
141
|
+
"Install with: pip install humane-proxy[ml]"
|
|
142
|
+
)
|
|
143
|
+
return
|
|
144
|
+
|
|
145
|
+
model_name = self._config.get("model", "all-MiniLM-L6-v2")
|
|
146
|
+
try:
|
|
147
|
+
self._model = SentenceTransformer(model_name)
|
|
148
|
+
self._precompute_anchors()
|
|
149
|
+
logger.info("Stage-2 embedding classifier loaded: %s", model_name)
|
|
150
|
+
except Exception:
|
|
151
|
+
logger.exception("Failed to load embedding model: %s", model_name)
|
|
152
|
+
self._model = None
|
|
153
|
+
|
|
154
|
+
def _precompute_anchors(self) -> None:
|
|
155
|
+
"""Encode all anchor sentences and cache the vectors."""
|
|
156
|
+
for category, sentences in ANCHORS.items():
|
|
157
|
+
self._anchor_embeddings[category] = self._model.encode(sentences)
|
|
158
|
+
|
|
159
|
+
def classify(self, text: str) -> ClassificationResult:
|
|
160
|
+
"""Classify *text* using semantic similarity to anchor sentences.
|
|
161
|
+
|
|
162
|
+
Returns a neutral result if the model is not available.
|
|
163
|
+
"""
|
|
164
|
+
if not self._loaded:
|
|
165
|
+
self._try_load()
|
|
166
|
+
|
|
167
|
+
if self._model is None:
|
|
168
|
+
return ClassificationResult(stage=2)
|
|
169
|
+
|
|
170
|
+
# Encode the query text.
|
|
171
|
+
query_vec = self._model.encode([text])[0]
|
|
172
|
+
|
|
173
|
+
# Score against each category's anchors.
|
|
174
|
+
category_scores: dict[str, float] = {}
|
|
175
|
+
for cat_name, anchor_vecs in self._anchor_embeddings.items():
|
|
176
|
+
sims = [_cosine_similarity(query_vec, av) for av in anchor_vecs]
|
|
177
|
+
top_k = sorted(sims, reverse=True)[:3]
|
|
178
|
+
category_scores[cat_name] = (
|
|
179
|
+
sum(top_k) / len(top_k) if top_k else 0.0
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Determine the best category.
|
|
183
|
+
best_cat = max(category_scores, key=category_scores.get) # type: ignore[arg-type]
|
|
184
|
+
best_score = category_scores[best_cat]
|
|
185
|
+
|
|
186
|
+
threshold = self._config.get("safe_threshold", 0.35)
|
|
187
|
+
if best_score < threshold:
|
|
188
|
+
return ClassificationResult(category="safe", score=0.0, stage=2)
|
|
189
|
+
|
|
190
|
+
# Normalise to [0, 1].
|
|
191
|
+
normalised = max(0.0, min(1.0, best_score))
|
|
192
|
+
|
|
193
|
+
return ClassificationResult(
|
|
194
|
+
category=best_cat,
|
|
195
|
+
score=normalised,
|
|
196
|
+
triggers=[f"embedding:{best_cat}:{normalised:.3f}"],
|
|
197
|
+
stage=2,
|
|
198
|
+
)
|