langchain 1.0.0a12__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain/__init__.py +1 -1
- langchain/agents/__init__.py +7 -1
- langchain/agents/factory.py +722 -226
- langchain/agents/middleware/__init__.py +36 -9
- langchain/agents/middleware/_execution.py +388 -0
- langchain/agents/middleware/_redaction.py +350 -0
- langchain/agents/middleware/context_editing.py +46 -17
- langchain/agents/middleware/file_search.py +382 -0
- langchain/agents/middleware/human_in_the_loop.py +220 -173
- langchain/agents/middleware/model_call_limit.py +43 -10
- langchain/agents/middleware/model_fallback.py +79 -36
- langchain/agents/middleware/pii.py +68 -504
- langchain/agents/middleware/shell_tool.py +718 -0
- langchain/agents/middleware/summarization.py +2 -2
- langchain/agents/middleware/{planning.py → todo.py} +35 -16
- langchain/agents/middleware/tool_call_limit.py +308 -114
- langchain/agents/middleware/tool_emulator.py +200 -0
- langchain/agents/middleware/tool_retry.py +384 -0
- langchain/agents/middleware/tool_selection.py +25 -21
- langchain/agents/middleware/types.py +714 -257
- langchain/agents/structured_output.py +37 -27
- langchain/chat_models/__init__.py +7 -1
- langchain/chat_models/base.py +192 -190
- langchain/embeddings/__init__.py +13 -3
- langchain/embeddings/base.py +49 -29
- langchain/messages/__init__.py +50 -1
- langchain/tools/__init__.py +9 -7
- langchain/tools/tool_node.py +16 -1174
- langchain-1.0.4.dist-info/METADATA +92 -0
- langchain-1.0.4.dist-info/RECORD +34 -0
- langchain/_internal/__init__.py +0 -0
- langchain/_internal/_documents.py +0 -35
- langchain/_internal/_lazy_import.py +0 -35
- langchain/_internal/_prompts.py +0 -158
- langchain/_internal/_typing.py +0 -70
- langchain/_internal/_utils.py +0 -7
- langchain/agents/_internal/__init__.py +0 -1
- langchain/agents/_internal/_typing.py +0 -13
- langchain/agents/middleware/prompt_caching.py +0 -86
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -123
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a12.dist-info/METADATA +0 -122
- langchain-1.0.0a12.dist-info/RECORD +0 -43
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a12.dist-info → langchain-1.0.4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
"""Shared redaction utilities for middleware components."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import ipaddress
|
|
7
|
+
import re
|
|
8
|
+
from collections.abc import Callable, Sequence
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Literal
|
|
11
|
+
from urllib.parse import urlparse
|
|
12
|
+
|
|
13
|
+
from typing_extensions import TypedDict
|
|
14
|
+
|
|
15
|
+
RedactionStrategy = Literal["block", "redact", "mask", "hash"]
|
|
16
|
+
"""Supported strategies for handling detected sensitive values."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PIIMatch(TypedDict):
|
|
20
|
+
"""Represents an individual match of sensitive data."""
|
|
21
|
+
|
|
22
|
+
type: str
|
|
23
|
+
value: str
|
|
24
|
+
start: int
|
|
25
|
+
end: int
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PIIDetectionError(Exception):
|
|
29
|
+
"""Raised when configured to block on detected sensitive values."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, pii_type: str, matches: Sequence[PIIMatch]) -> None:
|
|
32
|
+
"""Initialize the exception with match context.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
pii_type: Name of the detected sensitive type.
|
|
36
|
+
matches: All matches that were detected for that type.
|
|
37
|
+
"""
|
|
38
|
+
self.pii_type = pii_type
|
|
39
|
+
self.matches = list(matches)
|
|
40
|
+
count = len(matches)
|
|
41
|
+
msg = f"Detected {count} instance(s) of {pii_type} in text content"
|
|
42
|
+
super().__init__(msg)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
Detector = Callable[[str], list[PIIMatch]]
|
|
46
|
+
"""Callable signature for detectors that locate sensitive values."""
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def detect_email(content: str) -> list[PIIMatch]:
|
|
50
|
+
"""Detect email addresses in content."""
|
|
51
|
+
pattern = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b"
|
|
52
|
+
return [
|
|
53
|
+
PIIMatch(
|
|
54
|
+
type="email",
|
|
55
|
+
value=match.group(),
|
|
56
|
+
start=match.start(),
|
|
57
|
+
end=match.end(),
|
|
58
|
+
)
|
|
59
|
+
for match in re.finditer(pattern, content)
|
|
60
|
+
]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def detect_credit_card(content: str) -> list[PIIMatch]:
|
|
64
|
+
"""Detect credit card numbers in content using Luhn validation."""
|
|
65
|
+
pattern = r"\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b"
|
|
66
|
+
matches = []
|
|
67
|
+
|
|
68
|
+
for match in re.finditer(pattern, content):
|
|
69
|
+
card_number = match.group()
|
|
70
|
+
if _passes_luhn(card_number):
|
|
71
|
+
matches.append(
|
|
72
|
+
PIIMatch(
|
|
73
|
+
type="credit_card",
|
|
74
|
+
value=card_number,
|
|
75
|
+
start=match.start(),
|
|
76
|
+
end=match.end(),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return matches
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def detect_ip(content: str) -> list[PIIMatch]:
|
|
84
|
+
"""Detect IPv4 or IPv6 addresses in content."""
|
|
85
|
+
matches: list[PIIMatch] = []
|
|
86
|
+
ipv4_pattern = r"\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b"
|
|
87
|
+
|
|
88
|
+
for match in re.finditer(ipv4_pattern, content):
|
|
89
|
+
ip_candidate = match.group()
|
|
90
|
+
try:
|
|
91
|
+
ipaddress.ip_address(ip_candidate)
|
|
92
|
+
except ValueError:
|
|
93
|
+
continue
|
|
94
|
+
matches.append(
|
|
95
|
+
PIIMatch(
|
|
96
|
+
type="ip",
|
|
97
|
+
value=ip_candidate,
|
|
98
|
+
start=match.start(),
|
|
99
|
+
end=match.end(),
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return matches
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def detect_mac_address(content: str) -> list[PIIMatch]:
|
|
107
|
+
"""Detect MAC addresses in content."""
|
|
108
|
+
pattern = r"\b([0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"
|
|
109
|
+
return [
|
|
110
|
+
PIIMatch(
|
|
111
|
+
type="mac_address",
|
|
112
|
+
value=match.group(),
|
|
113
|
+
start=match.start(),
|
|
114
|
+
end=match.end(),
|
|
115
|
+
)
|
|
116
|
+
for match in re.finditer(pattern, content)
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def detect_url(content: str) -> list[PIIMatch]:
|
|
121
|
+
"""Detect URLs in content using regex and stdlib validation."""
|
|
122
|
+
matches: list[PIIMatch] = []
|
|
123
|
+
|
|
124
|
+
# Pattern 1: URLs with scheme (http:// or https://)
|
|
125
|
+
scheme_pattern = r"https?://[^\s<>\"{}|\\^`\[\]]+"
|
|
126
|
+
|
|
127
|
+
for match in re.finditer(scheme_pattern, content):
|
|
128
|
+
url = match.group()
|
|
129
|
+
result = urlparse(url)
|
|
130
|
+
if result.scheme in ("http", "https") and result.netloc:
|
|
131
|
+
matches.append(
|
|
132
|
+
PIIMatch(
|
|
133
|
+
type="url",
|
|
134
|
+
value=url,
|
|
135
|
+
start=match.start(),
|
|
136
|
+
end=match.end(),
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Pattern 2: URLs without scheme (www.example.com or example.com/path)
|
|
141
|
+
# More conservative to avoid false positives
|
|
142
|
+
bare_pattern = (
|
|
143
|
+
r"\b(?:www\.)?[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?"
|
|
144
|
+
r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)+(?:/[^\s]*)?"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
for match in re.finditer(bare_pattern, content):
|
|
148
|
+
start, end = match.start(), match.end()
|
|
149
|
+
# Skip if already matched with scheme
|
|
150
|
+
if any(m["start"] <= start < m["end"] or m["start"] < end <= m["end"] for m in matches):
|
|
151
|
+
continue
|
|
152
|
+
|
|
153
|
+
url = match.group()
|
|
154
|
+
# Only accept if it has a path or starts with www
|
|
155
|
+
# This reduces false positives like "example.com" in prose
|
|
156
|
+
if "/" in url or url.startswith("www."):
|
|
157
|
+
# Add scheme for validation (required for urlparse to work correctly)
|
|
158
|
+
test_url = f"http://{url}"
|
|
159
|
+
result = urlparse(test_url)
|
|
160
|
+
if result.netloc and "." in result.netloc:
|
|
161
|
+
matches.append(
|
|
162
|
+
PIIMatch(
|
|
163
|
+
type="url",
|
|
164
|
+
value=url,
|
|
165
|
+
start=start,
|
|
166
|
+
end=end,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
return matches
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
BUILTIN_DETECTORS: dict[str, Detector] = {
|
|
174
|
+
"email": detect_email,
|
|
175
|
+
"credit_card": detect_credit_card,
|
|
176
|
+
"ip": detect_ip,
|
|
177
|
+
"mac_address": detect_mac_address,
|
|
178
|
+
"url": detect_url,
|
|
179
|
+
}
|
|
180
|
+
"""Registry of built-in detectors keyed by type name."""
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _passes_luhn(card_number: str) -> bool:
|
|
184
|
+
"""Validate credit card number using the Luhn checksum."""
|
|
185
|
+
digits = [int(d) for d in card_number if d.isdigit()]
|
|
186
|
+
if not 13 <= len(digits) <= 19:
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
checksum = 0
|
|
190
|
+
for index, digit in enumerate(reversed(digits)):
|
|
191
|
+
value = digit
|
|
192
|
+
if index % 2 == 1:
|
|
193
|
+
value *= 2
|
|
194
|
+
if value > 9:
|
|
195
|
+
value -= 9
|
|
196
|
+
checksum += value
|
|
197
|
+
return checksum % 10 == 0
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _apply_redact_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
201
|
+
result = content
|
|
202
|
+
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
|
203
|
+
replacement = f"[REDACTED_{match['type'].upper()}]"
|
|
204
|
+
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
205
|
+
return result
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _apply_mask_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
209
|
+
result = content
|
|
210
|
+
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
|
211
|
+
value = match["value"]
|
|
212
|
+
pii_type = match["type"]
|
|
213
|
+
if pii_type == "email":
|
|
214
|
+
parts = value.split("@")
|
|
215
|
+
if len(parts) == 2:
|
|
216
|
+
domain_parts = parts[1].split(".")
|
|
217
|
+
masked = (
|
|
218
|
+
f"{parts[0]}@****.{domain_parts[-1]}"
|
|
219
|
+
if len(domain_parts) >= 2
|
|
220
|
+
else f"{parts[0]}@****"
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
masked = "****"
|
|
224
|
+
elif pii_type == "credit_card":
|
|
225
|
+
digits_only = "".join(c for c in value if c.isdigit())
|
|
226
|
+
separator = "-" if "-" in value else " " if " " in value else ""
|
|
227
|
+
if separator:
|
|
228
|
+
masked = f"****{separator}****{separator}****{separator}{digits_only[-4:]}"
|
|
229
|
+
else:
|
|
230
|
+
masked = f"************{digits_only[-4:]}"
|
|
231
|
+
elif pii_type == "ip":
|
|
232
|
+
octets = value.split(".")
|
|
233
|
+
masked = f"*.*.*.{octets[-1]}" if len(octets) == 4 else "****"
|
|
234
|
+
elif pii_type == "mac_address":
|
|
235
|
+
separator = ":" if ":" in value else "-"
|
|
236
|
+
masked = (
|
|
237
|
+
f"**{separator}**{separator}**{separator}**{separator}**{separator}{value[-2:]}"
|
|
238
|
+
)
|
|
239
|
+
elif pii_type == "url":
|
|
240
|
+
masked = "[MASKED_URL]"
|
|
241
|
+
else:
|
|
242
|
+
masked = f"****{value[-4:]}" if len(value) > 4 else "****"
|
|
243
|
+
result = result[: match["start"]] + masked + result[match["end"] :]
|
|
244
|
+
return result
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _apply_hash_strategy(content: str, matches: list[PIIMatch]) -> str:
|
|
248
|
+
result = content
|
|
249
|
+
for match in sorted(matches, key=lambda item: item["start"], reverse=True):
|
|
250
|
+
digest = hashlib.sha256(match["value"].encode()).hexdigest()[:8]
|
|
251
|
+
replacement = f"<{match['type']}_hash:{digest}>"
|
|
252
|
+
result = result[: match["start"]] + replacement + result[match["end"] :]
|
|
253
|
+
return result
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def apply_strategy(
|
|
257
|
+
content: str,
|
|
258
|
+
matches: list[PIIMatch],
|
|
259
|
+
strategy: RedactionStrategy,
|
|
260
|
+
) -> str:
|
|
261
|
+
"""Apply the configured strategy to matches within content."""
|
|
262
|
+
if not matches:
|
|
263
|
+
return content
|
|
264
|
+
if strategy == "redact":
|
|
265
|
+
return _apply_redact_strategy(content, matches)
|
|
266
|
+
if strategy == "mask":
|
|
267
|
+
return _apply_mask_strategy(content, matches)
|
|
268
|
+
if strategy == "hash":
|
|
269
|
+
return _apply_hash_strategy(content, matches)
|
|
270
|
+
if strategy == "block":
|
|
271
|
+
raise PIIDetectionError(matches[0]["type"], matches)
|
|
272
|
+
msg = f"Unknown redaction strategy: {strategy}"
|
|
273
|
+
raise ValueError(msg)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def resolve_detector(pii_type: str, detector: Detector | str | None) -> Detector:
|
|
277
|
+
"""Return a callable detector for the given configuration."""
|
|
278
|
+
if detector is None:
|
|
279
|
+
if pii_type not in BUILTIN_DETECTORS:
|
|
280
|
+
msg = (
|
|
281
|
+
f"Unknown PII type: {pii_type}. "
|
|
282
|
+
f"Must be one of {list(BUILTIN_DETECTORS.keys())} or provide a custom detector."
|
|
283
|
+
)
|
|
284
|
+
raise ValueError(msg)
|
|
285
|
+
return BUILTIN_DETECTORS[pii_type]
|
|
286
|
+
if isinstance(detector, str):
|
|
287
|
+
pattern = re.compile(detector)
|
|
288
|
+
|
|
289
|
+
def regex_detector(content: str) -> list[PIIMatch]:
|
|
290
|
+
return [
|
|
291
|
+
PIIMatch(
|
|
292
|
+
type=pii_type,
|
|
293
|
+
value=match.group(),
|
|
294
|
+
start=match.start(),
|
|
295
|
+
end=match.end(),
|
|
296
|
+
)
|
|
297
|
+
for match in pattern.finditer(content)
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
return regex_detector
|
|
301
|
+
return detector
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@dataclass(frozen=True)
|
|
305
|
+
class RedactionRule:
|
|
306
|
+
"""Configuration for handling a single PII type."""
|
|
307
|
+
|
|
308
|
+
pii_type: str
|
|
309
|
+
strategy: RedactionStrategy = "redact"
|
|
310
|
+
detector: Detector | str | None = None
|
|
311
|
+
|
|
312
|
+
def resolve(self) -> ResolvedRedactionRule:
|
|
313
|
+
"""Resolve runtime detector and return an immutable rule."""
|
|
314
|
+
resolved_detector = resolve_detector(self.pii_type, self.detector)
|
|
315
|
+
return ResolvedRedactionRule(
|
|
316
|
+
pii_type=self.pii_type,
|
|
317
|
+
strategy=self.strategy,
|
|
318
|
+
detector=resolved_detector,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@dataclass(frozen=True)
|
|
323
|
+
class ResolvedRedactionRule:
|
|
324
|
+
"""Resolved redaction rule ready for execution."""
|
|
325
|
+
|
|
326
|
+
pii_type: str
|
|
327
|
+
strategy: RedactionStrategy
|
|
328
|
+
detector: Detector
|
|
329
|
+
|
|
330
|
+
def apply(self, content: str) -> tuple[str, list[PIIMatch]]:
|
|
331
|
+
"""Apply this rule to content, returning new content and matches."""
|
|
332
|
+
matches = self.detector(content)
|
|
333
|
+
if not matches:
|
|
334
|
+
return content, []
|
|
335
|
+
updated = apply_strategy(content, matches, self.strategy)
|
|
336
|
+
return updated, matches
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
__all__ = [
|
|
340
|
+
"PIIDetectionError",
|
|
341
|
+
"PIIMatch",
|
|
342
|
+
"RedactionRule",
|
|
343
|
+
"ResolvedRedactionRule",
|
|
344
|
+
"apply_strategy",
|
|
345
|
+
"detect_credit_card",
|
|
346
|
+
"detect_email",
|
|
347
|
+
"detect_ip",
|
|
348
|
+
"detect_mac_address",
|
|
349
|
+
"detect_url",
|
|
350
|
+
]
|
|
@@ -8,9 +8,9 @@ with any LangChain chat model.
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
-
from collections.abc import Callable, Iterable, Sequence
|
|
11
|
+
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
|
12
12
|
from dataclasses import dataclass
|
|
13
|
-
from typing import
|
|
13
|
+
from typing import Literal
|
|
14
14
|
|
|
15
15
|
from langchain_core.messages import (
|
|
16
16
|
AIMessage,
|
|
@@ -22,10 +22,12 @@ from langchain_core.messages import (
|
|
|
22
22
|
from langchain_core.messages.utils import count_tokens_approximately
|
|
23
23
|
from typing_extensions import Protocol
|
|
24
24
|
|
|
25
|
-
from langchain.agents.middleware.types import
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
25
|
+
from langchain.agents.middleware.types import (
|
|
26
|
+
AgentMiddleware,
|
|
27
|
+
ModelCallResult,
|
|
28
|
+
ModelRequest,
|
|
29
|
+
ModelResponse,
|
|
30
|
+
)
|
|
29
31
|
|
|
30
32
|
DEFAULT_TOOL_PLACEHOLDER = "[cleared]"
|
|
31
33
|
|
|
@@ -180,11 +182,11 @@ class ClearToolUsesEdit(ContextEdit):
|
|
|
180
182
|
|
|
181
183
|
|
|
182
184
|
class ContextEditingMiddleware(AgentMiddleware):
|
|
183
|
-
"""
|
|
185
|
+
"""Automatically prunes tool results to manage context size.
|
|
184
186
|
|
|
185
187
|
The middleware applies a sequence of edits when the total input token count
|
|
186
|
-
exceeds configured thresholds. Currently the
|
|
187
|
-
supported, aligning with Anthropic's
|
|
188
|
+
exceeds configured thresholds. Currently the `ClearToolUsesEdit` strategy is
|
|
189
|
+
supported, aligning with Anthropic's `clear_tool_uses_20250919` behaviour.
|
|
188
190
|
"""
|
|
189
191
|
|
|
190
192
|
edits: list[ContextEdit]
|
|
@@ -196,7 +198,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
196
198
|
edits: Iterable[ContextEdit] | None = None,
|
|
197
199
|
token_count_method: Literal["approximate", "model"] = "approximate", # noqa: S107
|
|
198
200
|
) -> None:
|
|
199
|
-
"""
|
|
201
|
+
"""Initializes a context editing middleware instance.
|
|
200
202
|
|
|
201
203
|
Args:
|
|
202
204
|
edits: Sequence of edit strategies to apply. Defaults to a single
|
|
@@ -209,15 +211,42 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
209
211
|
self.edits = list(edits or (ClearToolUsesEdit(),))
|
|
210
212
|
self.token_count_method = token_count_method
|
|
211
213
|
|
|
212
|
-
def
|
|
214
|
+
def wrap_model_call(
|
|
215
|
+
self,
|
|
216
|
+
request: ModelRequest,
|
|
217
|
+
handler: Callable[[ModelRequest], ModelResponse],
|
|
218
|
+
) -> ModelCallResult:
|
|
219
|
+
"""Apply context edits before invoking the model via handler."""
|
|
220
|
+
if not request.messages:
|
|
221
|
+
return handler(request)
|
|
222
|
+
|
|
223
|
+
if self.token_count_method == "approximate": # noqa: S105
|
|
224
|
+
|
|
225
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
226
|
+
return count_tokens_approximately(messages)
|
|
227
|
+
else:
|
|
228
|
+
system_msg = (
|
|
229
|
+
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
233
|
+
return request.model.get_num_tokens_from_messages(
|
|
234
|
+
system_msg + list(messages), request.tools
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
for edit in self.edits:
|
|
238
|
+
edit.apply(request.messages, count_tokens=count_tokens)
|
|
239
|
+
|
|
240
|
+
return handler(request)
|
|
241
|
+
|
|
242
|
+
async def awrap_model_call(
|
|
213
243
|
self,
|
|
214
244
|
request: ModelRequest,
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
"""Modify the model request by applying context edits before invocation."""
|
|
245
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
246
|
+
) -> ModelCallResult:
|
|
247
|
+
"""Apply context edits before invoking the model via handler (async version)."""
|
|
219
248
|
if not request.messages:
|
|
220
|
-
return request
|
|
249
|
+
return await handler(request)
|
|
221
250
|
|
|
222
251
|
if self.token_count_method == "approximate": # noqa: S105
|
|
223
252
|
|
|
@@ -236,7 +265,7 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
236
265
|
for edit in self.edits:
|
|
237
266
|
edit.apply(request.messages, count_tokens=count_tokens)
|
|
238
267
|
|
|
239
|
-
return request
|
|
268
|
+
return await handler(request)
|
|
240
269
|
|
|
241
270
|
|
|
242
271
|
__all__ = [
|