langgraph-agent-toolkit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langgraph_agent_toolkit/__init__.py +7 -0
- langgraph_agent_toolkit/agents/__init__.py +0 -0
- langgraph_agent_toolkit/agents/agent.py +14 -0
- langgraph_agent_toolkit/agents/agent_executor.py +415 -0
- langgraph_agent_toolkit/agents/blueprints/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/agent.py +69 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/task.py +52 -0
- langgraph_agent_toolkit/agents/blueprints/bg_task_agent/utils.py +17 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/chatbot/agent.py +34 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/command_agent/agent.py +54 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/interrupt_agent/agent.py +140 -0
- langgraph_agent_toolkit/agents/blueprints/react/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react/agent.py +67 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/react_so/agent.py +39 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/__init__.py +0 -0
- langgraph_agent_toolkit/agents/blueprints/supervisor_agent/agent.py +44 -0
- langgraph_agent_toolkit/agents/components/__init__.py +0 -0
- langgraph_agent_toolkit/agents/components/creators/__init__.py +4 -0
- langgraph_agent_toolkit/agents/components/creators/create_react_agent.py +459 -0
- langgraph_agent_toolkit/agents/components/tools.py +13 -0
- langgraph_agent_toolkit/agents/components/utils.py +42 -0
- langgraph_agent_toolkit/client/__init__.py +4 -0
- langgraph_agent_toolkit/client/client.py +344 -0
- langgraph_agent_toolkit/core/__init__.py +5 -0
- langgraph_agent_toolkit/core/memory/__init__.py +0 -0
- langgraph_agent_toolkit/core/memory/base.py +33 -0
- langgraph_agent_toolkit/core/memory/factory.py +30 -0
- langgraph_agent_toolkit/core/memory/postgres.py +76 -0
- langgraph_agent_toolkit/core/memory/sqlite.py +21 -0
- langgraph_agent_toolkit/core/memory/types.py +6 -0
- langgraph_agent_toolkit/core/models/__init__.py +5 -0
- langgraph_agent_toolkit/core/models/chat_openai.py +21 -0
- langgraph_agent_toolkit/core/models/factory.py +118 -0
- langgraph_agent_toolkit/core/models/fake.py +25 -0
- langgraph_agent_toolkit/core/observability/__init__.py +10 -0
- langgraph_agent_toolkit/core/observability/base.py +331 -0
- langgraph_agent_toolkit/core/observability/empty.py +67 -0
- langgraph_agent_toolkit/core/observability/factory.py +43 -0
- langgraph_agent_toolkit/core/observability/langfuse.py +118 -0
- langgraph_agent_toolkit/core/observability/langsmith.py +131 -0
- langgraph_agent_toolkit/core/observability/types.py +34 -0
- langgraph_agent_toolkit/core/prompts/__init__.py +0 -0
- langgraph_agent_toolkit/core/prompts/chat_prompt_template.py +528 -0
- langgraph_agent_toolkit/core/settings.py +164 -0
- langgraph_agent_toolkit/helper/__init__.py +0 -0
- langgraph_agent_toolkit/helper/constants.py +10 -0
- langgraph_agent_toolkit/helper/logging.py +111 -0
- langgraph_agent_toolkit/helper/types.py +7 -0
- langgraph_agent_toolkit/helper/utils.py +80 -0
- langgraph_agent_toolkit/run_agent.py +68 -0
- langgraph_agent_toolkit/run_client.py +55 -0
- langgraph_agent_toolkit/run_service.py +19 -0
- langgraph_agent_toolkit/schema/__init__.py +28 -0
- langgraph_agent_toolkit/schema/models.py +25 -0
- langgraph_agent_toolkit/schema/schema.py +210 -0
- langgraph_agent_toolkit/schema/task_data.py +72 -0
- langgraph_agent_toolkit/service/__init__.py +0 -0
- langgraph_agent_toolkit/service/exception_handlers.py +46 -0
- langgraph_agent_toolkit/service/factory.py +213 -0
- langgraph_agent_toolkit/service/handler.py +122 -0
- langgraph_agent_toolkit/service/middleware.py +18 -0
- langgraph_agent_toolkit/service/routes.py +169 -0
- langgraph_agent_toolkit/service/types.py +8 -0
- langgraph_agent_toolkit/service/utils.py +136 -0
- langgraph_agent_toolkit/streamlit_app.py +368 -0
- langgraph_agent_toolkit-0.1.0.dist-info/METADATA +424 -0
- langgraph_agent_toolkit-0.1.0.dist-info/RECORD +74 -0
- langgraph_agent_toolkit-0.1.0.dist-info/WHEEL +4 -0
- langgraph_agent_toolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from collections.abc import AsyncGenerator, Generator
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
|
|
8
|
+
from langgraph_agent_toolkit.schema import (
|
|
9
|
+
ChatHistory,
|
|
10
|
+
ChatHistoryInput,
|
|
11
|
+
ChatMessage,
|
|
12
|
+
Feedback,
|
|
13
|
+
ServiceMetadata,
|
|
14
|
+
StreamInput,
|
|
15
|
+
UserInput,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AgentClientError(Exception):
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AgentClient:
|
|
24
|
+
"""Client for interacting with the agent service."""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
base_url: str = "http://0.0.0.0",
|
|
29
|
+
agent: str = None,
|
|
30
|
+
timeout: float | None = None,
|
|
31
|
+
get_info: bool = True,
|
|
32
|
+
verify: bool = False,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize the client.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
base_url (str): The base URL of the agent service.
|
|
38
|
+
agent (str): The name of the default agent to use.
|
|
39
|
+
timeout (float, optional): The timeout for requests.
|
|
40
|
+
get_info (bool, optional): Whether to fetch agent information on init.
|
|
41
|
+
Default: True
|
|
42
|
+
verify (bool, optional): Whether to verify the agent information.
|
|
43
|
+
Default: False
|
|
44
|
+
|
|
45
|
+
"""
|
|
46
|
+
self.base_url = base_url
|
|
47
|
+
self.auth_secret = os.getenv("AUTH_SECRET")
|
|
48
|
+
self.timeout = timeout
|
|
49
|
+
self.info: ServiceMetadata | None = None
|
|
50
|
+
self.agent: str | None = None
|
|
51
|
+
if get_info:
|
|
52
|
+
self.retrieve_info()
|
|
53
|
+
if agent:
|
|
54
|
+
self.update_agent(agent, verify=verify)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def _headers(self) -> dict[str, str]:
|
|
58
|
+
headers = {}
|
|
59
|
+
if self.auth_secret:
|
|
60
|
+
headers["Authorization"] = f"Bearer {self.auth_secret}"
|
|
61
|
+
return headers
|
|
62
|
+
|
|
63
|
+
def retrieve_info(self) -> None:
|
|
64
|
+
try:
|
|
65
|
+
response = httpx.get(
|
|
66
|
+
f"{self.base_url}/info",
|
|
67
|
+
headers=self._headers,
|
|
68
|
+
timeout=self.timeout,
|
|
69
|
+
)
|
|
70
|
+
response.raise_for_status()
|
|
71
|
+
except httpx.HTTPError as e:
|
|
72
|
+
raise AgentClientError(f"Error getting service info: {e}")
|
|
73
|
+
|
|
74
|
+
self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json())
|
|
75
|
+
if not self.agent or self.agent not in [a.key for a in self.info.agents]:
|
|
76
|
+
self.agent = self.info.default_agent
|
|
77
|
+
|
|
78
|
+
def update_agent(self, agent: str, verify: bool = True) -> None:
|
|
79
|
+
if verify:
|
|
80
|
+
if not self.info:
|
|
81
|
+
self.retrieve_info()
|
|
82
|
+
agent_keys = [a.key for a in self.info.agents]
|
|
83
|
+
if agent not in agent_keys:
|
|
84
|
+
raise AgentClientError(f"Agent {agent} not found in available agents: {', '.join(agent_keys)}")
|
|
85
|
+
self.agent = agent
|
|
86
|
+
|
|
87
|
+
async def ainvoke(
|
|
88
|
+
self,
|
|
89
|
+
message: str,
|
|
90
|
+
model: str | None = None,
|
|
91
|
+
thread_id: str | None = None,
|
|
92
|
+
agent_config: dict[str, Any] | None = None,
|
|
93
|
+
) -> ChatMessage:
|
|
94
|
+
"""Invoke the agent asynchronously. Only the final message is returned.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
message (str): The message to send to the agent
|
|
98
|
+
model (str, optional): LLM model to use for the agent
|
|
99
|
+
thread_id (str, optional): Thread ID for continuing a conversation
|
|
100
|
+
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
AnyMessage: The response from the agent
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
if not self.agent:
|
|
107
|
+
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
|
|
108
|
+
request = UserInput(message=message)
|
|
109
|
+
if thread_id:
|
|
110
|
+
request.thread_id = thread_id
|
|
111
|
+
if model:
|
|
112
|
+
request.model = model
|
|
113
|
+
if agent_config:
|
|
114
|
+
request.agent_config = agent_config
|
|
115
|
+
async with httpx.AsyncClient() as client:
|
|
116
|
+
try:
|
|
117
|
+
response = await client.post(
|
|
118
|
+
f"{self.base_url}/{self.agent}/invoke",
|
|
119
|
+
json=request.model_dump(),
|
|
120
|
+
headers=self._headers,
|
|
121
|
+
timeout=self.timeout,
|
|
122
|
+
)
|
|
123
|
+
response.raise_for_status()
|
|
124
|
+
except httpx.HTTPError as e:
|
|
125
|
+
raise AgentClientError(f"Error: {e}")
|
|
126
|
+
|
|
127
|
+
return ChatMessage.model_validate(response.json())
|
|
128
|
+
|
|
129
|
+
def invoke(
|
|
130
|
+
self,
|
|
131
|
+
message: str,
|
|
132
|
+
model: str | None = None,
|
|
133
|
+
thread_id: str | None = None,
|
|
134
|
+
agent_config: dict[str, Any] | None = None,
|
|
135
|
+
) -> ChatMessage:
|
|
136
|
+
"""Invoke the agent synchronously. Only the final message is returned.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
message (str): The message to send to the agent
|
|
140
|
+
model (str, optional): LLM model to use for the agent
|
|
141
|
+
thread_id (str, optional): Thread ID for continuing a conversation
|
|
142
|
+
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
ChatMessage: The response from the agent
|
|
146
|
+
|
|
147
|
+
"""
|
|
148
|
+
if not self.agent:
|
|
149
|
+
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
|
|
150
|
+
request = UserInput(message=message)
|
|
151
|
+
if thread_id:
|
|
152
|
+
request.thread_id = thread_id
|
|
153
|
+
if model:
|
|
154
|
+
request.model = model
|
|
155
|
+
if agent_config:
|
|
156
|
+
request.agent_config = agent_config
|
|
157
|
+
try:
|
|
158
|
+
response = httpx.post(
|
|
159
|
+
f"{self.base_url}/{self.agent}/invoke",
|
|
160
|
+
json=request.model_dump(),
|
|
161
|
+
headers=self._headers,
|
|
162
|
+
timeout=self.timeout,
|
|
163
|
+
)
|
|
164
|
+
response.raise_for_status()
|
|
165
|
+
except httpx.HTTPError as e:
|
|
166
|
+
raise AgentClientError(f"Error: {e}")
|
|
167
|
+
|
|
168
|
+
return ChatMessage.model_validate(response.json())
|
|
169
|
+
|
|
170
|
+
def _parse_stream_line(self, line: str) -> ChatMessage | str | None:
|
|
171
|
+
line = line.strip()
|
|
172
|
+
if line.startswith("data: "):
|
|
173
|
+
data = line[6:]
|
|
174
|
+
if data == "[DONE]":
|
|
175
|
+
return None
|
|
176
|
+
try:
|
|
177
|
+
parsed = json.loads(data)
|
|
178
|
+
except Exception as e:
|
|
179
|
+
raise Exception(f"Error JSON parsing message from server: {e}")
|
|
180
|
+
match parsed["type"]:
|
|
181
|
+
case "message":
|
|
182
|
+
# Convert the JSON formatted message to an AnyMessage
|
|
183
|
+
try:
|
|
184
|
+
return ChatMessage.model_validate(parsed["content"])
|
|
185
|
+
except Exception as e:
|
|
186
|
+
raise Exception(f"Server returned invalid message: {e}")
|
|
187
|
+
case "token":
|
|
188
|
+
# Yield the str token directly
|
|
189
|
+
return parsed["content"]
|
|
190
|
+
case "error":
|
|
191
|
+
raise Exception(parsed["content"])
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def stream(
|
|
195
|
+
self,
|
|
196
|
+
message: str,
|
|
197
|
+
model: str | None = None,
|
|
198
|
+
thread_id: str | None = None,
|
|
199
|
+
agent_config: dict[str, Any] | None = None,
|
|
200
|
+
stream_tokens: bool = True,
|
|
201
|
+
) -> Generator[ChatMessage | str, None, None]:
|
|
202
|
+
"""Stream the agent's response synchronously.
|
|
203
|
+
|
|
204
|
+
Each intermediate message of the agent process is yielded as a ChatMessage.
|
|
205
|
+
If stream_tokens is True (the default value), the response will also yield
|
|
206
|
+
content tokens from streaming models as they are generated.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
message (str): The message to send to the agent
|
|
210
|
+
model (str, optional): LLM model to use for the agent
|
|
211
|
+
thread_id (str, optional): Thread ID for continuing a conversation
|
|
212
|
+
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
|
|
213
|
+
stream_tokens (bool, optional): Stream tokens as they are generated
|
|
214
|
+
Default: True
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Generator[ChatMessage | str, None, None]: The response from the agent
|
|
218
|
+
|
|
219
|
+
"""
|
|
220
|
+
if not self.agent:
|
|
221
|
+
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
|
|
222
|
+
request = StreamInput(message=message, stream_tokens=stream_tokens)
|
|
223
|
+
if thread_id:
|
|
224
|
+
request.thread_id = thread_id
|
|
225
|
+
if model:
|
|
226
|
+
request.model = model
|
|
227
|
+
if agent_config:
|
|
228
|
+
request.agent_config = agent_config
|
|
229
|
+
try:
|
|
230
|
+
with httpx.stream(
|
|
231
|
+
"POST",
|
|
232
|
+
f"{self.base_url}/{self.agent}/stream",
|
|
233
|
+
json=request.model_dump(),
|
|
234
|
+
headers=self._headers,
|
|
235
|
+
timeout=self.timeout,
|
|
236
|
+
) as response:
|
|
237
|
+
response.raise_for_status()
|
|
238
|
+
for line in response.iter_lines():
|
|
239
|
+
if line.strip():
|
|
240
|
+
parsed = self._parse_stream_line(line)
|
|
241
|
+
if parsed is None:
|
|
242
|
+
break
|
|
243
|
+
yield parsed
|
|
244
|
+
except httpx.HTTPError as e:
|
|
245
|
+
raise AgentClientError(f"Error: {e}")
|
|
246
|
+
|
|
247
|
+
async def astream(
|
|
248
|
+
self,
|
|
249
|
+
message: str,
|
|
250
|
+
model: str | None = None,
|
|
251
|
+
thread_id: str | None = None,
|
|
252
|
+
agent_config: dict[str, Any] | None = None,
|
|
253
|
+
stream_tokens: bool = True,
|
|
254
|
+
) -> AsyncGenerator[ChatMessage | str, None]:
|
|
255
|
+
"""Stream the agent's response asynchronously.
|
|
256
|
+
|
|
257
|
+
Each intermediate message of the agent process is yielded as an AnyMessage.
|
|
258
|
+
If stream_tokens is True (the default value), the response will also yield
|
|
259
|
+
content tokens from streaming modelsas they are generated.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
message (str): The message to send to the agent
|
|
263
|
+
model (str, optional): LLM model to use for the agent
|
|
264
|
+
thread_id (str, optional): Thread ID for continuing a conversation
|
|
265
|
+
agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
|
|
266
|
+
stream_tokens (bool, optional): Stream tokens as they are generated
|
|
267
|
+
Default: True
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
AsyncGenerator[ChatMessage | str, None]: The response from the agent
|
|
271
|
+
|
|
272
|
+
"""
|
|
273
|
+
if not self.agent:
|
|
274
|
+
raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
|
|
275
|
+
request = StreamInput(message=message, stream_tokens=stream_tokens)
|
|
276
|
+
if thread_id:
|
|
277
|
+
request.thread_id = thread_id
|
|
278
|
+
if model:
|
|
279
|
+
request.model = model
|
|
280
|
+
if agent_config:
|
|
281
|
+
request.agent_config = agent_config
|
|
282
|
+
async with httpx.AsyncClient() as client:
|
|
283
|
+
try:
|
|
284
|
+
async with client.stream(
|
|
285
|
+
"POST",
|
|
286
|
+
f"{self.base_url}/{self.agent}/stream",
|
|
287
|
+
json=request.model_dump(),
|
|
288
|
+
headers=self._headers,
|
|
289
|
+
timeout=self.timeout,
|
|
290
|
+
) as response:
|
|
291
|
+
response.raise_for_status()
|
|
292
|
+
async for line in response.aiter_lines():
|
|
293
|
+
if line.strip():
|
|
294
|
+
parsed = self._parse_stream_line(line)
|
|
295
|
+
if parsed is None:
|
|
296
|
+
break
|
|
297
|
+
yield parsed
|
|
298
|
+
except httpx.HTTPError as e:
|
|
299
|
+
raise AgentClientError(f"Error: {e}")
|
|
300
|
+
|
|
301
|
+
async def acreate_feedback(self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {}) -> None:
|
|
302
|
+
"""Create a feedback record for a run.
|
|
303
|
+
|
|
304
|
+
This is a simple wrapper for the LangSmith create_feedback API, so the
|
|
305
|
+
credentials can be stored and managed in the service rather than the client.
|
|
306
|
+
See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post
|
|
307
|
+
"""
|
|
308
|
+
request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs)
|
|
309
|
+
async with httpx.AsyncClient() as client:
|
|
310
|
+
try:
|
|
311
|
+
response = await client.post(
|
|
312
|
+
f"{self.base_url}/feedback",
|
|
313
|
+
json=request.model_dump(),
|
|
314
|
+
headers=self._headers,
|
|
315
|
+
timeout=self.timeout,
|
|
316
|
+
)
|
|
317
|
+
response.raise_for_status()
|
|
318
|
+
response.json()
|
|
319
|
+
except httpx.HTTPError as e:
|
|
320
|
+
raise AgentClientError(f"Error: {e}")
|
|
321
|
+
|
|
322
|
+
def get_history(
|
|
323
|
+
self,
|
|
324
|
+
thread_id: str,
|
|
325
|
+
) -> ChatHistory:
|
|
326
|
+
"""Get chat history.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
thread_id (str, optional): Thread ID for identifying a conversation
|
|
330
|
+
|
|
331
|
+
"""
|
|
332
|
+
request = ChatHistoryInput(thread_id=thread_id)
|
|
333
|
+
try:
|
|
334
|
+
response = httpx.post(
|
|
335
|
+
f"{self.base_url}/history",
|
|
336
|
+
json=request.model_dump(),
|
|
337
|
+
headers=self._headers,
|
|
338
|
+
timeout=self.timeout,
|
|
339
|
+
)
|
|
340
|
+
response.raise_for_status()
|
|
341
|
+
except httpx.HTTPError as e:
|
|
342
|
+
raise AgentClientError(f"Error: {e}")
|
|
343
|
+
|
|
344
|
+
return ChatHistory.model_validate(response.json())
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from contextlib import AbstractAsyncContextManager
|
|
3
|
+
from typing import Any, TypeVar
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
T = TypeVar("T", bound=Any)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseMemoryBackend(ABC):
|
|
10
|
+
"""Base class for memory backends."""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def validate_config(self) -> bool:
|
|
14
|
+
"""Validate that all necessary configuration is set.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
True if configuration is valid
|
|
18
|
+
|
|
19
|
+
Raises:
|
|
20
|
+
ValueError: If required configuration is missing
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_checkpoint_saver(self) -> AbstractAsyncContextManager[T]:
|
|
27
|
+
"""Get the checkpoint saver for the memory backend.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A configured checkpoint saver
|
|
31
|
+
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
|
|
2
|
+
from langgraph_agent_toolkit.core.memory.postgres import PostgresMemoryBackend
|
|
3
|
+
from langgraph_agent_toolkit.core.memory.sqlite import SQLiteMemoryBackend
|
|
4
|
+
from langgraph_agent_toolkit.core.memory.types import MemoryBackends
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class MemoryFactory:
|
|
8
|
+
"""Factory for creating memory backend instances."""
|
|
9
|
+
|
|
10
|
+
@staticmethod
|
|
11
|
+
def create(backend: MemoryBackends) -> BaseMemoryBackend:
|
|
12
|
+
"""Create and return a memory backend instance.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
backend: The memory backend to create
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
An instance of the requested memory backend
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If the requested backend is not supported
|
|
22
|
+
|
|
23
|
+
"""
|
|
24
|
+
match backend:
|
|
25
|
+
case MemoryBackends.POSTGRES:
|
|
26
|
+
return PostgresMemoryBackend()
|
|
27
|
+
case MemoryBackends.SQLITE:
|
|
28
|
+
return SQLiteMemoryBackend()
|
|
29
|
+
case _:
|
|
30
|
+
raise ValueError(f"Unsupported memory backend: {backend}")
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
3
|
+
|
|
4
|
+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
5
|
+
from psycopg.rows import dict_row
|
|
6
|
+
from psycopg_pool import AsyncConnectionPool
|
|
7
|
+
|
|
8
|
+
from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
|
|
9
|
+
from langgraph_agent_toolkit.core.settings import settings
|
|
10
|
+
from langgraph_agent_toolkit.helper.logging import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PostgresMemoryBackend(BaseMemoryBackend):
|
|
14
|
+
"""PostgreSQL implementation of memory backend."""
|
|
15
|
+
|
|
16
|
+
def validate_config(self) -> bool:
|
|
17
|
+
"""Validate that all required PostgreSQL configuration is present."""
|
|
18
|
+
required_vars = [
|
|
19
|
+
"POSTGRES_USER",
|
|
20
|
+
"POSTGRES_PASSWORD",
|
|
21
|
+
"POSTGRES_HOST",
|
|
22
|
+
"POSTGRES_PORT",
|
|
23
|
+
"POSTGRES_DB",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
missing = [var for var in required_vars if not getattr(settings, var, None)]
|
|
27
|
+
if missing:
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
|
|
30
|
+
"These environment variables must be set to use PostgreSQL persistence."
|
|
31
|
+
)
|
|
32
|
+
return True
|
|
33
|
+
|
|
34
|
+
def get_connection_string(self) -> str:
|
|
35
|
+
"""Build and return the PostgreSQL connection string from settings."""
|
|
36
|
+
return (
|
|
37
|
+
f"postgresql://{settings.POSTGRES_USER}:"
|
|
38
|
+
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
|
|
39
|
+
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
|
|
40
|
+
f"{settings.POSTGRES_DB}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@asynccontextmanager
|
|
44
|
+
async def get_saver(self) -> AsyncGenerator[AsyncPostgresSaver, None]:
|
|
45
|
+
"""Asynchronous context manager for acquiring and releasing the PostgreSQL connection pool.
|
|
46
|
+
|
|
47
|
+
Yields:
|
|
48
|
+
AsyncPostgresSaver: The database saver instance
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
conn_string = self.get_connection_string()
|
|
52
|
+
|
|
53
|
+
logger.info(
|
|
54
|
+
f"Creating PostgreSQL connection pool: min_size={settings.POSTGRES_MIN_SIZE}, "
|
|
55
|
+
f"max_size={settings.POSTGRES_POOL_SIZE}, max_idle={settings.POSTGRES_MAX_IDLE}"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Use AsyncConnectionPool as an async context manager
|
|
59
|
+
async with AsyncConnectionPool(
|
|
60
|
+
conn_string,
|
|
61
|
+
min_size=settings.POSTGRES_MIN_SIZE,
|
|
62
|
+
max_size=settings.POSTGRES_POOL_SIZE,
|
|
63
|
+
max_idle=settings.POSTGRES_MAX_IDLE,
|
|
64
|
+
kwargs=dict(autocommit=True, prepare_threshold=0, row_factory=dict_row),
|
|
65
|
+
) as pool:
|
|
66
|
+
logger.info("PostgreSQL connection pool opened successfully")
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
yield AsyncPostgresSaver(conn=pool)
|
|
70
|
+
finally:
|
|
71
|
+
logger.info("PostgreSQL connection pool will be closed automatically")
|
|
72
|
+
|
|
73
|
+
def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncPostgresSaver]:
|
|
74
|
+
"""Initialize and return a PostgreSQL saver instance."""
|
|
75
|
+
self.validate_config()
|
|
76
|
+
return self.get_saver()
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from contextlib import AbstractAsyncContextManager
|
|
2
|
+
|
|
3
|
+
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
4
|
+
|
|
5
|
+
from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
|
|
6
|
+
from langgraph_agent_toolkit.core.settings import settings
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SQLiteMemoryBackend(BaseMemoryBackend):
|
|
10
|
+
"""SQLite implementation of memory backend."""
|
|
11
|
+
|
|
12
|
+
def validate_config(self) -> bool:
|
|
13
|
+
"""Validate that SQLite configuration is present."""
|
|
14
|
+
if not getattr(settings, "SQLITE_DB_PATH", None):
|
|
15
|
+
raise ValueError("Missing SQLITE_DB_PATH configuration. This must be set to use SQLite persistence.")
|
|
16
|
+
return True
|
|
17
|
+
|
|
18
|
+
def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncSqliteSaver]:
|
|
19
|
+
"""Initialize and return a SQLite saver instance."""
|
|
20
|
+
self.validate_config()
|
|
21
|
+
return AsyncSqliteSaver.from_conn_string(settings.SQLITE_DB_PATH)
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from typing import Dict, Optional, Union
|
|
2
|
+
|
|
3
|
+
import openai
|
|
4
|
+
from langchain_core.outputs import ChatResult
|
|
5
|
+
from langchain_openai import ChatOpenAI
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ChatOpenAIPatched(ChatOpenAI):
|
|
9
|
+
def _create_chat_result(
|
|
10
|
+
self,
|
|
11
|
+
response: Union[dict, openai.BaseModel],
|
|
12
|
+
generation_info: Optional[Dict] = None,
|
|
13
|
+
) -> ChatResult:
|
|
14
|
+
for idx in range(len(response.choices)):
|
|
15
|
+
resp = response.choices[idx]
|
|
16
|
+
# fix the role of the message
|
|
17
|
+
if resp.message.role.startswith("assistant"):
|
|
18
|
+
resp.message.role = "assistant"
|
|
19
|
+
response.choices[idx] = resp
|
|
20
|
+
|
|
21
|
+
return super()._create_chat_result(response, generation_info)
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from functools import cache
|
|
3
|
+
from typing import (
|
|
4
|
+
Any,
|
|
5
|
+
List,
|
|
6
|
+
Literal,
|
|
7
|
+
Optional,
|
|
8
|
+
Tuple,
|
|
9
|
+
Union,
|
|
10
|
+
cast,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from langchain.chat_models.base import _ConfigurableModel, _init_chat_model_helper
|
|
14
|
+
from langchain_core.language_models import BaseChatModel
|
|
15
|
+
from langchain_core.runnables import RunnableSerializable
|
|
16
|
+
from typing_extensions import TypeAlias
|
|
17
|
+
|
|
18
|
+
from langgraph_agent_toolkit.core.models import ChatOpenAIPatched, FakeToolModel
|
|
19
|
+
from langgraph_agent_toolkit.core.settings import settings
|
|
20
|
+
from langgraph_agent_toolkit.helper.constants import DEFAULT_OPENAI_COMPATIBLE_MODEL_PARAMS
|
|
21
|
+
from langgraph_agent_toolkit.schema.models import (
|
|
22
|
+
AllModelEnum,
|
|
23
|
+
FakeModelName,
|
|
24
|
+
OpenAICompatibleName,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
ModelT: TypeAlias = ChatOpenAIPatched | FakeToolModel | RunnableSerializable | _ConfigurableModel
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ModelFactory:
|
|
32
|
+
"""Factory for creating model instances."""
|
|
33
|
+
|
|
34
|
+
# Map model enum names to their respective API model names
|
|
35
|
+
_MODEL_TABLE = {
|
|
36
|
+
OpenAICompatibleName.OPENAI_COMPATIBLE: settings.COMPATIBLE_MODEL,
|
|
37
|
+
FakeModelName.FAKE: "fake",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def __init_chat_model_helper(model: str, *, model_provider: Optional[str] = None, **kwargs: Any) -> BaseChatModel:
|
|
42
|
+
if model_provider == "openai":
|
|
43
|
+
return ChatOpenAIPatched(model_name=model, **kwargs)
|
|
44
|
+
else:
|
|
45
|
+
return _init_chat_model_helper(model, model_provider=model_provider, **kwargs)
|
|
46
|
+
|
|
47
|
+
@staticmethod
|
|
48
|
+
def init_chat_model(
|
|
49
|
+
model: Optional[str] = None,
|
|
50
|
+
*,
|
|
51
|
+
model_provider: Optional[str] = None,
|
|
52
|
+
configurable_fields: Optional[Union[Literal["any"], List[str], Tuple[str, ...]]] = None,
|
|
53
|
+
config_prefix: Optional[str] = None,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> Union[BaseChatModel, _ConfigurableModel]:
|
|
56
|
+
if not model and not configurable_fields:
|
|
57
|
+
configurable_fields = ("model", "model_provider")
|
|
58
|
+
config_prefix = config_prefix or ""
|
|
59
|
+
|
|
60
|
+
if config_prefix and not configurable_fields:
|
|
61
|
+
warnings.warn(
|
|
62
|
+
f"{config_prefix=} has been set but no fields are configurable. Set "
|
|
63
|
+
f"`configurable_fields=(...)` to specify the model params that are "
|
|
64
|
+
f"configurable."
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if not configurable_fields:
|
|
68
|
+
return ModelFactory.__init_chat_model_helper(cast(str, model), model_provider=model_provider, **kwargs)
|
|
69
|
+
else:
|
|
70
|
+
if model:
|
|
71
|
+
kwargs["model"] = model
|
|
72
|
+
if model_provider:
|
|
73
|
+
kwargs["model_provider"] = model_provider
|
|
74
|
+
return _ConfigurableModel(
|
|
75
|
+
default_config=kwargs,
|
|
76
|
+
config_prefix=config_prefix,
|
|
77
|
+
configurable_fields=configurable_fields,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
@cache
|
|
82
|
+
def create(model_name: AllModelEnum) -> ModelT:
|
|
83
|
+
"""Create and return a model instance.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
model_name: The model to create from AllModelEnum
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
An instance of the requested model
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ValueError: If the requested model is not supported
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
api_model_name = ModelFactory._MODEL_TABLE.get(model_name)
|
|
96
|
+
if not api_model_name:
|
|
97
|
+
raise ValueError(f"Unsupported model: {model_name}")
|
|
98
|
+
|
|
99
|
+
match model_name:
|
|
100
|
+
case name if name in OpenAICompatibleName:
|
|
101
|
+
if not settings.COMPATIBLE_BASE_URL or not settings.COMPATIBLE_MODEL:
|
|
102
|
+
raise ValueError("OpenAICompatible base url and endpoint must be configured")
|
|
103
|
+
|
|
104
|
+
model = ModelFactory.init_chat_model(
|
|
105
|
+
model=settings.COMPATIBLE_MODEL,
|
|
106
|
+
model_provider="openai",
|
|
107
|
+
configurable_fields=("temperature", "max_tokens", "top_p", "streaming"),
|
|
108
|
+
config_prefix="agent",
|
|
109
|
+
openai_api_base=settings.COMPATIBLE_BASE_URL,
|
|
110
|
+
openai_api_key=settings.COMPATIBLE_API_KEY,
|
|
111
|
+
**DEFAULT_OPENAI_COMPATIBLE_MODEL_PARAMS,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return model
|
|
115
|
+
case name if name in FakeModelName:
|
|
116
|
+
return FakeToolModel(responses=["This is a test response from the fake model."])
|
|
117
|
+
case _:
|
|
118
|
+
raise ValueError(f"Unsupported model: {model_name}")
|