dao-ai 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +3 -0
- dao_ai/config.py +21 -3
- dao_ai/graph.py +31 -31
- dao_ai/hooks/__init__.py +2 -0
- dao_ai/hooks/core.py +96 -30
- dao_ai/memory/postgres.py +6 -6
- dao_ai/messages.py +6 -0
- dao_ai/models.py +66 -32
- dao_ai/nodes.py +12 -10
- dao_ai/providers/databricks.py +83 -3
- dao_ai/state.py +7 -0
- dao_ai/tools/__init__.py +3 -4
- dao_ai/tools/core.py +1 -294
- dao_ai/tools/human_in_the_loop.py +96 -0
- dao_ai/tools/mcp.py +118 -0
- dao_ai/tools/python.py +60 -0
- dao_ai/tools/unity_catalog.py +50 -0
- {dao_ai-0.0.5.dist-info → dao_ai-0.0.7.dist-info}/METADATA +11 -12
- dao_ai-0.0.7.dist-info/RECORD +40 -0
- dao_ai-0.0.5.dist-info/RECORD +0 -36
- {dao_ai-0.0.5.dist-info → dao_ai-0.0.7.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.5.dist-info → dao_ai-0.0.7.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.5.dist-info → dao_ai-0.0.7.dist-info}/licenses/LICENSE +0 -0
dao_ai/cli.py
CHANGED
|
@@ -3,6 +3,7 @@ import json
|
|
|
3
3
|
import os
|
|
4
4
|
import subprocess
|
|
5
5
|
import sys
|
|
6
|
+
import traceback
|
|
6
7
|
from argparse import ArgumentParser, Namespace
|
|
7
8
|
from pathlib import Path
|
|
8
9
|
from typing import Optional, Sequence
|
|
@@ -388,7 +389,9 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
388
389
|
|
|
389
390
|
except Exception as e:
|
|
390
391
|
print(f"\n❌ Error during streaming: {e}")
|
|
392
|
+
print(f"Stack trace:\n{traceback.format_exc()}")
|
|
391
393
|
logger.error(f"Streaming error: {e}")
|
|
394
|
+
logger.error(f"Stack trace: {traceback.format_exc()}")
|
|
392
395
|
|
|
393
396
|
except EOFError:
|
|
394
397
|
# Handle Ctrl-D
|
dao_ai/config.py
CHANGED
|
@@ -43,8 +43,6 @@ from mlflow.models.resources import (
|
|
|
43
43
|
)
|
|
44
44
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
|
|
45
45
|
|
|
46
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
47
|
-
|
|
48
46
|
|
|
49
47
|
class HasValue(ABC):
|
|
50
48
|
@abstractmethod
|
|
@@ -275,9 +273,14 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
275
273
|
# chat_client: LanguageModelLike = self.as_open_ai_client()
|
|
276
274
|
|
|
277
275
|
# Create ChatDatabricksWrapper instance directly
|
|
276
|
+
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
277
|
+
|
|
278
278
|
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
279
279
|
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
280
280
|
)
|
|
281
|
+
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
282
|
+
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
283
|
+
# )
|
|
281
284
|
|
|
282
285
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
283
286
|
for fallback in self.fallbacks:
|
|
@@ -1001,7 +1004,15 @@ class ChatHistoryModel(BaseModel):
|
|
|
1001
1004
|
max_tokens: int = 256
|
|
1002
1005
|
max_tokens_before_summary: Optional[int] = None
|
|
1003
1006
|
max_messages_before_summary: Optional[int] = None
|
|
1004
|
-
max_summary_tokens:
|
|
1007
|
+
max_summary_tokens: int = 255
|
|
1008
|
+
|
|
1009
|
+
@model_validator(mode="after")
|
|
1010
|
+
def validate_max_summary_tokens(self) -> "ChatHistoryModel":
|
|
1011
|
+
if self.max_summary_tokens >= self.max_tokens:
|
|
1012
|
+
raise ValueError(
|
|
1013
|
+
f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
|
|
1014
|
+
)
|
|
1015
|
+
return self
|
|
1005
1016
|
|
|
1006
1017
|
|
|
1007
1018
|
class AppModel(BaseModel):
|
|
@@ -1057,12 +1068,19 @@ class AppModel(BaseModel):
|
|
|
1057
1068
|
return self
|
|
1058
1069
|
|
|
1059
1070
|
|
|
1071
|
+
class GuidelineModel(BaseModel):
|
|
1072
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1073
|
+
name: str
|
|
1074
|
+
guidelines: list[str]
|
|
1075
|
+
|
|
1076
|
+
|
|
1060
1077
|
class EvaluationModel(BaseModel):
|
|
1061
1078
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1062
1079
|
model: LLMModel
|
|
1063
1080
|
table: TableModel
|
|
1064
1081
|
num_evals: int
|
|
1065
1082
|
custom_inputs: dict[str, Any] = Field(default_factory=dict)
|
|
1083
|
+
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
1066
1084
|
|
|
1067
1085
|
|
|
1068
1086
|
class DatasetFormat(str, Enum):
|
dao_ai/graph.py
CHANGED
|
@@ -3,13 +3,10 @@ from typing import Sequence
|
|
|
3
3
|
from langchain_core.language_models import LanguageModelLike
|
|
4
4
|
from langchain_core.runnables import RunnableConfig
|
|
5
5
|
from langchain_core.tools import BaseTool
|
|
6
|
-
from langgraph.cache.base import BaseCache
|
|
7
|
-
from langgraph.cache.memory import InMemoryCache
|
|
8
6
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
9
7
|
from langgraph.graph import END, StateGraph
|
|
10
8
|
from langgraph.graph.state import CompiledStateGraph
|
|
11
9
|
from langgraph.store.base import BaseStore
|
|
12
|
-
from langgraph.types import CachePolicy
|
|
13
10
|
from langgraph_supervisor import create_handoff_tool as supervisor_handoff_tool
|
|
14
11
|
from langgraph_supervisor import create_supervisor
|
|
15
12
|
from langgraph_swarm import create_handoff_tool as swarm_handoff_tool
|
|
@@ -29,7 +26,7 @@ from dao_ai.nodes import (
|
|
|
29
26
|
message_hook_node,
|
|
30
27
|
)
|
|
31
28
|
from dao_ai.prompts import make_prompt
|
|
32
|
-
from dao_ai.state import IncomingState, OutgoingState, SharedState
|
|
29
|
+
from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
def route_message(state: SharedState) -> str:
|
|
@@ -113,9 +110,6 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
113
110
|
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
114
111
|
logger.debug(f"Using checkpointer: {checkpointer}")
|
|
115
112
|
|
|
116
|
-
cache: BaseCache = None
|
|
117
|
-
cache = InMemoryCache()
|
|
118
|
-
|
|
119
113
|
prompt: str = supervisor.prompt
|
|
120
114
|
|
|
121
115
|
model: LanguageModelLike = supervisor.model.as_chat_model()
|
|
@@ -127,22 +121,29 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
127
121
|
tools=tools,
|
|
128
122
|
state_schema=SharedState,
|
|
129
123
|
config_schema=RunnableConfig,
|
|
124
|
+
output_mode="last_message",
|
|
125
|
+
add_handoff_messages=False,
|
|
126
|
+
add_handoff_back_messages=False,
|
|
127
|
+
context_schema=Context,
|
|
128
|
+
# output_mode="full",
|
|
129
|
+
# add_handoff_messages=True,
|
|
130
|
+
# add_handoff_back_messages=True,
|
|
130
131
|
)
|
|
131
132
|
|
|
132
|
-
supervisor_node: CompiledStateGraph = supervisor_workflow.compile(
|
|
133
|
+
supervisor_node: CompiledStateGraph = supervisor_workflow.compile(
|
|
134
|
+
checkpointer=checkpointer, store=store
|
|
135
|
+
)
|
|
133
136
|
|
|
134
137
|
workflow: StateGraph = StateGraph(
|
|
135
138
|
SharedState,
|
|
136
|
-
config_schema=RunnableConfig,
|
|
137
139
|
input=IncomingState,
|
|
138
140
|
output=OutgoingState,
|
|
141
|
+
context_schema=Context,
|
|
139
142
|
)
|
|
140
143
|
|
|
141
144
|
workflow.add_node("message_hook", message_hook_node(config=config))
|
|
142
145
|
|
|
143
|
-
workflow.add_node(
|
|
144
|
-
"orchestration", supervisor_node, cache_policy=CachePolicy(ttl=60)
|
|
145
|
-
)
|
|
146
|
+
workflow.add_node("orchestration", supervisor_node)
|
|
146
147
|
workflow.add_conditional_edges(
|
|
147
148
|
"message_hook",
|
|
148
149
|
route_message,
|
|
@@ -153,7 +154,7 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
153
154
|
)
|
|
154
155
|
workflow.set_entry_point("message_hook")
|
|
155
156
|
|
|
156
|
-
return workflow.compile(checkpointer=checkpointer, store=store
|
|
157
|
+
return workflow.compile(checkpointer=checkpointer, store=store)
|
|
157
158
|
|
|
158
159
|
|
|
159
160
|
def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
@@ -172,6 +173,16 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
172
173
|
orchestration: OrchestrationModel = config.app.orchestration
|
|
173
174
|
swarm: SwarmModel = orchestration.swarm
|
|
174
175
|
|
|
176
|
+
store: BaseStore = None
|
|
177
|
+
if orchestration.memory and orchestration.memory.store:
|
|
178
|
+
store = orchestration.memory.store.as_store()
|
|
179
|
+
logger.debug(f"Using memory store: {store}")
|
|
180
|
+
|
|
181
|
+
checkpointer: BaseCheckpointSaver = None
|
|
182
|
+
if orchestration.memory and orchestration.memory.checkpointer:
|
|
183
|
+
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
184
|
+
logger.debug(f"Using checkpointer: {checkpointer}")
|
|
185
|
+
|
|
175
186
|
default_agent: AgentModel = swarm.default_agent
|
|
176
187
|
if isinstance(default_agent, AgentModel):
|
|
177
188
|
default_agent = default_agent.name
|
|
@@ -180,20 +191,22 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
180
191
|
agents=agents,
|
|
181
192
|
default_active_agent=default_agent,
|
|
182
193
|
state_schema=SharedState,
|
|
183
|
-
|
|
194
|
+
context_schema=Context,
|
|
184
195
|
)
|
|
185
196
|
|
|
186
|
-
swarm_node: CompiledStateGraph = swarm_workflow.compile(
|
|
197
|
+
swarm_node: CompiledStateGraph = swarm_workflow.compile(
|
|
198
|
+
checkpointer=checkpointer, store=store
|
|
199
|
+
)
|
|
187
200
|
|
|
188
201
|
workflow: StateGraph = StateGraph(
|
|
189
202
|
SharedState,
|
|
190
|
-
config_schema=RunnableConfig,
|
|
191
203
|
input=IncomingState,
|
|
192
204
|
output=OutgoingState,
|
|
205
|
+
context_schema=Context,
|
|
193
206
|
)
|
|
194
207
|
|
|
195
208
|
workflow.add_node("message_hook", message_hook_node(config=config))
|
|
196
|
-
workflow.add_node("orchestration", swarm_node
|
|
209
|
+
workflow.add_node("orchestration", swarm_node)
|
|
197
210
|
|
|
198
211
|
workflow.add_conditional_edges(
|
|
199
212
|
"message_hook",
|
|
@@ -206,20 +219,7 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
206
219
|
|
|
207
220
|
workflow.set_entry_point("message_hook")
|
|
208
221
|
|
|
209
|
-
|
|
210
|
-
if orchestration.memory and orchestration.memory.store:
|
|
211
|
-
store = orchestration.memory.store.as_store()
|
|
212
|
-
logger.debug(f"Using memory store: {store}")
|
|
213
|
-
|
|
214
|
-
checkpointer: BaseCheckpointSaver = None
|
|
215
|
-
if orchestration.memory and orchestration.memory.checkpointer:
|
|
216
|
-
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
217
|
-
logger.debug(f"Using checkpointer: {checkpointer}")
|
|
218
|
-
|
|
219
|
-
cache: BaseCache = None
|
|
220
|
-
cache = InMemoryCache()
|
|
221
|
-
|
|
222
|
-
return workflow.compile(checkpointer=checkpointer, store=store, cache=cache)
|
|
222
|
+
return workflow.compile(checkpointer=checkpointer, store=store)
|
|
223
223
|
|
|
224
224
|
|
|
225
225
|
def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
|
dao_ai/hooks/__init__.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from dao_ai.hooks.core import (
|
|
2
2
|
create_hooks,
|
|
3
|
+
filter_last_human_message_hook,
|
|
3
4
|
null_hook,
|
|
4
5
|
null_initialization_hook,
|
|
5
6
|
null_shutdown_hook,
|
|
@@ -14,4 +15,5 @@ __all__ = [
|
|
|
14
15
|
"null_shutdown_hook",
|
|
15
16
|
"require_thread_id_hook",
|
|
16
17
|
"require_user_id_hook",
|
|
18
|
+
"filter_last_human_message_hook",
|
|
17
19
|
]
|
dao_ai/hooks/core.py
CHANGED
|
@@ -1,9 +1,13 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from typing import Any, Callable, Sequence
|
|
3
3
|
|
|
4
|
+
from langchain_core.messages import BaseMessage, HumanMessage, RemoveMessage
|
|
5
|
+
from langgraph.runtime import Runtime
|
|
4
6
|
from loguru import logger
|
|
5
7
|
|
|
6
8
|
from dao_ai.config import AppConfig, FunctionHook, PythonFunctionModel
|
|
9
|
+
from dao_ai.messages import last_human_message
|
|
10
|
+
from dao_ai.state import Context
|
|
7
11
|
|
|
8
12
|
|
|
9
13
|
def create_hooks(
|
|
@@ -23,7 +27,7 @@ def create_hooks(
|
|
|
23
27
|
return hooks
|
|
24
28
|
|
|
25
29
|
|
|
26
|
-
def null_hook(state: dict[str, Any],
|
|
30
|
+
def null_hook(state: dict[str, Any], runtime: Runtime[Context]) -> dict[str, Any]:
|
|
27
31
|
logger.debug("Executing null hook")
|
|
28
32
|
return {}
|
|
29
33
|
|
|
@@ -36,19 +40,79 @@ def null_shutdown_hook(config: AppConfig) -> None:
|
|
|
36
40
|
logger.debug("Executing null shutdown hook")
|
|
37
41
|
|
|
38
42
|
|
|
43
|
+
def filter_last_human_message_hook(
|
|
44
|
+
state: dict[str, Any], runtime: Runtime[Context]
|
|
45
|
+
) -> dict[str, Any]:
|
|
46
|
+
"""
|
|
47
|
+
Filter messages to keep only the last human message.
|
|
48
|
+
|
|
49
|
+
This hook removes all messages except for the most recent human message,
|
|
50
|
+
which can be useful for scenarios where you want to process only the
|
|
51
|
+
latest user input without conversation history.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
state: The current state containing messages
|
|
55
|
+
runtime: The runtime context (unused in this hook)
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Updated state with filtered messages
|
|
59
|
+
"""
|
|
60
|
+
logger.debug("Executing filter_last_human_message hook")
|
|
61
|
+
|
|
62
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
63
|
+
|
|
64
|
+
if not messages:
|
|
65
|
+
logger.debug("No messages found in state")
|
|
66
|
+
return state
|
|
67
|
+
|
|
68
|
+
# Use the helper function to find the last human message
|
|
69
|
+
last_message: HumanMessage = last_human_message(messages)
|
|
70
|
+
|
|
71
|
+
if last_message is None:
|
|
72
|
+
logger.debug("No human messages found in state")
|
|
73
|
+
# Return empty messages if no human message found
|
|
74
|
+
updated_state = state.copy()
|
|
75
|
+
updated_state["messages"] = []
|
|
76
|
+
return updated_state
|
|
77
|
+
|
|
78
|
+
logger.debug(f"Filtered {len(messages)} messages down to 1 (last human message)")
|
|
79
|
+
|
|
80
|
+
removed_messages: Sequence[BaseMessage] = [
|
|
81
|
+
RemoveMessage(id=message.id)
|
|
82
|
+
for message in messages
|
|
83
|
+
if message.id != last_message.id
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
updated_state: dict[str, Sequence[BaseMessage]] = {"messages": removed_messages}
|
|
87
|
+
|
|
88
|
+
return updated_state
|
|
89
|
+
|
|
90
|
+
|
|
39
91
|
def require_user_id_hook(
|
|
40
|
-
state: dict[str, Any],
|
|
92
|
+
state: dict[str, Any], runtime: Runtime[Context]
|
|
41
93
|
) -> dict[str, Any]:
|
|
42
94
|
logger.debug("Executing user_id validation hook")
|
|
43
95
|
|
|
44
|
-
|
|
96
|
+
context: Context = runtime.context or Context()
|
|
45
97
|
|
|
46
|
-
|
|
98
|
+
user_id: str | None = context.user_id
|
|
47
99
|
|
|
48
|
-
if
|
|
100
|
+
if not user_id:
|
|
49
101
|
logger.error("User ID is required but not provided in the configuration.")
|
|
50
102
|
|
|
51
|
-
|
|
103
|
+
# Create corrected configuration using any provided context parameters
|
|
104
|
+
corrected_config = {
|
|
105
|
+
"configurable": {
|
|
106
|
+
"thread_id": context.thread_id or "1",
|
|
107
|
+
"user_id": "my_user_id",
|
|
108
|
+
"store_num": context.store_num or 87887,
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
# Format as JSON for copy-paste
|
|
113
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
114
|
+
|
|
115
|
+
error_message = f"""
|
|
52
116
|
## Authentication Required
|
|
53
117
|
|
|
54
118
|
A **user_id** is required to process your request. Please provide your user ID in the configuration.
|
|
@@ -58,25 +122,19 @@ A **user_id** is required to process your request. Please provide your user ID i
|
|
|
58
122
|
Please include the following JSON in your request configuration:
|
|
59
123
|
|
|
60
124
|
```json
|
|
61
|
-
{
|
|
62
|
-
"configurable": {
|
|
63
|
-
"thread_id": "1",
|
|
64
|
-
"user_id": "my_user_id"
|
|
65
|
-
}
|
|
66
|
-
}
|
|
125
|
+
{corrected_config_json}
|
|
67
126
|
```
|
|
68
127
|
|
|
69
128
|
### Field Descriptions
|
|
70
129
|
- **user_id**: Your unique user identifier (required)
|
|
71
|
-
- **thread_id**: Conversation thread identifier (
|
|
130
|
+
- **thread_id**: Conversation thread identifier (required)
|
|
131
|
+
- **store_num**: Your store number (required)
|
|
72
132
|
|
|
73
133
|
Please update your configuration and try again.
|
|
74
134
|
""".strip()
|
|
75
135
|
|
|
76
136
|
raise ValueError(error_message)
|
|
77
137
|
|
|
78
|
-
# Validate that user_id doesn't contain dots
|
|
79
|
-
user_id = configurable["user_id"]
|
|
80
138
|
if "." in user_id:
|
|
81
139
|
logger.error(f"User ID '{user_id}' contains invalid character '.'")
|
|
82
140
|
|
|
@@ -88,9 +146,9 @@ Please update your configuration and try again.
|
|
|
88
146
|
# Corrected config with fixed user_id
|
|
89
147
|
corrected_config = {
|
|
90
148
|
"configurable": {
|
|
91
|
-
"thread_id":
|
|
149
|
+
"thread_id": context.thread_id or "1",
|
|
92
150
|
"user_id": corrected_user_id,
|
|
93
|
-
"store_num":
|
|
151
|
+
"store_num": context.store_num or 87887,
|
|
94
152
|
}
|
|
95
153
|
}
|
|
96
154
|
|
|
@@ -116,38 +174,46 @@ Please update your user_id and try again.
|
|
|
116
174
|
|
|
117
175
|
|
|
118
176
|
def require_thread_id_hook(
|
|
119
|
-
state: dict[str, Any],
|
|
177
|
+
state: dict[str, Any], runtime: Runtime[Context]
|
|
120
178
|
) -> dict[str, Any]:
|
|
121
179
|
logger.debug("Executing thread_id validation hook")
|
|
122
180
|
|
|
123
|
-
|
|
181
|
+
context: Context = runtime.context or Context()
|
|
124
182
|
|
|
125
|
-
|
|
183
|
+
thread_id: str | None = context.thread_id
|
|
126
184
|
|
|
127
|
-
if
|
|
185
|
+
if not thread_id:
|
|
128
186
|
logger.error("Thread ID is required but not provided in the configuration.")
|
|
129
187
|
|
|
130
|
-
|
|
188
|
+
# Create corrected configuration using any provided context parameters
|
|
189
|
+
corrected_config = {
|
|
190
|
+
"configurable": {
|
|
191
|
+
"thread_id": "1",
|
|
192
|
+
"user_id": context.user_id or "my_user_id",
|
|
193
|
+
"store_num": context.store_num or 87887,
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
# Format as JSON for copy-paste
|
|
198
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
199
|
+
|
|
200
|
+
error_message = f"""
|
|
131
201
|
## Authentication Required
|
|
132
202
|
|
|
133
|
-
A **thread_id** is required to process your request. Please provide your
|
|
203
|
+
A **thread_id** is required to process your request. Please provide your thread ID in the configuration.
|
|
134
204
|
|
|
135
205
|
### Required Configuration Format
|
|
136
206
|
|
|
137
207
|
Please include the following JSON in your request configuration:
|
|
138
208
|
|
|
139
209
|
```json
|
|
140
|
-
{
|
|
141
|
-
"configurable": {
|
|
142
|
-
"thread_id": "1",
|
|
143
|
-
"user_id": "my_user_id"
|
|
144
|
-
}
|
|
145
|
-
}
|
|
210
|
+
{corrected_config_json}
|
|
146
211
|
```
|
|
147
212
|
|
|
148
213
|
### Field Descriptions
|
|
214
|
+
- **thread_id**: Conversation thread identifier (required)
|
|
149
215
|
- **user_id**: Your unique user identifier (required)
|
|
150
|
-
- **
|
|
216
|
+
- **store_num**: Your store number (required)
|
|
151
217
|
|
|
152
218
|
Please update your configuration and try again.
|
|
153
219
|
""".strip()
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -4,8 +4,8 @@ import threading
|
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
6
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
7
|
-
from langgraph.checkpoint.postgres import
|
|
8
|
-
from langgraph.checkpoint.postgres.aio import
|
|
7
|
+
from langgraph.checkpoint.postgres import ShallowPostgresSaver
|
|
8
|
+
from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
|
|
9
9
|
from langgraph.store.base import BaseStore
|
|
10
10
|
from langgraph.store.postgres import PostgresStore
|
|
11
11
|
from langgraph.store.postgres.aio import AsyncPostgresStore
|
|
@@ -141,7 +141,7 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
141
141
|
def __init__(self, checkpointer_model: CheckpointerModel):
|
|
142
142
|
self.checkpointer_model = checkpointer_model
|
|
143
143
|
self.pool: Optional[AsyncConnectionPool] = None
|
|
144
|
-
self._checkpointer: Optional[
|
|
144
|
+
self._checkpointer: Optional[AsyncShallowPostgresSaver] = None
|
|
145
145
|
self._setup_complete = False
|
|
146
146
|
|
|
147
147
|
def checkpointer(self) -> BaseCheckpointSaver:
|
|
@@ -183,7 +183,7 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
183
183
|
)
|
|
184
184
|
|
|
185
185
|
# Create checkpointer with the shared pool
|
|
186
|
-
self._checkpointer =
|
|
186
|
+
self._checkpointer = AsyncShallowPostgresSaver(conn=self.pool)
|
|
187
187
|
await self._checkpointer.setup()
|
|
188
188
|
|
|
189
189
|
self._setup_complete = True
|
|
@@ -315,7 +315,7 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
315
315
|
def __init__(self, checkpointer_model: CheckpointerModel):
|
|
316
316
|
self.checkpointer_model = checkpointer_model
|
|
317
317
|
self.pool: Optional[ConnectionPool] = None
|
|
318
|
-
self._checkpointer: Optional[
|
|
318
|
+
self._checkpointer: Optional[ShallowPostgresSaver] = None
|
|
319
319
|
self._setup_complete = False
|
|
320
320
|
|
|
321
321
|
def checkpointer(self) -> BaseCheckpointSaver:
|
|
@@ -345,7 +345,7 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
345
345
|
self.pool = PostgresPoolManager.get_pool(self.checkpointer_model.database)
|
|
346
346
|
|
|
347
347
|
# Create checkpointer with the shared pool
|
|
348
|
-
self._checkpointer =
|
|
348
|
+
self._checkpointer = ShallowPostgresSaver(conn=self.pool)
|
|
349
349
|
self._checkpointer.setup()
|
|
350
350
|
|
|
351
351
|
self._setup_complete = True
|
dao_ai/messages.py
CHANGED
|
@@ -78,6 +78,12 @@ def convert_to_langchain_messages(messages: dict[str, Any]) -> Sequence[BaseMess
|
|
|
78
78
|
return langchain_messages
|
|
79
79
|
|
|
80
80
|
|
|
81
|
+
def has_human_message(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
82
|
+
if isinstance(messages, BaseMessage):
|
|
83
|
+
messages = [messages]
|
|
84
|
+
return any(isinstance(m, HumanMessage) for m in messages)
|
|
85
|
+
|
|
86
|
+
|
|
81
87
|
def has_langchain_messages(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
82
88
|
if isinstance(messages, BaseMessage):
|
|
83
89
|
messages = [messages]
|
dao_ai/models.py
CHANGED
|
@@ -4,7 +4,6 @@ from pathlib import Path
|
|
|
4
4
|
from typing import Any, Generator, Optional, Sequence
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
7
|
-
from langchain_core.runnables import RunnableConfig
|
|
8
7
|
from langgraph.graph.state import CompiledStateGraph
|
|
9
8
|
from loguru import logger
|
|
10
9
|
from mlflow import MlflowClient
|
|
@@ -20,7 +19,7 @@ from mlflow.types.llm import (
|
|
|
20
19
|
)
|
|
21
20
|
|
|
22
21
|
from dao_ai.messages import has_langchain_messages, has_mlflow_messages
|
|
23
|
-
from dao_ai.state import
|
|
22
|
+
from dao_ai.state import Context
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def get_latest_model_version(model_name: str) -> int:
|
|
@@ -63,10 +62,11 @@ class LanggraphChatModel(ChatModel):
|
|
|
63
62
|
|
|
64
63
|
request = {"messages": self._convert_messages_to_dict(messages)}
|
|
65
64
|
|
|
66
|
-
|
|
65
|
+
context: Context = self._convert_to_context(params)
|
|
66
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
67
67
|
|
|
68
68
|
response: dict[str, Sequence[BaseMessage]] = self.graph.invoke(
|
|
69
|
-
request, config=
|
|
69
|
+
request, context=context, config=custom_inputs
|
|
70
70
|
)
|
|
71
71
|
logger.trace(f"response: {response}")
|
|
72
72
|
|
|
@@ -75,12 +75,9 @@ class LanggraphChatModel(ChatModel):
|
|
|
75
75
|
response_message = ChatMessage(role="assistant", content=last_message.content)
|
|
76
76
|
return ChatCompletionResponse(choices=[ChatChoice(message=response_message)])
|
|
77
77
|
|
|
78
|
-
def
|
|
78
|
+
def _convert_to_context(
|
|
79
79
|
self, params: Optional[ChatParams | dict[str, Any]]
|
|
80
|
-
) ->
|
|
81
|
-
if not params:
|
|
82
|
-
return {}
|
|
83
|
-
|
|
80
|
+
) -> Context:
|
|
84
81
|
input_data = params
|
|
85
82
|
if isinstance(params, ChatParams):
|
|
86
83
|
input_data = params.to_dict()
|
|
@@ -102,8 +99,8 @@ class LanggraphChatModel(ChatModel):
|
|
|
102
99
|
if "thread_id" not in configurable:
|
|
103
100
|
configurable["thread_id"] = str(uuid.uuid4())
|
|
104
101
|
|
|
105
|
-
|
|
106
|
-
return
|
|
102
|
+
context: Context = Context(**configurable)
|
|
103
|
+
return context
|
|
107
104
|
|
|
108
105
|
def predict_stream(
|
|
109
106
|
self, context, messages: list[ChatMessage], params: ChatParams
|
|
@@ -114,25 +111,36 @@ class LanggraphChatModel(ChatModel):
|
|
|
114
111
|
|
|
115
112
|
request = {"messages": self._convert_messages_to_dict(messages)}
|
|
116
113
|
|
|
117
|
-
|
|
114
|
+
context: Context = self._convert_to_context(params)
|
|
115
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
118
116
|
|
|
119
|
-
for
|
|
120
|
-
request,
|
|
117
|
+
for nodes, stream_mode, messages_batch in self.graph.stream(
|
|
118
|
+
request,
|
|
119
|
+
context=context,
|
|
120
|
+
config=custom_inputs,
|
|
121
|
+
stream_mode=["messages", "custom"],
|
|
122
|
+
subgraphs=True,
|
|
121
123
|
):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
124
|
+
nodes: tuple[str, ...]
|
|
125
|
+
stream_mode: str
|
|
126
|
+
messages_batch: Sequence[BaseMessage]
|
|
127
|
+
logger.trace(
|
|
128
|
+
f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages_batch}"
|
|
129
|
+
)
|
|
130
|
+
for message in messages_batch:
|
|
131
|
+
if (
|
|
132
|
+
isinstance(
|
|
133
|
+
message,
|
|
134
|
+
(
|
|
135
|
+
AIMessageChunk,
|
|
136
|
+
AIMessage,
|
|
137
|
+
),
|
|
138
|
+
)
|
|
139
|
+
and message.content
|
|
140
|
+
and "summarization" not in nodes
|
|
141
|
+
):
|
|
142
|
+
content = message.content
|
|
143
|
+
yield self._create_chat_completion_chunk(content)
|
|
136
144
|
|
|
137
145
|
def _create_chat_completion_chunk(self, content: str) -> ChatCompletionChunk:
|
|
138
146
|
return ChatCompletionChunk(
|
|
@@ -183,11 +191,37 @@ def _process_langchain_messages_stream(
|
|
|
183
191
|
if isinstance(app, LanggraphChatModel):
|
|
184
192
|
app = app.graph
|
|
185
193
|
|
|
186
|
-
|
|
187
|
-
|
|
194
|
+
logger.debug(f"Processing messages: {messages}, custom_inputs: {custom_inputs}")
|
|
195
|
+
|
|
196
|
+
custom_inputs = custom_inputs.get("configurable", custom_inputs or {})
|
|
197
|
+
context: Context = Context(**custom_inputs)
|
|
198
|
+
|
|
199
|
+
for nodes, stream_mode, messages in app.stream(
|
|
200
|
+
{"messages": messages},
|
|
201
|
+
context=context,
|
|
202
|
+
config=custom_inputs,
|
|
203
|
+
stream_mode=["messages", "custom"],
|
|
204
|
+
subgraphs=True,
|
|
188
205
|
):
|
|
189
|
-
|
|
190
|
-
|
|
206
|
+
nodes: tuple[str, ...]
|
|
207
|
+
stream_mode: str
|
|
208
|
+
messages: Sequence[BaseMessage]
|
|
209
|
+
logger.trace(
|
|
210
|
+
f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages}"
|
|
211
|
+
)
|
|
212
|
+
for message in messages:
|
|
213
|
+
if (
|
|
214
|
+
isinstance(
|
|
215
|
+
message,
|
|
216
|
+
(
|
|
217
|
+
AIMessageChunk,
|
|
218
|
+
AIMessage,
|
|
219
|
+
),
|
|
220
|
+
)
|
|
221
|
+
and message.content
|
|
222
|
+
and "summarization" not in nodes
|
|
223
|
+
):
|
|
224
|
+
yield message
|
|
191
225
|
|
|
192
226
|
|
|
193
227
|
def _process_mlflow_messages(
|