langchain 1.0.0a13__py3-none-any.whl → 1.0.0a15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langchain might be problematic. Click here for more details.
- langchain/__init__.py +1 -1
- langchain/agents/factory.py +115 -29
- langchain/agents/middleware/__init__.py +6 -5
- langchain/agents/middleware/context_editing.py +29 -1
- langchain/agents/middleware/human_in_the_loop.py +13 -13
- langchain/agents/middleware/model_call_limit.py +38 -4
- langchain/agents/middleware/model_fallback.py +36 -1
- langchain/agents/middleware/pii.py +6 -8
- langchain/agents/middleware/{planning.py → todo.py} +18 -5
- langchain/agents/middleware/tool_call_limit.py +88 -15
- langchain/agents/middleware/types.py +196 -18
- langchain/embeddings/__init__.py +0 -2
- langchain/messages/__init__.py +32 -0
- langchain/tools/__init__.py +1 -6
- langchain/tools/tool_node.py +62 -11
- langchain-1.0.0a15.dist-info/METADATA +85 -0
- langchain-1.0.0a15.dist-info/RECORD +29 -0
- langchain/agents/middleware/prompt_caching.py +0 -89
- langchain/documents/__init__.py +0 -7
- langchain/embeddings/cache.py +0 -361
- langchain/storage/__init__.py +0 -22
- langchain/storage/encoder_backed.py +0 -122
- langchain/storage/exceptions.py +0 -5
- langchain/storage/in_memory.py +0 -13
- langchain-1.0.0a13.dist-info/METADATA +0 -125
- langchain-1.0.0a13.dist-info/RECORD +0 -36
- {langchain-1.0.0a13.dist-info → langchain-1.0.0a15.dist-info}/WHEEL +0 -0
- {langchain-1.0.0a13.dist-info → langchain-1.0.0a15.dist-info}/licenses/LICENSE +0 -0
langchain/__init__.py
CHANGED
langchain/agents/factory.py
CHANGED
|
@@ -13,9 +13,6 @@ from typing import (
|
|
|
13
13
|
get_type_hints,
|
|
14
14
|
)
|
|
15
15
|
|
|
16
|
-
if TYPE_CHECKING:
|
|
17
|
-
from collections.abc import Awaitable
|
|
18
|
-
|
|
19
16
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
20
17
|
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
|
|
21
18
|
from langchain_core.tools import BaseTool
|
|
@@ -47,11 +44,10 @@ from langchain.agents.structured_output import (
|
|
|
47
44
|
ToolStrategy,
|
|
48
45
|
)
|
|
49
46
|
from langchain.chat_models import init_chat_model
|
|
50
|
-
from langchain.tools import
|
|
51
|
-
from langchain.tools.tool_node import ToolCallWithContext
|
|
47
|
+
from langchain.tools.tool_node import ToolCallWithContext, _ToolNode
|
|
52
48
|
|
|
53
49
|
if TYPE_CHECKING:
|
|
54
|
-
from collections.abc import Callable, Sequence
|
|
50
|
+
from collections.abc import Awaitable, Callable, Sequence
|
|
55
51
|
|
|
56
52
|
from langchain_core.runnables import Runnable
|
|
57
53
|
from langgraph.cache.base import BaseCache
|
|
@@ -449,6 +445,70 @@ def _chain_tool_call_wrappers(
|
|
|
449
445
|
return result
|
|
450
446
|
|
|
451
447
|
|
|
448
|
+
def _chain_async_tool_call_wrappers(
|
|
449
|
+
wrappers: Sequence[
|
|
450
|
+
Callable[
|
|
451
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
452
|
+
Awaitable[ToolMessage | Command],
|
|
453
|
+
]
|
|
454
|
+
],
|
|
455
|
+
) -> (
|
|
456
|
+
Callable[
|
|
457
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
458
|
+
Awaitable[ToolMessage | Command],
|
|
459
|
+
]
|
|
460
|
+
| None
|
|
461
|
+
):
|
|
462
|
+
"""Compose async wrappers into middleware stack (first = outermost).
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
wrappers: Async wrappers in middleware order.
|
|
466
|
+
|
|
467
|
+
Returns:
|
|
468
|
+
Composed async wrapper, or None if empty.
|
|
469
|
+
"""
|
|
470
|
+
if not wrappers:
|
|
471
|
+
return None
|
|
472
|
+
|
|
473
|
+
if len(wrappers) == 1:
|
|
474
|
+
return wrappers[0]
|
|
475
|
+
|
|
476
|
+
def compose_two(
|
|
477
|
+
outer: Callable[
|
|
478
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
479
|
+
Awaitable[ToolMessage | Command],
|
|
480
|
+
],
|
|
481
|
+
inner: Callable[
|
|
482
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
483
|
+
Awaitable[ToolMessage | Command],
|
|
484
|
+
],
|
|
485
|
+
) -> Callable[
|
|
486
|
+
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
|
|
487
|
+
Awaitable[ToolMessage | Command],
|
|
488
|
+
]:
|
|
489
|
+
"""Compose two async wrappers where outer wraps inner."""
|
|
490
|
+
|
|
491
|
+
async def composed(
|
|
492
|
+
request: ToolCallRequest,
|
|
493
|
+
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
|
494
|
+
) -> ToolMessage | Command:
|
|
495
|
+
# Create an async callable that invokes inner with the original execute
|
|
496
|
+
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
|
|
497
|
+
return await inner(req, execute)
|
|
498
|
+
|
|
499
|
+
# Outer can call call_inner multiple times
|
|
500
|
+
return await outer(request, call_inner)
|
|
501
|
+
|
|
502
|
+
return composed
|
|
503
|
+
|
|
504
|
+
# Chain all wrappers: first -> second -> ... -> last
|
|
505
|
+
result = wrappers[-1]
|
|
506
|
+
for wrapper in reversed(wrappers[:-1]):
|
|
507
|
+
result = compose_two(wrapper, result)
|
|
508
|
+
|
|
509
|
+
return result
|
|
510
|
+
|
|
511
|
+
|
|
452
512
|
def create_agent( # noqa: PLR0915
|
|
453
513
|
model: str | BaseChatModel,
|
|
454
514
|
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
@@ -477,13 +537,12 @@ def create_agent( # noqa: PLR0915
|
|
|
477
537
|
(e.g., `"openai:gpt-4"`), a chat model instance (e.g., `ChatOpenAI()`).
|
|
478
538
|
tools: A list of tools, dicts, or callables. If `None` or an empty list,
|
|
479
539
|
the agent will consist of a model node without a tool calling loop.
|
|
480
|
-
system_prompt: An optional system prompt for the LLM.
|
|
481
|
-
|
|
482
|
-
of the message list.
|
|
540
|
+
system_prompt: An optional system prompt for the LLM. Prompts are converted to a
|
|
541
|
+
`SystemMessage` and added to the beginning of the message list.
|
|
483
542
|
middleware: A sequence of middleware instances to apply to the agent.
|
|
484
543
|
Middleware can intercept and modify agent behavior at various stages.
|
|
485
544
|
response_format: An optional configuration for structured responses.
|
|
486
|
-
Can be a ToolStrategy
|
|
545
|
+
Can be a `ToolStrategy`, `ProviderStrategy`, or a Pydantic model class.
|
|
487
546
|
If provided, the agent will handle structured output during the
|
|
488
547
|
conversation flow. Raw schemas will be wrapped in an appropriate strategy
|
|
489
548
|
based on model capabilities.
|
|
@@ -500,14 +559,14 @@ def create_agent( # noqa: PLR0915
|
|
|
500
559
|
This is useful if you want to return directly or run additional processing
|
|
501
560
|
on an output.
|
|
502
561
|
debug: A flag indicating whether to enable debug mode.
|
|
503
|
-
name: An optional name for the CompiledStateGraph
|
|
562
|
+
name: An optional name for the `CompiledStateGraph`.
|
|
504
563
|
This name will be automatically used when adding the agent graph to
|
|
505
564
|
another graph as a subgraph node - particularly useful for building
|
|
506
565
|
multi-agent systems.
|
|
507
|
-
cache: An optional BaseCache instance to enable caching of graph execution.
|
|
566
|
+
cache: An optional `BaseCache` instance to enable caching of graph execution.
|
|
508
567
|
|
|
509
568
|
Returns:
|
|
510
|
-
A compiled StateGraph that can be used for chat interactions.
|
|
569
|
+
A compiled `StateGraph` that can be used for chat interactions.
|
|
511
570
|
|
|
512
571
|
The agent node calls the language model with the messages list (after applying
|
|
513
572
|
the system prompt). If the resulting AIMessage contains `tool_calls`, the graph will
|
|
@@ -576,9 +635,14 @@ def create_agent( # noqa: PLR0915
|
|
|
576
635
|
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
|
|
577
636
|
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
|
|
578
637
|
|
|
579
|
-
# Collect middleware with wrap_tool_call hooks
|
|
638
|
+
# Collect middleware with wrap_tool_call or awrap_tool_call hooks
|
|
639
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
640
|
+
# when middleware doesn't support the execution path
|
|
580
641
|
middleware_w_wrap_tool_call = [
|
|
581
|
-
m
|
|
642
|
+
m
|
|
643
|
+
for m in middleware
|
|
644
|
+
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
645
|
+
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
582
646
|
]
|
|
583
647
|
|
|
584
648
|
# Chain all wrap_tool_call handlers into a single composed handler
|
|
@@ -587,8 +651,24 @@ def create_agent( # noqa: PLR0915
|
|
|
587
651
|
wrappers = [m.wrap_tool_call for m in middleware_w_wrap_tool_call]
|
|
588
652
|
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
|
|
589
653
|
|
|
654
|
+
# Collect middleware with awrap_tool_call or wrap_tool_call hooks
|
|
655
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
656
|
+
# when middleware doesn't support the execution path
|
|
657
|
+
middleware_w_awrap_tool_call = [
|
|
658
|
+
m
|
|
659
|
+
for m in middleware
|
|
660
|
+
if m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
|
|
661
|
+
or m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
|
|
662
|
+
]
|
|
663
|
+
|
|
664
|
+
# Chain all awrap_tool_call handlers into a single composed async handler
|
|
665
|
+
awrap_tool_call_wrapper = None
|
|
666
|
+
if middleware_w_awrap_tool_call:
|
|
667
|
+
async_wrappers = [m.awrap_tool_call for m in middleware_w_awrap_tool_call]
|
|
668
|
+
awrap_tool_call_wrapper = _chain_async_tool_call_wrappers(async_wrappers)
|
|
669
|
+
|
|
590
670
|
# Setup tools
|
|
591
|
-
tool_node:
|
|
671
|
+
tool_node: _ToolNode | None = None
|
|
592
672
|
# Extract built-in provider tools (dict format) and regular tools (BaseTool/callables)
|
|
593
673
|
built_in_tools = [t for t in tools if isinstance(t, dict)]
|
|
594
674
|
regular_tools = [t for t in tools if not isinstance(t, dict)]
|
|
@@ -598,7 +678,11 @@ def create_agent( # noqa: PLR0915
|
|
|
598
678
|
|
|
599
679
|
# Only create ToolNode if we have client-side tools
|
|
600
680
|
tool_node = (
|
|
601
|
-
|
|
681
|
+
_ToolNode(
|
|
682
|
+
tools=available_tools,
|
|
683
|
+
wrap_tool_call=wrap_tool_call_wrapper,
|
|
684
|
+
awrap_tool_call=awrap_tool_call_wrapper,
|
|
685
|
+
)
|
|
602
686
|
if available_tools
|
|
603
687
|
else None
|
|
604
688
|
)
|
|
@@ -640,13 +724,23 @@ def create_agent( # noqa: PLR0915
|
|
|
640
724
|
if m.__class__.after_agent is not AgentMiddleware.after_agent
|
|
641
725
|
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
|
|
642
726
|
]
|
|
727
|
+
# Collect middleware with wrap_model_call or awrap_model_call hooks
|
|
728
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
729
|
+
# when middleware doesn't support the execution path
|
|
643
730
|
middleware_w_wrap_model_call = [
|
|
644
|
-
m
|
|
731
|
+
m
|
|
732
|
+
for m in middleware
|
|
733
|
+
if m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
734
|
+
or m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
645
735
|
]
|
|
736
|
+
# Collect middleware with awrap_model_call or wrap_model_call hooks
|
|
737
|
+
# Include middleware with either implementation to ensure NotImplementedError is raised
|
|
738
|
+
# when middleware doesn't support the execution path
|
|
646
739
|
middleware_w_awrap_model_call = [
|
|
647
740
|
m
|
|
648
741
|
for m in middleware
|
|
649
742
|
if m.__class__.awrap_model_call is not AgentMiddleware.awrap_model_call
|
|
743
|
+
or m.__class__.wrap_model_call is not AgentMiddleware.wrap_model_call
|
|
650
744
|
]
|
|
651
745
|
|
|
652
746
|
# Compose wrap_model_call handlers into a single middleware stack (sync)
|
|
@@ -937,11 +1031,7 @@ def create_agent( # noqa: PLR0915
|
|
|
937
1031
|
if response.structured_response is not None:
|
|
938
1032
|
state_updates["structured_response"] = response.structured_response
|
|
939
1033
|
|
|
940
|
-
return
|
|
941
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
942
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
943
|
-
**state_updates,
|
|
944
|
-
}
|
|
1034
|
+
return state_updates
|
|
945
1035
|
|
|
946
1036
|
async def _execute_model_async(request: ModelRequest) -> ModelResponse:
|
|
947
1037
|
"""Execute model asynchronously and return response.
|
|
@@ -992,11 +1082,7 @@ def create_agent( # noqa: PLR0915
|
|
|
992
1082
|
if response.structured_response is not None:
|
|
993
1083
|
state_updates["structured_response"] = response.structured_response
|
|
994
1084
|
|
|
995
|
-
return
|
|
996
|
-
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
997
|
-
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
998
|
-
**state_updates,
|
|
999
|
-
}
|
|
1085
|
+
return state_updates
|
|
1000
1086
|
|
|
1001
1087
|
# Use sync or async based on model capabilities
|
|
1002
1088
|
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
|
|
@@ -1378,7 +1464,7 @@ def _make_model_to_model_edge(
|
|
|
1378
1464
|
|
|
1379
1465
|
def _make_tools_to_model_edge(
|
|
1380
1466
|
*,
|
|
1381
|
-
tool_node:
|
|
1467
|
+
tool_node: _ToolNode,
|
|
1382
1468
|
model_destination: str,
|
|
1383
1469
|
structured_output_tools: dict[str, OutputToolBinding],
|
|
1384
1470
|
end_destination: str,
|
|
@@ -11,9 +11,8 @@ from .human_in_the_loop import (
|
|
|
11
11
|
from .model_call_limit import ModelCallLimitMiddleware
|
|
12
12
|
from .model_fallback import ModelFallbackMiddleware
|
|
13
13
|
from .pii import PIIDetectionError, PIIMiddleware
|
|
14
|
-
from .planning import PlanningMiddleware
|
|
15
|
-
from .prompt_caching import AnthropicPromptCachingMiddleware
|
|
16
14
|
from .summarization import SummarizationMiddleware
|
|
15
|
+
from .todo import TodoListMiddleware
|
|
17
16
|
from .tool_call_limit import ToolCallLimitMiddleware
|
|
18
17
|
from .tool_emulator import LLMToolEmulator
|
|
19
18
|
from .tool_selection import LLMToolSelectorMiddleware
|
|
@@ -21,6 +20,7 @@ from .types import (
|
|
|
21
20
|
AgentMiddleware,
|
|
22
21
|
AgentState,
|
|
23
22
|
ModelRequest,
|
|
23
|
+
ModelResponse,
|
|
24
24
|
after_agent,
|
|
25
25
|
after_model,
|
|
26
26
|
before_agent,
|
|
@@ -28,13 +28,12 @@ from .types import (
|
|
|
28
28
|
dynamic_prompt,
|
|
29
29
|
hook_config,
|
|
30
30
|
wrap_model_call,
|
|
31
|
+
wrap_tool_call,
|
|
31
32
|
)
|
|
32
33
|
|
|
33
34
|
__all__ = [
|
|
34
35
|
"AgentMiddleware",
|
|
35
36
|
"AgentState",
|
|
36
|
-
# should move to langchain-anthropic if we decide to keep it
|
|
37
|
-
"AnthropicPromptCachingMiddleware",
|
|
38
37
|
"ClearToolUsesEdit",
|
|
39
38
|
"ContextEditingMiddleware",
|
|
40
39
|
"HumanInTheLoopMiddleware",
|
|
@@ -44,10 +43,11 @@ __all__ = [
|
|
|
44
43
|
"ModelCallLimitMiddleware",
|
|
45
44
|
"ModelFallbackMiddleware",
|
|
46
45
|
"ModelRequest",
|
|
46
|
+
"ModelResponse",
|
|
47
47
|
"PIIDetectionError",
|
|
48
48
|
"PIIMiddleware",
|
|
49
|
-
"PlanningMiddleware",
|
|
50
49
|
"SummarizationMiddleware",
|
|
50
|
+
"TodoListMiddleware",
|
|
51
51
|
"ToolCallLimitMiddleware",
|
|
52
52
|
"after_agent",
|
|
53
53
|
"after_model",
|
|
@@ -56,4 +56,5 @@ __all__ = [
|
|
|
56
56
|
"dynamic_prompt",
|
|
57
57
|
"hook_config",
|
|
58
58
|
"wrap_model_call",
|
|
59
|
+
"wrap_tool_call",
|
|
59
60
|
]
|
|
@@ -8,7 +8,7 @@ with any LangChain chat model.
|
|
|
8
8
|
|
|
9
9
|
from __future__ import annotations
|
|
10
10
|
|
|
11
|
-
from collections.abc import Callable, Iterable, Sequence
|
|
11
|
+
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
|
12
12
|
from dataclasses import dataclass
|
|
13
13
|
from typing import Literal
|
|
14
14
|
|
|
@@ -239,6 +239,34 @@ class ContextEditingMiddleware(AgentMiddleware):
|
|
|
239
239
|
|
|
240
240
|
return handler(request)
|
|
241
241
|
|
|
242
|
+
async def awrap_model_call(
|
|
243
|
+
self,
|
|
244
|
+
request: ModelRequest,
|
|
245
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
246
|
+
) -> ModelCallResult:
|
|
247
|
+
"""Apply context edits before invoking the model via handler (async version)."""
|
|
248
|
+
if not request.messages:
|
|
249
|
+
return await handler(request)
|
|
250
|
+
|
|
251
|
+
if self.token_count_method == "approximate": # noqa: S105
|
|
252
|
+
|
|
253
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
254
|
+
return count_tokens_approximately(messages)
|
|
255
|
+
else:
|
|
256
|
+
system_msg = (
|
|
257
|
+
[SystemMessage(content=request.system_prompt)] if request.system_prompt else []
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def count_tokens(messages: Sequence[BaseMessage]) -> int:
|
|
261
|
+
return request.model.get_num_tokens_from_messages(
|
|
262
|
+
system_msg + list(messages), request.tools
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
for edit in self.edits:
|
|
266
|
+
edit.apply(request.messages, count_tokens=count_tokens)
|
|
267
|
+
|
|
268
|
+
return await handler(request)
|
|
269
|
+
|
|
242
270
|
|
|
243
271
|
__all__ = [
|
|
244
272
|
"ClearToolUsesEdit",
|
|
@@ -11,23 +11,23 @@ from langchain.agents.middleware.types import AgentMiddleware, AgentState
|
|
|
11
11
|
|
|
12
12
|
|
|
13
13
|
class Action(TypedDict):
|
|
14
|
-
"""Represents an action with a name and
|
|
14
|
+
"""Represents an action with a name and args."""
|
|
15
15
|
|
|
16
16
|
name: str
|
|
17
17
|
"""The type or name of action being requested (e.g., "add_numbers")."""
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
"""Key-value pairs of
|
|
19
|
+
args: dict[str, Any]
|
|
20
|
+
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
21
21
|
|
|
22
22
|
|
|
23
23
|
class ActionRequest(TypedDict):
|
|
24
|
-
"""Represents an action request with a name,
|
|
24
|
+
"""Represents an action request with a name, args, and description."""
|
|
25
25
|
|
|
26
26
|
name: str
|
|
27
27
|
"""The name of the action being requested."""
|
|
28
28
|
|
|
29
|
-
|
|
30
|
-
"""Key-value pairs of
|
|
29
|
+
args: dict[str, Any]
|
|
30
|
+
"""Key-value pairs of args needed for the action (e.g., {"a": 1, "b": 2})."""
|
|
31
31
|
|
|
32
32
|
description: NotRequired[str]
|
|
33
33
|
"""The description of the action to be reviewed."""
|
|
@@ -45,8 +45,8 @@ class ReviewConfig(TypedDict):
|
|
|
45
45
|
allowed_decisions: list[DecisionType]
|
|
46
46
|
"""The decisions that are allowed for this request."""
|
|
47
47
|
|
|
48
|
-
|
|
49
|
-
"""JSON schema for the
|
|
48
|
+
args_schema: NotRequired[dict[str, Any]]
|
|
49
|
+
"""JSON schema for the args associated with the action, if edits are allowed."""
|
|
50
50
|
|
|
51
51
|
|
|
52
52
|
class HITLRequest(TypedDict):
|
|
@@ -150,8 +150,8 @@ class InterruptOnConfig(TypedDict):
|
|
|
150
150
|
)
|
|
151
151
|
```
|
|
152
152
|
"""
|
|
153
|
-
|
|
154
|
-
"""JSON schema for the
|
|
153
|
+
args_schema: NotRequired[dict[str, Any]]
|
|
154
|
+
"""JSON schema for the args associated with the action, if edits are allowed."""
|
|
155
155
|
|
|
156
156
|
|
|
157
157
|
class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
@@ -214,12 +214,12 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
214
214
|
# Create ActionRequest with description
|
|
215
215
|
action_request = ActionRequest(
|
|
216
216
|
name=tool_name,
|
|
217
|
-
|
|
217
|
+
args=tool_args,
|
|
218
218
|
description=description,
|
|
219
219
|
)
|
|
220
220
|
|
|
221
221
|
# Create ReviewConfig
|
|
222
|
-
# eventually can get tool information and populate
|
|
222
|
+
# eventually can get tool information and populate args_schema from there
|
|
223
223
|
review_config = ReviewConfig(
|
|
224
224
|
action_name=tool_name,
|
|
225
225
|
allowed_decisions=config["allowed_decisions"],
|
|
@@ -244,7 +244,7 @@ class HumanInTheLoopMiddleware(AgentMiddleware):
|
|
|
244
244
|
ToolCall(
|
|
245
245
|
type="tool_call",
|
|
246
246
|
name=edited_action["name"],
|
|
247
|
-
args=edited_action["
|
|
247
|
+
args=edited_action["args"],
|
|
248
248
|
id=tool_call["id"],
|
|
249
249
|
),
|
|
250
250
|
None,
|
|
@@ -2,16 +2,33 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
-
from typing import TYPE_CHECKING, Any, Literal
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal
|
|
6
6
|
|
|
7
7
|
from langchain_core.messages import AIMessage
|
|
8
|
+
from langgraph.channels.untracked_value import UntrackedValue
|
|
9
|
+
from typing_extensions import NotRequired
|
|
8
10
|
|
|
9
|
-
from langchain.agents.middleware.types import
|
|
11
|
+
from langchain.agents.middleware.types import (
|
|
12
|
+
AgentMiddleware,
|
|
13
|
+
AgentState,
|
|
14
|
+
PrivateStateAttr,
|
|
15
|
+
hook_config,
|
|
16
|
+
)
|
|
10
17
|
|
|
11
18
|
if TYPE_CHECKING:
|
|
12
19
|
from langgraph.runtime import Runtime
|
|
13
20
|
|
|
14
21
|
|
|
22
|
+
class ModelCallLimitState(AgentState):
|
|
23
|
+
"""State schema for ModelCallLimitMiddleware.
|
|
24
|
+
|
|
25
|
+
Extends AgentState with model call tracking fields.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
thread_model_call_count: NotRequired[Annotated[int, PrivateStateAttr]]
|
|
29
|
+
run_model_call_count: NotRequired[Annotated[int, UntrackedValue, PrivateStateAttr]]
|
|
30
|
+
|
|
31
|
+
|
|
15
32
|
def _build_limit_exceeded_message(
|
|
16
33
|
thread_count: int,
|
|
17
34
|
run_count: int,
|
|
@@ -69,7 +86,7 @@ class ModelCallLimitExceededError(Exception):
|
|
|
69
86
|
super().__init__(msg)
|
|
70
87
|
|
|
71
88
|
|
|
72
|
-
class ModelCallLimitMiddleware(AgentMiddleware):
|
|
89
|
+
class ModelCallLimitMiddleware(AgentMiddleware[ModelCallLimitState, Any]):
|
|
73
90
|
"""Middleware that tracks model call counts and enforces limits.
|
|
74
91
|
|
|
75
92
|
This middleware monitors the number of model calls made during agent execution
|
|
@@ -97,6 +114,8 @@ class ModelCallLimitMiddleware(AgentMiddleware):
|
|
|
97
114
|
```
|
|
98
115
|
"""
|
|
99
116
|
|
|
117
|
+
state_schema = ModelCallLimitState
|
|
118
|
+
|
|
100
119
|
def __init__(
|
|
101
120
|
self,
|
|
102
121
|
*,
|
|
@@ -135,7 +154,7 @@ class ModelCallLimitMiddleware(AgentMiddleware):
|
|
|
135
154
|
self.exit_behavior = exit_behavior
|
|
136
155
|
|
|
137
156
|
@hook_config(can_jump_to=["end"])
|
|
138
|
-
def before_model(self, state:
|
|
157
|
+
def before_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
139
158
|
"""Check model call limits before making a model call.
|
|
140
159
|
|
|
141
160
|
Args:
|
|
@@ -175,3 +194,18 @@ class ModelCallLimitMiddleware(AgentMiddleware):
|
|
|
175
194
|
return {"jump_to": "end", "messages": [limit_ai_message]}
|
|
176
195
|
|
|
177
196
|
return None
|
|
197
|
+
|
|
198
|
+
def after_model(self, state: ModelCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
|
|
199
|
+
"""Increment model call counts after a model call.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
state: The current agent state.
|
|
203
|
+
runtime: The langgraph runtime.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
State updates with incremented call counts.
|
|
207
|
+
"""
|
|
208
|
+
return {
|
|
209
|
+
"thread_model_call_count": state.get("thread_model_call_count", 0) + 1,
|
|
210
|
+
"run_model_call_count": state.get("run_model_call_count", 0) + 1,
|
|
211
|
+
}
|
|
@@ -13,7 +13,7 @@ from langchain.agents.middleware.types import (
|
|
|
13
13
|
from langchain.chat_models import init_chat_model
|
|
14
14
|
|
|
15
15
|
if TYPE_CHECKING:
|
|
16
|
-
from collections.abc import Callable
|
|
16
|
+
from collections.abc import Awaitable, Callable
|
|
17
17
|
|
|
18
18
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
19
19
|
|
|
@@ -102,3 +102,38 @@ class ModelFallbackMiddleware(AgentMiddleware):
|
|
|
102
102
|
continue
|
|
103
103
|
|
|
104
104
|
raise last_exception
|
|
105
|
+
|
|
106
|
+
async def awrap_model_call(
|
|
107
|
+
self,
|
|
108
|
+
request: ModelRequest,
|
|
109
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
110
|
+
) -> ModelCallResult:
|
|
111
|
+
"""Try fallback models in sequence on errors (async version).
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
request: Initial model request.
|
|
115
|
+
handler: Async callback to execute the model.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
AIMessage from successful model call.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
Exception: If all models fail, re-raises last exception.
|
|
122
|
+
"""
|
|
123
|
+
# Try primary model first
|
|
124
|
+
last_exception: Exception
|
|
125
|
+
try:
|
|
126
|
+
return await handler(request)
|
|
127
|
+
except Exception as e: # noqa: BLE001
|
|
128
|
+
last_exception = e
|
|
129
|
+
|
|
130
|
+
# Try fallback models
|
|
131
|
+
for fallback_model in self.models:
|
|
132
|
+
request.model = fallback_model
|
|
133
|
+
try:
|
|
134
|
+
return await handler(request)
|
|
135
|
+
except Exception as e: # noqa: BLE001
|
|
136
|
+
last_exception = e
|
|
137
|
+
continue
|
|
138
|
+
|
|
139
|
+
raise last_exception
|
|
@@ -431,14 +431,12 @@ class PIIMiddleware(AgentMiddleware):
|
|
|
431
431
|
|
|
432
432
|
Strategy Selection Guide:
|
|
433
433
|
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
`
|
|
438
|
-
`
|
|
439
|
-
`
|
|
440
|
-
`hash` Yes (pseudonymous) Analytics, debugging
|
|
441
|
-
======== =================== =======================================
|
|
434
|
+
| Strategy | Preserves Identity? | Best For |
|
|
435
|
+
| -------- | ------------------- | --------------------------------------- |
|
|
436
|
+
| `block` | N/A | Avoid PII completely |
|
|
437
|
+
| `redact` | No | General compliance, log sanitization |
|
|
438
|
+
| `mask` | No | Human readability, customer service UIs |
|
|
439
|
+
| `hash` | Yes (pseudonymous) | Analytics, debugging |
|
|
442
440
|
|
|
443
441
|
Example:
|
|
444
442
|
```python
|
|
@@ -6,7 +6,7 @@ from __future__ import annotations
|
|
|
6
6
|
from typing import TYPE_CHECKING, Annotated, Literal
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
|
-
from collections.abc import Callable
|
|
9
|
+
from collections.abc import Awaitable, Callable
|
|
10
10
|
|
|
11
11
|
from langchain_core.messages import ToolMessage
|
|
12
12
|
from langchain_core.tools import tool
|
|
@@ -126,7 +126,7 @@ def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCall
|
|
|
126
126
|
)
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
class
|
|
129
|
+
class TodoListMiddleware(AgentMiddleware):
|
|
130
130
|
"""Middleware that provides todo list management capabilities to agents.
|
|
131
131
|
|
|
132
132
|
This middleware adds a `write_todos` tool that allows agents to create and manage
|
|
@@ -139,10 +139,10 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
139
139
|
|
|
140
140
|
Example:
|
|
141
141
|
```python
|
|
142
|
-
from langchain.agents.middleware.
|
|
142
|
+
from langchain.agents.middleware.todo import TodoListMiddleware
|
|
143
143
|
from langchain.agents import create_agent
|
|
144
144
|
|
|
145
|
-
agent = create_agent("openai:gpt-4o", middleware=[
|
|
145
|
+
agent = create_agent("openai:gpt-4o", middleware=[TodoListMiddleware()])
|
|
146
146
|
|
|
147
147
|
# Agent now has access to write_todos tool and todo state tracking
|
|
148
148
|
result = await agent.invoke({"messages": [HumanMessage("Help me refactor my codebase")]})
|
|
@@ -165,7 +165,7 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
165
165
|
system_prompt: str = WRITE_TODOS_SYSTEM_PROMPT,
|
|
166
166
|
tool_description: str = WRITE_TODOS_TOOL_DESCRIPTION,
|
|
167
167
|
) -> None:
|
|
168
|
-
"""Initialize the
|
|
168
|
+
"""Initialize the TodoListMiddleware with optional custom prompts.
|
|
169
169
|
|
|
170
170
|
Args:
|
|
171
171
|
system_prompt: Custom system prompt to guide the agent on using the todo tool.
|
|
@@ -204,3 +204,16 @@ class PlanningMiddleware(AgentMiddleware):
|
|
|
204
204
|
else self.system_prompt
|
|
205
205
|
)
|
|
206
206
|
return handler(request)
|
|
207
|
+
|
|
208
|
+
async def awrap_model_call(
|
|
209
|
+
self,
|
|
210
|
+
request: ModelRequest,
|
|
211
|
+
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
|
212
|
+
) -> ModelCallResult:
|
|
213
|
+
"""Update the system prompt to include the todo system prompt (async version)."""
|
|
214
|
+
request.system_prompt = (
|
|
215
|
+
request.system_prompt + "\n\n" + self.system_prompt
|
|
216
|
+
if request.system_prompt
|
|
217
|
+
else self.system_prompt
|
|
218
|
+
)
|
|
219
|
+
return await handler(request)
|