surrogateshield 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- surrogateshield/__init__.py +264 -0
- surrogateshield/_display.py +101 -0
- surrogateshield/_response_parser.py +50 -0
- surrogateshield/_state.py +65 -0
- surrogateshield/core/__init__.py +0 -0
- surrogateshield/core/detection/__init__.py +0 -0
- surrogateshield/core/detection/context_guard.py +178 -0
- surrogateshield/core/detection/entity_trace.py +163 -0
- surrogateshield/core/detection/geo_data.py +102 -0
- surrogateshield/core/detection/pattern_scan.py +358 -0
- surrogateshield/core/detection/pipeline.py +453 -0
- surrogateshield/core/detection/quasi_identifier.py +184 -0
- surrogateshield/core/detection/service_query.py +187 -0
- surrogateshield/core/entities.py +32 -0
- surrogateshield/core/generation/__init__.py +0 -0
- surrogateshield/core/generation/mimic.py +201 -0
- surrogateshield/core/reconstruction/__init__.py +0 -0
- surrogateshield/core/reconstruction/resolve.py +142 -0
- surrogateshield/core/storage/__init__.py +0 -0
- surrogateshield/core/storage/shadow_map.py +158 -0
- surrogateshield-0.1.0.dist-info/METADATA +602 -0
- surrogateshield-0.1.0.dist-info/RECORD +24 -0
- surrogateshield-0.1.0.dist-info/WHEEL +5 -0
- surrogateshield-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SurrogateShield — Privacy-preserving PII proxy for LLMs.
|
|
3
|
+
|
|
4
|
+
Intercepts text before it reaches any LLM, replaces all PII with realistic
|
|
5
|
+
fake surrogates, and restores the real values in the LLM response.
|
|
6
|
+
|
|
7
|
+
Public API
|
|
8
|
+
──────────
|
|
9
|
+
import surrogateshield as shield
|
|
10
|
+
|
|
11
|
+
shield.config(pii_off=["phone", "location"])
|
|
12
|
+
sanitized = shield.mask(user_text)
|
|
13
|
+
response = llm.chat(sanitized)
|
|
14
|
+
restored = shield.unmask(response)
|
|
15
|
+
shield.flush()
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import os
|
|
21
|
+
from typing import Dict, List, Optional
|
|
22
|
+
|
|
23
|
+
from ._state import cfg, session
|
|
24
|
+
from . import _display, _response_parser
|
|
25
|
+
from .core.detection import pipeline as _pipeline
|
|
26
|
+
from .core.detection import service_query as _service_query
|
|
27
|
+
from .core.reconstruction.resolve import ResolvePass as _ResolvePass
|
|
28
|
+
|
|
29
|
+
__version__ = "0.1.0"
|
|
30
|
+
__all__ = ["config", "scan", "pii_finder", "mask", "unmask", "flush"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
34
|
+
# config()
|
|
35
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
36
|
+
|
|
37
|
+
def config(
|
|
38
|
+
detailed_view: bool = True,
|
|
39
|
+
pii_mem: str = "temp",
|
|
40
|
+
pii_off=None,
|
|
41
|
+
service: bool = True,
|
|
42
|
+
spacy_model: str = "en_core_web_lg",
|
|
43
|
+
context_guard_enabled: bool = True,
|
|
44
|
+
entity_trace_high_threshold: float = 0.85,
|
|
45
|
+
entity_trace_low_threshold: float = 0.60,
|
|
46
|
+
context_guard_threshold: float = 0.70,
|
|
47
|
+
entity_trace_fallback_threshold: float = 0.65,
|
|
48
|
+
fuzzy_threshold: int = 85,
|
|
49
|
+
) -> None:
|
|
50
|
+
"""
|
|
51
|
+
Configure SurrogateShield.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
detailed_view: Print detection/masking tables to stdout.
|
|
55
|
+
pii_mem: "temp" for in-memory session (default), or
|
|
56
|
+
a directory path for encrypted persistent storage.
|
|
57
|
+
pii_off: PII types to detect but NOT replace.
|
|
58
|
+
Accepts type names or aliases:
|
|
59
|
+
"phone", "name", "location", "org", "email",
|
|
60
|
+
"ssn", "dob", "address", "zip", "postcode",
|
|
61
|
+
"credit_card", "ip_address", "api_key",
|
|
62
|
+
"crypto", "bank", "license", "gender_indicator".
|
|
63
|
+
service: Enable service-query detection (address fuzzing
|
|
64
|
+
instead of full replacement for map queries).
|
|
65
|
+
spacy_model: spaCy model name for named entity recognition.
|
|
66
|
+
context_guard_enabled: Enable the HuggingFace NER second-pass.
|
|
67
|
+
entity_trace_high_threshold: spaCy score ≥ this → confirmed entity.
|
|
68
|
+
entity_trace_low_threshold: spaCy score ≥ this → borderline entity.
|
|
69
|
+
context_guard_threshold: ContextGuard score ≥ this → confirmed.
|
|
70
|
+
entity_trace_fallback_threshold: Promotion threshold when ContextGuard is off.
|
|
71
|
+
fuzzy_threshold: rapidfuzz partial_ratio threshold for unmask().
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: If pii_mem is not "temp" and the path does not exist or is
|
|
75
|
+
not a directory.
|
|
76
|
+
"""
|
|
77
|
+
if pii_off is None:
|
|
78
|
+
pii_off = []
|
|
79
|
+
|
|
80
|
+
if pii_mem != "temp":
|
|
81
|
+
if not os.path.isdir(pii_mem):
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"pii_mem path does not exist or is not a directory: {pii_mem!r}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
cfg.detailed_view = detailed_view
|
|
87
|
+
cfg.pii_mem = pii_mem
|
|
88
|
+
cfg.pii_off = list(pii_off)
|
|
89
|
+
cfg.service = service
|
|
90
|
+
cfg.spacy_model = spacy_model
|
|
91
|
+
cfg.context_guard_enabled = context_guard_enabled
|
|
92
|
+
cfg.entity_trace_high_threshold = entity_trace_high_threshold
|
|
93
|
+
cfg.entity_trace_low_threshold = entity_trace_low_threshold
|
|
94
|
+
cfg.context_guard_threshold = context_guard_threshold
|
|
95
|
+
cfg.entity_trace_fallback_threshold = entity_trace_fallback_threshold
|
|
96
|
+
cfg.fuzzy_threshold = fuzzy_threshold
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
100
|
+
# scan() / pii_finder
|
|
101
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
102
|
+
|
|
103
|
+
def scan(text: str) -> Dict[str, str]:
|
|
104
|
+
"""
|
|
105
|
+
Detect all PII in *text* without modifying anything.
|
|
106
|
+
|
|
107
|
+
Runs the full detection cascade (PatternScan → EntityTrace → ContextGuard)
|
|
108
|
+
and returns every detected entity regardless of pii_off settings.
|
|
109
|
+
Does NOT update the session shadow map.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
text: Any string to scan for PII.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict mapping detected_value → pii_type_string.
|
|
116
|
+
Example: {"john@example.com": "email", "John Smith": "PERSON"}
|
|
117
|
+
"""
|
|
118
|
+
confirmed, _ = _pipeline.run_cascade(
|
|
119
|
+
text=text,
|
|
120
|
+
skip_values=None,
|
|
121
|
+
skip_location_entities=False,
|
|
122
|
+
pii_off=None, # scan is always comprehensive
|
|
123
|
+
spacy_model=cfg.spacy_model,
|
|
124
|
+
context_guard_enabled=cfg.context_guard_enabled,
|
|
125
|
+
entity_trace_high_threshold=cfg.entity_trace_high_threshold,
|
|
126
|
+
entity_trace_low_threshold=cfg.entity_trace_low_threshold,
|
|
127
|
+
context_guard_threshold=cfg.context_guard_threshold,
|
|
128
|
+
entity_trace_fallback_threshold=cfg.entity_trace_fallback_threshold,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if cfg.detailed_view:
|
|
132
|
+
_display.show_scan_results(confirmed, cfg.pii_off)
|
|
133
|
+
|
|
134
|
+
return {ent.text: ent.type for ent in confirmed}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# Alias
|
|
138
|
+
pii_finder = scan
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
142
|
+
# mask()
|
|
143
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
144
|
+
|
|
145
|
+
def mask(text: str) -> str:
|
|
146
|
+
"""
|
|
147
|
+
Replace all PII in *text* with realistic fake surrogates.
|
|
148
|
+
|
|
149
|
+
The original→surrogate mapping is stored in the session shadow map so
|
|
150
|
+
that unmask() can restore the real values from the LLM response.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
text: The text to sanitize before sending to an LLM.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Sanitized text with PII replaced by surrogates.
|
|
157
|
+
"""
|
|
158
|
+
skip_values = None
|
|
159
|
+
skip_location_entities = False
|
|
160
|
+
|
|
161
|
+
# Service-query detection: fuzz addresses, suppress location entities
|
|
162
|
+
if cfg.service and _service_query.is_service_query(text):
|
|
163
|
+
text, fuzz_map = _service_query.fuzz_addresses(text, verify=True)
|
|
164
|
+
skip_location_entities = True
|
|
165
|
+
skip_values = set(fuzz_map.values())
|
|
166
|
+
|
|
167
|
+
# Run detection cascade
|
|
168
|
+
confirmed, _ = _pipeline.run_cascade(
|
|
169
|
+
text=text,
|
|
170
|
+
skip_values=skip_values,
|
|
171
|
+
skip_location_entities=skip_location_entities,
|
|
172
|
+
pii_off=cfg.pii_off,
|
|
173
|
+
spacy_model=cfg.spacy_model,
|
|
174
|
+
context_guard_enabled=cfg.context_guard_enabled,
|
|
175
|
+
entity_trace_high_threshold=cfg.entity_trace_high_threshold,
|
|
176
|
+
entity_trace_low_threshold=cfg.entity_trace_low_threshold,
|
|
177
|
+
context_guard_threshold=cfg.context_guard_threshold,
|
|
178
|
+
entity_trace_fallback_threshold=cfg.entity_trace_fallback_threshold,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Deduplicate
|
|
182
|
+
confirmed = _pipeline.deduplicate(confirmed)
|
|
183
|
+
|
|
184
|
+
if not confirmed:
|
|
185
|
+
if cfg.detailed_view:
|
|
186
|
+
_display.show_mask_results([], {})
|
|
187
|
+
return text
|
|
188
|
+
|
|
189
|
+
# Reuse surrogates for originals already seen in this session
|
|
190
|
+
existing_shadow = session.get_shadow_map().get_all() # {surrogate: original}
|
|
191
|
+
original_to_surrogate = {v: k for k, v in existing_shadow.items()}
|
|
192
|
+
|
|
193
|
+
already_mapped = [e for e in confirmed if e.text in original_to_surrogate]
|
|
194
|
+
new_entities = [e for e in confirmed if e.text not in original_to_surrogate]
|
|
195
|
+
|
|
196
|
+
surrogate_map = {e.text: original_to_surrogate[e.text] for e in already_mapped}
|
|
197
|
+
if new_entities:
|
|
198
|
+
surrogate_map.update(session.get_mimic().generate_all(new_entities))
|
|
199
|
+
|
|
200
|
+
# Apply substitutions (longest match first to avoid substring collisions)
|
|
201
|
+
sanitized = text
|
|
202
|
+
for original, surrogate in sorted(surrogate_map.items(), key=lambda x: len(x[0]), reverse=True):
|
|
203
|
+
sanitized = sanitized.replace(original, surrogate)
|
|
204
|
+
|
|
205
|
+
# Store inverted map (surrogate → original) in session shadow map
|
|
206
|
+
inverted = {v: k for k, v in surrogate_map.items()}
|
|
207
|
+
session.get_shadow_map().update(inverted)
|
|
208
|
+
|
|
209
|
+
if cfg.detailed_view:
|
|
210
|
+
_display.show_mask_results(confirmed, surrogate_map)
|
|
211
|
+
|
|
212
|
+
return sanitized
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
216
|
+
# unmask()
|
|
217
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
218
|
+
|
|
219
|
+
def unmask(response) -> str:
|
|
220
|
+
"""
|
|
221
|
+
Restore original PII values in the LLM *response*.
|
|
222
|
+
|
|
223
|
+
Extracts text from any major LLM SDK response object (Anthropic, OpenAI,
|
|
224
|
+
Gemini) or accepts a plain string, then replaces surrogates with the
|
|
225
|
+
originals stored in the session shadow map.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
response: An LLM SDK response object or a plain string.
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Response text with surrogates replaced by the original PII values.
|
|
232
|
+
"""
|
|
233
|
+
text = _response_parser.extract_text(response)
|
|
234
|
+
shadow_map = session.get_shadow_map().get_all()
|
|
235
|
+
|
|
236
|
+
resolver = _ResolvePass()
|
|
237
|
+
restored = resolver.resolve(
|
|
238
|
+
response_text=text,
|
|
239
|
+
shadow_map=shadow_map,
|
|
240
|
+
fuzzy_threshold=cfg.fuzzy_threshold,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
if cfg.detailed_view:
|
|
244
|
+
# Count how many surrogates were actually replaced
|
|
245
|
+
replaced_count = sum(1 for s in shadow_map if s not in restored or s not in text)
|
|
246
|
+
_display.show_unmask_results(len(shadow_map))
|
|
247
|
+
|
|
248
|
+
return restored
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
252
|
+
# flush()
|
|
253
|
+
# ─────────────────────────────────────────────────────────────────────────────
|
|
254
|
+
|
|
255
|
+
def flush() -> None:
|
|
256
|
+
"""
|
|
257
|
+
Clear the session: discard all surrogate mappings and reset the session id.
|
|
258
|
+
|
|
259
|
+
Call this after a conversation ends to ensure surrogate mappings from
|
|
260
|
+
one session cannot bleed into the next.
|
|
261
|
+
"""
|
|
262
|
+
session.reset()
|
|
263
|
+
if cfg.detailed_view:
|
|
264
|
+
_display.show_flush()
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
surrogateshield/_display.py — Output display helpers
|
|
3
|
+
|
|
4
|
+
Uses Rich tables when available; falls back to plain print otherwise.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from rich.console import Console as _Console
|
|
11
|
+
from rich.table import Table as _Table
|
|
12
|
+
_console = _Console()
|
|
13
|
+
HAS_RICH = True
|
|
14
|
+
except ImportError:
|
|
15
|
+
HAS_RICH = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def show_scan_results(entities: list, pii_off: list) -> None:
|
|
19
|
+
"""Print a table of detected PII entities from scan()."""
|
|
20
|
+
if not entities:
|
|
21
|
+
print("[SurrogateShield] No PII detected.")
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
pii_off_lower = {t.lower() for t in (pii_off or [])}
|
|
25
|
+
|
|
26
|
+
if HAS_RICH:
|
|
27
|
+
table = _Table(title="[bold cyan]SurrogateShield — Scan Results[/bold cyan]", show_lines=True)
|
|
28
|
+
table.add_column("Detected Value", style="red bold")
|
|
29
|
+
table.add_column("Type", style="yellow")
|
|
30
|
+
table.add_column("Score", style="white")
|
|
31
|
+
table.add_column("Source", style="dim")
|
|
32
|
+
for ent in entities:
|
|
33
|
+
skipped = ent.type.lower() in pii_off_lower
|
|
34
|
+
note = "[dim](skipped — pii_off)[/dim]" if skipped else ""
|
|
35
|
+
table.add_row(
|
|
36
|
+
ent.text,
|
|
37
|
+
ent.type,
|
|
38
|
+
f"{ent.score:.2f}",
|
|
39
|
+
ent.source,
|
|
40
|
+
note,
|
|
41
|
+
)
|
|
42
|
+
_console.print(table)
|
|
43
|
+
else:
|
|
44
|
+
print("\n[SurrogateShield] Scan Results")
|
|
45
|
+
print(f"{'Detected Value':<30} {'Type':<20} {'Score':<8} {'Source':<10}")
|
|
46
|
+
print("-" * 70)
|
|
47
|
+
for ent in entities:
|
|
48
|
+
skipped = ent.type.lower() in pii_off_lower
|
|
49
|
+
note = " (skipped — pii_off)" if skipped else ""
|
|
50
|
+
print(f"{ent.text:<30} {ent.type:<20} {ent.score:<8.2f} {ent.source:<10}{note}")
|
|
51
|
+
print()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def show_mask_results(entities: list, surrogate_map: dict) -> None:
|
|
55
|
+
"""Print a table showing original PII and their surrogates."""
|
|
56
|
+
if not entities:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
if HAS_RICH:
|
|
60
|
+
table = _Table(title="[bold cyan]SurrogateShield — Masked[/bold cyan]", show_lines=True)
|
|
61
|
+
table.add_column("Original", style="red bold")
|
|
62
|
+
table.add_column("Type", style="yellow")
|
|
63
|
+
table.add_column("Score", style="white")
|
|
64
|
+
table.add_column("Source", style="dim")
|
|
65
|
+
table.add_column("Surrogate", style="green bold")
|
|
66
|
+
for ent in entities:
|
|
67
|
+
surrogate = surrogate_map.get(ent.text, "[dim]—[/dim]")
|
|
68
|
+
table.add_row(
|
|
69
|
+
ent.text,
|
|
70
|
+
ent.type,
|
|
71
|
+
f"{ent.score:.2f}",
|
|
72
|
+
ent.source,
|
|
73
|
+
surrogate,
|
|
74
|
+
)
|
|
75
|
+
_console.print(table)
|
|
76
|
+
else:
|
|
77
|
+
print("\n[SurrogateShield] Mask Results")
|
|
78
|
+
print(f"{'Original':<30} {'Type':<20} {'Score':<8} {'Source':<10} {'Surrogate':<30}")
|
|
79
|
+
print("-" * 100)
|
|
80
|
+
for ent in entities:
|
|
81
|
+
surrogate = surrogate_map.get(ent.text, "—")
|
|
82
|
+
print(f"{ent.text:<30} {ent.type:<20} {ent.score:<8.2f} {ent.source:<10} {surrogate:<30}")
|
|
83
|
+
print()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def show_unmask_results(restored_count: int) -> None:
|
|
87
|
+
"""Print a one-liner confirming how many surrogates were restored."""
|
|
88
|
+
msg = f"[SurrogateShield] Restored {restored_count} surrogate(s)"
|
|
89
|
+
if HAS_RICH:
|
|
90
|
+
_console.print(f"[green]{msg}[/green]")
|
|
91
|
+
else:
|
|
92
|
+
print(msg)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def show_flush() -> None:
|
|
96
|
+
"""Print a one-liner confirming the session was cleared."""
|
|
97
|
+
msg = "[SurrogateShield] Session memory cleared"
|
|
98
|
+
if HAS_RICH:
|
|
99
|
+
_console.print(f"[yellow]{msg}[/yellow]")
|
|
100
|
+
else:
|
|
101
|
+
print(msg)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
"""
|
|
2
|
+
surrogateshield/_response_parser.py — LLM response text extractor
|
|
3
|
+
|
|
4
|
+
Extracts plain text from any major LLM SDK response object.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def extract_text(response) -> str:
|
|
11
|
+
"""
|
|
12
|
+
Extract the text content from an LLM response object.
|
|
13
|
+
|
|
14
|
+
Tries in order:
|
|
15
|
+
1. Anthropic style: response.content[0].text
|
|
16
|
+
2. OpenAI style: response.choices[0].message.content
|
|
17
|
+
3. Gemini style: response.text (has .text but no .choices)
|
|
18
|
+
4. Fallback: str(response)
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
response: Any LLM SDK response, or a plain string.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
The extracted text as a string.
|
|
25
|
+
"""
|
|
26
|
+
if isinstance(response, str):
|
|
27
|
+
return response
|
|
28
|
+
|
|
29
|
+
# Anthropic: response.content is a list of content blocks
|
|
30
|
+
try:
|
|
31
|
+
if hasattr(response, "content") and isinstance(response.content, list):
|
|
32
|
+
return response.content[0].text
|
|
33
|
+
except (AttributeError, IndexError):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
# OpenAI: response.choices[0].message.content
|
|
37
|
+
try:
|
|
38
|
+
if hasattr(response, "choices"):
|
|
39
|
+
return response.choices[0].message.content
|
|
40
|
+
except (AttributeError, IndexError):
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
# Gemini: response.text (but no .choices attribute)
|
|
44
|
+
try:
|
|
45
|
+
if hasattr(response, "text") and not hasattr(response, "choices"):
|
|
46
|
+
return response.text
|
|
47
|
+
except AttributeError:
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
return str(response)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
surrogateshield/_state.py — Module-level singletons
|
|
3
|
+
|
|
4
|
+
Holds cfg (Config) and session (Session) as module-level singletons
|
|
5
|
+
so all public API calls share the same state within a Python process.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import uuid
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from typing import List, Optional
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class _Config:
|
|
17
|
+
"""Holds all library-wide configuration values."""
|
|
18
|
+
detailed_view: bool = True
|
|
19
|
+
pii_mem: str = "temp"
|
|
20
|
+
pii_off: List[str] = field(default_factory=list)
|
|
21
|
+
service: bool = True
|
|
22
|
+
spacy_model: str = "en_core_web_lg"
|
|
23
|
+
context_guard_enabled: bool = True
|
|
24
|
+
entity_trace_high_threshold: float = 0.85
|
|
25
|
+
entity_trace_low_threshold: float = 0.60
|
|
26
|
+
context_guard_threshold: float = 0.70
|
|
27
|
+
entity_trace_fallback_threshold: float = 0.65
|
|
28
|
+
fuzzy_threshold: int = 85
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _Session:
|
|
32
|
+
"""Holds per-session state: session id, shadow map, and mimic generator."""
|
|
33
|
+
|
|
34
|
+
def __init__(self) -> None:
|
|
35
|
+
self.id: str = str(uuid.uuid4())
|
|
36
|
+
self._shadow_map = None
|
|
37
|
+
self._mimic = None
|
|
38
|
+
|
|
39
|
+
def reset(self) -> None:
|
|
40
|
+
"""Clear all session state and generate a new session id."""
|
|
41
|
+
if self._shadow_map is not None:
|
|
42
|
+
self._shadow_map.flush()
|
|
43
|
+
self._shadow_map = None
|
|
44
|
+
self._mimic = None
|
|
45
|
+
self.id = str(uuid.uuid4())
|
|
46
|
+
|
|
47
|
+
def get_mimic(self):
|
|
48
|
+
"""Return the MimicGen for this session, creating it if needed."""
|
|
49
|
+
if self._mimic is None:
|
|
50
|
+
from .core.generation.mimic import MimicGen
|
|
51
|
+
self._mimic = MimicGen()
|
|
52
|
+
return self._mimic
|
|
53
|
+
|
|
54
|
+
def get_shadow_map(self):
|
|
55
|
+
"""Return the ShadowMap for this session, creating it if needed."""
|
|
56
|
+
if self._shadow_map is None:
|
|
57
|
+
from .core.storage.shadow_map import ShadowMap
|
|
58
|
+
storage_dir = None if cfg.pii_mem == "temp" else cfg.pii_mem
|
|
59
|
+
self._shadow_map = ShadowMap(self.id, storage_dir=storage_dir)
|
|
60
|
+
return self._shadow_map
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
# Module-level singletons
|
|
64
|
+
cfg = _Config()
|
|
65
|
+
session = _Session()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,178 @@
|
|
|
1
|
+
"""
|
|
2
|
+
detection/context_guard.py — ContextGuard
|
|
3
|
+
|
|
4
|
+
NER-based detection of named entities using a local HuggingFace model
|
|
5
|
+
(dslim/distilbert-NER by default, ~250 MB).
|
|
6
|
+
|
|
7
|
+
This module detects named entities in the text that PatternScan and
|
|
8
|
+
EntityTrace missed. It does NOT decide whether a geographic entity is
|
|
9
|
+
PII — that decision is made in detection/pipeline.py.
|
|
10
|
+
|
|
11
|
+
Tokenization artefact handling:
|
|
12
|
+
distilbert word-piece tokenisation sometimes produces tokens like ". Sun"
|
|
13
|
+
or "##wick". Both are stripped before emitting entities.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import logging
|
|
19
|
+
import re
|
|
20
|
+
from typing import Dict, List, Tuple
|
|
21
|
+
|
|
22
|
+
from ..entities import DetectedEntity
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
# Cache pipelines keyed by model name
|
|
27
|
+
_ner_pipelines: Dict[str, object] = {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _get_ner(model_name: str = "dslim/distilbert-NER"):
|
|
31
|
+
"""Lazy-load and cache the HuggingFace NER pipeline by model name."""
|
|
32
|
+
global _ner_pipelines
|
|
33
|
+
if model_name in _ner_pipelines:
|
|
34
|
+
return _ner_pipelines[model_name]
|
|
35
|
+
try:
|
|
36
|
+
from transformers import pipeline as hf_pipeline
|
|
37
|
+
pipeline = hf_pipeline(
|
|
38
|
+
"ner",
|
|
39
|
+
model=model_name,
|
|
40
|
+
aggregation_strategy="simple",
|
|
41
|
+
device=-1,
|
|
42
|
+
)
|
|
43
|
+
_ner_pipelines[model_name] = pipeline
|
|
44
|
+
logger.info(f"[ContextGuard] Loaded NER model: {model_name}")
|
|
45
|
+
except ImportError:
|
|
46
|
+
logger.warning(
|
|
47
|
+
"[ContextGuard] transformers not installed — skipping. "
|
|
48
|
+
"Run: pip install transformers torch"
|
|
49
|
+
)
|
|
50
|
+
_ner_pipelines[model_name] = None
|
|
51
|
+
except Exception as exc:
|
|
52
|
+
logger.warning(f"[ContextGuard] Failed to load NER model: {exc}")
|
|
53
|
+
_ner_pipelines[model_name] = None
|
|
54
|
+
return _ner_pipelines[model_name]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
_LABEL_MAP = {
|
|
58
|
+
"PER": "PERSON",
|
|
59
|
+
"PERSON": "PERSON",
|
|
60
|
+
"ORG": "ORG",
|
|
61
|
+
"LOC": "LOC",
|
|
62
|
+
"GPE": "GPE",
|
|
63
|
+
"MISC": "MISC",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
_KEEP_LABELS = {"PER", "PERSON", "ORG", "LOC", "GPE"}
|
|
67
|
+
|
|
68
|
+
_CG_BLOCKLIST: frozenset = frozenset({
|
|
69
|
+
"dr", "mr", "mrs", "ms", "prof", "professor", "rev", "sr", "jr",
|
|
70
|
+
"sir", "lord", "dame", "capt", "lt", "sgt", "col", "gen",
|
|
71
|
+
"de", "le", "la", "el", "al", "van", "von",
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _clean_token(raw: str) -> str:
|
|
76
|
+
"""Strip HuggingFace word-piece artefacts and leading punctuation."""
|
|
77
|
+
text = raw.replace("##", "")
|
|
78
|
+
text = re.sub(r'^[^A-Za-z0-9]+', '', text)
|
|
79
|
+
return text.strip()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def guard(
|
|
83
|
+
remaining_text: str,
|
|
84
|
+
borderline_entities: List[DetectedEntity],
|
|
85
|
+
model_name: str = "dslim/distilbert-NER",
|
|
86
|
+
enabled: bool = True,
|
|
87
|
+
confidence_threshold: float = 0.70,
|
|
88
|
+
) -> Tuple[List[DetectedEntity], List[DetectedEntity]]:
|
|
89
|
+
"""
|
|
90
|
+
Run NER on remaining_text and verify borderline_entities.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
remaining_text: Text not covered by PatternScan / EntityTrace.
|
|
94
|
+
borderline_entities: Entities EntityTrace was uncertain about.
|
|
95
|
+
model_name: HuggingFace model to use for NER inference.
|
|
96
|
+
enabled: If False, skip NER inference entirely.
|
|
97
|
+
confidence_threshold: Minimum score to promote a borderline entity.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Tuple of (confirmed_entities, uncertain_entities).
|
|
101
|
+
"""
|
|
102
|
+
confirmed: List[DetectedEntity] = []
|
|
103
|
+
uncertain: List[DetectedEntity] = []
|
|
104
|
+
|
|
105
|
+
# Verify borderline entities from EntityTrace against the threshold
|
|
106
|
+
for ent in borderline_entities:
|
|
107
|
+
if ent.score >= confidence_threshold:
|
|
108
|
+
confirmed.append(ent)
|
|
109
|
+
logger.debug(
|
|
110
|
+
f"[ContextGuard] Verified borderline: {ent.text!r} "
|
|
111
|
+
f"({ent.type}, score={ent.score:.2f})"
|
|
112
|
+
)
|
|
113
|
+
else:
|
|
114
|
+
uncertain.append(ent)
|
|
115
|
+
|
|
116
|
+
if not enabled:
|
|
117
|
+
return confirmed, uncertain
|
|
118
|
+
|
|
119
|
+
# Run NER on remaining text
|
|
120
|
+
clean = remaining_text.replace("█", " ").strip()
|
|
121
|
+
if not clean:
|
|
122
|
+
return confirmed, uncertain
|
|
123
|
+
|
|
124
|
+
ner = _get_ner(model_name)
|
|
125
|
+
if ner is None:
|
|
126
|
+
return confirmed, uncertain
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
results = ner(clean)
|
|
130
|
+
except Exception as exc:
|
|
131
|
+
logger.warning(f"[ContextGuard] NER inference failed: {exc}")
|
|
132
|
+
return confirmed, uncertain
|
|
133
|
+
|
|
134
|
+
for r in results:
|
|
135
|
+
label = r.get("entity_group", r.get("entity", ""))
|
|
136
|
+
if label not in _KEEP_LABELS:
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
entity_type = _LABEL_MAP.get(label, label)
|
|
140
|
+
score = float(r.get("score", 0.0))
|
|
141
|
+
|
|
142
|
+
raw_word = r.get("word", "")
|
|
143
|
+
text = _clean_token(raw_word)
|
|
144
|
+
|
|
145
|
+
if len(text) < 3:
|
|
146
|
+
logger.debug(
|
|
147
|
+
f"[ContextGuard] Skipping too-short token: {raw_word!r} → {text!r}"
|
|
148
|
+
)
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
if text.lower() in _CG_BLOCKLIST:
|
|
152
|
+
logger.debug(f"[ContextGuard] Skipping blocklisted token: {text!r}")
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
entity = DetectedEntity(
|
|
156
|
+
text=text,
|
|
157
|
+
start=r.get("start", 0),
|
|
158
|
+
end=r.get("end", len(text)),
|
|
159
|
+
type=entity_type,
|
|
160
|
+
score=score,
|
|
161
|
+
source="slm",
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if score >= confidence_threshold:
|
|
165
|
+
confirmed.append(entity)
|
|
166
|
+
logger.debug(
|
|
167
|
+
f"[ContextGuard] Confirmed: {text!r} ({entity_type}, {score:.2f})"
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
uncertain.append(entity)
|
|
171
|
+
logger.debug(
|
|
172
|
+
f"[ContextGuard] Uncertain: {text!r} ({entity_type}, {score:.2f})"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
logger.info(
|
|
176
|
+
f"[ContextGuard] confirmed={len(confirmed)}, uncertain={len(uncertain)}"
|
|
177
|
+
)
|
|
178
|
+
return confirmed, uncertain
|