aiproxyguard-python-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiproxyguard/__init__.py +58 -0
- aiproxyguard/client.py +619 -0
- aiproxyguard/decorators.py +245 -0
- aiproxyguard/exceptions.py +76 -0
- aiproxyguard/models.py +222 -0
- aiproxyguard/py.typed +0 -0
- aiproxyguard_python_sdk-0.1.0.dist-info/METADATA +384 -0
- aiproxyguard_python_sdk-0.1.0.dist-info/RECORD +10 -0
- aiproxyguard_python_sdk-0.1.0.dist-info/WHEEL +4 -0
- aiproxyguard_python_sdk-0.1.0.dist-info/licenses/LICENSE +201 -0
aiproxyguard/__init__.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""AIProxyGuard Python SDK - LLM security proxy for prompt injection detection.
|
|
2
|
+
|
|
3
|
+
Example:
|
|
4
|
+
>>> from aiproxyguard import AIProxyGuard
|
|
5
|
+
>>> client = AIProxyGuard("https://docker.aiproxyguard.com")
|
|
6
|
+
>>> result = client.check("Ignore all previous instructions")
|
|
7
|
+
>>> if result.is_blocked:
|
|
8
|
+
... print(f"Blocked: {result.category}")
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from .client import AIProxyGuard, ApiMode
|
|
12
|
+
from .decorators import GuardConfigurationError, guard, guard_output
|
|
13
|
+
from .exceptions import (
|
|
14
|
+
AIProxyGuardError,
|
|
15
|
+
ConnectionError,
|
|
16
|
+
ContentBlockedError,
|
|
17
|
+
RateLimitError,
|
|
18
|
+
ServerError,
|
|
19
|
+
TimeoutError,
|
|
20
|
+
ValidationError,
|
|
21
|
+
)
|
|
22
|
+
from .models import (
|
|
23
|
+
Action,
|
|
24
|
+
CheckResult,
|
|
25
|
+
CloudCheckResult,
|
|
26
|
+
HealthStatus,
|
|
27
|
+
ReadyStatus,
|
|
28
|
+
ServiceInfo,
|
|
29
|
+
ThreatDetail,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__version__ = "0.1.0"
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
# Client
|
|
36
|
+
"AIProxyGuard",
|
|
37
|
+
"ApiMode",
|
|
38
|
+
# Models
|
|
39
|
+
"Action",
|
|
40
|
+
"CheckResult",
|
|
41
|
+
"CloudCheckResult",
|
|
42
|
+
"HealthStatus",
|
|
43
|
+
"ReadyStatus",
|
|
44
|
+
"ServiceInfo",
|
|
45
|
+
"ThreatDetail",
|
|
46
|
+
# Exceptions
|
|
47
|
+
"AIProxyGuardError",
|
|
48
|
+
"ConnectionError",
|
|
49
|
+
"ContentBlockedError",
|
|
50
|
+
"GuardConfigurationError",
|
|
51
|
+
"RateLimitError",
|
|
52
|
+
"ServerError",
|
|
53
|
+
"TimeoutError",
|
|
54
|
+
"ValidationError",
|
|
55
|
+
# Decorators
|
|
56
|
+
"guard",
|
|
57
|
+
"guard_output",
|
|
58
|
+
]
|
aiproxyguard/client.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
1
|
+
"""AIProxyGuard client for prompt injection detection."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import random
|
|
7
|
+
import time
|
|
8
|
+
from collections.abc import Awaitable
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from typing import Any, Callable, TypeVar
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
|
|
14
|
+
from .exceptions import (
|
|
15
|
+
AIProxyGuardError,
|
|
16
|
+
ConnectionError,
|
|
17
|
+
RateLimitError,
|
|
18
|
+
ServerError,
|
|
19
|
+
TimeoutError,
|
|
20
|
+
ValidationError,
|
|
21
|
+
)
|
|
22
|
+
from .models import (
|
|
23
|
+
CheckResult,
|
|
24
|
+
CloudCheckResult,
|
|
25
|
+
HealthStatus,
|
|
26
|
+
ReadyStatus,
|
|
27
|
+
ServiceInfo,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Maximum characters to include from response text in error messages
|
|
31
|
+
_MAX_ERROR_TEXT_LENGTH = 200
|
|
32
|
+
|
|
33
|
+
T = TypeVar("T")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ApiMode(str, Enum):
|
|
37
|
+
"""API mode for the client."""
|
|
38
|
+
|
|
39
|
+
PROXY = "proxy" # Direct proxy (e.g., docker.aiproxyguard.com)
|
|
40
|
+
CLOUD = "cloud" # Cloud API (e.g., aiproxyguard.com/api/v1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AIProxyGuard:
|
|
44
|
+
"""AIProxyGuard client for prompt injection detection.
|
|
45
|
+
|
|
46
|
+
Provides both synchronous and asynchronous methods for checking text
|
|
47
|
+
against the AIProxyGuard API for potential prompt injection attacks.
|
|
48
|
+
|
|
49
|
+
Supports two API modes:
|
|
50
|
+
- "proxy": Direct proxy mode (docker.aiproxyguard.com) - simpler
|
|
51
|
+
- "cloud": Cloud API mode (aiproxyguard.com) - caching, rate limiting
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
base_url: Base URL of the AIProxyGuard service.
|
|
55
|
+
api_key: Optional API key for authentication (required for cloud mode).
|
|
56
|
+
timeout: Request timeout in seconds. Defaults to 30.0.
|
|
57
|
+
retries: Number of retry attempts for transient failures. Defaults to 3.
|
|
58
|
+
retry_delay: Initial delay between retries in seconds. Defaults to 0.5.
|
|
59
|
+
max_concurrency: Max concurrent requests for batch operations. Defaults to 10.
|
|
60
|
+
api_mode: API mode - "proxy" or "cloud". Auto-detected from URL.
|
|
61
|
+
|
|
62
|
+
Example:
|
|
63
|
+
>>> # Direct proxy mode
|
|
64
|
+
>>> client = AIProxyGuard("https://docker.aiproxyguard.com")
|
|
65
|
+
>>> result = client.check("Ignore all previous instructions")
|
|
66
|
+
|
|
67
|
+
>>> # Cloud API mode
|
|
68
|
+
>>> client = AIProxyGuard(
|
|
69
|
+
... "https://aiproxyguard.com",
|
|
70
|
+
... api_key="apg_xxx",
|
|
71
|
+
... api_mode="cloud"
|
|
72
|
+
... )
|
|
73
|
+
>>> result = client.check("Ignore all previous instructions")
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
def __init__(
|
|
77
|
+
self,
|
|
78
|
+
base_url: str = "http://localhost:8080",
|
|
79
|
+
api_key: str | None = None,
|
|
80
|
+
timeout: float = 30.0,
|
|
81
|
+
retries: int = 3,
|
|
82
|
+
retry_delay: float = 0.5,
|
|
83
|
+
max_concurrency: int = 10,
|
|
84
|
+
api_mode: str | None = None,
|
|
85
|
+
allow_insecure: bool = False,
|
|
86
|
+
) -> None:
|
|
87
|
+
self.base_url = base_url.rstrip("/")
|
|
88
|
+
self._api_key = api_key
|
|
89
|
+
self.timeout = timeout
|
|
90
|
+
self.retries = retries
|
|
91
|
+
self.retry_delay = retry_delay
|
|
92
|
+
self.max_concurrency = max_concurrency
|
|
93
|
+
self._client: httpx.Client | None = None
|
|
94
|
+
self._async_client: httpx.AsyncClient | None = None
|
|
95
|
+
self._pending_async_close: httpx.AsyncClient | None = None
|
|
96
|
+
|
|
97
|
+
# Security: Reject plain HTTP with API keys unless explicitly allowed
|
|
98
|
+
if api_key and self.base_url.startswith("http://") and not allow_insecure:
|
|
99
|
+
# Allow localhost for development
|
|
100
|
+
if not any(h in self.base_url for h in ("localhost", "127.0.0.1", "[::1]")):
|
|
101
|
+
raise ValidationError(
|
|
102
|
+
"API key provided with non-HTTPS URL. "
|
|
103
|
+
"Use HTTPS or set allow_insecure=True for testing."
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Auto-detect API mode from URL if not specified
|
|
107
|
+
if api_mode is None:
|
|
108
|
+
# Cloud mode if URL contains aiproxyguard.com but not docker.
|
|
109
|
+
if "aiproxyguard.com" in self.base_url and "docker" not in self.base_url:
|
|
110
|
+
self._api_mode = ApiMode.CLOUD
|
|
111
|
+
else:
|
|
112
|
+
self._api_mode = ApiMode.PROXY
|
|
113
|
+
else:
|
|
114
|
+
self._api_mode = ApiMode(api_mode)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def api_key(self) -> str | None:
|
|
118
|
+
"""Get the current API key."""
|
|
119
|
+
return self._api_key
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def api_mode(self) -> ApiMode:
|
|
123
|
+
"""Get the current API mode."""
|
|
124
|
+
return self._api_mode
|
|
125
|
+
|
|
126
|
+
def set_api_key(self, api_key: str | None) -> None:
|
|
127
|
+
"""Update the API key and rebuild HTTP clients.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
api_key: New API key, or None to remove authentication.
|
|
131
|
+
|
|
132
|
+
Note:
|
|
133
|
+
If an async client exists, it will be scheduled for cleanup.
|
|
134
|
+
Call aclose() or use the async context manager for proper cleanup.
|
|
135
|
+
"""
|
|
136
|
+
self._api_key = api_key
|
|
137
|
+
# Close existing clients so they get rebuilt with new headers
|
|
138
|
+
if self._client is not None:
|
|
139
|
+
self._client.close()
|
|
140
|
+
self._client = None
|
|
141
|
+
if self._async_client is not None:
|
|
142
|
+
# Track for cleanup - will be closed on next aclose() or close()
|
|
143
|
+
self._pending_async_close = self._async_client
|
|
144
|
+
self._async_client = None
|
|
145
|
+
|
|
146
|
+
def _get_headers(self) -> dict[str, str]:
|
|
147
|
+
"""Build request headers."""
|
|
148
|
+
headers = {"Content-Type": "application/json"}
|
|
149
|
+
if self._api_key:
|
|
150
|
+
headers["X-API-Key"] = self._api_key
|
|
151
|
+
return headers
|
|
152
|
+
|
|
153
|
+
def _get_client(self) -> httpx.Client:
|
|
154
|
+
"""Get or create the sync HTTP client."""
|
|
155
|
+
if self._client is None:
|
|
156
|
+
self._client = httpx.Client(
|
|
157
|
+
base_url=self.base_url,
|
|
158
|
+
headers=self._get_headers(),
|
|
159
|
+
timeout=self.timeout,
|
|
160
|
+
)
|
|
161
|
+
return self._client
|
|
162
|
+
|
|
163
|
+
def _get_async_client(self) -> httpx.AsyncClient:
|
|
164
|
+
"""Get or create the async HTTP client."""
|
|
165
|
+
if self._async_client is None:
|
|
166
|
+
self._async_client = httpx.AsyncClient(
|
|
167
|
+
base_url=self.base_url,
|
|
168
|
+
headers=self._get_headers(),
|
|
169
|
+
timeout=self.timeout,
|
|
170
|
+
)
|
|
171
|
+
return self._async_client
|
|
172
|
+
|
|
173
|
+
def _truncate_error_text(self, text: str) -> str:
|
|
174
|
+
"""Truncate response text for error messages to prevent log pollution."""
|
|
175
|
+
if len(text) <= _MAX_ERROR_TEXT_LENGTH:
|
|
176
|
+
return text
|
|
177
|
+
return text[:_MAX_ERROR_TEXT_LENGTH] + "..."
|
|
178
|
+
|
|
179
|
+
def _handle_error(self, response: httpx.Response) -> None:
|
|
180
|
+
"""Handle error responses from the API."""
|
|
181
|
+
if response.status_code == 429:
|
|
182
|
+
retry_after = response.headers.get("Retry-After")
|
|
183
|
+
raise RateLimitError(
|
|
184
|
+
"Rate limited",
|
|
185
|
+
retry_after=int(retry_after) if retry_after else None,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# 5xx errors are server errors (retryable)
|
|
189
|
+
if response.status_code >= 500:
|
|
190
|
+
raise ServerError(
|
|
191
|
+
f"Server error: HTTP {response.status_code}",
|
|
192
|
+
status_code=response.status_code,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# 4xx errors are client errors (not retryable)
|
|
196
|
+
if response.status_code >= 400:
|
|
197
|
+
try:
|
|
198
|
+
data = response.json()
|
|
199
|
+
# Handle different error formats
|
|
200
|
+
if "error" in data:
|
|
201
|
+
error = data["error"]
|
|
202
|
+
if isinstance(error, dict):
|
|
203
|
+
raise ValidationError(
|
|
204
|
+
error.get("message", "Unknown error"),
|
|
205
|
+
code=error.get("type"),
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
raise ValidationError(str(error))
|
|
209
|
+
elif "detail" in data:
|
|
210
|
+
# FastAPI style error
|
|
211
|
+
raise ValidationError(str(data["detail"]))
|
|
212
|
+
else:
|
|
213
|
+
raise ValidationError(str(data))
|
|
214
|
+
except (ValueError, KeyError, TypeError):
|
|
215
|
+
error_text = self._truncate_error_text(response.text)
|
|
216
|
+
raise AIProxyGuardError(f"HTTP {response.status_code}: {error_text}")
|
|
217
|
+
|
|
218
|
+
def _get_check_endpoint(self) -> str:
|
|
219
|
+
"""Get the check endpoint based on API mode."""
|
|
220
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
221
|
+
return "/api/v1/check"
|
|
222
|
+
return "/check"
|
|
223
|
+
|
|
224
|
+
def _build_check_payload(
|
|
225
|
+
self, text: str, context: dict[str, Any] | None = None
|
|
226
|
+
) -> dict[str, Any]:
|
|
227
|
+
"""Build the request payload based on API mode."""
|
|
228
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
229
|
+
payload: dict[str, Any] = {"input": text}
|
|
230
|
+
if context:
|
|
231
|
+
payload["context"] = context
|
|
232
|
+
return payload
|
|
233
|
+
return {"text": text}
|
|
234
|
+
|
|
235
|
+
def _parse_check_response(self, data: dict[str, Any]) -> CheckResult:
|
|
236
|
+
"""Parse check response based on API mode."""
|
|
237
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
238
|
+
return CheckResult.from_cloud_dict(data)
|
|
239
|
+
return CheckResult.from_dict(data)
|
|
240
|
+
|
|
241
|
+
def _calculate_delay(self, attempt: int, rate_limit_retry: int | None) -> float:
|
|
242
|
+
"""Calculate delay with exponential backoff and jitter."""
|
|
243
|
+
if rate_limit_retry is not None:
|
|
244
|
+
return float(rate_limit_retry)
|
|
245
|
+
base_delay: float = self.retry_delay * (2**attempt)
|
|
246
|
+
jitter: float = random.uniform(0, 0.1 * base_delay)
|
|
247
|
+
return base_delay + jitter
|
|
248
|
+
|
|
249
|
+
def _retry_sync(self, operation: Callable[[], T]) -> T:
|
|
250
|
+
"""Execute operation with retry logic (sync)."""
|
|
251
|
+
last_exception: Exception | None = None
|
|
252
|
+
|
|
253
|
+
for attempt in range(self.retries + 1):
|
|
254
|
+
try:
|
|
255
|
+
return operation()
|
|
256
|
+
|
|
257
|
+
except httpx.TimeoutException:
|
|
258
|
+
last_exception = TimeoutError("Request timed out")
|
|
259
|
+
except httpx.ConnectError:
|
|
260
|
+
last_exception = ConnectionError("Failed to connect to AIProxyGuard")
|
|
261
|
+
except RateLimitError as e:
|
|
262
|
+
last_exception = e
|
|
263
|
+
if attempt < self.retries:
|
|
264
|
+
time.sleep(self._calculate_delay(attempt, e.retry_after))
|
|
265
|
+
continue
|
|
266
|
+
except ServerError as e:
|
|
267
|
+
last_exception = e
|
|
268
|
+
except ValidationError:
|
|
269
|
+
raise
|
|
270
|
+
|
|
271
|
+
if attempt < self.retries:
|
|
272
|
+
time.sleep(self._calculate_delay(attempt, None))
|
|
273
|
+
|
|
274
|
+
raise last_exception or AIProxyGuardError("Request failed after retries")
|
|
275
|
+
|
|
276
|
+
async def _retry_async(self, operation: Callable[[], Awaitable[T]]) -> T:
|
|
277
|
+
"""Execute operation with retry logic (async)."""
|
|
278
|
+
last_exception: Exception | None = None
|
|
279
|
+
|
|
280
|
+
for attempt in range(self.retries + 1):
|
|
281
|
+
try:
|
|
282
|
+
return await operation()
|
|
283
|
+
|
|
284
|
+
except httpx.TimeoutException:
|
|
285
|
+
last_exception = TimeoutError("Request timed out")
|
|
286
|
+
except httpx.ConnectError:
|
|
287
|
+
last_exception = ConnectionError("Failed to connect to AIProxyGuard")
|
|
288
|
+
except RateLimitError as e:
|
|
289
|
+
last_exception = e
|
|
290
|
+
if attempt < self.retries:
|
|
291
|
+
await asyncio.sleep(self._calculate_delay(attempt, e.retry_after))
|
|
292
|
+
continue
|
|
293
|
+
except ServerError as e:
|
|
294
|
+
last_exception = e
|
|
295
|
+
except ValidationError:
|
|
296
|
+
raise
|
|
297
|
+
|
|
298
|
+
if attempt < self.retries:
|
|
299
|
+
await asyncio.sleep(self._calculate_delay(attempt, None))
|
|
300
|
+
|
|
301
|
+
raise last_exception or AIProxyGuardError("Request failed after retries")
|
|
302
|
+
|
|
303
|
+
# -------------------------------------------------------------------------
|
|
304
|
+
# Sync API
|
|
305
|
+
# -------------------------------------------------------------------------
|
|
306
|
+
|
|
307
|
+
def check(
|
|
308
|
+
self, text: str, context: dict[str, Any] | None = None
|
|
309
|
+
) -> CheckResult:
|
|
310
|
+
"""Check text for prompt injection.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
text: The text to scan for prompt injection.
|
|
314
|
+
context: Optional context metadata (cloud mode only).
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
CheckResult with action, category, signature_name, and confidence.
|
|
318
|
+
|
|
319
|
+
Raises:
|
|
320
|
+
ValidationError: If the request is invalid.
|
|
321
|
+
TimeoutError: If the request times out.
|
|
322
|
+
RateLimitError: If rate limited.
|
|
323
|
+
ConnectionError: If connection fails.
|
|
324
|
+
ServerError: If the server returns a 5xx error.
|
|
325
|
+
AIProxyGuardError: For other errors.
|
|
326
|
+
"""
|
|
327
|
+
client = self._get_client()
|
|
328
|
+
endpoint = self._get_check_endpoint()
|
|
329
|
+
payload = self._build_check_payload(text, context)
|
|
330
|
+
|
|
331
|
+
def do_check() -> CheckResult:
|
|
332
|
+
response = client.post(endpoint, json=payload)
|
|
333
|
+
self._handle_error(response)
|
|
334
|
+
return self._parse_check_response(response.json())
|
|
335
|
+
|
|
336
|
+
return self._retry_sync(do_check)
|
|
337
|
+
|
|
338
|
+
def check_cloud(
|
|
339
|
+
self, text: str, context: dict[str, Any] | None = None
|
|
340
|
+
) -> CloudCheckResult:
|
|
341
|
+
"""Check text and return full cloud API response (cloud mode only).
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
text: The text to scan for prompt injection.
|
|
345
|
+
context: Optional context metadata (e.g., {"provider": "openai"}).
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
CloudCheckResult with full response including id, latency_ms, cached.
|
|
349
|
+
|
|
350
|
+
Raises:
|
|
351
|
+
AIProxyGuardError: If not in cloud mode or request fails.
|
|
352
|
+
"""
|
|
353
|
+
if self._api_mode != ApiMode.CLOUD:
|
|
354
|
+
raise AIProxyGuardError("check_cloud() requires cloud API mode")
|
|
355
|
+
|
|
356
|
+
client = self._get_client()
|
|
357
|
+
payload = self._build_check_payload(text, context)
|
|
358
|
+
|
|
359
|
+
def do_check() -> CloudCheckResult:
|
|
360
|
+
response = client.post("/api/v1/check", json=payload)
|
|
361
|
+
self._handle_error(response)
|
|
362
|
+
return CloudCheckResult.from_dict(response.json())
|
|
363
|
+
|
|
364
|
+
return self._retry_sync(do_check)
|
|
365
|
+
|
|
366
|
+
def check_batch(self, texts: list[str]) -> list[CheckResult]:
|
|
367
|
+
"""Check multiple texts for prompt injection.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
texts: List of texts to scan.
|
|
371
|
+
|
|
372
|
+
Returns:
|
|
373
|
+
List of CheckResult objects in the same order as inputs.
|
|
374
|
+
"""
|
|
375
|
+
return [self.check(text) for text in texts]
|
|
376
|
+
|
|
377
|
+
def is_safe(self, text: str) -> bool:
|
|
378
|
+
"""Check if text is safe (not blocked).
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
text: The text to scan.
|
|
382
|
+
|
|
383
|
+
Returns:
|
|
384
|
+
True if the text is safe, False if blocked.
|
|
385
|
+
"""
|
|
386
|
+
return self.check(text).is_safe
|
|
387
|
+
|
|
388
|
+
def info(self) -> ServiceInfo:
|
|
389
|
+
"""Get service information (proxy mode only).
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
ServiceInfo with service name and version.
|
|
393
|
+
|
|
394
|
+
Raises:
|
|
395
|
+
AIProxyGuardError: If called in cloud mode.
|
|
396
|
+
"""
|
|
397
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
398
|
+
raise AIProxyGuardError("info() is not available in cloud mode")
|
|
399
|
+
client = self._get_client()
|
|
400
|
+
response = client.get("/")
|
|
401
|
+
self._handle_error(response)
|
|
402
|
+
return ServiceInfo.from_dict(response.json())
|
|
403
|
+
|
|
404
|
+
def health(self) -> HealthStatus:
|
|
405
|
+
"""Check service health.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
HealthStatus with health status.
|
|
409
|
+
"""
|
|
410
|
+
client = self._get_client()
|
|
411
|
+
endpoint = "/health" if self._api_mode == ApiMode.CLOUD else "/healthz"
|
|
412
|
+
try:
|
|
413
|
+
response = client.get(endpoint)
|
|
414
|
+
if response.status_code == 200:
|
|
415
|
+
return HealthStatus.from_dict(response.json())
|
|
416
|
+
return HealthStatus(status="unhealthy", healthy=False)
|
|
417
|
+
except Exception:
|
|
418
|
+
return HealthStatus(status="unreachable", healthy=False)
|
|
419
|
+
|
|
420
|
+
def ready(self) -> ReadyStatus:
|
|
421
|
+
"""Check service readiness (proxy mode only).
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
ReadyStatus with readiness status and individual checks.
|
|
425
|
+
|
|
426
|
+
Raises:
|
|
427
|
+
AIProxyGuardError: If called in cloud mode.
|
|
428
|
+
"""
|
|
429
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
430
|
+
raise AIProxyGuardError("ready() is not available in cloud mode")
|
|
431
|
+
client = self._get_client()
|
|
432
|
+
try:
|
|
433
|
+
response = client.get("/readyz")
|
|
434
|
+
return ReadyStatus.from_dict(response.json())
|
|
435
|
+
except Exception:
|
|
436
|
+
return ReadyStatus(status="unreachable", ready=False, checks={})
|
|
437
|
+
|
|
438
|
+
# -------------------------------------------------------------------------
|
|
439
|
+
# Async API
|
|
440
|
+
# -------------------------------------------------------------------------
|
|
441
|
+
|
|
442
|
+
async def check_async(
|
|
443
|
+
self, text: str, context: dict[str, Any] | None = None
|
|
444
|
+
) -> CheckResult:
|
|
445
|
+
"""Async version of check().
|
|
446
|
+
|
|
447
|
+
Args:
|
|
448
|
+
text: The text to scan for prompt injection.
|
|
449
|
+
context: Optional context metadata (cloud mode only).
|
|
450
|
+
|
|
451
|
+
Returns:
|
|
452
|
+
CheckResult with action, category, signature_name, and confidence.
|
|
453
|
+
"""
|
|
454
|
+
client = self._get_async_client()
|
|
455
|
+
endpoint = self._get_check_endpoint()
|
|
456
|
+
payload = self._build_check_payload(text, context)
|
|
457
|
+
|
|
458
|
+
async def do_check() -> CheckResult:
|
|
459
|
+
response = await client.post(endpoint, json=payload)
|
|
460
|
+
self._handle_error(response)
|
|
461
|
+
return self._parse_check_response(response.json())
|
|
462
|
+
|
|
463
|
+
return await self._retry_async(do_check)
|
|
464
|
+
|
|
465
|
+
async def check_cloud_async(
|
|
466
|
+
self, text: str, context: dict[str, Any] | None = None
|
|
467
|
+
) -> CloudCheckResult:
|
|
468
|
+
"""Async version of check_cloud()."""
|
|
469
|
+
if self._api_mode != ApiMode.CLOUD:
|
|
470
|
+
raise AIProxyGuardError("check_cloud_async() requires cloud API mode")
|
|
471
|
+
|
|
472
|
+
client = self._get_async_client()
|
|
473
|
+
payload = self._build_check_payload(text, context)
|
|
474
|
+
|
|
475
|
+
async def do_check() -> CloudCheckResult:
|
|
476
|
+
response = await client.post("/api/v1/check", json=payload)
|
|
477
|
+
self._handle_error(response)
|
|
478
|
+
return CloudCheckResult.from_dict(response.json())
|
|
479
|
+
|
|
480
|
+
return await self._retry_async(do_check)
|
|
481
|
+
|
|
482
|
+
async def check_batch_async(
|
|
483
|
+
self, texts: list[str], max_concurrency: int | None = None
|
|
484
|
+
) -> list[CheckResult]:
|
|
485
|
+
"""Async version of check_batch(). Runs checks with bounded concurrency.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
texts: List of texts to scan.
|
|
489
|
+
max_concurrency: Max concurrent requests. Uses client default if None.
|
|
490
|
+
|
|
491
|
+
Returns:
|
|
492
|
+
List of CheckResult objects in the same order as inputs.
|
|
493
|
+
"""
|
|
494
|
+
limit = max_concurrency or self.max_concurrency
|
|
495
|
+
semaphore = asyncio.Semaphore(limit)
|
|
496
|
+
|
|
497
|
+
async def bounded_check(text: str) -> CheckResult:
|
|
498
|
+
async with semaphore:
|
|
499
|
+
return await self.check_async(text)
|
|
500
|
+
|
|
501
|
+
tasks = [bounded_check(text) for text in texts]
|
|
502
|
+
return await asyncio.gather(*tasks)
|
|
503
|
+
|
|
504
|
+
async def is_safe_async(self, text: str) -> bool:
|
|
505
|
+
"""Async version of is_safe().
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
text: The text to scan.
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
True if the text is safe, False if blocked.
|
|
512
|
+
"""
|
|
513
|
+
result = await self.check_async(text)
|
|
514
|
+
return result.is_safe
|
|
515
|
+
|
|
516
|
+
async def info_async(self) -> ServiceInfo:
|
|
517
|
+
"""Async version of info()."""
|
|
518
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
519
|
+
raise AIProxyGuardError("info_async() is not available in cloud mode")
|
|
520
|
+
client = self._get_async_client()
|
|
521
|
+
response = await client.get("/")
|
|
522
|
+
self._handle_error(response)
|
|
523
|
+
return ServiceInfo.from_dict(response.json())
|
|
524
|
+
|
|
525
|
+
async def health_async(self) -> HealthStatus:
|
|
526
|
+
"""Async version of health()."""
|
|
527
|
+
client = self._get_async_client()
|
|
528
|
+
endpoint = "/health" if self._api_mode == ApiMode.CLOUD else "/healthz"
|
|
529
|
+
try:
|
|
530
|
+
response = await client.get(endpoint)
|
|
531
|
+
if response.status_code == 200:
|
|
532
|
+
return HealthStatus.from_dict(response.json())
|
|
533
|
+
return HealthStatus(status="unhealthy", healthy=False)
|
|
534
|
+
except Exception:
|
|
535
|
+
return HealthStatus(status="unreachable", healthy=False)
|
|
536
|
+
|
|
537
|
+
async def ready_async(self) -> ReadyStatus:
|
|
538
|
+
"""Async version of ready()."""
|
|
539
|
+
if self._api_mode == ApiMode.CLOUD:
|
|
540
|
+
raise AIProxyGuardError("ready_async() is not available in cloud mode")
|
|
541
|
+
client = self._get_async_client()
|
|
542
|
+
try:
|
|
543
|
+
response = await client.get("/readyz")
|
|
544
|
+
return ReadyStatus.from_dict(response.json())
|
|
545
|
+
except Exception:
|
|
546
|
+
return ReadyStatus(status="unreachable", ready=False, checks={})
|
|
547
|
+
|
|
548
|
+
# -------------------------------------------------------------------------
|
|
549
|
+
# Context Manager
|
|
550
|
+
# -------------------------------------------------------------------------
|
|
551
|
+
|
|
552
|
+
def __enter__(self) -> AIProxyGuard:
|
|
553
|
+
return self
|
|
554
|
+
|
|
555
|
+
def __exit__(self, *args: Any) -> None:
|
|
556
|
+
self.close()
|
|
557
|
+
|
|
558
|
+
async def __aenter__(self) -> AIProxyGuard:
|
|
559
|
+
return self
|
|
560
|
+
|
|
561
|
+
async def __aexit__(self, *args: Any) -> None:
|
|
562
|
+
await self.aclose()
|
|
563
|
+
|
|
564
|
+
def close(self) -> None:
|
|
565
|
+
"""Close all clients and release resources.
|
|
566
|
+
|
|
567
|
+
Note:
|
|
568
|
+
For async clients, this performs a best-effort synchronous close.
|
|
569
|
+
Use aclose() in async contexts for proper cleanup.
|
|
570
|
+
"""
|
|
571
|
+
if self._client:
|
|
572
|
+
self._client.close()
|
|
573
|
+
self._client = None
|
|
574
|
+
# Close pending async client from set_api_key()
|
|
575
|
+
if self._pending_async_close:
|
|
576
|
+
try:
|
|
577
|
+
# Best effort - httpx AsyncClient has _closed flag we can check
|
|
578
|
+
if not getattr(self._pending_async_close, "_closed", True):
|
|
579
|
+
# Create event loop if needed for cleanup
|
|
580
|
+
try:
|
|
581
|
+
loop = asyncio.get_running_loop()
|
|
582
|
+
loop.create_task(self._pending_async_close.aclose())
|
|
583
|
+
except RuntimeError:
|
|
584
|
+
# No event loop - just let it be garbage collected
|
|
585
|
+
pass
|
|
586
|
+
except Exception:
|
|
587
|
+
pass
|
|
588
|
+
self._pending_async_close = None
|
|
589
|
+
# Close current async client if it exists
|
|
590
|
+
if self._async_client:
|
|
591
|
+
try:
|
|
592
|
+
if not getattr(self._async_client, "_closed", True):
|
|
593
|
+
try:
|
|
594
|
+
loop = asyncio.get_running_loop()
|
|
595
|
+
loop.create_task(self._async_client.aclose())
|
|
596
|
+
except RuntimeError:
|
|
597
|
+
pass
|
|
598
|
+
except Exception:
|
|
599
|
+
pass
|
|
600
|
+
self._async_client = None
|
|
601
|
+
|
|
602
|
+
async def aclose(self) -> None:
|
|
603
|
+
"""Close all clients and release resources (async)."""
|
|
604
|
+
# Close pending async client from set_api_key()
|
|
605
|
+
if self._pending_async_close:
|
|
606
|
+
await self._pending_async_close.aclose()
|
|
607
|
+
self._pending_async_close = None
|
|
608
|
+
# Close current async client
|
|
609
|
+
if self._async_client:
|
|
610
|
+
await self._async_client.aclose()
|
|
611
|
+
self._async_client = None
|
|
612
|
+
# Close sync client if it exists
|
|
613
|
+
if self._client:
|
|
614
|
+
self._client.close()
|
|
615
|
+
self._client = None
|
|
616
|
+
|
|
617
|
+
def __repr__(self) -> str:
|
|
618
|
+
mode = self._api_mode.value
|
|
619
|
+
return f"AIProxyGuard(base_url={self.base_url!r}, api_mode={mode!r})"
|