koreshield 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,10 @@
1
- """Asynchronous KoreShield client."""
1
+ """Asynchronous KoreShield client with enhanced features."""
2
2
 
3
3
  import asyncio
4
4
  import time
5
- from typing import Dict, List, Optional, Any, Union
5
+ import json
6
+ from typing import Dict, List, Optional, Any, Union, AsyncGenerator, Callable
7
+ from contextlib import asynccontextmanager
6
8
  import httpx
7
9
 
8
10
  from .types import (
@@ -15,6 +17,10 @@ from .types import (
15
17
  RAGDocument,
16
18
  RAGScanRequest,
17
19
  RAGScanResponse,
20
+ StreamingScanRequest,
21
+ StreamingScanResponse,
22
+ SecurityPolicy,
23
+ PerformanceMetrics,
18
24
  )
19
25
  from .exceptions import (
20
26
  KoreShieldError,
@@ -28,7 +34,7 @@ from .exceptions import (
28
34
 
29
35
 
30
36
  class AsyncKoreShieldClient:
31
- """Asynchronous KoreShield API client."""
37
+ """Asynchronous KoreShield API client with enhanced features."""
32
38
 
33
39
  def __init__(
34
40
  self,
@@ -37,6 +43,9 @@ class AsyncKoreShieldClient:
37
43
  timeout: float = 30.0,
38
44
  retry_attempts: int = 3,
39
45
  retry_delay: float = 1.0,
46
+ enable_metrics: bool = True,
47
+ security_policy: Optional[SecurityPolicy] = None,
48
+ connection_pool_limits: Optional[Dict[str, int]] = None,
40
49
  ):
41
50
  """Initialize the async KoreShield client.
42
51
 
@@ -46,6 +55,9 @@ class AsyncKoreShieldClient:
46
55
  timeout: Request timeout in seconds
47
56
  retry_attempts: Number of retry attempts
48
57
  retry_delay: Delay between retries in seconds
58
+ enable_metrics: Whether to collect performance metrics
59
+ security_policy: Custom security policy configuration
60
+ connection_pool_limits: HTTP connection pool limits
49
61
  """
50
62
  self.auth_config = AuthConfig(
51
63
  api_key=api_key,
@@ -55,12 +67,29 @@ class AsyncKoreShieldClient:
55
67
  retry_delay=retry_delay,
56
68
  )
57
69
 
70
+ # Performance monitoring
71
+ self.enable_metrics = enable_metrics
72
+ self.metrics = PerformanceMetrics()
73
+ self._start_time = time.time()
74
+ self._request_count = 0
75
+
76
+ # Security policy
77
+ self.security_policy = security_policy or SecurityPolicy(name="default")
78
+
79
+ # Connection pool configuration
80
+ pool_limits = connection_pool_limits or {
81
+ "max_keepalive_connections": 20,
82
+ "max_connections": 100,
83
+ "keepalive_expiry": 30.0,
84
+ }
85
+
58
86
  self.client = httpx.AsyncClient(
59
- timeout=timeout,
87
+ timeout=httpx.Timeout(timeout, connect=10.0),
88
+ limits=httpx.Limits(**pool_limits),
60
89
  headers={
61
90
  "Authorization": f"Bearer {api_key}",
62
91
  "Content-Type": "application/json",
63
- "User-Agent": f"koreshield-python-sdk/0.1.0",
92
+ "User-Agent": f"koreshield-python-sdk/0.2.0",
64
93
  },
65
94
  )
66
95
 
@@ -77,7 +106,7 @@ class AsyncKoreShieldClient:
77
106
  await self.client.aclose()
78
107
 
79
108
  async def scan_prompt(self, prompt: str, **kwargs) -> DetectionResult:
80
- """Scan a single prompt for security threats asynchronously.
109
+ """Scan a single prompt for security threats asynchronously with enhanced features.
81
110
 
82
111
  Args:
83
112
  prompt: The prompt text to scan
@@ -94,53 +123,305 @@ class AsyncKoreShieldClient:
94
123
  NetworkError: If network error occurs
95
124
  TimeoutError: If request times out
96
125
  """
126
+ start_time = time.time()
127
+
128
+ # Apply security policy filtering
129
+ if not self._passes_security_policy(prompt):
130
+ # Create blocked result based on policy
131
+ processing_time = time.time() - start_time
132
+ self._update_metrics(processing_time)
133
+
134
+ return DetectionResult(
135
+ is_safe=False,
136
+ threat_level=self.security_policy.threat_threshold,
137
+ confidence=1.0,
138
+ indicators=[DetectionIndicator(
139
+ type=DetectionType.RULE,
140
+ severity=self.security_policy.threat_threshold,
141
+ confidence=1.0,
142
+ description="Blocked by security policy",
143
+ metadata={"policy_name": self.security_policy.name}
144
+ )],
145
+ processing_time_ms=processing_time * 1000,
146
+ scan_id=f"policy_block_{int(time.time())}",
147
+ metadata={"blocked_by_policy": True}
148
+ )
149
+
97
150
  request = ScanRequest(prompt=prompt, **kwargs)
98
151
 
99
152
  for attempt in range(self.auth_config.retry_attempts + 1):
100
153
  try:
101
154
  response = await self._make_request("POST", "/v1/scan", request.dict())
102
155
  scan_response = ScanResponse(**response)
156
+
157
+ processing_time = time.time() - start_time
158
+ self._update_metrics(processing_time)
159
+
103
160
  return scan_response.result
161
+
104
162
  except (RateLimitError, ServerError, NetworkError) as e:
105
163
  if attempt == self.auth_config.retry_attempts:
164
+ processing_time = time.time() - start_time
165
+ self._update_metrics(processing_time, is_error=True)
106
166
  raise e
107
167
  await asyncio.sleep(self.auth_config.retry_delay * (2 ** attempt))
108
168
 
169
+ def _passes_security_policy(self, prompt: str) -> bool:
170
+ """Check if prompt passes the current security policy.
171
+
172
+ Args:
173
+ prompt: The prompt to check
174
+
175
+ Returns:
176
+ True if prompt passes policy, False if blocked
177
+ """
178
+ # Check blocklist patterns first (blocking takes precedence)
179
+ for pattern in self.security_policy.blocklist_patterns:
180
+ if pattern.lower() in prompt.lower():
181
+ return False
182
+
183
+ # Check allowlist patterns
184
+ for pattern in self.security_policy.allowlist_patterns:
185
+ if pattern.lower() in prompt.lower():
186
+ return True
187
+
188
+ return True
189
+
109
190
  async def scan_batch(
110
191
  self,
111
192
  prompts: List[str],
112
193
  parallel: bool = True,
113
194
  max_concurrent: int = 10,
195
+ batch_size: int = 50,
196
+ progress_callback: Optional[Callable[[int, int], None]] = None,
114
197
  **kwargs
115
198
  ) -> List[DetectionResult]:
116
- """Scan multiple prompts for security threats asynchronously.
199
+ """Scan multiple prompts for security threats asynchronously with enhanced features.
117
200
 
118
201
  Args:
119
202
  prompts: List of prompt texts to scan
120
203
  parallel: Whether to process in parallel (default: True)
121
204
  max_concurrent: Maximum concurrent requests (default: 10)
205
+ batch_size: Size of each batch for processing (default: 50)
206
+ progress_callback: Optional callback for progress updates (current, total)
122
207
  **kwargs: Additional context for all requests
123
208
 
124
209
  Returns:
125
210
  List of DetectionResult objects
126
211
  """
127
- if not parallel:
212
+ start_time = time.time()
213
+ total_prompts = len(prompts)
214
+ all_results = []
215
+
216
+ if not parallel or total_prompts == 1:
128
217
  # Sequential processing
129
- results = []
130
- for prompt in prompts:
218
+ for i, prompt in enumerate(prompts):
131
219
  result = await self.scan_prompt(prompt, **kwargs)
132
- results.append(result)
133
- return results
134
-
135
- # Parallel processing with semaphore for concurrency control
220
+ all_results.append(result)
221
+ if progress_callback:
222
+ progress_callback(i + 1, total_prompts)
223
+ processing_time = time.time() - start_time
224
+ self._update_batch_metrics(total_prompts, processing_time, len(all_results))
225
+ return all_results
226
+
227
+ # Parallel processing with batching for better performance
136
228
  semaphore = asyncio.Semaphore(max_concurrent)
229
+ completed = 0
137
230
 
138
231
  async def scan_with_semaphore(prompt: str) -> DetectionResult:
232
+ nonlocal completed
139
233
  async with semaphore:
140
- return await self.scan_prompt(prompt, **kwargs)
234
+ result = await self.scan_prompt(prompt, **kwargs)
235
+ completed += 1
236
+ if progress_callback:
237
+ progress_callback(completed, total_prompts)
238
+ return result
239
+
240
+ # Process in batches to avoid overwhelming the server
241
+ for i in range(0, total_prompts, batch_size):
242
+ batch = prompts[i:i + batch_size]
243
+ tasks = [scan_with_semaphore(prompt) for prompt in batch]
244
+ batch_results = await asyncio.gather(*tasks)
245
+ all_results.extend(batch_results)
246
+
247
+ processing_time = time.time() - start_time
248
+ self._update_batch_metrics(total_prompts, processing_time, len(all_results))
249
+ return all_results
250
+
251
+ async def scan_stream(
252
+ self,
253
+ content: str,
254
+ chunk_size: int = 1000,
255
+ overlap: int = 100,
256
+ **kwargs
257
+ ) -> StreamingScanResponse:
258
+ """Scan long content in streaming chunks for real-time security analysis.
141
259
 
142
- tasks = [scan_with_semaphore(prompt) for prompt in prompts]
143
- return await asyncio.gather(*tasks)
260
+ Args:
261
+ content: The long content to scan in chunks
262
+ chunk_size: Size of each chunk in characters (default: 1000)
263
+ overlap: Overlap between chunks in characters (default: 100)
264
+ **kwargs: Additional context for the scan
265
+
266
+ Returns:
267
+ StreamingScanResponse with chunk-by-chunk results
268
+ """
269
+ start_time = time.time()
270
+
271
+ # Create overlapping chunks
272
+ chunks = self._create_overlapping_chunks(content, chunk_size, overlap)
273
+ chunk_results = []
274
+
275
+ # Process chunks concurrently for better performance
276
+ semaphore = asyncio.Semaphore(5) # Limit concurrent chunk processing
277
+
278
+ async def scan_chunk(chunk: str, chunk_index: int) -> DetectionResult:
279
+ async with semaphore:
280
+ # Add chunk context
281
+ chunk_kwargs = {
282
+ **kwargs,
283
+ "chunk_index": chunk_index,
284
+ "total_chunks": len(chunks),
285
+ "chunk_metadata": {
286
+ "start_pos": chunk_index * (chunk_size - overlap),
287
+ "end_pos": min((chunk_index + 1) * (chunk_size - overlap) + chunk_size, len(content)),
288
+ "overlap": overlap if chunk_index > 0 else 0
289
+ }
290
+ }
291
+ result = await self.scan_prompt(chunk, **chunk_kwargs)
292
+ self.metrics.streaming_chunks_processed += 1
293
+ return result
294
+
295
+ # Process all chunks
296
+ tasks = [scan_chunk(chunk, i) for i, chunk in enumerate(chunks)]
297
+ chunk_results = await asyncio.gather(*tasks)
298
+
299
+ # Aggregate overall result
300
+ overall_threat_level = max((r.threat_level for r in chunk_results),
301
+ key=lambda x: ["safe", "low", "medium", "high", "critical"].index(x.value))
302
+ overall_confidence = sum(r.confidence for r in chunk_results) / len(chunk_results)
303
+ overall_safe = all(r.is_safe for r in chunk_results)
304
+
305
+ # Create aggregate indicators
306
+ all_indicators = []
307
+ for i, result in enumerate(chunk_results):
308
+ for indicator in result.indicators:
309
+ # Add chunk information to indicators
310
+ enhanced_indicator = DetectionIndicator(
311
+ **indicator.model_dump(),
312
+ metadata={
313
+ **(indicator.metadata or {}),
314
+ "chunk_index": i
315
+ }
316
+ )
317
+ all_indicators.append(enhanced_indicator)
318
+
319
+ overall_result = DetectionResult(
320
+ is_safe=overall_safe,
321
+ threat_level=overall_threat_level,
322
+ confidence=overall_confidence,
323
+ indicators=all_indicators,
324
+ processing_time_ms=time.time() - start_time,
325
+ scan_id=f"stream_{int(time.time())}",
326
+ metadata={
327
+ "total_chunks": len(chunks),
328
+ "chunk_size": chunk_size,
329
+ "overlap": overlap,
330
+ "content_length": len(content)
331
+ }
332
+ )
333
+
334
+ processing_time = time.time() - start_time
335
+ self._update_metrics(processing_time)
336
+
337
+ return StreamingScanResponse(
338
+ chunk_results=chunk_results,
339
+ overall_result=overall_result,
340
+ total_chunks=len(chunks),
341
+ processing_time_ms=processing_time * 1000,
342
+ request_id=f"stream_{int(time.time())}",
343
+ timestamp=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
344
+ version="0.2.0"
345
+ )
346
+
347
+ def _create_overlapping_chunks(self, content: str, chunk_size: int, overlap: int) -> List[str]:
348
+ """Create overlapping chunks from content for streaming analysis."""
349
+ if len(content) <= chunk_size:
350
+ return [content]
351
+
352
+ chunks = []
353
+ start = 0
354
+
355
+ while start < len(content):
356
+ end = min(start + chunk_size, len(content))
357
+ chunk = content[start:end]
358
+ chunks.append(chunk)
359
+
360
+ # Move start position with overlap, but ensure progress
361
+ start += chunk_size - overlap
362
+ if start >= end: # Prevent infinite loop
363
+ break
364
+
365
+ return chunks
366
+
367
+ async def get_performance_metrics(self) -> PerformanceMetrics:
368
+ """Get current performance and usage metrics.
369
+
370
+ Returns:
371
+ PerformanceMetrics object with current statistics
372
+ """
373
+ self.metrics.uptime_seconds = time.time() - self._start_time
374
+
375
+ if self.metrics.total_requests > 0:
376
+ self.metrics.average_response_time_ms = (
377
+ self.metrics.total_processing_time_ms / self.metrics.total_requests
378
+ )
379
+ self.metrics.requests_per_second = (
380
+ self.metrics.total_requests / self.metrics.uptime_seconds
381
+ )
382
+
383
+ return self.metrics
384
+
385
+ async def reset_metrics(self) -> None:
386
+ """Reset performance metrics."""
387
+ self.metrics = PerformanceMetrics()
388
+ self._start_time = time.time()
389
+ self._request_count = 0
390
+
391
+ async def apply_security_policy(self, policy: SecurityPolicy) -> None:
392
+ """Apply a custom security policy to the client.
393
+
394
+ Args:
395
+ policy: SecurityPolicy configuration to apply
396
+ """
397
+ self.security_policy = policy
398
+
399
+ async def get_security_policy(self) -> SecurityPolicy:
400
+ """Get the current security policy.
401
+
402
+ Returns:
403
+ Current SecurityPolicy configuration
404
+ """
405
+ return self.security_policy
406
+
407
+ def _update_metrics(self, processing_time: float, is_error: bool = False) -> None:
408
+ """Update internal performance metrics."""
409
+ if not self.enable_metrics:
410
+ return
411
+
412
+ self.metrics.total_requests += 1
413
+ self.metrics.total_processing_time_ms += processing_time * 1000
414
+
415
+ if is_error:
416
+ self.metrics.error_count += 1
417
+
418
+ def _update_batch_metrics(self, total_prompts: int, processing_time: float, results_count: int) -> None:
419
+ """Update batch processing metrics."""
420
+ if not self.enable_metrics:
421
+ return
422
+
423
+ self.metrics.batch_efficiency = results_count / total_prompts if total_prompts > 0 else 0
424
+ self._update_metrics(processing_time)
144
425
 
145
426
  async def get_scan_history(self, limit: int = 50, offset: int = 0, **filters) -> Dict[str, Any]:
146
427
  """Get scan history with optional filters asynchronously.
@@ -1,15 +1,39 @@
1
1
  """Integrations with popular frameworks and libraries."""
2
2
 
3
- from .langchain import (
4
- KoreShieldCallbackHandler,
5
- AsyncKoreShieldCallbackHandler,
6
- create_koreshield_callback,
7
- create_async_koreshield_callback,
3
+ # Optional imports for langchain integration
4
+ try:
5
+ from .langchain import (
6
+ KoreShieldCallbackHandler,
7
+ AsyncKoreShieldCallbackHandler,
8
+ create_koreshield_callback,
9
+ create_async_koreshield_callback,
10
+ )
11
+ _langchain_available = True
12
+ except ImportError:
13
+ _langchain_available = False
14
+
15
+ from .frameworks import (
16
+ FastAPIIntegration,
17
+ FlaskIntegration,
18
+ DjangoIntegration,
19
+ create_fastapi_middleware,
20
+ create_flask_middleware,
21
+ create_django_middleware,
8
22
  )
9
23
 
10
24
  __all__ = [
11
- "KoreShieldCallbackHandler",
12
- "AsyncKoreShieldCallbackHandler",
13
- "create_koreshield_callback",
14
- "create_async_koreshield_callback",
15
- ]
25
+ "FastAPIIntegration",
26
+ "FlaskIntegration",
27
+ "DjangoIntegration",
28
+ "create_fastapi_middleware",
29
+ "create_flask_middleware",
30
+ "create_django_middleware",
31
+ ]
32
+
33
+ if _langchain_available:
34
+ __all__.extend([
35
+ "KoreShieldCallbackHandler",
36
+ "AsyncKoreShieldCallbackHandler",
37
+ "create_koreshield_callback",
38
+ "create_async_koreshield_callback",
39
+ ])