koreshield 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,263 @@
1
+ """Asynchronous KoreShield client."""
2
+
3
+ import asyncio
4
+ import time
5
+ from typing import Dict, List, Optional, Any, Union
6
+ import httpx
7
+
8
+ from .types import (
9
+ AuthConfig,
10
+ ScanRequest,
11
+ ScanResponse,
12
+ BatchScanRequest,
13
+ BatchScanResponse,
14
+ DetectionResult,
15
+ )
16
+ from .exceptions import (
17
+ KoreShieldError,
18
+ AuthenticationError,
19
+ ValidationError,
20
+ RateLimitError,
21
+ ServerError,
22
+ NetworkError,
23
+ TimeoutError,
24
+ )
25
+
26
+
27
+ class AsyncKoreShieldClient:
28
+ """Asynchronous KoreShield API client."""
29
+
30
+ def __init__(
31
+ self,
32
+ api_key: str,
33
+ base_url: str = "https://api.koreshield.com",
34
+ timeout: float = 30.0,
35
+ retry_attempts: int = 3,
36
+ retry_delay: float = 1.0,
37
+ ):
38
+ """Initialize the async KoreShield client.
39
+
40
+ Args:
41
+ api_key: Your KoreShield API key
42
+ base_url: Base URL for the API (default: production)
43
+ timeout: Request timeout in seconds
44
+ retry_attempts: Number of retry attempts
45
+ retry_delay: Delay between retries in seconds
46
+ """
47
+ self.auth_config = AuthConfig(
48
+ api_key=api_key,
49
+ base_url=base_url.rstrip("/"),
50
+ timeout=timeout,
51
+ retry_attempts=retry_attempts,
52
+ retry_delay=retry_delay,
53
+ )
54
+
55
+ self.client = httpx.AsyncClient(
56
+ timeout=timeout,
57
+ headers={
58
+ "Authorization": f"Bearer {api_key}",
59
+ "Content-Type": "application/json",
60
+ "User-Agent": f"koreshield-python-sdk/0.1.0",
61
+ },
62
+ )
63
+
64
+ async def __aenter__(self):
65
+ """Async context manager entry."""
66
+ return self
67
+
68
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
69
+ """Async context manager exit."""
70
+ await self.close()
71
+
72
+ async def close(self):
73
+ """Close the HTTP client."""
74
+ await self.client.aclose()
75
+
76
+ async def scan_prompt(self, prompt: str, **kwargs) -> DetectionResult:
77
+ """Scan a single prompt for security threats asynchronously.
78
+
79
+ Args:
80
+ prompt: The prompt text to scan
81
+ **kwargs: Additional context (user_id, session_id, metadata, etc.)
82
+
83
+ Returns:
84
+ DetectionResult with security analysis
85
+
86
+ Raises:
87
+ AuthenticationError: If API key is invalid
88
+ ValidationError: If request is malformed
89
+ RateLimitError: If rate limit exceeded
90
+ ServerError: If server error occurs
91
+ NetworkError: If network error occurs
92
+ TimeoutError: If request times out
93
+ """
94
+ request = ScanRequest(prompt=prompt, **kwargs)
95
+
96
+ for attempt in range(self.auth_config.retry_attempts + 1):
97
+ try:
98
+ response = await self._make_request("POST", "/v1/scan", request.dict())
99
+ scan_response = ScanResponse(**response)
100
+ return scan_response.result
101
+ except (RateLimitError, ServerError, NetworkError) as e:
102
+ if attempt == self.auth_config.retry_attempts:
103
+ raise e
104
+ await asyncio.sleep(self.auth_config.retry_delay * (2 ** attempt))
105
+
106
+ async def scan_batch(
107
+ self,
108
+ prompts: List[str],
109
+ parallel: bool = True,
110
+ max_concurrent: int = 10,
111
+ **kwargs
112
+ ) -> List[DetectionResult]:
113
+ """Scan multiple prompts for security threats asynchronously.
114
+
115
+ Args:
116
+ prompts: List of prompt texts to scan
117
+ parallel: Whether to process in parallel (default: True)
118
+ max_concurrent: Maximum concurrent requests (default: 10)
119
+ **kwargs: Additional context for all requests
120
+
121
+ Returns:
122
+ List of DetectionResult objects
123
+ """
124
+ if not parallel:
125
+ # Sequential processing
126
+ results = []
127
+ for prompt in prompts:
128
+ result = await self.scan_prompt(prompt, **kwargs)
129
+ results.append(result)
130
+ return results
131
+
132
+ # Parallel processing with semaphore for concurrency control
133
+ semaphore = asyncio.Semaphore(max_concurrent)
134
+
135
+ async def scan_with_semaphore(prompt: str) -> DetectionResult:
136
+ async with semaphore:
137
+ return await self.scan_prompt(prompt, **kwargs)
138
+
139
+ tasks = [scan_with_semaphore(prompt) for prompt in prompts]
140
+ return await asyncio.gather(*tasks)
141
+
142
+ async def get_scan_history(self, limit: int = 50, offset: int = 0, **filters) -> Dict[str, Any]:
143
+ """Get scan history with optional filters asynchronously.
144
+
145
+ Args:
146
+ limit: Maximum number of results (default: 50)
147
+ offset: Offset for pagination (default: 0)
148
+ **filters: Additional filters (user_id, threat_level, etc.)
149
+
150
+ Returns:
151
+ Dictionary with scan history and pagination info
152
+ """
153
+ params = {"limit": limit, "offset": offset, **filters}
154
+ return await self._make_request("GET", "/v1/scans", params=params)
155
+
156
+ async def get_scan_details(self, scan_id: str) -> Dict[str, Any]:
157
+ """Get detailed information about a specific scan asynchronously.
158
+
159
+ Args:
160
+ scan_id: The scan ID to retrieve
161
+
162
+ Returns:
163
+ Dictionary with scan details
164
+ """
165
+ return await self._make_request("GET", f"/v1/scans/{scan_id}")
166
+
167
+ async def health_check(self) -> Dict[str, Any]:
168
+ """Check API health and version information asynchronously.
169
+
170
+ Returns:
171
+ Dictionary with health status and version info
172
+ """
173
+ return await self._make_request("GET", "/health")
174
+
175
+ async def _make_request(
176
+ self,
177
+ method: str,
178
+ endpoint: str,
179
+ data: Optional[Dict] = None,
180
+ params: Optional[Dict] = None
181
+ ) -> Dict[str, Any]:
182
+ """Make an asynchronous HTTP request to the API.
183
+
184
+ Args:
185
+ method: HTTP method (GET, POST, etc.)
186
+ endpoint: API endpoint
187
+ data: Request body data
188
+ params: Query parameters
189
+
190
+ Returns:
191
+ Parsed JSON response
192
+
193
+ Raises:
194
+ Various KoreShieldError subclasses based on response
195
+ """
196
+ url = f"{self.auth_config.base_url}{endpoint}"
197
+
198
+ try:
199
+ response = await self.client.request(
200
+ method=method,
201
+ url=url,
202
+ json=data,
203
+ params=params,
204
+ )
205
+
206
+ return self._handle_response(response)
207
+
208
+ except httpx.TimeoutException:
209
+ raise TimeoutError("Request timed out")
210
+ except httpx.ConnectError:
211
+ raise NetworkError("Network connection failed")
212
+ except httpx.RequestError as e:
213
+ raise NetworkError(f"Request failed: {str(e)}")
214
+
215
+ def _handle_response(self, response: httpx.Response) -> Dict[str, Any]:
216
+ """Handle API response and raise appropriate exceptions.
217
+
218
+ Args:
219
+ response: The HTTP response object
220
+
221
+ Returns:
222
+ Parsed JSON response data
223
+
224
+ Raises:
225
+ Various KoreShieldError subclasses
226
+ """
227
+ try:
228
+ data = response.json() if response.content else {}
229
+ except ValueError:
230
+ data = {"message": "Invalid JSON response"}
231
+
232
+ if response.status_code == 200:
233
+ return data
234
+ elif response.status_code == 401:
235
+ raise AuthenticationError(
236
+ data.get("message", "Authentication failed"),
237
+ status_code=response.status_code,
238
+ response_data=data,
239
+ )
240
+ elif response.status_code == 400:
241
+ raise ValidationError(
242
+ data.get("message", "Validation failed"),
243
+ status_code=response.status_code,
244
+ response_data=data,
245
+ )
246
+ elif response.status_code == 429:
247
+ raise RateLimitError(
248
+ data.get("message", "Rate limit exceeded"),
249
+ status_code=response.status_code,
250
+ response_data=data,
251
+ )
252
+ elif response.status_code >= 500:
253
+ raise ServerError(
254
+ data.get("message", "Server error"),
255
+ status_code=response.status_code,
256
+ response_data=data,
257
+ )
258
+ else:
259
+ raise KoreShieldError(
260
+ data.get("message", f"Unexpected error: {response.status_code}"),
261
+ status_code=response.status_code,
262
+ response_data=data,
263
+ )
@@ -0,0 +1,227 @@
1
+ """Synchronous KoreShield client."""
2
+
3
+ import time
4
+ from typing import Dict, List, Optional, Any, Union
5
+ import requests
6
+ from requests.adapters import HTTPAdapter
7
+ from urllib3.util.retry import Retry
8
+
9
+ from .types import (
10
+ AuthConfig,
11
+ ScanRequest,
12
+ ScanResponse,
13
+ BatchScanRequest,
14
+ BatchScanResponse,
15
+ DetectionResult,
16
+ )
17
+ from .exceptions import (
18
+ KoreShieldError,
19
+ AuthenticationError,
20
+ ValidationError,
21
+ RateLimitError,
22
+ ServerError,
23
+ NetworkError,
24
+ TimeoutError,
25
+ )
26
+
27
+
28
+ class KoreShieldClient:
29
+ """Synchronous KoreShield API client."""
30
+
31
+ def __init__(self, api_key: str, base_url: str = "https://api.koreshield.com", timeout: float = 30.0):
32
+ """Initialize the KoreShield client.
33
+
34
+ Args:
35
+ api_key: Your KoreShield API key
36
+ base_url: Base URL for the API (default: production)
37
+ timeout: Request timeout in seconds
38
+ """
39
+ self.auth_config = AuthConfig(
40
+ api_key=api_key,
41
+ base_url=base_url.rstrip("/"),
42
+ timeout=timeout,
43
+ )
44
+
45
+ self.session = requests.Session()
46
+
47
+ # Configure retry strategy
48
+ retry_strategy = Retry(
49
+ total=3,
50
+ status_forcelist=[429, 500, 502, 503, 504],
51
+ backoff_factor=1,
52
+ )
53
+ adapter = HTTPAdapter(max_retries=retry_strategy)
54
+ self.session.mount("http://", adapter)
55
+ self.session.mount("https://", adapter)
56
+
57
+ # Set default headers
58
+ self.session.headers.update({
59
+ "Authorization": f"Bearer {api_key}",
60
+ "Content-Type": "application/json",
61
+ "User-Agent": f"koreshield-python-sdk/0.1.0",
62
+ })
63
+
64
+ def scan_prompt(self, prompt: str, **kwargs) -> DetectionResult:
65
+ """Scan a single prompt for security threats.
66
+
67
+ Args:
68
+ prompt: The prompt text to scan
69
+ **kwargs: Additional context (user_id, session_id, metadata, etc.)
70
+
71
+ Returns:
72
+ DetectionResult with security analysis
73
+
74
+ Raises:
75
+ AuthenticationError: If API key is invalid
76
+ ValidationError: If request is malformed
77
+ RateLimitError: If rate limit exceeded
78
+ ServerError: If server error occurs
79
+ NetworkError: If network error occurs
80
+ TimeoutError: If request times out
81
+ """
82
+ request = ScanRequest(prompt=prompt, **kwargs)
83
+ response = self._make_request("POST", "/v1/scan", request.model_dump())
84
+
85
+ scan_response = ScanResponse(**response)
86
+ return scan_response.result
87
+
88
+ def scan_batch(self, prompts: List[str], parallel: bool = True, max_concurrent: int = 10) -> List[DetectionResult]:
89
+ """Scan multiple prompts for security threats.
90
+
91
+ Args:
92
+ prompts: List of prompt texts to scan
93
+ parallel: Whether to process in parallel (default: True)
94
+ max_concurrent: Maximum concurrent requests (default: 10)
95
+
96
+ Returns:
97
+ List of DetectionResult objects
98
+ """
99
+ requests_list = [ScanRequest(prompt=prompt) for prompt in prompts]
100
+ batch_request = BatchScanRequest(
101
+ requests=requests_list,
102
+ parallel=parallel,
103
+ max_concurrent=max_concurrent,
104
+ )
105
+
106
+ response = self._make_request("POST", "/v1/scan/batch", batch_request.model_dump())
107
+ batch_response = BatchScanResponse(**response)
108
+
109
+ return [scan_response.result for scan_response in batch_response.results]
110
+
111
+ def get_scan_history(self, limit: int = 50, offset: int = 0, **filters) -> Dict[str, Any]:
112
+ """Get scan history with optional filters.
113
+
114
+ Args:
115
+ limit: Maximum number of results (default: 50)
116
+ offset: Offset for pagination (default: 0)
117
+ **filters: Additional filters (user_id, threat_level, etc.)
118
+
119
+ Returns:
120
+ Dictionary with scan history and pagination info
121
+ """
122
+ params = {"limit": limit, "offset": offset, **filters}
123
+ return self._make_request("GET", "/v1/scans", params=params)
124
+
125
+ def get_scan_details(self, scan_id: str) -> Dict[str, Any]:
126
+ """Get detailed information about a specific scan.
127
+
128
+ Args:
129
+ scan_id: The scan ID to retrieve
130
+
131
+ Returns:
132
+ Dictionary with scan details
133
+ """
134
+ return self._make_request("GET", f"/v1/scans/{scan_id}")
135
+
136
+ def health_check(self) -> Dict[str, Any]:
137
+ """Check API health and version information.
138
+
139
+ Returns:
140
+ Dictionary with health status and version info
141
+ """
142
+ return self._make_request("GET", "/health")
143
+
144
+ def _make_request(self, method: str, endpoint: str, data: Optional[Dict] = None, params: Optional[Dict] = None) -> Dict[str, Any]:
145
+ """Make an HTTP request to the API.
146
+
147
+ Args:
148
+ method: HTTP method (GET, POST, etc.)
149
+ endpoint: API endpoint
150
+ data: Request body data
151
+ params: Query parameters
152
+
153
+ Returns:
154
+ Parsed JSON response
155
+
156
+ Raises:
157
+ Various KoreShieldError subclasses based on response
158
+ """
159
+ url = f"{self.auth_config.base_url}{endpoint}"
160
+
161
+ try:
162
+ response = self.session.request(
163
+ method=method,
164
+ url=url,
165
+ json=data,
166
+ params=params,
167
+ timeout=self.auth_config.timeout,
168
+ )
169
+
170
+ return self._handle_response(response)
171
+
172
+ except requests.exceptions.Timeout:
173
+ raise TimeoutError("Request timed out")
174
+ except requests.exceptions.ConnectionError:
175
+ raise NetworkError("Network connection failed")
176
+ except requests.exceptions.RequestException as e:
177
+ raise NetworkError(f"Request failed: {str(e)}")
178
+
179
+ def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
180
+ """Handle API response and raise appropriate exceptions.
181
+
182
+ Args:
183
+ response: The HTTP response object
184
+
185
+ Returns:
186
+ Parsed JSON response data
187
+
188
+ Raises:
189
+ Various KoreShieldError subclasses
190
+ """
191
+ try:
192
+ data = response.json() if response.content else {}
193
+ except ValueError:
194
+ data = {"message": "Invalid JSON response"}
195
+
196
+ if response.status_code == 200:
197
+ return data
198
+ elif response.status_code == 401:
199
+ raise AuthenticationError(
200
+ data.get("message", "Authentication failed"),
201
+ status_code=response.status_code,
202
+ response_data=data,
203
+ )
204
+ elif response.status_code == 400:
205
+ raise ValidationError(
206
+ data.get("message", "Validation failed"),
207
+ status_code=response.status_code,
208
+ response_data=data,
209
+ )
210
+ elif response.status_code == 429:
211
+ raise RateLimitError(
212
+ data.get("message", "Rate limit exceeded"),
213
+ status_code=response.status_code,
214
+ response_data=data,
215
+ )
216
+ elif response.status_code >= 500:
217
+ raise ServerError(
218
+ data.get("message", "Server error"),
219
+ status_code=response.status_code,
220
+ response_data=data,
221
+ )
222
+ else:
223
+ raise KoreShieldError(
224
+ data.get("message", f"Unexpected error: {response.status_code}"),
225
+ status_code=response.status_code,
226
+ response_data=data,
227
+ )
@@ -0,0 +1,41 @@
1
+ """Exceptions for KoreShield SDK."""
2
+
3
+
4
+ class KoreShieldError(Exception):
5
+ """Base exception for KoreShield SDK errors."""
6
+
7
+ def __init__(self, message: str, status_code: int = None, response_data: dict = None):
8
+ super().__init__(message)
9
+ self.message = message
10
+ self.status_code = status_code
11
+ self.response_data = response_data or {}
12
+
13
+
14
+ class AuthenticationError(KoreShieldError):
15
+ """Raised when authentication fails."""
16
+ pass
17
+
18
+
19
+ class ValidationError(KoreShieldError):
20
+ """Raised when request validation fails."""
21
+ pass
22
+
23
+
24
+ class RateLimitError(KoreShieldError):
25
+ """Raised when rate limit is exceeded."""
26
+ pass
27
+
28
+
29
+ class ServerError(KoreShieldError):
30
+ """Raised when server returns an error."""
31
+ pass
32
+
33
+
34
+ class NetworkError(KoreShieldError):
35
+ """Raised when network communication fails."""
36
+ pass
37
+
38
+
39
+ class TimeoutError(KoreShieldError):
40
+ """Raised when request times out."""
41
+ pass
@@ -0,0 +1,15 @@
1
+ """Integrations with popular frameworks and libraries."""
2
+
3
+ from .langchain import (
4
+ KoreShieldCallbackHandler,
5
+ AsyncKoreShieldCallbackHandler,
6
+ create_koreshield_callback,
7
+ create_async_koreshield_callback,
8
+ )
9
+
10
+ __all__ = [
11
+ "KoreShieldCallbackHandler",
12
+ "AsyncKoreShieldCallbackHandler",
13
+ "create_koreshield_callback",
14
+ "create_async_koreshield_callback",
15
+ ]