netra-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of netra-sdk might be problematic. Click here for more details.
- netra/__init__.py +148 -0
- netra/anonymizer/__init__.py +7 -0
- netra/anonymizer/anonymizer.py +79 -0
- netra/anonymizer/base.py +159 -0
- netra/anonymizer/fp_anonymizer.py +182 -0
- netra/config.py +111 -0
- netra/decorators.py +167 -0
- netra/exceptions/__init__.py +6 -0
- netra/exceptions/injection.py +33 -0
- netra/exceptions/pii.py +46 -0
- netra/input_scanner.py +142 -0
- netra/instrumentation/__init__.py +257 -0
- netra/instrumentation/aiohttp/__init__.py +378 -0
- netra/instrumentation/aiohttp/version.py +1 -0
- netra/instrumentation/cohere/__init__.py +446 -0
- netra/instrumentation/cohere/version.py +1 -0
- netra/instrumentation/google_genai/__init__.py +506 -0
- netra/instrumentation/google_genai/config.py +5 -0
- netra/instrumentation/google_genai/utils.py +31 -0
- netra/instrumentation/google_genai/version.py +1 -0
- netra/instrumentation/httpx/__init__.py +545 -0
- netra/instrumentation/httpx/version.py +1 -0
- netra/instrumentation/instruments.py +78 -0
- netra/instrumentation/mistralai/__init__.py +545 -0
- netra/instrumentation/mistralai/config.py +5 -0
- netra/instrumentation/mistralai/utils.py +30 -0
- netra/instrumentation/mistralai/version.py +1 -0
- netra/instrumentation/weaviate/__init__.py +121 -0
- netra/instrumentation/weaviate/version.py +1 -0
- netra/pii.py +757 -0
- netra/processors/__init__.py +4 -0
- netra/processors/session_span_processor.py +55 -0
- netra/processors/span_aggregation_processor.py +365 -0
- netra/scanner.py +104 -0
- netra/session.py +185 -0
- netra/session_manager.py +96 -0
- netra/tracer.py +99 -0
- netra/version.py +1 -0
- netra_sdk-0.1.0.dist-info/LICENCE +201 -0
- netra_sdk-0.1.0.dist-info/METADATA +573 -0
- netra_sdk-0.1.0.dist-info/RECORD +42 -0
- netra_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from opentelemetry import baggage
|
|
5
|
+
from opentelemetry import context as otel_context
|
|
6
|
+
from opentelemetry import trace
|
|
7
|
+
from opentelemetry.sdk.trace import SpanProcessor
|
|
8
|
+
|
|
9
|
+
from netra.config import Config
|
|
10
|
+
from netra.session_manager import SessionManager
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SessionSpanProcessor(SpanProcessor): # type: ignore[misc]
|
|
16
|
+
"""OpenTelemetry span processor that automatically adds session attributes to spans."""
|
|
17
|
+
|
|
18
|
+
def on_start(self, span: trace.Span, parent_context: Optional[otel_context.Context] = None) -> None:
|
|
19
|
+
"""Add session attributes to span when it starts and store current span."""
|
|
20
|
+
try:
|
|
21
|
+
# Store the current span in SessionManager
|
|
22
|
+
SessionManager.set_current_span(span)
|
|
23
|
+
|
|
24
|
+
ctx = otel_context.get_current()
|
|
25
|
+
session_id = baggage.get_baggage("session_id", ctx)
|
|
26
|
+
user_id = baggage.get_baggage("user_id", ctx)
|
|
27
|
+
tenant_id = baggage.get_baggage("tenant_id", ctx)
|
|
28
|
+
custom_keys = baggage.get_baggage("custom_keys", ctx)
|
|
29
|
+
|
|
30
|
+
span.set_attribute("library.name", Config.LIBRARY_NAME)
|
|
31
|
+
span.set_attribute("library.version", Config.LIBRARY_VERSION)
|
|
32
|
+
span.set_attribute("sdk.name", Config.SDK_NAME)
|
|
33
|
+
|
|
34
|
+
if session_id:
|
|
35
|
+
span.set_attribute(f"{Config.LIBRARY_NAME}.session_id", session_id)
|
|
36
|
+
if user_id:
|
|
37
|
+
span.set_attribute(f"{Config.LIBRARY_NAME}.user_id", user_id)
|
|
38
|
+
if tenant_id:
|
|
39
|
+
span.set_attribute(f"{Config.LIBRARY_NAME}.tenant_id", tenant_id)
|
|
40
|
+
if custom_keys:
|
|
41
|
+
for key in custom_keys.split(","):
|
|
42
|
+
value = baggage.get_baggage(f"custom.{key}", ctx)
|
|
43
|
+
if value:
|
|
44
|
+
span.set_attribute(f"{Config.LIBRARY_NAME}.custom.{key}", value)
|
|
45
|
+
except Exception as e:
|
|
46
|
+
logger.exception(f"Error setting span attributes: {e}")
|
|
47
|
+
|
|
48
|
+
def on_end(self, span: trace.Span) -> None:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def force_flush(self, timeout_millis: int = 30000) -> None:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def shutdown(self) -> None:
|
|
55
|
+
pass
|
|
@@ -0,0 +1,365 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Span aggregation utilities for Combat SDK.
|
|
3
|
+
Handles aggregation of child span data into parent spans.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from typing import Any, Dict, Optional, Set
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
from opentelemetry import trace
|
|
13
|
+
from opentelemetry.sdk.trace import SpanProcessor
|
|
14
|
+
from opentelemetry.trace import Context, Span
|
|
15
|
+
|
|
16
|
+
from netra import Netra
|
|
17
|
+
from netra.config import Config
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SpanAggregationData:
|
|
23
|
+
"""Holds aggregated data for a span."""
|
|
24
|
+
|
|
25
|
+
def __init__(self) -> None:
|
|
26
|
+
self.tokens: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
|
|
27
|
+
self.models: Set[str] = set()
|
|
28
|
+
self.has_pii: bool = False
|
|
29
|
+
self.pii_entities: Set[str] = set()
|
|
30
|
+
self.pii_actions: Dict[str, Set[str]] = defaultdict(set)
|
|
31
|
+
self.has_violation: bool = False
|
|
32
|
+
self.violations: Set[str] = set()
|
|
33
|
+
self.violation_actions: Dict[str, Set[str]] = defaultdict(set)
|
|
34
|
+
self.has_error: bool = False
|
|
35
|
+
self.status_codes: Set[int] = set()
|
|
36
|
+
|
|
37
|
+
def merge_from_other(self, other: "SpanAggregationData") -> None:
|
|
38
|
+
"""Merge data from another SpanAggregationData instance."""
|
|
39
|
+
# Merge error data
|
|
40
|
+
if other.has_error:
|
|
41
|
+
self.has_error = True
|
|
42
|
+
self.status_codes.update(other.status_codes)
|
|
43
|
+
|
|
44
|
+
# Merge tokens - take the maximum values for each model
|
|
45
|
+
for model, token_data in other.tokens.items():
|
|
46
|
+
if model not in self.tokens:
|
|
47
|
+
self.tokens[model] = {}
|
|
48
|
+
for token_type, count in token_data.items():
|
|
49
|
+
self.tokens[model][token_type] = max(self.tokens[model].get(token_type, 0), count)
|
|
50
|
+
|
|
51
|
+
# Merge models
|
|
52
|
+
self.models.update(other.models)
|
|
53
|
+
|
|
54
|
+
# Merge PII data
|
|
55
|
+
if other.has_pii:
|
|
56
|
+
self.has_pii = True
|
|
57
|
+
self.pii_entities.update(other.pii_entities)
|
|
58
|
+
for action, entities in other.pii_actions.items():
|
|
59
|
+
self.pii_actions[action].update(entities)
|
|
60
|
+
|
|
61
|
+
# Merge violation data
|
|
62
|
+
if other.has_violation:
|
|
63
|
+
self.has_violation = True
|
|
64
|
+
self.violations.update(other.violations)
|
|
65
|
+
for action, violations in other.violation_actions.items():
|
|
66
|
+
self.violation_actions[action].update(violations)
|
|
67
|
+
|
|
68
|
+
def to_attributes(self) -> Dict[str, str]:
|
|
69
|
+
"""Convert aggregated data to span attributes."""
|
|
70
|
+
attributes = {}
|
|
71
|
+
|
|
72
|
+
# Error Data
|
|
73
|
+
attributes["has_error"] = str(self.has_error).lower()
|
|
74
|
+
if self.has_error:
|
|
75
|
+
attributes["status_codes"] = json.dumps(list(self.status_codes))
|
|
76
|
+
|
|
77
|
+
# Token usage by model
|
|
78
|
+
if self.tokens:
|
|
79
|
+
tokens_dict = {}
|
|
80
|
+
for model, usage in self.tokens.items():
|
|
81
|
+
tokens_dict[model] = dict(usage)
|
|
82
|
+
attributes["tokens"] = json.dumps(tokens_dict)
|
|
83
|
+
|
|
84
|
+
# Models used
|
|
85
|
+
if self.models:
|
|
86
|
+
attributes["models"] = json.dumps(sorted(list(self.models)))
|
|
87
|
+
|
|
88
|
+
# PII information
|
|
89
|
+
attributes["has_pii"] = str(self.has_pii).lower()
|
|
90
|
+
if self.pii_entities:
|
|
91
|
+
attributes["pii_entities"] = json.dumps(sorted(list(self.pii_entities)))
|
|
92
|
+
if self.pii_actions:
|
|
93
|
+
pii_actions_dict = {}
|
|
94
|
+
for action, entities in self.pii_actions.items():
|
|
95
|
+
pii_actions_dict[action] = sorted(list(entities))
|
|
96
|
+
attributes["pii_actions"] = json.dumps(pii_actions_dict)
|
|
97
|
+
|
|
98
|
+
# Violation information
|
|
99
|
+
attributes["has_violation"] = str(self.has_violation).lower()
|
|
100
|
+
if self.violations:
|
|
101
|
+
attributes["violations"] = json.dumps(sorted(list(self.violations)))
|
|
102
|
+
if self.violation_actions:
|
|
103
|
+
violation_actions_dict = {}
|
|
104
|
+
for action, violations in self.violation_actions.items():
|
|
105
|
+
violation_actions_dict[action] = sorted(list(violations))
|
|
106
|
+
attributes["violation_actions"] = json.dumps(violation_actions_dict)
|
|
107
|
+
|
|
108
|
+
return attributes
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class SpanAggregationProcessor(SpanProcessor): # type: ignore[misc]
|
|
112
|
+
"""
|
|
113
|
+
OpenTelemetry span processor that aggregates data from child spans into parent spans.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(self) -> None:
|
|
117
|
+
self._span_data: Dict[str, SpanAggregationData] = {}
|
|
118
|
+
self._span_hierarchy: Dict[str, Optional[str]] = {} # child_id -> parent_id
|
|
119
|
+
self._root_spans: Set[str] = set()
|
|
120
|
+
self._captured_data: Dict[str, Dict[str, Any]] = {} # span_id -> {attributes, events}
|
|
121
|
+
self._active_spans: Dict[str, Span] = {} # span_id -> original span reference
|
|
122
|
+
|
|
123
|
+
def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None:
|
|
124
|
+
"""Called when a span starts."""
|
|
125
|
+
span_id = self._get_span_id(span)
|
|
126
|
+
if not span_id:
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
# Store the original span for later use
|
|
130
|
+
self._active_spans[span_id] = span
|
|
131
|
+
|
|
132
|
+
# Initialize aggregation data
|
|
133
|
+
self._span_data[span_id] = SpanAggregationData()
|
|
134
|
+
self._captured_data[span_id] = {"attributes": {}, "events": []}
|
|
135
|
+
|
|
136
|
+
# Check if this is a root span (no parent)
|
|
137
|
+
if span.parent is None:
|
|
138
|
+
self._root_spans.add(span_id)
|
|
139
|
+
else:
|
|
140
|
+
# Track parent-child relationship - span.parent is a SpanContext, not a Span
|
|
141
|
+
try:
|
|
142
|
+
parent_span_context = span.parent
|
|
143
|
+
if parent_span_context and parent_span_context.span_id:
|
|
144
|
+
parent_span_id = f"{parent_span_context.trace_id:032x}-{parent_span_context.span_id:016x}"
|
|
145
|
+
self._span_hierarchy[span_id] = parent_span_id
|
|
146
|
+
else:
|
|
147
|
+
logger.warning(f"DEBUG: Parent span context is invalid for child {span_id}")
|
|
148
|
+
except Exception as e:
|
|
149
|
+
logger.warning(f"DEBUG: Could not get parent span ID for child {span_id}: {e}")
|
|
150
|
+
|
|
151
|
+
# Wrap span methods to capture data
|
|
152
|
+
self._wrap_span_methods(span, span_id)
|
|
153
|
+
|
|
154
|
+
def on_end(self, span: Span) -> None:
|
|
155
|
+
"""Called when a span ends."""
|
|
156
|
+
span_id = self._get_span_id(span)
|
|
157
|
+
if not span_id or span_id not in self._span_data:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
# Process this span's captured data
|
|
162
|
+
captured = self._captured_data.get(span_id, {})
|
|
163
|
+
self._process_attributes(self._span_data[span_id], captured.get("attributes", {}))
|
|
164
|
+
|
|
165
|
+
# Set aggregated attributes on this span
|
|
166
|
+
original_span = self._active_spans.get(span_id)
|
|
167
|
+
if original_span and original_span.is_recording():
|
|
168
|
+
self._set_span_attributes(original_span, self._span_data[span_id])
|
|
169
|
+
|
|
170
|
+
# Handle parent-child aggregation for any remaining data
|
|
171
|
+
self._aggregate_to_all_parents(span_id)
|
|
172
|
+
|
|
173
|
+
except Exception as e:
|
|
174
|
+
logger.error(f"Error during span aggregation for span {span_id}: {e}")
|
|
175
|
+
# Even if there's an error, try to do basic aggregation
|
|
176
|
+
try:
|
|
177
|
+
original_span = self._active_spans.get(span_id)
|
|
178
|
+
if original_span and original_span.is_recording():
|
|
179
|
+
self._set_span_attributes(original_span, self._span_data[span_id])
|
|
180
|
+
except Exception as inner_e:
|
|
181
|
+
logger.error(f"Failed to set basic aggregation attributes: {inner_e}")
|
|
182
|
+
|
|
183
|
+
# Clean up
|
|
184
|
+
self._span_data.pop(span_id, None)
|
|
185
|
+
self._captured_data.pop(span_id, None)
|
|
186
|
+
self._active_spans.pop(span_id, None)
|
|
187
|
+
self._root_spans.discard(span_id)
|
|
188
|
+
self._span_hierarchy.pop(span_id, None)
|
|
189
|
+
|
|
190
|
+
def _wrap_span_methods(self, span: Span, span_id: str) -> Any:
|
|
191
|
+
"""Wrap span methods to capture attributes and events."""
|
|
192
|
+
# Wrap set_attribute
|
|
193
|
+
original_set_attribute = span.set_attribute
|
|
194
|
+
|
|
195
|
+
def wrapped_set_attribute(key: str, value: Any) -> Any:
|
|
196
|
+
# Status code processing
|
|
197
|
+
if key == "http.status_code":
|
|
198
|
+
self._status_code_processing(value)
|
|
199
|
+
|
|
200
|
+
# Capture the all the attribute data
|
|
201
|
+
self._captured_data[span_id]["attributes"][key] = value
|
|
202
|
+
return original_set_attribute(key, value)
|
|
203
|
+
|
|
204
|
+
span.set_attribute = wrapped_set_attribute
|
|
205
|
+
|
|
206
|
+
# Wrap add_event
|
|
207
|
+
original_add_event = span.add_event
|
|
208
|
+
|
|
209
|
+
def wrapped_add_event(name: str, attributes: Dict[str, Any] = {}, timestamp: int = 0) -> Any:
|
|
210
|
+
# Only process PII and violation events
|
|
211
|
+
if name == "pii_detected" and attributes:
|
|
212
|
+
self._process_pii_event(self._span_data[span_id], attributes)
|
|
213
|
+
if span.is_recording():
|
|
214
|
+
self._set_span_attributes(span, self._span_data[span_id])
|
|
215
|
+
# Immediately aggregate to parent spans
|
|
216
|
+
self._aggregate_to_all_parents(span_id)
|
|
217
|
+
elif name == "violation_detected" and attributes:
|
|
218
|
+
self._process_violation_event(self._span_data[span_id], attributes)
|
|
219
|
+
if span.is_recording():
|
|
220
|
+
self._set_span_attributes(span, self._span_data[span_id])
|
|
221
|
+
# Immediately aggregate to parent spans
|
|
222
|
+
self._aggregate_to_all_parents(span_id)
|
|
223
|
+
|
|
224
|
+
# Check if span is still recording before adding event
|
|
225
|
+
if not span.is_recording():
|
|
226
|
+
logger.debug(f"Attempted to add event to ended span {span_id}")
|
|
227
|
+
return None
|
|
228
|
+
return original_add_event(name, attributes, timestamp)
|
|
229
|
+
|
|
230
|
+
span.add_event = wrapped_add_event
|
|
231
|
+
|
|
232
|
+
def _process_attributes(self, data: SpanAggregationData, attributes: Dict[str, Any]) -> None:
|
|
233
|
+
"""Process span attributes for aggregation."""
|
|
234
|
+
# Extract status code for error identification
|
|
235
|
+
status_code = attributes.get("http.status_code", 200)
|
|
236
|
+
if httpx.codes.is_error(status_code):
|
|
237
|
+
data.has_error = True
|
|
238
|
+
data.status_codes.update([status_code])
|
|
239
|
+
|
|
240
|
+
# Extract model information
|
|
241
|
+
model = attributes.get("gen_ai.request.model") or attributes.get("gen_ai.response.model")
|
|
242
|
+
if model:
|
|
243
|
+
data.models.add(model)
|
|
244
|
+
# Extract token usage
|
|
245
|
+
token_fields = {
|
|
246
|
+
"prompt_tokens": attributes.get("gen_ai.usage.prompt_tokens", 0),
|
|
247
|
+
"completion_tokens": attributes.get("gen_ai.usage.completion_tokens", 0),
|
|
248
|
+
"total_tokens": attributes.get("llm.usage.total_tokens", 0),
|
|
249
|
+
"cache_read_input_tokens": attributes.get("gen_ai.usage.cache_read_input_tokens", 0),
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
# Initialize token fields if they don't exist
|
|
253
|
+
if model not in data.tokens:
|
|
254
|
+
data.tokens[model] = {}
|
|
255
|
+
|
|
256
|
+
# Add token values
|
|
257
|
+
for field, value in token_fields.items():
|
|
258
|
+
if isinstance(value, (int, str)):
|
|
259
|
+
current_value = data.tokens[model].get(field, 0)
|
|
260
|
+
data.tokens[model][field] = current_value + int(value)
|
|
261
|
+
|
|
262
|
+
def _process_pii_event(self, data: SpanAggregationData, attrs: Dict[str, Any]) -> None:
|
|
263
|
+
"""Process pii_detected event."""
|
|
264
|
+
if attrs.get("has_pii"):
|
|
265
|
+
data.has_pii = True
|
|
266
|
+
|
|
267
|
+
# Extract entities from pii_entities field
|
|
268
|
+
entity_counts_str = attrs.get("pii_entities")
|
|
269
|
+
if entity_counts_str:
|
|
270
|
+
try:
|
|
271
|
+
entity_counts = (
|
|
272
|
+
json.loads(entity_counts_str) if isinstance(entity_counts_str, str) else entity_counts_str
|
|
273
|
+
)
|
|
274
|
+
if isinstance(entity_counts, dict):
|
|
275
|
+
entities = set(entity_counts.keys())
|
|
276
|
+
data.pii_entities.update(entities)
|
|
277
|
+
|
|
278
|
+
# Determine action
|
|
279
|
+
if attrs.get("is_blocked"):
|
|
280
|
+
data.pii_actions["BLOCK"].update(entities)
|
|
281
|
+
elif attrs.get("is_masked"):
|
|
282
|
+
data.pii_actions["MASK"].update(entities)
|
|
283
|
+
else:
|
|
284
|
+
data.pii_actions["FLAG"].update(entities)
|
|
285
|
+
except (json.JSONDecodeError, TypeError):
|
|
286
|
+
logger.error(f"Failed to parse pii_entities: {entity_counts_str}")
|
|
287
|
+
|
|
288
|
+
def _process_violation_event(self, data: SpanAggregationData, attrs: Dict[str, Any]) -> None:
|
|
289
|
+
"""Process violation_detected event."""
|
|
290
|
+
if attrs.get("has_violation"):
|
|
291
|
+
data.has_violation = True
|
|
292
|
+
violations = attrs.get("violations", [])
|
|
293
|
+
if violations:
|
|
294
|
+
data.violations.update(violations)
|
|
295
|
+
# Set action based on is_blocked flag
|
|
296
|
+
action = "BLOCK" if attrs.get("is_blocked") else "FLAG"
|
|
297
|
+
data.violation_actions[action].update(violations)
|
|
298
|
+
|
|
299
|
+
def _aggregate_to_all_parents(self, child_span_id: str) -> None:
|
|
300
|
+
"""Aggregate data from child span to all its parent spans in the hierarchy."""
|
|
301
|
+
if child_span_id not in self._span_data:
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
child_data = self._span_data[child_span_id]
|
|
305
|
+
current_span_id = child_span_id
|
|
306
|
+
|
|
307
|
+
# Traverse up the parent hierarchy
|
|
308
|
+
while True:
|
|
309
|
+
parent_id = self._span_hierarchy.get(current_span_id)
|
|
310
|
+
if not parent_id or parent_id not in self._span_data:
|
|
311
|
+
break
|
|
312
|
+
|
|
313
|
+
# Merge child data into parent
|
|
314
|
+
self._span_data[parent_id].merge_from_other(child_data)
|
|
315
|
+
|
|
316
|
+
# Update parent span attributes if it's still active and recording
|
|
317
|
+
parent_span = self._active_spans.get(parent_id)
|
|
318
|
+
if parent_span and parent_span.is_recording():
|
|
319
|
+
self._set_span_attributes(parent_span, self._span_data[parent_id])
|
|
320
|
+
|
|
321
|
+
# Move up to the next parent
|
|
322
|
+
current_span_id = parent_id
|
|
323
|
+
|
|
324
|
+
def _set_span_attributes(self, span: Span, data: SpanAggregationData) -> None:
|
|
325
|
+
"""Set aggregated attributes on the given span."""
|
|
326
|
+
try:
|
|
327
|
+
aggregated_attrs = data.to_attributes()
|
|
328
|
+
# Set all aggregated attributes under a single 'aggregator' key as a JSON object
|
|
329
|
+
span.set_attribute(f"{Config.LIBRARY_NAME}.aggregated_attributes", json.dumps(aggregated_attrs))
|
|
330
|
+
except Exception as e:
|
|
331
|
+
logger.error(f"Failed to set aggregated attributes: {e}")
|
|
332
|
+
|
|
333
|
+
def _get_span_id(self, span: Span) -> Optional[str]:
|
|
334
|
+
"""Get a unique identifier for the span."""
|
|
335
|
+
try:
|
|
336
|
+
span_context = span.get_span_context()
|
|
337
|
+
return f"{span_context.trace_id:032x}-{span_context.span_id:016x}"
|
|
338
|
+
except Exception:
|
|
339
|
+
return None
|
|
340
|
+
|
|
341
|
+
def _get_span_id_from_context(self, context: Context) -> Optional[str]:
|
|
342
|
+
"""Extract span ID from context."""
|
|
343
|
+
if context:
|
|
344
|
+
span_context = trace.get_current_span(context).get_span_context()
|
|
345
|
+
if span_context and span_context.span_id:
|
|
346
|
+
return f"{span_context.trace_id:032x}-{span_context.span_id:016x}"
|
|
347
|
+
return None
|
|
348
|
+
|
|
349
|
+
def _status_code_processing(self, status_code: int) -> None:
|
|
350
|
+
if httpx.codes.is_error(status_code):
|
|
351
|
+
event_attributes = {"has_error": True, "status_code": status_code}
|
|
352
|
+
Netra.set_custom_event(event_name="error_detected", attributes=event_attributes)
|
|
353
|
+
|
|
354
|
+
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
|
355
|
+
"""Force flush any pending data."""
|
|
356
|
+
return True
|
|
357
|
+
|
|
358
|
+
def shutdown(self) -> bool:
|
|
359
|
+
"""Shutdown the processor."""
|
|
360
|
+
self._span_data.clear()
|
|
361
|
+
self._span_hierarchy.clear()
|
|
362
|
+
self._root_spans.clear()
|
|
363
|
+
self._captured_data.clear()
|
|
364
|
+
self._active_spans.clear()
|
|
365
|
+
return True
|
netra/scanner.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Scanner module for Netra SDK to implement various scanning capabilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import Optional, Tuple
|
|
8
|
+
|
|
9
|
+
from netra.exceptions import InjectionException
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Scanner(ABC):
|
|
15
|
+
"""
|
|
16
|
+
Abstract base class for scanner implementations.
|
|
17
|
+
|
|
18
|
+
Scanners can analyze and process input prompts for various purposes
|
|
19
|
+
such as security checks, content moderation, etc.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def scan(self, prompt: str) -> Tuple[str, bool, float]:
|
|
24
|
+
"""
|
|
25
|
+
Scan the input prompt and return the sanitized prompt, validity flag, and risk score.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
prompt: The input prompt to scan
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple containing:
|
|
32
|
+
- sanitized_prompt: The potentially modified prompt after scanning
|
|
33
|
+
- is_valid: Boolean indicating if the prompt passed the scan
|
|
34
|
+
- risk_score: A score between 0.0 and 1.0 indicating the risk level
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PromptInjection(Scanner):
|
|
39
|
+
"""
|
|
40
|
+
A scanner implementation that detects and handles prompt injection attempts.
|
|
41
|
+
|
|
42
|
+
This scanner uses llm_guard's PromptInjection scanner under the hood.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, threshold: float = 0.5, match_type: Optional[str] = None):
|
|
46
|
+
"""
|
|
47
|
+
Initialize the PromptInjection scanner.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
threshold: The threshold value (between 0.0 and 1.0) above which a prompt is considered risky
|
|
51
|
+
match_type: The type of matching to use
|
|
52
|
+
(from llm_guard.input_scanners.prompt_injection.MatchType)
|
|
53
|
+
"""
|
|
54
|
+
self.threshold = threshold
|
|
55
|
+
self.scanner = None
|
|
56
|
+
self.llm_guard_available = False
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
from llm_guard.input_scanners import PromptInjection as LLMGuardPromptInjection
|
|
60
|
+
from llm_guard.input_scanners.prompt_injection import MatchType
|
|
61
|
+
|
|
62
|
+
if match_type is None:
|
|
63
|
+
match_type = MatchType.FULL
|
|
64
|
+
|
|
65
|
+
self.scanner = LLMGuardPromptInjection(threshold=threshold, match_type=match_type)
|
|
66
|
+
self.llm_guard_available = True
|
|
67
|
+
except ImportError:
|
|
68
|
+
logger.warning(
|
|
69
|
+
"llm-guard package is not installed. Prompt injection scanning will be limited. "
|
|
70
|
+
"To enable full functionality, install with: pip install 'netra-sdk[llm_guard]'"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def scan(self, prompt: str) -> Tuple[str, bool, float]:
|
|
74
|
+
"""
|
|
75
|
+
Scan the input prompt for potential prompt injection attempts.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
prompt: The input prompt to scan
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Tuple containing:
|
|
82
|
+
- sanitized_prompt: The potentially modified prompt after scanning
|
|
83
|
+
- is_valid: Boolean indicating if the prompt passed the scan
|
|
84
|
+
- risk_score: A score between 0.0 and 1.0 indicating the risk level
|
|
85
|
+
"""
|
|
86
|
+
if not self.llm_guard_available or self.scanner is None:
|
|
87
|
+
# Simple fallback when llm-guard is not available
|
|
88
|
+
# Always pass validation but log a warning
|
|
89
|
+
logger.warning(
|
|
90
|
+
"Using fallback prompt injection detection (llm-guard not available). "
|
|
91
|
+
"Install the llm_guard optional dependency for full protection."
|
|
92
|
+
)
|
|
93
|
+
return prompt, True, 0.0
|
|
94
|
+
|
|
95
|
+
# Use llm_guard's scanner to check for prompt injection
|
|
96
|
+
assert self.scanner is not None # This helps mypy understand self.scanner is not None here
|
|
97
|
+
sanitized_prompt, is_valid, risk_score = self.scanner.scan(prompt)
|
|
98
|
+
if not is_valid:
|
|
99
|
+
raise InjectionException(
|
|
100
|
+
message="Input blocked: detected prompt injection",
|
|
101
|
+
has_violation=True,
|
|
102
|
+
violations=["prompt_injection"],
|
|
103
|
+
)
|
|
104
|
+
return sanitized_prompt, is_valid, risk_score
|