koreshield 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {koreshield-0.1.4.dist-info → koreshield-0.2.0.dist-info}/METADATA +289 -50
- koreshield-0.2.0.dist-info/RECORD +14 -0
- {koreshield-0.1.4.dist-info → koreshield-0.2.0.dist-info}/WHEEL +1 -1
- koreshield_sdk/async_client.py +298 -17
- koreshield_sdk/integrations/__init__.py +34 -10
- koreshield_sdk/integrations/frameworks.py +361 -0
- koreshield_sdk/types.py +53 -1
- koreshield-0.1.4.dist-info/RECORD +0 -13
- {koreshield-0.1.4.dist-info → koreshield-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {koreshield-0.1.4.dist-info → koreshield-0.2.0.dist-info}/top_level.txt +0 -0
koreshield_sdk/async_client.py
CHANGED
|
@@ -1,8 +1,10 @@
|
|
|
1
|
-
"""Asynchronous KoreShield client."""
|
|
1
|
+
"""Asynchronous KoreShield client with enhanced features."""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import time
|
|
5
|
-
|
|
5
|
+
import json
|
|
6
|
+
from typing import Dict, List, Optional, Any, Union, AsyncGenerator, Callable
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
6
8
|
import httpx
|
|
7
9
|
|
|
8
10
|
from .types import (
|
|
@@ -12,6 +14,10 @@ from .types import (
|
|
|
12
14
|
BatchScanRequest,
|
|
13
15
|
BatchScanResponse,
|
|
14
16
|
DetectionResult,
|
|
17
|
+
StreamingScanRequest,
|
|
18
|
+
StreamingScanResponse,
|
|
19
|
+
SecurityPolicy,
|
|
20
|
+
PerformanceMetrics,
|
|
15
21
|
)
|
|
16
22
|
from .exceptions import (
|
|
17
23
|
KoreShieldError,
|
|
@@ -25,7 +31,7 @@ from .exceptions import (
|
|
|
25
31
|
|
|
26
32
|
|
|
27
33
|
class AsyncKoreShieldClient:
|
|
28
|
-
"""Asynchronous KoreShield API client."""
|
|
34
|
+
"""Asynchronous KoreShield API client with enhanced features."""
|
|
29
35
|
|
|
30
36
|
def __init__(
|
|
31
37
|
self,
|
|
@@ -34,6 +40,9 @@ class AsyncKoreShieldClient:
|
|
|
34
40
|
timeout: float = 30.0,
|
|
35
41
|
retry_attempts: int = 3,
|
|
36
42
|
retry_delay: float = 1.0,
|
|
43
|
+
enable_metrics: bool = True,
|
|
44
|
+
security_policy: Optional[SecurityPolicy] = None,
|
|
45
|
+
connection_pool_limits: Optional[Dict[str, int]] = None,
|
|
37
46
|
):
|
|
38
47
|
"""Initialize the async KoreShield client.
|
|
39
48
|
|
|
@@ -43,6 +52,9 @@ class AsyncKoreShieldClient:
|
|
|
43
52
|
timeout: Request timeout in seconds
|
|
44
53
|
retry_attempts: Number of retry attempts
|
|
45
54
|
retry_delay: Delay between retries in seconds
|
|
55
|
+
enable_metrics: Whether to collect performance metrics
|
|
56
|
+
security_policy: Custom security policy configuration
|
|
57
|
+
connection_pool_limits: HTTP connection pool limits
|
|
46
58
|
"""
|
|
47
59
|
self.auth_config = AuthConfig(
|
|
48
60
|
api_key=api_key,
|
|
@@ -52,12 +64,29 @@ class AsyncKoreShieldClient:
|
|
|
52
64
|
retry_delay=retry_delay,
|
|
53
65
|
)
|
|
54
66
|
|
|
67
|
+
# Performance monitoring
|
|
68
|
+
self.enable_metrics = enable_metrics
|
|
69
|
+
self.metrics = PerformanceMetrics()
|
|
70
|
+
self._start_time = time.time()
|
|
71
|
+
self._request_count = 0
|
|
72
|
+
|
|
73
|
+
# Security policy
|
|
74
|
+
self.security_policy = security_policy or SecurityPolicy(name="default")
|
|
75
|
+
|
|
76
|
+
# Connection pool configuration
|
|
77
|
+
pool_limits = connection_pool_limits or {
|
|
78
|
+
"max_keepalive_connections": 20,
|
|
79
|
+
"max_connections": 100,
|
|
80
|
+
"keepalive_expiry": 30.0,
|
|
81
|
+
}
|
|
82
|
+
|
|
55
83
|
self.client = httpx.AsyncClient(
|
|
56
|
-
timeout=timeout,
|
|
84
|
+
timeout=httpx.Timeout(timeout, connect=10.0),
|
|
85
|
+
limits=httpx.Limits(**pool_limits),
|
|
57
86
|
headers={
|
|
58
87
|
"Authorization": f"Bearer {api_key}",
|
|
59
88
|
"Content-Type": "application/json",
|
|
60
|
-
"User-Agent": f"koreshield-python-sdk/0.
|
|
89
|
+
"User-Agent": f"koreshield-python-sdk/0.2.0",
|
|
61
90
|
},
|
|
62
91
|
)
|
|
63
92
|
|
|
@@ -74,7 +103,7 @@ class AsyncKoreShieldClient:
|
|
|
74
103
|
await self.client.aclose()
|
|
75
104
|
|
|
76
105
|
async def scan_prompt(self, prompt: str, **kwargs) -> DetectionResult:
|
|
77
|
-
"""Scan a single prompt for security threats asynchronously.
|
|
106
|
+
"""Scan a single prompt for security threats asynchronously with enhanced features.
|
|
78
107
|
|
|
79
108
|
Args:
|
|
80
109
|
prompt: The prompt text to scan
|
|
@@ -91,53 +120,305 @@ class AsyncKoreShieldClient:
|
|
|
91
120
|
NetworkError: If network error occurs
|
|
92
121
|
TimeoutError: If request times out
|
|
93
122
|
"""
|
|
123
|
+
start_time = time.time()
|
|
124
|
+
|
|
125
|
+
# Apply security policy filtering
|
|
126
|
+
if not self._passes_security_policy(prompt):
|
|
127
|
+
# Create blocked result based on policy
|
|
128
|
+
processing_time = time.time() - start_time
|
|
129
|
+
self._update_metrics(processing_time)
|
|
130
|
+
|
|
131
|
+
return DetectionResult(
|
|
132
|
+
is_safe=False,
|
|
133
|
+
threat_level=self.security_policy.threat_threshold,
|
|
134
|
+
confidence=1.0,
|
|
135
|
+
indicators=[DetectionIndicator(
|
|
136
|
+
type=DetectionType.RULE,
|
|
137
|
+
severity=self.security_policy.threat_threshold,
|
|
138
|
+
confidence=1.0,
|
|
139
|
+
description="Blocked by security policy",
|
|
140
|
+
metadata={"policy_name": self.security_policy.name}
|
|
141
|
+
)],
|
|
142
|
+
processing_time_ms=processing_time * 1000,
|
|
143
|
+
scan_id=f"policy_block_{int(time.time())}",
|
|
144
|
+
metadata={"blocked_by_policy": True}
|
|
145
|
+
)
|
|
146
|
+
|
|
94
147
|
request = ScanRequest(prompt=prompt, **kwargs)
|
|
95
148
|
|
|
96
149
|
for attempt in range(self.auth_config.retry_attempts + 1):
|
|
97
150
|
try:
|
|
98
151
|
response = await self._make_request("POST", "/v1/scan", request.dict())
|
|
99
152
|
scan_response = ScanResponse(**response)
|
|
153
|
+
|
|
154
|
+
processing_time = time.time() - start_time
|
|
155
|
+
self._update_metrics(processing_time)
|
|
156
|
+
|
|
100
157
|
return scan_response.result
|
|
158
|
+
|
|
101
159
|
except (RateLimitError, ServerError, NetworkError) as e:
|
|
102
160
|
if attempt == self.auth_config.retry_attempts:
|
|
161
|
+
processing_time = time.time() - start_time
|
|
162
|
+
self._update_metrics(processing_time, is_error=True)
|
|
103
163
|
raise e
|
|
104
164
|
await asyncio.sleep(self.auth_config.retry_delay * (2 ** attempt))
|
|
105
165
|
|
|
166
|
+
def _passes_security_policy(self, prompt: str) -> bool:
|
|
167
|
+
"""Check if prompt passes the current security policy.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
prompt: The prompt to check
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
True if prompt passes policy, False if blocked
|
|
174
|
+
"""
|
|
175
|
+
# Check blocklist patterns first (blocking takes precedence)
|
|
176
|
+
for pattern in self.security_policy.blocklist_patterns:
|
|
177
|
+
if pattern.lower() in prompt.lower():
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
# Check allowlist patterns
|
|
181
|
+
for pattern in self.security_policy.allowlist_patterns:
|
|
182
|
+
if pattern.lower() in prompt.lower():
|
|
183
|
+
return True
|
|
184
|
+
|
|
185
|
+
return True
|
|
186
|
+
|
|
106
187
|
async def scan_batch(
|
|
107
188
|
self,
|
|
108
189
|
prompts: List[str],
|
|
109
190
|
parallel: bool = True,
|
|
110
191
|
max_concurrent: int = 10,
|
|
192
|
+
batch_size: int = 50,
|
|
193
|
+
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
111
194
|
**kwargs
|
|
112
195
|
) -> List[DetectionResult]:
|
|
113
|
-
"""Scan multiple prompts for security threats asynchronously.
|
|
196
|
+
"""Scan multiple prompts for security threats asynchronously with enhanced features.
|
|
114
197
|
|
|
115
198
|
Args:
|
|
116
199
|
prompts: List of prompt texts to scan
|
|
117
200
|
parallel: Whether to process in parallel (default: True)
|
|
118
201
|
max_concurrent: Maximum concurrent requests (default: 10)
|
|
202
|
+
batch_size: Size of each batch for processing (default: 50)
|
|
203
|
+
progress_callback: Optional callback for progress updates (current, total)
|
|
119
204
|
**kwargs: Additional context for all requests
|
|
120
205
|
|
|
121
206
|
Returns:
|
|
122
207
|
List of DetectionResult objects
|
|
123
208
|
"""
|
|
124
|
-
|
|
209
|
+
start_time = time.time()
|
|
210
|
+
total_prompts = len(prompts)
|
|
211
|
+
all_results = []
|
|
212
|
+
|
|
213
|
+
if not parallel or total_prompts == 1:
|
|
125
214
|
# Sequential processing
|
|
126
|
-
|
|
127
|
-
for prompt in prompts:
|
|
215
|
+
for i, prompt in enumerate(prompts):
|
|
128
216
|
result = await self.scan_prompt(prompt, **kwargs)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
217
|
+
all_results.append(result)
|
|
218
|
+
if progress_callback:
|
|
219
|
+
progress_callback(i + 1, total_prompts)
|
|
220
|
+
processing_time = time.time() - start_time
|
|
221
|
+
self._update_batch_metrics(total_prompts, processing_time, len(all_results))
|
|
222
|
+
return all_results
|
|
223
|
+
|
|
224
|
+
# Parallel processing with batching for better performance
|
|
133
225
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
226
|
+
completed = 0
|
|
134
227
|
|
|
135
228
|
async def scan_with_semaphore(prompt: str) -> DetectionResult:
|
|
229
|
+
nonlocal completed
|
|
230
|
+
async with semaphore:
|
|
231
|
+
result = await self.scan_prompt(prompt, **kwargs)
|
|
232
|
+
completed += 1
|
|
233
|
+
if progress_callback:
|
|
234
|
+
progress_callback(completed, total_prompts)
|
|
235
|
+
return result
|
|
236
|
+
|
|
237
|
+
# Process in batches to avoid overwhelming the server
|
|
238
|
+
for i in range(0, total_prompts, batch_size):
|
|
239
|
+
batch = prompts[i:i + batch_size]
|
|
240
|
+
tasks = [scan_with_semaphore(prompt) for prompt in batch]
|
|
241
|
+
batch_results = await asyncio.gather(*tasks)
|
|
242
|
+
all_results.extend(batch_results)
|
|
243
|
+
|
|
244
|
+
processing_time = time.time() - start_time
|
|
245
|
+
self._update_batch_metrics(total_prompts, processing_time, len(all_results))
|
|
246
|
+
return all_results
|
|
247
|
+
|
|
248
|
+
async def scan_stream(
|
|
249
|
+
self,
|
|
250
|
+
content: str,
|
|
251
|
+
chunk_size: int = 1000,
|
|
252
|
+
overlap: int = 100,
|
|
253
|
+
**kwargs
|
|
254
|
+
) -> StreamingScanResponse:
|
|
255
|
+
"""Scan long content in streaming chunks for real-time security analysis.
|
|
256
|
+
|
|
257
|
+
Args:
|
|
258
|
+
content: The long content to scan in chunks
|
|
259
|
+
chunk_size: Size of each chunk in characters (default: 1000)
|
|
260
|
+
overlap: Overlap between chunks in characters (default: 100)
|
|
261
|
+
**kwargs: Additional context for the scan
|
|
262
|
+
|
|
263
|
+
Returns:
|
|
264
|
+
StreamingScanResponse with chunk-by-chunk results
|
|
265
|
+
"""
|
|
266
|
+
start_time = time.time()
|
|
267
|
+
|
|
268
|
+
# Create overlapping chunks
|
|
269
|
+
chunks = self._create_overlapping_chunks(content, chunk_size, overlap)
|
|
270
|
+
chunk_results = []
|
|
271
|
+
|
|
272
|
+
# Process chunks concurrently for better performance
|
|
273
|
+
semaphore = asyncio.Semaphore(5) # Limit concurrent chunk processing
|
|
274
|
+
|
|
275
|
+
async def scan_chunk(chunk: str, chunk_index: int) -> DetectionResult:
|
|
136
276
|
async with semaphore:
|
|
137
|
-
|
|
277
|
+
# Add chunk context
|
|
278
|
+
chunk_kwargs = {
|
|
279
|
+
**kwargs,
|
|
280
|
+
"chunk_index": chunk_index,
|
|
281
|
+
"total_chunks": len(chunks),
|
|
282
|
+
"chunk_metadata": {
|
|
283
|
+
"start_pos": chunk_index * (chunk_size - overlap),
|
|
284
|
+
"end_pos": min((chunk_index + 1) * (chunk_size - overlap) + chunk_size, len(content)),
|
|
285
|
+
"overlap": overlap if chunk_index > 0 else 0
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
result = await self.scan_prompt(chunk, **chunk_kwargs)
|
|
289
|
+
self.metrics.streaming_chunks_processed += 1
|
|
290
|
+
return result
|
|
291
|
+
|
|
292
|
+
# Process all chunks
|
|
293
|
+
tasks = [scan_chunk(chunk, i) for i, chunk in enumerate(chunks)]
|
|
294
|
+
chunk_results = await asyncio.gather(*tasks)
|
|
295
|
+
|
|
296
|
+
# Aggregate overall result
|
|
297
|
+
overall_threat_level = max((r.threat_level for r in chunk_results),
|
|
298
|
+
key=lambda x: ["safe", "low", "medium", "high", "critical"].index(x.value))
|
|
299
|
+
overall_confidence = sum(r.confidence for r in chunk_results) / len(chunk_results)
|
|
300
|
+
overall_safe = all(r.is_safe for r in chunk_results)
|
|
301
|
+
|
|
302
|
+
# Create aggregate indicators
|
|
303
|
+
all_indicators = []
|
|
304
|
+
for i, result in enumerate(chunk_results):
|
|
305
|
+
for indicator in result.indicators:
|
|
306
|
+
# Add chunk information to indicators
|
|
307
|
+
enhanced_indicator = DetectionIndicator(
|
|
308
|
+
**indicator.model_dump(),
|
|
309
|
+
metadata={
|
|
310
|
+
**(indicator.metadata or {}),
|
|
311
|
+
"chunk_index": i
|
|
312
|
+
}
|
|
313
|
+
)
|
|
314
|
+
all_indicators.append(enhanced_indicator)
|
|
315
|
+
|
|
316
|
+
overall_result = DetectionResult(
|
|
317
|
+
is_safe=overall_safe,
|
|
318
|
+
threat_level=overall_threat_level,
|
|
319
|
+
confidence=overall_confidence,
|
|
320
|
+
indicators=all_indicators,
|
|
321
|
+
processing_time_ms=time.time() - start_time,
|
|
322
|
+
scan_id=f"stream_{int(time.time())}",
|
|
323
|
+
metadata={
|
|
324
|
+
"total_chunks": len(chunks),
|
|
325
|
+
"chunk_size": chunk_size,
|
|
326
|
+
"overlap": overlap,
|
|
327
|
+
"content_length": len(content)
|
|
328
|
+
}
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
processing_time = time.time() - start_time
|
|
332
|
+
self._update_metrics(processing_time)
|
|
333
|
+
|
|
334
|
+
return StreamingScanResponse(
|
|
335
|
+
chunk_results=chunk_results,
|
|
336
|
+
overall_result=overall_result,
|
|
337
|
+
total_chunks=len(chunks),
|
|
338
|
+
processing_time_ms=processing_time * 1000,
|
|
339
|
+
request_id=f"stream_{int(time.time())}",
|
|
340
|
+
timestamp=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
341
|
+
version="0.2.0"
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
def _create_overlapping_chunks(self, content: str, chunk_size: int, overlap: int) -> List[str]:
|
|
345
|
+
"""Create overlapping chunks from content for streaming analysis."""
|
|
346
|
+
if len(content) <= chunk_size:
|
|
347
|
+
return [content]
|
|
348
|
+
|
|
349
|
+
chunks = []
|
|
350
|
+
start = 0
|
|
351
|
+
|
|
352
|
+
while start < len(content):
|
|
353
|
+
end = min(start + chunk_size, len(content))
|
|
354
|
+
chunk = content[start:end]
|
|
355
|
+
chunks.append(chunk)
|
|
356
|
+
|
|
357
|
+
# Move start position with overlap, but ensure progress
|
|
358
|
+
start += chunk_size - overlap
|
|
359
|
+
if start >= end: # Prevent infinite loop
|
|
360
|
+
break
|
|
361
|
+
|
|
362
|
+
return chunks
|
|
363
|
+
|
|
364
|
+
async def get_performance_metrics(self) -> PerformanceMetrics:
|
|
365
|
+
"""Get current performance and usage metrics.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
PerformanceMetrics object with current statistics
|
|
369
|
+
"""
|
|
370
|
+
self.metrics.uptime_seconds = time.time() - self._start_time
|
|
371
|
+
|
|
372
|
+
if self.metrics.total_requests > 0:
|
|
373
|
+
self.metrics.average_response_time_ms = (
|
|
374
|
+
self.metrics.total_processing_time_ms / self.metrics.total_requests
|
|
375
|
+
)
|
|
376
|
+
self.metrics.requests_per_second = (
|
|
377
|
+
self.metrics.total_requests / self.metrics.uptime_seconds
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return self.metrics
|
|
381
|
+
|
|
382
|
+
async def reset_metrics(self) -> None:
|
|
383
|
+
"""Reset performance metrics."""
|
|
384
|
+
self.metrics = PerformanceMetrics()
|
|
385
|
+
self._start_time = time.time()
|
|
386
|
+
self._request_count = 0
|
|
387
|
+
|
|
388
|
+
async def apply_security_policy(self, policy: SecurityPolicy) -> None:
|
|
389
|
+
"""Apply a custom security policy to the client.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
policy: SecurityPolicy configuration to apply
|
|
393
|
+
"""
|
|
394
|
+
self.security_policy = policy
|
|
395
|
+
|
|
396
|
+
async def get_security_policy(self) -> SecurityPolicy:
|
|
397
|
+
"""Get the current security policy.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
Current SecurityPolicy configuration
|
|
401
|
+
"""
|
|
402
|
+
return self.security_policy
|
|
403
|
+
|
|
404
|
+
def _update_metrics(self, processing_time: float, is_error: bool = False) -> None:
|
|
405
|
+
"""Update internal performance metrics."""
|
|
406
|
+
if not self.enable_metrics:
|
|
407
|
+
return
|
|
408
|
+
|
|
409
|
+
self.metrics.total_requests += 1
|
|
410
|
+
self.metrics.total_processing_time_ms += processing_time * 1000
|
|
411
|
+
|
|
412
|
+
if is_error:
|
|
413
|
+
self.metrics.error_count += 1
|
|
414
|
+
|
|
415
|
+
def _update_batch_metrics(self, total_prompts: int, processing_time: float, results_count: int) -> None:
|
|
416
|
+
"""Update batch processing metrics."""
|
|
417
|
+
if not self.enable_metrics:
|
|
418
|
+
return
|
|
138
419
|
|
|
139
|
-
|
|
140
|
-
|
|
420
|
+
self.metrics.batch_efficiency = results_count / total_prompts if total_prompts > 0 else 0
|
|
421
|
+
self._update_metrics(processing_time)
|
|
141
422
|
|
|
142
423
|
async def get_scan_history(self, limit: int = 50, offset: int = 0, **filters) -> Dict[str, Any]:
|
|
143
424
|
"""Get scan history with optional filters asynchronously.
|
|
@@ -1,15 +1,39 @@
|
|
|
1
1
|
"""Integrations with popular frameworks and libraries."""
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
3
|
+
# Optional imports for langchain integration
|
|
4
|
+
try:
|
|
5
|
+
from .langchain import (
|
|
6
|
+
KoreShieldCallbackHandler,
|
|
7
|
+
AsyncKoreShieldCallbackHandler,
|
|
8
|
+
create_koreshield_callback,
|
|
9
|
+
create_async_koreshield_callback,
|
|
10
|
+
)
|
|
11
|
+
_langchain_available = True
|
|
12
|
+
except ImportError:
|
|
13
|
+
_langchain_available = False
|
|
14
|
+
|
|
15
|
+
from .frameworks import (
|
|
16
|
+
FastAPIIntegration,
|
|
17
|
+
FlaskIntegration,
|
|
18
|
+
DjangoIntegration,
|
|
19
|
+
create_fastapi_middleware,
|
|
20
|
+
create_flask_middleware,
|
|
21
|
+
create_django_middleware,
|
|
8
22
|
)
|
|
9
23
|
|
|
10
24
|
__all__ = [
|
|
11
|
-
"
|
|
12
|
-
"
|
|
13
|
-
"
|
|
14
|
-
"
|
|
15
|
-
|
|
25
|
+
"FastAPIIntegration",
|
|
26
|
+
"FlaskIntegration",
|
|
27
|
+
"DjangoIntegration",
|
|
28
|
+
"create_fastapi_middleware",
|
|
29
|
+
"create_flask_middleware",
|
|
30
|
+
"create_django_middleware",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
if _langchain_available:
|
|
34
|
+
__all__.extend([
|
|
35
|
+
"KoreShieldCallbackHandler",
|
|
36
|
+
"AsyncKoreShieldCallbackHandler",
|
|
37
|
+
"create_koreshield_callback",
|
|
38
|
+
"create_async_koreshield_callback",
|
|
39
|
+
])
|