langgraph-agent-toolkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_agent_toolkit/__init__.py +7 -0
- langgraph_agent_toolkit/agents/__init__.py +0 -0
- langgraph_agent_toolkit/agents/agent.py +14 -0
- langgraph_agent_toolkit/agents/agent_executor.py +415 -0
- langgraph_agent_toolkit/agents/blueprints/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/agent.py +69 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/task.py +52 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/utils.py +17 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/agent.py +34 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/agent.py +54 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/agent.py +140 -0
- langgraph_agent_toolkit/agents/blueprints/react/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react/agent.py +67 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/agent.py +39 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/agent.py +44 -0
- langgraph_agent_toolkit/agents/components/__init__.py +0 -0
- langgraph_agent_toolkit/agents/components/creators/__init__.py +4 -0
- langgraph_agent_toolkit/agents/components/creators/create_react_agent.py +459 -0
- langgraph_agent_toolkit/agents/components/tools.py +13 -0
- langgraph_agent_toolkit/agents/components/utils.py +42 -0
- langgraph_agent_toolkit/client/__init__.py +4 -0
- langgraph_agent_toolkit/client/client.py +344 -0
- langgraph_agent_toolkit/core/__init__.py +5 -0
- langgraph_agent_toolkit/core/memory/__init__.py +0 -0
- langgraph_agent_toolkit/core/memory/base.py +33 -0
- langgraph_agent_toolkit/core/memory/factory.py +30 -0
- langgraph_agent_toolkit/core/memory/postgres.py +76 -0
- langgraph_agent_toolkit/core/memory/sqlite.py +21 -0
- langgraph_agent_toolkit/core/memory/types.py +6 -0
- langgraph_agent_toolkit/core/models/__init__.py +5 -0
- langgraph_agent_toolkit/core/models/chat_openai.py +21 -0
- langgraph_agent_toolkit/core/models/factory.py +118 -0
- langgraph_agent_toolkit/core/models/fake.py +25 -0
- langgraph_agent_toolkit/core/observability/__init__.py +10 -0
- langgraph_agent_toolkit/core/observability/base.py +331 -0
- langgraph_agent_toolkit/core/observability/empty.py +67 -0
- langgraph_agent_toolkit/core/observability/factory.py +43 -0
- langgraph_agent_toolkit/core/observability/langfuse.py +118 -0
- langgraph_agent_toolkit/core/observability/langsmith.py +131 -0
- langgraph_agent_toolkit/core/observability/types.py +34 -0
- langgraph_agent_toolkit/core/prompts/__init__.py +0 -0
- langgraph_agent_toolkit/core/prompts/chat_prompt_template.py +528 -0
- langgraph_agent_toolkit/core/settings.py +164 -0
- langgraph_agent_toolkit/helper/__init__.py +0 -0
- langgraph_agent_toolkit/helper/constants.py +10 -0
- langgraph_agent_toolkit/helper/logging.py +111 -0
- langgraph_agent_toolkit/helper/types.py +7 -0
- langgraph_agent_toolkit/helper/utils.py +80 -0
- langgraph_agent_toolkit/run_agent.py +68 -0
- langgraph_agent_toolkit/run_client.py +55 -0
- langgraph_agent_toolkit/run_service.py +19 -0
- langgraph_agent_toolkit/schema/__init__.py +28 -0
- langgraph_agent_toolkit/schema/models.py +25 -0
- langgraph_agent_toolkit/schema/schema.py +210 -0
- langgraph_agent_toolkit/schema/task_data.py +72 -0
- langgraph_agent_toolkit/service/__init__.py +0 -0
- langgraph_agent_toolkit/service/exception_handlers.py +46 -0
- langgraph_agent_toolkit/service/factory.py +213 -0
- langgraph_agent_toolkit/service/handler.py +122 -0
- langgraph_agent_toolkit/service/middleware.py +18 -0
- langgraph_agent_toolkit/service/routes.py +169 -0
- langgraph_agent_toolkit/service/types.py +8 -0
- langgraph_agent_toolkit/service/utils.py +136 -0
- langgraph_agent_toolkit/streamlit_app.py +368 -0
- langgraph_agent_toolkit-0.1.0.dist-info/METADATA +424 -0
- langgraph_agent_toolkit-0.1.0.dist-info/RECORD +74 -0
- langgraph_agent_toolkit-0.1.0.dist-info/WHEEL +4 -0
- langgraph_agent_toolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
Any,
|
|
3
|
+
Callable,
|
|
4
|
+
Literal,
|
|
5
|
+
Optional,
|
|
6
|
+
Sequence,
|
|
7
|
+
Type,
|
|
8
|
+
Union,
|
|
9
|
+
cast,
|
|
10
|
+
get_type_hints,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from langchain_core.language_models import (
|
|
14
|
+
BaseChatModel,
|
|
15
|
+
LanguageModelInput,
|
|
16
|
+
LanguageModelLike,
|
|
17
|
+
)
|
|
18
|
+
from langchain_core.messages import (
|
|
19
|
+
AIMessage,
|
|
20
|
+
AnyMessage,
|
|
21
|
+
BaseMessage,
|
|
22
|
+
SystemMessage,
|
|
23
|
+
ToolMessage,
|
|
24
|
+
)
|
|
25
|
+
from langchain_core.runnables import (
|
|
26
|
+
Runnable,
|
|
27
|
+
RunnableConfig,
|
|
28
|
+
)
|
|
29
|
+
from langchain_core.tools import BaseTool
|
|
30
|
+
from langgraph.graph import END, StateGraph
|
|
31
|
+
from langgraph.graph.graph import CompiledGraph
|
|
32
|
+
from langgraph.prebuilt.chat_agent_executor import (
|
|
33
|
+
AgentState,
|
|
34
|
+
AgentStateWithStructuredResponse,
|
|
35
|
+
StructuredResponseSchema,
|
|
36
|
+
_get_model,
|
|
37
|
+
_get_prompt_runnable,
|
|
38
|
+
_get_state_value,
|
|
39
|
+
_should_bind_tools,
|
|
40
|
+
_validate_chat_history,
|
|
41
|
+
)
|
|
42
|
+
from langgraph.prebuilt.tool_node import ToolNode
|
|
43
|
+
from langgraph.store.base import BaseStore
|
|
44
|
+
from langgraph.types import Checkpointer, Send
|
|
45
|
+
from langgraph.utils.runnable import RunnableCallable, RunnableLike
|
|
46
|
+
from pydantic import BaseModel
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def create_react_agent(
|
|
50
|
+
model: Union[str, LanguageModelLike],
|
|
51
|
+
tools: Union[Sequence[Union[BaseTool, Callable]], ToolNode],
|
|
52
|
+
*,
|
|
53
|
+
prompt: Optional[
|
|
54
|
+
Union[SystemMessage, str, Callable[[Any], LanguageModelInput], Runnable[Any, LanguageModelInput]]
|
|
55
|
+
] = None,
|
|
56
|
+
response_format: Optional[Union[StructuredResponseSchema, tuple[str, StructuredResponseSchema]]] = None,
|
|
57
|
+
pre_model_hook: Optional[RunnableLike] = None,
|
|
58
|
+
state_schema: Optional[Type[Any]] = None,
|
|
59
|
+
config_schema: Optional[Type[Any]] = None,
|
|
60
|
+
checkpointer: Optional[Checkpointer] = None,
|
|
61
|
+
store: Optional[BaseStore] = None,
|
|
62
|
+
interrupt_before: Optional[list[str]] = None,
|
|
63
|
+
interrupt_after: Optional[list[str]] = None,
|
|
64
|
+
debug: bool = False,
|
|
65
|
+
version: Literal["v1", "v2"] = "v1",
|
|
66
|
+
name: Optional[str] = None,
|
|
67
|
+
immediate_step_threshold: int = 5, # New parameter for customizing when to use immediate generation
|
|
68
|
+
immediate_generation_prompt: Optional[str] = None, # New parameter for customizing the immediate generation prompt
|
|
69
|
+
) -> CompiledGraph:
|
|
70
|
+
"""Create a graph that works with a chat model that utilizes tool calling with an additional router.
|
|
71
|
+
|
|
72
|
+
This implementation extends the original create_react_agent by adding a router node
|
|
73
|
+
that checks remaining steps and routes to either the agent or an immediate generation
|
|
74
|
+
node when the remaining steps are below a threshold.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
model: The `LangChain` chat model that supports tool calling.
|
|
78
|
+
tools: A list of tools or a ToolNode instance.
|
|
79
|
+
prompt: An optional prompt for the LLM.
|
|
80
|
+
response_format: An optional schema for the final agent output.
|
|
81
|
+
pre_model_hook: An optional node to add before the `agent` node.
|
|
82
|
+
state_schema: An optional state schema that defines graph state.
|
|
83
|
+
config_schema: An optional schema for configuration.
|
|
84
|
+
checkpointer: An optional checkpoint saver object.
|
|
85
|
+
store: An optional store object.
|
|
86
|
+
interrupt_before: An optional list of node names to interrupt before.
|
|
87
|
+
interrupt_after: An optional list of node names to interrupt after.
|
|
88
|
+
debug: A flag indicating whether to enable debug mode.
|
|
89
|
+
version: Determines the version of the graph to create ('v1' or 'v2').
|
|
90
|
+
name: An optional name for the CompiledStateGraph.
|
|
91
|
+
immediate_step_threshold: Number of remaining steps below which the router will use immediate generation.
|
|
92
|
+
immediate_generation_prompt: Optional custom prompt for the immediate generation mode.
|
|
93
|
+
If not provided, a default prompt will be used instructing the model to generate a direct answer.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
A compiled LangChain runnable that can be used for chat interactions.
|
|
97
|
+
|
|
98
|
+
"""
|
|
99
|
+
if version not in ("v1", "v2"):
|
|
100
|
+
raise ValueError(f"Invalid version {version}. Supported versions are 'v1' and 'v2'.")
|
|
101
|
+
|
|
102
|
+
if state_schema is not None:
|
|
103
|
+
required_keys = {"messages", "remaining_steps"}
|
|
104
|
+
if response_format is not None:
|
|
105
|
+
required_keys.add("structured_response")
|
|
106
|
+
|
|
107
|
+
schema_keys = set(get_type_hints(state_schema))
|
|
108
|
+
if missing_keys := required_keys - set(schema_keys):
|
|
109
|
+
raise ValueError(f"Missing required key(s) {missing_keys} in state_schema")
|
|
110
|
+
|
|
111
|
+
if state_schema is None:
|
|
112
|
+
state_schema = AgentStateWithStructuredResponse if response_format is not None else AgentState
|
|
113
|
+
|
|
114
|
+
if isinstance(tools, ToolNode):
|
|
115
|
+
tool_classes = list(tools.tools_by_name.values())
|
|
116
|
+
tool_node = tools
|
|
117
|
+
else:
|
|
118
|
+
tool_node = ToolNode(tools)
|
|
119
|
+
# get the tool functions wrapped in a tool class from the ToolNode
|
|
120
|
+
tool_classes = list(tool_node.tools_by_name.values())
|
|
121
|
+
|
|
122
|
+
if isinstance(model, str):
|
|
123
|
+
try:
|
|
124
|
+
from langchain.chat_models import ( # type: ignore[import-not-found]
|
|
125
|
+
init_chat_model,
|
|
126
|
+
)
|
|
127
|
+
except ImportError:
|
|
128
|
+
raise ImportError(
|
|
129
|
+
"Please install langchain (`pip install langchain`) to use '<provider>:<model>' "
|
|
130
|
+
"string syntax for `model` parameter."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
model = cast(BaseChatModel, init_chat_model(model))
|
|
134
|
+
|
|
135
|
+
tool_calling_enabled = len(tool_classes) > 0
|
|
136
|
+
|
|
137
|
+
if _should_bind_tools(model, tool_classes) and tool_calling_enabled:
|
|
138
|
+
model = cast(BaseChatModel, model).bind_tools(tool_classes)
|
|
139
|
+
|
|
140
|
+
model_runnable = _get_prompt_runnable(prompt) | model
|
|
141
|
+
|
|
142
|
+
# If any of the tools are configured to return_directly after running,
|
|
143
|
+
# our graph needs to check if these were called
|
|
144
|
+
should_return_direct = {t.name for t in tool_classes if t.return_direct}
|
|
145
|
+
|
|
146
|
+
def _are_more_steps_needed(state: Any, response: BaseMessage) -> bool:
|
|
147
|
+
has_tool_calls = isinstance(response, AIMessage) and response.tool_calls
|
|
148
|
+
all_tools_return_direct = (
|
|
149
|
+
all(call["name"] in should_return_direct for call in response.tool_calls)
|
|
150
|
+
if isinstance(response, AIMessage)
|
|
151
|
+
else False
|
|
152
|
+
)
|
|
153
|
+
remaining_steps = _get_state_value(state, "remaining_steps", None)
|
|
154
|
+
is_last_step = _get_state_value(state, "is_last_step", False)
|
|
155
|
+
return (
|
|
156
|
+
(remaining_steps is None and is_last_step and has_tool_calls)
|
|
157
|
+
or (remaining_steps is not None and remaining_steps < 1 and all_tools_return_direct)
|
|
158
|
+
or (remaining_steps is not None and remaining_steps < 2 and has_tool_calls)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def _get_model_input_state(state: Any) -> Any:
|
|
162
|
+
if pre_model_hook is not None:
|
|
163
|
+
messages = (_get_state_value(state, "llm_input_messages")) or _get_state_value(state, "messages")
|
|
164
|
+
error_msg = f"Expected input to call_model to have 'llm_input_messages' or 'messages' key, but got {state}"
|
|
165
|
+
else:
|
|
166
|
+
messages = _get_state_value(state, "messages")
|
|
167
|
+
error_msg = f"Expected input to call_model to have 'messages' key, but got {state}"
|
|
168
|
+
|
|
169
|
+
if messages is None:
|
|
170
|
+
raise ValueError(error_msg)
|
|
171
|
+
|
|
172
|
+
_validate_chat_history(messages)
|
|
173
|
+
# we're passing messages under `messages` key, as this is expected by the prompt
|
|
174
|
+
if isinstance(state_schema, type) and issubclass(state_schema, BaseModel):
|
|
175
|
+
state.messages = messages # type: ignore
|
|
176
|
+
else:
|
|
177
|
+
state["messages"] = messages # type: ignore
|
|
178
|
+
|
|
179
|
+
return state
|
|
180
|
+
|
|
181
|
+
# Define the function that calls the model
|
|
182
|
+
def call_model(state: Any, config: RunnableConfig) -> Any:
|
|
183
|
+
state = _get_model_input_state(state)
|
|
184
|
+
response = cast(AIMessage, model_runnable.invoke(state, config))
|
|
185
|
+
# add agent name to the AIMessage
|
|
186
|
+
response.name = name
|
|
187
|
+
|
|
188
|
+
if _are_more_steps_needed(state, response):
|
|
189
|
+
return {
|
|
190
|
+
"messages": [
|
|
191
|
+
AIMessage(
|
|
192
|
+
id=response.id,
|
|
193
|
+
content="Sorry, need more steps to process this request.",
|
|
194
|
+
)
|
|
195
|
+
]
|
|
196
|
+
}
|
|
197
|
+
# We return a list, because this will get added to the existing list
|
|
198
|
+
return {"messages": [response]}
|
|
199
|
+
|
|
200
|
+
async def acall_model(state: Any, config: RunnableConfig) -> Any:
|
|
201
|
+
state = _get_model_input_state(state)
|
|
202
|
+
response = cast(AIMessage, await model_runnable.ainvoke(state, config))
|
|
203
|
+
# add agent name to the AIMessage
|
|
204
|
+
response.name = name
|
|
205
|
+
if _are_more_steps_needed(state, response):
|
|
206
|
+
return {
|
|
207
|
+
"messages": [
|
|
208
|
+
AIMessage(
|
|
209
|
+
id=response.id,
|
|
210
|
+
content="Sorry, need more steps to process this request.",
|
|
211
|
+
)
|
|
212
|
+
]
|
|
213
|
+
}
|
|
214
|
+
# We return a list, because this will get added to the existing list
|
|
215
|
+
return {"messages": [response]}
|
|
216
|
+
|
|
217
|
+
# Define the immediate generation function - similar to call_model but with a prompt
|
|
218
|
+
# that instructs the model to avoid tool calls and generate a direct response
|
|
219
|
+
def immediate_generation(state: Any, config: RunnableConfig) -> Any:
|
|
220
|
+
state = _get_model_input_state(state)
|
|
221
|
+
# Create a special system message that instructs the model to give a direct answer
|
|
222
|
+
default_prompt = (
|
|
223
|
+
"You need to generate a direct answer based on the information you already have. "
|
|
224
|
+
"DO NOT make any tool calls. Synthesize what you know and respond directly."
|
|
225
|
+
)
|
|
226
|
+
prompt_content = immediate_generation_prompt or default_prompt
|
|
227
|
+
immediate_prompt = SystemMessage(content=prompt_content)
|
|
228
|
+
|
|
229
|
+
messages = _get_state_value(state, "messages")
|
|
230
|
+
prompt_with_instruction = [immediate_prompt] + list(messages)
|
|
231
|
+
|
|
232
|
+
# Use the model directly without tool calling capabilities
|
|
233
|
+
base_model = _get_model(model)
|
|
234
|
+
response = cast(AIMessage, base_model.invoke(prompt_with_instruction, config))
|
|
235
|
+
response.name = name
|
|
236
|
+
|
|
237
|
+
return {"messages": [response]}
|
|
238
|
+
|
|
239
|
+
async def aimmediate_generation(state: Any, config: RunnableConfig) -> Any:
|
|
240
|
+
state = _get_model_input_state(state)
|
|
241
|
+
default_prompt = (
|
|
242
|
+
"You need to generate a direct answer based on the information you already have. "
|
|
243
|
+
"DO NOT make any tool calls. Synthesize what you know and respond directly."
|
|
244
|
+
)
|
|
245
|
+
prompt_content = immediate_generation_prompt or default_prompt
|
|
246
|
+
immediate_prompt = SystemMessage(content=prompt_content)
|
|
247
|
+
|
|
248
|
+
messages = _get_state_value(state, "messages")
|
|
249
|
+
prompt_with_instruction = [immediate_prompt] + list(messages)
|
|
250
|
+
|
|
251
|
+
base_model = _get_model(model)
|
|
252
|
+
# Fix: Use ainvoke instead of invoke for async function
|
|
253
|
+
response = cast(AIMessage, await base_model.ainvoke(prompt_with_instruction, config))
|
|
254
|
+
response.name = name
|
|
255
|
+
|
|
256
|
+
return {"messages": [response]}
|
|
257
|
+
|
|
258
|
+
# Define the router function that checks remaining steps
|
|
259
|
+
def router_condition(state: Any) -> str:
|
|
260
|
+
remaining_steps = _get_state_value(state, "remaining_steps", None)
|
|
261
|
+
|
|
262
|
+
# If remaining_steps is below threshold and not None, route to immediate generation
|
|
263
|
+
if remaining_steps is not None and remaining_steps < immediate_step_threshold:
|
|
264
|
+
return "immediate_generation"
|
|
265
|
+
|
|
266
|
+
# Otherwise, continue with normal agent flow
|
|
267
|
+
return "agent"
|
|
268
|
+
|
|
269
|
+
input_schema = state_schema
|
|
270
|
+
if pre_model_hook is not None:
|
|
271
|
+
# Dynamically create a schema that inherits from state_schema and adds 'llm_input_messages'
|
|
272
|
+
if isinstance(state_schema, type) and issubclass(state_schema, BaseModel):
|
|
273
|
+
# For Pydantic schemas
|
|
274
|
+
from pydantic import create_model
|
|
275
|
+
|
|
276
|
+
input_schema = create_model(
|
|
277
|
+
"CallModelInputSchema",
|
|
278
|
+
llm_input_messages=(list[AnyMessage], ...),
|
|
279
|
+
__base__=state_schema,
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
# For TypedDict schemas
|
|
283
|
+
class CallModelInputSchema(state_schema): # type: ignore
|
|
284
|
+
llm_input_messages: list[AnyMessage]
|
|
285
|
+
|
|
286
|
+
input_schema = CallModelInputSchema
|
|
287
|
+
|
|
288
|
+
def generate_structured_response(state: Any, config: RunnableConfig) -> Any:
|
|
289
|
+
messages = _get_state_value(state, "messages")
|
|
290
|
+
structured_response_schema = response_format
|
|
291
|
+
if isinstance(response_format, tuple):
|
|
292
|
+
system_prompt, structured_response_schema = response_format
|
|
293
|
+
messages = [SystemMessage(content=system_prompt)] + list(messages)
|
|
294
|
+
|
|
295
|
+
model_with_structured_output = _get_model(model).with_structured_output(
|
|
296
|
+
cast(StructuredResponseSchema, structured_response_schema)
|
|
297
|
+
)
|
|
298
|
+
response = model_with_structured_output.invoke(messages, config)
|
|
299
|
+
return {"structured_response": response}
|
|
300
|
+
|
|
301
|
+
async def agenerate_structured_response(state: Any, config: RunnableConfig) -> Any:
|
|
302
|
+
messages = _get_state_value(state, "messages")
|
|
303
|
+
structured_response_schema = response_format
|
|
304
|
+
if isinstance(response_format, tuple):
|
|
305
|
+
system_prompt, structured_response_schema = response_format
|
|
306
|
+
messages = [SystemMessage(content=system_prompt)] + list(messages)
|
|
307
|
+
|
|
308
|
+
model_with_structured_output = _get_model(model).with_structured_output(
|
|
309
|
+
cast(StructuredResponseSchema, structured_response_schema)
|
|
310
|
+
)
|
|
311
|
+
response = await model_with_structured_output.ainvoke(messages, config)
|
|
312
|
+
return {"structured_response": response}
|
|
313
|
+
|
|
314
|
+
if not tool_calling_enabled:
|
|
315
|
+
# Define a new graph
|
|
316
|
+
workflow = StateGraph(state_schema, config_schema=config_schema)
|
|
317
|
+
|
|
318
|
+
# Add nodes for agent and immediate generation
|
|
319
|
+
workflow.add_node(
|
|
320
|
+
"agent",
|
|
321
|
+
RunnableCallable(call_model, acall_model),
|
|
322
|
+
input=input_schema,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
workflow.add_node(
|
|
326
|
+
"immediate_generation",
|
|
327
|
+
RunnableCallable(immediate_generation, aimmediate_generation),
|
|
328
|
+
input=input_schema,
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Set up routing structure
|
|
332
|
+
if pre_model_hook is not None:
|
|
333
|
+
workflow.add_node("pre_model_hook", pre_model_hook)
|
|
334
|
+
# Route pre_model_hook directly to either agent or immediate_generation based on condition
|
|
335
|
+
workflow.add_conditional_edges("pre_model_hook", router_condition, ["agent", "immediate_generation"])
|
|
336
|
+
entrypoint = "pre_model_hook"
|
|
337
|
+
else:
|
|
338
|
+
# If no pre_model_hook, use START as the conditional router
|
|
339
|
+
workflow.add_conditional_edges("START", router_condition, ["agent", "immediate_generation"])
|
|
340
|
+
entrypoint = "START"
|
|
341
|
+
|
|
342
|
+
# Connect both agent and immediate_generation to END or structured response
|
|
343
|
+
if response_format is not None:
|
|
344
|
+
workflow.add_node(
|
|
345
|
+
"generate_structured_response",
|
|
346
|
+
RunnableCallable(generate_structured_response, agenerate_structured_response),
|
|
347
|
+
)
|
|
348
|
+
workflow.add_edge("agent", "generate_structured_response")
|
|
349
|
+
workflow.add_edge("immediate_generation", "generate_structured_response")
|
|
350
|
+
workflow.add_edge("generate_structured_response", END)
|
|
351
|
+
else:
|
|
352
|
+
workflow.add_edge("agent", END)
|
|
353
|
+
workflow.add_edge("immediate_generation", END)
|
|
354
|
+
|
|
355
|
+
workflow.set_entry_point(entrypoint)
|
|
356
|
+
|
|
357
|
+
return workflow.compile(
|
|
358
|
+
checkpointer=checkpointer,
|
|
359
|
+
store=store,
|
|
360
|
+
interrupt_before=interrupt_before,
|
|
361
|
+
interrupt_after=interrupt_after,
|
|
362
|
+
debug=debug,
|
|
363
|
+
name=name,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Define the function that determines whether to continue or not
|
|
367
|
+
def should_continue(state: Any) -> Union[str, list]:
|
|
368
|
+
messages = _get_state_value(state, "messages")
|
|
369
|
+
last_message = messages[-1]
|
|
370
|
+
# If there is no function call, then we finish
|
|
371
|
+
if not isinstance(last_message, AIMessage) or not last_message.tool_calls:
|
|
372
|
+
return END if response_format is None else "generate_structured_response"
|
|
373
|
+
# Otherwise if there is, we continue
|
|
374
|
+
else:
|
|
375
|
+
if version == "v1":
|
|
376
|
+
return "tools"
|
|
377
|
+
elif version == "v2":
|
|
378
|
+
tool_calls = [
|
|
379
|
+
tool_node.inject_tool_args(call, state, store) # type: ignore[arg-type]
|
|
380
|
+
for call in last_message.tool_calls
|
|
381
|
+
]
|
|
382
|
+
return [Send("tools", [tool_call]) for tool_call in tool_calls]
|
|
383
|
+
|
|
384
|
+
# Define a new graph
|
|
385
|
+
workflow = StateGraph(state_schema, config_schema=config_schema)
|
|
386
|
+
|
|
387
|
+
# Define the nodes
|
|
388
|
+
workflow.add_node("agent", RunnableCallable(call_model, acall_model), input=input_schema)
|
|
389
|
+
workflow.add_node(
|
|
390
|
+
"immediate_generation", RunnableCallable(immediate_generation, aimmediate_generation), input=input_schema
|
|
391
|
+
)
|
|
392
|
+
workflow.add_node("tools", tool_node)
|
|
393
|
+
|
|
394
|
+
# Set up the routing structure
|
|
395
|
+
if pre_model_hook is not None:
|
|
396
|
+
workflow.add_node("pre_model_hook", pre_model_hook)
|
|
397
|
+
# Route pre_model_hook directly to either agent or immediate_generation based on condition
|
|
398
|
+
workflow.add_conditional_edges("pre_model_hook", router_condition, ["agent", "immediate_generation"])
|
|
399
|
+
entrypoint = "pre_model_hook"
|
|
400
|
+
else:
|
|
401
|
+
# If no pre_model_hook, use START as the conditional router
|
|
402
|
+
workflow.add_conditional_edges("START", router_condition, ["agent", "immediate_generation"])
|
|
403
|
+
entrypoint = "START"
|
|
404
|
+
|
|
405
|
+
# Set the entrypoint
|
|
406
|
+
workflow.set_entry_point(entrypoint)
|
|
407
|
+
|
|
408
|
+
# Add structured output node if response_format is provided
|
|
409
|
+
if response_format is not None:
|
|
410
|
+
workflow.add_node(
|
|
411
|
+
"generate_structured_response",
|
|
412
|
+
RunnableCallable(generate_structured_response, agenerate_structured_response),
|
|
413
|
+
)
|
|
414
|
+
workflow.add_edge("generate_structured_response", END)
|
|
415
|
+
workflow.add_edge("immediate_generation", "generate_structured_response")
|
|
416
|
+
should_continue_destinations = ["tools", "generate_structured_response"]
|
|
417
|
+
else:
|
|
418
|
+
workflow.add_edge("immediate_generation", END)
|
|
419
|
+
should_continue_destinations = ["tools", END]
|
|
420
|
+
|
|
421
|
+
# Add conditional edges from agent
|
|
422
|
+
workflow.add_conditional_edges(
|
|
423
|
+
"agent",
|
|
424
|
+
should_continue,
|
|
425
|
+
should_continue_destinations,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
def route_tool_responses(state: Any) -> str:
|
|
429
|
+
for m in reversed(_get_state_value(state, "messages")):
|
|
430
|
+
if not isinstance(m, ToolMessage):
|
|
431
|
+
break
|
|
432
|
+
if m.name in should_return_direct:
|
|
433
|
+
return END
|
|
434
|
+
|
|
435
|
+
# After tools, go back to the conditional routing - either pre_model_hook or router_condition
|
|
436
|
+
return "pre_model_hook" if pre_model_hook is not None else "START"
|
|
437
|
+
|
|
438
|
+
if should_return_direct:
|
|
439
|
+
destinations = []
|
|
440
|
+
if pre_model_hook is not None:
|
|
441
|
+
destinations.append("pre_model_hook")
|
|
442
|
+
else:
|
|
443
|
+
destinations.append("START")
|
|
444
|
+
|
|
445
|
+
destinations.append(END)
|
|
446
|
+
workflow.add_conditional_edges("tools", route_tool_responses, destinations)
|
|
447
|
+
else:
|
|
448
|
+
# After tools, go back to the conditional routing
|
|
449
|
+
workflow.add_edge("tools", "pre_model_hook" if pre_model_hook is not None else "START")
|
|
450
|
+
|
|
451
|
+
# Finally, we compile it!
|
|
452
|
+
return workflow.compile(
|
|
453
|
+
checkpointer=checkpointer,
|
|
454
|
+
store=store,
|
|
455
|
+
interrupt_before=interrupt_before,
|
|
456
|
+
interrupt_after=interrupt_after,
|
|
457
|
+
debug=debug,
|
|
458
|
+
name=name,
|
|
459
|
+
)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from typing import TypeVar
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages.utils import trim_messages
|
|
4
|
+
from langchain_core.runnables import RunnableConfig
|
|
5
|
+
from langgraph.managed.is_last_step import RemainingSteps
|
|
6
|
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
|
7
|
+
|
|
8
|
+
from langgraph_agent_toolkit.helper.constants import DEFAULT_MAX_MESSAGE_HISTORY_LENGTH
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T")
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AgentStateWithRemainingSteps(AgentState):
|
|
15
|
+
remaining_steps: RemainingSteps
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def pre_model_hook_standard(state: T, config: RunnableConfig):
|
|
19
|
+
# https://langchain-ai.github.io/langgraph/how-tos/create-react-agent-manage-message-history/
|
|
20
|
+
# if last message is a human message, trim the messages to only include human messages
|
|
21
|
+
updated_messages = state["messages"]
|
|
22
|
+
if updated_messages[-1].type == "human" or config["metadata"]["langgraph_step"] == 1:
|
|
23
|
+
updated_messages = [
|
|
24
|
+
message
|
|
25
|
+
for message in updated_messages
|
|
26
|
+
if message.type not in {"tool", "tool_call", "function"} and message.content
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
_max_messages = config.get("configurable", {}).get("memory_saver_params", {}).get("k", None)
|
|
30
|
+
|
|
31
|
+
updated_messages = trim_messages(
|
|
32
|
+
updated_messages,
|
|
33
|
+
token_counter=len,
|
|
34
|
+
max_tokens=_max_messages or DEFAULT_MAX_MESSAGE_HISTORY_LENGTH,
|
|
35
|
+
strategy="last",
|
|
36
|
+
start_on="human",
|
|
37
|
+
end_on=("human", "tool"),
|
|
38
|
+
include_system=True,
|
|
39
|
+
allow_partial=False,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
return {"llm_input_messages": updated_messages}
|