koreshield 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ """LangChain integration for KoreShield."""
2
+
3
+ from typing import Any, Dict, List, Optional, Union
4
+ from langchain_core.callbacks import BaseCallbackHandler
5
+ from langchain_core.outputs import LLMResult
6
+ from langchain_core.messages import BaseMessage
7
+
8
+ from ..client import KoreShieldClient
9
+ from ..async_client import AsyncKoreShieldClient
10
+ from ..types import DetectionResult, ThreatLevel
11
+ from ..exceptions import KoreShieldError
12
+
13
+
14
+ class KoreShieldCallbackHandler(BaseCallbackHandler):
15
+ """LangChain callback handler for KoreShield security monitoring.
16
+
17
+ This handler automatically scans prompts and responses for security threats
18
+ during LangChain operations.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ api_key: str,
24
+ base_url: str = "https://api.koreshield.com",
25
+ block_on_threat: bool = False,
26
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
27
+ scan_responses: bool = True,
28
+ **client_kwargs
29
+ ):
30
+ """Initialize the KoreShield callback handler.
31
+
32
+ Args:
33
+ api_key: KoreShield API key
34
+ base_url: KoreShield API base URL
35
+ block_on_threat: Whether to raise exception on detected threats
36
+ threat_threshold: Minimum threat level to trigger blocking
37
+ scan_responses: Whether to scan LLM responses
38
+ **client_kwargs: Additional arguments for KoreShieldClient
39
+ """
40
+ super().__init__()
41
+ self.client = KoreShieldClient(api_key, base_url, **client_kwargs)
42
+ self.block_on_threat = block_on_threat
43
+ self.threat_threshold = threat_threshold
44
+ self.scan_responses = scan_responses
45
+ self.scan_results = []
46
+
47
+ def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
48
+ """Called when LLM starts processing prompts."""
49
+ for prompt in prompts:
50
+ try:
51
+ result = self.client.scan_prompt(prompt)
52
+ self.scan_results.append({
53
+ "type": "prompt",
54
+ "content": prompt,
55
+ "result": result,
56
+ "timestamp": kwargs.get("run_id", None),
57
+ })
58
+
59
+ if self._should_block(result):
60
+ raise KoreShieldError(
61
+ f"Security threat detected in prompt: {result.threat_level.value} "
62
+ f"(confidence: {result.confidence:.2f})"
63
+ )
64
+
65
+ except KoreShieldError:
66
+ raise # Re-raise security errors
67
+ except Exception as e:
68
+ # Log but don't block on client errors
69
+ print(f"Warning: KoreShield scan failed: {e}")
70
+
71
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
72
+ """Called when LLM finishes processing."""
73
+ if not self.scan_responses:
74
+ return
75
+
76
+ for generation in response.generations:
77
+ for gen in generation:
78
+ if hasattr(gen, 'text') and gen.text:
79
+ try:
80
+ result = self.client.scan_prompt(gen.text)
81
+ self.scan_results.append({
82
+ "type": "response",
83
+ "content": gen.text,
84
+ "result": result,
85
+ "timestamp": kwargs.get("run_id", None),
86
+ })
87
+
88
+ if self._should_block(result):
89
+ raise KoreShieldError(
90
+ f"Security threat detected in response: {result.threat_level.value} "
91
+ f"(confidence: {result.confidence:.2f})"
92
+ )
93
+
94
+ except KoreShieldError:
95
+ raise # Re-raise security errors
96
+ except Exception as e:
97
+ # Log but don't block on client errors
98
+ print(f"Warning: KoreShield scan failed: {e}")
99
+
100
+ def get_scan_results(self) -> List[Dict[str, Any]]:
101
+ """Get all scan results from this handler."""
102
+ return self.scan_results.copy()
103
+
104
+ def clear_scan_results(self) -> None:
105
+ """Clear stored scan results."""
106
+ self.scan_results.clear()
107
+
108
+ def _should_block(self, result: DetectionResult) -> bool:
109
+ """Determine if a result should trigger blocking."""
110
+ if not self.block_on_threat:
111
+ return False
112
+
113
+ threat_levels = {
114
+ ThreatLevel.SAFE: 0,
115
+ ThreatLevel.LOW: 1,
116
+ ThreatLevel.MEDIUM: 2,
117
+ ThreatLevel.HIGH: 3,
118
+ ThreatLevel.CRITICAL: 4,
119
+ }
120
+
121
+ result_level = threat_levels.get(result.threat_level, 0)
122
+ threshold_level = threat_levels.get(self.threat_threshold, 2)
123
+
124
+ return result_level >= threshold_level
125
+
126
+
127
+ class AsyncKoreShieldCallbackHandler(BaseCallbackHandler):
128
+ """Async LangChain callback handler for KoreShield security monitoring."""
129
+
130
+ def __init__(
131
+ self,
132
+ api_key: str,
133
+ base_url: str = "https://api.koreshield.com",
134
+ block_on_threat: bool = False,
135
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
136
+ scan_responses: bool = True,
137
+ **client_kwargs
138
+ ):
139
+ """Initialize the async KoreShield callback handler."""
140
+ super().__init__()
141
+ self.client = AsyncKoreShieldClient(api_key, base_url, **client_kwargs)
142
+ self.block_on_threat = block_on_threat
143
+ self.threat_threshold = threat_threshold
144
+ self.scan_responses = scan_responses
145
+ self.scan_results = []
146
+
147
+ async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
148
+ """Called when LLM starts processing prompts (async)."""
149
+ for prompt in prompts:
150
+ try:
151
+ result = await self.client.scan_prompt(prompt)
152
+ self.scan_results.append({
153
+ "type": "prompt",
154
+ "content": prompt,
155
+ "result": result,
156
+ "timestamp": kwargs.get("run_id", None),
157
+ })
158
+
159
+ if self._should_block(result):
160
+ raise KoreShieldError(
161
+ f"Security threat detected in prompt: {result.threat_level.value} "
162
+ f"(confidence: {result.confidence:.2f})"
163
+ )
164
+
165
+ except KoreShieldError:
166
+ raise # Re-raise security errors
167
+ except Exception as e:
168
+ # Log but don't block on client errors
169
+ print(f"Warning: KoreShield scan failed: {e}")
170
+
171
+ async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
172
+ """Called when LLM finishes processing (async)."""
173
+ if not self.scan_responses:
174
+ return
175
+
176
+ for generation in response.generations:
177
+ for gen in generation:
178
+ if hasattr(gen, 'text') and gen.text:
179
+ try:
180
+ result = await self.client.scan_prompt(gen.text)
181
+ self.scan_results.append({
182
+ "type": "response",
183
+ "content": gen.text,
184
+ "result": result,
185
+ "timestamp": kwargs.get("run_id", None),
186
+ })
187
+
188
+ if self._should_block(result):
189
+ raise KoreShieldError(
190
+ f"Security threat detected in response: {result.threat_level.value} "
191
+ f"(confidence: {result.confidence:.2f})"
192
+ )
193
+
194
+ except KoreShieldError:
195
+ raise # Re-raise security errors
196
+ except Exception as e:
197
+ # Log but don't block on client errors
198
+ print(f"Warning: KoreShield scan failed: {e}")
199
+
200
+ def get_scan_results(self) -> List[Dict[str, Any]]:
201
+ """Get all scan results from this handler."""
202
+ return self.scan_results.copy()
203
+
204
+ def clear_scan_results(self) -> None:
205
+ """Clear stored scan results."""
206
+ self.scan_results.clear()
207
+
208
+ def _should_block(self, result: DetectionResult) -> bool:
209
+ """Determine if a result should trigger blocking."""
210
+ if not self.block_on_threat:
211
+ return False
212
+
213
+ threat_levels = {
214
+ ThreatLevel.SAFE: 0,
215
+ ThreatLevel.LOW: 1,
216
+ ThreatLevel.MEDIUM: 2,
217
+ ThreatLevel.HIGH: 3,
218
+ ThreatLevel.CRITICAL: 4,
219
+ }
220
+
221
+ result_level = threat_levels.get(result.threat_level, 0)
222
+ threshold_level = threat_levels.get(self.threat_threshold, 2)
223
+
224
+ return result_level >= threshold_level
225
+
226
+
227
+ # Convenience functions for easy integration
228
+ def create_koreshield_callback(
229
+ api_key: str,
230
+ block_on_threat: bool = False,
231
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
232
+ **kwargs
233
+ ) -> KoreShieldCallbackHandler:
234
+ """Create a KoreShield callback handler for LangChain.
235
+
236
+ Args:
237
+ api_key: KoreShield API key
238
+ block_on_threat: Whether to block on detected threats
239
+ threat_threshold: Minimum threat level for blocking
240
+ **kwargs: Additional arguments for KoreShieldClient
241
+
242
+ Returns:
243
+ Configured KoreShieldCallbackHandler
244
+ """
245
+ return KoreShieldCallbackHandler(
246
+ api_key=api_key,
247
+ block_on_threat=block_on_threat,
248
+ threat_threshold=threat_threshold,
249
+ **kwargs
250
+ )
251
+
252
+
253
+ def create_async_koreshield_callback(
254
+ api_key: str,
255
+ block_on_threat: bool = False,
256
+ threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
257
+ **kwargs
258
+ ) -> AsyncKoreShieldCallbackHandler:
259
+ """Create an async KoreShield callback handler for LangChain.
260
+
261
+ Args:
262
+ api_key: KoreShield API key
263
+ block_on_threat: Whether to block on detected threats
264
+ threat_threshold: Minimum threat level for blocking
265
+ **kwargs: Additional arguments for AsyncKoreShieldClient
266
+
267
+ Returns:
268
+ Configured AsyncKoreShieldCallbackHandler
269
+ """
270
+ return AsyncKoreShieldCallbackHandler(
271
+ api_key=api_key,
272
+ block_on_threat=block_on_threat,
273
+ threat_threshold=threat_threshold,
274
+ **kwargs
275
+ )
@@ -0,0 +1 @@
1
+ # Marker file for PEP 561
@@ -0,0 +1,90 @@
1
+ """Type definitions for KoreShield SDK."""
2
+
3
+ from enum import Enum
4
+ from typing import Dict, List, Optional, Any
5
+ from pydantic import BaseModel, Field, ConfigDict
6
+
7
+
8
+ class ThreatLevel(str, Enum):
9
+ """Threat level enumeration."""
10
+ SAFE = "safe"
11
+ LOW = "low"
12
+ MEDIUM = "medium"
13
+ HIGH = "high"
14
+ CRITICAL = "critical"
15
+
16
+
17
+ class DetectionType(str, Enum):
18
+ """Detection type enumeration."""
19
+ KEYWORD = "keyword"
20
+ PATTERN = "pattern"
21
+ RULE = "rule"
22
+ ML = "ml"
23
+ BLOCKLIST = "blocklist"
24
+ ALLOWLIST = "allowlist"
25
+
26
+
27
+ class DetectionIndicator(BaseModel):
28
+ """Individual detection indicator."""
29
+ type: DetectionType
30
+ severity: ThreatLevel
31
+ confidence: float = Field(ge=0.0, le=1.0)
32
+ description: str
33
+ metadata: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ class DetectionResult(BaseModel):
37
+ """Result of a security scan."""
38
+ is_safe: bool
39
+ threat_level: ThreatLevel
40
+ confidence: float = Field(ge=0.0, le=1.0)
41
+ indicators: List[DetectionIndicator] = Field(default_factory=list)
42
+ processing_time_ms: float
43
+ scan_id: Optional[str] = None
44
+ metadata: Optional[Dict[str, Any]] = None
45
+
46
+
47
+ class ScanRequest(BaseModel):
48
+ """Request for security scanning."""
49
+ prompt: str
50
+ context: Optional[Dict[str, Any]] = None
51
+ user_id: Optional[str] = None
52
+ session_id: Optional[str] = None
53
+ metadata: Optional[Dict[str, Any]] = None
54
+
55
+ model_config = ConfigDict(extra="allow") # Allow extra fields
56
+
57
+
58
+ class ScanResponse(BaseModel):
59
+ """Response from security scanning."""
60
+ result: DetectionResult
61
+ request_id: str
62
+ timestamp: str
63
+ version: str
64
+
65
+
66
+ class AuthConfig(BaseModel):
67
+ """Authentication configuration."""
68
+ api_key: str
69
+ base_url: str = "https://api.koreshield.com"
70
+ timeout: float = 30.0
71
+ retry_attempts: int = 3
72
+ retry_delay: float = 1.0
73
+
74
+
75
+ class BatchScanRequest(BaseModel):
76
+ """Request for batch security scanning."""
77
+ requests: List[ScanRequest]
78
+ parallel: bool = True
79
+ max_concurrent: int = 10
80
+
81
+
82
+ class BatchScanResponse(BaseModel):
83
+ """Response from batch security scanning."""
84
+ results: List[ScanResponse]
85
+ total_processed: int
86
+ total_safe: int
87
+ total_unsafe: int
88
+ processing_time_ms: float
89
+ request_id: str
90
+ timestamp: str