aitracer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aitracer-0.1.0.dist-info/METADATA +234 -0
- aitracer-0.1.0.dist-info/RECORD +15 -0
- aitracer-0.1.0.dist-info/WHEEL +4 -0
- python/__init__.py +30 -0
- python/client.py +437 -0
- python/integrations/__init__.py +5 -0
- python/integrations/langchain.py +452 -0
- python/pii.py +223 -0
- python/queue.py +219 -0
- python/session.py +144 -0
- python/trace.py +65 -0
- python/wrappers/__init__.py +7 -0
- python/wrappers/anthropic_wrapper.py +208 -0
- python/wrappers/gemini_wrapper.py +350 -0
- python/wrappers/openai_wrapper.py +193 -0
|
@@ -0,0 +1,452 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LangChain integration for AITracer.
|
|
3
|
+
|
|
4
|
+
This module provides a callback handler that automatically logs
|
|
5
|
+
LangChain LLM calls to AITracer.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
from aitracer.integrations.langchain import AITracerCallbackHandler
|
|
9
|
+
from langchain_openai import ChatOpenAI
|
|
10
|
+
|
|
11
|
+
handler = AITracerCallbackHandler(api_key="your-api-key", project="my-project")
|
|
12
|
+
|
|
13
|
+
llm = ChatOpenAI(callbacks=[handler])
|
|
14
|
+
response = llm.invoke("Hello!")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import time
|
|
20
|
+
import uuid
|
|
21
|
+
from decimal import Decimal
|
|
22
|
+
from typing import Any, Dict, List, Optional, Union
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
26
|
+
from langchain_core.outputs import LLMResult
|
|
27
|
+
from langchain_core.messages import BaseMessage
|
|
28
|
+
LANGCHAIN_AVAILABLE = True
|
|
29
|
+
except ImportError:
|
|
30
|
+
LANGCHAIN_AVAILABLE = False
|
|
31
|
+
# Create a dummy base class for type hints
|
|
32
|
+
class BaseCallbackHandler:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AITracerCallbackHandler(BaseCallbackHandler):
|
|
37
|
+
"""
|
|
38
|
+
LangChain callback handler that logs LLM calls to AITracer.
|
|
39
|
+
|
|
40
|
+
This handler captures:
|
|
41
|
+
- LLM calls (start, end, error)
|
|
42
|
+
- Chain execution
|
|
43
|
+
- Tool usage
|
|
44
|
+
- Token usage and costs
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
>>> from aitracer.integrations.langchain import AITracerCallbackHandler
|
|
48
|
+
>>> from langchain_openai import ChatOpenAI
|
|
49
|
+
>>>
|
|
50
|
+
>>> handler = AITracerCallbackHandler(
|
|
51
|
+
... api_key="at-xxx",
|
|
52
|
+
... project="my-langchain-app"
|
|
53
|
+
... )
|
|
54
|
+
>>> llm = ChatOpenAI(callbacks=[handler])
|
|
55
|
+
>>> response = llm.invoke("What is the capital of France?")
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
api_key: Optional[str] = None,
|
|
61
|
+
project: Optional[str] = None,
|
|
62
|
+
base_url: str = "https://api.aitracer.co",
|
|
63
|
+
session_id: Optional[str] = None,
|
|
64
|
+
user_id: Optional[str] = None,
|
|
65
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
66
|
+
flush_on_chain_end: bool = True,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Initialize the AITracer callback handler.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
api_key: AITracer API key. If not provided, uses AITRACER_API_KEY env var.
|
|
73
|
+
project: Project name for logging.
|
|
74
|
+
base_url: AITracer API base URL.
|
|
75
|
+
session_id: Optional session ID for grouping related calls.
|
|
76
|
+
user_id: Optional user ID for tracking user interactions.
|
|
77
|
+
metadata: Additional metadata to include with all logs.
|
|
78
|
+
flush_on_chain_end: Whether to flush logs when a chain completes.
|
|
79
|
+
"""
|
|
80
|
+
if not LANGCHAIN_AVAILABLE:
|
|
81
|
+
raise ImportError(
|
|
82
|
+
"LangChain is not installed. Install it with: pip install langchain-core"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
super().__init__()
|
|
86
|
+
|
|
87
|
+
# Import AITracer client
|
|
88
|
+
from aitracer import AITracer
|
|
89
|
+
|
|
90
|
+
self.client = AITracer(api_key=api_key, project=project, base_url=base_url)
|
|
91
|
+
self.project = project
|
|
92
|
+
self.session_id = session_id or str(uuid.uuid4())
|
|
93
|
+
self.user_id = user_id
|
|
94
|
+
self.default_metadata = metadata or {}
|
|
95
|
+
self.flush_on_chain_end = flush_on_chain_end
|
|
96
|
+
|
|
97
|
+
# Track active runs
|
|
98
|
+
self._runs: Dict[str, Dict[str, Any]] = {}
|
|
99
|
+
self._chain_stack: List[str] = []
|
|
100
|
+
|
|
101
|
+
def _get_trace_id(self) -> str:
|
|
102
|
+
"""Get the current trace ID (from chain or generate new)."""
|
|
103
|
+
if self._chain_stack:
|
|
104
|
+
return self._chain_stack[0]
|
|
105
|
+
return str(uuid.uuid4())
|
|
106
|
+
|
|
107
|
+
def _get_parent_run_id(self) -> Optional[str]:
|
|
108
|
+
"""Get the parent run ID if in a chain."""
|
|
109
|
+
if len(self._chain_stack) > 1:
|
|
110
|
+
return self._chain_stack[-1]
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
def _extract_model_info(self, serialized: Dict[str, Any]) -> Dict[str, str]:
|
|
114
|
+
"""Extract provider and model from serialized LLM info."""
|
|
115
|
+
# Default values
|
|
116
|
+
provider = "unknown"
|
|
117
|
+
model = "unknown"
|
|
118
|
+
|
|
119
|
+
# Try to extract from kwargs
|
|
120
|
+
kwargs = serialized.get("kwargs", {})
|
|
121
|
+
|
|
122
|
+
# Check for model name in various places
|
|
123
|
+
if "model" in kwargs:
|
|
124
|
+
model = kwargs["model"]
|
|
125
|
+
elif "model_name" in kwargs:
|
|
126
|
+
model = kwargs["model_name"]
|
|
127
|
+
|
|
128
|
+
# Determine provider from class name or model name
|
|
129
|
+
class_name = serialized.get("name", "").lower()
|
|
130
|
+
if "openai" in class_name:
|
|
131
|
+
provider = "openai"
|
|
132
|
+
elif "anthropic" in class_name or "claude" in class_name:
|
|
133
|
+
provider = "anthropic"
|
|
134
|
+
elif "google" in class_name or "gemini" in class_name:
|
|
135
|
+
provider = "google"
|
|
136
|
+
elif "mistral" in class_name:
|
|
137
|
+
provider = "mistral"
|
|
138
|
+
elif "cohere" in class_name:
|
|
139
|
+
provider = "cohere"
|
|
140
|
+
|
|
141
|
+
# Also check model name for provider hints
|
|
142
|
+
model_lower = model.lower()
|
|
143
|
+
if "gpt" in model_lower or model_lower.startswith("o1"):
|
|
144
|
+
provider = "openai"
|
|
145
|
+
elif "claude" in model_lower:
|
|
146
|
+
provider = "anthropic"
|
|
147
|
+
elif "gemini" in model_lower:
|
|
148
|
+
provider = "google"
|
|
149
|
+
|
|
150
|
+
return {"provider": provider, "model": model}
|
|
151
|
+
|
|
152
|
+
def _format_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
|
153
|
+
"""Format messages for logging."""
|
|
154
|
+
formatted = []
|
|
155
|
+
for msg in messages:
|
|
156
|
+
if isinstance(msg, dict):
|
|
157
|
+
formatted.append(msg)
|
|
158
|
+
elif hasattr(msg, "type") and hasattr(msg, "content"):
|
|
159
|
+
formatted.append({
|
|
160
|
+
"role": msg.type,
|
|
161
|
+
"content": msg.content,
|
|
162
|
+
})
|
|
163
|
+
elif hasattr(msg, "role") and hasattr(msg, "content"):
|
|
164
|
+
formatted.append({
|
|
165
|
+
"role": msg.role,
|
|
166
|
+
"content": msg.content,
|
|
167
|
+
})
|
|
168
|
+
else:
|
|
169
|
+
formatted.append({"role": "unknown", "content": str(msg)})
|
|
170
|
+
return formatted
|
|
171
|
+
|
|
172
|
+
# ========== LLM Callbacks ==========
|
|
173
|
+
|
|
174
|
+
def on_llm_start(
|
|
175
|
+
self,
|
|
176
|
+
serialized: Dict[str, Any],
|
|
177
|
+
prompts: List[str],
|
|
178
|
+
*,
|
|
179
|
+
run_id: uuid.UUID,
|
|
180
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
181
|
+
tags: Optional[List[str]] = None,
|
|
182
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
183
|
+
**kwargs: Any,
|
|
184
|
+
) -> Any:
|
|
185
|
+
"""Called when an LLM starts processing."""
|
|
186
|
+
run_id_str = str(run_id)
|
|
187
|
+
model_info = self._extract_model_info(serialized)
|
|
188
|
+
|
|
189
|
+
self._runs[run_id_str] = {
|
|
190
|
+
"type": "llm",
|
|
191
|
+
"start_time": time.time(),
|
|
192
|
+
"provider": model_info["provider"],
|
|
193
|
+
"model": model_info["model"],
|
|
194
|
+
"input": {"prompts": prompts},
|
|
195
|
+
"trace_id": self._get_trace_id(),
|
|
196
|
+
"parent_run_id": str(parent_run_id) if parent_run_id else self._get_parent_run_id(),
|
|
197
|
+
"tags": tags or [],
|
|
198
|
+
"metadata": {**self.default_metadata, **(metadata or {})},
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
def on_chat_model_start(
|
|
202
|
+
self,
|
|
203
|
+
serialized: Dict[str, Any],
|
|
204
|
+
messages: List[List[BaseMessage]],
|
|
205
|
+
*,
|
|
206
|
+
run_id: uuid.UUID,
|
|
207
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
208
|
+
tags: Optional[List[str]] = None,
|
|
209
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
210
|
+
**kwargs: Any,
|
|
211
|
+
) -> Any:
|
|
212
|
+
"""Called when a chat model starts processing."""
|
|
213
|
+
run_id_str = str(run_id)
|
|
214
|
+
model_info = self._extract_model_info(serialized)
|
|
215
|
+
|
|
216
|
+
# Format messages for logging
|
|
217
|
+
formatted_messages = []
|
|
218
|
+
for msg_list in messages:
|
|
219
|
+
formatted_messages.extend(self._format_messages(msg_list))
|
|
220
|
+
|
|
221
|
+
self._runs[run_id_str] = {
|
|
222
|
+
"type": "chat",
|
|
223
|
+
"start_time": time.time(),
|
|
224
|
+
"provider": model_info["provider"],
|
|
225
|
+
"model": model_info["model"],
|
|
226
|
+
"input": {"messages": formatted_messages},
|
|
227
|
+
"trace_id": self._get_trace_id(),
|
|
228
|
+
"parent_run_id": str(parent_run_id) if parent_run_id else self._get_parent_run_id(),
|
|
229
|
+
"tags": tags or [],
|
|
230
|
+
"metadata": {**self.default_metadata, **(metadata or {})},
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
def on_llm_end(
|
|
234
|
+
self,
|
|
235
|
+
response: LLMResult,
|
|
236
|
+
*,
|
|
237
|
+
run_id: uuid.UUID,
|
|
238
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
239
|
+
**kwargs: Any,
|
|
240
|
+
) -> Any:
|
|
241
|
+
"""Called when an LLM finishes processing."""
|
|
242
|
+
run_id_str = str(run_id)
|
|
243
|
+
run_data = self._runs.pop(run_id_str, None)
|
|
244
|
+
|
|
245
|
+
if not run_data:
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
latency_ms = int((time.time() - run_data["start_time"]) * 1000)
|
|
249
|
+
|
|
250
|
+
# Extract token usage
|
|
251
|
+
token_usage = {}
|
|
252
|
+
if response.llm_output:
|
|
253
|
+
token_usage = response.llm_output.get("token_usage", {})
|
|
254
|
+
|
|
255
|
+
# Extract output
|
|
256
|
+
output_text = ""
|
|
257
|
+
if response.generations:
|
|
258
|
+
for gen_list in response.generations:
|
|
259
|
+
for gen in gen_list:
|
|
260
|
+
output_text += gen.text
|
|
261
|
+
|
|
262
|
+
# Log to AITracer
|
|
263
|
+
self.client.log(
|
|
264
|
+
trace_id=run_data["trace_id"],
|
|
265
|
+
span_id=run_id_str,
|
|
266
|
+
parent_span_id=run_data.get("parent_run_id"),
|
|
267
|
+
provider=run_data["provider"],
|
|
268
|
+
model=run_data["model"],
|
|
269
|
+
input=run_data["input"],
|
|
270
|
+
output={"content": output_text},
|
|
271
|
+
input_tokens=token_usage.get("prompt_tokens", 0),
|
|
272
|
+
output_tokens=token_usage.get("completion_tokens", 0),
|
|
273
|
+
latency_ms=latency_ms,
|
|
274
|
+
status="success",
|
|
275
|
+
metadata={
|
|
276
|
+
**run_data["metadata"],
|
|
277
|
+
"tags": run_data["tags"],
|
|
278
|
+
"session_id": self.session_id,
|
|
279
|
+
"user_id": self.user_id,
|
|
280
|
+
"source": "langchain",
|
|
281
|
+
},
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def on_llm_error(
|
|
285
|
+
self,
|
|
286
|
+
error: Union[Exception, KeyboardInterrupt],
|
|
287
|
+
*,
|
|
288
|
+
run_id: uuid.UUID,
|
|
289
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
290
|
+
**kwargs: Any,
|
|
291
|
+
) -> Any:
|
|
292
|
+
"""Called when an LLM errors."""
|
|
293
|
+
run_id_str = str(run_id)
|
|
294
|
+
run_data = self._runs.pop(run_id_str, None)
|
|
295
|
+
|
|
296
|
+
if not run_data:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
latency_ms = int((time.time() - run_data["start_time"]) * 1000)
|
|
300
|
+
|
|
301
|
+
# Log error to AITracer
|
|
302
|
+
self.client.log(
|
|
303
|
+
trace_id=run_data["trace_id"],
|
|
304
|
+
span_id=run_id_str,
|
|
305
|
+
parent_span_id=run_data.get("parent_run_id"),
|
|
306
|
+
provider=run_data["provider"],
|
|
307
|
+
model=run_data["model"],
|
|
308
|
+
input=run_data["input"],
|
|
309
|
+
output=None,
|
|
310
|
+
latency_ms=latency_ms,
|
|
311
|
+
status="error",
|
|
312
|
+
error_message=str(error),
|
|
313
|
+
metadata={
|
|
314
|
+
**run_data["metadata"],
|
|
315
|
+
"tags": run_data["tags"],
|
|
316
|
+
"session_id": self.session_id,
|
|
317
|
+
"user_id": self.user_id,
|
|
318
|
+
"source": "langchain",
|
|
319
|
+
},
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# ========== Chain Callbacks ==========
|
|
323
|
+
|
|
324
|
+
def on_chain_start(
|
|
325
|
+
self,
|
|
326
|
+
serialized: Dict[str, Any],
|
|
327
|
+
inputs: Dict[str, Any],
|
|
328
|
+
*,
|
|
329
|
+
run_id: uuid.UUID,
|
|
330
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
331
|
+
tags: Optional[List[str]] = None,
|
|
332
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
333
|
+
**kwargs: Any,
|
|
334
|
+
) -> Any:
|
|
335
|
+
"""Called when a chain starts."""
|
|
336
|
+
run_id_str = str(run_id)
|
|
337
|
+
self._chain_stack.append(run_id_str)
|
|
338
|
+
|
|
339
|
+
self._runs[run_id_str] = {
|
|
340
|
+
"type": "chain",
|
|
341
|
+
"start_time": time.time(),
|
|
342
|
+
"name": serialized.get("name", "unknown"),
|
|
343
|
+
"inputs": inputs,
|
|
344
|
+
"tags": tags or [],
|
|
345
|
+
"metadata": {**self.default_metadata, **(metadata or {})},
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
def on_chain_end(
|
|
349
|
+
self,
|
|
350
|
+
outputs: Dict[str, Any],
|
|
351
|
+
*,
|
|
352
|
+
run_id: uuid.UUID,
|
|
353
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
354
|
+
**kwargs: Any,
|
|
355
|
+
) -> Any:
|
|
356
|
+
"""Called when a chain ends."""
|
|
357
|
+
run_id_str = str(run_id)
|
|
358
|
+
|
|
359
|
+
if run_id_str in self._chain_stack:
|
|
360
|
+
self._chain_stack.remove(run_id_str)
|
|
361
|
+
|
|
362
|
+
self._runs.pop(run_id_str, None)
|
|
363
|
+
|
|
364
|
+
# Flush logs if configured
|
|
365
|
+
if self.flush_on_chain_end and not self._chain_stack:
|
|
366
|
+
self.client.flush()
|
|
367
|
+
|
|
368
|
+
def on_chain_error(
|
|
369
|
+
self,
|
|
370
|
+
error: Union[Exception, KeyboardInterrupt],
|
|
371
|
+
*,
|
|
372
|
+
run_id: uuid.UUID,
|
|
373
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
374
|
+
**kwargs: Any,
|
|
375
|
+
) -> Any:
|
|
376
|
+
"""Called when a chain errors."""
|
|
377
|
+
run_id_str = str(run_id)
|
|
378
|
+
|
|
379
|
+
if run_id_str in self._chain_stack:
|
|
380
|
+
self._chain_stack.remove(run_id_str)
|
|
381
|
+
|
|
382
|
+
self._runs.pop(run_id_str, None)
|
|
383
|
+
|
|
384
|
+
# ========== Tool Callbacks ==========
|
|
385
|
+
|
|
386
|
+
def on_tool_start(
|
|
387
|
+
self,
|
|
388
|
+
serialized: Dict[str, Any],
|
|
389
|
+
input_str: str,
|
|
390
|
+
*,
|
|
391
|
+
run_id: uuid.UUID,
|
|
392
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
393
|
+
tags: Optional[List[str]] = None,
|
|
394
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
395
|
+
**kwargs: Any,
|
|
396
|
+
) -> Any:
|
|
397
|
+
"""Called when a tool starts."""
|
|
398
|
+
run_id_str = str(run_id)
|
|
399
|
+
|
|
400
|
+
self._runs[run_id_str] = {
|
|
401
|
+
"type": "tool",
|
|
402
|
+
"start_time": time.time(),
|
|
403
|
+
"name": serialized.get("name", "unknown"),
|
|
404
|
+
"input": input_str,
|
|
405
|
+
"tags": tags or [],
|
|
406
|
+
"metadata": {**self.default_metadata, **(metadata or {})},
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
def on_tool_end(
|
|
410
|
+
self,
|
|
411
|
+
output: str,
|
|
412
|
+
*,
|
|
413
|
+
run_id: uuid.UUID,
|
|
414
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
415
|
+
**kwargs: Any,
|
|
416
|
+
) -> Any:
|
|
417
|
+
"""Called when a tool ends."""
|
|
418
|
+
run_id_str = str(run_id)
|
|
419
|
+
run_data = self._runs.pop(run_id_str, None)
|
|
420
|
+
|
|
421
|
+
# Tools are tracked but not logged as separate entries
|
|
422
|
+
# They appear in the chain context
|
|
423
|
+
|
|
424
|
+
def on_tool_error(
|
|
425
|
+
self,
|
|
426
|
+
error: Union[Exception, KeyboardInterrupt],
|
|
427
|
+
*,
|
|
428
|
+
run_id: uuid.UUID,
|
|
429
|
+
parent_run_id: Optional[uuid.UUID] = None,
|
|
430
|
+
**kwargs: Any,
|
|
431
|
+
) -> Any:
|
|
432
|
+
"""Called when a tool errors."""
|
|
433
|
+
run_id_str = str(run_id)
|
|
434
|
+
self._runs.pop(run_id_str, None)
|
|
435
|
+
|
|
436
|
+
# ========== Utility Methods ==========
|
|
437
|
+
|
|
438
|
+
def flush(self) -> None:
|
|
439
|
+
"""Flush any pending logs to AITracer."""
|
|
440
|
+
self.client.flush()
|
|
441
|
+
|
|
442
|
+
def set_session_id(self, session_id: str) -> None:
|
|
443
|
+
"""Set the session ID for subsequent calls."""
|
|
444
|
+
self.session_id = session_id
|
|
445
|
+
|
|
446
|
+
def set_user_id(self, user_id: str) -> None:
|
|
447
|
+
"""Set the user ID for subsequent calls."""
|
|
448
|
+
self.user_id = user_id
|
|
449
|
+
|
|
450
|
+
def add_metadata(self, key: str, value: Any) -> None:
|
|
451
|
+
"""Add metadata that will be included with all subsequent logs."""
|
|
452
|
+
self.default_metadata[key] = value
|
python/pii.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
"""PII Detection and Masking utilities for SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import re
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import Any, Optional
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PIIMatch:
|
|
13
|
+
"""Represents a detected PII match."""
|
|
14
|
+
|
|
15
|
+
pii_type: str
|
|
16
|
+
value: str
|
|
17
|
+
start: int
|
|
18
|
+
end: int
|
|
19
|
+
confidence: float = 1.0
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# PII Detection Patterns
|
|
23
|
+
PII_PATTERNS = {
|
|
24
|
+
"email": {
|
|
25
|
+
"pattern": r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
|
|
26
|
+
"mask": "[EMAIL]",
|
|
27
|
+
"confidence": 0.95,
|
|
28
|
+
},
|
|
29
|
+
"phone": {
|
|
30
|
+
# International and Japanese phone formats
|
|
31
|
+
"pattern": r"\b(?:\+?[0-9]{1,4}[-.\s]?)?(?:\(?[0-9]{2,4}\)?[-.\s]?)?[0-9]{2,4}[-.\s]?[0-9]{3,4}[-.\s]?[0-9]{3,4}\b",
|
|
32
|
+
"mask": "[PHONE]",
|
|
33
|
+
"confidence": 0.85,
|
|
34
|
+
},
|
|
35
|
+
"credit_card": {
|
|
36
|
+
# Major credit card formats (Visa, Mastercard, Amex, etc.)
|
|
37
|
+
"pattern": r"\b(?:4[0-9]{3}[-\s]?[0-9]{4}[-\s]?[0-9]{4}[-\s]?[0-9]{4}|5[1-5][0-9]{2}[-\s]?[0-9]{4}[-\s]?[0-9]{4}[-\s]?[0-9]{4}|3[47][0-9]{2}[-\s]?[0-9]{6}[-\s]?[0-9]{5}|6(?:011|5[0-9]{2})[-\s]?[0-9]{4}[-\s]?[0-9]{4}[-\s]?[0-9]{4})\b",
|
|
38
|
+
"mask": "[CREDIT_CARD]",
|
|
39
|
+
"confidence": 0.95,
|
|
40
|
+
},
|
|
41
|
+
"ssn": {
|
|
42
|
+
# US Social Security Number and Japanese My Number
|
|
43
|
+
"pattern": r"\b(?:[0-9]{3}[-\s]?[0-9]{2}[-\s]?[0-9]{4}|[0-9]{4}[-\s]?[0-9]{4}[-\s]?[0-9]{4})\b",
|
|
44
|
+
"mask": "[SSN]",
|
|
45
|
+
"confidence": 0.90,
|
|
46
|
+
},
|
|
47
|
+
"ip_address": {
|
|
48
|
+
# IPv4 and IPv6
|
|
49
|
+
"pattern": r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b|(?:[A-Fa-f0-9]{1,4}:){7}[A-Fa-f0-9]{1,4}\b",
|
|
50
|
+
"mask": "[IP_ADDRESS]",
|
|
51
|
+
"confidence": 0.95,
|
|
52
|
+
},
|
|
53
|
+
"name": {
|
|
54
|
+
# Japanese names (simplified pattern)
|
|
55
|
+
"pattern": r"[一-龯]{1,4}[\s ][一-龯]{1,4}",
|
|
56
|
+
"mask": "[NAME]",
|
|
57
|
+
"confidence": 0.70,
|
|
58
|
+
},
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class PIIDetector:
|
|
63
|
+
"""PII Detection and Masking class."""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
enabled_types: Optional[list[str]] = None,
|
|
68
|
+
action: str = "mask",
|
|
69
|
+
custom_patterns: Optional[list[dict]] = None,
|
|
70
|
+
):
|
|
71
|
+
"""
|
|
72
|
+
Initialize PII Detector.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
enabled_types: List of PII types to detect. None means all.
|
|
76
|
+
action: Action to take - "mask", "redact", "hash", "none"
|
|
77
|
+
custom_patterns: List of custom pattern dicts with 'name', 'pattern', 'mask' keys
|
|
78
|
+
"""
|
|
79
|
+
self.enabled_types = enabled_types or list(PII_PATTERNS.keys())
|
|
80
|
+
self.action = action
|
|
81
|
+
self.custom_patterns = custom_patterns or []
|
|
82
|
+
|
|
83
|
+
# Compile patterns
|
|
84
|
+
self.compiled_patterns: dict[str, dict] = {}
|
|
85
|
+
for pii_type in self.enabled_types:
|
|
86
|
+
if pii_type in PII_PATTERNS:
|
|
87
|
+
pattern_info = PII_PATTERNS[pii_type]
|
|
88
|
+
self.compiled_patterns[pii_type] = {
|
|
89
|
+
"regex": re.compile(pattern_info["pattern"], re.IGNORECASE),
|
|
90
|
+
"mask": pattern_info["mask"],
|
|
91
|
+
"confidence": pattern_info["confidence"],
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
# Add custom patterns
|
|
95
|
+
for custom in self.custom_patterns:
|
|
96
|
+
name = custom.get("name", f"custom_{len(self.compiled_patterns)}")
|
|
97
|
+
pattern = custom.get("pattern")
|
|
98
|
+
mask = custom.get("mask", f"[{name.upper()}]")
|
|
99
|
+
if pattern:
|
|
100
|
+
try:
|
|
101
|
+
self.compiled_patterns[name] = {
|
|
102
|
+
"regex": re.compile(pattern, re.IGNORECASE),
|
|
103
|
+
"mask": mask,
|
|
104
|
+
"confidence": custom.get("confidence", 0.8),
|
|
105
|
+
}
|
|
106
|
+
except re.error:
|
|
107
|
+
pass # Skip invalid patterns
|
|
108
|
+
|
|
109
|
+
def detect(self, text: str) -> list[PIIMatch]:
|
|
110
|
+
"""
|
|
111
|
+
Detect PII in text.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
text: Text to scan for PII
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
List of PIIMatch objects
|
|
118
|
+
"""
|
|
119
|
+
if not text:
|
|
120
|
+
return []
|
|
121
|
+
|
|
122
|
+
matches: list[PIIMatch] = []
|
|
123
|
+
for pii_type, pattern_info in self.compiled_patterns.items():
|
|
124
|
+
for match in pattern_info["regex"].finditer(text):
|
|
125
|
+
matches.append(
|
|
126
|
+
PIIMatch(
|
|
127
|
+
pii_type=pii_type,
|
|
128
|
+
value=match.group(),
|
|
129
|
+
start=match.start(),
|
|
130
|
+
end=match.end(),
|
|
131
|
+
confidence=pattern_info["confidence"],
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Sort by position and remove overlapping matches (keep higher confidence)
|
|
136
|
+
matches.sort(key=lambda m: (m.start, -m.confidence))
|
|
137
|
+
filtered_matches: list[PIIMatch] = []
|
|
138
|
+
last_end = -1
|
|
139
|
+
for match in matches:
|
|
140
|
+
if match.start >= last_end:
|
|
141
|
+
filtered_matches.append(match)
|
|
142
|
+
last_end = match.end
|
|
143
|
+
|
|
144
|
+
return filtered_matches
|
|
145
|
+
|
|
146
|
+
def mask_text(self, text: str) -> tuple[str, list[PIIMatch]]:
|
|
147
|
+
"""
|
|
148
|
+
Detect and mask PII in text.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
text: Text to process
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
Tuple of (masked_text, list of matches)
|
|
155
|
+
"""
|
|
156
|
+
matches = self.detect(text)
|
|
157
|
+
if not matches:
|
|
158
|
+
return text, []
|
|
159
|
+
|
|
160
|
+
# Apply masking based on action
|
|
161
|
+
result = text
|
|
162
|
+
offset = 0
|
|
163
|
+
|
|
164
|
+
for match in sorted(matches, key=lambda m: m.start):
|
|
165
|
+
start = match.start + offset
|
|
166
|
+
end = match.end + offset
|
|
167
|
+
|
|
168
|
+
if self.action == "mask":
|
|
169
|
+
replacement = self.compiled_patterns[match.pii_type]["mask"]
|
|
170
|
+
elif self.action == "redact":
|
|
171
|
+
replacement = "***"
|
|
172
|
+
elif self.action == "hash":
|
|
173
|
+
hash_value = hashlib.sha256(match.value.encode()).hexdigest()[:12]
|
|
174
|
+
replacement = f"[HASH:{hash_value}]"
|
|
175
|
+
else: # none
|
|
176
|
+
continue
|
|
177
|
+
|
|
178
|
+
result = result[:start] + replacement + result[end:]
|
|
179
|
+
offset += len(replacement) - len(match.value)
|
|
180
|
+
|
|
181
|
+
return result, matches
|
|
182
|
+
|
|
183
|
+
def process_json(
|
|
184
|
+
self, data: Any, path: str = ""
|
|
185
|
+
) -> tuple[Any, list[tuple[str, PIIMatch]]]:
|
|
186
|
+
"""
|
|
187
|
+
Recursively process JSON data for PII.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
data: JSON data (dict, list, or primitive)
|
|
191
|
+
path: Current path in the JSON structure
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Tuple of (processed_data, list of (path, match) tuples)
|
|
195
|
+
"""
|
|
196
|
+
all_matches: list[tuple[str, PIIMatch]] = []
|
|
197
|
+
|
|
198
|
+
if isinstance(data, dict):
|
|
199
|
+
result = {}
|
|
200
|
+
for key, value in data.items():
|
|
201
|
+
new_path = f"{path}.{key}" if path else key
|
|
202
|
+
processed, matches = self.process_json(value, new_path)
|
|
203
|
+
result[key] = processed
|
|
204
|
+
all_matches.extend(matches)
|
|
205
|
+
return result, all_matches
|
|
206
|
+
|
|
207
|
+
elif isinstance(data, list):
|
|
208
|
+
result = []
|
|
209
|
+
for i, item in enumerate(data):
|
|
210
|
+
new_path = f"{path}[{i}]"
|
|
211
|
+
processed, matches = self.process_json(item, new_path)
|
|
212
|
+
result.append(processed)
|
|
213
|
+
all_matches.extend(matches)
|
|
214
|
+
return result, all_matches
|
|
215
|
+
|
|
216
|
+
elif isinstance(data, str):
|
|
217
|
+
masked_text, matches = self.mask_text(data)
|
|
218
|
+
for match in matches:
|
|
219
|
+
all_matches.append((path, match))
|
|
220
|
+
return masked_text, all_matches
|
|
221
|
+
|
|
222
|
+
else:
|
|
223
|
+
return data, []
|