koreshield 0.1.4__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,10 @@
1
- """Asynchronous KoreShield client."""
1
+ """Asynchronous KoreShield client with enhanced features."""
2
2
 
3
3
  import asyncio
4
4
  import time
5
- from typing import Dict, List, Optional, Any, Union
5
+ import json
6
+ from typing import Dict, List, Optional, Any, Union, AsyncGenerator, Callable
7
+ from contextlib import asynccontextmanager
6
8
  import httpx
7
9
 
8
10
  from .types import (
@@ -12,6 +14,10 @@ from .types import (
12
14
  BatchScanRequest,
13
15
  BatchScanResponse,
14
16
  DetectionResult,
17
+ StreamingScanRequest,
18
+ StreamingScanResponse,
19
+ SecurityPolicy,
20
+ PerformanceMetrics,
15
21
  )
16
22
  from .exceptions import (
17
23
  KoreShieldError,
@@ -25,7 +31,7 @@ from .exceptions import (
25
31
 
26
32
 
27
33
  class AsyncKoreShieldClient:
28
- """Asynchronous KoreShield API client."""
34
+ """Asynchronous KoreShield API client with enhanced features."""
29
35
 
30
36
  def __init__(
31
37
  self,
@@ -34,6 +40,9 @@ class AsyncKoreShieldClient:
34
40
  timeout: float = 30.0,
35
41
  retry_attempts: int = 3,
36
42
  retry_delay: float = 1.0,
43
+ enable_metrics: bool = True,
44
+ security_policy: Optional[SecurityPolicy] = None,
45
+ connection_pool_limits: Optional[Dict[str, int]] = None,
37
46
  ):
38
47
  """Initialize the async KoreShield client.
39
48
 
@@ -43,6 +52,9 @@ class AsyncKoreShieldClient:
43
52
  timeout: Request timeout in seconds
44
53
  retry_attempts: Number of retry attempts
45
54
  retry_delay: Delay between retries in seconds
55
+ enable_metrics: Whether to collect performance metrics
56
+ security_policy: Custom security policy configuration
57
+ connection_pool_limits: HTTP connection pool limits
46
58
  """
47
59
  self.auth_config = AuthConfig(
48
60
  api_key=api_key,
@@ -52,12 +64,29 @@ class AsyncKoreShieldClient:
52
64
  retry_delay=retry_delay,
53
65
  )
54
66
 
67
+ # Performance monitoring
68
+ self.enable_metrics = enable_metrics
69
+ self.metrics = PerformanceMetrics()
70
+ self._start_time = time.time()
71
+ self._request_count = 0
72
+
73
+ # Security policy
74
+ self.security_policy = security_policy or SecurityPolicy(name="default")
75
+
76
+ # Connection pool configuration
77
+ pool_limits = connection_pool_limits or {
78
+ "max_keepalive_connections": 20,
79
+ "max_connections": 100,
80
+ "keepalive_expiry": 30.0,
81
+ }
82
+
55
83
  self.client = httpx.AsyncClient(
56
- timeout=timeout,
84
+ timeout=httpx.Timeout(timeout, connect=10.0),
85
+ limits=httpx.Limits(**pool_limits),
57
86
  headers={
58
87
  "Authorization": f"Bearer {api_key}",
59
88
  "Content-Type": "application/json",
60
- "User-Agent": f"koreshield-python-sdk/0.1.0",
89
+ "User-Agent": f"koreshield-python-sdk/0.2.0",
61
90
  },
62
91
  )
63
92
 
@@ -74,7 +103,7 @@ class AsyncKoreShieldClient:
74
103
  await self.client.aclose()
75
104
 
76
105
  async def scan_prompt(self, prompt: str, **kwargs) -> DetectionResult:
77
- """Scan a single prompt for security threats asynchronously.
106
+ """Scan a single prompt for security threats asynchronously with enhanced features.
78
107
 
79
108
  Args:
80
109
  prompt: The prompt text to scan
@@ -91,53 +120,305 @@ class AsyncKoreShieldClient:
91
120
  NetworkError: If network error occurs
92
121
  TimeoutError: If request times out
93
122
  """
123
+ start_time = time.time()
124
+
125
+ # Apply security policy filtering
126
+ if not self._passes_security_policy(prompt):
127
+ # Create blocked result based on policy
128
+ processing_time = time.time() - start_time
129
+ self._update_metrics(processing_time)
130
+
131
+ return DetectionResult(
132
+ is_safe=False,
133
+ threat_level=self.security_policy.threat_threshold,
134
+ confidence=1.0,
135
+ indicators=[DetectionIndicator(
136
+ type=DetectionType.RULE,
137
+ severity=self.security_policy.threat_threshold,
138
+ confidence=1.0,
139
+ description="Blocked by security policy",
140
+ metadata={"policy_name": self.security_policy.name}
141
+ )],
142
+ processing_time_ms=processing_time * 1000,
143
+ scan_id=f"policy_block_{int(time.time())}",
144
+ metadata={"blocked_by_policy": True}
145
+ )
146
+
94
147
  request = ScanRequest(prompt=prompt, **kwargs)
95
148
 
96
149
  for attempt in range(self.auth_config.retry_attempts + 1):
97
150
  try:
98
151
  response = await self._make_request("POST", "/v1/scan", request.dict())
99
152
  scan_response = ScanResponse(**response)
153
+
154
+ processing_time = time.time() - start_time
155
+ self._update_metrics(processing_time)
156
+
100
157
  return scan_response.result
158
+
101
159
  except (RateLimitError, ServerError, NetworkError) as e:
102
160
  if attempt == self.auth_config.retry_attempts:
161
+ processing_time = time.time() - start_time
162
+ self._update_metrics(processing_time, is_error=True)
103
163
  raise e
104
164
  await asyncio.sleep(self.auth_config.retry_delay * (2 ** attempt))
105
165
 
166
+ def _passes_security_policy(self, prompt: str) -> bool:
167
+ """Check if prompt passes the current security policy.
168
+
169
+ Args:
170
+ prompt: The prompt to check
171
+
172
+ Returns:
173
+ True if prompt passes policy, False if blocked
174
+ """
175
+ # Check blocklist patterns first (blocking takes precedence)
176
+ for pattern in self.security_policy.blocklist_patterns:
177
+ if pattern.lower() in prompt.lower():
178
+ return False
179
+
180
+ # Check allowlist patterns
181
+ for pattern in self.security_policy.allowlist_patterns:
182
+ if pattern.lower() in prompt.lower():
183
+ return True
184
+
185
+ return True
186
+
106
187
  async def scan_batch(
107
188
  self,
108
189
  prompts: List[str],
109
190
  parallel: bool = True,
110
191
  max_concurrent: int = 10,
192
+ batch_size: int = 50,
193
+ progress_callback: Optional[Callable[[int, int], None]] = None,
111
194
  **kwargs
112
195
  ) -> List[DetectionResult]:
113
- """Scan multiple prompts for security threats asynchronously.
196
+ """Scan multiple prompts for security threats asynchronously with enhanced features.
114
197
 
115
198
  Args:
116
199
  prompts: List of prompt texts to scan
117
200
  parallel: Whether to process in parallel (default: True)
118
201
  max_concurrent: Maximum concurrent requests (default: 10)
202
+ batch_size: Size of each batch for processing (default: 50)
203
+ progress_callback: Optional callback for progress updates (current, total)
119
204
  **kwargs: Additional context for all requests
120
205
 
121
206
  Returns:
122
207
  List of DetectionResult objects
123
208
  """
124
- if not parallel:
209
+ start_time = time.time()
210
+ total_prompts = len(prompts)
211
+ all_results = []
212
+
213
+ if not parallel or total_prompts == 1:
125
214
  # Sequential processing
126
- results = []
127
- for prompt in prompts:
215
+ for i, prompt in enumerate(prompts):
128
216
  result = await self.scan_prompt(prompt, **kwargs)
129
- results.append(result)
130
- return results
131
-
132
- # Parallel processing with semaphore for concurrency control
217
+ all_results.append(result)
218
+ if progress_callback:
219
+ progress_callback(i + 1, total_prompts)
220
+ processing_time = time.time() - start_time
221
+ self._update_batch_metrics(total_prompts, processing_time, len(all_results))
222
+ return all_results
223
+
224
+ # Parallel processing with batching for better performance
133
225
  semaphore = asyncio.Semaphore(max_concurrent)
226
+ completed = 0
134
227
 
135
228
  async def scan_with_semaphore(prompt: str) -> DetectionResult:
229
+ nonlocal completed
230
+ async with semaphore:
231
+ result = await self.scan_prompt(prompt, **kwargs)
232
+ completed += 1
233
+ if progress_callback:
234
+ progress_callback(completed, total_prompts)
235
+ return result
236
+
237
+ # Process in batches to avoid overwhelming the server
238
+ for i in range(0, total_prompts, batch_size):
239
+ batch = prompts[i:i + batch_size]
240
+ tasks = [scan_with_semaphore(prompt) for prompt in batch]
241
+ batch_results = await asyncio.gather(*tasks)
242
+ all_results.extend(batch_results)
243
+
244
+ processing_time = time.time() - start_time
245
+ self._update_batch_metrics(total_prompts, processing_time, len(all_results))
246
+ return all_results
247
+
248
+ async def scan_stream(
249
+ self,
250
+ content: str,
251
+ chunk_size: int = 1000,
252
+ overlap: int = 100,
253
+ **kwargs
254
+ ) -> StreamingScanResponse:
255
+ """Scan long content in streaming chunks for real-time security analysis.
256
+
257
+ Args:
258
+ content: The long content to scan in chunks
259
+ chunk_size: Size of each chunk in characters (default: 1000)
260
+ overlap: Overlap between chunks in characters (default: 100)
261
+ **kwargs: Additional context for the scan
262
+
263
+ Returns:
264
+ StreamingScanResponse with chunk-by-chunk results
265
+ """
266
+ start_time = time.time()
267
+
268
+ # Create overlapping chunks
269
+ chunks = self._create_overlapping_chunks(content, chunk_size, overlap)
270
+ chunk_results = []
271
+
272
+ # Process chunks concurrently for better performance
273
+ semaphore = asyncio.Semaphore(5) # Limit concurrent chunk processing
274
+
275
+ async def scan_chunk(chunk: str, chunk_index: int) -> DetectionResult:
136
276
  async with semaphore:
137
- return await self.scan_prompt(prompt, **kwargs)
277
+ # Add chunk context
278
+ chunk_kwargs = {
279
+ **kwargs,
280
+ "chunk_index": chunk_index,
281
+ "total_chunks": len(chunks),
282
+ "chunk_metadata": {
283
+ "start_pos": chunk_index * (chunk_size - overlap),
284
+ "end_pos": min((chunk_index + 1) * (chunk_size - overlap) + chunk_size, len(content)),
285
+ "overlap": overlap if chunk_index > 0 else 0
286
+ }
287
+ }
288
+ result = await self.scan_prompt(chunk, **chunk_kwargs)
289
+ self.metrics.streaming_chunks_processed += 1
290
+ return result
291
+
292
+ # Process all chunks
293
+ tasks = [scan_chunk(chunk, i) for i, chunk in enumerate(chunks)]
294
+ chunk_results = await asyncio.gather(*tasks)
295
+
296
+ # Aggregate overall result
297
+ overall_threat_level = max((r.threat_level for r in chunk_results),
298
+ key=lambda x: ["safe", "low", "medium", "high", "critical"].index(x.value))
299
+ overall_confidence = sum(r.confidence for r in chunk_results) / len(chunk_results)
300
+ overall_safe = all(r.is_safe for r in chunk_results)
301
+
302
+ # Create aggregate indicators
303
+ all_indicators = []
304
+ for i, result in enumerate(chunk_results):
305
+ for indicator in result.indicators:
306
+ # Add chunk information to indicators
307
+ enhanced_indicator = DetectionIndicator(
308
+ **indicator.model_dump(),
309
+ metadata={
310
+ **(indicator.metadata or {}),
311
+ "chunk_index": i
312
+ }
313
+ )
314
+ all_indicators.append(enhanced_indicator)
315
+
316
+ overall_result = DetectionResult(
317
+ is_safe=overall_safe,
318
+ threat_level=overall_threat_level,
319
+ confidence=overall_confidence,
320
+ indicators=all_indicators,
321
+ processing_time_ms=time.time() - start_time,
322
+ scan_id=f"stream_{int(time.time())}",
323
+ metadata={
324
+ "total_chunks": len(chunks),
325
+ "chunk_size": chunk_size,
326
+ "overlap": overlap,
327
+ "content_length": len(content)
328
+ }
329
+ )
330
+
331
+ processing_time = time.time() - start_time
332
+ self._update_metrics(processing_time)
333
+
334
+ return StreamingScanResponse(
335
+ chunk_results=chunk_results,
336
+ overall_result=overall_result,
337
+ total_chunks=len(chunks),
338
+ processing_time_ms=processing_time * 1000,
339
+ request_id=f"stream_{int(time.time())}",
340
+ timestamp=time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
341
+ version="0.2.0"
342
+ )
343
+
344
+ def _create_overlapping_chunks(self, content: str, chunk_size: int, overlap: int) -> List[str]:
345
+ """Create overlapping chunks from content for streaming analysis."""
346
+ if len(content) <= chunk_size:
347
+ return [content]
348
+
349
+ chunks = []
350
+ start = 0
351
+
352
+ while start < len(content):
353
+ end = min(start + chunk_size, len(content))
354
+ chunk = content[start:end]
355
+ chunks.append(chunk)
356
+
357
+ # Move start position with overlap, but ensure progress
358
+ start += chunk_size - overlap
359
+ if start >= end: # Prevent infinite loop
360
+ break
361
+
362
+ return chunks
363
+
364
+ async def get_performance_metrics(self) -> PerformanceMetrics:
365
+ """Get current performance and usage metrics.
366
+
367
+ Returns:
368
+ PerformanceMetrics object with current statistics
369
+ """
370
+ self.metrics.uptime_seconds = time.time() - self._start_time
371
+
372
+ if self.metrics.total_requests > 0:
373
+ self.metrics.average_response_time_ms = (
374
+ self.metrics.total_processing_time_ms / self.metrics.total_requests
375
+ )
376
+ self.metrics.requests_per_second = (
377
+ self.metrics.total_requests / self.metrics.uptime_seconds
378
+ )
379
+
380
+ return self.metrics
381
+
382
+ async def reset_metrics(self) -> None:
383
+ """Reset performance metrics."""
384
+ self.metrics = PerformanceMetrics()
385
+ self._start_time = time.time()
386
+ self._request_count = 0
387
+
388
+ async def apply_security_policy(self, policy: SecurityPolicy) -> None:
389
+ """Apply a custom security policy to the client.
390
+
391
+ Args:
392
+ policy: SecurityPolicy configuration to apply
393
+ """
394
+ self.security_policy = policy
395
+
396
+ async def get_security_policy(self) -> SecurityPolicy:
397
+ """Get the current security policy.
398
+
399
+ Returns:
400
+ Current SecurityPolicy configuration
401
+ """
402
+ return self.security_policy
403
+
404
+ def _update_metrics(self, processing_time: float, is_error: bool = False) -> None:
405
+ """Update internal performance metrics."""
406
+ if not self.enable_metrics:
407
+ return
408
+
409
+ self.metrics.total_requests += 1
410
+ self.metrics.total_processing_time_ms += processing_time * 1000
411
+
412
+ if is_error:
413
+ self.metrics.error_count += 1
414
+
415
+ def _update_batch_metrics(self, total_prompts: int, processing_time: float, results_count: int) -> None:
416
+ """Update batch processing metrics."""
417
+ if not self.enable_metrics:
418
+ return
138
419
 
139
- tasks = [scan_with_semaphore(prompt) for prompt in prompts]
140
- return await asyncio.gather(*tasks)
420
+ self.metrics.batch_efficiency = results_count / total_prompts if total_prompts > 0 else 0
421
+ self._update_metrics(processing_time)
141
422
 
142
423
  async def get_scan_history(self, limit: int = 50, offset: int = 0, **filters) -> Dict[str, Any]:
143
424
  """Get scan history with optional filters asynchronously.
@@ -1,15 +1,39 @@
1
1
  """Integrations with popular frameworks and libraries."""
2
2
 
3
- from .langchain import (
4
- KoreShieldCallbackHandler,
5
- AsyncKoreShieldCallbackHandler,
6
- create_koreshield_callback,
7
- create_async_koreshield_callback,
3
+ # Optional imports for langchain integration
4
+ try:
5
+ from .langchain import (
6
+ KoreShieldCallbackHandler,
7
+ AsyncKoreShieldCallbackHandler,
8
+ create_koreshield_callback,
9
+ create_async_koreshield_callback,
10
+ )
11
+ _langchain_available = True
12
+ except ImportError:
13
+ _langchain_available = False
14
+
15
+ from .frameworks import (
16
+ FastAPIIntegration,
17
+ FlaskIntegration,
18
+ DjangoIntegration,
19
+ create_fastapi_middleware,
20
+ create_flask_middleware,
21
+ create_django_middleware,
8
22
  )
9
23
 
10
24
  __all__ = [
11
- "KoreShieldCallbackHandler",
12
- "AsyncKoreShieldCallbackHandler",
13
- "create_koreshield_callback",
14
- "create_async_koreshield_callback",
15
- ]
25
+ "FastAPIIntegration",
26
+ "FlaskIntegration",
27
+ "DjangoIntegration",
28
+ "create_fastapi_middleware",
29
+ "create_flask_middleware",
30
+ "create_django_middleware",
31
+ ]
32
+
33
+ if _langchain_available:
34
+ __all__.extend([
35
+ "KoreShieldCallbackHandler",
36
+ "AsyncKoreShieldCallbackHandler",
37
+ "create_koreshield_callback",
38
+ "create_async_koreshield_callback",
39
+ ])