langchain-dev-utils 1.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langchain_dev_utils/__init__.py +1 -0
- langchain_dev_utils/_utils.py +131 -0
- langchain_dev_utils/agents/__init__.py +4 -0
- langchain_dev_utils/agents/factory.py +99 -0
- langchain_dev_utils/agents/file_system.py +252 -0
- langchain_dev_utils/agents/middleware/__init__.py +21 -0
- langchain_dev_utils/agents/middleware/format_prompt.py +66 -0
- langchain_dev_utils/agents/middleware/handoffs.py +214 -0
- langchain_dev_utils/agents/middleware/model_fallback.py +49 -0
- langchain_dev_utils/agents/middleware/model_router.py +200 -0
- langchain_dev_utils/agents/middleware/plan.py +367 -0
- langchain_dev_utils/agents/middleware/summarization.py +85 -0
- langchain_dev_utils/agents/middleware/tool_call_repair.py +96 -0
- langchain_dev_utils/agents/middleware/tool_emulator.py +60 -0
- langchain_dev_utils/agents/middleware/tool_selection.py +82 -0
- langchain_dev_utils/agents/plan.py +188 -0
- langchain_dev_utils/agents/wrap.py +324 -0
- langchain_dev_utils/chat_models/__init__.py +11 -0
- langchain_dev_utils/chat_models/adapters/__init__.py +3 -0
- langchain_dev_utils/chat_models/adapters/create_utils.py +53 -0
- langchain_dev_utils/chat_models/adapters/openai_compatible.py +715 -0
- langchain_dev_utils/chat_models/adapters/register_profiles.py +15 -0
- langchain_dev_utils/chat_models/base.py +282 -0
- langchain_dev_utils/chat_models/types.py +27 -0
- langchain_dev_utils/embeddings/__init__.py +11 -0
- langchain_dev_utils/embeddings/adapters/__init__.py +3 -0
- langchain_dev_utils/embeddings/adapters/create_utils.py +45 -0
- langchain_dev_utils/embeddings/adapters/openai_compatible.py +91 -0
- langchain_dev_utils/embeddings/base.py +234 -0
- langchain_dev_utils/message_convert/__init__.py +15 -0
- langchain_dev_utils/message_convert/content.py +201 -0
- langchain_dev_utils/message_convert/format.py +69 -0
- langchain_dev_utils/pipeline/__init__.py +7 -0
- langchain_dev_utils/pipeline/parallel.py +135 -0
- langchain_dev_utils/pipeline/sequential.py +101 -0
- langchain_dev_utils/pipeline/types.py +3 -0
- langchain_dev_utils/py.typed +0 -0
- langchain_dev_utils/tool_calling/__init__.py +14 -0
- langchain_dev_utils/tool_calling/human_in_the_loop.py +284 -0
- langchain_dev_utils/tool_calling/utils.py +81 -0
- langchain_dev_utils-1.3.7.dist-info/METADATA +103 -0
- langchain_dev_utils-1.3.7.dist-info/RECORD +44 -0
- langchain_dev_utils-1.3.7.dist-info/WHEEL +4 -0
- langchain_dev_utils-1.3.7.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "1.3.7"
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from importlib import util
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _check_pkg_install(
|
|
8
|
+
pkg: Literal["langchain_openai", "json_repair"],
|
|
9
|
+
) -> None:
|
|
10
|
+
if not util.find_spec(pkg):
|
|
11
|
+
if pkg == "langchain_openai":
|
|
12
|
+
msg = "Please install langchain_dev_utils[standard],when use 'openai-compatible'"
|
|
13
|
+
else:
|
|
14
|
+
msg = "Please install langchain_dev_utils[standard] to use ToolCallRepairMiddleware."
|
|
15
|
+
raise ImportError(msg)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _get_base_url_field_name(model_cls: type[BaseModel]) -> str | None:
|
|
19
|
+
"""
|
|
20
|
+
Return 'base_url' if the model has a field named or aliased as 'base_url',
|
|
21
|
+
else return 'api_base' if it has a field named or aliased as 'api_base',
|
|
22
|
+
else return None.
|
|
23
|
+
The return value is always either 'base_url', 'api_base', or None.
|
|
24
|
+
"""
|
|
25
|
+
model_fields = model_cls.model_fields
|
|
26
|
+
|
|
27
|
+
# try model_fields first
|
|
28
|
+
if "base_url" in model_fields:
|
|
29
|
+
return "base_url"
|
|
30
|
+
|
|
31
|
+
if "api_base" in model_fields:
|
|
32
|
+
return "api_base"
|
|
33
|
+
|
|
34
|
+
# then try aliases
|
|
35
|
+
for field_info in model_fields.values():
|
|
36
|
+
if field_info.alias == "base_url":
|
|
37
|
+
return "base_url"
|
|
38
|
+
|
|
39
|
+
for field_info in model_fields.values():
|
|
40
|
+
if field_info.alias == "api_base":
|
|
41
|
+
return "api_base"
|
|
42
|
+
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _validate_base_url(base_url: Optional[str] = None) -> None:
|
|
47
|
+
"""Validate base URL format.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
base_url: Base URL to validate
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If base URL is not a valid HTTP or HTTPS URL
|
|
54
|
+
"""
|
|
55
|
+
if base_url is None:
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
from urllib.parse import urlparse
|
|
59
|
+
|
|
60
|
+
parsed = urlparse(base_url.strip())
|
|
61
|
+
|
|
62
|
+
if not parsed.scheme or not parsed.netloc:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"base_url must be a valid HTTP or HTTPS URL. Received: {base_url}"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if parsed.scheme not in ("http", "https"):
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"base_url must use HTTP or HTTPS protocol. Received: {parsed.scheme}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _validate_model_cls_name(model_cls_name: str) -> None:
|
|
74
|
+
"""Validate model class name follows Python naming conventions.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
model_cls_name: Class name to validate
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If class name is invalid
|
|
81
|
+
"""
|
|
82
|
+
if not model_cls_name:
|
|
83
|
+
raise ValueError("model_cls_name cannot be empty")
|
|
84
|
+
|
|
85
|
+
if not model_cls_name[0].isalpha():
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"model_cls_name must start with a letter. Received: {model_cls_name}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if not all(c.isalnum() or c == "_" for c in model_cls_name):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"model_cls_name can only contain letters, numbers, and underscores. Received: {model_cls_name}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if model_cls_name[0].islower():
|
|
96
|
+
raise ValueError(
|
|
97
|
+
f"model_cls_name should start with an uppercase letter (PEP 8). Received: {model_cls_name}"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if len(model_cls_name) > 30:
|
|
101
|
+
raise ValueError(
|
|
102
|
+
f"model_cls_name must be 30 characters or fewer. Received: {model_cls_name}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _validate_provider_name(provider_name: str) -> None:
|
|
107
|
+
"""Validate provider name follows Python naming conventions.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
provider_name: Provider name to validate
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If provider name is invalid
|
|
114
|
+
"""
|
|
115
|
+
if not provider_name:
|
|
116
|
+
raise ValueError("provider_name cannot be empty")
|
|
117
|
+
|
|
118
|
+
if not provider_name[0].isalnum():
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"provider_name must start with a letter or number. Received: {provider_name}"
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
if not all(c.isalnum() or c == "_" for c in provider_name):
|
|
124
|
+
raise ValueError(
|
|
125
|
+
f"provider_name can only contain letters, numbers, underscores. Received: {provider_name}"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if len(provider_name) > 20:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"provider_name must be 20 characters or fewer. Received: {provider_name}"
|
|
131
|
+
)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Any, Callable, Sequence
|
|
2
|
+
|
|
3
|
+
from langchain.agents import create_agent as _create_agent
|
|
4
|
+
from langchain.agents.middleware.types import (
|
|
5
|
+
AgentMiddleware,
|
|
6
|
+
AgentState,
|
|
7
|
+
ResponseT,
|
|
8
|
+
StateT_co,
|
|
9
|
+
_InputAgentState,
|
|
10
|
+
_OutputAgentState,
|
|
11
|
+
)
|
|
12
|
+
from langchain.agents.structured_output import ResponseFormat
|
|
13
|
+
from langchain_core.messages import SystemMessage
|
|
14
|
+
from langchain_core.tools import BaseTool
|
|
15
|
+
from langgraph.cache.base import BaseCache
|
|
16
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
17
|
+
from langgraph.store.base import BaseStore
|
|
18
|
+
from langgraph.types import Checkpointer
|
|
19
|
+
from langgraph.typing import ContextT
|
|
20
|
+
|
|
21
|
+
from ..chat_models import load_chat_model
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_agent( # noqa: PLR0915
|
|
25
|
+
model: str,
|
|
26
|
+
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
|
|
27
|
+
*,
|
|
28
|
+
system_prompt: str | SystemMessage | None = None,
|
|
29
|
+
response_format: ResponseFormat[ResponseT] | type[ResponseT] | None = None,
|
|
30
|
+
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
|
|
31
|
+
state_schema: type[AgentState[ResponseT]] | None = None,
|
|
32
|
+
context_schema: type[ContextT] | None = None,
|
|
33
|
+
checkpointer: Checkpointer | None = None,
|
|
34
|
+
store: BaseStore | None = None,
|
|
35
|
+
interrupt_before: list[str] | None = None,
|
|
36
|
+
interrupt_after: list[str] | None = None,
|
|
37
|
+
debug: bool = False,
|
|
38
|
+
name: str | None = None,
|
|
39
|
+
cache: BaseCache | None = None,
|
|
40
|
+
) -> CompiledStateGraph[
|
|
41
|
+
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
|
|
42
|
+
]:
|
|
43
|
+
"""
|
|
44
|
+
Create a prebuilt agent with string-based model specification.
|
|
45
|
+
|
|
46
|
+
This function provides the same functionality as the official `create_react_agent`,
|
|
47
|
+
but with the constraint that the model parameter must be a string that can be
|
|
48
|
+
loaded by the `load_chat_model` function. This allows for more flexible model
|
|
49
|
+
specification using the registered model providers.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
model: Model identifier string that can be loaded by `load_chat_model`.
|
|
53
|
+
Can be specified as "provider:model-name" format.
|
|
54
|
+
*: All other parameters are the same as in langchain.agents.create_agent.
|
|
55
|
+
See langchain.agents.create_agent for documentation on available parameters.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
CompiledStateGraph: A compiled state graph representing the agent.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If the model string cannot be loaded by load_chat_model.
|
|
62
|
+
|
|
63
|
+
Example:
|
|
64
|
+
>>> from langchain_dev_utils.chat_models import register_model_provider
|
|
65
|
+
>>> from langchain_dev_utils.agents import create_agent
|
|
66
|
+
>>>
|
|
67
|
+
# Register a model provider, must be done before creating the agent
|
|
68
|
+
>>> register_model_provider(
|
|
69
|
+
... provider_name="vllm",
|
|
70
|
+
... chat_model="openai-compatible",
|
|
71
|
+
... base_url="http://localhost:8000/v1",
|
|
72
|
+
... )
|
|
73
|
+
>>>
|
|
74
|
+
>>> agent = create_agent(
|
|
75
|
+
... "vllm:qwen3-4b",
|
|
76
|
+
... tools=[get_current_time],
|
|
77
|
+
... name="time-agent"
|
|
78
|
+
... )
|
|
79
|
+
>>> response = agent.invoke({
|
|
80
|
+
... "messages": [{"role": "user", "content": "What's the time?"}]
|
|
81
|
+
... })
|
|
82
|
+
>>> response
|
|
83
|
+
"""
|
|
84
|
+
return _create_agent(
|
|
85
|
+
model=load_chat_model(model),
|
|
86
|
+
tools=tools,
|
|
87
|
+
system_prompt=system_prompt,
|
|
88
|
+
middleware=middleware,
|
|
89
|
+
response_format=response_format,
|
|
90
|
+
state_schema=state_schema,
|
|
91
|
+
context_schema=context_schema,
|
|
92
|
+
checkpointer=checkpointer,
|
|
93
|
+
store=store,
|
|
94
|
+
interrupt_before=interrupt_before,
|
|
95
|
+
interrupt_after=interrupt_after,
|
|
96
|
+
debug=debug,
|
|
97
|
+
name=name,
|
|
98
|
+
cache=cache,
|
|
99
|
+
)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Annotated, Literal, Optional
|
|
3
|
+
|
|
4
|
+
from langchain.tools import BaseTool, ToolRuntime, tool
|
|
5
|
+
from langchain_core.messages import ToolMessage
|
|
6
|
+
from langgraph.types import Command
|
|
7
|
+
from typing_extensions import TypedDict
|
|
8
|
+
|
|
9
|
+
warnings.warn(
|
|
10
|
+
"langchain_dev_utils.agents.file_system is deprecated, and it will be removed in a future version. Please use middleware in deepagents instead.",
|
|
11
|
+
DeprecationWarning,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
_DEFAULT_WRITE_FILE_DESCRIPTION = """
|
|
15
|
+
A tool for writing files.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
content: The content of the file
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
_DEFAULT_LS_DESCRIPTION = """List all the saved file names."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
_DEFAULT_QUERY_FILE_DESCRIPTION = """
|
|
25
|
+
Query the content of a file.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
file_name: The name of the file
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
_DEFAULT_UPDATE_FILE_DESCRIPTION = """
|
|
32
|
+
Update the content of a file.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
file_name: The name of the file
|
|
36
|
+
origin_content: The original content of the file, must be a content in the file
|
|
37
|
+
new_content: The new content of the file
|
|
38
|
+
replace_all: Whether to replace all the origin content
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def file_reducer(left: dict | None, right: dict | None):
|
|
43
|
+
if left is None:
|
|
44
|
+
return right
|
|
45
|
+
elif right is None:
|
|
46
|
+
return left
|
|
47
|
+
else:
|
|
48
|
+
return {**left, **right}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class FileStateMixin(TypedDict):
|
|
52
|
+
file: Annotated[dict[str, str], file_reducer]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def create_write_file_tool(
|
|
56
|
+
name: Optional[str] = None,
|
|
57
|
+
description: Optional[str] = None,
|
|
58
|
+
message_key: Optional[str] = None,
|
|
59
|
+
) -> BaseTool:
|
|
60
|
+
"""Create a tool for writing files.
|
|
61
|
+
|
|
62
|
+
This function creates a tool that allows agents to write files and store them
|
|
63
|
+
in the state. The files are stored in a dictionary with the file name as the key
|
|
64
|
+
and the content as the value.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
name: The name of the tool. Defaults to "write_file".
|
|
68
|
+
description: The description of the tool. Uses default description if not provided.
|
|
69
|
+
message_key: The key of the message to be updated. Defaults to "messages".
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
BaseTool: The tool for writing files.
|
|
73
|
+
|
|
74
|
+
Example:
|
|
75
|
+
Basic usage:
|
|
76
|
+
>>> from langchain_dev_utils.agents.file_system import create_write_file_tool
|
|
77
|
+
>>> write_file = create_write_file_tool()
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
@tool(
|
|
81
|
+
name_or_callable=name or "write_file",
|
|
82
|
+
description=description or _DEFAULT_WRITE_FILE_DESCRIPTION,
|
|
83
|
+
)
|
|
84
|
+
def write_file(
|
|
85
|
+
file_name: Annotated[str, "the name of the file"],
|
|
86
|
+
content: Annotated[str, "the content of the file"],
|
|
87
|
+
runtime: ToolRuntime,
|
|
88
|
+
write_mode: Annotated[
|
|
89
|
+
Literal["write", "append"], "the write mode of the file"
|
|
90
|
+
] = "write",
|
|
91
|
+
):
|
|
92
|
+
files = runtime.state.get("file", {})
|
|
93
|
+
if write_mode == "append":
|
|
94
|
+
content = files.get(file_name, "") + content
|
|
95
|
+
if write_mode == "write" and file_name in files:
|
|
96
|
+
# if the file already exists, append a suffix to the file name when write_mode is "write"
|
|
97
|
+
file_name = file_name + "_" + str(len(files[file_name]))
|
|
98
|
+
msg_key = message_key or "messages"
|
|
99
|
+
return Command(
|
|
100
|
+
update={
|
|
101
|
+
"file": {file_name: content},
|
|
102
|
+
msg_key: [
|
|
103
|
+
ToolMessage(
|
|
104
|
+
content=f"file {file_name} written successfully, content is {content}",
|
|
105
|
+
tool_call_id=runtime.tool_call_id,
|
|
106
|
+
)
|
|
107
|
+
],
|
|
108
|
+
}
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return write_file
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def create_ls_file_tool(
|
|
115
|
+
name: Optional[str] = None, description: Optional[str] = None
|
|
116
|
+
) -> BaseTool:
|
|
117
|
+
"""Create a tool for listing all the saved file names.
|
|
118
|
+
|
|
119
|
+
This function creates a tool that allows agents to list all available files
|
|
120
|
+
stored in the state. This is useful for discovering what files have been
|
|
121
|
+
created before querying or updating them.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
name: The name of the tool. Defaults to "ls".
|
|
125
|
+
description: The description of the tool. Uses default description if not provided.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
BaseTool: The tool for listing all the saved file names.
|
|
129
|
+
|
|
130
|
+
Example:
|
|
131
|
+
Basic usage:
|
|
132
|
+
>>> from langchain_dev_utils.agents.file_system import create_ls_file_tool
|
|
133
|
+
>>> ls = create_ls_file_tool()
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
@tool(
|
|
137
|
+
name_or_callable=name or "ls",
|
|
138
|
+
description=description or _DEFAULT_LS_DESCRIPTION,
|
|
139
|
+
)
|
|
140
|
+
def ls(runtime: ToolRuntime):
|
|
141
|
+
files = runtime.state.get("file", {})
|
|
142
|
+
return list(files.keys())
|
|
143
|
+
|
|
144
|
+
return ls
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def create_query_file_tool(
|
|
148
|
+
name: Optional[str] = None, description: Optional[str] = None
|
|
149
|
+
) -> BaseTool:
|
|
150
|
+
"""Create a tool for querying the content of a file.
|
|
151
|
+
|
|
152
|
+
This function creates a tool that allows agents to retrieve the content of
|
|
153
|
+
a specific file by its name. This is useful for accessing previously stored
|
|
154
|
+
information during the conversation.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
name: The name of the tool. Defaults to "query_file".
|
|
158
|
+
description: The description of the tool. Uses default description if not provided.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
BaseTool: The tool for querying the content of a file.
|
|
162
|
+
|
|
163
|
+
Example:
|
|
164
|
+
Basic usage:
|
|
165
|
+
>>> from langchain_dev_utils.agents.file_system import create_query_file_tool
|
|
166
|
+
>>> query_file = create_query_file_tool()
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
@tool(
|
|
170
|
+
name_or_callable=name or "query_file",
|
|
171
|
+
description=description or _DEFAULT_QUERY_FILE_DESCRIPTION,
|
|
172
|
+
)
|
|
173
|
+
def query_file(file_name: str, runtime: ToolRuntime):
|
|
174
|
+
files = runtime.state.get("file", {})
|
|
175
|
+
if file_name not in files:
|
|
176
|
+
raise ValueError(f"Error: File {file_name} not found")
|
|
177
|
+
|
|
178
|
+
content = files.get(file_name)
|
|
179
|
+
|
|
180
|
+
if not content or content.strip() == "":
|
|
181
|
+
raise ValueError(f"Error: File {file_name} is empty")
|
|
182
|
+
|
|
183
|
+
return content
|
|
184
|
+
|
|
185
|
+
return query_file
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def create_update_file_tool(
|
|
189
|
+
name: Optional[str] = None,
|
|
190
|
+
description: Optional[str] = None,
|
|
191
|
+
message_key: Optional[str] = None,
|
|
192
|
+
) -> BaseTool:
|
|
193
|
+
"""Create a tool for updating files.
|
|
194
|
+
|
|
195
|
+
This function creates a tool that allows agents to update the content of
|
|
196
|
+
existing files. The tool can replace either the first occurrence of the
|
|
197
|
+
original content or all occurrences, depending on the replace_all parameter.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
name: The name of the tool. Defaults to "update_file".
|
|
201
|
+
description: The description of the tool. Uses default description if not provided.
|
|
202
|
+
message_key: The key of the message to be updated. Defaults to "messages".
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
BaseTool: The tool for updating files.
|
|
206
|
+
|
|
207
|
+
Example:
|
|
208
|
+
Basic usage:
|
|
209
|
+
>>> from langchain_dev_utils.agents.file_system import create_update_file_tool
|
|
210
|
+
>>> update_file_tool = create_update_file_tool()
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
@tool(
|
|
214
|
+
name_or_callable=name or "update_file",
|
|
215
|
+
description=description or _DEFAULT_UPDATE_FILE_DESCRIPTION,
|
|
216
|
+
)
|
|
217
|
+
def update_file(
|
|
218
|
+
file_name: Annotated[str, "the name of the file"],
|
|
219
|
+
origin_content: Annotated[str, "the original content of the file"],
|
|
220
|
+
new_content: Annotated[str, "the new content of the file"],
|
|
221
|
+
runtime: ToolRuntime,
|
|
222
|
+
replace_all: Annotated[bool, "replace all the origin content"] = False,
|
|
223
|
+
):
|
|
224
|
+
msg_key = message_key or "messages"
|
|
225
|
+
files = runtime.state.get("file", {})
|
|
226
|
+
if file_name not in files:
|
|
227
|
+
raise ValueError(f"Error: File {file_name} not found")
|
|
228
|
+
|
|
229
|
+
if origin_content not in files.get(file_name, ""):
|
|
230
|
+
raise ValueError(
|
|
231
|
+
f"Error: Origin content {origin_content} not found in file {file_name}"
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
if replace_all:
|
|
235
|
+
new_content = files.get(file_name, "").replace(origin_content, new_content)
|
|
236
|
+
else:
|
|
237
|
+
new_content = files.get(file_name, "").replace(
|
|
238
|
+
origin_content, new_content, 1
|
|
239
|
+
)
|
|
240
|
+
return Command(
|
|
241
|
+
update={
|
|
242
|
+
"file": {file_name: new_content},
|
|
243
|
+
msg_key: [
|
|
244
|
+
ToolMessage(
|
|
245
|
+
content=f"file {file_name} updated successfully, content is {new_content}",
|
|
246
|
+
tool_call_id=runtime.tool_call_id,
|
|
247
|
+
)
|
|
248
|
+
],
|
|
249
|
+
}
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
return update_file
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .format_prompt import format_prompt
|
|
2
|
+
from .handoffs import HandoffAgentMiddleware
|
|
3
|
+
from .model_fallback import ModelFallbackMiddleware
|
|
4
|
+
from .model_router import ModelRouterMiddleware
|
|
5
|
+
from .plan import PlanMiddleware
|
|
6
|
+
from .summarization import SummarizationMiddleware
|
|
7
|
+
from .tool_call_repair import ToolCallRepairMiddleware
|
|
8
|
+
from .tool_emulator import LLMToolEmulator
|
|
9
|
+
from .tool_selection import LLMToolSelectorMiddleware
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"SummarizationMiddleware",
|
|
13
|
+
"LLMToolSelectorMiddleware",
|
|
14
|
+
"PlanMiddleware",
|
|
15
|
+
"ModelFallbackMiddleware",
|
|
16
|
+
"LLMToolEmulator",
|
|
17
|
+
"ModelRouterMiddleware",
|
|
18
|
+
"ToolCallRepairMiddleware",
|
|
19
|
+
"format_prompt",
|
|
20
|
+
"HandoffAgentMiddleware",
|
|
21
|
+
]
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from langchain.agents.middleware import ModelRequest, dynamic_prompt
|
|
2
|
+
from langchain_core.prompts.string import get_template_variables
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
@dynamic_prompt
|
|
6
|
+
def format_prompt(request: ModelRequest) -> str:
|
|
7
|
+
"""Format the system prompt with variables from state and context.
|
|
8
|
+
|
|
9
|
+
This middleware function extracts template variables from the system prompt
|
|
10
|
+
and populates them with values from the agent's state and runtime context.
|
|
11
|
+
Variables are first resolved from the state, then from the context if not found.
|
|
12
|
+
|
|
13
|
+
Example:
|
|
14
|
+
>>> from langchain_dev_utils.agents.middleware import format_prompt
|
|
15
|
+
>>> from langchain.agents import create_agent
|
|
16
|
+
>>> from langchain_core.messages import HumanMessage
|
|
17
|
+
>>> from dataclasses import dataclass
|
|
18
|
+
>>>
|
|
19
|
+
>>> @dataclass
|
|
20
|
+
... class Context:
|
|
21
|
+
... name: str
|
|
22
|
+
... user: str
|
|
23
|
+
>>>
|
|
24
|
+
>>> agent=create_agent(
|
|
25
|
+
... model=model,
|
|
26
|
+
... tools=tools,
|
|
27
|
+
... system_prompt="You are a helpful assistant. Your name is {name}. Your user is {user}.",
|
|
28
|
+
... middleware=[format_prompt],
|
|
29
|
+
... context_schema=Context,
|
|
30
|
+
... )
|
|
31
|
+
>>> agent.invoke(
|
|
32
|
+
... {
|
|
33
|
+
... "messages": [HumanMessage(content="Hello")],
|
|
34
|
+
... },
|
|
35
|
+
... context=Context(name="assistant", user="Tom"),
|
|
36
|
+
... )
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
system_msg = request.system_message
|
|
40
|
+
if system_msg is None:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"system_message must be provided,while use format_prompt in middleware."
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
system_prompt = "\n".join(
|
|
46
|
+
[content.get("text", "") for content in system_msg.content_blocks]
|
|
47
|
+
)
|
|
48
|
+
variables = get_template_variables(system_prompt, "f-string")
|
|
49
|
+
|
|
50
|
+
format_params = {}
|
|
51
|
+
|
|
52
|
+
state = request.state
|
|
53
|
+
for key in variables:
|
|
54
|
+
if var := state.get(key, None):
|
|
55
|
+
format_params[key] = var
|
|
56
|
+
|
|
57
|
+
other_var_keys = set(variables) - set(format_params.keys())
|
|
58
|
+
|
|
59
|
+
if other_var_keys:
|
|
60
|
+
context = request.runtime.context
|
|
61
|
+
if context is not None:
|
|
62
|
+
for key in other_var_keys:
|
|
63
|
+
if var := getattr(context, key, None):
|
|
64
|
+
format_params[key] = var
|
|
65
|
+
|
|
66
|
+
return system_prompt.format(**format_params)
|