langchain 1.0.0a9__py3-none-any.whl → 1.0.0a11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -24
- langchain/_internal/_documents.py +1 -1
- langchain/_internal/_prompts.py +2 -2
- langchain/_internal/_typing.py +1 -1
- langchain/agents/__init__.py +2 -3
- langchain/agents/factory.py +1126 -0
- langchain/agents/middleware/__init__.py +38 -1
- langchain/agents/middleware/context_editing.py +245 -0
- langchain/agents/middleware/human_in_the_loop.py +67 -20
- langchain/agents/middleware/model_call_limit.py +177 -0
- langchain/agents/middleware/model_fallback.py +94 -0
- langchain/agents/middleware/pii.py +753 -0
- langchain/agents/middleware/planning.py +201 -0
- langchain/agents/middleware/prompt_caching.py +7 -4
- langchain/agents/middleware/summarization.py +2 -1
- langchain/agents/middleware/tool_call_limit.py +260 -0
- langchain/agents/middleware/tool_selection.py +306 -0
- langchain/agents/middleware/types.py +708 -127
- langchain/agents/structured_output.py +15 -1
- langchain/chat_models/base.py +22 -25
- langchain/embeddings/base.py +3 -4
- langchain/embeddings/cache.py +0 -1
- langchain/messages/__init__.py +29 -0
- langchain/rate_limiters/__init__.py +13 -0
- langchain/tools/__init__.py +9 -0
- langchain/{agents → tools}/tool_node.py +8 -10
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/METADATA +29 -35
- langchain-1.0.0a11.dist-info/RECORD +43 -0
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/WHEEL +1 -1
- langchain/agents/middleware_agent.py +0 -617
- langchain/agents/react_agent.py +0 -1228
- langchain/globals.py +0 -18
- langchain/text_splitter.py +0 -50
- langchain-1.0.0a9.dist-info/RECORD +0 -38
- langchain-1.0.0a9.dist-info/entry_points.txt +0 -4
- {langchain-1.0.0a9.dist-info → langchain-1.0.0a11.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,753 @@
|
|
|
1
|
+
"""PII detection and handling middleware for agents."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import ipaddress
|
|
7
|
+
import re
|
|
8
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
|
|
11
|
+
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
|
|
12
|
+
from typing_extensions import TypedDict
|
|
13
|
+
|
|
14
|
+
from langchain.agents.middleware.types import AgentMiddleware, AgentState, hook_config
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
|
|
19
|
+
from langgraph.runtime import Runtime
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PIIMatch(TypedDict):
|
|
23
|
+
"""Represents a detected PII match in text."""
|
|
24
|
+
|
|
25
|
+
type: str
|
|
26
|
+
"""The type of PII detected (e.g., 'email', 'ssn', 'credit_card')."""
|
|
27
|
+
value: str
|
|
28
|
+
"""The actual matched text."""
|
|
29
|
+
start: int
|
|
30
|
+
"""Starting position of the match in the text."""
|
|
31
|
+
end: int
|
|
32
|
+
"""Ending position of the match in the text."""
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PIIDetectionError(Exception):
|
|
36
|
+
"""Exception raised when PII is detected and strategy is 'block'."""
|
|
37
|
+
|
|
38
|
+
def __init__(self, pii_type: str, matches: list[PIIMatch]) -> None:
|
|
39
|
+
"""Initialize the exception with PII detection information.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
pii_type: The type of PII that was detected.
|
|
43
|
+
matches: List of PII matches found.
|
|
44
|
+
"""
|
|
45
|
+
self.pii_type = pii_type
|
|
46
|
+
self.matches = matches
|
|
47
|
+
count = len(matches)
|
|
48
|
+
msg = f"Detected {count} instance(s) of {pii_type} in message content"
|
|
49
|
+
super().__init__(msg)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# ============================================================================
|
|
53
|
+
# PII Detection Functions
|
|
54
|
+
# ============================================================================
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _luhn_checksum(card_number: str) -> bool:
|
|
58
|
+
"""Validate credit card number using Luhn algorithm.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
card_number: Credit card number string (digits only).
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
True if the number passes Luhn validation, False otherwise.
|
|
65
|
+
"""
|
|
66
|
+
digits = [int(d) for d in card_number if d.isdigit()]
|
|
67
|
+
|
|
68
|
+
if len(digits) < 13 or len(digits) > 19:
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
checksum = 0
|
|
72
|
+
for i, digit in enumerate(reversed(digits)):
|
|
73
|
+
d = digit
|
|
74
|
+
if i % 2 == 1:
|
|
75
|
+
d *= 2
|
|
76
|
+
if d > 9:
|
|
77
|
+
d -= 9
|
|
78
|
+
checksum += d
|
|
79
|
+
|
|
80
|
+
return checksum % 10 == 0
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def detect_email(content: str) -> list[PIIMatch]:
|
|
84
|
+
"""Detect email addresses in content.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
content: Text content to scan.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
List of detected email matches.
|
|
91
|
+
"""
|
|
92
|
+
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
|
93
|
+
return [
|
|
94
|
+
PIIMatch(
|
|
95
|
+
type="email",
|
|
96
|
+
value=match.group(),
|
|
97
|
+
start=match.start(),
|
|
98
|
+
end=match.end(),
|
|
99
|
+
)
|
|
100
|
+
for match in re.finditer(pattern, content)
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def detect_credit_card(content: str) -> list[PIIMatch]:
|
|
105
|
+
"""Detect credit card numbers in content using Luhn validation.
|
|
106
|
+
|
|
107
|
+
Detects cards in formats like:
|
|
108
|
+
- 1234567890123456
|
|
109
|
+
- 1234 5678 9012 3456
|
|
110
|
+
- 1234-5678-9012-3456
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
content: Text content to scan.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
List of detected credit card matches.
|
|
117
|
+
"""
|
|
118
|
+
# Match various credit card formats
|
|
119
|
+
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
|
|
120
|
+
matches = []
|
|
121
|
+
|
|
122
|
+
for match in re.finditer(pattern, content):
|
|
123
|
+
card_number = match.group()
|
|
124
|
+
# Validate with Luhn algorithm
|
|
125
|
+
if _luhn_checksum(card_number):
|
|
126
|
+
matches.append(
|
|
127
|
+
PIIMatch(
|
|
128
|
+
type="credit_card",
|
|
129
|
+
value=card_number,
|
|
130
|
+
start=match.start(),
|
|
131
|
+
end=match.end(),
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return matches
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def detect_ip(content: str) -> list[PIIMatch]:
|
|
139
|
+
"""Detect IP addresses in content using stdlib validation.
|
|
140
|
+
|
|
141
|
+
Validates both IPv4 and IPv6 addresses.
|
|
142
|
+
|
|
143
|
+
Args:
|
|
144
|
+
content: Text content to scan.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
List of detected IP address matches.
|
|
148
|
+
"""
|
|
149
|
+
matches = []
|
|
150
|
+
|
|
151
|
+
# IPv4 pattern
|
|
152
|
+
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
|
|
153
|
+
|
|
154
|
+
for match in re.finditer(ipv4_pattern, content):
|
|
155
|
+
ip_str = match.group()
|
|
156
|
+
try:
|
|
157
|
+
# Validate with stdlib
|
|
158
|
+
ipaddress.ip_address(ip_str)
|
|
159
|
+
matches.append(
|
|
160
|
+
PIIMatch(
|
|
161
|
+
type="ip",
|
|
162
|
+
value=ip_str,
|
|
163
|
+
start=match.start(),
|
|
164
|
+
end=match.end(),
|
|
165
|
+
)
|
|
166
|
+
)
|
|
167
|
+
except ValueError:
|
|
168
|
+
# Not a valid IP address
|
|
169
|
+
pass
|
|
170
|
+
|
|
171
|
+
return matches
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def detect_mac_address(content: str) -> list[PIIMatch]:
|
|
175
|
+
"""Detect MAC addresses in content.
|
|
176
|
+
|
|
177
|
+
Detects formats like:
|
|
178
|
+
- 00:1A:2B:3C:4D:5E
|
|
179
|
+
- 00-1A-2B-3C-4D-5E
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
content: Text content to scan.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
List of detected MAC address matches.
|
|
186
|
+
"""
|
|
187
|
+
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
|
|
188
|
+
return [
|
|
189
|
+
PIIMatch(
|
|
190
|
+
type="mac_address",
|
|
191
|
+
value=match.group(),
|
|
192
|
+
start=match.start(),
|
|
193
|
+
end=match.end(),
|
|
194
|
+
)
|
|
195
|
+
for match in re.finditer(pattern, content)
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def detect_url(content: str) -> list[PIIMatch]:
|
|
200
|
+
"""Detect URLs in content using regex and stdlib validation.
|
|
201
|
+
|
|
202
|
+
Detects:
|
|
203
|
+
- http://example.com
|
|
204
|
+
- https://example.com/path
|
|
205
|
+
- www.example.com
|
|
206
|
+
- example.com/path
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
content: Text content to scan.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
List of detected URL matches.
|
|
213
|
+
"""
|
|
214
|
+
matches = []
|
|
215
|
+
|
|
216
|
+
# Pattern 1: URLs with scheme (http:// or https://)
|
|
217
|
+
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
|
|
218
|
+
|
|
219
|
+
for match in re.finditer(scheme_pattern, content):
|
|
220
|
+
url = match.group()
|
|
221
|
+
try:
|
|
222
|
+
result = urlparse(url)
|
|
223
|
+
if result.scheme in ("http", "https") and result.netloc:
|
|
224
|
+
matches.append(
|
|
225
|
+
PIIMatch(
|
|
226
|
+
type="url",
|
|
227
|
+
value=url,
|
|
228
|
+
start=match.start(),
|
|
229
|
+
end=match.end(),
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
except Exception: # noqa: S110, BLE001
|
|
233
|
+
# Invalid URL, skip
|
|
234
|
+
pass
|
|
235
|
+
|
|
236
|
+
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
|
|
237
|
+
# More conservative to avoid false positives
|
|
238
|
+
bare_pattern = r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?" # noqa: E501
|
|
239
|
+
|
|
240
|
+
for match in re.finditer(bare_pattern, content):
|
|
241
|
+
# Skip if already matched with scheme
|
|
242
|
+
if any(
|
|
243
|
+
m["start"] <= match.start() < m["end"] or m["start"] < match.end() <= m["end"]
|
|
244
|
+
for m in matches
|
|
245
|
+
):
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
url = match.group()
|
|
249
|
+
# Only accept if it has a path or starts with www
|
|
250
|
+
# This reduces false positives like "example.com" in prose
|
|
251
|
+
if "/" in url or url.startswith("www."):
|
|
252
|
+
try:
|
|
253
|
+
# Add scheme for validation (required for urlparse to work correctly)
|
|
254
|
+
test_url = f"http://{url}"
|
|
255
|
+
result = urlparse(test_url)
|
|
256
|
+
if result.netloc and "." in result.netloc:
|
|
257
|
+
matches.append(
|
|
258
|
+
PIIMatch(
|
|
259
|
+
type="url",
|
|
260
|
+
value=url,
|
|
261
|
+
start=match.start(),
|
|
262
|
+
end=match.end(),
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
except Exception: # noqa: S110, BLE001
|
|
266
|
+
# Invalid URL, skip
|
|
267
|
+
pass
|
|
268
|
+
|
|
269
|
+
return matches
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# Built-in detector registry
|
|
273
|
+
_BUILTIN_DETECTORS: dict[str, Callable[[str], list[PIIMatch]]] = {
|
|
274
|
+
"email": detect_email,
|
|
275
|
+
"credit_card": detect_credit_card,
|
|
276
|
+
"ip": detect_ip,
|
|
277
|
+
"mac_address": detect_mac_address,
|
|
278
|
+
"url": detect_url,
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# ============================================================================
|
|
283
|
+
# Strategy Implementations
|
|
284
|
+
# ============================================================================
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
288
|
+
"""Replace PII with [REDACTED_TYPE] placeholders.
|
|
289
|
+
|
|
290
|
+
Args:
|
|
291
|
+
content: Original content.
|
|
292
|
+
matches: List of PII matches to redact.
|
|
293
|
+
|
|
294
|
+
Returns:
|
|
295
|
+
Content with PII redacted.
|
|
296
|
+
"""
|
|
297
|
+
if not matches:
|
|
298
|
+
return content
|
|
299
|
+
|
|
300
|
+
# Sort matches by start position in reverse to avoid offset issues
|
|
301
|
+
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
302
|
+
|
|
303
|
+
result = content
|
|
304
|
+
for match in sorted_matches:
|
|
305
|
+
replacement = f"[REDACTED_{match['type'].upper()}]"
|
|
306
|
+
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
307
|
+
|
|
308
|
+
return result
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
312
|
+
"""Partially mask PII, showing only last few characters.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
content: Original content.
|
|
316
|
+
matches: List of PII matches to mask.
|
|
317
|
+
|
|
318
|
+
Returns:
|
|
319
|
+
Content with PII masked.
|
|
320
|
+
"""
|
|
321
|
+
if not matches:
|
|
322
|
+
return content
|
|
323
|
+
|
|
324
|
+
# Sort matches by start position in reverse
|
|
325
|
+
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
326
|
+
|
|
327
|
+
result = content
|
|
328
|
+
for match in sorted_matches:
|
|
329
|
+
value = match["value"]
|
|
330
|
+
pii_type = match["type"]
|
|
331
|
+
|
|
332
|
+
# Different masking strategies by type
|
|
333
|
+
if pii_type == "email":
|
|
334
|
+
# Show only domain: user@****.com
|
|
335
|
+
parts = value.split("@")
|
|
336
|
+
if len(parts) == 2:
|
|
337
|
+
domain_parts = parts[1].split(".")
|
|
338
|
+
if len(domain_parts) >= 2:
|
|
339
|
+
masked = f"{parts[0]}@****.{domain_parts[-1]}"
|
|
340
|
+
else:
|
|
341
|
+
masked = f"{parts[0]}@****"
|
|
342
|
+
else:
|
|
343
|
+
masked = "****"
|
|
344
|
+
|
|
345
|
+
elif pii_type == "credit_card":
|
|
346
|
+
# Show last 4: ****-****-****-1234
|
|
347
|
+
digits_only = "".join(c for c in value if c.isdigit())
|
|
348
|
+
separator = "-" if "-" in value else " " if " " in value else ""
|
|
349
|
+
if separator:
|
|
350
|
+
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
|
|
351
|
+
else:
|
|
352
|
+
masked = f"************{digits_only[-4:]}"
|
|
353
|
+
|
|
354
|
+
elif pii_type == "ip":
|
|
355
|
+
# Show last octet: *.*.*. 123
|
|
356
|
+
parts = value.split(".")
|
|
357
|
+
masked = f"*.*.*.{parts[-1]}" if len(parts) == 4 else "****"
|
|
358
|
+
|
|
359
|
+
elif pii_type == "mac_address":
|
|
360
|
+
# Show last byte: **:**:**:**:**:5E
|
|
361
|
+
separator = ":" if ":" in value else "-"
|
|
362
|
+
masked = (
|
|
363
|
+
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
elif pii_type == "url":
|
|
367
|
+
# Mask everything: [MASKED_URL]
|
|
368
|
+
masked = "[MASKED_URL]"
|
|
369
|
+
|
|
370
|
+
else:
|
|
371
|
+
# Default: show last 4 chars
|
|
372
|
+
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
|
|
373
|
+
|
|
374
|
+
result = result[: match["start"]] + masked + result[match["end"] :]
|
|
375
|
+
|
|
376
|
+
return result
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
380
|
+
"""Replace PII with deterministic hash including type information.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
content: Original content.
|
|
384
|
+
matches: List of PII matches to hash.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
Content with PII replaced by hashes in format <type_hash:digest>.
|
|
388
|
+
"""
|
|
389
|
+
if not matches:
|
|
390
|
+
return content
|
|
391
|
+
|
|
392
|
+
# Sort matches by start position in reverse
|
|
393
|
+
sorted_matches = sorted(matches, key=lambda m: m["start"], reverse=True)
|
|
394
|
+
|
|
395
|
+
result = content
|
|
396
|
+
for match in sorted_matches:
|
|
397
|
+
value = match["value"]
|
|
398
|
+
pii_type = match["type"]
|
|
399
|
+
# Create deterministic hash
|
|
400
|
+
hash_digest = hashlib.sha256(value.encode()).hexdigest()[:8]
|
|
401
|
+
replacement = f"<{pii_type}_hash:{hash_digest}>"
|
|
402
|
+
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
403
|
+
|
|
404
|
+
return result
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
# ============================================================================
|
|
408
|
+
# PIIMiddleware
|
|
409
|
+
# ============================================================================
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
class PIIMiddleware(AgentMiddleware):
|
|
413
|
+
"""Detect and handle Personally Identifiable Information (PII) in agent conversations.
|
|
414
|
+
|
|
415
|
+
This middleware detects common PII types and applies configurable strategies
|
|
416
|
+
to handle them. It can detect emails, credit cards, IP addresses,
|
|
417
|
+
MAC addresses, and URLs in both user input and agent output.
|
|
418
|
+
|
|
419
|
+
Built-in PII types:
|
|
420
|
+
- ``email``: Email addresses
|
|
421
|
+
- ``credit_card``: Credit card numbers (validated with Luhn algorithm)
|
|
422
|
+
- ``ip``: IP addresses (validated with stdlib)
|
|
423
|
+
- ``mac_address``: MAC addresses
|
|
424
|
+
- ``url``: URLs (both http/https and bare URLs)
|
|
425
|
+
|
|
426
|
+
Strategies:
|
|
427
|
+
- ``block``: Raise an exception when PII is detected
|
|
428
|
+
- ``redact``: Replace PII with ``[REDACTED_TYPE]`` placeholders
|
|
429
|
+
- ``mask``: Partially mask PII (e.g., ``****-****-****-1234`` for credit card)
|
|
430
|
+
- ``hash``: Replace PII with deterministic hash (e.g., ``<email_hash:a1b2c3d4>``)
|
|
431
|
+
|
|
432
|
+
Strategy Selection Guide:
|
|
433
|
+
|
|
434
|
+
======== =================== =======================================
|
|
435
|
+
Strategy Preserves Identity? Best For
|
|
436
|
+
======== =================== =======================================
|
|
437
|
+
`block` N/A Avoid PII completely
|
|
438
|
+
`redact` No General compliance, log sanitization
|
|
439
|
+
`mask` No Human readability, customer service UIs
|
|
440
|
+
`hash` Yes (pseudonymous) Analytics, debugging
|
|
441
|
+
======== =================== =======================================
|
|
442
|
+
|
|
443
|
+
Example:
|
|
444
|
+
```python
|
|
445
|
+
from langchain.agents.middleware import PIIMiddleware
|
|
446
|
+
from langchain.agents import create_agent
|
|
447
|
+
|
|
448
|
+
# Redact all emails in user input
|
|
449
|
+
agent = create_agent(
|
|
450
|
+
"openai:gpt-5",
|
|
451
|
+
middleware=[
|
|
452
|
+
PIIMiddleware("email", strategy="redact"),
|
|
453
|
+
],
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
# Use different strategies for different PII types
|
|
457
|
+
agent = create_agent(
|
|
458
|
+
"openai:gpt-4o",
|
|
459
|
+
middleware=[
|
|
460
|
+
PIIMiddleware("credit_card", strategy="mask"),
|
|
461
|
+
PIIMiddleware("url", strategy="redact"),
|
|
462
|
+
PIIMiddleware("ip", strategy="hash"),
|
|
463
|
+
],
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Custom PII type with regex
|
|
467
|
+
agent = create_agent(
|
|
468
|
+
"openai:gpt-5",
|
|
469
|
+
middleware=[
|
|
470
|
+
PIIMiddleware("api_key", detector=r"sk-[a-zA-Z0-9]{32}", strategy="block"),
|
|
471
|
+
],
|
|
472
|
+
)
|
|
473
|
+
```
|
|
474
|
+
"""
|
|
475
|
+
|
|
476
|
+
def __init__(
|
|
477
|
+
self,
|
|
478
|
+
pii_type: Literal["email", "credit_card", "ip", "mac_address", "url"] | str, # noqa: PYI051
|
|
479
|
+
*,
|
|
480
|
+
strategy: Literal["block", "redact", "mask", "hash"] = "redact",
|
|
481
|
+
detector: Callable[[str], list[PIIMatch]] | str | None = None,
|
|
482
|
+
apply_to_input: bool = True,
|
|
483
|
+
apply_to_output: bool = False,
|
|
484
|
+
apply_to_tool_results: bool = False,
|
|
485
|
+
) -> None:
|
|
486
|
+
"""Initialize the PII detection middleware.
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
pii_type: Type of PII to detect. Can be a built-in type
|
|
490
|
+
(``email``, ``credit_card``, ``ip``, ``mac_address``, ``url``)
|
|
491
|
+
or a custom type name.
|
|
492
|
+
strategy: How to handle detected PII:
|
|
493
|
+
|
|
494
|
+
* ``block``: Raise PIIDetectionError when PII is detected
|
|
495
|
+
* ``redact``: Replace with ``[REDACTED_TYPE]`` placeholders
|
|
496
|
+
* ``mask``: Partially mask PII (show last few characters)
|
|
497
|
+
* ``hash``: Replace with deterministic hash (format: ``<type_hash:digest>``)
|
|
498
|
+
|
|
499
|
+
detector: Custom detector function or regex pattern.
|
|
500
|
+
|
|
501
|
+
* If ``Callable``: Function that takes content string and returns
|
|
502
|
+
list of PIIMatch objects
|
|
503
|
+
* If ``str``: Regex pattern to match PII
|
|
504
|
+
* If ``None``: Uses built-in detector for the pii_type
|
|
505
|
+
|
|
506
|
+
apply_to_input: Whether to check user messages before model call.
|
|
507
|
+
apply_to_output: Whether to check AI messages after model call.
|
|
508
|
+
apply_to_tool_results: Whether to check tool result messages after tool execution.
|
|
509
|
+
|
|
510
|
+
Raises:
|
|
511
|
+
ValueError: If pii_type is not built-in and no detector is provided.
|
|
512
|
+
"""
|
|
513
|
+
super().__init__()
|
|
514
|
+
|
|
515
|
+
self.pii_type = pii_type
|
|
516
|
+
self.strategy = strategy
|
|
517
|
+
self.apply_to_input = apply_to_input
|
|
518
|
+
self.apply_to_output = apply_to_output
|
|
519
|
+
self.apply_to_tool_results = apply_to_tool_results
|
|
520
|
+
|
|
521
|
+
# Resolve detector
|
|
522
|
+
if detector is None:
|
|
523
|
+
# Use built-in detector
|
|
524
|
+
if pii_type not in _BUILTIN_DETECTORS:
|
|
525
|
+
msg = (
|
|
526
|
+
f"Unknown PII type: {pii_type}. "
|
|
527
|
+
f"Must be one of {list(_BUILTIN_DETECTORS.keys())} "
|
|
528
|
+
"or provide a custom detector."
|
|
529
|
+
)
|
|
530
|
+
raise ValueError(msg)
|
|
531
|
+
self.detector = _BUILTIN_DETECTORS[pii_type]
|
|
532
|
+
elif isinstance(detector, str):
|
|
533
|
+
# Custom regex pattern
|
|
534
|
+
pattern = detector
|
|
535
|
+
|
|
536
|
+
def regex_detector(content: str) -> list[PIIMatch]:
|
|
537
|
+
return [
|
|
538
|
+
PIIMatch(
|
|
539
|
+
type=pii_type,
|
|
540
|
+
value=match.group(),
|
|
541
|
+
start=match.start(),
|
|
542
|
+
end=match.end(),
|
|
543
|
+
)
|
|
544
|
+
for match in re.finditer(pattern, content)
|
|
545
|
+
]
|
|
546
|
+
|
|
547
|
+
self.detector = regex_detector
|
|
548
|
+
else:
|
|
549
|
+
# Custom callable detector
|
|
550
|
+
self.detector = detector
|
|
551
|
+
|
|
552
|
+
@property
|
|
553
|
+
def name(self) -> str:
|
|
554
|
+
"""Name of the middleware."""
|
|
555
|
+
return f"{self.__class__.__name__}[{self.pii_type}]"
|
|
556
|
+
|
|
557
|
+
@hook_config(can_jump_to=["end"])
|
|
558
|
+
def before_model( # noqa: PLR0915
|
|
559
|
+
self,
|
|
560
|
+
state: AgentState,
|
|
561
|
+
runtime: Runtime, # noqa: ARG002
|
|
562
|
+
) -> dict[str, Any] | None:
|
|
563
|
+
"""Check user messages and tool results for PII before model invocation.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
state: The current agent state.
|
|
567
|
+
runtime: The langgraph runtime.
|
|
568
|
+
|
|
569
|
+
Returns:
|
|
570
|
+
Updated state with PII handled according to strategy, or None if no PII detected.
|
|
571
|
+
|
|
572
|
+
Raises:
|
|
573
|
+
PIIDetectionError: If PII is detected and strategy is "block".
|
|
574
|
+
"""
|
|
575
|
+
if not self.apply_to_input and not self.apply_to_tool_results:
|
|
576
|
+
return None
|
|
577
|
+
|
|
578
|
+
messages = state["messages"]
|
|
579
|
+
if not messages:
|
|
580
|
+
return None
|
|
581
|
+
|
|
582
|
+
new_messages = list(messages)
|
|
583
|
+
any_modified = False
|
|
584
|
+
|
|
585
|
+
# Check user input if enabled
|
|
586
|
+
if self.apply_to_input:
|
|
587
|
+
# Get last user message
|
|
588
|
+
last_user_msg = None
|
|
589
|
+
last_user_idx = None
|
|
590
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
591
|
+
if isinstance(messages[i], HumanMessage):
|
|
592
|
+
last_user_msg = messages[i]
|
|
593
|
+
last_user_idx = i
|
|
594
|
+
break
|
|
595
|
+
|
|
596
|
+
if last_user_idx is not None and last_user_msg and last_user_msg.content:
|
|
597
|
+
# Detect PII in message content
|
|
598
|
+
content = str(last_user_msg.content)
|
|
599
|
+
matches = self.detector(content)
|
|
600
|
+
|
|
601
|
+
if matches:
|
|
602
|
+
# Apply strategy
|
|
603
|
+
if self.strategy == "block":
|
|
604
|
+
raise PIIDetectionError(self.pii_type, matches)
|
|
605
|
+
|
|
606
|
+
if self.strategy == "redact":
|
|
607
|
+
new_content = _apply_redact_strategy(content, matches)
|
|
608
|
+
elif self.strategy == "mask":
|
|
609
|
+
new_content = _apply_mask_strategy(content, matches)
|
|
610
|
+
elif self.strategy == "hash":
|
|
611
|
+
new_content = _apply_hash_strategy(content, matches)
|
|
612
|
+
else:
|
|
613
|
+
# Should not reach here due to type hints
|
|
614
|
+
msg = f"Unknown strategy: {self.strategy}"
|
|
615
|
+
raise ValueError(msg)
|
|
616
|
+
|
|
617
|
+
# Create updated message
|
|
618
|
+
updated_message: AnyMessage = HumanMessage(
|
|
619
|
+
content=new_content,
|
|
620
|
+
id=last_user_msg.id,
|
|
621
|
+
name=last_user_msg.name,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
new_messages[last_user_idx] = updated_message
|
|
625
|
+
any_modified = True
|
|
626
|
+
|
|
627
|
+
# Check tool results if enabled
|
|
628
|
+
if self.apply_to_tool_results:
|
|
629
|
+
# Find the last AIMessage, then process all ToolMessages after it
|
|
630
|
+
last_ai_idx = None
|
|
631
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
632
|
+
if isinstance(messages[i], AIMessage):
|
|
633
|
+
last_ai_idx = i
|
|
634
|
+
break
|
|
635
|
+
|
|
636
|
+
if last_ai_idx is not None:
|
|
637
|
+
# Get all tool messages after the last AI message
|
|
638
|
+
for i in range(last_ai_idx + 1, len(messages)):
|
|
639
|
+
msg = messages[i]
|
|
640
|
+
if isinstance(msg, ToolMessage):
|
|
641
|
+
tool_msg = msg
|
|
642
|
+
if not tool_msg.content:
|
|
643
|
+
continue
|
|
644
|
+
|
|
645
|
+
content = str(tool_msg.content)
|
|
646
|
+
matches = self.detector(content)
|
|
647
|
+
|
|
648
|
+
if not matches:
|
|
649
|
+
continue
|
|
650
|
+
|
|
651
|
+
# Apply strategy
|
|
652
|
+
if self.strategy == "block":
|
|
653
|
+
raise PIIDetectionError(self.pii_type, matches)
|
|
654
|
+
|
|
655
|
+
if self.strategy == "redact":
|
|
656
|
+
new_content = _apply_redact_strategy(content, matches)
|
|
657
|
+
elif self.strategy == "mask":
|
|
658
|
+
new_content = _apply_mask_strategy(content, matches)
|
|
659
|
+
elif self.strategy == "hash":
|
|
660
|
+
new_content = _apply_hash_strategy(content, matches)
|
|
661
|
+
else:
|
|
662
|
+
# Should not reach here due to type hints
|
|
663
|
+
msg = f"Unknown strategy: {self.strategy}"
|
|
664
|
+
raise ValueError(msg)
|
|
665
|
+
|
|
666
|
+
# Create updated tool message
|
|
667
|
+
updated_message = ToolMessage(
|
|
668
|
+
content=new_content,
|
|
669
|
+
id=tool_msg.id,
|
|
670
|
+
name=tool_msg.name,
|
|
671
|
+
tool_call_id=tool_msg.tool_call_id,
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
new_messages[i] = updated_message
|
|
675
|
+
any_modified = True
|
|
676
|
+
|
|
677
|
+
if any_modified:
|
|
678
|
+
return {"messages": new_messages}
|
|
679
|
+
|
|
680
|
+
return None
|
|
681
|
+
|
|
682
|
+
def after_model(
|
|
683
|
+
self,
|
|
684
|
+
state: AgentState,
|
|
685
|
+
runtime: Runtime, # noqa: ARG002
|
|
686
|
+
) -> dict[str, Any] | None:
|
|
687
|
+
"""Check AI messages for PII after model invocation.
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
state: The current agent state.
|
|
691
|
+
runtime: The langgraph runtime.
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
Updated state with PII handled according to strategy, or None if no PII detected.
|
|
695
|
+
|
|
696
|
+
Raises:
|
|
697
|
+
PIIDetectionError: If PII is detected and strategy is "block".
|
|
698
|
+
"""
|
|
699
|
+
if not self.apply_to_output:
|
|
700
|
+
return None
|
|
701
|
+
|
|
702
|
+
messages = state["messages"]
|
|
703
|
+
if not messages:
|
|
704
|
+
return None
|
|
705
|
+
|
|
706
|
+
# Get last AI message
|
|
707
|
+
last_ai_msg = None
|
|
708
|
+
last_ai_idx = None
|
|
709
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
710
|
+
msg = messages[i]
|
|
711
|
+
if isinstance(msg, AIMessage):
|
|
712
|
+
last_ai_msg = msg
|
|
713
|
+
last_ai_idx = i
|
|
714
|
+
break
|
|
715
|
+
|
|
716
|
+
if last_ai_idx is None or not last_ai_msg or not last_ai_msg.content:
|
|
717
|
+
return None
|
|
718
|
+
|
|
719
|
+
# Detect PII in message content
|
|
720
|
+
content = str(last_ai_msg.content)
|
|
721
|
+
matches = self.detector(content)
|
|
722
|
+
|
|
723
|
+
if not matches:
|
|
724
|
+
return None
|
|
725
|
+
|
|
726
|
+
# Apply strategy
|
|
727
|
+
if self.strategy == "block":
|
|
728
|
+
raise PIIDetectionError(self.pii_type, matches)
|
|
729
|
+
|
|
730
|
+
if self.strategy == "redact":
|
|
731
|
+
new_content = _apply_redact_strategy(content, matches)
|
|
732
|
+
elif self.strategy == "mask":
|
|
733
|
+
new_content = _apply_mask_strategy(content, matches)
|
|
734
|
+
elif self.strategy == "hash":
|
|
735
|
+
new_content = _apply_hash_strategy(content, matches)
|
|
736
|
+
else:
|
|
737
|
+
# Should not reach here due to type hints
|
|
738
|
+
msg = f"Unknown strategy: {self.strategy}"
|
|
739
|
+
raise ValueError(msg)
|
|
740
|
+
|
|
741
|
+
# Create updated message
|
|
742
|
+
updated_message = AIMessage(
|
|
743
|
+
content=new_content,
|
|
744
|
+
id=last_ai_msg.id,
|
|
745
|
+
name=last_ai_msg.name,
|
|
746
|
+
tool_calls=last_ai_msg.tool_calls,
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
# Return updated messages
|
|
750
|
+
new_messages = list(messages)
|
|
751
|
+
new_messages[last_ai_idx] = updated_message
|
|
752
|
+
|
|
753
|
+
return {"messages": new_messages}
|