llm-firewall 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_firewall-0.2.0/LICENSE +21 -0
- llm_firewall-0.2.0/PKG-INFO +41 -0
- llm_firewall-0.2.0/README.md +15 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/PKG-INFO +41 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/SOURCES.txt +23 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/dependency_links.txt +1 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/entry_points.txt +2 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/requires.txt +13 -0
- llm_firewall-0.2.0/llm_firewall.egg-info/top_level.txt +1 -0
- llm_firewall-0.2.0/llm_shield/__init__.py +3 -0
- llm_firewall-0.2.0/llm_shield/adapters/__init__.py +86 -0
- llm_firewall-0.2.0/llm_shield/cli.py +40 -0
- llm_firewall-0.2.0/llm_shield/core.py +176 -0
- llm_firewall-0.2.0/llm_shield/dynamic.py +90 -0
- llm_firewall-0.2.0/llm_shield/logger.py +55 -0
- llm_firewall-0.2.0/llm_shield/profiles/__init__.py +48 -0
- llm_firewall-0.2.0/llm_shield/utils/__init__.py +0 -0
- llm_firewall-0.2.0/pyproject.toml +34 -0
- llm_firewall-0.2.0/setup.cfg +4 -0
- llm_firewall-0.2.0/tests/test_adapters.py +70 -0
- llm_firewall-0.2.0/tests/test_core.py +71 -0
- llm_firewall-0.2.0/tests/test_dynamic.py +34 -0
- llm_firewall-0.2.0/tests/test_dynamic_firewall.py +162 -0
- llm_firewall-0.2.0/tests/test_logger.py +32 -0
- llm_firewall-0.2.0/tests/test_profiles.py +80 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Parshva Shah
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: llm-firewall
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model, with dynamic policy-document support.
|
|
5
|
+
Author-email: Parshva Shah <shahparshva2005@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/Parshva2605/LLM-Firewall
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Topic :: Security
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: PyYAML
|
|
17
|
+
Provides-Extra: openai
|
|
18
|
+
Requires-Dist: openai; extra == "openai"
|
|
19
|
+
Provides-Extra: groq
|
|
20
|
+
Requires-Dist: groq; extra == "groq"
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest; extra == "dev"
|
|
23
|
+
Provides-Extra: dynamic
|
|
24
|
+
Requires-Dist: sentence-transformers; extra == "dynamic"
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
# LLM-Shield
|
|
28
|
+
|
|
29
|
+
A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model.
|
|
30
|
+
|
|
31
|
+
Install with:
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install llm-shield
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Import with:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from llm_shield import Firewall
|
|
41
|
+
```
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# LLM-Shield
|
|
2
|
+
|
|
3
|
+
A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model.
|
|
4
|
+
|
|
5
|
+
Install with:
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install llm-shield
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Import with:
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
from llm_shield import Firewall
|
|
15
|
+
```
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: llm-firewall
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model, with dynamic policy-document support.
|
|
5
|
+
Author-email: Parshva Shah <shahparshva2005@gmail.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Repository, https://github.com/Parshva2605/LLM-Firewall
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Classifier: Topic :: Security
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Requires-Python: >=3.9
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Requires-Dist: PyYAML
|
|
17
|
+
Provides-Extra: openai
|
|
18
|
+
Requires-Dist: openai; extra == "openai"
|
|
19
|
+
Provides-Extra: groq
|
|
20
|
+
Requires-Dist: groq; extra == "groq"
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest; extra == "dev"
|
|
23
|
+
Provides-Extra: dynamic
|
|
24
|
+
Requires-Dist: sentence-transformers; extra == "dynamic"
|
|
25
|
+
Dynamic: license-file
|
|
26
|
+
|
|
27
|
+
# LLM-Shield
|
|
28
|
+
|
|
29
|
+
A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model.
|
|
30
|
+
|
|
31
|
+
Install with:
|
|
32
|
+
|
|
33
|
+
```bash
|
|
34
|
+
pip install llm-shield
|
|
35
|
+
```
|
|
36
|
+
|
|
37
|
+
Import with:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from llm_shield import Firewall
|
|
41
|
+
```
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
llm_firewall.egg-info/PKG-INFO
|
|
5
|
+
llm_firewall.egg-info/SOURCES.txt
|
|
6
|
+
llm_firewall.egg-info/dependency_links.txt
|
|
7
|
+
llm_firewall.egg-info/entry_points.txt
|
|
8
|
+
llm_firewall.egg-info/requires.txt
|
|
9
|
+
llm_firewall.egg-info/top_level.txt
|
|
10
|
+
llm_shield/__init__.py
|
|
11
|
+
llm_shield/cli.py
|
|
12
|
+
llm_shield/core.py
|
|
13
|
+
llm_shield/dynamic.py
|
|
14
|
+
llm_shield/logger.py
|
|
15
|
+
llm_shield/adapters/__init__.py
|
|
16
|
+
llm_shield/profiles/__init__.py
|
|
17
|
+
llm_shield/utils/__init__.py
|
|
18
|
+
tests/test_adapters.py
|
|
19
|
+
tests/test_core.py
|
|
20
|
+
tests/test_dynamic.py
|
|
21
|
+
tests/test_dynamic_firewall.py
|
|
22
|
+
tests/test_logger.py
|
|
23
|
+
tests/test_profiles.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
llm_shield
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from ..core import Firewall
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def detect_client_type(client) -> str:
|
|
5
|
+
module_name = type(client).__module__.lower()
|
|
6
|
+
if "openai" in module_name:
|
|
7
|
+
return "openai"
|
|
8
|
+
if "groq" in module_name:
|
|
9
|
+
return "groq"
|
|
10
|
+
if "ollama" in module_name:
|
|
11
|
+
return "ollama"
|
|
12
|
+
return "unknown"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FirewallProtectedClient:
|
|
16
|
+
"""Wrap an LLM client so input is filtered, the model is called, then output is filtered.
|
|
17
|
+
|
|
18
|
+
The wrapper keeps the underlying client transparent while routing requests through a Firewall
|
|
19
|
+
instance before and after the LLM call.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(self, client, firewall):
|
|
23
|
+
self.client = client
|
|
24
|
+
self.firewall = firewall
|
|
25
|
+
self.client_type = detect_client_type(client)
|
|
26
|
+
|
|
27
|
+
def chat(self, message: str, **kwargs) -> dict:
|
|
28
|
+
input_result = self.firewall.check_input(message)
|
|
29
|
+
if input_result.blocked:
|
|
30
|
+
return {
|
|
31
|
+
"blocked": True,
|
|
32
|
+
"stage": "input",
|
|
33
|
+
"reason": input_result.reason,
|
|
34
|
+
"category": input_result.category,
|
|
35
|
+
"response": None,
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
if self.client_type == "openai":
|
|
40
|
+
response = self.client.chat.completions.create(
|
|
41
|
+
model=kwargs.get("model", "gpt-4o-mini"),
|
|
42
|
+
messages=[{"role": "user", "content": message}],
|
|
43
|
+
)
|
|
44
|
+
response_text = response.choices[0].message.content
|
|
45
|
+
elif self.client_type == "groq":
|
|
46
|
+
response = self.client.chat.completions.create(
|
|
47
|
+
model=kwargs.get("model", "llama-3.1-8b-instant"),
|
|
48
|
+
messages=[{"role": "user", "content": message}],
|
|
49
|
+
)
|
|
50
|
+
response_text = response.choices[0].message.content
|
|
51
|
+
elif self.client_type == "ollama":
|
|
52
|
+
response = self.client.chat(
|
|
53
|
+
model=kwargs.get("model", "llama3.2"),
|
|
54
|
+
messages=[{"role": "user", "content": message}],
|
|
55
|
+
)
|
|
56
|
+
response_text = response["message"]["content"]
|
|
57
|
+
else:
|
|
58
|
+
raise NotImplementedError(
|
|
59
|
+
"Unsupported client type 'unknown'. Supported client types are: openai, groq, ollama."
|
|
60
|
+
)
|
|
61
|
+
except Exception as exception:
|
|
62
|
+
return {
|
|
63
|
+
"blocked": False,
|
|
64
|
+
"stage": "error",
|
|
65
|
+
"reason": str(exception),
|
|
66
|
+
"category": "error",
|
|
67
|
+
"response": None,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
output_result = self.firewall.check_output(response_text)
|
|
71
|
+
if output_result.blocked:
|
|
72
|
+
return {
|
|
73
|
+
"blocked": True,
|
|
74
|
+
"stage": "output",
|
|
75
|
+
"reason": output_result.reason,
|
|
76
|
+
"category": output_result.category,
|
|
77
|
+
"response": None,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
return {
|
|
81
|
+
"blocked": False,
|
|
82
|
+
"stage": None,
|
|
83
|
+
"reason": "",
|
|
84
|
+
"category": "clean",
|
|
85
|
+
"response": response_text,
|
|
86
|
+
}
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
|
|
4
|
+
from llm_shield.logger import AuditLogger
|
|
5
|
+
from llm_shield.profiles import list_profiles
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
parser = argparse.ArgumentParser(prog="llm-shield")
|
|
10
|
+
subparsers = parser.add_subparsers(dest="command")
|
|
11
|
+
|
|
12
|
+
subparsers.add_parser("profiles", help="List available firewall profiles")
|
|
13
|
+
subparsers.add_parser("logs", help="Show audit log summary and recent entries")
|
|
14
|
+
|
|
15
|
+
args = parser.parse_args()
|
|
16
|
+
|
|
17
|
+
if args.command == "profiles":
|
|
18
|
+
for profile_name in list_profiles():
|
|
19
|
+
print(profile_name)
|
|
20
|
+
return
|
|
21
|
+
|
|
22
|
+
if args.command == "logs":
|
|
23
|
+
logger = AuditLogger()
|
|
24
|
+
summary = logger.get_summary()
|
|
25
|
+
print(f"total_requests: {summary['total_requests']}")
|
|
26
|
+
print(f"blocked_count: {summary['blocked_count']}")
|
|
27
|
+
print(f"allowed_count: {summary['allowed_count']}")
|
|
28
|
+
print("by_category:")
|
|
29
|
+
for category, count in summary["by_category"].items():
|
|
30
|
+
print(f" {category}: {count}")
|
|
31
|
+
print("recent_entries:")
|
|
32
|
+
for entry in logger.read_logs(limit=10):
|
|
33
|
+
print(json.dumps(entry, indent=2))
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
parser.print_help()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
if __name__ == "__main__":
|
|
40
|
+
main()
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
from .logger import AuditLogger
|
|
5
|
+
from .profiles import load_profile, load_profile_from_path
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from llm_shield.dynamic import LocalEmbedder, SENTENCE_TRANSFORMERS_AVAILABLE
|
|
9
|
+
except ImportError:
|
|
10
|
+
LocalEmbedder = None
|
|
11
|
+
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class FirewallResult:
|
|
16
|
+
"""Result of a firewall check.
|
|
17
|
+
|
|
18
|
+
The category field may be one of: off_topic, prompt_injection, safety_violation,
|
|
19
|
+
pii_leak, malicious_code, or clean.
|
|
20
|
+
"""
|
|
21
|
+
blocked: bool
|
|
22
|
+
reason: str
|
|
23
|
+
category: str = "clean"
|
|
24
|
+
confidence: float = 1.0
|
|
25
|
+
matched_rule: str = ""
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Firewall:
|
|
29
|
+
"""Rule-based firewall for basic input blocking and output leak detection.
|
|
30
|
+
|
|
31
|
+
It checks user input against explicit blocked keyword patterns and checks model output for
|
|
32
|
+
simple PII and secret-like leaks using regexes. It can also optionally use a policy document
|
|
33
|
+
to reject messages that fall outside the document's semantic scope.
|
|
34
|
+
|
|
35
|
+
Note: The policy_document semantic scope-checking is a best-effort probabilistic layer based
|
|
36
|
+
on local embeddings. It works best when the policy document is detailed (a few sentences per
|
|
37
|
+
topic, not single short phrases) and should be used alongside blocked_keywords for explicit
|
|
38
|
+
or known-bad patterns rather than as a sole line of defense. Results may vary based on the
|
|
39
|
+
specificity of the policy text and the similarity threshold.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, allowed_topics: list[str] = None, blocked_keywords: list[str] = None, profile: str = None, profile_path: str = None, enable_logging: bool = False, log_file: str = "firewall_audit.log", policy_document: str = None, similarity_threshold: float = 0.12):
|
|
43
|
+
if profile is not None and profile_path is not None:
|
|
44
|
+
raise ValueError("Provide either 'profile' or 'profile_path', not both")
|
|
45
|
+
|
|
46
|
+
self.profile_data = None
|
|
47
|
+
if profile is not None:
|
|
48
|
+
self.profile_data = load_profile(profile)
|
|
49
|
+
elif profile_path is not None:
|
|
50
|
+
self.profile_data = load_profile_from_path(profile_path)
|
|
51
|
+
|
|
52
|
+
if self.profile_data is not None:
|
|
53
|
+
self.allowed_topics = self.profile_data.get("allowed_topics", [])
|
|
54
|
+
profile_blocked_keywords = self.profile_data.get("blocked_keywords", [])
|
|
55
|
+
combined_blocked_keywords = list(dict.fromkeys(profile_blocked_keywords + (blocked_keywords or [])))
|
|
56
|
+
self.blocked_keywords = combined_blocked_keywords
|
|
57
|
+
# If policy_document not explicitly passed but profile contains one, use it
|
|
58
|
+
if policy_document is None and "policy_document" in self.profile_data:
|
|
59
|
+
policy_document = self.profile_data.get("policy_document")
|
|
60
|
+
else:
|
|
61
|
+
self.allowed_topics = allowed_topics or []
|
|
62
|
+
self.blocked_keywords = blocked_keywords or []
|
|
63
|
+
self.logger = AuditLogger(log_file=log_file) if enable_logging else None
|
|
64
|
+
self.similarity_threshold = similarity_threshold
|
|
65
|
+
if policy_document is not None:
|
|
66
|
+
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
|
67
|
+
raise ImportError(
|
|
68
|
+
"sentence-transformers is required to use policy_document. Install with: pip install llm-shield[dynamic]"
|
|
69
|
+
)
|
|
70
|
+
self.embedder = LocalEmbedder()
|
|
71
|
+
self.embedder.index_document(policy_document)
|
|
72
|
+
else:
|
|
73
|
+
self.embedder = None
|
|
74
|
+
|
|
75
|
+
def check_input(self, text: str) -> FirewallResult:
|
|
76
|
+
"""Check input text against blocked keyword patterns.
|
|
77
|
+
|
|
78
|
+
The first substring match wins and returns a blocked result; otherwise the input is clean.
|
|
79
|
+
Matches can be labeled as prompt_injection, safety_violation, malicious_code, or off_topic.
|
|
80
|
+
"""
|
|
81
|
+
lowered_text = text.lower()
|
|
82
|
+
for keyword in self.blocked_keywords:
|
|
83
|
+
lowered_keyword = keyword.lower()
|
|
84
|
+
if lowered_keyword in lowered_text:
|
|
85
|
+
if any(trigger in lowered_keyword for trigger in ["ignore", "disregard", "you are now", "act as", "new instructions", "system prompt", "forget your"]):
|
|
86
|
+
category = "prompt_injection"
|
|
87
|
+
elif any(trigger in lowered_keyword for trigger in ["phone number", "address", "meet me", "keep this secret", "don't tell", "keep secret", "your location", "where do you live"]):
|
|
88
|
+
category = "safety_violation"
|
|
89
|
+
elif any(trigger in lowered_keyword for trigger in ["write code", "python script", "javascript", "write a program", "sql injection", "exploit"]):
|
|
90
|
+
category = "malicious_code"
|
|
91
|
+
else:
|
|
92
|
+
category = "off_topic"
|
|
93
|
+
result = FirewallResult(
|
|
94
|
+
blocked=True,
|
|
95
|
+
reason=f"Input matched blocked pattern: '{keyword}'",
|
|
96
|
+
category=category,
|
|
97
|
+
confidence=1.0,
|
|
98
|
+
matched_rule=keyword,
|
|
99
|
+
)
|
|
100
|
+
if self.logger is not None:
|
|
101
|
+
self.logger.log(result, input_text=text)
|
|
102
|
+
return result
|
|
103
|
+
|
|
104
|
+
result = FirewallResult(blocked=False, reason="", category="clean", confidence=1.0, matched_rule="")
|
|
105
|
+
if self.embedder is not None:
|
|
106
|
+
score = self.embedder.max_similarity_score(text)
|
|
107
|
+
if score < self.similarity_threshold:
|
|
108
|
+
result = FirewallResult(
|
|
109
|
+
blocked=True,
|
|
110
|
+
reason=f"Message does not appear related to the provided policy document (similarity score: {score:.2f}, threshold: {self.similarity_threshold})",
|
|
111
|
+
category="off_topic",
|
|
112
|
+
confidence=round(1.0 - score, 2),
|
|
113
|
+
matched_rule="policy_document_similarity",
|
|
114
|
+
)
|
|
115
|
+
if self.logger is not None:
|
|
116
|
+
self.logger.log(result, input_text=text)
|
|
117
|
+
return result
|
|
118
|
+
if self.logger is not None:
|
|
119
|
+
self.logger.log(result, input_text=text)
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
def check_output(self, text: str) -> FirewallResult:
|
|
123
|
+
"""Check output text for obvious PII or secret-like leaks.
|
|
124
|
+
|
|
125
|
+
It currently detects email addresses and simple API key patterns, then returns a blocked result.
|
|
126
|
+
"""
|
|
127
|
+
email_pattern = re.compile(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
|
|
128
|
+
api_key_patterns = [
|
|
129
|
+
re.compile(r"sk-[A-Za-z0-9]{20,}"),
|
|
130
|
+
re.compile(r"sk-ant-[A-Za-z0-9]{20,}"),
|
|
131
|
+
re.compile(r"sk-proj-[A-Za-z0-9]{20,}"),
|
|
132
|
+
]
|
|
133
|
+
|
|
134
|
+
if email_pattern.search(text):
|
|
135
|
+
result = FirewallResult(
|
|
136
|
+
blocked=True,
|
|
137
|
+
reason="Output contains a detected email address leak.",
|
|
138
|
+
category="pii_leak",
|
|
139
|
+
confidence=1.0,
|
|
140
|
+
matched_rule="email_pattern",
|
|
141
|
+
)
|
|
142
|
+
if self.logger is not None:
|
|
143
|
+
self.logger.log(result, output_text=text)
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
for pattern in api_key_patterns:
|
|
147
|
+
if pattern.search(text):
|
|
148
|
+
result = FirewallResult(
|
|
149
|
+
blocked=True,
|
|
150
|
+
reason="Output contains a detected API key leak.",
|
|
151
|
+
category="pii_leak",
|
|
152
|
+
confidence=1.0,
|
|
153
|
+
matched_rule="api_key_pattern",
|
|
154
|
+
)
|
|
155
|
+
if self.logger is not None:
|
|
156
|
+
self.logger.log(result, output_text=text)
|
|
157
|
+
return result
|
|
158
|
+
|
|
159
|
+
if self.embedder is not None:
|
|
160
|
+
score = self.embedder.max_similarity_score(text)
|
|
161
|
+
if score < self.similarity_threshold:
|
|
162
|
+
result = FirewallResult(
|
|
163
|
+
blocked=True,
|
|
164
|
+
reason=f"Response does not appear related to the provided policy document (similarity score: {score:.2f}, threshold: {self.similarity_threshold})",
|
|
165
|
+
category="off_topic",
|
|
166
|
+
confidence=round(1.0 - score, 2),
|
|
167
|
+
matched_rule="policy_document_similarity_output",
|
|
168
|
+
)
|
|
169
|
+
if self.logger is not None:
|
|
170
|
+
self.logger.log(result, output_text=text)
|
|
171
|
+
return result
|
|
172
|
+
|
|
173
|
+
result = FirewallResult(blocked=False, reason="", category="clean", confidence=1.0, matched_rule="")
|
|
174
|
+
if self.logger is not None:
|
|
175
|
+
self.logger.log(result, output_text=text)
|
|
176
|
+
return result
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Dynamic embedding utilities for local semantic indexing.
|
|
2
|
+
|
|
3
|
+
This module provides a small helper for chunking text and computing similarity scores
|
|
4
|
+
with sentence-transformers when it is installed.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from sentence_transformers import SentenceTransformer
|
|
15
|
+
except ImportError:
|
|
16
|
+
SentenceTransformer = None
|
|
17
|
+
SENTENCE_TRANSFORMERS_AVAILABLE = False
|
|
18
|
+
else:
|
|
19
|
+
SENTENCE_TRANSFORMERS_AVAILABLE = True
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def chunk_text(text: str, chunk_size: int = 200, overlap: int = 50) -> list[str]:
|
|
23
|
+
"""Split text into overlapping word-based chunks.
|
|
24
|
+
|
|
25
|
+
The function uses whitespace tokenization and returns the original text as a single chunk
|
|
26
|
+
when the input is shorter than the requested chunk size.
|
|
27
|
+
"""
|
|
28
|
+
words = text.split()
|
|
29
|
+
if len(words) <= chunk_size:
|
|
30
|
+
return [text]
|
|
31
|
+
|
|
32
|
+
step = max(1, chunk_size - overlap)
|
|
33
|
+
chunks = []
|
|
34
|
+
for start in range(0, len(words), step):
|
|
35
|
+
chunk_words = words[start:start + chunk_size]
|
|
36
|
+
if not chunk_words:
|
|
37
|
+
break
|
|
38
|
+
chunks.append(" ".join(chunk_words))
|
|
39
|
+
if start + chunk_size >= len(words):
|
|
40
|
+
break
|
|
41
|
+
|
|
42
|
+
return chunks
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LocalEmbedder:
|
|
46
|
+
"""Index text locally and compare queries with sentence-transformer embeddings.
|
|
47
|
+
|
|
48
|
+
The class stores chunked document text and computes cosine similarity scores for queries
|
|
49
|
+
against the indexed chunks.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
|
53
|
+
"""Load a sentence-transformers model for local semantic comparison."""
|
|
54
|
+
if not SENTENCE_TRANSFORMERS_AVAILABLE:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"sentence-transformers is required for dynamic profiles. Install with: pip install llm-shield[dynamic]"
|
|
57
|
+
)
|
|
58
|
+
self.model = SentenceTransformer(model_name)
|
|
59
|
+
self.chunks = []
|
|
60
|
+
self.embeddings = None
|
|
61
|
+
|
|
62
|
+
def index_document(self, text: str) -> None:
|
|
63
|
+
"""Chunk and embed a document for later similarity search."""
|
|
64
|
+
self.chunks = chunk_text(text)
|
|
65
|
+
self.embeddings = self.model.encode(self.chunks)
|
|
66
|
+
|
|
67
|
+
def most_similar(self, query: str, top_k: int = 3) -> list[dict]:
|
|
68
|
+
"""Return the most similar indexed chunks for a query string."""
|
|
69
|
+
if self.embeddings is None:
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
query_embedding = np.asarray(self.model.encode([query])[0], dtype=float)
|
|
73
|
+
chunk_embeddings = np.asarray(self.embeddings, dtype=float)
|
|
74
|
+
|
|
75
|
+
query_norm = np.linalg.norm(query_embedding)
|
|
76
|
+
chunk_norms = np.linalg.norm(chunk_embeddings, axis=1)
|
|
77
|
+
denominator = chunk_norms * query_norm
|
|
78
|
+
scores = np.zeros(len(chunk_embeddings), dtype=float)
|
|
79
|
+
valid_mask = denominator != 0
|
|
80
|
+
scores[valid_mask] = np.dot(chunk_embeddings[valid_mask], query_embedding) / denominator[valid_mask]
|
|
81
|
+
|
|
82
|
+
ranked_indices = np.argsort(scores)[::-1][:top_k]
|
|
83
|
+
return [{"chunk": self.chunks[index], "score": float(scores[index])} for index in ranked_indices]
|
|
84
|
+
|
|
85
|
+
def max_similarity_score(self, query: str) -> float:
|
|
86
|
+
"""Return the highest similarity score for a query or 0.0 when no chunks are indexed."""
|
|
87
|
+
results = self.most_similar(query, top_k=1)
|
|
88
|
+
if not results:
|
|
89
|
+
return 0.0
|
|
90
|
+
return results[0]["score"]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AuditLogger:
|
|
8
|
+
def __init__(self, log_file: str = "firewall_audit.log"):
|
|
9
|
+
self.log_file = os.fspath(log_file)
|
|
10
|
+
|
|
11
|
+
def log(self, result, input_text: str = None, output_text: str = None) -> None:
|
|
12
|
+
entry = {
|
|
13
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
14
|
+
"blocked": result.blocked,
|
|
15
|
+
"category": result.category,
|
|
16
|
+
"reason": result.reason,
|
|
17
|
+
"matched_rule": result.matched_rule,
|
|
18
|
+
"confidence": result.confidence,
|
|
19
|
+
"input_preview": input_text[:100] if input_text is not None else None,
|
|
20
|
+
"output_preview": output_text[:100] if output_text is not None else None,
|
|
21
|
+
}
|
|
22
|
+
with Path(self.log_file).open("a", encoding="utf-8") as log_handle:
|
|
23
|
+
log_handle.write(json.dumps(entry) + "\n")
|
|
24
|
+
|
|
25
|
+
def read_logs(self, limit: int = 50) -> list:
|
|
26
|
+
log_path = Path(self.log_file)
|
|
27
|
+
if not log_path.exists():
|
|
28
|
+
return []
|
|
29
|
+
|
|
30
|
+
entries = []
|
|
31
|
+
with log_path.open("r", encoding="utf-8") as log_handle:
|
|
32
|
+
for line in log_handle:
|
|
33
|
+
line = line.strip()
|
|
34
|
+
if not line:
|
|
35
|
+
continue
|
|
36
|
+
entries.append(json.loads(line))
|
|
37
|
+
|
|
38
|
+
return entries[-limit:]
|
|
39
|
+
|
|
40
|
+
def get_summary(self) -> dict:
|
|
41
|
+
log_path = Path(self.log_file)
|
|
42
|
+
if not log_path.exists():
|
|
43
|
+
return {"total_requests": 0, "blocked_count": 0, "allowed_count": 0, "by_category": {}}
|
|
44
|
+
|
|
45
|
+
entries = self.read_logs(limit=10**9)
|
|
46
|
+
summary = {"total_requests": len(entries), "blocked_count": 0, "allowed_count": 0, "by_category": {}}
|
|
47
|
+
for entry in entries:
|
|
48
|
+
if entry.get("blocked"):
|
|
49
|
+
summary["blocked_count"] += 1
|
|
50
|
+
else:
|
|
51
|
+
summary["allowed_count"] += 1
|
|
52
|
+
category = entry.get("category", "")
|
|
53
|
+
summary["by_category"][category] = summary["by_category"].get(category, 0) + 1
|
|
54
|
+
|
|
55
|
+
return summary
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def _profiles_data_dir() -> Path:
|
|
7
|
+
return Path(__file__).resolve().parent / "data"
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def list_profiles() -> list[str]:
|
|
11
|
+
"""Return all available profile names from the built-in profiles directory.
|
|
12
|
+
|
|
13
|
+
Profile names are derived from YAML filenames and returned in sorted order.
|
|
14
|
+
"""
|
|
15
|
+
profiles_dir = _profiles_data_dir()
|
|
16
|
+
return sorted(profile_path.stem for profile_path in profiles_dir.glob("*.yaml"))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_profile(name: str) -> dict:
|
|
20
|
+
"""Load a named YAML profile from the profiles data directory.
|
|
21
|
+
|
|
22
|
+
Raises ValueError if the requested profile file does not exist.
|
|
23
|
+
"""
|
|
24
|
+
profiles_dir = _profiles_data_dir()
|
|
25
|
+
profile_path = profiles_dir / f"{name}.yaml"
|
|
26
|
+
if not profile_path.exists():
|
|
27
|
+
available_profiles = list_profiles()
|
|
28
|
+
raise ValueError(f"Profile '{name}' not found. Available profiles: {available_profiles}")
|
|
29
|
+
|
|
30
|
+
with profile_path.open("r", encoding="utf-8") as profile_file:
|
|
31
|
+
profile_data = yaml.safe_load(profile_file)
|
|
32
|
+
|
|
33
|
+
return profile_data or {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_profile_from_path(path: str) -> dict:
|
|
37
|
+
"""Load a YAML profile from an external file path.
|
|
38
|
+
|
|
39
|
+
Raises FileNotFoundError if the file does not exist.
|
|
40
|
+
"""
|
|
41
|
+
profile_path = Path(path)
|
|
42
|
+
if not profile_path.exists():
|
|
43
|
+
raise FileNotFoundError(f"Profile file not found: {path}")
|
|
44
|
+
|
|
45
|
+
with profile_path.open("r", encoding="utf-8") as profile_file:
|
|
46
|
+
profile_data = yaml.safe_load(profile_file)
|
|
47
|
+
|
|
48
|
+
return profile_data or {}
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "llm-firewall"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "A lightweight, domain-aware safety firewall for LLM applications — works with any API or local model, with dynamic policy-document support."
|
|
9
|
+
authors = [{ name = "Parshva Shah", email = "shahparshva2005@gmail.com" }]
|
|
10
|
+
license = { text = "MIT" }
|
|
11
|
+
classifiers = [
|
|
12
|
+
"Programming Language :: Python :: 3",
|
|
13
|
+
"License :: OSI Approved :: MIT License",
|
|
14
|
+
"Operating System :: OS Independent",
|
|
15
|
+
"Topic :: Security",
|
|
16
|
+
"Development Status :: 3 - Alpha",
|
|
17
|
+
]
|
|
18
|
+
readme = "README.md"
|
|
19
|
+
requires-python = ">=3.9"
|
|
20
|
+
dependencies = [
|
|
21
|
+
"PyYAML",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
[project.optional-dependencies]
|
|
25
|
+
openai = ["openai"]
|
|
26
|
+
groq = ["groq"]
|
|
27
|
+
dev = ["pytest"]
|
|
28
|
+
dynamic = ["sentence-transformers"]
|
|
29
|
+
|
|
30
|
+
[project.scripts]
|
|
31
|
+
llm-shield = "llm_shield.cli:main"
|
|
32
|
+
|
|
33
|
+
[project.urls]
|
|
34
|
+
Repository = "https://github.com/Parshva2605/LLM-Firewall"
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
from llm_shield.adapters import FirewallProtectedClient, detect_client_type
|
|
2
|
+
from llm_shield.core import Firewall
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class FakeOpenAIResponseMessage:
|
|
6
|
+
def __init__(self, content):
|
|
7
|
+
self.content = content
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FakeOpenAIResponseChoice:
|
|
11
|
+
def __init__(self, content):
|
|
12
|
+
self.message = FakeOpenAIResponseMessage(content)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FakeOpenAIResponse:
|
|
16
|
+
def __init__(self, content):
|
|
17
|
+
self.choices = [FakeOpenAIResponseChoice(content)]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FakeOpenAICompletions:
|
|
21
|
+
def __init__(self, response_text):
|
|
22
|
+
self.response_text = response_text
|
|
23
|
+
self.called = False
|
|
24
|
+
|
|
25
|
+
def create(self, model, messages):
|
|
26
|
+
self.called = True
|
|
27
|
+
return FakeOpenAIResponse(self.response_text)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class FakeOpenAIChat:
|
|
31
|
+
def __init__(self, response_text):
|
|
32
|
+
self.completions = FakeOpenAICompletions(response_text)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class FakeOpenAIClient:
|
|
36
|
+
__module__ = "openai.some.fake.path"
|
|
37
|
+
|
|
38
|
+
def __init__(self, response_text="This is a test response"):
|
|
39
|
+
self.chat = FakeOpenAIChat(response_text)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_detect_client_type_identifies_openai():
|
|
43
|
+
client = FakeOpenAIClient()
|
|
44
|
+
|
|
45
|
+
assert detect_client_type(client) == "openai"
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_firewall_protected_client_blocks_input_without_calling_client():
|
|
49
|
+
client = FakeOpenAIClient()
|
|
50
|
+
firewall = Firewall(blocked_keywords=["write code"])
|
|
51
|
+
protected_client = FirewallProtectedClient(client, firewall)
|
|
52
|
+
|
|
53
|
+
result = protected_client.chat("please write code for me")
|
|
54
|
+
|
|
55
|
+
assert result["blocked"] is True
|
|
56
|
+
assert result["stage"] == "input"
|
|
57
|
+
assert client.chat.completions.called is False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_firewall_protected_client_allows_clean_message_and_returns_response():
|
|
61
|
+
client = FakeOpenAIClient()
|
|
62
|
+
firewall = Firewall(blocked_keywords=["write code"])
|
|
63
|
+
protected_client = FirewallProtectedClient(client, firewall)
|
|
64
|
+
|
|
65
|
+
result = protected_client.chat("hello there")
|
|
66
|
+
|
|
67
|
+
assert result["blocked"] is False
|
|
68
|
+
assert result["stage"] is None
|
|
69
|
+
assert result["response"] == "This is a test response"
|
|
70
|
+
assert client.chat.completions.called is True
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from llm_shield.core import Firewall
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_check_input_blocks_write_code():
|
|
7
|
+
firewall = Firewall(blocked_keywords=["write code", "ignore previous instructions"])
|
|
8
|
+
|
|
9
|
+
result = firewall.check_input("please write code for me")
|
|
10
|
+
|
|
11
|
+
assert result.blocked is True
|
|
12
|
+
assert result.category == "malicious_code"
|
|
13
|
+
assert result.matched_rule == "write code"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def test_check_input_blocks_safety_violation():
|
|
17
|
+
firewall = Firewall(blocked_keywords=["keep this secret"])
|
|
18
|
+
|
|
19
|
+
result = firewall.check_input("what is your address, keep this secret")
|
|
20
|
+
|
|
21
|
+
assert result.blocked is True
|
|
22
|
+
assert result.category == "safety_violation"
|
|
23
|
+
assert result.matched_rule == "keep this secret"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_check_input_blocks_prompt_injection():
|
|
27
|
+
firewall = Firewall(blocked_keywords=["ignore previous instructions"])
|
|
28
|
+
|
|
29
|
+
result = firewall.check_input("ignore previous instructions now")
|
|
30
|
+
|
|
31
|
+
assert result.blocked is True
|
|
32
|
+
assert result.category == "prompt_injection"
|
|
33
|
+
assert result.matched_rule == "ignore previous instructions"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def test_check_input_allows_clean_input():
|
|
37
|
+
firewall = Firewall(blocked_keywords=["write code", "ignore previous instructions"])
|
|
38
|
+
|
|
39
|
+
result = firewall.check_input("what time do you close")
|
|
40
|
+
|
|
41
|
+
assert result.blocked is False
|
|
42
|
+
assert result.category == "clean"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_check_output_blocks_email():
|
|
46
|
+
firewall = Firewall()
|
|
47
|
+
|
|
48
|
+
result = firewall.check_output("contact me at test@example.com")
|
|
49
|
+
|
|
50
|
+
assert result.blocked is True
|
|
51
|
+
assert result.category == "pii_leak"
|
|
52
|
+
assert result.matched_rule == "email_pattern"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_check_output_blocks_api_key_pattern():
|
|
56
|
+
firewall = Firewall()
|
|
57
|
+
|
|
58
|
+
result = firewall.check_output("here is a key sk-abc123def456ghi789jkl012")
|
|
59
|
+
|
|
60
|
+
assert result.blocked is True
|
|
61
|
+
assert result.category == "pii_leak"
|
|
62
|
+
assert result.matched_rule == "api_key_pattern"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_check_output_allows_clean_text():
|
|
66
|
+
firewall = Firewall()
|
|
67
|
+
|
|
68
|
+
result = firewall.check_output("your order will arrive in 20 minutes")
|
|
69
|
+
|
|
70
|
+
assert result.blocked is False
|
|
71
|
+
assert result.category == "clean"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from llm_shield.dynamic import SENTENCE_TRANSFORMERS_AVAILABLE, LocalEmbedder, chunk_text
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_chunk_text_short_text_returns_single_chunk():
|
|
7
|
+
text = "This is a short policy summary."
|
|
8
|
+
|
|
9
|
+
chunks = chunk_text(text, chunk_size=200, overlap=50)
|
|
10
|
+
|
|
11
|
+
assert chunks == [text]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_chunk_text_long_text_returns_multiple_chunks():
|
|
15
|
+
text = "word " * 500
|
|
16
|
+
|
|
17
|
+
chunks = chunk_text(text.strip(), chunk_size=200, overlap=50)
|
|
18
|
+
|
|
19
|
+
assert len(chunks) > 1
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
23
|
+
def test_local_embedder_similarity_scores_refund_policy_high():
|
|
24
|
+
embedder = LocalEmbedder()
|
|
25
|
+
embedder.index_document(
|
|
26
|
+
"Our company accepts returns within 30 days of purchase with a receipt. "
|
|
27
|
+
"Refunds are processed to the original payment method after the item is inspected. "
|
|
28
|
+
"Shipping fees are non-refundable unless the return is due to our error."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
results = embedder.most_similar("what is your refund policy")
|
|
32
|
+
|
|
33
|
+
assert results
|
|
34
|
+
assert results[0]["score"] > 0.3
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from llm_shield.core import Firewall
|
|
4
|
+
from llm_shield.dynamic import SENTENCE_TRANSFORMERS_AVAILABLE
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
8
|
+
def test_policy_document_allows_in_scope_question():
|
|
9
|
+
firewall = Firewall(
|
|
10
|
+
policy_document=(
|
|
11
|
+
"Our company sells organic coffee beans and offers subscription delivery. "
|
|
12
|
+
"We answer questions about roast levels, brewing methods, and shipping times."
|
|
13
|
+
)
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
result = firewall.check_input("what roast level do you recommend")
|
|
17
|
+
|
|
18
|
+
assert result.blocked is False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
22
|
+
def test_policy_document_blocks_out_of_scope_message():
|
|
23
|
+
firewall = Firewall(
|
|
24
|
+
policy_document=(
|
|
25
|
+
"Our company sells organic coffee beans and offers subscription delivery. "
|
|
26
|
+
"We answer questions about roast levels, brewing methods, and shipping times."
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
result = firewall.check_input("can you help me write a Python script to scrape websites")
|
|
31
|
+
|
|
32
|
+
assert result.blocked is True
|
|
33
|
+
assert result.category == "off_topic"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
37
|
+
def test_firewall_without_policy_document_still_allows_clean_input():
|
|
38
|
+
firewall = Firewall()
|
|
39
|
+
|
|
40
|
+
result = firewall.check_input("random text")
|
|
41
|
+
|
|
42
|
+
assert result.blocked is False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@pytest.fixture(scope="module")
|
|
46
|
+
def policy_firewall():
|
|
47
|
+
return Firewall(
|
|
48
|
+
policy_document=(
|
|
49
|
+
"Our company sells organic coffee beans and offers subscription delivery. "
|
|
50
|
+
"We answer questions about roast levels, brewing methods, and shipping times."
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
56
|
+
def test_policy_document_allows_relevant_output(policy_firewall):
|
|
57
|
+
result = policy_firewall.check_output("Our medium roast has notes of chocolate and caramel, perfect for espresso")
|
|
58
|
+
|
|
59
|
+
assert result.blocked is False
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
63
|
+
def test_policy_document_blocks_out_of_scope_output(policy_firewall):
|
|
64
|
+
result = policy_firewall.check_output("The stock market rose 2% today amid strong tech earnings reports")
|
|
65
|
+
|
|
66
|
+
assert result.blocked is True
|
|
67
|
+
assert result.category == "off_topic"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
71
|
+
def test_profile_path_with_policy_document_blocks_off_topic_input(tmp_path):
|
|
72
|
+
profile_file = tmp_path / "coffee_shop_profile.yaml"
|
|
73
|
+
profile_file.write_text("""name: coffee_shop_profile
|
|
74
|
+
allowed_topics:
|
|
75
|
+
- coffee
|
|
76
|
+
- beverages
|
|
77
|
+
blocked_keywords:
|
|
78
|
+
- refund
|
|
79
|
+
policy_document: |
|
|
80
|
+
Our coffee shop specializes in artisan coffee drinks and pastries.
|
|
81
|
+
We answer questions about our menu, ingredients, opening hours, and locations.
|
|
82
|
+
We help customers choose drinks based on their preferences.
|
|
83
|
+
""")
|
|
84
|
+
|
|
85
|
+
firewall = Firewall(profile_path=str(profile_file))
|
|
86
|
+
result = firewall.check_input("write me a python script to analyze data")
|
|
87
|
+
|
|
88
|
+
assert result.blocked is True
|
|
89
|
+
assert result.category == "off_topic"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
93
|
+
def test_profile_path_with_policy_document_allows_on_topic_input(tmp_path):
|
|
94
|
+
profile_file = tmp_path / "coffee_shop_profile.yaml"
|
|
95
|
+
profile_file.write_text("""name: coffee_shop_profile
|
|
96
|
+
allowed_topics:
|
|
97
|
+
- coffee
|
|
98
|
+
- beverages
|
|
99
|
+
blocked_keywords:
|
|
100
|
+
- refund
|
|
101
|
+
policy_document: |
|
|
102
|
+
Our coffee shop specializes in artisan coffee drinks and pastries.
|
|
103
|
+
We answer questions about our menu, ingredients, opening hours, and locations.
|
|
104
|
+
We help customers choose drinks based on their preferences.
|
|
105
|
+
""")
|
|
106
|
+
|
|
107
|
+
firewall = Firewall(profile_path=str(profile_file))
|
|
108
|
+
result = firewall.check_input("what drinks do you recommend for someone who likes sweet flavors")
|
|
109
|
+
|
|
110
|
+
assert result.blocked is False
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
114
|
+
def test_explicit_policy_document_overrides_profile_policy_document(tmp_path):
|
|
115
|
+
profile_file = tmp_path / "coffee_shop_profile.yaml"
|
|
116
|
+
profile_file.write_text("""name: coffee_shop_profile
|
|
117
|
+
policy_document: |
|
|
118
|
+
This is about coffee shops and beverages.
|
|
119
|
+
""")
|
|
120
|
+
|
|
121
|
+
# Explicit policy_document about tech support should override the coffee shop policy in YAML
|
|
122
|
+
firewall = Firewall(
|
|
123
|
+
profile_path=str(profile_file),
|
|
124
|
+
policy_document="We provide technical support for software issues and troubleshooting."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Coffee question should now be blocked because explicit policy is about tech support
|
|
128
|
+
result = firewall.check_input("what coffee drinks do you have")
|
|
129
|
+
assert result.blocked is True
|
|
130
|
+
assert result.category == "off_topic"
|
|
131
|
+
|
|
132
|
+
# Tech question should pass
|
|
133
|
+
result = firewall.check_input("I need help troubleshooting my software installation")
|
|
134
|
+
assert result.blocked is False
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.skipif(not SENTENCE_TRANSFORMERS_AVAILABLE, reason="sentence-transformers not installed")
|
|
138
|
+
def test_builtin_profile_with_policy_document_works_identically(tmp_path):
|
|
139
|
+
# Test that profile= (built-in) and profile_path= (external) both work with policy_document
|
|
140
|
+
# First create an external profile with policy_document
|
|
141
|
+
profile_file = tmp_path / "book_store_profile.yaml"
|
|
142
|
+
profile_file.write_text("""name: book_store_profile
|
|
143
|
+
allowed_topics:
|
|
144
|
+
- books
|
|
145
|
+
- authors
|
|
146
|
+
blocked_keywords:
|
|
147
|
+
- malicious
|
|
148
|
+
policy_document: |
|
|
149
|
+
We are an online bookstore. We answer questions about book recommendations,
|
|
150
|
+
author information, genres, and order status.
|
|
151
|
+
""")
|
|
152
|
+
|
|
153
|
+
firewall = Firewall(profile_path=str(profile_file))
|
|
154
|
+
|
|
155
|
+
# Book-related question should pass
|
|
156
|
+
result = firewall.check_input("can you recommend a mystery novel")
|
|
157
|
+
assert result.blocked is False
|
|
158
|
+
|
|
159
|
+
# Unrelated question should be blocked by policy_document
|
|
160
|
+
result = firewall.check_input("what is the weather forecast for tomorrow")
|
|
161
|
+
assert result.blocked is True
|
|
162
|
+
assert result.category == "off_topic"
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from llm_shield.core import FirewallResult
|
|
2
|
+
from llm_shield.logger import AuditLogger
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_audit_logger_summary(tmp_path):
|
|
6
|
+
log_file = tmp_path / "audit.log"
|
|
7
|
+
logger = AuditLogger(log_file=str(log_file))
|
|
8
|
+
|
|
9
|
+
logger.log(FirewallResult(blocked=True, reason="blocked one", category="prompt_injection", confidence=1.0, matched_rule="rule1"), input_text="input 1")
|
|
10
|
+
logger.log(FirewallResult(blocked=True, reason="blocked two", category="pii_leak", confidence=1.0, matched_rule="rule2"), output_text="output 2")
|
|
11
|
+
logger.log(FirewallResult(blocked=False, reason="", category="clean", confidence=1.0, matched_rule=""), input_text="input 3")
|
|
12
|
+
|
|
13
|
+
summary = logger.get_summary()
|
|
14
|
+
|
|
15
|
+
assert summary["total_requests"] == 3
|
|
16
|
+
assert summary["blocked_count"] == 2
|
|
17
|
+
assert summary["allowed_count"] == 1
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_audit_logger_read_logs_limit(tmp_path):
|
|
21
|
+
log_file = tmp_path / "audit.log"
|
|
22
|
+
logger = AuditLogger(log_file=str(log_file))
|
|
23
|
+
|
|
24
|
+
logger.log(FirewallResult(blocked=True, reason="blocked one", category="prompt_injection", confidence=1.0, matched_rule="rule1"))
|
|
25
|
+
logger.log(FirewallResult(blocked=True, reason="blocked two", category="pii_leak", confidence=1.0, matched_rule="rule2"))
|
|
26
|
+
logger.log(FirewallResult(blocked=False, reason="", category="clean", confidence=1.0, matched_rule=""))
|
|
27
|
+
|
|
28
|
+
logs = logger.read_logs(limit=2)
|
|
29
|
+
|
|
30
|
+
assert len(logs) == 2
|
|
31
|
+
assert logs[0]["reason"] == "blocked two"
|
|
32
|
+
assert logs[1]["reason"] == ""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from llm_shield.core import Firewall
|
|
4
|
+
from llm_shield.profiles import list_profiles, load_profile, load_profile_from_path
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_load_profile_returns_fast_food_support_profile():
|
|
8
|
+
profile = load_profile("fast_food_support")
|
|
9
|
+
|
|
10
|
+
assert profile["name"] == "fast_food_support"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def test_load_profile_raises_for_missing_profile():
|
|
14
|
+
with pytest.raises(ValueError):
|
|
15
|
+
load_profile("nonexistent_profile")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def test_list_profiles_returns_all_profile_names():
|
|
19
|
+
profiles = list_profiles()
|
|
20
|
+
|
|
21
|
+
assert "fast_food_support" in profiles
|
|
22
|
+
assert "hospital_patient_chat" in profiles
|
|
23
|
+
assert "school_tutor_bot" in profiles
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def test_firewall_profile_blocks_fast_food_code_request():
|
|
27
|
+
firewall = Firewall(profile="fast_food_support")
|
|
28
|
+
|
|
29
|
+
result = firewall.check_input("can you write code for me")
|
|
30
|
+
|
|
31
|
+
assert result.blocked is True
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_firewall_profile_blocks_hospital_diagnosis_request():
|
|
35
|
+
firewall = Firewall(profile="hospital_patient_chat")
|
|
36
|
+
|
|
37
|
+
result = firewall.check_input("can you diagnose me")
|
|
38
|
+
|
|
39
|
+
assert result.blocked is True
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def test_load_profile_from_path_loads_yaml_file(tmp_path):
|
|
43
|
+
profile_file = tmp_path / "test_profile.yaml"
|
|
44
|
+
profile_file.write_text("""name: test_profile
|
|
45
|
+
allowed_topics:
|
|
46
|
+
- testing
|
|
47
|
+
- demo
|
|
48
|
+
blocked_keywords:
|
|
49
|
+
- forbidden
|
|
50
|
+
- banned
|
|
51
|
+
""")
|
|
52
|
+
|
|
53
|
+
profile = load_profile_from_path(str(profile_file))
|
|
54
|
+
|
|
55
|
+
assert profile["name"] == "test_profile"
|
|
56
|
+
assert profile["allowed_topics"] == ["testing", "demo"]
|
|
57
|
+
assert profile["blocked_keywords"] == ["forbidden", "banned"]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_load_profile_from_path_raises_for_nonexistent_file():
|
|
61
|
+
with pytest.raises(FileNotFoundError):
|
|
62
|
+
load_profile_from_path("/nonexistent/path/profile.yaml")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_firewall_profile_path_blocks_using_loaded_keywords(tmp_path):
|
|
66
|
+
profile_file = tmp_path / "custom_profile.yaml"
|
|
67
|
+
profile_file.write_text("""name: custom_profile
|
|
68
|
+
blocked_keywords:
|
|
69
|
+
- test_blocked_word
|
|
70
|
+
""")
|
|
71
|
+
|
|
72
|
+
firewall = Firewall(profile_path=str(profile_file))
|
|
73
|
+
result = firewall.check_input("this contains test_blocked_word in it")
|
|
74
|
+
|
|
75
|
+
assert result.blocked is True
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def test_firewall_raises_error_when_both_profile_and_profile_path_provided():
|
|
79
|
+
with pytest.raises(ValueError, match="Provide either 'profile' or 'profile_path', not both"):
|
|
80
|
+
Firewall(profile="fast_food_support", profile_path="somepath.yaml")
|