koreshield 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- koreshield-0.1.1.dist-info/METADATA +455 -0
- koreshield-0.1.1.dist-info/RECORD +13 -0
- koreshield-0.1.1.dist-info/WHEEL +5 -0
- koreshield-0.1.1.dist-info/licenses/LICENSE +21 -0
- koreshield-0.1.1.dist-info/top_level.txt +1 -0
- koreshield_sdk/__init__.py +33 -0
- koreshield_sdk/async_client.py +263 -0
- koreshield_sdk/client.py +227 -0
- koreshield_sdk/exceptions.py +41 -0
- koreshield_sdk/integrations/__init__.py +15 -0
- koreshield_sdk/integrations/langchain.py +275 -0
- koreshield_sdk/py.typed +1 -0
- koreshield_sdk/types.py +90 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""LangChain integration for KoreShield."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
5
|
+
from langchain_core.outputs import LLMResult
|
|
6
|
+
from langchain_core.messages import BaseMessage
|
|
7
|
+
|
|
8
|
+
from ..client import KoreShieldClient
|
|
9
|
+
from ..async_client import AsyncKoreShieldClient
|
|
10
|
+
from ..types import DetectionResult, ThreatLevel
|
|
11
|
+
from ..exceptions import KoreShieldError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class KoreShieldCallbackHandler(BaseCallbackHandler):
|
|
15
|
+
"""LangChain callback handler for KoreShield security monitoring.
|
|
16
|
+
|
|
17
|
+
This handler automatically scans prompts and responses for security threats
|
|
18
|
+
during LangChain operations.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
api_key: str,
|
|
24
|
+
base_url: str = "https://api.koreshield.com",
|
|
25
|
+
block_on_threat: bool = False,
|
|
26
|
+
threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
|
|
27
|
+
scan_responses: bool = True,
|
|
28
|
+
**client_kwargs
|
|
29
|
+
):
|
|
30
|
+
"""Initialize the KoreShield callback handler.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
api_key: KoreShield API key
|
|
34
|
+
base_url: KoreShield API base URL
|
|
35
|
+
block_on_threat: Whether to raise exception on detected threats
|
|
36
|
+
threat_threshold: Minimum threat level to trigger blocking
|
|
37
|
+
scan_responses: Whether to scan LLM responses
|
|
38
|
+
**client_kwargs: Additional arguments for KoreShieldClient
|
|
39
|
+
"""
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.client = KoreShieldClient(api_key, base_url, **client_kwargs)
|
|
42
|
+
self.block_on_threat = block_on_threat
|
|
43
|
+
self.threat_threshold = threat_threshold
|
|
44
|
+
self.scan_responses = scan_responses
|
|
45
|
+
self.scan_results = []
|
|
46
|
+
|
|
47
|
+
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
|
48
|
+
"""Called when LLM starts processing prompts."""
|
|
49
|
+
for prompt in prompts:
|
|
50
|
+
try:
|
|
51
|
+
result = self.client.scan_prompt(prompt)
|
|
52
|
+
self.scan_results.append({
|
|
53
|
+
"type": "prompt",
|
|
54
|
+
"content": prompt,
|
|
55
|
+
"result": result,
|
|
56
|
+
"timestamp": kwargs.get("run_id", None),
|
|
57
|
+
})
|
|
58
|
+
|
|
59
|
+
if self._should_block(result):
|
|
60
|
+
raise KoreShieldError(
|
|
61
|
+
f"Security threat detected in prompt: {result.threat_level.value} "
|
|
62
|
+
f"(confidence: {result.confidence:.2f})"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
except KoreShieldError:
|
|
66
|
+
raise # Re-raise security errors
|
|
67
|
+
except Exception as e:
|
|
68
|
+
# Log but don't block on client errors
|
|
69
|
+
print(f"Warning: KoreShield scan failed: {e}")
|
|
70
|
+
|
|
71
|
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
72
|
+
"""Called when LLM finishes processing."""
|
|
73
|
+
if not self.scan_responses:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
for generation in response.generations:
|
|
77
|
+
for gen in generation:
|
|
78
|
+
if hasattr(gen, 'text') and gen.text:
|
|
79
|
+
try:
|
|
80
|
+
result = self.client.scan_prompt(gen.text)
|
|
81
|
+
self.scan_results.append({
|
|
82
|
+
"type": "response",
|
|
83
|
+
"content": gen.text,
|
|
84
|
+
"result": result,
|
|
85
|
+
"timestamp": kwargs.get("run_id", None),
|
|
86
|
+
})
|
|
87
|
+
|
|
88
|
+
if self._should_block(result):
|
|
89
|
+
raise KoreShieldError(
|
|
90
|
+
f"Security threat detected in response: {result.threat_level.value} "
|
|
91
|
+
f"(confidence: {result.confidence:.2f})"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
except KoreShieldError:
|
|
95
|
+
raise # Re-raise security errors
|
|
96
|
+
except Exception as e:
|
|
97
|
+
# Log but don't block on client errors
|
|
98
|
+
print(f"Warning: KoreShield scan failed: {e}")
|
|
99
|
+
|
|
100
|
+
def get_scan_results(self) -> List[Dict[str, Any]]:
|
|
101
|
+
"""Get all scan results from this handler."""
|
|
102
|
+
return self.scan_results.copy()
|
|
103
|
+
|
|
104
|
+
def clear_scan_results(self) -> None:
|
|
105
|
+
"""Clear stored scan results."""
|
|
106
|
+
self.scan_results.clear()
|
|
107
|
+
|
|
108
|
+
def _should_block(self, result: DetectionResult) -> bool:
|
|
109
|
+
"""Determine if a result should trigger blocking."""
|
|
110
|
+
if not self.block_on_threat:
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
threat_levels = {
|
|
114
|
+
ThreatLevel.SAFE: 0,
|
|
115
|
+
ThreatLevel.LOW: 1,
|
|
116
|
+
ThreatLevel.MEDIUM: 2,
|
|
117
|
+
ThreatLevel.HIGH: 3,
|
|
118
|
+
ThreatLevel.CRITICAL: 4,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
result_level = threat_levels.get(result.threat_level, 0)
|
|
122
|
+
threshold_level = threat_levels.get(self.threat_threshold, 2)
|
|
123
|
+
|
|
124
|
+
return result_level >= threshold_level
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AsyncKoreShieldCallbackHandler(BaseCallbackHandler):
|
|
128
|
+
"""Async LangChain callback handler for KoreShield security monitoring."""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
api_key: str,
|
|
133
|
+
base_url: str = "https://api.koreshield.com",
|
|
134
|
+
block_on_threat: bool = False,
|
|
135
|
+
threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
|
|
136
|
+
scan_responses: bool = True,
|
|
137
|
+
**client_kwargs
|
|
138
|
+
):
|
|
139
|
+
"""Initialize the async KoreShield callback handler."""
|
|
140
|
+
super().__init__()
|
|
141
|
+
self.client = AsyncKoreShieldClient(api_key, base_url, **client_kwargs)
|
|
142
|
+
self.block_on_threat = block_on_threat
|
|
143
|
+
self.threat_threshold = threat_threshold
|
|
144
|
+
self.scan_responses = scan_responses
|
|
145
|
+
self.scan_results = []
|
|
146
|
+
|
|
147
|
+
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
|
148
|
+
"""Called when LLM starts processing prompts (async)."""
|
|
149
|
+
for prompt in prompts:
|
|
150
|
+
try:
|
|
151
|
+
result = await self.client.scan_prompt(prompt)
|
|
152
|
+
self.scan_results.append({
|
|
153
|
+
"type": "prompt",
|
|
154
|
+
"content": prompt,
|
|
155
|
+
"result": result,
|
|
156
|
+
"timestamp": kwargs.get("run_id", None),
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
if self._should_block(result):
|
|
160
|
+
raise KoreShieldError(
|
|
161
|
+
f"Security threat detected in prompt: {result.threat_level.value} "
|
|
162
|
+
f"(confidence: {result.confidence:.2f})"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
except KoreShieldError:
|
|
166
|
+
raise # Re-raise security errors
|
|
167
|
+
except Exception as e:
|
|
168
|
+
# Log but don't block on client errors
|
|
169
|
+
print(f"Warning: KoreShield scan failed: {e}")
|
|
170
|
+
|
|
171
|
+
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
172
|
+
"""Called when LLM finishes processing (async)."""
|
|
173
|
+
if not self.scan_responses:
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
for generation in response.generations:
|
|
177
|
+
for gen in generation:
|
|
178
|
+
if hasattr(gen, 'text') and gen.text:
|
|
179
|
+
try:
|
|
180
|
+
result = await self.client.scan_prompt(gen.text)
|
|
181
|
+
self.scan_results.append({
|
|
182
|
+
"type": "response",
|
|
183
|
+
"content": gen.text,
|
|
184
|
+
"result": result,
|
|
185
|
+
"timestamp": kwargs.get("run_id", None),
|
|
186
|
+
})
|
|
187
|
+
|
|
188
|
+
if self._should_block(result):
|
|
189
|
+
raise KoreShieldError(
|
|
190
|
+
f"Security threat detected in response: {result.threat_level.value} "
|
|
191
|
+
f"(confidence: {result.confidence:.2f})"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
except KoreShieldError:
|
|
195
|
+
raise # Re-raise security errors
|
|
196
|
+
except Exception as e:
|
|
197
|
+
# Log but don't block on client errors
|
|
198
|
+
print(f"Warning: KoreShield scan failed: {e}")
|
|
199
|
+
|
|
200
|
+
def get_scan_results(self) -> List[Dict[str, Any]]:
|
|
201
|
+
"""Get all scan results from this handler."""
|
|
202
|
+
return self.scan_results.copy()
|
|
203
|
+
|
|
204
|
+
def clear_scan_results(self) -> None:
|
|
205
|
+
"""Clear stored scan results."""
|
|
206
|
+
self.scan_results.clear()
|
|
207
|
+
|
|
208
|
+
def _should_block(self, result: DetectionResult) -> bool:
|
|
209
|
+
"""Determine if a result should trigger blocking."""
|
|
210
|
+
if not self.block_on_threat:
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
threat_levels = {
|
|
214
|
+
ThreatLevel.SAFE: 0,
|
|
215
|
+
ThreatLevel.LOW: 1,
|
|
216
|
+
ThreatLevel.MEDIUM: 2,
|
|
217
|
+
ThreatLevel.HIGH: 3,
|
|
218
|
+
ThreatLevel.CRITICAL: 4,
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
result_level = threat_levels.get(result.threat_level, 0)
|
|
222
|
+
threshold_level = threat_levels.get(self.threat_threshold, 2)
|
|
223
|
+
|
|
224
|
+
return result_level >= threshold_level
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# Convenience functions for easy integration
|
|
228
|
+
def create_koreshield_callback(
|
|
229
|
+
api_key: str,
|
|
230
|
+
block_on_threat: bool = False,
|
|
231
|
+
threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
|
|
232
|
+
**kwargs
|
|
233
|
+
) -> KoreShieldCallbackHandler:
|
|
234
|
+
"""Create a KoreShield callback handler for LangChain.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
api_key: KoreShield API key
|
|
238
|
+
block_on_threat: Whether to block on detected threats
|
|
239
|
+
threat_threshold: Minimum threat level for blocking
|
|
240
|
+
**kwargs: Additional arguments for KoreShieldClient
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
Configured KoreShieldCallbackHandler
|
|
244
|
+
"""
|
|
245
|
+
return KoreShieldCallbackHandler(
|
|
246
|
+
api_key=api_key,
|
|
247
|
+
block_on_threat=block_on_threat,
|
|
248
|
+
threat_threshold=threat_threshold,
|
|
249
|
+
**kwargs
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def create_async_koreshield_callback(
|
|
254
|
+
api_key: str,
|
|
255
|
+
block_on_threat: bool = False,
|
|
256
|
+
threat_threshold: ThreatLevel = ThreatLevel.MEDIUM,
|
|
257
|
+
**kwargs
|
|
258
|
+
) -> AsyncKoreShieldCallbackHandler:
|
|
259
|
+
"""Create an async KoreShield callback handler for LangChain.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
api_key: KoreShield API key
|
|
263
|
+
block_on_threat: Whether to block on detected threats
|
|
264
|
+
threat_threshold: Minimum threat level for blocking
|
|
265
|
+
**kwargs: Additional arguments for AsyncKoreShieldClient
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
Configured AsyncKoreShieldCallbackHandler
|
|
269
|
+
"""
|
|
270
|
+
return AsyncKoreShieldCallbackHandler(
|
|
271
|
+
api_key=api_key,
|
|
272
|
+
block_on_threat=block_on_threat,
|
|
273
|
+
threat_threshold=threat_threshold,
|
|
274
|
+
**kwargs
|
|
275
|
+
)
|
koreshield_sdk/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Marker file for PEP 561
|
koreshield_sdk/types.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""Type definitions for KoreShield SDK."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Dict, List, Optional, Any
|
|
5
|
+
from pydantic import BaseModel, Field, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ThreatLevel(str, Enum):
|
|
9
|
+
"""Threat level enumeration."""
|
|
10
|
+
SAFE = "safe"
|
|
11
|
+
LOW = "low"
|
|
12
|
+
MEDIUM = "medium"
|
|
13
|
+
HIGH = "high"
|
|
14
|
+
CRITICAL = "critical"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DetectionType(str, Enum):
|
|
18
|
+
"""Detection type enumeration."""
|
|
19
|
+
KEYWORD = "keyword"
|
|
20
|
+
PATTERN = "pattern"
|
|
21
|
+
RULE = "rule"
|
|
22
|
+
ML = "ml"
|
|
23
|
+
BLOCKLIST = "blocklist"
|
|
24
|
+
ALLOWLIST = "allowlist"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DetectionIndicator(BaseModel):
|
|
28
|
+
"""Individual detection indicator."""
|
|
29
|
+
type: DetectionType
|
|
30
|
+
severity: ThreatLevel
|
|
31
|
+
confidence: float = Field(ge=0.0, le=1.0)
|
|
32
|
+
description: str
|
|
33
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DetectionResult(BaseModel):
|
|
37
|
+
"""Result of a security scan."""
|
|
38
|
+
is_safe: bool
|
|
39
|
+
threat_level: ThreatLevel
|
|
40
|
+
confidence: float = Field(ge=0.0, le=1.0)
|
|
41
|
+
indicators: List[DetectionIndicator] = Field(default_factory=list)
|
|
42
|
+
processing_time_ms: float
|
|
43
|
+
scan_id: Optional[str] = None
|
|
44
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ScanRequest(BaseModel):
|
|
48
|
+
"""Request for security scanning."""
|
|
49
|
+
prompt: str
|
|
50
|
+
context: Optional[Dict[str, Any]] = None
|
|
51
|
+
user_id: Optional[str] = None
|
|
52
|
+
session_id: Optional[str] = None
|
|
53
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
54
|
+
|
|
55
|
+
model_config = ConfigDict(extra="allow") # Allow extra fields
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ScanResponse(BaseModel):
|
|
59
|
+
"""Response from security scanning."""
|
|
60
|
+
result: DetectionResult
|
|
61
|
+
request_id: str
|
|
62
|
+
timestamp: str
|
|
63
|
+
version: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class AuthConfig(BaseModel):
|
|
67
|
+
"""Authentication configuration."""
|
|
68
|
+
api_key: str
|
|
69
|
+
base_url: str = "https://api.koreshield.com"
|
|
70
|
+
timeout: float = 30.0
|
|
71
|
+
retry_attempts: int = 3
|
|
72
|
+
retry_delay: float = 1.0
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class BatchScanRequest(BaseModel):
|
|
76
|
+
"""Request for batch security scanning."""
|
|
77
|
+
requests: List[ScanRequest]
|
|
78
|
+
parallel: bool = True
|
|
79
|
+
max_concurrent: int = 10
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BatchScanResponse(BaseModel):
|
|
83
|
+
"""Response from batch security scanning."""
|
|
84
|
+
results: List[ScanResponse]
|
|
85
|
+
total_processed: int
|
|
86
|
+
total_safe: int
|
|
87
|
+
total_unsafe: int
|
|
88
|
+
processing_time_ms: float
|
|
89
|
+
request_id: str
|
|
90
|
+
timestamp: str
|