aiproxyguard-python-sdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,245 @@
1
+ """AIProxyGuard decorator utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import inspect
7
+ import warnings
8
+ from functools import wraps
9
+ from typing import Any, Callable, TypeVar, overload
10
+
11
+ from .exceptions import ContentBlockedError
12
+
13
+ if __import__("typing").TYPE_CHECKING:
14
+ from .client import AIProxyGuard
15
+
16
+ F = TypeVar("F", bound=Callable[..., Any])
17
+
18
+
19
+ class GuardConfigurationError(ValueError):
20
+ """Raised when the guard decorator is misconfigured."""
21
+
22
+ pass
23
+
24
+
25
+ @overload
26
+ def guard(
27
+ client: AIProxyGuard,
28
+ *,
29
+ input_arg: str = "prompt",
30
+ raise_on_block: bool = True,
31
+ fail_closed: bool = True,
32
+ ) -> Callable[[F], F]: ...
33
+
34
+
35
+ @overload
36
+ def guard(
37
+ client: AIProxyGuard,
38
+ *,
39
+ input_arg: int,
40
+ raise_on_block: bool = True,
41
+ fail_closed: bool = True,
42
+ ) -> Callable[[F], F]: ...
43
+
44
+
45
+ def guard(
46
+ client: AIProxyGuard,
47
+ *,
48
+ input_arg: str | int = "prompt",
49
+ raise_on_block: bool = True,
50
+ fail_closed: bool = True,
51
+ ) -> Callable[[F], F]:
52
+ """Decorator to guard a function with prompt injection detection.
53
+
54
+ Checks the specified input argument before the function executes.
55
+ If the content is blocked and raise_on_block is True, raises ContentBlockedError.
56
+
57
+ Args:
58
+ client: AIProxyGuard client instance.
59
+ input_arg: Name or index of the argument to check. Defaults to "prompt".
60
+ Can be a string (kwarg name) or int (positional index).
61
+ raise_on_block: If True, raise ContentBlockedError when content is blocked.
62
+ If False, the function is not called but no error is raised.
63
+ fail_closed: If True (default), raise GuardConfigurationError when input_arg
64
+ cannot be resolved. If False, issue a warning and skip checking.
65
+
66
+ Returns:
67
+ Decorated function that checks input before execution.
68
+
69
+ Raises:
70
+ GuardConfigurationError: If input_arg doesn't match any parameter
71
+ (at decoration time for string args, or at call time if fail_closed).
72
+
73
+ Example:
74
+ >>> client = AIProxyGuard("http://localhost:8080")
75
+ >>> @guard(client)
76
+ ... def call_llm(prompt: str) -> str:
77
+ ... return "response"
78
+ ...
79
+ >>> @guard(client, input_arg="user_input")
80
+ ... def process(user_input: str, system: str) -> str:
81
+ ... return "processed"
82
+ ...
83
+ >>> @guard(client, input_arg=0)
84
+ ... def handle(text: str) -> str:
85
+ ... return "handled"
86
+ """
87
+
88
+ def decorator(func: F) -> F:
89
+ is_async = asyncio.iscoroutinefunction(func)
90
+
91
+ # Cache signature introspection at decoration time for performance
92
+ sig = inspect.signature(func)
93
+ param_names = list(sig.parameters.keys())
94
+
95
+ # Validate string input_arg at decoration time
96
+ if isinstance(input_arg, str):
97
+ if input_arg not in param_names:
98
+ raise GuardConfigurationError(
99
+ f"guard(): input_arg '{input_arg}' not found in function "
100
+ f"'{func.__name__}' parameters: {param_names}"
101
+ )
102
+ # Pre-compute the positional index for this parameter
103
+ cached_arg_index: int | None = param_names.index(input_arg)
104
+ else:
105
+ cached_arg_index = None
106
+
107
+ def _extract_text(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str | None:
108
+ """Extract text to check from function arguments."""
109
+ text: Any = None
110
+ resolved = False
111
+
112
+ if isinstance(input_arg, int):
113
+ if len(args) > input_arg:
114
+ text = args[input_arg]
115
+ resolved = True
116
+ elif fail_closed:
117
+ raise GuardConfigurationError(
118
+ f"guard(): input_arg index {input_arg} out of range. "
119
+ f"'{func.__name__}' got {len(args)} positional args."
120
+ )
121
+ else:
122
+ warnings.warn(
123
+ f"guard(): input_arg index {input_arg} out of range for "
124
+ f"'{func.__name__}'. Skipping security check.",
125
+ RuntimeWarning,
126
+ stacklevel=4,
127
+ )
128
+ else:
129
+ # String input_arg - use cached index
130
+ if input_arg in kwargs:
131
+ text = kwargs[input_arg]
132
+ resolved = True
133
+ elif cached_arg_index is not None and cached_arg_index < len(args):
134
+ text = args[cached_arg_index]
135
+ resolved = True
136
+ else:
137
+ # This shouldn't happen if validation passed, but handle edge cases
138
+ if fail_closed:
139
+ raise GuardConfigurationError(
140
+ f"guard(): Could not resolve input_arg '{input_arg}' "
141
+ f"for function '{func.__name__}'."
142
+ )
143
+ else:
144
+ warnings.warn(
145
+ f"guard(): Could not resolve '{input_arg}' for "
146
+ f"'{func.__name__}'. Skipping security check.",
147
+ RuntimeWarning,
148
+ stacklevel=4,
149
+ )
150
+
151
+ if not resolved:
152
+ return None
153
+
154
+ return str(text) if text is not None else ""
155
+
156
+ @wraps(func)
157
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
158
+ text = _extract_text(args, kwargs)
159
+
160
+ # Check even empty strings (text is not None means we resolved the arg)
161
+ if text is not None:
162
+ result = client.check(text)
163
+ if result.is_blocked:
164
+ if raise_on_block:
165
+ raise ContentBlockedError(result)
166
+ return None
167
+
168
+ return func(*args, **kwargs)
169
+
170
+ @wraps(func)
171
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
172
+ text = _extract_text(args, kwargs)
173
+
174
+ # Check even empty strings (text is not None means we resolved the arg)
175
+ if text is not None:
176
+ result = await client.check_async(text)
177
+ if result.is_blocked:
178
+ if raise_on_block:
179
+ raise ContentBlockedError(result)
180
+ return None
181
+
182
+ return await func(*args, **kwargs)
183
+
184
+ return async_wrapper if is_async else sync_wrapper # type: ignore[return-value]
185
+
186
+ return decorator
187
+
188
+
189
+ def guard_output(
190
+ client: AIProxyGuard,
191
+ *,
192
+ raise_on_block: bool = True,
193
+ ) -> Callable[[F], F]:
194
+ """Decorator to guard a function's output with prompt injection detection.
195
+
196
+ Checks the function's return value after execution.
197
+ Useful for validating LLM responses before returning them.
198
+
199
+ Args:
200
+ client: AIProxyGuard client instance.
201
+ raise_on_block: If True, raise ContentBlockedError when content is blocked.
202
+ If False, returns None instead of the blocked content.
203
+
204
+ Returns:
205
+ Decorated function that checks output after execution.
206
+
207
+ Example:
208
+ >>> client = AIProxyGuard("http://localhost:8080")
209
+ >>> @guard_output(client)
210
+ ... def get_llm_response(prompt: str) -> str:
211
+ ... return llm.generate(prompt)
212
+ """
213
+
214
+ def decorator(func: F) -> F:
215
+ is_async = asyncio.iscoroutinefunction(func)
216
+
217
+ @wraps(func)
218
+ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
219
+ output = func(*args, **kwargs)
220
+
221
+ if output is not None:
222
+ result = client.check(str(output))
223
+ if result.is_blocked:
224
+ if raise_on_block:
225
+ raise ContentBlockedError(result)
226
+ return None
227
+
228
+ return output
229
+
230
+ @wraps(func)
231
+ async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
232
+ output = await func(*args, **kwargs)
233
+
234
+ if output is not None:
235
+ result = await client.check_async(str(output))
236
+ if result.is_blocked:
237
+ if raise_on_block:
238
+ raise ContentBlockedError(result)
239
+ return None
240
+
241
+ return output
242
+
243
+ return async_wrapper if is_async else sync_wrapper # type: ignore[return-value]
244
+
245
+ return decorator
@@ -0,0 +1,76 @@
1
+ """AIProxyGuard SDK exceptions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from .models import CheckResult
9
+
10
+
11
+ class AIProxyGuardError(Exception):
12
+ """Base exception for AIProxyGuard SDK."""
13
+
14
+ def __init__(self, message: str, code: str | None = None) -> None:
15
+ super().__init__(message)
16
+ self.message = message
17
+ self.code = code
18
+
19
+ def __repr__(self) -> str:
20
+ cls = self.__class__.__name__
21
+ return f"{cls}(message={self.message!r}, code={self.code!r})"
22
+
23
+
24
+ class ValidationError(AIProxyGuardError):
25
+ """Raised when the request is invalid (400 errors)."""
26
+
27
+ pass
28
+
29
+
30
+ class ConnectionError(AIProxyGuardError):
31
+ """Raised when connection to the service fails."""
32
+
33
+ pass
34
+
35
+
36
+ class TimeoutError(AIProxyGuardError):
37
+ """Raised when a request times out."""
38
+
39
+ pass
40
+
41
+
42
+ class ServerError(AIProxyGuardError):
43
+ """Raised when the server returns a 5xx error (retryable)."""
44
+
45
+ def __init__(self, message: str, status_code: int) -> None:
46
+ super().__init__(message, code="server_error")
47
+ self.status_code = status_code
48
+
49
+ def __repr__(self) -> str:
50
+ return f"ServerError(status_code={self.status_code}, message={self.message!r})"
51
+
52
+
53
+ class RateLimitError(AIProxyGuardError):
54
+ """Raised when rate limited (429 errors)."""
55
+
56
+ def __init__(
57
+ self, message: str = "Rate limited", retry_after: int | None = None
58
+ ) -> None:
59
+ super().__init__(message, code="rate_limit")
60
+ self.retry_after = retry_after
61
+
62
+ def __repr__(self) -> str:
63
+ return f"RateLimitError(message={self.message!r}, retry={self.retry_after})"
64
+
65
+
66
+ class ContentBlockedError(AIProxyGuardError):
67
+ """Raised when content is blocked due to prompt injection detection."""
68
+
69
+ def __init__(self, result: CheckResult) -> None:
70
+ super().__init__(f"Content blocked: {result.category}", code="content_blocked")
71
+ self.result = result
72
+
73
+ def __repr__(self) -> str:
74
+ cat = self.result.category
75
+ conf = self.result.confidence
76
+ return f"ContentBlockedError(category={cat!r}, confidence={conf})"
aiproxyguard/models.py ADDED
@@ -0,0 +1,222 @@
1
+ """AIProxyGuard SDK data models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+
10
+ class Action(str, Enum):
11
+ """Action taken by AIProxyGuard on scanned content."""
12
+
13
+ ALLOW = "allow"
14
+ LOG = "log"
15
+ WARN = "warn"
16
+ BLOCK = "block"
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class ThreatDetail:
21
+ """Details about a detected threat (cloud API only).
22
+
23
+ Attributes:
24
+ type: Threat category (e.g., "prompt-injection").
25
+ confidence: Detection confidence (0.0 to 1.0).
26
+ rule: Rule/signature ID that triggered the detection.
27
+ """
28
+
29
+ type: str
30
+ confidence: float
31
+ rule: str | None = None
32
+
33
+ @classmethod
34
+ def from_dict(cls, data: dict[str, Any]) -> ThreatDetail:
35
+ """Create ThreatDetail from API response dictionary."""
36
+ return cls(
37
+ type=data["type"],
38
+ confidence=float(data.get("confidence", 0.0)),
39
+ rule=data.get("rule"),
40
+ )
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class CheckResult:
45
+ """Result from scanning text for prompt injection.
46
+
47
+ Attributes:
48
+ action: The action taken (allow, log, warn, or block).
49
+ category: Category of the detected threat, if any.
50
+ signature_name: Name of the matching signature, if any.
51
+ confidence: Confidence score of the detection (0.0 to 1.0).
52
+ """
53
+
54
+ action: Action
55
+ category: str | None
56
+ signature_name: str | None
57
+ confidence: float
58
+
59
+ @property
60
+ def is_safe(self) -> bool:
61
+ """Returns True if the text was not blocked."""
62
+ return self.action != Action.BLOCK
63
+
64
+ @property
65
+ def is_blocked(self) -> bool:
66
+ """Returns True if the text was blocked."""
67
+ return self.action == Action.BLOCK
68
+
69
+ @property
70
+ def requires_attention(self) -> bool:
71
+ """Returns True if the text requires attention (warn or block)."""
72
+ return self.action in (Action.WARN, Action.BLOCK)
73
+
74
+ @classmethod
75
+ def from_dict(cls, data: dict[str, Any]) -> CheckResult:
76
+ """Create a CheckResult from a proxy API response dictionary."""
77
+ return cls(
78
+ action=Action(data["action"]),
79
+ category=data.get("category"),
80
+ signature_name=data.get("signature_name"),
81
+ confidence=float(data.get("confidence", 0.0)),
82
+ )
83
+
84
+ @classmethod
85
+ def from_cloud_dict(cls, data: dict[str, Any]) -> CheckResult:
86
+ """Create a CheckResult from a cloud API response dictionary.
87
+
88
+ The cloud API returns a different format with threats array.
89
+ """
90
+ threats = data.get("threats", [])
91
+ category = None
92
+ signature_name = None
93
+ confidence = 0.0
94
+
95
+ if threats:
96
+ # Use the first threat for backwards compatibility
97
+ first_threat = threats[0]
98
+ category = first_threat.get("type")
99
+ signature_name = first_threat.get("rule")
100
+ confidence = float(first_threat.get("confidence", 0.0))
101
+
102
+ return cls(
103
+ action=Action(data["action"]),
104
+ category=category,
105
+ signature_name=signature_name,
106
+ confidence=confidence,
107
+ )
108
+
109
+
110
+ @dataclass(frozen=True)
111
+ class CloudCheckResult:
112
+ """Extended result from the cloud API with additional metadata.
113
+
114
+ Attributes:
115
+ id: Unique check ID.
116
+ flagged: Whether any threat was detected.
117
+ action: The action taken (allow, log, warn, or block).
118
+ threats: List of detected threats.
119
+ latency_ms: Processing time in milliseconds.
120
+ cached: Whether result was served from cache.
121
+ """
122
+
123
+ id: str
124
+ flagged: bool
125
+ action: Action
126
+ threats: list[ThreatDetail]
127
+ latency_ms: float
128
+ cached: bool
129
+
130
+ @property
131
+ def is_safe(self) -> bool:
132
+ """Returns True if the text was not blocked."""
133
+ return self.action != Action.BLOCK
134
+
135
+ @property
136
+ def is_blocked(self) -> bool:
137
+ """Returns True if the text was blocked."""
138
+ return self.action == Action.BLOCK
139
+
140
+ @property
141
+ def category(self) -> str | None:
142
+ """Returns the primary threat category, if any."""
143
+ return self.threats[0].type if self.threats else None
144
+
145
+ @property
146
+ def confidence(self) -> float:
147
+ """Returns the primary threat confidence."""
148
+ return self.threats[0].confidence if self.threats else 0.0
149
+
150
+ @classmethod
151
+ def from_dict(cls, data: dict[str, Any]) -> CloudCheckResult:
152
+ """Create CloudCheckResult from API response dictionary."""
153
+ return cls(
154
+ id=data["id"],
155
+ flagged=data["flagged"],
156
+ action=Action(data["action"]),
157
+ threats=[ThreatDetail.from_dict(t) for t in data.get("threats", [])],
158
+ latency_ms=float(data.get("latency_ms", 0.0)),
159
+ cached=data.get("cached", False),
160
+ )
161
+
162
+
163
+ @dataclass(frozen=True)
164
+ class ServiceInfo:
165
+ """Service information from the AIProxyGuard API.
166
+
167
+ Attributes:
168
+ service: Service name.
169
+ version: Service version.
170
+ """
171
+
172
+ service: str
173
+ version: str
174
+
175
+ @classmethod
176
+ def from_dict(cls, data: dict[str, Any]) -> ServiceInfo:
177
+ """Create ServiceInfo from an API response dictionary."""
178
+ return cls(service=data["service"], version=data["version"])
179
+
180
+
181
+ @dataclass(frozen=True)
182
+ class HealthStatus:
183
+ """Health status from the AIProxyGuard API.
184
+
185
+ Attributes:
186
+ status: Health status string (e.g., "healthy").
187
+ healthy: Boolean indicating if the service is healthy.
188
+ """
189
+
190
+ status: str
191
+ healthy: bool
192
+
193
+ @classmethod
194
+ def from_dict(cls, data: dict[str, Any]) -> HealthStatus:
195
+ """Create HealthStatus from an API response dictionary."""
196
+ status = data.get("status", "unknown")
197
+ return cls(status=status, healthy=status == "healthy")
198
+
199
+
200
+ @dataclass(frozen=True)
201
+ class ReadyStatus:
202
+ """Readiness status from the AIProxyGuard API.
203
+
204
+ Attributes:
205
+ status: Readiness status string (e.g., "ready").
206
+ ready: Boolean indicating if the service is ready.
207
+ checks: Dictionary of individual check results.
208
+ """
209
+
210
+ status: str
211
+ ready: bool
212
+ checks: dict[str, Any]
213
+
214
+ @classmethod
215
+ def from_dict(cls, data: dict[str, Any]) -> ReadyStatus:
216
+ """Create ReadyStatus from an API response dictionary."""
217
+ status = data.get("status", "unknown")
218
+ return cls(
219
+ status=status,
220
+ ready=status == "ready",
221
+ checks=data.get("checks", {}),
222
+ )
aiproxyguard/py.typed ADDED
File without changes