dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
dao_ai/__init__.py,sha256=18P98ExEgUaJ1Byw440Ct1ty59v6nxyWtc5S6Uq2m9Q,1062
|
|
2
|
+
dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
|
|
3
|
+
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
|
+
dao_ai/cli.py,sha256=AoGw4erjF1C8_2XEppj9yU8agQKSRnJr-x-I8Y-oAx8,31121
|
|
5
|
+
dao_ai/config.py,sha256=VjFFoBNer7ylE6u3lKgqGVUN4VDlIyeUlTcXuMkNHVo,101027
|
|
6
|
+
dao_ai/graph.py,sha256=1-uQlo7iXZQTT3uU8aYu0N5rnhw5_g_2YLwVsAs6M-U,1119
|
|
7
|
+
dao_ai/messages.py,sha256=4ZBzO4iFdktGSLrmhHzFjzMIt2tpaL-aQLHOQJysGnY,6959
|
|
8
|
+
dao_ai/models.py,sha256=hOUEPyWZNOL0SmA0V9e0fyaO2ZtuAd-rr59M_3CoggQ,77734
|
|
9
|
+
dao_ai/nodes.py,sha256=2DeiR6WmLlfiWMIaT8Gj9FOC2Ewk5bammMKXmVb1VTc,7299
|
|
10
|
+
dao_ai/optimization.py,sha256=KwdKR9njYXe1aIlwGXJ8f5dE7G83aV1CxNqhvjBfGis,22780
|
|
11
|
+
dao_ai/prompts.py,sha256=5Mh-8OYMASMHid58iFAj-rr6K1B41gmNN91Ia8Ciq_I,4846
|
|
12
|
+
dao_ai/state.py,sha256=fNRZ8_M2VBqO8lXS3Egh53PH6LsD3bB8npOG44XWxGA,5641
|
|
13
|
+
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
dao_ai/utils.py,sha256=elkkhNe0L976oTdaEjxHsmW5Fc9VLLS4IjS7q5HgRdA,11401
|
|
15
|
+
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
16
|
+
dao_ai/genie/__init__.py,sha256=vdEyGhrt6L8GlK75SyYvTnl8QpHKDCJC5hJKLg4DesQ,1063
|
|
17
|
+
dao_ai/genie/core.py,sha256=HPKbocvhnnw_PkQwfoq5bpgQmL9lZyyS6_goTJL8yiY,1073
|
|
18
|
+
dao_ai/genie/cache/__init__.py,sha256=JfgCJl1NYQ1aZvZ4kly4T6uQK6ZCJ6PX_htuq7nJF50,1203
|
|
19
|
+
dao_ai/genie/cache/base.py,sha256=_MhHqYrHejVGrJjSLX26TdHwvQZb-HgiantRYSB8fJY,1961
|
|
20
|
+
dao_ai/genie/cache/core.py,sha256=Oe7MSreefcaOri4eSpJJFQBBfNP0MoW7qQdbA0vq89U,2439
|
|
21
|
+
dao_ai/genie/cache/lru.py,sha256=4oXINJnTb-5FvKVd-U3J6gUI9o9jUEoBLOHsxlkvW34,11465
|
|
22
|
+
dao_ai/genie/cache/semantic.py,sha256=RZAzrniuW_VmeWZoHWfbPr2eP3m9gj4fi4vEDM_fvYA,36892
|
|
23
|
+
dao_ai/hooks/__init__.py,sha256=uA4DQdP9gDf4SyNjNx9mWPoI8UZOcTyFsCXV0NraFvQ,463
|
|
24
|
+
dao_ai/hooks/core.py,sha256=upTAI11RncWZ5uu2wBA6nJMaqRMVHJFuzci-iVL8MFw,1700
|
|
25
|
+
dao_ai/memory/__init__.py,sha256=Us3wFehvug_h83m-UJ7OXdq2qZ0e9nHBQE7m5RwoAd8,559
|
|
26
|
+
dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
|
|
27
|
+
dao_ai/memory/core.py,sha256=vhN-cDZxhtykwXFnRK5HWBbMvlbYyafRoIgew-7ufqI,5233
|
|
28
|
+
dao_ai/memory/databricks.py,sha256=PJSNKD1C48-jKXNg3jPYEANKbP4bH-zWLpMZG5SHo0A,14136
|
|
29
|
+
dao_ai/memory/postgres.py,sha256=5T9qTSijxUl2zCEiW3REqgsdsPjgdxfQTGTf6gBY2M8,15362
|
|
30
|
+
dao_ai/middleware/__init__.py,sha256=epSCtCtttIogl21nVK768Ln35L0mOShVczyURtR6Ln8,3609
|
|
31
|
+
dao_ai/middleware/assertions.py,sha256=EgFVqxtD_UWBTQkXHmoduVYRuBYZtBbfcDCieMj3PXM,26651
|
|
32
|
+
dao_ai/middleware/base.py,sha256=uG2tpdnjL5xY5jCKvb_m3UTBtl4ZC6fJQUkDsQvV8S4,1279
|
|
33
|
+
dao_ai/middleware/core.py,sha256=Umq4wgxlsDHN_-LYMjlrQur2R6j9LBYozP7uCEFC3uI,1852
|
|
34
|
+
dao_ai/middleware/guardrails.py,sha256=U2U5bzWTJ5UdceXbEwSZS4brhpFZhFhRmx9RmbcikJE,13967
|
|
35
|
+
dao_ai/middleware/human_in_the_loop.py,sha256=hivU5CugD4D-9j_d5AHqu6atIc6s6GHel5VxHTbgcys,7435
|
|
36
|
+
dao_ai/middleware/message_validation.py,sha256=zBAAlI_th1SdxX06p4_4hCKKkM4v0Dkh8TEWGHeCDYc,18471
|
|
37
|
+
dao_ai/middleware/summarization.py,sha256=VmpFXtUmFA5UybGvrODzzerZ1aoVQEFvgTTU90clo4M,7152
|
|
38
|
+
dao_ai/orchestration/__init__.py,sha256=i85CLfRR335NcCFhaXABcMkn6WZfXnJ8cHH4YZsZN0s,1622
|
|
39
|
+
dao_ai/orchestration/core.py,sha256=4EdorNTs-NV3KkX_K3RtwRsKH3-K4sT_ORO-lsMYpYo,9418
|
|
40
|
+
dao_ai/orchestration/supervisor.py,sha256=tjZwwcwXzKHjwf1x7x0MT0QqywAWP6ODvQhrK9jJ3Jk,9433
|
|
41
|
+
dao_ai/orchestration/swarm.py,sha256=d7lq5KFFBX3qxSoTAIEKVZJtDJiEjSLviGUtas3-4pY,7603
|
|
42
|
+
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
43
|
+
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
44
|
+
dao_ai/providers/databricks.py,sha256=rzFNVLmZYO2Lm3U3kGGQsEdn4W8mm04NUdnSoiPD12o,54714
|
|
45
|
+
dao_ai/tools/__init__.py,sha256=WLb4mgC7WUbaSDOVpjlf9tJkm5Dr4vEmxG5hnyCgeSc,1568
|
|
46
|
+
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
47
|
+
dao_ai/tools/core.py,sha256=f6unLEtYfLZ9eDhoqgaREFylrb9JZ6zlZBYV_45Albo,3785
|
|
48
|
+
dao_ai/tools/email.py,sha256=p7FHlA9gq0LR9-YLuJL1Glq3n465wFu2tqDPX-UW0aY,9662
|
|
49
|
+
dao_ai/tools/genie.py,sha256=zfHAosWMLl-us2vG1c7ziEe8oQehBrxZISWqmC7GZ5s,8942
|
|
50
|
+
dao_ai/tools/mcp.py,sha256=C1NAUKkgl3QUkMTkCZKO1uIMU0gViWOFSjiq-t_tNGk,8454
|
|
51
|
+
dao_ai/tools/memory.py,sha256=lwObKimAand22Nq3Y63tsv-AXQ5SXUigN9PqRjoWKes,1836
|
|
52
|
+
dao_ai/tools/python.py,sha256=D021e4kPlH-CkrQGpzM3n7DZJtTCv2LCW0VPVQGb5Sk,1734
|
|
53
|
+
dao_ai/tools/search.py,sha256=AdQf8UouvJjw5UdCmD1-jrX_PJ-cNZ1oY_jh01dHGxY,433
|
|
54
|
+
dao_ai/tools/slack.py,sha256=NWtzxpUJOxXbngThpb2KYfVvqSOi312crLCCA0xtv-w,4775
|
|
55
|
+
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
56
|
+
dao_ai/tools/unity_catalog.py,sha256=5Wf47t810HOLMue7qhixuy041B8ykRYwzovbLTL13No,14579
|
|
57
|
+
dao_ai/tools/vector_search.py,sha256=tiump9N6o558Y2tIsSox7-R2nQYmdJh0XTyxgRWDmlY,12881
|
|
58
|
+
dao_ai-0.1.1.dist-info/METADATA,sha256=PK7dUXOuoWkCAYg-KnAgzoFVQMt6uzGHvoc494oMKfI,79539
|
|
59
|
+
dao_ai-0.1.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
60
|
+
dao_ai-0.1.1.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
61
|
+
dao_ai-0.1.1.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
62
|
+
dao_ai-0.1.1.dist-info/RECORD,,
|
dao_ai/chat_models.py
DELETED
|
@@ -1,204 +0,0 @@
|
|
|
1
|
-
import json
|
|
2
|
-
from typing import Any, Iterator, Optional, Sequence
|
|
3
|
-
|
|
4
|
-
from databricks_langchain import ChatDatabricks
|
|
5
|
-
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
6
|
-
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
7
|
-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
8
|
-
from loguru import logger
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
class ChatDatabricksFiltered(ChatDatabricks):
|
|
12
|
-
def __init__(self, **kwargs):
|
|
13
|
-
super().__init__(**kwargs)
|
|
14
|
-
|
|
15
|
-
def _preprocess_messages(
|
|
16
|
-
self, messages: Sequence[BaseMessage]
|
|
17
|
-
) -> Sequence[BaseMessage]:
|
|
18
|
-
logger.debug(f"Preprocessing {len(messages)} messages for filtering")
|
|
19
|
-
|
|
20
|
-
logger.trace(
|
|
21
|
-
f"Original messages:\n{json.dumps([msg.model_dump() for msg in messages], indent=2)}"
|
|
22
|
-
)
|
|
23
|
-
|
|
24
|
-
# Diagnostic logging to understand what types of messages we're getting
|
|
25
|
-
message_types = {}
|
|
26
|
-
remove_message_count = 0
|
|
27
|
-
empty_content_count = 0
|
|
28
|
-
|
|
29
|
-
for msg in messages:
|
|
30
|
-
msg_type = msg.__class__.__name__
|
|
31
|
-
message_types[msg_type] = message_types.get(msg_type, 0) + 1
|
|
32
|
-
|
|
33
|
-
if msg_type == "RemoveMessage":
|
|
34
|
-
remove_message_count += 1
|
|
35
|
-
elif hasattr(msg, "content") and (msg.content == "" or msg.content is None):
|
|
36
|
-
empty_content_count += 1
|
|
37
|
-
|
|
38
|
-
logger.debug(f"Message type breakdown: {message_types}")
|
|
39
|
-
logger.debug(
|
|
40
|
-
f"RemoveMessage count: {remove_message_count}, Empty content count: {empty_content_count}"
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
filtered_messages = []
|
|
44
|
-
for i, msg in enumerate(messages):
|
|
45
|
-
# First, filter out RemoveMessage objects completely - they're LangGraph-specific
|
|
46
|
-
# and should never be sent to an LLM
|
|
47
|
-
if hasattr(msg, "__class__") and msg.__class__.__name__ == "RemoveMessage":
|
|
48
|
-
logger.debug(f"Filtering out RemoveMessage at index {i}")
|
|
49
|
-
continue
|
|
50
|
-
|
|
51
|
-
# Be very conservative with filtering - only filter out messages that are:
|
|
52
|
-
# 1. Have empty or None content AND
|
|
53
|
-
# 2. Are not tool-related messages AND
|
|
54
|
-
# 3. Don't break tool_use/tool_result pairing
|
|
55
|
-
# 4. Are not the only remaining message (to avoid filtering everything)
|
|
56
|
-
has_empty_content = hasattr(msg, "content") and (
|
|
57
|
-
msg.content == "" or msg.content is None
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
# Check if this message has tool calls (non-empty list)
|
|
61
|
-
has_tool_calls = (
|
|
62
|
-
hasattr(msg, "tool_calls")
|
|
63
|
-
and msg.tool_calls
|
|
64
|
-
and len(msg.tool_calls) > 0
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# Check if this is a tool result message
|
|
68
|
-
is_tool_result = hasattr(msg, "tool_call_id") or isinstance(
|
|
69
|
-
msg, ToolMessage
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
# Check if the previous message had tool calls (this message might be a tool result)
|
|
73
|
-
prev_had_tool_calls = False
|
|
74
|
-
if i > 0:
|
|
75
|
-
prev_msg = messages[i - 1]
|
|
76
|
-
prev_had_tool_calls = (
|
|
77
|
-
hasattr(prev_msg, "tool_calls")
|
|
78
|
-
and prev_msg.tool_calls
|
|
79
|
-
and len(prev_msg.tool_calls) > 0
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
# Check if the next message is a tool result (this message might be a tool use)
|
|
83
|
-
next_is_tool_result = False
|
|
84
|
-
if i < len(messages) - 1:
|
|
85
|
-
next_msg = messages[i + 1]
|
|
86
|
-
next_is_tool_result = hasattr(next_msg, "tool_call_id") or isinstance(
|
|
87
|
-
next_msg, ToolMessage
|
|
88
|
-
)
|
|
89
|
-
|
|
90
|
-
# Special handling for empty AIMessages - they might be placeholders or incomplete responses
|
|
91
|
-
# Don't filter them if they're the only AI response or seem important to the conversation flow
|
|
92
|
-
is_empty_ai_message = has_empty_content and isinstance(msg, AIMessage)
|
|
93
|
-
|
|
94
|
-
# Only filter out messages with empty content that are definitely not needed
|
|
95
|
-
should_filter = (
|
|
96
|
-
has_empty_content
|
|
97
|
-
and not has_tool_calls
|
|
98
|
-
and not is_tool_result
|
|
99
|
-
and not prev_had_tool_calls # Don't filter if previous message had tool calls
|
|
100
|
-
and not next_is_tool_result # Don't filter if next message is a tool result
|
|
101
|
-
and not (
|
|
102
|
-
is_empty_ai_message and len(messages) <= 2
|
|
103
|
-
) # Don't filter empty AI messages in short conversations
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if should_filter:
|
|
107
|
-
logger.debug(f"Filtering out message at index {i}: {msg.model_dump()}")
|
|
108
|
-
continue
|
|
109
|
-
else:
|
|
110
|
-
filtered_messages.append(msg)
|
|
111
|
-
|
|
112
|
-
logger.debug(
|
|
113
|
-
f"Filtered {len(messages)} messages down to {len(filtered_messages)} messages"
|
|
114
|
-
)
|
|
115
|
-
|
|
116
|
-
# Log diagnostic information if all messages were filtered out
|
|
117
|
-
if len(filtered_messages) == 0:
|
|
118
|
-
logger.warning(
|
|
119
|
-
f"All {len(messages)} messages were filtered out! This indicates a problem with the conversation state."
|
|
120
|
-
)
|
|
121
|
-
logger.debug(f"Original message types: {message_types}")
|
|
122
|
-
|
|
123
|
-
if remove_message_count == len(messages):
|
|
124
|
-
logger.warning(
|
|
125
|
-
"All messages were RemoveMessage objects - this suggests a bug in summarization logic"
|
|
126
|
-
)
|
|
127
|
-
elif empty_content_count > 0:
|
|
128
|
-
logger.debug(f"{empty_content_count} messages had empty content")
|
|
129
|
-
|
|
130
|
-
return filtered_messages
|
|
131
|
-
|
|
132
|
-
def _postprocess_message(self, message: BaseMessage) -> BaseMessage:
|
|
133
|
-
return message
|
|
134
|
-
|
|
135
|
-
def _generate(
|
|
136
|
-
self,
|
|
137
|
-
messages: Sequence[BaseMessage],
|
|
138
|
-
stop: Optional[Sequence[str]] = None,
|
|
139
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
140
|
-
**kwargs: Any,
|
|
141
|
-
) -> ChatResult:
|
|
142
|
-
"""Override _generate to apply message preprocessing and postprocessing."""
|
|
143
|
-
# Apply message preprocessing
|
|
144
|
-
processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
|
|
145
|
-
|
|
146
|
-
if len(processed_messages) == 0:
|
|
147
|
-
logger.error(
|
|
148
|
-
"All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
|
|
149
|
-
)
|
|
150
|
-
empty_generation = ChatGeneration(
|
|
151
|
-
message=AIMessage(content="", id="empty-response")
|
|
152
|
-
)
|
|
153
|
-
return ChatResult(generations=[empty_generation])
|
|
154
|
-
|
|
155
|
-
logger.trace(
|
|
156
|
-
f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
result: ChatResult = super()._generate(
|
|
160
|
-
processed_messages, stop, run_manager, **kwargs
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
if result.generations:
|
|
164
|
-
for generation in result.generations:
|
|
165
|
-
if isinstance(generation, ChatGeneration) and generation.message:
|
|
166
|
-
generation.message = self._postprocess_message(generation.message)
|
|
167
|
-
|
|
168
|
-
return result
|
|
169
|
-
|
|
170
|
-
def _stream(
|
|
171
|
-
self,
|
|
172
|
-
messages: Sequence[BaseMessage],
|
|
173
|
-
stop: Optional[Sequence[str]] = None,
|
|
174
|
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
175
|
-
**kwargs: Any,
|
|
176
|
-
) -> Iterator[ChatGeneration]:
|
|
177
|
-
"""Override _stream to apply message preprocessing and postprocessing."""
|
|
178
|
-
# Apply message preprocessing
|
|
179
|
-
processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
|
|
180
|
-
|
|
181
|
-
# Handle the edge case where all messages were filtered out
|
|
182
|
-
if len(processed_messages) == 0:
|
|
183
|
-
logger.error(
|
|
184
|
-
"All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
|
|
185
|
-
)
|
|
186
|
-
# Return an empty streaming result without calling the underlying API
|
|
187
|
-
# This prevents API errors while making the issue visible through an empty response
|
|
188
|
-
empty_chunk = ChatGenerationChunk(
|
|
189
|
-
message=AIMessage(content="", id="empty-response")
|
|
190
|
-
)
|
|
191
|
-
yield empty_chunk
|
|
192
|
-
return
|
|
193
|
-
|
|
194
|
-
logger.trace(
|
|
195
|
-
f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
|
|
196
|
-
)
|
|
197
|
-
|
|
198
|
-
# Call the parent ChatDatabricks implementation
|
|
199
|
-
for chunk in super()._stream(processed_messages, stop, run_manager, **kwargs):
|
|
200
|
-
chunk: ChatGenerationChunk
|
|
201
|
-
# Apply message postprocessing to each chunk
|
|
202
|
-
if isinstance(chunk, ChatGeneration) and chunk.message:
|
|
203
|
-
chunk.message = self._postprocess_message(chunk.message)
|
|
204
|
-
yield chunk
|
dao_ai/guardrails.py
DELETED
|
@@ -1,112 +0,0 @@
|
|
|
1
|
-
from typing import Any, Literal, Optional, Type
|
|
2
|
-
|
|
3
|
-
from langchain_core.language_models import LanguageModelLike
|
|
4
|
-
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
5
|
-
from langchain_core.runnables import RunnableConfig
|
|
6
|
-
from langchain_core.runnables.base import RunnableLike
|
|
7
|
-
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
8
|
-
from langgraph.graph.state import CompiledStateGraph
|
|
9
|
-
from langgraph.managed import RemainingSteps
|
|
10
|
-
from loguru import logger
|
|
11
|
-
from openevals.llm import create_llm_as_judge
|
|
12
|
-
|
|
13
|
-
from dao_ai.config import GuardrailModel
|
|
14
|
-
from dao_ai.messages import last_ai_message, last_human_message
|
|
15
|
-
from dao_ai.state import SharedState
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class MessagesWithSteps(MessagesState):
|
|
19
|
-
guardrails_remaining_steps: RemainingSteps
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def end_or_reflect(state: MessagesWithSteps) -> Literal[END, "graph"]:
|
|
23
|
-
if state["guardrails_remaining_steps"] < 2:
|
|
24
|
-
return END
|
|
25
|
-
if len(state["messages"]) == 0:
|
|
26
|
-
return END
|
|
27
|
-
last_message = state["messages"][-1]
|
|
28
|
-
if isinstance(last_message, HumanMessage):
|
|
29
|
-
return "graph"
|
|
30
|
-
else:
|
|
31
|
-
return END
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
def create_reflection_graph(
|
|
35
|
-
graph: CompiledStateGraph,
|
|
36
|
-
reflection: CompiledStateGraph,
|
|
37
|
-
state_schema: Optional[Type[Any]] = None,
|
|
38
|
-
config_schema: Optional[Type[Any]] = None,
|
|
39
|
-
) -> StateGraph:
|
|
40
|
-
logger.debug("Creating reflection graph")
|
|
41
|
-
_state_schema = state_schema or graph.builder.schema
|
|
42
|
-
|
|
43
|
-
if "guardrails_remaining_steps" in _state_schema.__annotations__:
|
|
44
|
-
raise ValueError(
|
|
45
|
-
"Has key 'guardrails_remaining_steps' in state_schema, this shadows a built in key"
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
if "messages" not in _state_schema.__annotations__:
|
|
49
|
-
raise ValueError("Missing required key 'messages' in state_schema")
|
|
50
|
-
|
|
51
|
-
class StateSchema(_state_schema):
|
|
52
|
-
guardrails_remaining_steps: RemainingSteps
|
|
53
|
-
|
|
54
|
-
rgraph = StateGraph(StateSchema, config_schema=config_schema)
|
|
55
|
-
rgraph.add_node("graph", graph)
|
|
56
|
-
rgraph.add_node("reflection", reflection)
|
|
57
|
-
rgraph.add_edge(START, "graph")
|
|
58
|
-
rgraph.add_edge("graph", "reflection")
|
|
59
|
-
rgraph.add_conditional_edges("reflection", end_or_reflect)
|
|
60
|
-
return rgraph
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def with_guardrails(
|
|
64
|
-
graph: CompiledStateGraph, guardrail: CompiledStateGraph
|
|
65
|
-
) -> CompiledStateGraph:
|
|
66
|
-
logger.debug("Creating graph with guardrails")
|
|
67
|
-
return create_reflection_graph(
|
|
68
|
-
graph, guardrail, state_schema=SharedState, config_schema=RunnableConfig
|
|
69
|
-
).compile()
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def judge_node(guardrails: GuardrailModel) -> RunnableLike:
|
|
73
|
-
def judge(state: SharedState, config: RunnableConfig) -> dict[str, BaseMessage]:
|
|
74
|
-
llm: LanguageModelLike = guardrails.model.as_chat_model()
|
|
75
|
-
|
|
76
|
-
evaluator = create_llm_as_judge(
|
|
77
|
-
prompt=guardrails.prompt,
|
|
78
|
-
judge=llm,
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
ai_message: AIMessage = last_ai_message(state["messages"])
|
|
82
|
-
human_message: HumanMessage = last_human_message(state["messages"])
|
|
83
|
-
|
|
84
|
-
logger.debug(f"Evaluating response: {ai_message.content}")
|
|
85
|
-
eval_result = evaluator(
|
|
86
|
-
inputs=human_message.content, outputs=ai_message.content
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if eval_result["score"]:
|
|
90
|
-
logger.debug("Response approved by judge")
|
|
91
|
-
logger.debug(f"Judge's comment: {eval_result['comment']}")
|
|
92
|
-
return
|
|
93
|
-
else:
|
|
94
|
-
# Otherwise, return the judge's critique as a new user message
|
|
95
|
-
logger.warning("Judge requested improvements")
|
|
96
|
-
comment: str = eval_result["comment"]
|
|
97
|
-
logger.warning(f"Judge's critique: {comment}")
|
|
98
|
-
content: str = "\n".join([human_message.content, comment])
|
|
99
|
-
return {"messages": [HumanMessage(content=content)]}
|
|
100
|
-
|
|
101
|
-
return judge
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
def reflection_guardrail(guardrails: GuardrailModel) -> CompiledStateGraph:
|
|
105
|
-
judge: CompiledStateGraph = (
|
|
106
|
-
StateGraph(SharedState, config_schema=RunnableConfig)
|
|
107
|
-
.add_node("judge", judge_node(guardrails=guardrails))
|
|
108
|
-
.add_edge(START, "judge")
|
|
109
|
-
.add_edge("judge", END)
|
|
110
|
-
.compile()
|
|
111
|
-
)
|
|
112
|
-
return judge
|
dao_ai/tools/genie/__init__.py
DELETED
|
@@ -1,236 +0,0 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Genie tools for natural language queries to databases.
|
|
3
|
-
|
|
4
|
-
This package provides tools for interacting with Databricks Genie to translate
|
|
5
|
-
natural language questions into SQL queries.
|
|
6
|
-
|
|
7
|
-
Main exports:
|
|
8
|
-
- create_genie_tool: Factory function to create a Genie tool with optional caching
|
|
9
|
-
|
|
10
|
-
Cache implementations are available in the genie cache package:
|
|
11
|
-
- dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
|
|
12
|
-
- dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
|
|
13
|
-
"""
|
|
14
|
-
|
|
15
|
-
import json
|
|
16
|
-
import os
|
|
17
|
-
from textwrap import dedent
|
|
18
|
-
from typing import Annotated, Any, Callable
|
|
19
|
-
|
|
20
|
-
import pandas as pd
|
|
21
|
-
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
22
|
-
from langchain.tools import tool
|
|
23
|
-
from langchain_core.messages import ToolMessage
|
|
24
|
-
from langchain_core.tools import InjectedToolCallId
|
|
25
|
-
from langgraph.prebuilt import InjectedState
|
|
26
|
-
from langgraph.types import Command
|
|
27
|
-
from loguru import logger
|
|
28
|
-
from pydantic import BaseModel
|
|
29
|
-
|
|
30
|
-
from dao_ai.config import (
|
|
31
|
-
AnyVariable,
|
|
32
|
-
CompositeVariableModel,
|
|
33
|
-
GenieLRUCacheParametersModel,
|
|
34
|
-
GenieRoomModel,
|
|
35
|
-
GenieSemanticCacheParametersModel,
|
|
36
|
-
value_of,
|
|
37
|
-
)
|
|
38
|
-
from dao_ai.genie import GenieService
|
|
39
|
-
from dao_ai.genie.cache import (
|
|
40
|
-
CacheResult,
|
|
41
|
-
GenieServiceBase,
|
|
42
|
-
LRUCacheService,
|
|
43
|
-
SemanticCacheService,
|
|
44
|
-
SQLCacheEntry,
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
class GenieToolInput(BaseModel):
|
|
49
|
-
"""Input schema for Genie tool - only includes user-facing parameters."""
|
|
50
|
-
|
|
51
|
-
question: str
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
def _response_to_json(response: GenieResponse) -> str:
|
|
55
|
-
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
56
|
-
# Convert result to string if it's a DataFrame
|
|
57
|
-
result: str | pd.DataFrame = response.result
|
|
58
|
-
if isinstance(result, pd.DataFrame):
|
|
59
|
-
result = result.to_markdown()
|
|
60
|
-
|
|
61
|
-
data: dict[str, Any] = {
|
|
62
|
-
"result": result,
|
|
63
|
-
"query": response.query,
|
|
64
|
-
"description": response.description,
|
|
65
|
-
"conversation_id": response.conversation_id,
|
|
66
|
-
}
|
|
67
|
-
return json.dumps(data)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def create_genie_tool(
|
|
71
|
-
genie_room: GenieRoomModel | dict[str, Any],
|
|
72
|
-
name: str | None = None,
|
|
73
|
-
description: str | None = None,
|
|
74
|
-
persist_conversation: bool = True,
|
|
75
|
-
truncate_results: bool = False,
|
|
76
|
-
lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
|
|
77
|
-
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
78
|
-
| dict[str, Any]
|
|
79
|
-
| None = None,
|
|
80
|
-
) -> Callable[..., Command]:
|
|
81
|
-
"""
|
|
82
|
-
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
83
|
-
|
|
84
|
-
This factory function generates a tool that leverages Databricks Genie to translate natural
|
|
85
|
-
language questions into SQL queries and execute them against retail databases. This enables
|
|
86
|
-
answering questions about inventory, sales, and other structured retail data.
|
|
87
|
-
|
|
88
|
-
Args:
|
|
89
|
-
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
90
|
-
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
91
|
-
description: Optional custom description for the tool. If None, uses default description
|
|
92
|
-
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
93
|
-
multi-turn conversations within the same Genie space
|
|
94
|
-
truncate_results: Whether to truncate large query results to fit token limits
|
|
95
|
-
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
96
|
-
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
97
|
-
for similarity-based query matching
|
|
98
|
-
|
|
99
|
-
Returns:
|
|
100
|
-
A LangGraph tool that processes natural language queries through Genie
|
|
101
|
-
"""
|
|
102
|
-
logger.debug("create_genie_tool")
|
|
103
|
-
logger.debug(f"genie_room type: {type(genie_room)}")
|
|
104
|
-
logger.debug(f"genie_room: {genie_room}")
|
|
105
|
-
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
106
|
-
logger.debug(f"truncate_results: {truncate_results}")
|
|
107
|
-
logger.debug(f"name: {name}")
|
|
108
|
-
logger.debug(f"description: {description}")
|
|
109
|
-
logger.debug(f"genie_room: {genie_room}")
|
|
110
|
-
logger.debug(f"persist_conversation: {persist_conversation}")
|
|
111
|
-
logger.debug(f"truncate_results: {truncate_results}")
|
|
112
|
-
logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
|
|
113
|
-
logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
|
|
114
|
-
|
|
115
|
-
if isinstance(genie_room, dict):
|
|
116
|
-
genie_room = GenieRoomModel(**genie_room)
|
|
117
|
-
|
|
118
|
-
if isinstance(lru_cache_parameters, dict):
|
|
119
|
-
lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
|
|
120
|
-
|
|
121
|
-
if isinstance(semantic_cache_parameters, dict):
|
|
122
|
-
semantic_cache_parameters = GenieSemanticCacheParametersModel(
|
|
123
|
-
**semantic_cache_parameters
|
|
124
|
-
)
|
|
125
|
-
|
|
126
|
-
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
127
|
-
"DATABRICKS_GENIE_SPACE_ID"
|
|
128
|
-
)
|
|
129
|
-
if isinstance(space_id, dict):
|
|
130
|
-
space_id = CompositeVariableModel(**space_id)
|
|
131
|
-
space_id = value_of(space_id)
|
|
132
|
-
|
|
133
|
-
default_description: str = dedent("""
|
|
134
|
-
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
135
|
-
questions about the data and the tool will try to answer them.
|
|
136
|
-
Please ask simple clear questions that can be answer by sql queries. If you need to do statistics or other forms of testing defer to using another tool.
|
|
137
|
-
Try to ask for aggregations on the data and ask very simple questions.
|
|
138
|
-
Prefer to call this tool multiple times rather than asking a complex question.
|
|
139
|
-
""")
|
|
140
|
-
|
|
141
|
-
tool_description: str = (
|
|
142
|
-
description if description is not None else default_description
|
|
143
|
-
)
|
|
144
|
-
tool_name: str = name if name is not None else "genie_tool"
|
|
145
|
-
|
|
146
|
-
function_docs = """
|
|
147
|
-
|
|
148
|
-
Args:
|
|
149
|
-
question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
|
|
150
|
-
|
|
151
|
-
Returns:
|
|
152
|
-
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
153
|
-
tool_description = tool_description + function_docs
|
|
154
|
-
|
|
155
|
-
genie: Genie = Genie(
|
|
156
|
-
space_id=space_id,
|
|
157
|
-
client=genie_room.workspace_client,
|
|
158
|
-
truncate_results=truncate_results,
|
|
159
|
-
)
|
|
160
|
-
|
|
161
|
-
genie_service: GenieServiceBase = GenieService(genie)
|
|
162
|
-
|
|
163
|
-
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
164
|
-
if semantic_cache_parameters is not None:
|
|
165
|
-
genie_service = SemanticCacheService(
|
|
166
|
-
impl=genie_service,
|
|
167
|
-
parameters=semantic_cache_parameters,
|
|
168
|
-
genie_space_id=space_id,
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
172
|
-
if lru_cache_parameters is not None:
|
|
173
|
-
genie_service = LRUCacheService(
|
|
174
|
-
impl=genie_service,
|
|
175
|
-
parameters=lru_cache_parameters,
|
|
176
|
-
)
|
|
177
|
-
|
|
178
|
-
@tool(
|
|
179
|
-
name_or_callable=tool_name,
|
|
180
|
-
description=tool_description,
|
|
181
|
-
)
|
|
182
|
-
def genie_tool(
|
|
183
|
-
question: Annotated[str, "The question to ask Genie about your data"],
|
|
184
|
-
state: Annotated[dict, InjectedState],
|
|
185
|
-
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
186
|
-
) -> Command:
|
|
187
|
-
"""Process a natural language question through Databricks Genie."""
|
|
188
|
-
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
189
|
-
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
190
|
-
existing_conversation_id: str | None = conversation_ids.get(space_id)
|
|
191
|
-
logger.debug(
|
|
192
|
-
f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
response: GenieResponse = genie_service.ask_question(
|
|
196
|
-
question, conversation_id=existing_conversation_id
|
|
197
|
-
)
|
|
198
|
-
|
|
199
|
-
current_conversation_id: str = response.conversation_id
|
|
200
|
-
logger.debug(
|
|
201
|
-
f"Current conversation ID for space {space_id}: {current_conversation_id}"
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
# Update the conversation mapping with the new conversation ID for this space
|
|
205
|
-
|
|
206
|
-
update: dict[str, Any] = {
|
|
207
|
-
"messages": [
|
|
208
|
-
ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
|
|
209
|
-
],
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
if persist_conversation:
|
|
213
|
-
updated_conversation_ids: dict[str, str] = conversation_ids.copy()
|
|
214
|
-
updated_conversation_ids[space_id] = current_conversation_id
|
|
215
|
-
update["genie_conversation_ids"] = updated_conversation_ids
|
|
216
|
-
|
|
217
|
-
return Command(update=update)
|
|
218
|
-
|
|
219
|
-
return genie_tool
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
# Re-export cache types for convenience
|
|
223
|
-
__all__ = [
|
|
224
|
-
# Main tool
|
|
225
|
-
"create_genie_tool",
|
|
226
|
-
# Input types
|
|
227
|
-
"GenieToolInput",
|
|
228
|
-
# Service base classes
|
|
229
|
-
"GenieService",
|
|
230
|
-
"GenieServiceBase",
|
|
231
|
-
# Cache types (from cache subpackage)
|
|
232
|
-
"CacheResult",
|
|
233
|
-
"LRUCacheService",
|
|
234
|
-
"SemanticCacheService",
|
|
235
|
-
"SQLCacheEntry",
|
|
236
|
-
]
|