koreshield 0.1.5__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,361 @@
1
+ """Framework-specific integration helpers for KoreShield SDK."""
2
+
3
+ from typing import Dict, List, Optional, Any, Callable, Union
4
+ from functools import wraps
5
+ import asyncio
6
+ import time
7
+
8
+ from ..async_client import AsyncKoreShieldClient
9
+ from ..types import DetectionResult, SecurityPolicy, ThreatLevel
10
+ from ..exceptions import KoreShieldError
11
+
12
+
13
+ class FastAPIIntegration:
14
+ """FastAPI integration helper for KoreShield security middleware."""
15
+
16
+ def __init__(
17
+ self,
18
+ client: AsyncKoreShieldClient,
19
+ scan_request_body: bool = True,
20
+ scan_response_body: bool = False,
21
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
22
+ block_on_threat: bool = False,
23
+ exclude_paths: Optional[List[str]] = None,
24
+ custom_scanner: Optional[Callable] = None,
25
+ ):
26
+ """Initialize FastAPI integration.
27
+
28
+ Args:
29
+ client: AsyncKoreShieldClient instance
30
+ scan_request_body: Whether to scan request bodies
31
+ scan_response_body: Whether to scan response bodies
32
+ threat_threshold: Minimum threat level to flag
33
+ block_on_threat: Whether to block requests with threats
34
+ exclude_paths: List of paths to exclude from scanning
35
+ custom_scanner: Custom scanning function
36
+ """
37
+ self.client = client
38
+ self.scan_request_body = scan_request_body
39
+ self.scan_response_body = scan_response_body
40
+ self.threat_threshold = threat_threshold
41
+ self.block_on_threat = block_on_threat
42
+ self.exclude_paths = exclude_paths or ["/health", "/docs", "/openapi.json"]
43
+ self.custom_scanner = custom_scanner
44
+
45
+ def create_middleware(self):
46
+ """Create FastAPI middleware for automatic security scanning."""
47
+ from fastapi import Request, Response, HTTPException
48
+ from fastapi.responses import JSONResponse
49
+ import json
50
+
51
+ async def koreshield_middleware(request: Request, call_next):
52
+ # Skip excluded paths
53
+ if request.url.path in self.exclude_paths:
54
+ return await call_next(request)
55
+
56
+ scan_results = []
57
+
58
+ # Scan request body
59
+ if self.scan_request_body and request.method in ["POST", "PUT", "PATCH"]:
60
+ try:
61
+ body = await request.body()
62
+ if body:
63
+ # Try to parse as JSON for better scanning
64
+ try:
65
+ json_body = json.loads(body.decode())
66
+ # Extract text content from common fields
67
+ text_content = self._extract_text_from_request(json_body)
68
+ if text_content:
69
+ result = await self.client.scan_prompt(text_content)
70
+ scan_results.append(("request", result))
71
+ except (json.JSONDecodeError, UnicodeDecodeError):
72
+ # If not JSON, scan raw content
73
+ if len(body) < 10000: # Limit scan size
74
+ result = await self.client.scan_prompt(body.decode(errors='ignore'))
75
+ scan_results.append(("request", result))
76
+ except Exception as e:
77
+ # Log error but don't block request
78
+ print(f"KoreShield request scan error: {e}")
79
+
80
+ # Check for threats in request
81
+ for scan_type, result in scan_results:
82
+ if not result.is_safe and self._is_above_threshold(result):
83
+ if self.block_on_threat:
84
+ return JSONResponse(
85
+ status_code=403,
86
+ content={
87
+ "error": "Security threat detected",
88
+ "threat_level": result.threat_level.value,
89
+ "confidence": result.confidence,
90
+ "scan_type": scan_type
91
+ }
92
+ )
93
+ else:
94
+ # Add security headers
95
+ request.state.koreshield_threat = result
96
+
97
+ # Process response
98
+ response = await call_next(request)
99
+
100
+ # Scan response body if enabled
101
+ if self.scan_response_body and hasattr(response, 'body'):
102
+ try:
103
+ # This would need to be implemented based on response type
104
+ pass
105
+ except Exception as e:
106
+ print(f"KoreShield response scan error: {e}")
107
+
108
+ # Add security headers
109
+ response.headers["X-KoreShield-Scanned"] = "true"
110
+ if scan_results:
111
+ threat_levels = [r.threat_level.value for _, r in scan_results]
112
+ response.headers["X-KoreShield-Threat-Levels"] = ",".join(threat_levels)
113
+
114
+ return response
115
+
116
+ return koreshield_middleware
117
+
118
+ def _extract_text_from_request(self, data: Any) -> str:
119
+ """Extract text content from request data."""
120
+ if isinstance(data, str):
121
+ return data
122
+ elif isinstance(data, dict):
123
+ # Common text fields in APIs
124
+ text_fields = ['prompt', 'message', 'content', 'text', 'query', 'input']
125
+ texts = []
126
+ for field in text_fields:
127
+ if field in data and isinstance(data[field], str):
128
+ texts.append(data[field])
129
+ return " ".join(texts)
130
+ elif isinstance(data, list):
131
+ return " ".join(str(item) for item in data if isinstance(item, str))
132
+ return ""
133
+
134
+ def _is_above_threshold(self, result: DetectionResult) -> bool:
135
+ """Check if detection result is above threat threshold."""
136
+ levels = [ThreatLevel.SAFE, ThreatLevel.LOW, ThreatLevel.MEDIUM, ThreatLevel.HIGH, ThreatLevel.CRITICAL]
137
+ result_level_index = levels.index(result.threat_level)
138
+ threshold_index = levels.index(self.threat_threshold)
139
+ return result_level_index >= threshold_index
140
+
141
+
142
+ class FlaskIntegration:
143
+ """Flask integration helper for KoreShield security middleware."""
144
+
145
+ def __init__(
146
+ self,
147
+ client: AsyncKoreShieldClient,
148
+ scan_request_body: bool = True,
149
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
150
+ block_on_threat: bool = False,
151
+ exclude_paths: Optional[List[str]] = None,
152
+ ):
153
+ """Initialize Flask integration.
154
+
155
+ Args:
156
+ client: AsyncKoreShieldClient instance
157
+ scan_request_body: Whether to scan request bodies
158
+ threat_threshold: Minimum threat level to flag
159
+ block_on_threat: Whether to block requests with threats
160
+ exclude_paths: List of paths to exclude from scanning
161
+ """
162
+ self.client = client
163
+ self.scan_request_body = scan_request_body
164
+ self.threat_threshold = threat_threshold
165
+ self.block_on_threat = block_on_threat
166
+ self.exclude_paths = exclude_paths or ["/health", "/static"]
167
+
168
+ def create_middleware(self):
169
+ """Create Flask middleware for automatic security scanning."""
170
+ from flask import request, jsonify, g
171
+ import json
172
+
173
+ def koreshield_middleware():
174
+ # Skip excluded paths
175
+ if request.path in self.exclude_paths:
176
+ return None
177
+
178
+ # Only scan POST/PUT/PATCH requests with bodies
179
+ if request.method not in ["POST", "PUT", "PATCH"] or not request.is_json:
180
+ return None
181
+
182
+ try:
183
+ data = request.get_json()
184
+ text_content = self._extract_text_from_request(data)
185
+
186
+ if text_content:
187
+ # Use asyncio to run async scan in sync context
188
+ loop = asyncio.new_event_loop()
189
+ asyncio.set_event_loop(loop)
190
+ try:
191
+ result = loop.run_until_complete(self.client.scan_prompt(text_content))
192
+ g.koreshield_result = result
193
+
194
+ if not result.is_safe and self._is_above_threshold(result):
195
+ if self.block_on_threat:
196
+ return jsonify({
197
+ "error": "Security threat detected",
198
+ "threat_level": result.threat_level.value,
199
+ "confidence": result.confidence
200
+ }), 403
201
+ finally:
202
+ loop.close()
203
+
204
+ except Exception as e:
205
+ # Log error but don't block
206
+ print(f"KoreShield middleware error: {e}")
207
+
208
+ return None
209
+
210
+ return koreshield_middleware
211
+
212
+ def _extract_text_from_request(self, data: Any) -> str:
213
+ """Extract text content from request data."""
214
+ if isinstance(data, str):
215
+ return data
216
+ elif isinstance(data, dict):
217
+ text_fields = ['prompt', 'message', 'content', 'text', 'query', 'input']
218
+ texts = []
219
+ for field in text_fields:
220
+ if field in data and isinstance(data[field], str):
221
+ texts.append(data[field])
222
+ return " ".join(texts)
223
+ elif isinstance(data, list):
224
+ return " ".join(str(item) for item in data if isinstance(item, str))
225
+ return ""
226
+
227
+ def _is_above_threshold(self, result: DetectionResult) -> bool:
228
+ """Check if detection result is above threat threshold."""
229
+ levels = [ThreatLevel.SAFE, ThreatLevel.LOW, ThreatLevel.MEDIUM, ThreatLevel.HIGH, ThreatLevel.CRITICAL]
230
+ result_level_index = levels.index(result.threat_level)
231
+ threshold_index = levels.index(self.threat_threshold)
232
+ return result_level_index >= threshold_index
233
+
234
+
235
+ class DjangoIntegration:
236
+ """Django integration helper for KoreShield security middleware."""
237
+
238
+ def __init__(
239
+ self,
240
+ client: AsyncKoreShieldClient,
241
+ scan_request_body: bool = True,
242
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
243
+ block_on_threat: bool = False,
244
+ exclude_paths: Optional[List[str]] = None,
245
+ ):
246
+ """Initialize Django integration.
247
+
248
+ Args:
249
+ client: AsyncKoreShieldClient instance
250
+ scan_request_body: Whether to scan request bodies
251
+ threat_threshold: Minimum threat level to flag
252
+ block_on_threat: Whether to block requests with threats
253
+ exclude_paths: List of paths to exclude from scanning
254
+ """
255
+ self.client = client
256
+ self.scan_request_body = scan_request_body
257
+ self.threat_threshold = threat_threshold
258
+ self.block_on_threat = block_on_threat
259
+ self.exclude_paths = exclude_paths or ["/admin", "/static", "/media"]
260
+
261
+ def create_middleware(self):
262
+ """Create Django middleware for automatic security scanning."""
263
+ from django.http import JsonResponse
264
+ from django.core.exceptions import MiddlewareNotUsed
265
+ import json
266
+ import asyncio
267
+
268
+ class KoreShieldMiddleware:
269
+ def __init__(self, get_response):
270
+ self.get_response = get_response
271
+
272
+ def __call__(self, request):
273
+ # Skip excluded paths
274
+ if request.path in self.exclude_paths:
275
+ return self.get_response(request)
276
+
277
+ # Only scan POST/PUT/PATCH requests
278
+ if request.method not in ["POST", "PUT", "PATCH"]:
279
+ return self.get_response(request)
280
+
281
+ # Scan request body
282
+ if self.scan_request_body:
283
+ try:
284
+ if request.content_type == 'application/json':
285
+ data = json.loads(request.body.decode())
286
+ text_content = self._extract_text_from_request(data)
287
+
288
+ if text_content:
289
+ # Run async scan in sync context
290
+ loop = asyncio.new_event_loop()
291
+ asyncio.set_event_loop(loop)
292
+ try:
293
+ result = loop.run_until_complete(self.client.scan_prompt(text_content))
294
+
295
+ if not result.is_safe and self._is_above_threshold(result):
296
+ if self.block_on_threat:
297
+ return JsonResponse({
298
+ "error": "Security threat detected",
299
+ "threat_level": result.threat_level.value,
300
+ "confidence": result.confidence
301
+ }, status=403)
302
+ else:
303
+ # Store result for later use
304
+ request.koreshield_result = result
305
+ finally:
306
+ loop.close()
307
+
308
+ except Exception as e:
309
+ print(f"KoreShield middleware error: {e}")
310
+
311
+ response = self.get_response(request)
312
+
313
+ # Add security headers
314
+ response["X-KoreShield-Scanned"] = "true"
315
+ if hasattr(request, 'koreshield_result'):
316
+ response["X-KoreShield-Threat-Level"] = request.koreshield_result.threat_level.value
317
+
318
+ return response
319
+
320
+ return KoreShieldMiddleware
321
+
322
+ def _extract_text_from_request(self, data: Any) -> str:
323
+ """Extract text content from request data."""
324
+ if isinstance(data, str):
325
+ return data
326
+ elif isinstance(data, dict):
327
+ text_fields = ['prompt', 'message', 'content', 'text', 'query', 'input']
328
+ texts = []
329
+ for field in text_fields:
330
+ if field in data and isinstance(data[field], str):
331
+ texts.append(data[field])
332
+ return " ".join(texts)
333
+ elif isinstance(data, list):
334
+ return " ".join(str(item) for item in data if isinstance(item, str))
335
+ return ""
336
+
337
+ def _is_above_threshold(self, result: DetectionResult) -> bool:
338
+ """Check if detection result is above threat threshold."""
339
+ levels = [ThreatLevel.SAFE, ThreatLevel.LOW, ThreatLevel.MEDIUM, ThreatLevel.HIGH, ThreatLevel.CRITICAL]
340
+ result_level_index = levels.index(result.threat_level)
341
+ threshold_index = levels.index(self.threat_threshold)
342
+ return result_level_index >= threshold_index
343
+
344
+
345
+ # Convenience functions for quick setup
346
+ def create_fastapi_middleware(client: AsyncKoreShieldClient, **kwargs):
347
+ """Create FastAPI middleware for KoreShield."""
348
+ integration = FastAPIIntegration(client, **kwargs)
349
+ return integration.create_middleware()
350
+
351
+
352
+ def create_flask_middleware(client: AsyncKoreShieldClient, **kwargs):
353
+ """Create Flask middleware for KoreShield."""
354
+ integration = FlaskIntegration(client, **kwargs)
355
+ return integration.create_middleware()
356
+
357
+
358
+ def create_django_middleware(client: AsyncKoreShieldClient, **kwargs):
359
+ """Create Django middleware for KoreShield."""
360
+ integration = DjangoIntegration(client, **kwargs)
361
+ return integration.create_middleware()
koreshield_sdk/types.py CHANGED
@@ -88,9 +88,10 @@ class BatchScanResponse(BaseModel):
88
88
  processing_time_ms: float
89
89
  request_id: str
90
90
  timestamp: str
91
+ version: Optional[str] = None
91
92
 
92
93
 
93
- # RAG Detection Types
94
+ # RAG Detection Types (from HEAD)
94
95
 
95
96
  class InjectionVector(str, Enum):
96
97
  """RAG injection vector taxonomy."""
@@ -201,41 +202,21 @@ class RAGScanResponse(BaseModel):
201
202
  timestamp: Optional[str] = None
202
203
 
203
204
  def get_threat_document_ids(self) -> List[str]:
204
- """Get list of document IDs with detected threats.
205
-
206
- Returns:
207
- List of document IDs that contain threats
208
- """
205
+ """Get list of document IDs with detected threats."""
209
206
  threat_ids = set()
210
-
211
- # From document-level threats
212
207
  for threat in self.context_analysis.document_threats:
213
208
  threat_ids.add(threat.document_id)
214
-
215
- # From cross-document threats
216
209
  for threat in self.context_analysis.cross_document_threats:
217
210
  threat_ids.update(threat.document_ids)
218
-
219
211
  return list(threat_ids)
220
212
 
221
213
  def get_safe_documents(self, original_documents: List[RAGDocument]) -> List[RAGDocument]:
222
- """Filter out threatening documents.
223
-
224
- Args:
225
- original_documents: Original list of documents scanned
226
-
227
- Returns:
228
- List of documents without detected threats
229
- """
214
+ """Filter out threatening documents."""
230
215
  threat_ids = set(self.get_threat_document_ids())
231
216
  return [doc for doc in original_documents if doc.id not in threat_ids]
232
217
 
233
218
  def has_critical_threats(self) -> bool:
234
- """Check if critical threats were detected.
235
-
236
- Returns:
237
- True if any critical severity threats found
238
- """
219
+ """Check if critical threats were detected."""
239
220
  return self.overall_severity == ThreatLevel.CRITICAL
240
221
 
241
222
 
@@ -245,4 +226,57 @@ class RAGScanRequest(BaseModel):
245
226
  documents: List[RAGDocument]
246
227
  config: Optional[Dict[str, Any]] = Field(default_factory=dict)
247
228
 
248
- model_config = ConfigDict(extra="allow")
229
+ model_config = ConfigDict(extra="allow")
230
+
231
+
232
+ # Streaming and Metric Types (from Origin)
233
+
234
+ class StreamingScanRequest(BaseModel):
235
+ """Request for streaming security scanning."""
236
+ content: str
237
+ chunk_size: int = 1000
238
+ overlap: int = 100
239
+ context: Optional[Dict[str, Any]] = None
240
+ user_id: Optional[str] = None
241
+ session_id: Optional[str] = None
242
+ metadata: Optional[Dict[str, Any]] = None
243
+
244
+ model_config = ConfigDict(extra="allow")
245
+
246
+
247
+ class StreamingScanResponse(BaseModel):
248
+ """Response from streaming security scanning."""
249
+ chunk_results: List[DetectionResult]
250
+ overall_result: DetectionResult
251
+ total_chunks: int
252
+ processing_time_ms: float
253
+ request_id: str
254
+ timestamp: str
255
+ version: str
256
+
257
+
258
+ class SecurityPolicy(BaseModel):
259
+ """Custom security policy configuration."""
260
+ name: str
261
+ description: Optional[str] = None
262
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM
263
+ blocked_detection_types: List[DetectionType] = Field(default_factory=list)
264
+ custom_rules: List[Dict[str, Any]] = Field(default_factory=list)
265
+ allowlist_patterns: List[str] = Field(default_factory=list)
266
+ blocklist_patterns: List[str] = Field(default_factory=list)
267
+ metadata: Optional[Dict[str, Any]] = None
268
+
269
+
270
+ class PerformanceMetrics(BaseModel):
271
+ """SDK performance and usage metrics."""
272
+ total_requests: int = 0
273
+ total_processing_time_ms: float = 0.0
274
+ average_response_time_ms: float = 0.0
275
+ requests_per_second: float = 0.0
276
+ error_count: int = 0
277
+ cache_hit_rate: float = 0.0
278
+ batch_efficiency: float = 0.0
279
+ streaming_chunks_processed: int = 0
280
+ uptime_seconds: float = 0.0
281
+ memory_usage_mb: Optional[float] = None
282
+ custom_metrics: Dict[str, Any] = Field(default_factory=dict)
@@ -1,13 +0,0 @@
1
- koreshield-0.1.5.dist-info/licenses/LICENSE,sha256=k3qeCwQxhbOO1GtxA10Do4-_veQzgflqjOp5uZD5mug,1071
2
- koreshield_sdk/__init__.py,sha256=JXErgUsoxTgM4EU--Os4ZTobARKWj1Mfurln-hNgCQw,785
3
- koreshield_sdk/async_client.py,sha256=WF4MQVefUJs-YpjVE4qkrP5P9vT6wb5qFJdsdebtOtc,14877
4
- koreshield_sdk/client.py,sha256=LHuCrHwugzDeoMY5bxmYRmIyRUwJUNgL_Vv3f5ncqpE,13217
5
- koreshield_sdk/exceptions.py,sha256=3j1FR4VFbe1Vv4i0bofBgQ_ZGwBfpOInBd9OyNQFUxo,945
6
- koreshield_sdk/py.typed,sha256=8ZJUsxZiuOy1oJeVhsTWQhTG_6pTVHVXk5hJL79ebTk,25
7
- koreshield_sdk/types.py,sha256=UabFBswT4ckPt2Umwl9FqOBSpPl6RN4FWJPl5qDn5cc,7034
8
- koreshield_sdk/integrations/__init__.py,sha256=po_sLSND55Wdu1vDmx4Nrjm072HLf04yxmtWj43yv7Y,382
9
- koreshield_sdk/integrations/langchain.py,sha256=w3BXs3tVk7R4ldFPhAm7qXbJPsHoamY3z2Ke0WPBVas,16542
10
- koreshield-0.1.5.dist-info/METADATA,sha256=XqNTIRL56qucFtHk2U0l7sfvtWSmCfGMiPwUTslFQ6A,15408
11
- koreshield-0.1.5.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
12
- koreshield-0.1.5.dist-info/top_level.txt,sha256=ePw2ZI3SrHZ5CaTRCyj3aya3j_qTcmRAQjoU7s3gAdM,15
13
- koreshield-0.1.5.dist-info/RECORD,,