guardix 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- guardix/__init__.py +57 -0
- guardix/config.py +54 -0
- guardix/core.py +203 -0
- guardix/decorators.py +123 -0
- guardix/detectors/__init__.py +6 -0
- guardix/detectors/base.py +14 -0
- guardix/detectors/bert_detector.py +108 -0
- guardix/exceptions.py +19 -0
- guardix/logging_config.py +114 -0
- guardix/middleware.py +100 -0
- guardix/providers/__init__.py +15 -0
- guardix/providers/anthropic.py +65 -0
- guardix/providers/base.py +43 -0
- guardix/providers/gemini.py +67 -0
- guardix/providers/generic.py +73 -0
- guardix/providers/openai.py +87 -0
- guardix/responses.py +129 -0
- guardix-0.1.0.dist-info/METADATA +314 -0
- guardix-0.1.0.dist-info/RECORD +22 -0
- guardix-0.1.0.dist-info/WHEEL +5 -0
- guardix-0.1.0.dist-info/licenses/LICENSE +21 -0
- guardix-0.1.0.dist-info/top_level.txt +1 -0
guardix/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""guardix — Universal LLM prompt guard against injection attacks."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from .core import Guardial, Policy, Decision
|
|
6
|
+
from .exceptions import GuardBlocked, GuardError
|
|
7
|
+
from .config import Config
|
|
8
|
+
from .responses import is_blocked_response
|
|
9
|
+
|
|
10
|
+
__version__ = "0.1.0"
|
|
11
|
+
__all__ = [
|
|
12
|
+
"Guardial",
|
|
13
|
+
"Policy",
|
|
14
|
+
"Decision",
|
|
15
|
+
"GuardBlocked",
|
|
16
|
+
"GuardError",
|
|
17
|
+
"Config",
|
|
18
|
+
"guard_client",
|
|
19
|
+
"is_blocked_response",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def guard_client(client: Any, guardial: Optional[Guardial] = None, provider: Optional[str] = None) -> Any:
|
|
24
|
+
"""Wrap any supported LLM client with prompt guarding in one line.
|
|
25
|
+
|
|
26
|
+
Auto-detects the client type:
|
|
27
|
+
- ``messages.create`` -> Anthropic
|
|
28
|
+
- ``models.generate_content`` -> Gemini (google-genai)
|
|
29
|
+
- ``chat.completions.create`` -> OpenAI and all OpenAI-compatible
|
|
30
|
+
providers (Azure OpenAI, Groq, OpenRouter, Together, ...)
|
|
31
|
+
|
|
32
|
+
``provider`` overrides the name used in logs (e.g. "groq", "openrouter").
|
|
33
|
+
|
|
34
|
+
Usage:
|
|
35
|
+
from guardix import guard_client
|
|
36
|
+
client = guard_client(OpenAI())
|
|
37
|
+
client.chat.completions.create(...) # guarded, never raises on block
|
|
38
|
+
"""
|
|
39
|
+
from .providers import AnthropicAdapter, GeminiAdapter, OpenAIAdapter
|
|
40
|
+
|
|
41
|
+
messages = getattr(client, "messages", None)
|
|
42
|
+
if messages is not None and callable(getattr(messages, "create", None)):
|
|
43
|
+
return AnthropicAdapter(client, guardial=guardial)
|
|
44
|
+
|
|
45
|
+
models = getattr(client, "models", None)
|
|
46
|
+
if models is not None and callable(getattr(models, "generate_content", None)):
|
|
47
|
+
return GeminiAdapter(client, guardial=guardial)
|
|
48
|
+
|
|
49
|
+
chat = getattr(client, "chat", None)
|
|
50
|
+
completions = getattr(chat, "completions", None) if chat is not None else None
|
|
51
|
+
if completions is not None and callable(getattr(completions, "create", None)):
|
|
52
|
+
return OpenAIAdapter(client, guardial=guardial, provider_name=provider or "openai")
|
|
53
|
+
|
|
54
|
+
raise TypeError(
|
|
55
|
+
"Unsupported client: expected an object with messages.create (Anthropic), "
|
|
56
|
+
"models.generate_content (Gemini), or chat.completions.create (OpenAI-compatible)."
|
|
57
|
+
)
|
guardix/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Configuration for Guardial guard engine."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
4
|
+
|
|
5
|
+
from .detectors.base import BaseDetector
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Config:
|
|
9
|
+
"""Guard configuration."""
|
|
10
|
+
|
|
11
|
+
DEFAULT_POLICIES: Dict[str, Dict[str, Any]] = {
|
|
12
|
+
"permissive": {"threshold": 0.9, "fail_mode": "open", "log_level": "INFO"},
|
|
13
|
+
"standard": {"threshold": 0.7, "fail_mode": "open", "log_level": "INFO"},
|
|
14
|
+
"strict": {"threshold": 0.5, "fail_mode": "closed", "log_level": "DEBUG"},
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
policy: str = "standard",
|
|
20
|
+
threshold: Optional[float] = None,
|
|
21
|
+
fail_mode: Optional[str] = None,
|
|
22
|
+
log_level: Optional[str] = None,
|
|
23
|
+
log_sink: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
24
|
+
log_file: Optional[str] = "logs/guardix.jsonl",
|
|
25
|
+
custom_detectors: Optional[List[BaseDetector]] = None,
|
|
26
|
+
mask_raw_prompt: bool = True,
|
|
27
|
+
block_mode: str = "mock",
|
|
28
|
+
block_message: Optional[str] = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
defaults = self.DEFAULT_POLICIES.get(policy, self.DEFAULT_POLICIES["standard"])
|
|
31
|
+
self.policy = policy
|
|
32
|
+
self.threshold = threshold if threshold is not None else defaults["threshold"]
|
|
33
|
+
self.fail_mode = fail_mode if fail_mode is not None else defaults["fail_mode"]
|
|
34
|
+
self.log_level = log_level if log_level is not None else defaults["log_level"]
|
|
35
|
+
self.log_sink = log_sink
|
|
36
|
+
# When set, structured logs go to this file (folder auto-created)
|
|
37
|
+
# and console output is suppressed.
|
|
38
|
+
self.log_file = log_file
|
|
39
|
+
self.custom_detectors = custom_detectors or []
|
|
40
|
+
self.mask_raw_prompt = mask_raw_prompt
|
|
41
|
+
# "mock": blocked calls return a provider-shaped mock response so the
|
|
42
|
+
# pipeline never breaks. "raise": blocked calls raise GuardBlocked.
|
|
43
|
+
self.block_mode = block_mode
|
|
44
|
+
self.block_message = block_message
|
|
45
|
+
|
|
46
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
47
|
+
return {
|
|
48
|
+
"policy": self.policy,
|
|
49
|
+
"threshold": self.threshold,
|
|
50
|
+
"fail_mode": self.fail_mode,
|
|
51
|
+
"log_level": self.log_level,
|
|
52
|
+
"mask_raw_prompt": self.mask_raw_prompt,
|
|
53
|
+
"block_mode": self.block_mode,
|
|
54
|
+
}
|
guardix/core.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Core Guardial engine, policies, and decisions."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Any, Dict, List, Optional, Callable
|
|
7
|
+
|
|
8
|
+
from .config import Config
|
|
9
|
+
from .logging_config import StructuredLogger
|
|
10
|
+
from .exceptions import GuardBlocked, GuardError
|
|
11
|
+
from .detectors.base import BaseDetector
|
|
12
|
+
from .detectors.bert_detector import BertDetector
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Decision:
|
|
17
|
+
"""Result of a guard scan."""
|
|
18
|
+
|
|
19
|
+
prompt_id: str
|
|
20
|
+
decision: str # "ALLOW", "WARN", "BLOCK"
|
|
21
|
+
scores: Dict[str, float] = field(default_factory=dict)
|
|
22
|
+
threshold: float = 0.7
|
|
23
|
+
reason: str = ""
|
|
24
|
+
latency_ms: float = 0.0
|
|
25
|
+
provider: str = "unknown"
|
|
26
|
+
raw_prompt: Optional[str] = None
|
|
27
|
+
class_name: str = ""
|
|
28
|
+
|
|
29
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
30
|
+
return {
|
|
31
|
+
"prompt_id": self.prompt_id,
|
|
32
|
+
"decision": self.decision,
|
|
33
|
+
"scores": self.scores,
|
|
34
|
+
"threshold": self.threshold,
|
|
35
|
+
"reason": self.reason,
|
|
36
|
+
"latency_ms": self.latency_ms,
|
|
37
|
+
"provider": self.provider,
|
|
38
|
+
"class_name": self.class_name,
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Policy:
|
|
43
|
+
"""Policy rules for guard decisions."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, threshold: float = 0.7, warn_threshold: Optional[float] = None) -> None:
|
|
46
|
+
self.threshold = threshold
|
|
47
|
+
self.warn_threshold = warn_threshold if warn_threshold is not None else threshold * 0.85
|
|
48
|
+
|
|
49
|
+
def evaluate(self, max_score: float) -> str:
|
|
50
|
+
if max_score >= self.threshold:
|
|
51
|
+
return "BLOCK"
|
|
52
|
+
if max_score >= self.warn_threshold:
|
|
53
|
+
return "WARN"
|
|
54
|
+
return "ALLOW"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class Guardial:
|
|
58
|
+
"""Guard engine powered by fine-tuned BERT-mini."""
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
policy: Optional[str] = None,
|
|
63
|
+
threshold: Optional[float] = None,
|
|
64
|
+
fail_mode: Optional[str] = None,
|
|
65
|
+
log_level: Optional[str] = None,
|
|
66
|
+
log_sink: Optional[Callable[[Dict[str, Any]], None]] = None,
|
|
67
|
+
log_file: Optional[str] = "logs/guardix.jsonl",
|
|
68
|
+
custom_detectors: Optional[List[BaseDetector]] = None,
|
|
69
|
+
mask_raw_prompt: bool = True,
|
|
70
|
+
block_mode: str = "mock",
|
|
71
|
+
block_message: Optional[str] = None,
|
|
72
|
+
config: Optional[Config] = None,
|
|
73
|
+
) -> None:
|
|
74
|
+
if config is not None:
|
|
75
|
+
self.config = config
|
|
76
|
+
else:
|
|
77
|
+
self.config = Config(
|
|
78
|
+
policy=policy or "standard",
|
|
79
|
+
threshold=threshold,
|
|
80
|
+
fail_mode=fail_mode,
|
|
81
|
+
log_level=log_level,
|
|
82
|
+
log_sink=log_sink,
|
|
83
|
+
log_file=log_file,
|
|
84
|
+
custom_detectors=custom_detectors,
|
|
85
|
+
mask_raw_prompt=mask_raw_prompt,
|
|
86
|
+
block_mode=block_mode,
|
|
87
|
+
block_message=block_message,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self.policy = Policy(threshold=self.config.threshold)
|
|
91
|
+
self.logger = StructuredLogger(
|
|
92
|
+
level=self.config.log_level,
|
|
93
|
+
sink=self.config.log_sink,
|
|
94
|
+
log_file=getattr(self.config, "log_file", None),
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
self.detectors: List[BaseDetector] = [BertDetector()]
|
|
98
|
+
if self.config.custom_detectors:
|
|
99
|
+
self.detectors.extend(self.config.custom_detectors)
|
|
100
|
+
|
|
101
|
+
def analyze(self, prompt: str, provider: str = "unknown") -> Decision:
|
|
102
|
+
"""Run BERT-mini detector and return a Decision. Never raises."""
|
|
103
|
+
prompt_id = str(uuid.uuid4())
|
|
104
|
+
start = time.perf_counter()
|
|
105
|
+
scores: Dict[str, float] = {}
|
|
106
|
+
class_name = ""
|
|
107
|
+
|
|
108
|
+
try:
|
|
109
|
+
for detector in self.detectors:
|
|
110
|
+
try:
|
|
111
|
+
if hasattr(detector, "detect_and_classify"):
|
|
112
|
+
score, class_name, _ = detector.detect_and_classify(prompt)
|
|
113
|
+
scores[detector.name] = score
|
|
114
|
+
else:
|
|
115
|
+
score = detector.detect(prompt)
|
|
116
|
+
scores[detector.name] = score
|
|
117
|
+
if hasattr(detector, "classify"):
|
|
118
|
+
try:
|
|
119
|
+
cn, _ = detector.classify(prompt)
|
|
120
|
+
class_name = cn
|
|
121
|
+
except Exception:
|
|
122
|
+
pass
|
|
123
|
+
except Exception as e:
|
|
124
|
+
scores[detector.name] = 0.0
|
|
125
|
+
self.logger.log_error(prompt_id, provider, f"{detector.name} failed: {e}", 0.0)
|
|
126
|
+
if self.config.fail_mode == "closed":
|
|
127
|
+
raise GuardError(f"{detector.name} failed: {e}", original_error=e)
|
|
128
|
+
|
|
129
|
+
max_score = max(scores.values()) if scores else 0.0
|
|
130
|
+
decision_label = self.policy.evaluate(max_score)
|
|
131
|
+
|
|
132
|
+
reasons = []
|
|
133
|
+
for name, score in scores.items():
|
|
134
|
+
if score >= self.policy.threshold:
|
|
135
|
+
reasons.append(f"{name}={score:.2f}")
|
|
136
|
+
reason = f"Threshold exceeded by {', '.join(reasons)}" if reasons else "No detectors flagged"
|
|
137
|
+
|
|
138
|
+
latency_ms = (time.perf_counter() - start) * 1000
|
|
139
|
+
|
|
140
|
+
decision = Decision(
|
|
141
|
+
prompt_id=prompt_id,
|
|
142
|
+
decision=decision_label,
|
|
143
|
+
scores=scores,
|
|
144
|
+
threshold=self.config.threshold,
|
|
145
|
+
reason=reason,
|
|
146
|
+
latency_ms=latency_ms,
|
|
147
|
+
provider=provider,
|
|
148
|
+
raw_prompt=prompt if not self.config.mask_raw_prompt else None,
|
|
149
|
+
class_name=class_name,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.logger.log_decision(
|
|
153
|
+
prompt_id=prompt_id,
|
|
154
|
+
provider=provider,
|
|
155
|
+
detector_results=scores,
|
|
156
|
+
decision=decision_label,
|
|
157
|
+
reason=reason,
|
|
158
|
+
latency_ms=latency_ms,
|
|
159
|
+
raw_prompt=prompt if not self.config.mask_raw_prompt else None,
|
|
160
|
+
level="DEBUG" if decision_label == "ALLOW" else "WARNING",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return decision
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
latency_ms = (time.perf_counter() - start) * 1000
|
|
167
|
+
self.logger.log_error(prompt_id, provider, f"Guard engine failed: {e}", latency_ms)
|
|
168
|
+
if self.config.fail_mode == "closed":
|
|
169
|
+
raise GuardError(f"Guard engine failed: {e}", original_error=e)
|
|
170
|
+
return Decision(
|
|
171
|
+
prompt_id=prompt_id,
|
|
172
|
+
decision="ALLOW",
|
|
173
|
+
scores={},
|
|
174
|
+
threshold=self.config.threshold,
|
|
175
|
+
reason=f"Guard engine error (fail_open): {e}",
|
|
176
|
+
latency_ms=latency_ms,
|
|
177
|
+
provider=provider,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
def guard(
|
|
181
|
+
self,
|
|
182
|
+
prompt: str,
|
|
183
|
+
provider: str = "unknown",
|
|
184
|
+
on_block: Optional[Callable[[Decision], Any]] = None,
|
|
185
|
+
) -> Decision:
|
|
186
|
+
"""Analyze the prompt and act on a BLOCK according to block_mode.
|
|
187
|
+
|
|
188
|
+
- ``on_block`` callback given: call it with the Decision.
|
|
189
|
+
- ``block_mode="raise"``: raise GuardBlocked (legacy behavior).
|
|
190
|
+
- ``block_mode="mock"`` (default): return the Decision unraised so
|
|
191
|
+
the caller (adapter/middleware/decorator) can substitute a
|
|
192
|
+
provider-shaped mock response and keep the pipeline alive.
|
|
193
|
+
"""
|
|
194
|
+
decision = self.analyze(prompt, provider=provider)
|
|
195
|
+
if decision.decision == "BLOCK":
|
|
196
|
+
if on_block:
|
|
197
|
+
on_block(decision)
|
|
198
|
+
elif self.config.block_mode == "raise":
|
|
199
|
+
self.logger.log_block_action(
|
|
200
|
+
decision.prompt_id, provider, "exception_raised", decision.reason
|
|
201
|
+
)
|
|
202
|
+
raise GuardBlocked(f"Prompt blocked: {decision.reason}", decision=decision)
|
|
203
|
+
return decision
|
guardix/decorators.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
"""Decorators for easy guarding of LLM call functions."""
|
|
2
|
+
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from typing import Any, Callable, Optional
|
|
5
|
+
|
|
6
|
+
from .core import Guardial, Decision
|
|
7
|
+
from .exceptions import GuardBlocked
|
|
8
|
+
from .responses import (
|
|
9
|
+
anthropic_blocked_response,
|
|
10
|
+
gemini_blocked_response,
|
|
11
|
+
openai_blocked_response,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
_RESPONSE_BUILDERS = {
|
|
15
|
+
"openai": openai_blocked_response,
|
|
16
|
+
"anthropic": anthropic_blocked_response,
|
|
17
|
+
"gemini": gemini_blocked_response,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def guardial_guard(
|
|
22
|
+
policy: Optional[str] = None,
|
|
23
|
+
threshold: Optional[float] = None,
|
|
24
|
+
fail_mode: Optional[str] = None,
|
|
25
|
+
on_block: Optional[Callable[[Decision], Any]] = None,
|
|
26
|
+
provider: str = "unknown",
|
|
27
|
+
block_mode: str = "mock",
|
|
28
|
+
response_format: str = "openai",
|
|
29
|
+
block_message: Optional[str] = None,
|
|
30
|
+
) -> Callable[..., Any]:
|
|
31
|
+
"""Decorator that guards a function taking a prompt or messages argument.
|
|
32
|
+
|
|
33
|
+
By default a blocked prompt does NOT raise: the wrapped function is
|
|
34
|
+
skipped and a provider-shaped mock response (``response_format``:
|
|
35
|
+
"openai", "anthropic", or "gemini") is returned, so the pipeline keeps
|
|
36
|
+
flowing. Pass ``block_mode="raise"`` for the old GuardBlocked behavior.
|
|
37
|
+
|
|
38
|
+
Usage:
|
|
39
|
+
@guardial_guard(policy="strict")
|
|
40
|
+
def chat(messages):
|
|
41
|
+
return openai_client.chat.completions.create(model="gpt-4", messages=messages)
|
|
42
|
+
"""
|
|
43
|
+
g = Guardial(
|
|
44
|
+
policy=policy,
|
|
45
|
+
threshold=threshold,
|
|
46
|
+
fail_mode=fail_mode,
|
|
47
|
+
block_mode=block_mode,
|
|
48
|
+
block_message=block_message,
|
|
49
|
+
)
|
|
50
|
+
builder = _RESPONSE_BUILDERS.get(response_format, openai_blocked_response)
|
|
51
|
+
|
|
52
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
53
|
+
@wraps(func)
|
|
54
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
55
|
+
messages = kwargs.get("messages")
|
|
56
|
+
if messages is None and args:
|
|
57
|
+
first = args[0]
|
|
58
|
+
if isinstance(first, str):
|
|
59
|
+
messages = [{"role": "user", "content": first}]
|
|
60
|
+
elif isinstance(first, list):
|
|
61
|
+
messages = first
|
|
62
|
+
prompt = _messages_to_prompt(messages)
|
|
63
|
+
decision = g.guard(prompt, provider=provider, on_block=on_block)
|
|
64
|
+
if decision.decision == "BLOCK" and on_block is None and block_mode == "mock":
|
|
65
|
+
g.logger.log_block_action(
|
|
66
|
+
decision.prompt_id, provider, "mock_response", decision.reason
|
|
67
|
+
)
|
|
68
|
+
if response_format == "gemini":
|
|
69
|
+
return builder(decision, message=block_message)
|
|
70
|
+
return builder(decision, model=kwargs.get("model", "unknown"), message=block_message)
|
|
71
|
+
return func(*args, **kwargs)
|
|
72
|
+
|
|
73
|
+
return wrapper
|
|
74
|
+
|
|
75
|
+
return decorator
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def guardial_audit(
|
|
79
|
+
policy: Optional[str] = None,
|
|
80
|
+
provider: str = "unknown",
|
|
81
|
+
log_sink: Optional[Callable[[Any], None]] = None,
|
|
82
|
+
) -> Callable[..., Any]:
|
|
83
|
+
"""Decorator that only audits (never blocks) LLM calls.
|
|
84
|
+
|
|
85
|
+
Usage:
|
|
86
|
+
@guardial_audit(policy="standard")
|
|
87
|
+
def chat(messages):
|
|
88
|
+
return openai_client.chat.completions.create(model="gpt-4", messages=messages)
|
|
89
|
+
"""
|
|
90
|
+
g = Guardial(policy=policy, fail_mode="open", log_sink=log_sink)
|
|
91
|
+
|
|
92
|
+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
93
|
+
@wraps(func)
|
|
94
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
95
|
+
messages = kwargs.get("messages")
|
|
96
|
+
if messages is None and args:
|
|
97
|
+
first = args[0]
|
|
98
|
+
if isinstance(first, str):
|
|
99
|
+
messages = [{"role": "user", "content": first}]
|
|
100
|
+
elif isinstance(first, list):
|
|
101
|
+
messages = first
|
|
102
|
+
prompt = _messages_to_prompt(messages)
|
|
103
|
+
g.analyze(prompt, provider=provider) # audit only, never blocks
|
|
104
|
+
return func(*args, **kwargs)
|
|
105
|
+
|
|
106
|
+
return wrapper
|
|
107
|
+
|
|
108
|
+
return decorator
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _messages_to_prompt(messages: Any) -> str:
|
|
112
|
+
if messages is None:
|
|
113
|
+
return ""
|
|
114
|
+
if isinstance(messages, str):
|
|
115
|
+
return messages
|
|
116
|
+
parts = []
|
|
117
|
+
for msg in messages:
|
|
118
|
+
if isinstance(msg, dict):
|
|
119
|
+
content = msg.get("content", "")
|
|
120
|
+
parts.append(content)
|
|
121
|
+
else:
|
|
122
|
+
parts.append(str(msg))
|
|
123
|
+
return "\n".join(parts)
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Base detector interface."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BaseDetector(ABC):
|
|
7
|
+
"""Abstract base class for all prompt injection detectors."""
|
|
8
|
+
|
|
9
|
+
name: str = "base"
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def detect(self, prompt: str) -> float:
|
|
13
|
+
"""Return a confidence score between 0.0 and 1.0."""
|
|
14
|
+
...
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Single BERT-mini PyTorch detector — replaces all 16 old rule detectors."""
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Dict, List, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
|
9
|
+
|
|
10
|
+
from .base import BaseDetector
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
MODEL_ID = "PraneshJs/promptgaurd"
|
|
14
|
+
|
|
15
|
+
# Process-wide cache so every Guardial/BertDetector instance shares one
|
|
16
|
+
# loaded model instead of re-downloading and re-loading per instance.
|
|
17
|
+
_MODEL_CACHE: Dict[str, Tuple] = {}
|
|
18
|
+
|
|
19
|
+
_SENTENCE_SPLIT = re.compile(r"(?<=[.!?\n])\s+")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _load_model(model_id: str) -> Tuple:
|
|
23
|
+
if model_id not in _MODEL_CACHE:
|
|
24
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
25
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
26
|
+
model = AutoModelForSequenceClassification.from_pretrained(model_id).to(device)
|
|
27
|
+
model.eval()
|
|
28
|
+
# Fast tokenizers are not thread-safe ("Already borrowed"); the
|
|
29
|
+
# shared cache means concurrent callers must serialize inference.
|
|
30
|
+
_MODEL_CACHE[model_id] = (tokenizer, model, device, threading.Lock())
|
|
31
|
+
return _MODEL_CACHE[model_id]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BertDetector(BaseDetector):
|
|
35
|
+
"""Binary detector using fine-tuned BERT-mini (safe/attack).
|
|
36
|
+
|
|
37
|
+
Returns 0.0–1.0 attack probability. Also provides class-level prediction.
|
|
38
|
+
Long prompts cannot bypass detection via truncation: the prompt is scored
|
|
39
|
+
as overlapping 128-token sliding windows AND as individual sentences (so a
|
|
40
|
+
short injection buried in benign text gets an undiluted look), all in one
|
|
41
|
+
batched forward pass. The worst (most attack-like) segment wins.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
name = "bert_mini"
|
|
45
|
+
|
|
46
|
+
def __init__(self, model_id: str = MODEL_ID):
|
|
47
|
+
self.tokenizer, self.model, self.device, self._lock = _load_model(model_id)
|
|
48
|
+
self.max_len = 128
|
|
49
|
+
self.stride = 64
|
|
50
|
+
|
|
51
|
+
def _segments(self, prompt: str) -> List[str]:
|
|
52
|
+
"""Full prompt plus per-sentence segments for undiluted scoring.
|
|
53
|
+
|
|
54
|
+
Sentence segments are only added when the prompt exceeds the model's
|
|
55
|
+
max length: a short prompt cannot hide an injection via truncation,
|
|
56
|
+
and scoring its sentences in isolation causes false positives
|
|
57
|
+
(e.g. a bare "You are a helpful assistant." reads as role hijacking).
|
|
58
|
+
"""
|
|
59
|
+
segments = [prompt]
|
|
60
|
+
n_tokens = len(self.tokenizer(prompt, truncation=False)["input_ids"])
|
|
61
|
+
if n_tokens > self.max_len:
|
|
62
|
+
sentences = [s.strip() for s in _SENTENCE_SPLIT.split(prompt) if len(s.strip()) >= 8]
|
|
63
|
+
if len(sentences) > 1:
|
|
64
|
+
segments.extend(sentences)
|
|
65
|
+
return segments
|
|
66
|
+
|
|
67
|
+
def _predict(self, prompt: str) -> Tuple[float, int, float]:
|
|
68
|
+
"""Run one batched inference over all sliding windows and sentences.
|
|
69
|
+
|
|
70
|
+
Returns (attack_prob, pred_id, confidence) taken from the segment
|
|
71
|
+
with the highest attack probability.
|
|
72
|
+
"""
|
|
73
|
+
with self._lock:
|
|
74
|
+
inputs = self.tokenizer(
|
|
75
|
+
self._segments(prompt),
|
|
76
|
+
truncation=True,
|
|
77
|
+
padding=True,
|
|
78
|
+
max_length=self.max_len,
|
|
79
|
+
stride=self.stride,
|
|
80
|
+
return_overflowing_tokens=True,
|
|
81
|
+
return_tensors="pt",
|
|
82
|
+
)
|
|
83
|
+
inputs.pop("overflow_to_sample_mapping", None)
|
|
84
|
+
inputs = inputs.to(self.device)
|
|
85
|
+
with torch.no_grad():
|
|
86
|
+
logits = self.model(**inputs).logits
|
|
87
|
+
probs = torch.softmax(logits, dim=-1) # (num_segments, 2)
|
|
88
|
+
worst = int(torch.argmax(probs[:, 1]))
|
|
89
|
+
segment_probs = probs[worst]
|
|
90
|
+
pred_id = int(torch.argmax(segment_probs))
|
|
91
|
+
return float(segment_probs[1]), pred_id, float(segment_probs[pred_id])
|
|
92
|
+
|
|
93
|
+
def detect(self, prompt: str) -> float:
|
|
94
|
+
"""Return 0.0-1.0 attack probability."""
|
|
95
|
+
return self.detect_and_classify(prompt)[0]
|
|
96
|
+
|
|
97
|
+
def classify(self, prompt: str) -> tuple:
|
|
98
|
+
"""Return (class_name, confidence)."""
|
|
99
|
+
_, class_name, confidence = self.detect_and_classify(prompt)
|
|
100
|
+
return (class_name, confidence)
|
|
101
|
+
|
|
102
|
+
def detect_and_classify(self, prompt: str) -> Tuple[float, str, float]:
|
|
103
|
+
"""Return (attack_prob, class_name, confidence) from a single inference."""
|
|
104
|
+
if not prompt or not prompt.strip():
|
|
105
|
+
return (0.0, "safe", 1.0)
|
|
106
|
+
attack_prob, pred_id, confidence = self._predict(prompt)
|
|
107
|
+
labels = self.model.config.id2label
|
|
108
|
+
return (attack_prob, labels[pred_id], confidence)
|
guardix/exceptions.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Exceptions raised by guardix."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class GuardError(Exception):
|
|
7
|
+
"""Raised when the guard itself fails in fail-closed mode."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, message: str, original_error: Optional[Exception] = None) -> None:
|
|
10
|
+
super().__init__(message)
|
|
11
|
+
self.original_error = original_error
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GuardBlocked(Exception):
|
|
15
|
+
"""Raised when a prompt is blocked by the guard."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, message: str, decision: "Decision") -> None:
|
|
18
|
+
super().__init__(message)
|
|
19
|
+
self.decision = decision
|