langgraph-agent-toolkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_agent_toolkit/__init__.py +7 -0
- langgraph_agent_toolkit/agents/__init__.py +0 -0
- langgraph_agent_toolkit/agents/agent.py +14 -0
- langgraph_agent_toolkit/agents/agent_executor.py +415 -0
- langgraph_agent_toolkit/agents/blueprints/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/agent.py +69 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/task.py +52 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/utils.py +17 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/agent.py +34 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/agent.py +54 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/agent.py +140 -0
- langgraph_agent_toolkit/agents/blueprints/react/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react/agent.py +67 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/agent.py +39 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/agent.py +44 -0
- langgraph_agent_toolkit/agents/components/__init__.py +0 -0
- langgraph_agent_toolkit/agents/components/creators/__init__.py +4 -0
- langgraph_agent_toolkit/agents/components/creators/create_react_agent.py +459 -0
- langgraph_agent_toolkit/agents/components/tools.py +13 -0
- langgraph_agent_toolkit/agents/components/utils.py +42 -0
- langgraph_agent_toolkit/client/__init__.py +4 -0
- langgraph_agent_toolkit/client/client.py +344 -0
- langgraph_agent_toolkit/core/__init__.py +5 -0
- langgraph_agent_toolkit/core/memory/__init__.py +0 -0
- langgraph_agent_toolkit/core/memory/base.py +33 -0
- langgraph_agent_toolkit/core/memory/factory.py +30 -0
- langgraph_agent_toolkit/core/memory/postgres.py +76 -0
- langgraph_agent_toolkit/core/memory/sqlite.py +21 -0
- langgraph_agent_toolkit/core/memory/types.py +6 -0
- langgraph_agent_toolkit/core/models/__init__.py +5 -0
- langgraph_agent_toolkit/core/models/chat_openai.py +21 -0
- langgraph_agent_toolkit/core/models/factory.py +118 -0
- langgraph_agent_toolkit/core/models/fake.py +25 -0
- langgraph_agent_toolkit/core/observability/__init__.py +10 -0
- langgraph_agent_toolkit/core/observability/base.py +331 -0
- langgraph_agent_toolkit/core/observability/empty.py +67 -0
- langgraph_agent_toolkit/core/observability/factory.py +43 -0
- langgraph_agent_toolkit/core/observability/langfuse.py +118 -0
- langgraph_agent_toolkit/core/observability/langsmith.py +131 -0
- langgraph_agent_toolkit/core/observability/types.py +34 -0
- langgraph_agent_toolkit/core/prompts/__init__.py +0 -0
- langgraph_agent_toolkit/core/prompts/chat_prompt_template.py +528 -0
- langgraph_agent_toolkit/core/settings.py +164 -0
- langgraph_agent_toolkit/helper/__init__.py +0 -0
- langgraph_agent_toolkit/helper/constants.py +10 -0
- langgraph_agent_toolkit/helper/logging.py +111 -0
- langgraph_agent_toolkit/helper/types.py +7 -0
- langgraph_agent_toolkit/helper/utils.py +80 -0
- langgraph_agent_toolkit/run_agent.py +68 -0
- langgraph_agent_toolkit/run_client.py +55 -0
- langgraph_agent_toolkit/run_service.py +19 -0
- langgraph_agent_toolkit/schema/__init__.py +28 -0
- langgraph_agent_toolkit/schema/models.py +25 -0
- langgraph_agent_toolkit/schema/schema.py +210 -0
- langgraph_agent_toolkit/schema/task_data.py +72 -0
- langgraph_agent_toolkit/service/__init__.py +0 -0
- langgraph_agent_toolkit/service/exception_handlers.py +46 -0
- langgraph_agent_toolkit/service/factory.py +213 -0
- langgraph_agent_toolkit/service/handler.py +122 -0
- langgraph_agent_toolkit/service/middleware.py +18 -0
- langgraph_agent_toolkit/service/routes.py +169 -0
- langgraph_agent_toolkit/service/types.py +8 -0
- langgraph_agent_toolkit/service/utils.py +136 -0
- langgraph_agent_toolkit/streamlit_app.py +368 -0
- langgraph_agent_toolkit-0.1.0.dist-info/METADATA +424 -0
- langgraph_agent_toolkit-0.1.0.dist-info/RECORD +74 -0
- langgraph_agent_toolkit-0.1.0.dist-info/WHEEL +4 -0
- langgraph_agent_toolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
File without changes
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from langgraph.func import Pregel
|
|
4
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
5
|
+
|
|
6
|
+
from langgraph_agent_toolkit.core.observability.base import BaseObservabilityPlatform
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class Agent:
|
|
11
|
+
name: str
|
|
12
|
+
description: str
|
|
13
|
+
graph: CompiledStateGraph | Pregel
|
|
14
|
+
observability: BaseObservabilityPlatform | None = None
|
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import functools
|
|
3
|
+
import importlib
|
|
4
|
+
import os
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, TypeVar
|
|
7
|
+
from uuid import UUID, uuid4
|
|
8
|
+
|
|
9
|
+
import joblib
|
|
10
|
+
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
|
|
11
|
+
from langchain_core.runnables import RunnableConfig
|
|
12
|
+
from langgraph.errors import GraphRecursionError
|
|
13
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
14
|
+
from langgraph.pregel import Pregel
|
|
15
|
+
from langgraph.types import Command, Interrupt
|
|
16
|
+
|
|
17
|
+
from langgraph_agent_toolkit.agents.agent import Agent
|
|
18
|
+
from langgraph_agent_toolkit.helper.constants import DEFAULT_AGENT, DEFAULT_RECURSION_LIMIT
|
|
19
|
+
from langgraph_agent_toolkit.helper.logging import logger
|
|
20
|
+
from langgraph_agent_toolkit.helper.utils import (
|
|
21
|
+
convert_message_content_to_string,
|
|
22
|
+
langchain_to_chat_message,
|
|
23
|
+
remove_tool_calls,
|
|
24
|
+
)
|
|
25
|
+
from langgraph_agent_toolkit.schema import AgentInfo, ChatMessage
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Type variable for the decorator
|
|
29
|
+
T = TypeVar("T")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AgentExecutor:
|
|
33
|
+
"""Handles the loading, execution and saving logic for different LangGraph agents."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, *args):
|
|
36
|
+
"""Initialize the AgentExecutor by importing agents.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
*args: Variable length strings specifying the agents to import,
|
|
40
|
+
e.g., "langgraph_agent_toolkit.agents.blueprints.react.agent:react_agent".
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If no agents are provided.
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
self.agents: Dict[str, Agent] = {}
|
|
47
|
+
|
|
48
|
+
# Check if args is empty and raise an error
|
|
49
|
+
if not args:
|
|
50
|
+
raise ValueError("At least one agent must be provided to AgentExecutor.")
|
|
51
|
+
|
|
52
|
+
# Load agents from import strings
|
|
53
|
+
self.load_agents_from_imports(args)
|
|
54
|
+
self._validate_default_agent_loaded()
|
|
55
|
+
|
|
56
|
+
def load_agents_from_imports(self, args: tuple) -> None:
|
|
57
|
+
"""Dynamically imports agents based on the provided import strings."""
|
|
58
|
+
for import_str in args:
|
|
59
|
+
try:
|
|
60
|
+
module_path, object_name = import_str.split(":")
|
|
61
|
+
module = importlib.import_module(module_path)
|
|
62
|
+
agent_obj = getattr(module, object_name)
|
|
63
|
+
|
|
64
|
+
# Check if it's a raw graph or already an Agent instance
|
|
65
|
+
if isinstance(agent_obj, (CompiledStateGraph, Pregel)):
|
|
66
|
+
agent = Agent(name=object_name, description=f"Dynamically loaded {object_name}", graph=agent_obj)
|
|
67
|
+
self.agents[agent.name] = agent
|
|
68
|
+
elif isinstance(agent_obj, Agent):
|
|
69
|
+
self.agents[agent_obj.name] = agent_obj
|
|
70
|
+
else:
|
|
71
|
+
print(f"Warning: Object '{object_name}' is neither a graph nor an Agent instance")
|
|
72
|
+
except (ImportError, AttributeError, ValueError) as e:
|
|
73
|
+
print(f"Error loading agent from '{import_str}': {e}")
|
|
74
|
+
|
|
75
|
+
def _validate_default_agent_loaded(self) -> None:
|
|
76
|
+
"""Validate that the default agent is loaded.
|
|
77
|
+
|
|
78
|
+
Raises:
|
|
79
|
+
ValueError: If the default agent is not loaded
|
|
80
|
+
|
|
81
|
+
"""
|
|
82
|
+
if not self.agents or DEFAULT_AGENT not in self.agents:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"Default agent '{DEFAULT_AGENT}' was not imported. Make sure to include it in your agent imports."
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def get_agent(self, agent_id: str) -> Agent:
|
|
88
|
+
"""Get an agent by its ID.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
agent_id: The ID of the agent to retrieve
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The requested Agent instance
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
KeyError: If the agent_id is not found
|
|
98
|
+
|
|
99
|
+
"""
|
|
100
|
+
if agent_id not in self.agents:
|
|
101
|
+
raise KeyError(f"Agent '{agent_id}' not found")
|
|
102
|
+
return self.agents[agent_id]
|
|
103
|
+
|
|
104
|
+
def get_all_agent_info(self) -> list[AgentInfo]:
|
|
105
|
+
"""Get information about all available agents.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
A list of AgentInfo objects containing agent IDs and descriptions
|
|
109
|
+
|
|
110
|
+
"""
|
|
111
|
+
return [AgentInfo(key=agent_id, description=agent.description) for agent_id, agent in self.agents.items()]
|
|
112
|
+
|
|
113
|
+
def add_agent(self, agent_id: str, agent: Agent) -> None:
|
|
114
|
+
"""Add a new agent to the executor.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
agent_id: The ID to assign to the agent
|
|
118
|
+
agent: The Agent instance to add
|
|
119
|
+
|
|
120
|
+
"""
|
|
121
|
+
self.agents[agent_id] = agent
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def handle_agent_errors(func: Callable[..., T]) -> Callable[..., T]:
|
|
125
|
+
"""Handle errors occurring during agent execution.
|
|
126
|
+
|
|
127
|
+
Specifically handles GraphRecursionError and other exceptions.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
func: The function to decorate
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
The decorated function
|
|
134
|
+
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def _handle_error(e: Exception):
|
|
138
|
+
"""Handle and re-raise errors with logging."""
|
|
139
|
+
if isinstance(e, GraphRecursionError):
|
|
140
|
+
logger.error(f"GraphRecursionError occurred: {e}")
|
|
141
|
+
else:
|
|
142
|
+
logger.error(f"Error during agent execution: {e}")
|
|
143
|
+
raise e
|
|
144
|
+
|
|
145
|
+
@functools.wraps(func)
|
|
146
|
+
async def async_wrapper(self, *args, **kwargs):
|
|
147
|
+
try:
|
|
148
|
+
return await func(self, *args, **kwargs)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
return _handle_error(e)
|
|
151
|
+
|
|
152
|
+
@functools.wraps(func)
|
|
153
|
+
def sync_wrapper(self, *args, **kwargs):
|
|
154
|
+
try:
|
|
155
|
+
return func(self, *args, **kwargs)
|
|
156
|
+
except Exception as e:
|
|
157
|
+
return _handle_error(e)
|
|
158
|
+
|
|
159
|
+
if asyncio.iscoroutinefunction(func):
|
|
160
|
+
return async_wrapper
|
|
161
|
+
else:
|
|
162
|
+
return sync_wrapper
|
|
163
|
+
|
|
164
|
+
async def _setup_agent_execution(
|
|
165
|
+
self,
|
|
166
|
+
agent_id: str,
|
|
167
|
+
message: str,
|
|
168
|
+
thread_id: Optional[str] = None,
|
|
169
|
+
model: Optional[str] = None,
|
|
170
|
+
agent_config: Optional[Dict[str, Any]] = None,
|
|
171
|
+
recursion_limit: Optional[int] = None,
|
|
172
|
+
) -> Tuple[Agent, Any, Any, UUID]:
|
|
173
|
+
"""Apply common setup for agent execution that both invoke and stream methods share.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
agent_id: ID of the agent to invoke
|
|
177
|
+
message: User message to send to the agent
|
|
178
|
+
thread_id: Optional thread ID for conversation history
|
|
179
|
+
model: Optional model name to override the default
|
|
180
|
+
agent_config: Optional additional configuration for the agent
|
|
181
|
+
recursion_limit: Optional recursion limit for the agent
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Tuple containing:
|
|
185
|
+
- agent: The Agent instance
|
|
186
|
+
- input_data: The properly formatted input for the agent
|
|
187
|
+
- config: The RunnableConfig for the agent
|
|
188
|
+
- run_id: The UUID for this run
|
|
189
|
+
|
|
190
|
+
"""
|
|
191
|
+
agent = self.get_agent(agent_id)
|
|
192
|
+
agent_graph = agent.graph
|
|
193
|
+
|
|
194
|
+
run_id = uuid4()
|
|
195
|
+
thread_id = thread_id or str(uuid4())
|
|
196
|
+
|
|
197
|
+
recursion_limit = recursion_limit or DEFAULT_RECURSION_LIMIT
|
|
198
|
+
|
|
199
|
+
configurable = {
|
|
200
|
+
"thread_id": thread_id,
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
if model:
|
|
204
|
+
configurable["model"] = model
|
|
205
|
+
|
|
206
|
+
if agent_config:
|
|
207
|
+
configurable.update(agent_config)
|
|
208
|
+
|
|
209
|
+
callback = agent.observability.get_callback_handler(session_id=thread_id)
|
|
210
|
+
|
|
211
|
+
config = RunnableConfig(
|
|
212
|
+
configurable=configurable,
|
|
213
|
+
run_id=run_id,
|
|
214
|
+
callbacks=[callback] if callback else None,
|
|
215
|
+
recursion_limit=recursion_limit,
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Check if there are any interrupts that need to be resumed
|
|
219
|
+
state = await agent_graph.aget_state(config=config)
|
|
220
|
+
interrupted_tasks = [task for task in state.tasks if hasattr(task, "interrupts") and task.interrupts]
|
|
221
|
+
|
|
222
|
+
if interrupted_tasks:
|
|
223
|
+
# User input is a response to resume agent execution from interrupt
|
|
224
|
+
input_data = Command(resume=message)
|
|
225
|
+
else:
|
|
226
|
+
input_data = {"messages": [HumanMessage(content=message)]}
|
|
227
|
+
|
|
228
|
+
return agent, input_data, config, run_id
|
|
229
|
+
|
|
230
|
+
@handle_agent_errors
|
|
231
|
+
async def invoke(
|
|
232
|
+
self,
|
|
233
|
+
agent_id: str,
|
|
234
|
+
message: str,
|
|
235
|
+
thread_id: Optional[str] = None,
|
|
236
|
+
model: Optional[str] = None,
|
|
237
|
+
agent_config: Optional[Dict[str, Any]] = None,
|
|
238
|
+
recursion_limit: Optional[int] = None,
|
|
239
|
+
) -> ChatMessage:
|
|
240
|
+
"""Invoke an agent with a message and return the response.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
agent_id: ID of the agent to invoke
|
|
244
|
+
message: User message to send to the agent
|
|
245
|
+
thread_id: Optional thread ID for conversation history
|
|
246
|
+
model: Optional model name to override the default
|
|
247
|
+
agent_config: Optional additional configuration for the agent
|
|
248
|
+
recursion_limit: Optional recursion limit for the agent
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
ChatMessage: The agent's response
|
|
252
|
+
|
|
253
|
+
"""
|
|
254
|
+
agent, input_data, config, run_id = await self._setup_agent_execution(
|
|
255
|
+
agent_id=agent_id,
|
|
256
|
+
message=message,
|
|
257
|
+
thread_id=thread_id,
|
|
258
|
+
model=model,
|
|
259
|
+
agent_config=agent_config,
|
|
260
|
+
recursion_limit=recursion_limit,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Invoke the agent
|
|
264
|
+
response_events = await agent.graph.ainvoke(input=input_data, config=config, stream_mode=["updates", "values"])
|
|
265
|
+
|
|
266
|
+
response_type, response = response_events[-1]
|
|
267
|
+
if response_type == "values":
|
|
268
|
+
# Normal response, the agent completed successfully
|
|
269
|
+
output = langchain_to_chat_message(response["messages"][-1])
|
|
270
|
+
elif response_type == "updates" and "__interrupt__" in response:
|
|
271
|
+
# The last thing to occur was an interrupt
|
|
272
|
+
# Return the value of the first interrupt as an AIMessage
|
|
273
|
+
output = langchain_to_chat_message(AIMessage(content=response["__interrupt__"][0].value))
|
|
274
|
+
else:
|
|
275
|
+
raise ValueError(f"Unexpected response type: {response_type}")
|
|
276
|
+
|
|
277
|
+
output.run_id = str(run_id)
|
|
278
|
+
return output
|
|
279
|
+
|
|
280
|
+
@handle_agent_errors
|
|
281
|
+
async def stream(
|
|
282
|
+
self,
|
|
283
|
+
agent_id: str,
|
|
284
|
+
message: str,
|
|
285
|
+
thread_id: Optional[str] = None,
|
|
286
|
+
model: Optional[str] = None,
|
|
287
|
+
stream_tokens: bool = True,
|
|
288
|
+
agent_config: Optional[Dict[str, Any]] = None,
|
|
289
|
+
recursion_limit: Optional[int] = None,
|
|
290
|
+
) -> AsyncGenerator[str | ChatMessage, None]:
|
|
291
|
+
"""Stream an agent's response to a message, yielding either tokens or messages.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
agent_id: ID of the agent to invoke
|
|
295
|
+
message: User message to send to the agent
|
|
296
|
+
thread_id: Optional thread ID for conversation history
|
|
297
|
+
model: Optional model name to override the default
|
|
298
|
+
stream_tokens: Whether to stream individual tokens
|
|
299
|
+
agent_config: Optional additional configuration for the agent
|
|
300
|
+
recursion_limit: Optional recursion limit for the agent
|
|
301
|
+
|
|
302
|
+
Yields:
|
|
303
|
+
Either ChatMessage objects for full messages or strings for token chunks
|
|
304
|
+
|
|
305
|
+
"""
|
|
306
|
+
agent, input_data, config, run_id = await self._setup_agent_execution(
|
|
307
|
+
agent_id=agent_id,
|
|
308
|
+
message=message,
|
|
309
|
+
thread_id=thread_id,
|
|
310
|
+
model=model,
|
|
311
|
+
agent_config=agent_config,
|
|
312
|
+
recursion_limit=recursion_limit,
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Stream from the agent with appropriate modes
|
|
316
|
+
stream_mode = ["updates", "messages", "custom"] if stream_tokens else ["updates"]
|
|
317
|
+
|
|
318
|
+
async for stream_event in agent.graph.astream(input=input_data, config=config, stream_mode=stream_mode):
|
|
319
|
+
if not isinstance(stream_event, tuple):
|
|
320
|
+
continue
|
|
321
|
+
|
|
322
|
+
stream_mode, event = stream_event
|
|
323
|
+
new_messages = []
|
|
324
|
+
|
|
325
|
+
if stream_mode == "updates":
|
|
326
|
+
for node, updates in event.items():
|
|
327
|
+
# A simple approach to handle agent interrupts.
|
|
328
|
+
# In a more sophisticated implementation, we could add
|
|
329
|
+
# some structured ChatMessage type to return the interrupt value.
|
|
330
|
+
if node == "__interrupt__":
|
|
331
|
+
interrupt: Interrupt
|
|
332
|
+
for interrupt in updates:
|
|
333
|
+
new_messages.append(AIMessage(content=interrupt.value))
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
update_messages = updates.get("messages", [])
|
|
337
|
+
|
|
338
|
+
# Special case for supervisor agent
|
|
339
|
+
if node == "supervisor":
|
|
340
|
+
# Get only the last AIMessage since supervisor includes all previous messages
|
|
341
|
+
ai_messages = [msg for msg in update_messages if isinstance(msg, AIMessage)]
|
|
342
|
+
if ai_messages:
|
|
343
|
+
update_messages = [ai_messages[-1]]
|
|
344
|
+
|
|
345
|
+
# Special case for expert agents
|
|
346
|
+
if node in ("research_expert", "math_expert"):
|
|
347
|
+
# Convert to ToolMessage so it displays in the UI as a tool response
|
|
348
|
+
if update_messages:
|
|
349
|
+
msg = ToolMessage(
|
|
350
|
+
content=update_messages[0].content,
|
|
351
|
+
name=node,
|
|
352
|
+
tool_call_id="",
|
|
353
|
+
)
|
|
354
|
+
update_messages = [msg]
|
|
355
|
+
new_messages.extend(update_messages)
|
|
356
|
+
|
|
357
|
+
elif stream_mode == "custom":
|
|
358
|
+
new_messages = [event]
|
|
359
|
+
|
|
360
|
+
elif stream_mode == "messages" and stream_tokens:
|
|
361
|
+
msg, metadata = event
|
|
362
|
+
if "skip_stream" in metadata.get("tags", []):
|
|
363
|
+
continue
|
|
364
|
+
# Skip non-LLM nodes that might send messages
|
|
365
|
+
if not isinstance(msg, AIMessageChunk):
|
|
366
|
+
continue
|
|
367
|
+
content = remove_tool_calls(msg.content)
|
|
368
|
+
if content:
|
|
369
|
+
# Empty content in OpenAI context usually means the model is asking for a tool to be invoked
|
|
370
|
+
yield convert_message_content_to_string(content)
|
|
371
|
+
|
|
372
|
+
# Process and yield all collected messages
|
|
373
|
+
for msg in new_messages:
|
|
374
|
+
try:
|
|
375
|
+
chat_message = langchain_to_chat_message(msg)
|
|
376
|
+
chat_message.run_id = str(run_id)
|
|
377
|
+
# Skip the input message if it's repeated by LangGraph
|
|
378
|
+
if chat_message.type == "human" and chat_message.content == message:
|
|
379
|
+
continue
|
|
380
|
+
yield chat_message
|
|
381
|
+
except Exception as e:
|
|
382
|
+
logger.error(f"Error parsing message: {e}")
|
|
383
|
+
continue
|
|
384
|
+
|
|
385
|
+
def save(self, path: str, agent_ids: Optional[List[str]] = None) -> None:
|
|
386
|
+
"""Save agents to disk using joblib.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
path: Directory path where to save agents
|
|
390
|
+
agent_ids: List of agent IDs to save. If None, saves all agents.
|
|
391
|
+
|
|
392
|
+
"""
|
|
393
|
+
_path = Path(path)
|
|
394
|
+
_path.mkdir(exist_ok=True, parents=True)
|
|
395
|
+
|
|
396
|
+
agents_to_save = self.agents
|
|
397
|
+
if agent_ids:
|
|
398
|
+
agents_to_save = {k: v for k, v in self.agents.items() if k in agent_ids}
|
|
399
|
+
|
|
400
|
+
for agent_id, agent in agents_to_save.items():
|
|
401
|
+
joblib.dump(agent, _path / f"{agent_id}.joblib")
|
|
402
|
+
|
|
403
|
+
def load_saved_agents(self, path: str) -> None:
|
|
404
|
+
"""Load saved agents from disk using joblib.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
path: Directory path from which to load agents
|
|
408
|
+
|
|
409
|
+
"""
|
|
410
|
+
for filename in os.listdir(path):
|
|
411
|
+
if filename.endswith(".joblib"):
|
|
412
|
+
agent = joblib.load(os.path.join(path, filename))
|
|
413
|
+
self.agents[agent.name] = agent
|
|
414
|
+
|
|
415
|
+
self._validate_default_agent_loaded()
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
|
4
|
+
from langchain_core.messages import AIMessage
|
|
5
|
+
from langchain_core.runnables import RunnableConfig, RunnableLambda, RunnableSerializable
|
|
6
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
7
|
+
from langgraph.graph import END, MessagesState, StateGraph
|
|
8
|
+
from langgraph.types import StreamWriter
|
|
9
|
+
|
|
10
|
+
from langgraph_agent_toolkit.agents.agent import Agent
|
|
11
|
+
from langgraph_agent_toolkit.agents.blueprints.bg_task_agent.task import Task
|
|
12
|
+
from langgraph_agent_toolkit.core import settings
|
|
13
|
+
from langgraph_agent_toolkit.core.models.factory import ModelFactory
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AgentState(MessagesState, total=False):
|
|
17
|
+
"""`total=False` is PEP589 specs.
|
|
18
|
+
|
|
19
|
+
documentation: https://typing.readthedocs.io/en/latest/spec/typeddict.html#totality
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def wrap_model(model: BaseChatModel) -> RunnableSerializable[AgentState, AIMessage]:
|
|
24
|
+
preprocessor = RunnableLambda(
|
|
25
|
+
lambda state: state["messages"],
|
|
26
|
+
name="StateModifier",
|
|
27
|
+
)
|
|
28
|
+
return preprocessor | model
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def acall_model(state: AgentState, config: RunnableConfig) -> AgentState:
|
|
32
|
+
m = ModelFactory.create(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
|
33
|
+
model_runnable = wrap_model(m)
|
|
34
|
+
response = await model_runnable.ainvoke(state, config)
|
|
35
|
+
|
|
36
|
+
# We return a list, because this will get added to the existing list
|
|
37
|
+
return {"messages": [response]}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
async def bg_task(state: AgentState, writer: StreamWriter) -> AgentState:
|
|
41
|
+
task1 = Task("Simple task 1...", writer)
|
|
42
|
+
task2 = Task("Simple task 2...", writer)
|
|
43
|
+
|
|
44
|
+
task1.start()
|
|
45
|
+
await asyncio.sleep(2)
|
|
46
|
+
task2.start()
|
|
47
|
+
await asyncio.sleep(2)
|
|
48
|
+
task1.write_data(data={"status": "Still running..."})
|
|
49
|
+
await asyncio.sleep(2)
|
|
50
|
+
task2.finish(result="error", data={"output": 42})
|
|
51
|
+
await asyncio.sleep(2)
|
|
52
|
+
task1.finish(result="success", data={"output": 42})
|
|
53
|
+
return {"messages": []}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Define the graph
|
|
57
|
+
agent = StateGraph(AgentState)
|
|
58
|
+
agent.add_node("model", acall_model)
|
|
59
|
+
agent.add_node("bg_task", bg_task)
|
|
60
|
+
agent.set_entry_point("bg_task")
|
|
61
|
+
|
|
62
|
+
agent.add_edge("bg_task", "model")
|
|
63
|
+
agent.add_edge("model", END)
|
|
64
|
+
|
|
65
|
+
bg_task_agent = Agent(
|
|
66
|
+
name="bg-task-agent",
|
|
67
|
+
description="A background task agent.",
|
|
68
|
+
graph=agent.compile(checkpointer=MemorySaver()),
|
|
69
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from typing import Any, Dict, Literal, Optional
|
|
2
|
+
from uuid import uuid4
|
|
3
|
+
|
|
4
|
+
from langchain_core.messages import BaseMessage
|
|
5
|
+
from langgraph.types import StreamWriter
|
|
6
|
+
|
|
7
|
+
from langgraph_agent_toolkit.agents.blueprints.bg_task_agent.utils import CustomData
|
|
8
|
+
from langgraph_agent_toolkit.schema.task_data import TaskData
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Task:
|
|
12
|
+
def __init__(self, task_name: str, writer: StreamWriter | None = None) -> None:
|
|
13
|
+
self.name = task_name
|
|
14
|
+
self.id = str(uuid4())
|
|
15
|
+
self.state: Literal["new", "running", "complete"] = "new"
|
|
16
|
+
self.result: Literal["success", "error"] | None = None
|
|
17
|
+
self.writer = writer
|
|
18
|
+
|
|
19
|
+
def _generate_and_dispatch_message(self, writer: StreamWriter, data: dict):
|
|
20
|
+
writer = writer or self.writer
|
|
21
|
+
task_data = TaskData(name=self.name, run_id=self.id, state=self.state, data=data)
|
|
22
|
+
if self.result:
|
|
23
|
+
task_data.result = self.result
|
|
24
|
+
task_custom_data = CustomData(
|
|
25
|
+
type=self.name,
|
|
26
|
+
data=task_data.model_dump(),
|
|
27
|
+
)
|
|
28
|
+
task_custom_data.dispatch(writer)
|
|
29
|
+
return task_custom_data.to_langchain()
|
|
30
|
+
|
|
31
|
+
def start(self, writer: StreamWriter | None = None, data: dict = {}) -> BaseMessage:
|
|
32
|
+
self.state = "new"
|
|
33
|
+
task_message = self._generate_and_dispatch_message(writer, data)
|
|
34
|
+
return task_message
|
|
35
|
+
|
|
36
|
+
def write_data(self, writer: StreamWriter | None = None, data: dict = {}) -> BaseMessage:
|
|
37
|
+
if self.state == "complete":
|
|
38
|
+
raise ValueError("Only incomplete tasks can output data.")
|
|
39
|
+
self.state = "running"
|
|
40
|
+
task_message = self._generate_and_dispatch_message(writer, data)
|
|
41
|
+
return task_message
|
|
42
|
+
|
|
43
|
+
def finish(
|
|
44
|
+
self,
|
|
45
|
+
result: Literal["success", "error"],
|
|
46
|
+
writer: StreamWriter | None = None,
|
|
47
|
+
data: Optional[Dict[str, Any]] = None,
|
|
48
|
+
) -> BaseMessage:
|
|
49
|
+
self.state = "complete"
|
|
50
|
+
self.result = result
|
|
51
|
+
task_message = self._generate_and_dispatch_message(writer, data or {})
|
|
52
|
+
return task_message
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from langchain_core.messages import ChatMessage
|
|
4
|
+
from langgraph.types import StreamWriter
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class CustomData(BaseModel):
|
|
9
|
+
"""Custom data being sent by an agent."""
|
|
10
|
+
|
|
11
|
+
data: dict[str, Any] = Field(description="The custom data")
|
|
12
|
+
|
|
13
|
+
def to_langchain(self) -> ChatMessage:
|
|
14
|
+
return ChatMessage(content=[self.data], role="custom")
|
|
15
|
+
|
|
16
|
+
def dispatch(self, writer: StreamWriter) -> None:
|
|
17
|
+
writer(self.to_langchain())
|
|
File without changes
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from langchain_core.messages import BaseMessage
|
|
2
|
+
from langchain_core.runnables import RunnableConfig
|
|
3
|
+
from langgraph.checkpoint.memory import MemorySaver
|
|
4
|
+
from langgraph.func import entrypoint
|
|
5
|
+
from langgraph.graph import add_messages
|
|
6
|
+
|
|
7
|
+
from langgraph_agent_toolkit.agents.agent import Agent
|
|
8
|
+
from langgraph_agent_toolkit.core import settings
|
|
9
|
+
from langgraph_agent_toolkit.core.models.factory import ModelFactory
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@entrypoint(
|
|
13
|
+
# checkpointer=MemorySaver(), # Uncomment if you want to save the state of the agent
|
|
14
|
+
)
|
|
15
|
+
async def chatbot(
|
|
16
|
+
inputs: dict[str, list[BaseMessage]],
|
|
17
|
+
*,
|
|
18
|
+
previous: dict[str, list[BaseMessage]],
|
|
19
|
+
config: RunnableConfig,
|
|
20
|
+
):
|
|
21
|
+
messages = inputs["messages"]
|
|
22
|
+
if previous:
|
|
23
|
+
messages = add_messages(previous["messages"], messages)
|
|
24
|
+
|
|
25
|
+
model = ModelFactory.create(config["configurable"].get("model", settings.DEFAULT_MODEL))
|
|
26
|
+
response = await model.ainvoke(messages)
|
|
27
|
+
return entrypoint.final(value={"messages": [response]}, save={"messages": add_messages(messages, response)})
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
chatbot_agent = Agent(
|
|
31
|
+
name="chatbot-agent",
|
|
32
|
+
description="A simple chatbot.",
|
|
33
|
+
graph=chatbot,
|
|
34
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from langchain_core.messages import AIMessage
|
|
5
|
+
from langgraph.graph import START, MessagesState, StateGraph
|
|
6
|
+
from langgraph.types import Command
|
|
7
|
+
|
|
8
|
+
from langgraph_agent_toolkit.agents.agent import Agent
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AgentState(MessagesState, total=False):
|
|
12
|
+
pass
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def node_a(state: AgentState) -> Command[Literal["node_b", "node_c"]]:
|
|
16
|
+
print("Called A")
|
|
17
|
+
value = random.choice(["a", "b"])
|
|
18
|
+
# this is a replacement for a conditional edge function
|
|
19
|
+
if value == "a":
|
|
20
|
+
goto = "node_b"
|
|
21
|
+
else:
|
|
22
|
+
goto = "node_c"
|
|
23
|
+
|
|
24
|
+
# note how Command allows you to BOTH update the graph state AND route to the next node
|
|
25
|
+
return Command(
|
|
26
|
+
# this is the state update
|
|
27
|
+
update={"messages": [AIMessage(content=f"Hello {value}")]},
|
|
28
|
+
# this is a replacement for an edge
|
|
29
|
+
goto=goto,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def node_b(state: AgentState):
|
|
34
|
+
print("Called B")
|
|
35
|
+
return {"messages": [AIMessage(content="Hello B")]}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def node_c(state: AgentState):
|
|
39
|
+
print("Called C")
|
|
40
|
+
return {"messages": [AIMessage(content="Hello C")]}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
builder = StateGraph(AgentState)
|
|
44
|
+
builder.add_edge(START, "node_a")
|
|
45
|
+
builder.add_node(node_a)
|
|
46
|
+
builder.add_node(node_b)
|
|
47
|
+
builder.add_node(node_c)
|
|
48
|
+
# NOTE: there are no edges between nodes A, B and C!
|
|
49
|
+
|
|
50
|
+
command_agent = Agent(
|
|
51
|
+
name="command-agent",
|
|
52
|
+
description="A command agent.",
|
|
53
|
+
graph=builder.compile(),
|
|
54
|
+
)
|
|
File without changes
|