langgraph-agent-toolkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. langgraph_agent_toolkit/__init__.py +7 -0
  2. langgraph_agent_toolkit/agents/__init__.py +0 -0
  3. langgraph_agent_toolkit/agents/agent.py +14 -0
  4. langgraph_agent_toolkit/agents/agent_executor.py +415 -0
  5. langgraph_agent_toolkit/agents/blueprints/__init__.py +0 -0
  6. langgraph_agent_toolkit/agents/blueprints/bg_task_agent/__init__.py +0 -0
  7. langgraph_agent_toolkit/agents/blueprints/bg_task_agent/agent.py +69 -0
  8. langgraph_agent_toolkit/agents/blueprints/bg_task_agent/task.py +52 -0
  9. langgraph_agent_toolkit/agents/blueprints/bg_task_agent/utils.py +17 -0
  10. langgraph_agent_toolkit/agents/blueprints/chatbot/__init__.py +0 -0
  11. langgraph_agent_toolkit/agents/blueprints/chatbot/agent.py +34 -0
  12. langgraph_agent_toolkit/agents/blueprints/command_agent/__init__.py +0 -0
  13. langgraph_agent_toolkit/agents/blueprints/command_agent/agent.py +54 -0
  14. langgraph_agent_toolkit/agents/blueprints/interrupt_agent/__init__.py +0 -0
  15. langgraph_agent_toolkit/agents/blueprints/interrupt_agent/agent.py +140 -0
  16. langgraph_agent_toolkit/agents/blueprints/react/__init__.py +0 -0
  17. langgraph_agent_toolkit/agents/blueprints/react/agent.py +67 -0
  18. langgraph_agent_toolkit/agents/blueprints/react_so/__init__.py +0 -0
  19. langgraph_agent_toolkit/agents/blueprints/react_so/agent.py +39 -0
  20. langgraph_agent_toolkit/agents/blueprints/supervisor_agent/__init__.py +0 -0
  21. langgraph_agent_toolkit/agents/blueprints/supervisor_agent/agent.py +44 -0
  22. langgraph_agent_toolkit/agents/components/__init__.py +0 -0
  23. langgraph_agent_toolkit/agents/components/creators/__init__.py +4 -0
  24. langgraph_agent_toolkit/agents/components/creators/create_react_agent.py +459 -0
  25. langgraph_agent_toolkit/agents/components/tools.py +13 -0
  26. langgraph_agent_toolkit/agents/components/utils.py +42 -0
  27. langgraph_agent_toolkit/client/__init__.py +4 -0
  28. langgraph_agent_toolkit/client/client.py +344 -0
  29. langgraph_agent_toolkit/core/__init__.py +5 -0
  30. langgraph_agent_toolkit/core/memory/__init__.py +0 -0
  31. langgraph_agent_toolkit/core/memory/base.py +33 -0
  32. langgraph_agent_toolkit/core/memory/factory.py +30 -0
  33. langgraph_agent_toolkit/core/memory/postgres.py +76 -0
  34. langgraph_agent_toolkit/core/memory/sqlite.py +21 -0
  35. langgraph_agent_toolkit/core/memory/types.py +6 -0
  36. langgraph_agent_toolkit/core/models/__init__.py +5 -0
  37. langgraph_agent_toolkit/core/models/chat_openai.py +21 -0
  38. langgraph_agent_toolkit/core/models/factory.py +118 -0
  39. langgraph_agent_toolkit/core/models/fake.py +25 -0
  40. langgraph_agent_toolkit/core/observability/__init__.py +10 -0
  41. langgraph_agent_toolkit/core/observability/base.py +331 -0
  42. langgraph_agent_toolkit/core/observability/empty.py +67 -0
  43. langgraph_agent_toolkit/core/observability/factory.py +43 -0
  44. langgraph_agent_toolkit/core/observability/langfuse.py +118 -0
  45. langgraph_agent_toolkit/core/observability/langsmith.py +131 -0
  46. langgraph_agent_toolkit/core/observability/types.py +34 -0
  47. langgraph_agent_toolkit/core/prompts/__init__.py +0 -0
  48. langgraph_agent_toolkit/core/prompts/chat_prompt_template.py +528 -0
  49. langgraph_agent_toolkit/core/settings.py +164 -0
  50. langgraph_agent_toolkit/helper/__init__.py +0 -0
  51. langgraph_agent_toolkit/helper/constants.py +10 -0
  52. langgraph_agent_toolkit/helper/logging.py +111 -0
  53. langgraph_agent_toolkit/helper/types.py +7 -0
  54. langgraph_agent_toolkit/helper/utils.py +80 -0
  55. langgraph_agent_toolkit/run_agent.py +68 -0
  56. langgraph_agent_toolkit/run_client.py +55 -0
  57. langgraph_agent_toolkit/run_service.py +19 -0
  58. langgraph_agent_toolkit/schema/__init__.py +28 -0
  59. langgraph_agent_toolkit/schema/models.py +25 -0
  60. langgraph_agent_toolkit/schema/schema.py +210 -0
  61. langgraph_agent_toolkit/schema/task_data.py +72 -0
  62. langgraph_agent_toolkit/service/__init__.py +0 -0
  63. langgraph_agent_toolkit/service/exception_handlers.py +46 -0
  64. langgraph_agent_toolkit/service/factory.py +213 -0
  65. langgraph_agent_toolkit/service/handler.py +122 -0
  66. langgraph_agent_toolkit/service/middleware.py +18 -0
  67. langgraph_agent_toolkit/service/routes.py +169 -0
  68. langgraph_agent_toolkit/service/types.py +8 -0
  69. langgraph_agent_toolkit/service/utils.py +136 -0
  70. langgraph_agent_toolkit/streamlit_app.py +368 -0
  71. langgraph_agent_toolkit-0.1.0.dist-info/METADATA +424 -0
  72. langgraph_agent_toolkit-0.1.0.dist-info/RECORD +74 -0
  73. langgraph_agent_toolkit-0.1.0.dist-info/WHEEL +4 -0
  74. langgraph_agent_toolkit-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,344 @@
1
+ import json
2
+ import os
3
+ from collections.abc import AsyncGenerator, Generator
4
+ from typing import Any
5
+
6
+ import httpx
7
+
8
+ from langgraph_agent_toolkit.schema import (
9
+ ChatHistory,
10
+ ChatHistoryInput,
11
+ ChatMessage,
12
+ Feedback,
13
+ ServiceMetadata,
14
+ StreamInput,
15
+ UserInput,
16
+ )
17
+
18
+
19
+ class AgentClientError(Exception):
20
+ pass
21
+
22
+
23
+ class AgentClient:
24
+ """Client for interacting with the agent service."""
25
+
26
+ def __init__(
27
+ self,
28
+ base_url: str = "http://0.0.0.0",
29
+ agent: str = None,
30
+ timeout: float | None = None,
31
+ get_info: bool = True,
32
+ verify: bool = False,
33
+ ) -> None:
34
+ """Initialize the client.
35
+
36
+ Args:
37
+ base_url (str): The base URL of the agent service.
38
+ agent (str): The name of the default agent to use.
39
+ timeout (float, optional): The timeout for requests.
40
+ get_info (bool, optional): Whether to fetch agent information on init.
41
+ Default: True
42
+ verify (bool, optional): Whether to verify the agent information.
43
+ Default: False
44
+
45
+ """
46
+ self.base_url = base_url
47
+ self.auth_secret = os.getenv("AUTH_SECRET")
48
+ self.timeout = timeout
49
+ self.info: ServiceMetadata | None = None
50
+ self.agent: str | None = None
51
+ if get_info:
52
+ self.retrieve_info()
53
+ if agent:
54
+ self.update_agent(agent, verify=verify)
55
+
56
+ @property
57
+ def _headers(self) -> dict[str, str]:
58
+ headers = {}
59
+ if self.auth_secret:
60
+ headers["Authorization"] = f"Bearer {self.auth_secret}"
61
+ return headers
62
+
63
+ def retrieve_info(self) -> None:
64
+ try:
65
+ response = httpx.get(
66
+ f"{self.base_url}/info",
67
+ headers=self._headers,
68
+ timeout=self.timeout,
69
+ )
70
+ response.raise_for_status()
71
+ except httpx.HTTPError as e:
72
+ raise AgentClientError(f"Error getting service info: {e}")
73
+
74
+ self.info: ServiceMetadata = ServiceMetadata.model_validate(response.json())
75
+ if not self.agent or self.agent not in [a.key for a in self.info.agents]:
76
+ self.agent = self.info.default_agent
77
+
78
+ def update_agent(self, agent: str, verify: bool = True) -> None:
79
+ if verify:
80
+ if not self.info:
81
+ self.retrieve_info()
82
+ agent_keys = [a.key for a in self.info.agents]
83
+ if agent not in agent_keys:
84
+ raise AgentClientError(f"Agent {agent} not found in available agents: {', '.join(agent_keys)}")
85
+ self.agent = agent
86
+
87
+ async def ainvoke(
88
+ self,
89
+ message: str,
90
+ model: str | None = None,
91
+ thread_id: str | None = None,
92
+ agent_config: dict[str, Any] | None = None,
93
+ ) -> ChatMessage:
94
+ """Invoke the agent asynchronously. Only the final message is returned.
95
+
96
+ Args:
97
+ message (str): The message to send to the agent
98
+ model (str, optional): LLM model to use for the agent
99
+ thread_id (str, optional): Thread ID for continuing a conversation
100
+ agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
101
+
102
+ Returns:
103
+ AnyMessage: The response from the agent
104
+
105
+ """
106
+ if not self.agent:
107
+ raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
108
+ request = UserInput(message=message)
109
+ if thread_id:
110
+ request.thread_id = thread_id
111
+ if model:
112
+ request.model = model
113
+ if agent_config:
114
+ request.agent_config = agent_config
115
+ async with httpx.AsyncClient() as client:
116
+ try:
117
+ response = await client.post(
118
+ f"{self.base_url}/{self.agent}/invoke",
119
+ json=request.model_dump(),
120
+ headers=self._headers,
121
+ timeout=self.timeout,
122
+ )
123
+ response.raise_for_status()
124
+ except httpx.HTTPError as e:
125
+ raise AgentClientError(f"Error: {e}")
126
+
127
+ return ChatMessage.model_validate(response.json())
128
+
129
+ def invoke(
130
+ self,
131
+ message: str,
132
+ model: str | None = None,
133
+ thread_id: str | None = None,
134
+ agent_config: dict[str, Any] | None = None,
135
+ ) -> ChatMessage:
136
+ """Invoke the agent synchronously. Only the final message is returned.
137
+
138
+ Args:
139
+ message (str): The message to send to the agent
140
+ model (str, optional): LLM model to use for the agent
141
+ thread_id (str, optional): Thread ID for continuing a conversation
142
+ agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
143
+
144
+ Returns:
145
+ ChatMessage: The response from the agent
146
+
147
+ """
148
+ if not self.agent:
149
+ raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
150
+ request = UserInput(message=message)
151
+ if thread_id:
152
+ request.thread_id = thread_id
153
+ if model:
154
+ request.model = model
155
+ if agent_config:
156
+ request.agent_config = agent_config
157
+ try:
158
+ response = httpx.post(
159
+ f"{self.base_url}/{self.agent}/invoke",
160
+ json=request.model_dump(),
161
+ headers=self._headers,
162
+ timeout=self.timeout,
163
+ )
164
+ response.raise_for_status()
165
+ except httpx.HTTPError as e:
166
+ raise AgentClientError(f"Error: {e}")
167
+
168
+ return ChatMessage.model_validate(response.json())
169
+
170
+ def _parse_stream_line(self, line: str) -> ChatMessage | str | None:
171
+ line = line.strip()
172
+ if line.startswith("data: "):
173
+ data = line[6:]
174
+ if data == "[DONE]":
175
+ return None
176
+ try:
177
+ parsed = json.loads(data)
178
+ except Exception as e:
179
+ raise Exception(f"Error JSON parsing message from server: {e}")
180
+ match parsed["type"]:
181
+ case "message":
182
+ # Convert the JSON formatted message to an AnyMessage
183
+ try:
184
+ return ChatMessage.model_validate(parsed["content"])
185
+ except Exception as e:
186
+ raise Exception(f"Server returned invalid message: {e}")
187
+ case "token":
188
+ # Yield the str token directly
189
+ return parsed["content"]
190
+ case "error":
191
+ raise Exception(parsed["content"])
192
+ return None
193
+
194
+ def stream(
195
+ self,
196
+ message: str,
197
+ model: str | None = None,
198
+ thread_id: str | None = None,
199
+ agent_config: dict[str, Any] | None = None,
200
+ stream_tokens: bool = True,
201
+ ) -> Generator[ChatMessage | str, None, None]:
202
+ """Stream the agent's response synchronously.
203
+
204
+ Each intermediate message of the agent process is yielded as a ChatMessage.
205
+ If stream_tokens is True (the default value), the response will also yield
206
+ content tokens from streaming models as they are generated.
207
+
208
+ Args:
209
+ message (str): The message to send to the agent
210
+ model (str, optional): LLM model to use for the agent
211
+ thread_id (str, optional): Thread ID for continuing a conversation
212
+ agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
213
+ stream_tokens (bool, optional): Stream tokens as they are generated
214
+ Default: True
215
+
216
+ Returns:
217
+ Generator[ChatMessage | str, None, None]: The response from the agent
218
+
219
+ """
220
+ if not self.agent:
221
+ raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
222
+ request = StreamInput(message=message, stream_tokens=stream_tokens)
223
+ if thread_id:
224
+ request.thread_id = thread_id
225
+ if model:
226
+ request.model = model
227
+ if agent_config:
228
+ request.agent_config = agent_config
229
+ try:
230
+ with httpx.stream(
231
+ "POST",
232
+ f"{self.base_url}/{self.agent}/stream",
233
+ json=request.model_dump(),
234
+ headers=self._headers,
235
+ timeout=self.timeout,
236
+ ) as response:
237
+ response.raise_for_status()
238
+ for line in response.iter_lines():
239
+ if line.strip():
240
+ parsed = self._parse_stream_line(line)
241
+ if parsed is None:
242
+ break
243
+ yield parsed
244
+ except httpx.HTTPError as e:
245
+ raise AgentClientError(f"Error: {e}")
246
+
247
+ async def astream(
248
+ self,
249
+ message: str,
250
+ model: str | None = None,
251
+ thread_id: str | None = None,
252
+ agent_config: dict[str, Any] | None = None,
253
+ stream_tokens: bool = True,
254
+ ) -> AsyncGenerator[ChatMessage | str, None]:
255
+ """Stream the agent's response asynchronously.
256
+
257
+ Each intermediate message of the agent process is yielded as an AnyMessage.
258
+ If stream_tokens is True (the default value), the response will also yield
259
+ content tokens from streaming modelsas they are generated.
260
+
261
+ Args:
262
+ message (str): The message to send to the agent
263
+ model (str, optional): LLM model to use for the agent
264
+ thread_id (str, optional): Thread ID for continuing a conversation
265
+ agent_config (dict[str, Any], optional): Additional configuration to pass through to the agent
266
+ stream_tokens (bool, optional): Stream tokens as they are generated
267
+ Default: True
268
+
269
+ Returns:
270
+ AsyncGenerator[ChatMessage | str, None]: The response from the agent
271
+
272
+ """
273
+ if not self.agent:
274
+ raise AgentClientError("No agent selected. Use update_agent() to select an agent.")
275
+ request = StreamInput(message=message, stream_tokens=stream_tokens)
276
+ if thread_id:
277
+ request.thread_id = thread_id
278
+ if model:
279
+ request.model = model
280
+ if agent_config:
281
+ request.agent_config = agent_config
282
+ async with httpx.AsyncClient() as client:
283
+ try:
284
+ async with client.stream(
285
+ "POST",
286
+ f"{self.base_url}/{self.agent}/stream",
287
+ json=request.model_dump(),
288
+ headers=self._headers,
289
+ timeout=self.timeout,
290
+ ) as response:
291
+ response.raise_for_status()
292
+ async for line in response.aiter_lines():
293
+ if line.strip():
294
+ parsed = self._parse_stream_line(line)
295
+ if parsed is None:
296
+ break
297
+ yield parsed
298
+ except httpx.HTTPError as e:
299
+ raise AgentClientError(f"Error: {e}")
300
+
301
+ async def acreate_feedback(self, run_id: str, key: str, score: float, kwargs: dict[str, Any] = {}) -> None:
302
+ """Create a feedback record for a run.
303
+
304
+ This is a simple wrapper for the LangSmith create_feedback API, so the
305
+ credentials can be stored and managed in the service rather than the client.
306
+ See: https://api.smith.langchain.com/redoc#tag/feedback/operation/create_feedback_api_v1_feedback_post
307
+ """
308
+ request = Feedback(run_id=run_id, key=key, score=score, kwargs=kwargs)
309
+ async with httpx.AsyncClient() as client:
310
+ try:
311
+ response = await client.post(
312
+ f"{self.base_url}/feedback",
313
+ json=request.model_dump(),
314
+ headers=self._headers,
315
+ timeout=self.timeout,
316
+ )
317
+ response.raise_for_status()
318
+ response.json()
319
+ except httpx.HTTPError as e:
320
+ raise AgentClientError(f"Error: {e}")
321
+
322
+ def get_history(
323
+ self,
324
+ thread_id: str,
325
+ ) -> ChatHistory:
326
+ """Get chat history.
327
+
328
+ Args:
329
+ thread_id (str, optional): Thread ID for identifying a conversation
330
+
331
+ """
332
+ request = ChatHistoryInput(thread_id=thread_id)
333
+ try:
334
+ response = httpx.post(
335
+ f"{self.base_url}/history",
336
+ json=request.model_dump(),
337
+ headers=self._headers,
338
+ timeout=self.timeout,
339
+ )
340
+ response.raise_for_status()
341
+ except httpx.HTTPError as e:
342
+ raise AgentClientError(f"Error: {e}")
343
+
344
+ return ChatHistory.model_validate(response.json())
@@ -0,0 +1,5 @@
1
+ from langgraph_agent_toolkit.core.models.factory import ModelFactory
2
+ from langgraph_agent_toolkit.core.settings import settings
3
+
4
+
5
+ __all__ = ["settings", "ModelFactory"]
File without changes
@@ -0,0 +1,33 @@
1
+ from abc import ABC, abstractmethod
2
+ from contextlib import AbstractAsyncContextManager
3
+ from typing import Any, TypeVar
4
+
5
+
6
+ T = TypeVar("T", bound=Any)
7
+
8
+
9
+ class BaseMemoryBackend(ABC):
10
+ """Base class for memory backends."""
11
+
12
+ @abstractmethod
13
+ def validate_config(self) -> bool:
14
+ """Validate that all necessary configuration is set.
15
+
16
+ Returns:
17
+ True if configuration is valid
18
+
19
+ Raises:
20
+ ValueError: If required configuration is missing
21
+
22
+ """
23
+ pass
24
+
25
+ @abstractmethod
26
+ def get_checkpoint_saver(self) -> AbstractAsyncContextManager[T]:
27
+ """Get the checkpoint saver for the memory backend.
28
+
29
+ Returns:
30
+ A configured checkpoint saver
31
+
32
+ """
33
+ pass
@@ -0,0 +1,30 @@
1
+ from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
2
+ from langgraph_agent_toolkit.core.memory.postgres import PostgresMemoryBackend
3
+ from langgraph_agent_toolkit.core.memory.sqlite import SQLiteMemoryBackend
4
+ from langgraph_agent_toolkit.core.memory.types import MemoryBackends
5
+
6
+
7
+ class MemoryFactory:
8
+ """Factory for creating memory backend instances."""
9
+
10
+ @staticmethod
11
+ def create(backend: MemoryBackends) -> BaseMemoryBackend:
12
+ """Create and return a memory backend instance.
13
+
14
+ Args:
15
+ backend: The memory backend to create
16
+
17
+ Returns:
18
+ An instance of the requested memory backend
19
+
20
+ Raises:
21
+ ValueError: If the requested backend is not supported
22
+
23
+ """
24
+ match backend:
25
+ case MemoryBackends.POSTGRES:
26
+ return PostgresMemoryBackend()
27
+ case MemoryBackends.SQLITE:
28
+ return SQLiteMemoryBackend()
29
+ case _:
30
+ raise ValueError(f"Unsupported memory backend: {backend}")
@@ -0,0 +1,76 @@
1
+ from collections.abc import AsyncGenerator
2
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
3
+
4
+ from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
5
+ from psycopg.rows import dict_row
6
+ from psycopg_pool import AsyncConnectionPool
7
+
8
+ from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
9
+ from langgraph_agent_toolkit.core.settings import settings
10
+ from langgraph_agent_toolkit.helper.logging import logger
11
+
12
+
13
+ class PostgresMemoryBackend(BaseMemoryBackend):
14
+ """PostgreSQL implementation of memory backend."""
15
+
16
+ def validate_config(self) -> bool:
17
+ """Validate that all required PostgreSQL configuration is present."""
18
+ required_vars = [
19
+ "POSTGRES_USER",
20
+ "POSTGRES_PASSWORD",
21
+ "POSTGRES_HOST",
22
+ "POSTGRES_PORT",
23
+ "POSTGRES_DB",
24
+ ]
25
+
26
+ missing = [var for var in required_vars if not getattr(settings, var, None)]
27
+ if missing:
28
+ raise ValueError(
29
+ f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
30
+ "These environment variables must be set to use PostgreSQL persistence."
31
+ )
32
+ return True
33
+
34
+ def get_connection_string(self) -> str:
35
+ """Build and return the PostgreSQL connection string from settings."""
36
+ return (
37
+ f"postgresql://{settings.POSTGRES_USER}:"
38
+ f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
39
+ f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
40
+ f"{settings.POSTGRES_DB}"
41
+ )
42
+
43
+ @asynccontextmanager
44
+ async def get_saver(self) -> AsyncGenerator[AsyncPostgresSaver, None]:
45
+ """Asynchronous context manager for acquiring and releasing the PostgreSQL connection pool.
46
+
47
+ Yields:
48
+ AsyncPostgresSaver: The database saver instance
49
+
50
+ """
51
+ conn_string = self.get_connection_string()
52
+
53
+ logger.info(
54
+ f"Creating PostgreSQL connection pool: min_size={settings.POSTGRES_MIN_SIZE}, "
55
+ f"max_size={settings.POSTGRES_POOL_SIZE}, max_idle={settings.POSTGRES_MAX_IDLE}"
56
+ )
57
+
58
+ # Use AsyncConnectionPool as an async context manager
59
+ async with AsyncConnectionPool(
60
+ conn_string,
61
+ min_size=settings.POSTGRES_MIN_SIZE,
62
+ max_size=settings.POSTGRES_POOL_SIZE,
63
+ max_idle=settings.POSTGRES_MAX_IDLE,
64
+ kwargs=dict(autocommit=True, prepare_threshold=0, row_factory=dict_row),
65
+ ) as pool:
66
+ logger.info("PostgreSQL connection pool opened successfully")
67
+
68
+ try:
69
+ yield AsyncPostgresSaver(conn=pool)
70
+ finally:
71
+ logger.info("PostgreSQL connection pool will be closed automatically")
72
+
73
+ def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncPostgresSaver]:
74
+ """Initialize and return a PostgreSQL saver instance."""
75
+ self.validate_config()
76
+ return self.get_saver()
@@ -0,0 +1,21 @@
1
+ from contextlib import AbstractAsyncContextManager
2
+
3
+ from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
4
+
5
+ from langgraph_agent_toolkit.core.memory.base import BaseMemoryBackend
6
+ from langgraph_agent_toolkit.core.settings import settings
7
+
8
+
9
+ class SQLiteMemoryBackend(BaseMemoryBackend):
10
+ """SQLite implementation of memory backend."""
11
+
12
+ def validate_config(self) -> bool:
13
+ """Validate that SQLite configuration is present."""
14
+ if not getattr(settings, "SQLITE_DB_PATH", None):
15
+ raise ValueError("Missing SQLITE_DB_PATH configuration. This must be set to use SQLite persistence.")
16
+ return True
17
+
18
+ def get_checkpoint_saver(self) -> AbstractAsyncContextManager[AsyncSqliteSaver]:
19
+ """Initialize and return a SQLite saver instance."""
20
+ self.validate_config()
21
+ return AsyncSqliteSaver.from_conn_string(settings.SQLITE_DB_PATH)
@@ -0,0 +1,6 @@
1
+ from enum import StrEnum, auto
2
+
3
+
4
+ class MemoryBackends(StrEnum):
5
+ POSTGRES = auto()
6
+ SQLITE = auto()
@@ -0,0 +1,5 @@
1
+ from langgraph_agent_toolkit.core.models.chat_openai import ChatOpenAIPatched
2
+ from langgraph_agent_toolkit.core.models.fake import FakeToolModel
3
+
4
+
5
+ __all__ = ["ChatOpenAIPatched", "FakeToolModel"]
@@ -0,0 +1,21 @@
1
+ from typing import Dict, Optional, Union
2
+
3
+ import openai
4
+ from langchain_core.outputs import ChatResult
5
+ from langchain_openai import ChatOpenAI
6
+
7
+
8
+ class ChatOpenAIPatched(ChatOpenAI):
9
+ def _create_chat_result(
10
+ self,
11
+ response: Union[dict, openai.BaseModel],
12
+ generation_info: Optional[Dict] = None,
13
+ ) -> ChatResult:
14
+ for idx in range(len(response.choices)):
15
+ resp = response.choices[idx]
16
+ # fix the role of the message
17
+ if resp.message.role.startswith("assistant"):
18
+ resp.message.role = "assistant"
19
+ response.choices[idx] = resp
20
+
21
+ return super()._create_chat_result(response, generation_info)
@@ -0,0 +1,118 @@
1
+ import warnings
2
+ from functools import cache
3
+ from typing import (
4
+ Any,
5
+ List,
6
+ Literal,
7
+ Optional,
8
+ Tuple,
9
+ Union,
10
+ cast,
11
+ )
12
+
13
+ from langchain.chat_models.base import _ConfigurableModel, _init_chat_model_helper
14
+ from langchain_core.language_models import BaseChatModel
15
+ from langchain_core.runnables import RunnableSerializable
16
+ from typing_extensions import TypeAlias
17
+
18
+ from langgraph_agent_toolkit.core.models import ChatOpenAIPatched, FakeToolModel
19
+ from langgraph_agent_toolkit.core.settings import settings
20
+ from langgraph_agent_toolkit.helper.constants import DEFAULT_OPENAI_COMPATIBLE_MODEL_PARAMS
21
+ from langgraph_agent_toolkit.schema.models import (
22
+ AllModelEnum,
23
+ FakeModelName,
24
+ OpenAICompatibleName,
25
+ )
26
+
27
+
28
+ ModelT: TypeAlias = ChatOpenAIPatched | FakeToolModel | RunnableSerializable | _ConfigurableModel
29
+
30
+
31
+ class ModelFactory:
32
+ """Factory for creating model instances."""
33
+
34
+ # Map model enum names to their respective API model names
35
+ _MODEL_TABLE = {
36
+ OpenAICompatibleName.OPENAI_COMPATIBLE: settings.COMPATIBLE_MODEL,
37
+ FakeModelName.FAKE: "fake",
38
+ }
39
+
40
+ @staticmethod
41
+ def __init_chat_model_helper(model: str, *, model_provider: Optional[str] = None, **kwargs: Any) -> BaseChatModel:
42
+ if model_provider == "openai":
43
+ return ChatOpenAIPatched(model_name=model, **kwargs)
44
+ else:
45
+ return _init_chat_model_helper(model, model_provider=model_provider, **kwargs)
46
+
47
+ @staticmethod
48
+ def init_chat_model(
49
+ model: Optional[str] = None,
50
+ *,
51
+ model_provider: Optional[str] = None,
52
+ configurable_fields: Optional[Union[Literal["any"], List[str], Tuple[str, ...]]] = None,
53
+ config_prefix: Optional[str] = None,
54
+ **kwargs: Any,
55
+ ) -> Union[BaseChatModel, _ConfigurableModel]:
56
+ if not model and not configurable_fields:
57
+ configurable_fields = ("model", "model_provider")
58
+ config_prefix = config_prefix or ""
59
+
60
+ if config_prefix and not configurable_fields:
61
+ warnings.warn(
62
+ f"{config_prefix=} has been set but no fields are configurable. Set "
63
+ f"`configurable_fields=(...)` to specify the model params that are "
64
+ f"configurable."
65
+ )
66
+
67
+ if not configurable_fields:
68
+ return ModelFactory.__init_chat_model_helper(cast(str, model), model_provider=model_provider, **kwargs)
69
+ else:
70
+ if model:
71
+ kwargs["model"] = model
72
+ if model_provider:
73
+ kwargs["model_provider"] = model_provider
74
+ return _ConfigurableModel(
75
+ default_config=kwargs,
76
+ config_prefix=config_prefix,
77
+ configurable_fields=configurable_fields,
78
+ )
79
+
80
+ @staticmethod
81
+ @cache
82
+ def create(model_name: AllModelEnum) -> ModelT:
83
+ """Create and return a model instance.
84
+
85
+ Args:
86
+ model_name: The model to create from AllModelEnum
87
+
88
+ Returns:
89
+ An instance of the requested model
90
+
91
+ Raises:
92
+ ValueError: If the requested model is not supported
93
+
94
+ """
95
+ api_model_name = ModelFactory._MODEL_TABLE.get(model_name)
96
+ if not api_model_name:
97
+ raise ValueError(f"Unsupported model: {model_name}")
98
+
99
+ match model_name:
100
+ case name if name in OpenAICompatibleName:
101
+ if not settings.COMPATIBLE_BASE_URL or not settings.COMPATIBLE_MODEL:
102
+ raise ValueError("OpenAICompatible base url and endpoint must be configured")
103
+
104
+ model = ModelFactory.init_chat_model(
105
+ model=settings.COMPATIBLE_MODEL,
106
+ model_provider="openai",
107
+ configurable_fields=("temperature", "max_tokens", "top_p", "streaming"),
108
+ config_prefix="agent",
109
+ openai_api_base=settings.COMPATIBLE_BASE_URL,
110
+ openai_api_key=settings.COMPATIBLE_API_KEY,
111
+ **DEFAULT_OPENAI_COMPATIBLE_MODEL_PARAMS,
112
+ )
113
+
114
+ return model
115
+ case name if name in FakeModelName:
116
+ return FakeToolModel(responses=["This is a test response from the fake model."])
117
+ case _:
118
+ raise ValueError(f"Unsupported model: {model_name}")