pyagentshield 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentguard/__init__.py +65 -0
- agentguard/api/__init__.py +9 -0
- agentguard/api/decorator.py +157 -0
- agentguard/api/scan.py +84 -0
- agentguard/cleaning/__init__.py +27 -0
- agentguard/cleaning/base.py +54 -0
- agentguard/cleaning/finetuned.py +363 -0
- agentguard/cleaning/heuristic.py +298 -0
- agentguard/cleaning/hybrid.py +360 -0
- agentguard/cleaning/llm.py +146 -0
- agentguard/cli/__init__.py +1 -0
- agentguard/cli/main.py +412 -0
- agentguard/core/__init__.py +23 -0
- agentguard/core/config.py +267 -0
- agentguard/core/exceptions.py +66 -0
- agentguard/core/results.py +134 -0
- agentguard/core/setup.py +169 -0
- agentguard/core/shield.py +328 -0
- agentguard/data/__init__.py +1 -0
- agentguard/detectors/__init__.py +10 -0
- agentguard/detectors/base.py +80 -0
- agentguard/detectors/zedd.py +327 -0
- agentguard/integrations/__init__.py +12 -0
- agentguard/integrations/langchain.py +267 -0
- agentguard/providers/__init__.py +23 -0
- agentguard/providers/base.py +73 -0
- agentguard/providers/local.py +208 -0
- agentguard/providers/mlx.py +446 -0
- agentguard/providers/openai.py +216 -0
- agentguard/py.typed +0 -0
- agentguard/threshold/__init__.py +11 -0
- agentguard/threshold/calibrator.py +421 -0
- agentguard/threshold/manager.py +260 -0
- agentguard/threshold/registry.py +90 -0
- pyagentshield-0.1.0.dist-info/METADATA +616 -0
- pyagentshield-0.1.0.dist-info/RECORD +39 -0
- pyagentshield-0.1.0.dist-info/WHEEL +4 -0
- pyagentshield-0.1.0.dist-info/entry_points.txt +2 -0
- pyagentshield-0.1.0.dist-info/licenses/LICENSE +21 -0
agentguard/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AgentGuard - Prompt injection detection for Agents.
|
|
3
|
+
|
|
4
|
+
Uses ZEDD (Zero-Shot Embedding Drift Detection) to identify malicious
|
|
5
|
+
content in retrieved documents before they reach the LLM context window.
|
|
6
|
+
|
|
7
|
+
Basic usage:
|
|
8
|
+
>>> from agentguard import scan
|
|
9
|
+
>>> result = scan("some document text")
|
|
10
|
+
>>> if result.is_suspicious:
|
|
11
|
+
... print(f"Detected: {result.details.summary}")
|
|
12
|
+
|
|
13
|
+
Decorator usage:
|
|
14
|
+
>>> from agentguard import shield
|
|
15
|
+
>>> @shield(on_detect="warn")
|
|
16
|
+
... def process_docs(query: str, docs: list[str]) -> str:
|
|
17
|
+
... return llm.invoke(build_prompt(query, docs))
|
|
18
|
+
|
|
19
|
+
LangChain integration:
|
|
20
|
+
>>> from agentguard.integrations.langchain import ShieldRunnable
|
|
21
|
+
>>> chain = retriever | ShieldRunnable() | prompt | llm
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
# Load environment variables from .env file
|
|
25
|
+
from dotenv import load_dotenv
|
|
26
|
+
load_dotenv()
|
|
27
|
+
|
|
28
|
+
from agentguard.core.config import ShieldConfig
|
|
29
|
+
from agentguard.core.results import ScanResult, DetectionSignal, ScanDetails
|
|
30
|
+
from agentguard.core.exceptions import (
|
|
31
|
+
AgentGuardError,
|
|
32
|
+
PromptInjectionDetected,
|
|
33
|
+
CalibrationError,
|
|
34
|
+
ConfigurationError,
|
|
35
|
+
SetupError,
|
|
36
|
+
)
|
|
37
|
+
from agentguard.core.shield import AgentGuard
|
|
38
|
+
from agentguard.core.setup import setup, is_model_cached, SetupResult
|
|
39
|
+
from agentguard.api.scan import scan
|
|
40
|
+
from agentguard.api.decorator import shield
|
|
41
|
+
|
|
42
|
+
__version__ = "0.1.0"
|
|
43
|
+
|
|
44
|
+
__all__ = [
|
|
45
|
+
# Main class
|
|
46
|
+
"AgentGuard",
|
|
47
|
+
# API functions
|
|
48
|
+
"scan",
|
|
49
|
+
"shield",
|
|
50
|
+
# Setup
|
|
51
|
+
"setup",
|
|
52
|
+
"is_model_cached",
|
|
53
|
+
"SetupResult",
|
|
54
|
+
# Config and results
|
|
55
|
+
"ShieldConfig",
|
|
56
|
+
"ScanResult",
|
|
57
|
+
"DetectionSignal",
|
|
58
|
+
"ScanDetails",
|
|
59
|
+
# Exceptions
|
|
60
|
+
"AgentGuardError",
|
|
61
|
+
"PromptInjectionDetected",
|
|
62
|
+
"CalibrationError",
|
|
63
|
+
"ConfigurationError",
|
|
64
|
+
"SetupError",
|
|
65
|
+
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""Decorator API for AgentGuard."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Union, TypeVar
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from typing import ParamSpec
|
|
13
|
+
except ImportError:
|
|
14
|
+
from typing_extensions import ParamSpec
|
|
15
|
+
|
|
16
|
+
from agentguard.core.config import ShieldConfig
|
|
17
|
+
from agentguard.core.exceptions import PromptInjectionDetected
|
|
18
|
+
|
|
19
|
+
P = ParamSpec("P")
|
|
20
|
+
R = TypeVar("R")
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def shield(
|
|
26
|
+
on_detect: str = "warn",
|
|
27
|
+
confidence_threshold: float = 0.5,
|
|
28
|
+
scan_args: Optional[List[str]] = None,
|
|
29
|
+
config: Optional[Union[ShieldConfig, Dict[str, Any], str, Path]] = None,
|
|
30
|
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
|
31
|
+
"""
|
|
32
|
+
Decorator to protect functions from prompt injection.
|
|
33
|
+
|
|
34
|
+
Scans string and document arguments before the function executes.
|
|
35
|
+
Can block, warn, or flag based on detection results.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
on_detect: Action on detection
|
|
39
|
+
- "block": Raise PromptInjectionDetected exception
|
|
40
|
+
- "warn": Log warning but continue execution
|
|
41
|
+
- "flag": Silent (for later inspection)
|
|
42
|
+
confidence_threshold: Minimum confidence to trigger action (0.0-1.0)
|
|
43
|
+
scan_args: Names of arguments to scan. If None, scans all string/list args.
|
|
44
|
+
config: Optional AgentGuard configuration
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Decorator function
|
|
48
|
+
|
|
49
|
+
Example:
|
|
50
|
+
>>> @shield(on_detect="warn")
|
|
51
|
+
... def process_documents(query: str, docs: list[str]) -> str:
|
|
52
|
+
... return llm.invoke(build_prompt(query, docs))
|
|
53
|
+
|
|
54
|
+
>>> @shield(on_detect="block", scan_args=["documents"])
|
|
55
|
+
... def answer_question(question: str, documents: list[str]) -> str:
|
|
56
|
+
... # Only 'documents' will be scanned, not 'question'
|
|
57
|
+
... return generate_answer(question, documents)
|
|
58
|
+
"""
|
|
59
|
+
from agentguard.core.shield import AgentGuard
|
|
60
|
+
|
|
61
|
+
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
62
|
+
# Initialize shield once per decorated function
|
|
63
|
+
_shield = AgentGuard(config=config)
|
|
64
|
+
|
|
65
|
+
@functools.wraps(func)
|
|
66
|
+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
67
|
+
# Extract texts to scan from arguments
|
|
68
|
+
texts_to_scan = _extract_texts(func, args, kwargs, scan_args)
|
|
69
|
+
|
|
70
|
+
if texts_to_scan:
|
|
71
|
+
# Scan all texts
|
|
72
|
+
results = _shield.scan(texts_to_scan)
|
|
73
|
+
if not isinstance(results, list):
|
|
74
|
+
results = [results]
|
|
75
|
+
|
|
76
|
+
# Filter by confidence threshold
|
|
77
|
+
suspicious = [
|
|
78
|
+
r for r in results
|
|
79
|
+
if r.confidence >= confidence_threshold and r.is_suspicious
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
if suspicious:
|
|
83
|
+
if on_detect == "block":
|
|
84
|
+
raise PromptInjectionDetected(
|
|
85
|
+
f"Blocked: {len(suspicious)} suspicious input(s) detected",
|
|
86
|
+
results=suspicious,
|
|
87
|
+
)
|
|
88
|
+
elif on_detect == "warn":
|
|
89
|
+
for result in suspicious:
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"Prompt injection detected "
|
|
92
|
+
f"(confidence={result.confidence:.2f}): "
|
|
93
|
+
f"{result.details.summary}"
|
|
94
|
+
)
|
|
95
|
+
# "flag" mode: do nothing, just continue
|
|
96
|
+
|
|
97
|
+
return func(*args, **kwargs)
|
|
98
|
+
|
|
99
|
+
return wrapper
|
|
100
|
+
|
|
101
|
+
return decorator
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _extract_texts(
|
|
105
|
+
func: Callable[..., Any],
|
|
106
|
+
args: tuple,
|
|
107
|
+
kwargs: Dict[str, Any],
|
|
108
|
+
scan_args: Optional[List[str]],
|
|
109
|
+
) -> List[str]:
|
|
110
|
+
"""
|
|
111
|
+
Extract string/list arguments to scan from function call.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
func: The decorated function
|
|
115
|
+
args: Positional arguments
|
|
116
|
+
kwargs: Keyword arguments
|
|
117
|
+
scan_args: Specific argument names to scan (None = all)
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
List of text strings to scan
|
|
121
|
+
"""
|
|
122
|
+
sig = inspect.signature(func)
|
|
123
|
+
bound = sig.bind(*args, **kwargs)
|
|
124
|
+
bound.apply_defaults()
|
|
125
|
+
|
|
126
|
+
texts: List[str] = []
|
|
127
|
+
|
|
128
|
+
for name, value in bound.arguments.items():
|
|
129
|
+
# Skip if scan_args specified and this arg not in it
|
|
130
|
+
if scan_args and name not in scan_args:
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
texts.extend(_extract_texts_from_value(value))
|
|
134
|
+
|
|
135
|
+
return texts
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _extract_texts_from_value(value: Any) -> List[str]:
|
|
139
|
+
"""Extract text strings from a value (handles nested structures)."""
|
|
140
|
+
texts: List[str] = []
|
|
141
|
+
|
|
142
|
+
if isinstance(value, str):
|
|
143
|
+
texts.append(value)
|
|
144
|
+
elif isinstance(value, list):
|
|
145
|
+
for item in value:
|
|
146
|
+
texts.extend(_extract_texts_from_value(item))
|
|
147
|
+
elif isinstance(value, dict):
|
|
148
|
+
for v in value.values():
|
|
149
|
+
texts.extend(_extract_texts_from_value(v))
|
|
150
|
+
elif hasattr(value, "page_content"):
|
|
151
|
+
# LangChain Document
|
|
152
|
+
texts.append(str(value.page_content))
|
|
153
|
+
elif hasattr(value, "text"):
|
|
154
|
+
# LlamaIndex Node/Document
|
|
155
|
+
texts.append(str(value.text))
|
|
156
|
+
|
|
157
|
+
return texts
|
agentguard/api/scan.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Simple scan() function API."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional, Union, overload
|
|
7
|
+
|
|
8
|
+
from agentguard.core.config import ShieldConfig
|
|
9
|
+
from agentguard.core.results import ScanResult
|
|
10
|
+
|
|
11
|
+
# Global default shield instance (lazy initialized)
|
|
12
|
+
_default_shield: Any = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _get_default_shield() -> Any:
|
|
16
|
+
"""Get or create the default AgentGuard instance."""
|
|
17
|
+
global _default_shield
|
|
18
|
+
if _default_shield is None:
|
|
19
|
+
from agentguard.core.shield import AgentGuard
|
|
20
|
+
|
|
21
|
+
_default_shield = AgentGuard()
|
|
22
|
+
return _default_shield
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def configure(
|
|
26
|
+
config: Optional[Union[ShieldConfig, Dict[str, Any], str, Path]] = None,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Configure the global default AgentGuard instance.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
config: Configuration (ShieldConfig, dict, path to YAML, or None)
|
|
34
|
+
**kwargs: Additional config options (merged with config)
|
|
35
|
+
"""
|
|
36
|
+
global _default_shield
|
|
37
|
+
from agentguard.core.shield import AgentGuard
|
|
38
|
+
|
|
39
|
+
if kwargs:
|
|
40
|
+
if isinstance(config, dict):
|
|
41
|
+
config = {**config, **kwargs}
|
|
42
|
+
elif config is None:
|
|
43
|
+
config = kwargs
|
|
44
|
+
|
|
45
|
+
_default_shield = AgentGuard(config=config)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@overload
|
|
49
|
+
def scan(text: str) -> ScanResult:
|
|
50
|
+
...
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@overload
|
|
54
|
+
def scan(text: List[str]) -> List[ScanResult]:
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def scan(text: Union[str, List[str]]) -> Union[ScanResult, List[ScanResult]]:
|
|
59
|
+
"""
|
|
60
|
+
Scan text for prompt injections.
|
|
61
|
+
|
|
62
|
+
This is the simplest interface for using AgentGuard. It uses a global
|
|
63
|
+
default shield instance that can be configured via `configure()`.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
text: Single text string or list of texts to scan
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
ScanResult for single text, or list of ScanResults for multiple texts
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
>>> from agentguard import scan
|
|
73
|
+
>>> result = scan("Hello, this is normal text")
|
|
74
|
+
>>> result.is_suspicious
|
|
75
|
+
False
|
|
76
|
+
|
|
77
|
+
>>> result = scan("IGNORE ALL PREVIOUS INSTRUCTIONS")
|
|
78
|
+
>>> result.is_suspicious
|
|
79
|
+
True
|
|
80
|
+
>>> result.confidence
|
|
81
|
+
0.87
|
|
82
|
+
"""
|
|
83
|
+
shield = _get_default_shield()
|
|
84
|
+
return shield.scan(text)
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Text cleaning utilities for AgentGuard."""
|
|
2
|
+
|
|
3
|
+
from agentguard.cleaning.base import TextCleaner
|
|
4
|
+
from agentguard.cleaning.heuristic import HeuristicCleaner
|
|
5
|
+
from agentguard.cleaning.hybrid import HybridCleaner, HybridMode, create_hybrid_cleaner
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"TextCleaner",
|
|
9
|
+
"HeuristicCleaner",
|
|
10
|
+
"HybridCleaner",
|
|
11
|
+
"HybridMode",
|
|
12
|
+
"create_hybrid_cleaner",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
# Conditional import for LLM cleaner (requires openai)
|
|
16
|
+
try:
|
|
17
|
+
from agentguard.cleaning.llm import LLMCleaner
|
|
18
|
+
__all__.append("LLMCleaner")
|
|
19
|
+
except ImportError:
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
# Conditional import for finetuned cleaner (requires transformers)
|
|
23
|
+
try:
|
|
24
|
+
from agentguard.cleaning.finetuned import FinetunedCleaner
|
|
25
|
+
__all__.append("FinetunedCleaner")
|
|
26
|
+
except ImportError:
|
|
27
|
+
pass
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Base text cleaner interface."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import List, Protocol, runtime_checkable
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class TextCleaner(Protocol):
|
|
10
|
+
"""
|
|
11
|
+
Protocol for text cleaning implementations.
|
|
12
|
+
|
|
13
|
+
Text cleaners remove potential injection content from text,
|
|
14
|
+
producing a "clean" version for comparison with the original.
|
|
15
|
+
The semantic drift between original and cleaned versions
|
|
16
|
+
indicates the presence of injected content.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def method(self) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Get the cleaning method name.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Method identifier (e.g., "heuristic", "llm")
|
|
26
|
+
"""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
def clean(self, text: str) -> str:
|
|
30
|
+
"""
|
|
31
|
+
Clean text by removing potential injection content.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
text: Original text that may contain injections
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
Cleaned text with injection attempts removed
|
|
38
|
+
"""
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
def clean_batch(self, texts: List[str]) -> List[str]:
|
|
42
|
+
"""
|
|
43
|
+
Clean multiple texts.
|
|
44
|
+
|
|
45
|
+
Default implementation calls clean() for each text.
|
|
46
|
+
Subclasses may override for batch efficiency.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
texts: List of texts to clean
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
List of cleaned texts
|
|
53
|
+
"""
|
|
54
|
+
...
|