pyagentshield 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. agentguard/__init__.py +65 -0
  2. agentguard/api/__init__.py +9 -0
  3. agentguard/api/decorator.py +157 -0
  4. agentguard/api/scan.py +84 -0
  5. agentguard/cleaning/__init__.py +27 -0
  6. agentguard/cleaning/base.py +54 -0
  7. agentguard/cleaning/finetuned.py +363 -0
  8. agentguard/cleaning/heuristic.py +298 -0
  9. agentguard/cleaning/hybrid.py +360 -0
  10. agentguard/cleaning/llm.py +146 -0
  11. agentguard/cli/__init__.py +1 -0
  12. agentguard/cli/main.py +412 -0
  13. agentguard/core/__init__.py +23 -0
  14. agentguard/core/config.py +267 -0
  15. agentguard/core/exceptions.py +66 -0
  16. agentguard/core/results.py +134 -0
  17. agentguard/core/setup.py +169 -0
  18. agentguard/core/shield.py +328 -0
  19. agentguard/data/__init__.py +1 -0
  20. agentguard/detectors/__init__.py +10 -0
  21. agentguard/detectors/base.py +80 -0
  22. agentguard/detectors/zedd.py +327 -0
  23. agentguard/integrations/__init__.py +12 -0
  24. agentguard/integrations/langchain.py +267 -0
  25. agentguard/providers/__init__.py +23 -0
  26. agentguard/providers/base.py +73 -0
  27. agentguard/providers/local.py +208 -0
  28. agentguard/providers/mlx.py +446 -0
  29. agentguard/providers/openai.py +216 -0
  30. agentguard/py.typed +0 -0
  31. agentguard/threshold/__init__.py +11 -0
  32. agentguard/threshold/calibrator.py +421 -0
  33. agentguard/threshold/manager.py +260 -0
  34. agentguard/threshold/registry.py +90 -0
  35. pyagentshield-0.1.0.dist-info/METADATA +616 -0
  36. pyagentshield-0.1.0.dist-info/RECORD +39 -0
  37. pyagentshield-0.1.0.dist-info/WHEEL +4 -0
  38. pyagentshield-0.1.0.dist-info/entry_points.txt +2 -0
  39. pyagentshield-0.1.0.dist-info/licenses/LICENSE +21 -0
agentguard/__init__.py ADDED
@@ -0,0 +1,65 @@
1
+ """
2
+ AgentGuard - Prompt injection detection for Agents.
3
+
4
+ Uses ZEDD (Zero-Shot Embedding Drift Detection) to identify malicious
5
+ content in retrieved documents before they reach the LLM context window.
6
+
7
+ Basic usage:
8
+ >>> from agentguard import scan
9
+ >>> result = scan("some document text")
10
+ >>> if result.is_suspicious:
11
+ ... print(f"Detected: {result.details.summary}")
12
+
13
+ Decorator usage:
14
+ >>> from agentguard import shield
15
+ >>> @shield(on_detect="warn")
16
+ ... def process_docs(query: str, docs: list[str]) -> str:
17
+ ... return llm.invoke(build_prompt(query, docs))
18
+
19
+ LangChain integration:
20
+ >>> from agentguard.integrations.langchain import ShieldRunnable
21
+ >>> chain = retriever | ShieldRunnable() | prompt | llm
22
+ """
23
+
24
+ # Load environment variables from .env file
25
+ from dotenv import load_dotenv
26
+ load_dotenv()
27
+
28
+ from agentguard.core.config import ShieldConfig
29
+ from agentguard.core.results import ScanResult, DetectionSignal, ScanDetails
30
+ from agentguard.core.exceptions import (
31
+ AgentGuardError,
32
+ PromptInjectionDetected,
33
+ CalibrationError,
34
+ ConfigurationError,
35
+ SetupError,
36
+ )
37
+ from agentguard.core.shield import AgentGuard
38
+ from agentguard.core.setup import setup, is_model_cached, SetupResult
39
+ from agentguard.api.scan import scan
40
+ from agentguard.api.decorator import shield
41
+
42
+ __version__ = "0.1.0"
43
+
44
+ __all__ = [
45
+ # Main class
46
+ "AgentGuard",
47
+ # API functions
48
+ "scan",
49
+ "shield",
50
+ # Setup
51
+ "setup",
52
+ "is_model_cached",
53
+ "SetupResult",
54
+ # Config and results
55
+ "ShieldConfig",
56
+ "ScanResult",
57
+ "DetectionSignal",
58
+ "ScanDetails",
59
+ # Exceptions
60
+ "AgentGuardError",
61
+ "PromptInjectionDetected",
62
+ "CalibrationError",
63
+ "ConfigurationError",
64
+ "SetupError",
65
+ ]
@@ -0,0 +1,9 @@
1
+ """Public API functions for AgentGuard."""
2
+
3
+ from agentguard.api.scan import scan
4
+ from agentguard.api.decorator import shield
5
+
6
+ __all__ = [
7
+ "scan",
8
+ "shield",
9
+ ]
@@ -0,0 +1,157 @@
1
+ """Decorator API for AgentGuard."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import inspect
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Dict, List, Optional, Union, TypeVar
10
+
11
+ try:
12
+ from typing import ParamSpec
13
+ except ImportError:
14
+ from typing_extensions import ParamSpec
15
+
16
+ from agentguard.core.config import ShieldConfig
17
+ from agentguard.core.exceptions import PromptInjectionDetected
18
+
19
+ P = ParamSpec("P")
20
+ R = TypeVar("R")
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def shield(
26
+ on_detect: str = "warn",
27
+ confidence_threshold: float = 0.5,
28
+ scan_args: Optional[List[str]] = None,
29
+ config: Optional[Union[ShieldConfig, Dict[str, Any], str, Path]] = None,
30
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
31
+ """
32
+ Decorator to protect functions from prompt injection.
33
+
34
+ Scans string and document arguments before the function executes.
35
+ Can block, warn, or flag based on detection results.
36
+
37
+ Args:
38
+ on_detect: Action on detection
39
+ - "block": Raise PromptInjectionDetected exception
40
+ - "warn": Log warning but continue execution
41
+ - "flag": Silent (for later inspection)
42
+ confidence_threshold: Minimum confidence to trigger action (0.0-1.0)
43
+ scan_args: Names of arguments to scan. If None, scans all string/list args.
44
+ config: Optional AgentGuard configuration
45
+
46
+ Returns:
47
+ Decorator function
48
+
49
+ Example:
50
+ >>> @shield(on_detect="warn")
51
+ ... def process_documents(query: str, docs: list[str]) -> str:
52
+ ... return llm.invoke(build_prompt(query, docs))
53
+
54
+ >>> @shield(on_detect="block", scan_args=["documents"])
55
+ ... def answer_question(question: str, documents: list[str]) -> str:
56
+ ... # Only 'documents' will be scanned, not 'question'
57
+ ... return generate_answer(question, documents)
58
+ """
59
+ from agentguard.core.shield import AgentGuard
60
+
61
+ def decorator(func: Callable[P, R]) -> Callable[P, R]:
62
+ # Initialize shield once per decorated function
63
+ _shield = AgentGuard(config=config)
64
+
65
+ @functools.wraps(func)
66
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
67
+ # Extract texts to scan from arguments
68
+ texts_to_scan = _extract_texts(func, args, kwargs, scan_args)
69
+
70
+ if texts_to_scan:
71
+ # Scan all texts
72
+ results = _shield.scan(texts_to_scan)
73
+ if not isinstance(results, list):
74
+ results = [results]
75
+
76
+ # Filter by confidence threshold
77
+ suspicious = [
78
+ r for r in results
79
+ if r.confidence >= confidence_threshold and r.is_suspicious
80
+ ]
81
+
82
+ if suspicious:
83
+ if on_detect == "block":
84
+ raise PromptInjectionDetected(
85
+ f"Blocked: {len(suspicious)} suspicious input(s) detected",
86
+ results=suspicious,
87
+ )
88
+ elif on_detect == "warn":
89
+ for result in suspicious:
90
+ logger.warning(
91
+ f"Prompt injection detected "
92
+ f"(confidence={result.confidence:.2f}): "
93
+ f"{result.details.summary}"
94
+ )
95
+ # "flag" mode: do nothing, just continue
96
+
97
+ return func(*args, **kwargs)
98
+
99
+ return wrapper
100
+
101
+ return decorator
102
+
103
+
104
+ def _extract_texts(
105
+ func: Callable[..., Any],
106
+ args: tuple,
107
+ kwargs: Dict[str, Any],
108
+ scan_args: Optional[List[str]],
109
+ ) -> List[str]:
110
+ """
111
+ Extract string/list arguments to scan from function call.
112
+
113
+ Args:
114
+ func: The decorated function
115
+ args: Positional arguments
116
+ kwargs: Keyword arguments
117
+ scan_args: Specific argument names to scan (None = all)
118
+
119
+ Returns:
120
+ List of text strings to scan
121
+ """
122
+ sig = inspect.signature(func)
123
+ bound = sig.bind(*args, **kwargs)
124
+ bound.apply_defaults()
125
+
126
+ texts: List[str] = []
127
+
128
+ for name, value in bound.arguments.items():
129
+ # Skip if scan_args specified and this arg not in it
130
+ if scan_args and name not in scan_args:
131
+ continue
132
+
133
+ texts.extend(_extract_texts_from_value(value))
134
+
135
+ return texts
136
+
137
+
138
+ def _extract_texts_from_value(value: Any) -> List[str]:
139
+ """Extract text strings from a value (handles nested structures)."""
140
+ texts: List[str] = []
141
+
142
+ if isinstance(value, str):
143
+ texts.append(value)
144
+ elif isinstance(value, list):
145
+ for item in value:
146
+ texts.extend(_extract_texts_from_value(item))
147
+ elif isinstance(value, dict):
148
+ for v in value.values():
149
+ texts.extend(_extract_texts_from_value(v))
150
+ elif hasattr(value, "page_content"):
151
+ # LangChain Document
152
+ texts.append(str(value.page_content))
153
+ elif hasattr(value, "text"):
154
+ # LlamaIndex Node/Document
155
+ texts.append(str(value.text))
156
+
157
+ return texts
agentguard/api/scan.py ADDED
@@ -0,0 +1,84 @@
1
+ """Simple scan() function API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional, Union, overload
7
+
8
+ from agentguard.core.config import ShieldConfig
9
+ from agentguard.core.results import ScanResult
10
+
11
+ # Global default shield instance (lazy initialized)
12
+ _default_shield: Any = None
13
+
14
+
15
+ def _get_default_shield() -> Any:
16
+ """Get or create the default AgentGuard instance."""
17
+ global _default_shield
18
+ if _default_shield is None:
19
+ from agentguard.core.shield import AgentGuard
20
+
21
+ _default_shield = AgentGuard()
22
+ return _default_shield
23
+
24
+
25
+ def configure(
26
+ config: Optional[Union[ShieldConfig, Dict[str, Any], str, Path]] = None,
27
+ **kwargs: Any,
28
+ ) -> None:
29
+ """
30
+ Configure the global default AgentGuard instance.
31
+
32
+ Args:
33
+ config: Configuration (ShieldConfig, dict, path to YAML, or None)
34
+ **kwargs: Additional config options (merged with config)
35
+ """
36
+ global _default_shield
37
+ from agentguard.core.shield import AgentGuard
38
+
39
+ if kwargs:
40
+ if isinstance(config, dict):
41
+ config = {**config, **kwargs}
42
+ elif config is None:
43
+ config = kwargs
44
+
45
+ _default_shield = AgentGuard(config=config)
46
+
47
+
48
+ @overload
49
+ def scan(text: str) -> ScanResult:
50
+ ...
51
+
52
+
53
+ @overload
54
+ def scan(text: List[str]) -> List[ScanResult]:
55
+ ...
56
+
57
+
58
+ def scan(text: Union[str, List[str]]) -> Union[ScanResult, List[ScanResult]]:
59
+ """
60
+ Scan text for prompt injections.
61
+
62
+ This is the simplest interface for using AgentGuard. It uses a global
63
+ default shield instance that can be configured via `configure()`.
64
+
65
+ Args:
66
+ text: Single text string or list of texts to scan
67
+
68
+ Returns:
69
+ ScanResult for single text, or list of ScanResults for multiple texts
70
+
71
+ Example:
72
+ >>> from agentguard import scan
73
+ >>> result = scan("Hello, this is normal text")
74
+ >>> result.is_suspicious
75
+ False
76
+
77
+ >>> result = scan("IGNORE ALL PREVIOUS INSTRUCTIONS")
78
+ >>> result.is_suspicious
79
+ True
80
+ >>> result.confidence
81
+ 0.87
82
+ """
83
+ shield = _get_default_shield()
84
+ return shield.scan(text)
@@ -0,0 +1,27 @@
1
+ """Text cleaning utilities for AgentGuard."""
2
+
3
+ from agentguard.cleaning.base import TextCleaner
4
+ from agentguard.cleaning.heuristic import HeuristicCleaner
5
+ from agentguard.cleaning.hybrid import HybridCleaner, HybridMode, create_hybrid_cleaner
6
+
7
+ __all__ = [
8
+ "TextCleaner",
9
+ "HeuristicCleaner",
10
+ "HybridCleaner",
11
+ "HybridMode",
12
+ "create_hybrid_cleaner",
13
+ ]
14
+
15
+ # Conditional import for LLM cleaner (requires openai)
16
+ try:
17
+ from agentguard.cleaning.llm import LLMCleaner
18
+ __all__.append("LLMCleaner")
19
+ except ImportError:
20
+ pass
21
+
22
+ # Conditional import for finetuned cleaner (requires transformers)
23
+ try:
24
+ from agentguard.cleaning.finetuned import FinetunedCleaner
25
+ __all__.append("FinetunedCleaner")
26
+ except ImportError:
27
+ pass
@@ -0,0 +1,54 @@
1
+ """Base text cleaner interface."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List, Protocol, runtime_checkable
6
+
7
+
8
+ @runtime_checkable
9
+ class TextCleaner(Protocol):
10
+ """
11
+ Protocol for text cleaning implementations.
12
+
13
+ Text cleaners remove potential injection content from text,
14
+ producing a "clean" version for comparison with the original.
15
+ The semantic drift between original and cleaned versions
16
+ indicates the presence of injected content.
17
+ """
18
+
19
+ @property
20
+ def method(self) -> str:
21
+ """
22
+ Get the cleaning method name.
23
+
24
+ Returns:
25
+ Method identifier (e.g., "heuristic", "llm")
26
+ """
27
+ ...
28
+
29
+ def clean(self, text: str) -> str:
30
+ """
31
+ Clean text by removing potential injection content.
32
+
33
+ Args:
34
+ text: Original text that may contain injections
35
+
36
+ Returns:
37
+ Cleaned text with injection attempts removed
38
+ """
39
+ ...
40
+
41
+ def clean_batch(self, texts: List[str]) -> List[str]:
42
+ """
43
+ Clean multiple texts.
44
+
45
+ Default implementation calls clean() for each text.
46
+ Subclasses may override for batch efficiency.
47
+
48
+ Args:
49
+ texts: List of texts to clean
50
+
51
+ Returns:
52
+ List of cleaned texts
53
+ """
54
+ ...