dao-ai 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/cli.py CHANGED
@@ -3,6 +3,7 @@ import json
3
3
  import os
4
4
  import subprocess
5
5
  import sys
6
+ import traceback
6
7
  from argparse import ArgumentParser, Namespace
7
8
  from pathlib import Path
8
9
  from typing import Optional, Sequence
@@ -388,7 +389,9 @@ def handle_chat_command(options: Namespace) -> None:
388
389
 
389
390
  except Exception as e:
390
391
  print(f"\n❌ Error during streaming: {e}")
392
+ print(f"Stack trace:\n{traceback.format_exc()}")
391
393
  logger.error(f"Streaming error: {e}")
394
+ logger.error(f"Stack trace: {traceback.format_exc()}")
392
395
 
393
396
  except EOFError:
394
397
  # Handle Ctrl-D
dao_ai/config.py CHANGED
@@ -28,6 +28,7 @@ from langchain_core.language_models import LanguageModelLike
28
28
  from langchain_core.runnables.base import RunnableLike
29
29
  from langchain_openai import ChatOpenAI
30
30
  from langgraph.checkpoint.base import BaseCheckpointSaver
31
+ from langgraph.graph.state import CompiledStateGraph
31
32
  from langgraph.store.base import BaseStore
32
33
  from loguru import logger
33
34
  from mlflow.models import ModelConfig
@@ -41,10 +42,9 @@ from mlflow.models.resources import (
41
42
  DatabricksUCConnection,
42
43
  DatabricksVectorSearchIndex,
43
44
  )
45
+ from mlflow.pyfunc import ChatModel
44
46
  from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
45
47
 
46
- from dao_ai.chat_models import ChatDatabricksFiltered
47
-
48
48
 
49
49
  class HasValue(ABC):
50
50
  @abstractmethod
@@ -275,9 +275,14 @@ class LLMModel(BaseModel, IsDatabricksResource):
275
275
  # chat_client: LanguageModelLike = self.as_open_ai_client()
276
276
 
277
277
  # Create ChatDatabricksWrapper instance directly
278
+ from dao_ai.chat_models import ChatDatabricksFiltered
279
+
278
280
  chat_client: LanguageModelLike = ChatDatabricksFiltered(
279
281
  model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
280
282
  )
283
+ # chat_client: LanguageModelLike = ChatDatabricks(
284
+ # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
285
+ # )
281
286
 
282
287
  fallbacks: Sequence[LanguageModelLike] = []
283
288
  for fallback in self.fallbacks:
@@ -1001,7 +1006,15 @@ class ChatHistoryModel(BaseModel):
1001
1006
  max_tokens: int = 256
1002
1007
  max_tokens_before_summary: Optional[int] = None
1003
1008
  max_messages_before_summary: Optional[int] = None
1004
- max_summary_tokens: Optional[int] = None
1009
+ max_summary_tokens: int = 255
1010
+
1011
+ @model_validator(mode="after")
1012
+ def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1013
+ if self.max_summary_tokens >= self.max_tokens:
1014
+ raise ValueError(
1015
+ f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1016
+ )
1017
+ return self
1005
1018
 
1006
1019
 
1007
1020
  class AppModel(BaseModel):
@@ -1057,12 +1070,19 @@ class AppModel(BaseModel):
1057
1070
  return self
1058
1071
 
1059
1072
 
1073
+ class GuidelineModel(BaseModel):
1074
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1075
+ name: str
1076
+ guidelines: list[str]
1077
+
1078
+
1060
1079
  class EvaluationModel(BaseModel):
1061
1080
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1062
1081
  model: LLMModel
1063
1082
  table: TableModel
1064
1083
  num_evals: int
1065
1084
  custom_inputs: dict[str, Any] = Field(default_factory=dict)
1085
+ guidelines: list[GuidelineModel] = Field(default_factory=list)
1066
1086
 
1067
1087
 
1068
1088
  class DatasetFormat(str, Enum):
@@ -1302,3 +1322,16 @@ class AppConfig(BaseModel):
1302
1322
  return [
1303
1323
  guardrail for guardrail in self.guardrails.values() if predicate(guardrail)
1304
1324
  ]
1325
+
1326
+ def as_graph(self) -> CompiledStateGraph:
1327
+ from dao_ai.graph import create_dao_ai_graph
1328
+
1329
+ graph: CompiledStateGraph = create_dao_ai_graph(config=self)
1330
+ return graph
1331
+
1332
+ def as_chat_model(self) -> ChatModel:
1333
+ from dao_ai.models import create_agent
1334
+
1335
+ graph: CompiledStateGraph = self.as_graph()
1336
+ app: ChatModel = create_agent(graph)
1337
+ return app
dao_ai/graph.py CHANGED
@@ -3,13 +3,10 @@ from typing import Sequence
3
3
  from langchain_core.language_models import LanguageModelLike
4
4
  from langchain_core.runnables import RunnableConfig
5
5
  from langchain_core.tools import BaseTool
6
- from langgraph.cache.base import BaseCache
7
- from langgraph.cache.memory import InMemoryCache
8
6
  from langgraph.checkpoint.base import BaseCheckpointSaver
9
7
  from langgraph.graph import END, StateGraph
10
8
  from langgraph.graph.state import CompiledStateGraph
11
9
  from langgraph.store.base import BaseStore
12
- from langgraph.types import CachePolicy
13
10
  from langgraph_supervisor import create_handoff_tool as supervisor_handoff_tool
14
11
  from langgraph_supervisor import create_supervisor
15
12
  from langgraph_swarm import create_handoff_tool as swarm_handoff_tool
@@ -29,7 +26,7 @@ from dao_ai.nodes import (
29
26
  message_hook_node,
30
27
  )
31
28
  from dao_ai.prompts import make_prompt
32
- from dao_ai.state import IncomingState, OutgoingState, SharedState
29
+ from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
33
30
 
34
31
 
35
32
  def route_message(state: SharedState) -> str:
@@ -113,9 +110,6 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
113
110
  checkpointer = orchestration.memory.checkpointer.as_checkpointer()
114
111
  logger.debug(f"Using checkpointer: {checkpointer}")
115
112
 
116
- cache: BaseCache = None
117
- cache = InMemoryCache()
118
-
119
113
  prompt: str = supervisor.prompt
120
114
 
121
115
  model: LanguageModelLike = supervisor.model.as_chat_model()
@@ -127,22 +121,29 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
127
121
  tools=tools,
128
122
  state_schema=SharedState,
129
123
  config_schema=RunnableConfig,
124
+ output_mode="last_message",
125
+ add_handoff_messages=False,
126
+ add_handoff_back_messages=False,
127
+ context_schema=Context,
128
+ # output_mode="full",
129
+ # add_handoff_messages=True,
130
+ # add_handoff_back_messages=True,
130
131
  )
131
132
 
132
- supervisor_node: CompiledStateGraph = supervisor_workflow.compile()
133
+ supervisor_node: CompiledStateGraph = supervisor_workflow.compile(
134
+ checkpointer=checkpointer, store=store
135
+ )
133
136
 
134
137
  workflow: StateGraph = StateGraph(
135
138
  SharedState,
136
- config_schema=RunnableConfig,
137
139
  input=IncomingState,
138
140
  output=OutgoingState,
141
+ context_schema=Context,
139
142
  )
140
143
 
141
144
  workflow.add_node("message_hook", message_hook_node(config=config))
142
145
 
143
- workflow.add_node(
144
- "orchestration", supervisor_node, cache_policy=CachePolicy(ttl=60)
145
- )
146
+ workflow.add_node("orchestration", supervisor_node)
146
147
  workflow.add_conditional_edges(
147
148
  "message_hook",
148
149
  route_message,
@@ -153,7 +154,7 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
153
154
  )
154
155
  workflow.set_entry_point("message_hook")
155
156
 
156
- return workflow.compile(checkpointer=checkpointer, store=store, cache=cache)
157
+ return workflow.compile(checkpointer=checkpointer, store=store)
157
158
 
158
159
 
159
160
  def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
@@ -172,6 +173,16 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
172
173
  orchestration: OrchestrationModel = config.app.orchestration
173
174
  swarm: SwarmModel = orchestration.swarm
174
175
 
176
+ store: BaseStore = None
177
+ if orchestration.memory and orchestration.memory.store:
178
+ store = orchestration.memory.store.as_store()
179
+ logger.debug(f"Using memory store: {store}")
180
+
181
+ checkpointer: BaseCheckpointSaver = None
182
+ if orchestration.memory and orchestration.memory.checkpointer:
183
+ checkpointer = orchestration.memory.checkpointer.as_checkpointer()
184
+ logger.debug(f"Using checkpointer: {checkpointer}")
185
+
175
186
  default_agent: AgentModel = swarm.default_agent
176
187
  if isinstance(default_agent, AgentModel):
177
188
  default_agent = default_agent.name
@@ -180,20 +191,22 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
180
191
  agents=agents,
181
192
  default_active_agent=default_agent,
182
193
  state_schema=SharedState,
183
- config_schema=RunnableConfig,
194
+ context_schema=Context,
184
195
  )
185
196
 
186
- swarm_node: CompiledStateGraph = swarm_workflow.compile()
197
+ swarm_node: CompiledStateGraph = swarm_workflow.compile(
198
+ checkpointer=checkpointer, store=store
199
+ )
187
200
 
188
201
  workflow: StateGraph = StateGraph(
189
202
  SharedState,
190
- config_schema=RunnableConfig,
191
203
  input=IncomingState,
192
204
  output=OutgoingState,
205
+ context_schema=Context,
193
206
  )
194
207
 
195
208
  workflow.add_node("message_hook", message_hook_node(config=config))
196
- workflow.add_node("orchestration", swarm_node, cache_policy=CachePolicy(ttl=60))
209
+ workflow.add_node("orchestration", swarm_node)
197
210
 
198
211
  workflow.add_conditional_edges(
199
212
  "message_hook",
@@ -206,20 +219,7 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
206
219
 
207
220
  workflow.set_entry_point("message_hook")
208
221
 
209
- store: BaseStore = None
210
- if orchestration.memory and orchestration.memory.store:
211
- store = orchestration.memory.store.as_store()
212
- logger.debug(f"Using memory store: {store}")
213
-
214
- checkpointer: BaseCheckpointSaver = None
215
- if orchestration.memory and orchestration.memory.checkpointer:
216
- checkpointer = orchestration.memory.checkpointer.as_checkpointer()
217
- logger.debug(f"Using checkpointer: {checkpointer}")
218
-
219
- cache: BaseCache = None
220
- cache = InMemoryCache()
221
-
222
- return workflow.compile(checkpointer=checkpointer, store=store, cache=cache)
222
+ return workflow.compile(checkpointer=checkpointer, store=store)
223
223
 
224
224
 
225
225
  def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
dao_ai/hooks/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  from dao_ai.hooks.core import (
2
2
  create_hooks,
3
+ filter_last_human_message_hook,
3
4
  null_hook,
4
5
  null_initialization_hook,
5
6
  null_shutdown_hook,
@@ -14,4 +15,5 @@ __all__ = [
14
15
  "null_shutdown_hook",
15
16
  "require_thread_id_hook",
16
17
  "require_user_id_hook",
18
+ "filter_last_human_message_hook",
17
19
  ]
dao_ai/hooks/core.py CHANGED
@@ -1,9 +1,13 @@
1
1
  import json
2
2
  from typing import Any, Callable, Sequence
3
3
 
4
+ from langchain_core.messages import BaseMessage, HumanMessage, RemoveMessage
5
+ from langgraph.runtime import Runtime
4
6
  from loguru import logger
5
7
 
6
8
  from dao_ai.config import AppConfig, FunctionHook, PythonFunctionModel
9
+ from dao_ai.messages import last_human_message
10
+ from dao_ai.state import Context
7
11
 
8
12
 
9
13
  def create_hooks(
@@ -23,7 +27,7 @@ def create_hooks(
23
27
  return hooks
24
28
 
25
29
 
26
- def null_hook(state: dict[str, Any], config: dict[str, Any]) -> dict[str, Any]:
30
+ def null_hook(state: dict[str, Any], runtime: Runtime[Context]) -> dict[str, Any]:
27
31
  logger.debug("Executing null hook")
28
32
  return {}
29
33
 
@@ -36,19 +40,79 @@ def null_shutdown_hook(config: AppConfig) -> None:
36
40
  logger.debug("Executing null shutdown hook")
37
41
 
38
42
 
43
+ def filter_last_human_message_hook(
44
+ state: dict[str, Any], runtime: Runtime[Context]
45
+ ) -> dict[str, Any]:
46
+ """
47
+ Filter messages to keep only the last human message.
48
+
49
+ This hook removes all messages except for the most recent human message,
50
+ which can be useful for scenarios where you want to process only the
51
+ latest user input without conversation history.
52
+
53
+ Args:
54
+ state: The current state containing messages
55
+ runtime: The runtime context (unused in this hook)
56
+
57
+ Returns:
58
+ Updated state with filtered messages
59
+ """
60
+ logger.debug("Executing filter_last_human_message hook")
61
+
62
+ messages: list[BaseMessage] = state.get("messages", [])
63
+
64
+ if not messages:
65
+ logger.debug("No messages found in state")
66
+ return state
67
+
68
+ # Use the helper function to find the last human message
69
+ last_message: HumanMessage = last_human_message(messages)
70
+
71
+ if last_message is None:
72
+ logger.debug("No human messages found in state")
73
+ # Return empty messages if no human message found
74
+ updated_state = state.copy()
75
+ updated_state["messages"] = []
76
+ return updated_state
77
+
78
+ logger.debug(f"Filtered {len(messages)} messages down to 1 (last human message)")
79
+
80
+ removed_messages: Sequence[BaseMessage] = [
81
+ RemoveMessage(id=message.id)
82
+ for message in messages
83
+ if message.id != last_message.id
84
+ ]
85
+
86
+ updated_state: dict[str, Sequence[BaseMessage]] = {"messages": removed_messages}
87
+
88
+ return updated_state
89
+
90
+
39
91
  def require_user_id_hook(
40
- state: dict[str, Any], config: dict[str, Any]
92
+ state: dict[str, Any], runtime: Runtime[Context]
41
93
  ) -> dict[str, Any]:
42
94
  logger.debug("Executing user_id validation hook")
43
95
 
44
- config = config.get("custom_inputs", config)
96
+ context: Context = runtime.context or Context()
45
97
 
46
- configurable: dict[str, Any] = config.get("configurable", {})
98
+ user_id: str | None = context.user_id
47
99
 
48
- if "user_id" not in configurable or not configurable["user_id"]:
100
+ if not user_id:
49
101
  logger.error("User ID is required but not provided in the configuration.")
50
102
 
51
- error_message = """
103
+ # Create corrected configuration using any provided context parameters
104
+ corrected_config = {
105
+ "configurable": {
106
+ "thread_id": context.thread_id or "1",
107
+ "user_id": "my_user_id",
108
+ "store_num": context.store_num or 87887,
109
+ }
110
+ }
111
+
112
+ # Format as JSON for copy-paste
113
+ corrected_config_json = json.dumps(corrected_config, indent=2)
114
+
115
+ error_message = f"""
52
116
  ## Authentication Required
53
117
 
54
118
  A **user_id** is required to process your request. Please provide your user ID in the configuration.
@@ -58,25 +122,19 @@ A **user_id** is required to process your request. Please provide your user ID i
58
122
  Please include the following JSON in your request configuration:
59
123
 
60
124
  ```json
61
- {
62
- "configurable": {
63
- "thread_id": "1",
64
- "user_id": "my_user_id"
65
- }
66
- }
125
+ {corrected_config_json}
67
126
  ```
68
127
 
69
128
  ### Field Descriptions
70
129
  - **user_id**: Your unique user identifier (required)
71
- - **thread_id**: Conversation thread identifier (optional)
130
+ - **thread_id**: Conversation thread identifier (required)
131
+ - **store_num**: Your store number (required)
72
132
 
73
133
  Please update your configuration and try again.
74
134
  """.strip()
75
135
 
76
136
  raise ValueError(error_message)
77
137
 
78
- # Validate that user_id doesn't contain dots
79
- user_id = configurable["user_id"]
80
138
  if "." in user_id:
81
139
  logger.error(f"User ID '{user_id}' contains invalid character '.'")
82
140
 
@@ -88,9 +146,9 @@ Please update your configuration and try again.
88
146
  # Corrected config with fixed user_id
89
147
  corrected_config = {
90
148
  "configurable": {
91
- "thread_id": configurable.get("thread_id", "1"),
149
+ "thread_id": context.thread_id or "1",
92
150
  "user_id": corrected_user_id,
93
- "store_num": configurable.get("store_num", 87887),
151
+ "store_num": context.store_num or 87887,
94
152
  }
95
153
  }
96
154
 
@@ -116,38 +174,46 @@ Please update your user_id and try again.
116
174
 
117
175
 
118
176
  def require_thread_id_hook(
119
- state: dict[str, Any], config: dict[str, Any]
177
+ state: dict[str, Any], runtime: Runtime[Context]
120
178
  ) -> dict[str, Any]:
121
179
  logger.debug("Executing thread_id validation hook")
122
180
 
123
- config = config.get("custom_inputs", config)
181
+ context: Context = runtime.context or Context()
124
182
 
125
- configurable: dict[str, Any] = config.get("configurable", {})
183
+ thread_id: str | None = context.thread_id
126
184
 
127
- if "thread_id" not in configurable or not configurable["thread_id"]:
185
+ if not thread_id:
128
186
  logger.error("Thread ID is required but not provided in the configuration.")
129
187
 
130
- error_message = """
188
+ # Create corrected configuration using any provided context parameters
189
+ corrected_config = {
190
+ "configurable": {
191
+ "thread_id": "1",
192
+ "user_id": context.user_id or "my_user_id",
193
+ "store_num": context.store_num or 87887,
194
+ }
195
+ }
196
+
197
+ # Format as JSON for copy-paste
198
+ corrected_config_json = json.dumps(corrected_config, indent=2)
199
+
200
+ error_message = f"""
131
201
  ## Authentication Required
132
202
 
133
- A **thread_id** is required to process your request. Please provide your user ID in the configuration.
203
+ A **thread_id** is required to process your request. Please provide your thread ID in the configuration.
134
204
 
135
205
  ### Required Configuration Format
136
206
 
137
207
  Please include the following JSON in your request configuration:
138
208
 
139
209
  ```json
140
- {
141
- "configurable": {
142
- "thread_id": "1",
143
- "user_id": "my_user_id"
144
- }
145
- }
210
+ {corrected_config_json}
146
211
  ```
147
212
 
148
213
  ### Field Descriptions
214
+ - **thread_id**: Conversation thread identifier (required)
149
215
  - **user_id**: Your unique user identifier (required)
150
- - **thread_id**: Conversation thread identifier (optional)
216
+ - **store_num**: Your store number (required)
151
217
 
152
218
  Please update your configuration and try again.
153
219
  """.strip()
dao_ai/memory/postgres.py CHANGED
@@ -4,8 +4,8 @@ import threading
4
4
  from typing import Any, Optional
5
5
 
6
6
  from langgraph.checkpoint.base import BaseCheckpointSaver
7
- from langgraph.checkpoint.postgres import PostgresSaver
8
- from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
7
+ from langgraph.checkpoint.postgres import ShallowPostgresSaver
8
+ from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
9
9
  from langgraph.store.base import BaseStore
10
10
  from langgraph.store.postgres import PostgresStore
11
11
  from langgraph.store.postgres.aio import AsyncPostgresStore
@@ -141,7 +141,7 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
141
141
  def __init__(self, checkpointer_model: CheckpointerModel):
142
142
  self.checkpointer_model = checkpointer_model
143
143
  self.pool: Optional[AsyncConnectionPool] = None
144
- self._checkpointer: Optional[AsyncPostgresSaver] = None
144
+ self._checkpointer: Optional[AsyncShallowPostgresSaver] = None
145
145
  self._setup_complete = False
146
146
 
147
147
  def checkpointer(self) -> BaseCheckpointSaver:
@@ -183,7 +183,7 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
183
183
  )
184
184
 
185
185
  # Create checkpointer with the shared pool
186
- self._checkpointer = AsyncPostgresSaver(conn=self.pool)
186
+ self._checkpointer = AsyncShallowPostgresSaver(conn=self.pool)
187
187
  await self._checkpointer.setup()
188
188
 
189
189
  self._setup_complete = True
@@ -315,7 +315,7 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
315
315
  def __init__(self, checkpointer_model: CheckpointerModel):
316
316
  self.checkpointer_model = checkpointer_model
317
317
  self.pool: Optional[ConnectionPool] = None
318
- self._checkpointer: Optional[PostgresSaver] = None
318
+ self._checkpointer: Optional[ShallowPostgresSaver] = None
319
319
  self._setup_complete = False
320
320
 
321
321
  def checkpointer(self) -> BaseCheckpointSaver:
@@ -345,7 +345,7 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
345
345
  self.pool = PostgresPoolManager.get_pool(self.checkpointer_model.database)
346
346
 
347
347
  # Create checkpointer with the shared pool
348
- self._checkpointer = PostgresSaver(conn=self.pool)
348
+ self._checkpointer = ShallowPostgresSaver(conn=self.pool)
349
349
  self._checkpointer.setup()
350
350
 
351
351
  self._setup_complete = True
dao_ai/messages.py CHANGED
@@ -78,6 +78,12 @@ def convert_to_langchain_messages(messages: dict[str, Any]) -> Sequence[BaseMess
78
78
  return langchain_messages
79
79
 
80
80
 
81
+ def has_human_message(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
82
+ if isinstance(messages, BaseMessage):
83
+ messages = [messages]
84
+ return any(isinstance(m, HumanMessage) for m in messages)
85
+
86
+
81
87
  def has_langchain_messages(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
82
88
  if isinstance(messages, BaseMessage):
83
89
  messages = [messages]
dao_ai/models.py CHANGED
@@ -4,7 +4,6 @@ from pathlib import Path
4
4
  from typing import Any, Generator, Optional, Sequence
5
5
 
6
6
  from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
7
- from langchain_core.runnables import RunnableConfig
8
7
  from langgraph.graph.state import CompiledStateGraph
9
8
  from loguru import logger
10
9
  from mlflow import MlflowClient
@@ -20,7 +19,7 @@ from mlflow.types.llm import (
20
19
  )
21
20
 
22
21
  from dao_ai.messages import has_langchain_messages, has_mlflow_messages
23
- from dao_ai.state import SharedState
22
+ from dao_ai.state import Context
24
23
 
25
24
 
26
25
  def get_latest_model_version(model_name: str) -> int:
@@ -63,10 +62,11 @@ class LanggraphChatModel(ChatModel):
63
62
 
64
63
  request = {"messages": self._convert_messages_to_dict(messages)}
65
64
 
66
- config: SharedState = self._convert_to_config(params)
65
+ context: Context = self._convert_to_context(params)
66
+ custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
67
67
 
68
68
  response: dict[str, Sequence[BaseMessage]] = self.graph.invoke(
69
- request, config=config
69
+ request, context=context, config=custom_inputs
70
70
  )
71
71
  logger.trace(f"response: {response}")
72
72
 
@@ -75,12 +75,9 @@ class LanggraphChatModel(ChatModel):
75
75
  response_message = ChatMessage(role="assistant", content=last_message.content)
76
76
  return ChatCompletionResponse(choices=[ChatChoice(message=response_message)])
77
77
 
78
- def _convert_to_config(
78
+ def _convert_to_context(
79
79
  self, params: Optional[ChatParams | dict[str, Any]]
80
- ) -> RunnableConfig:
81
- if not params:
82
- return {}
83
-
80
+ ) -> Context:
84
81
  input_data = params
85
82
  if isinstance(params, ChatParams):
86
83
  input_data = params.to_dict()
@@ -102,8 +99,8 @@ class LanggraphChatModel(ChatModel):
102
99
  if "thread_id" not in configurable:
103
100
  configurable["thread_id"] = str(uuid.uuid4())
104
101
 
105
- agent_config: RunnableConfig = RunnableConfig(**{"configurable": configurable})
106
- return agent_config
102
+ context: Context = Context(**configurable)
103
+ return context
107
104
 
108
105
  def predict_stream(
109
106
  self, context, messages: list[ChatMessage], params: ChatParams
@@ -114,25 +111,36 @@ class LanggraphChatModel(ChatModel):
114
111
 
115
112
  request = {"messages": self._convert_messages_to_dict(messages)}
116
113
 
117
- config: SharedState = self._convert_to_config(params)
114
+ context: Context = self._convert_to_context(params)
115
+ custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
118
116
 
119
- for message, metadata in self.graph.stream(
120
- request, config=config, stream_mode="messages"
117
+ for nodes, stream_mode, messages_batch in self.graph.stream(
118
+ request,
119
+ context=context,
120
+ config=custom_inputs,
121
+ stream_mode=["messages", "custom"],
122
+ subgraphs=True,
121
123
  ):
122
- logger.trace(f"message_type: {type(message)}, message: {message}")
123
- if (
124
- isinstance(
125
- message,
126
- (
127
- AIMessageChunk,
128
- AIMessage,
129
- ),
130
- )
131
- and message.content
132
- and metadata["langgraph_node"] not in ["summarization"]
133
- ):
134
- content = message.content
135
- yield self._create_chat_completion_chunk(content)
124
+ nodes: tuple[str, ...]
125
+ stream_mode: str
126
+ messages_batch: Sequence[BaseMessage]
127
+ logger.trace(
128
+ f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages_batch}"
129
+ )
130
+ for message in messages_batch:
131
+ if (
132
+ isinstance(
133
+ message,
134
+ (
135
+ AIMessageChunk,
136
+ AIMessage,
137
+ ),
138
+ )
139
+ and message.content
140
+ and "summarization" not in nodes
141
+ ):
142
+ content = message.content
143
+ yield self._create_chat_completion_chunk(content)
136
144
 
137
145
  def _create_chat_completion_chunk(self, content: str) -> ChatCompletionChunk:
138
146
  return ChatCompletionChunk(
@@ -183,11 +191,37 @@ def _process_langchain_messages_stream(
183
191
  if isinstance(app, LanggraphChatModel):
184
192
  app = app.graph
185
193
 
186
- for message, _ in app.stream(
187
- {"messages": messages}, config=custom_inputs, stream_mode="messages"
194
+ logger.debug(f"Processing messages: {messages}, custom_inputs: {custom_inputs}")
195
+
196
+ custom_inputs = custom_inputs.get("configurable", custom_inputs or {})
197
+ context: Context = Context(**custom_inputs)
198
+
199
+ for nodes, stream_mode, messages in app.stream(
200
+ {"messages": messages},
201
+ context=context,
202
+ config=custom_inputs,
203
+ stream_mode=["messages", "custom"],
204
+ subgraphs=True,
188
205
  ):
189
- message: AIMessageChunk
190
- yield message
206
+ nodes: tuple[str, ...]
207
+ stream_mode: str
208
+ messages: Sequence[BaseMessage]
209
+ logger.trace(
210
+ f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages}"
211
+ )
212
+ for message in messages:
213
+ if (
214
+ isinstance(
215
+ message,
216
+ (
217
+ AIMessageChunk,
218
+ AIMessage,
219
+ ),
220
+ )
221
+ and message.content
222
+ and "summarization" not in nodes
223
+ ):
224
+ yield message
191
225
 
192
226
 
193
227
  def _process_mlflow_messages(