dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,62 @@
1
+ dao_ai/__init__.py,sha256=18P98ExEgUaJ1Byw440Ct1ty59v6nxyWtc5S6Uq2m9Q,1062
2
+ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
3
+ dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
+ dao_ai/cli.py,sha256=AoGw4erjF1C8_2XEppj9yU8agQKSRnJr-x-I8Y-oAx8,31121
5
+ dao_ai/config.py,sha256=VjFFoBNer7ylE6u3lKgqGVUN4VDlIyeUlTcXuMkNHVo,101027
6
+ dao_ai/graph.py,sha256=1-uQlo7iXZQTT3uU8aYu0N5rnhw5_g_2YLwVsAs6M-U,1119
7
+ dao_ai/messages.py,sha256=4ZBzO4iFdktGSLrmhHzFjzMIt2tpaL-aQLHOQJysGnY,6959
8
+ dao_ai/models.py,sha256=hOUEPyWZNOL0SmA0V9e0fyaO2ZtuAd-rr59M_3CoggQ,77734
9
+ dao_ai/nodes.py,sha256=2DeiR6WmLlfiWMIaT8Gj9FOC2Ewk5bammMKXmVb1VTc,7299
10
+ dao_ai/optimization.py,sha256=KwdKR9njYXe1aIlwGXJ8f5dE7G83aV1CxNqhvjBfGis,22780
11
+ dao_ai/prompts.py,sha256=5Mh-8OYMASMHid58iFAj-rr6K1B41gmNN91Ia8Ciq_I,4846
12
+ dao_ai/state.py,sha256=fNRZ8_M2VBqO8lXS3Egh53PH6LsD3bB8npOG44XWxGA,5641
13
+ dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ dao_ai/utils.py,sha256=elkkhNe0L976oTdaEjxHsmW5Fc9VLLS4IjS7q5HgRdA,11401
15
+ dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
16
+ dao_ai/genie/__init__.py,sha256=vdEyGhrt6L8GlK75SyYvTnl8QpHKDCJC5hJKLg4DesQ,1063
17
+ dao_ai/genie/core.py,sha256=HPKbocvhnnw_PkQwfoq5bpgQmL9lZyyS6_goTJL8yiY,1073
18
+ dao_ai/genie/cache/__init__.py,sha256=JfgCJl1NYQ1aZvZ4kly4T6uQK6ZCJ6PX_htuq7nJF50,1203
19
+ dao_ai/genie/cache/base.py,sha256=_MhHqYrHejVGrJjSLX26TdHwvQZb-HgiantRYSB8fJY,1961
20
+ dao_ai/genie/cache/core.py,sha256=Oe7MSreefcaOri4eSpJJFQBBfNP0MoW7qQdbA0vq89U,2439
21
+ dao_ai/genie/cache/lru.py,sha256=4oXINJnTb-5FvKVd-U3J6gUI9o9jUEoBLOHsxlkvW34,11465
22
+ dao_ai/genie/cache/semantic.py,sha256=RZAzrniuW_VmeWZoHWfbPr2eP3m9gj4fi4vEDM_fvYA,36892
23
+ dao_ai/hooks/__init__.py,sha256=uA4DQdP9gDf4SyNjNx9mWPoI8UZOcTyFsCXV0NraFvQ,463
24
+ dao_ai/hooks/core.py,sha256=upTAI11RncWZ5uu2wBA6nJMaqRMVHJFuzci-iVL8MFw,1700
25
+ dao_ai/memory/__init__.py,sha256=Us3wFehvug_h83m-UJ7OXdq2qZ0e9nHBQE7m5RwoAd8,559
26
+ dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
27
+ dao_ai/memory/core.py,sha256=vhN-cDZxhtykwXFnRK5HWBbMvlbYyafRoIgew-7ufqI,5233
28
+ dao_ai/memory/databricks.py,sha256=PJSNKD1C48-jKXNg3jPYEANKbP4bH-zWLpMZG5SHo0A,14136
29
+ dao_ai/memory/postgres.py,sha256=5T9qTSijxUl2zCEiW3REqgsdsPjgdxfQTGTf6gBY2M8,15362
30
+ dao_ai/middleware/__init__.py,sha256=epSCtCtttIogl21nVK768Ln35L0mOShVczyURtR6Ln8,3609
31
+ dao_ai/middleware/assertions.py,sha256=EgFVqxtD_UWBTQkXHmoduVYRuBYZtBbfcDCieMj3PXM,26651
32
+ dao_ai/middleware/base.py,sha256=uG2tpdnjL5xY5jCKvb_m3UTBtl4ZC6fJQUkDsQvV8S4,1279
33
+ dao_ai/middleware/core.py,sha256=Umq4wgxlsDHN_-LYMjlrQur2R6j9LBYozP7uCEFC3uI,1852
34
+ dao_ai/middleware/guardrails.py,sha256=U2U5bzWTJ5UdceXbEwSZS4brhpFZhFhRmx9RmbcikJE,13967
35
+ dao_ai/middleware/human_in_the_loop.py,sha256=hivU5CugD4D-9j_d5AHqu6atIc6s6GHel5VxHTbgcys,7435
36
+ dao_ai/middleware/message_validation.py,sha256=zBAAlI_th1SdxX06p4_4hCKKkM4v0Dkh8TEWGHeCDYc,18471
37
+ dao_ai/middleware/summarization.py,sha256=VmpFXtUmFA5UybGvrODzzerZ1aoVQEFvgTTU90clo4M,7152
38
+ dao_ai/orchestration/__init__.py,sha256=i85CLfRR335NcCFhaXABcMkn6WZfXnJ8cHH4YZsZN0s,1622
39
+ dao_ai/orchestration/core.py,sha256=4EdorNTs-NV3KkX_K3RtwRsKH3-K4sT_ORO-lsMYpYo,9418
40
+ dao_ai/orchestration/supervisor.py,sha256=tjZwwcwXzKHjwf1x7x0MT0QqywAWP6ODvQhrK9jJ3Jk,9433
41
+ dao_ai/orchestration/swarm.py,sha256=d7lq5KFFBX3qxSoTAIEKVZJtDJiEjSLviGUtas3-4pY,7603
42
+ dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
43
+ dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
44
+ dao_ai/providers/databricks.py,sha256=rzFNVLmZYO2Lm3U3kGGQsEdn4W8mm04NUdnSoiPD12o,54714
45
+ dao_ai/tools/__init__.py,sha256=WLb4mgC7WUbaSDOVpjlf9tJkm5Dr4vEmxG5hnyCgeSc,1568
46
+ dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
47
+ dao_ai/tools/core.py,sha256=f6unLEtYfLZ9eDhoqgaREFylrb9JZ6zlZBYV_45Albo,3785
48
+ dao_ai/tools/email.py,sha256=p7FHlA9gq0LR9-YLuJL1Glq3n465wFu2tqDPX-UW0aY,9662
49
+ dao_ai/tools/genie.py,sha256=zfHAosWMLl-us2vG1c7ziEe8oQehBrxZISWqmC7GZ5s,8942
50
+ dao_ai/tools/mcp.py,sha256=C1NAUKkgl3QUkMTkCZKO1uIMU0gViWOFSjiq-t_tNGk,8454
51
+ dao_ai/tools/memory.py,sha256=lwObKimAand22Nq3Y63tsv-AXQ5SXUigN9PqRjoWKes,1836
52
+ dao_ai/tools/python.py,sha256=D021e4kPlH-CkrQGpzM3n7DZJtTCv2LCW0VPVQGb5Sk,1734
53
+ dao_ai/tools/search.py,sha256=AdQf8UouvJjw5UdCmD1-jrX_PJ-cNZ1oY_jh01dHGxY,433
54
+ dao_ai/tools/slack.py,sha256=NWtzxpUJOxXbngThpb2KYfVvqSOi312crLCCA0xtv-w,4775
55
+ dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
56
+ dao_ai/tools/unity_catalog.py,sha256=5Wf47t810HOLMue7qhixuy041B8ykRYwzovbLTL13No,14579
57
+ dao_ai/tools/vector_search.py,sha256=tiump9N6o558Y2tIsSox7-R2nQYmdJh0XTyxgRWDmlY,12881
58
+ dao_ai-0.1.0.dist-info/METADATA,sha256=SNrqRm_xRtHprL4WL5I6EcSUBH9Fzb85o6TmXNDV2oY,79539
59
+ dao_ai-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
60
+ dao_ai-0.1.0.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
61
+ dao_ai-0.1.0.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
62
+ dao_ai-0.1.0.dist-info/RECORD,,
dao_ai/chat_models.py DELETED
@@ -1,204 +0,0 @@
1
- import json
2
- from typing import Any, Iterator, Optional, Sequence
3
-
4
- from databricks_langchain import ChatDatabricks
5
- from langchain_core.callbacks import CallbackManagerForLLMRun
6
- from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
7
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
8
- from loguru import logger
9
-
10
-
11
- class ChatDatabricksFiltered(ChatDatabricks):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def _preprocess_messages(
16
- self, messages: Sequence[BaseMessage]
17
- ) -> Sequence[BaseMessage]:
18
- logger.debug(f"Preprocessing {len(messages)} messages for filtering")
19
-
20
- logger.trace(
21
- f"Original messages:\n{json.dumps([msg.model_dump() for msg in messages], indent=2)}"
22
- )
23
-
24
- # Diagnostic logging to understand what types of messages we're getting
25
- message_types = {}
26
- remove_message_count = 0
27
- empty_content_count = 0
28
-
29
- for msg in messages:
30
- msg_type = msg.__class__.__name__
31
- message_types[msg_type] = message_types.get(msg_type, 0) + 1
32
-
33
- if msg_type == "RemoveMessage":
34
- remove_message_count += 1
35
- elif hasattr(msg, "content") and (msg.content == "" or msg.content is None):
36
- empty_content_count += 1
37
-
38
- logger.debug(f"Message type breakdown: {message_types}")
39
- logger.debug(
40
- f"RemoveMessage count: {remove_message_count}, Empty content count: {empty_content_count}"
41
- )
42
-
43
- filtered_messages = []
44
- for i, msg in enumerate(messages):
45
- # First, filter out RemoveMessage objects completely - they're LangGraph-specific
46
- # and should never be sent to an LLM
47
- if hasattr(msg, "__class__") and msg.__class__.__name__ == "RemoveMessage":
48
- logger.debug(f"Filtering out RemoveMessage at index {i}")
49
- continue
50
-
51
- # Be very conservative with filtering - only filter out messages that are:
52
- # 1. Have empty or None content AND
53
- # 2. Are not tool-related messages AND
54
- # 3. Don't break tool_use/tool_result pairing
55
- # 4. Are not the only remaining message (to avoid filtering everything)
56
- has_empty_content = hasattr(msg, "content") and (
57
- msg.content == "" or msg.content is None
58
- )
59
-
60
- # Check if this message has tool calls (non-empty list)
61
- has_tool_calls = (
62
- hasattr(msg, "tool_calls")
63
- and msg.tool_calls
64
- and len(msg.tool_calls) > 0
65
- )
66
-
67
- # Check if this is a tool result message
68
- is_tool_result = hasattr(msg, "tool_call_id") or isinstance(
69
- msg, ToolMessage
70
- )
71
-
72
- # Check if the previous message had tool calls (this message might be a tool result)
73
- prev_had_tool_calls = False
74
- if i > 0:
75
- prev_msg = messages[i - 1]
76
- prev_had_tool_calls = (
77
- hasattr(prev_msg, "tool_calls")
78
- and prev_msg.tool_calls
79
- and len(prev_msg.tool_calls) > 0
80
- )
81
-
82
- # Check if the next message is a tool result (this message might be a tool use)
83
- next_is_tool_result = False
84
- if i < len(messages) - 1:
85
- next_msg = messages[i + 1]
86
- next_is_tool_result = hasattr(next_msg, "tool_call_id") or isinstance(
87
- next_msg, ToolMessage
88
- )
89
-
90
- # Special handling for empty AIMessages - they might be placeholders or incomplete responses
91
- # Don't filter them if they're the only AI response or seem important to the conversation flow
92
- is_empty_ai_message = has_empty_content and isinstance(msg, AIMessage)
93
-
94
- # Only filter out messages with empty content that are definitely not needed
95
- should_filter = (
96
- has_empty_content
97
- and not has_tool_calls
98
- and not is_tool_result
99
- and not prev_had_tool_calls # Don't filter if previous message had tool calls
100
- and not next_is_tool_result # Don't filter if next message is a tool result
101
- and not (
102
- is_empty_ai_message and len(messages) <= 2
103
- ) # Don't filter empty AI messages in short conversations
104
- )
105
-
106
- if should_filter:
107
- logger.debug(f"Filtering out message at index {i}: {msg.model_dump()}")
108
- continue
109
- else:
110
- filtered_messages.append(msg)
111
-
112
- logger.debug(
113
- f"Filtered {len(messages)} messages down to {len(filtered_messages)} messages"
114
- )
115
-
116
- # Log diagnostic information if all messages were filtered out
117
- if len(filtered_messages) == 0:
118
- logger.warning(
119
- f"All {len(messages)} messages were filtered out! This indicates a problem with the conversation state."
120
- )
121
- logger.debug(f"Original message types: {message_types}")
122
-
123
- if remove_message_count == len(messages):
124
- logger.warning(
125
- "All messages were RemoveMessage objects - this suggests a bug in summarization logic"
126
- )
127
- elif empty_content_count > 0:
128
- logger.debug(f"{empty_content_count} messages had empty content")
129
-
130
- return filtered_messages
131
-
132
- def _postprocess_message(self, message: BaseMessage) -> BaseMessage:
133
- return message
134
-
135
- def _generate(
136
- self,
137
- messages: Sequence[BaseMessage],
138
- stop: Optional[Sequence[str]] = None,
139
- run_manager: Optional[CallbackManagerForLLMRun] = None,
140
- **kwargs: Any,
141
- ) -> ChatResult:
142
- """Override _generate to apply message preprocessing and postprocessing."""
143
- # Apply message preprocessing
144
- processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
145
-
146
- if len(processed_messages) == 0:
147
- logger.error(
148
- "All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
149
- )
150
- empty_generation = ChatGeneration(
151
- message=AIMessage(content="", id="empty-response")
152
- )
153
- return ChatResult(generations=[empty_generation])
154
-
155
- logger.trace(
156
- f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
157
- )
158
-
159
- result: ChatResult = super()._generate(
160
- processed_messages, stop, run_manager, **kwargs
161
- )
162
-
163
- if result.generations:
164
- for generation in result.generations:
165
- if isinstance(generation, ChatGeneration) and generation.message:
166
- generation.message = self._postprocess_message(generation.message)
167
-
168
- return result
169
-
170
- def _stream(
171
- self,
172
- messages: Sequence[BaseMessage],
173
- stop: Optional[Sequence[str]] = None,
174
- run_manager: Optional[CallbackManagerForLLMRun] = None,
175
- **kwargs: Any,
176
- ) -> Iterator[ChatGeneration]:
177
- """Override _stream to apply message preprocessing and postprocessing."""
178
- # Apply message preprocessing
179
- processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
180
-
181
- # Handle the edge case where all messages were filtered out
182
- if len(processed_messages) == 0:
183
- logger.error(
184
- "All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
185
- )
186
- # Return an empty streaming result without calling the underlying API
187
- # This prevents API errors while making the issue visible through an empty response
188
- empty_chunk = ChatGenerationChunk(
189
- message=AIMessage(content="", id="empty-response")
190
- )
191
- yield empty_chunk
192
- return
193
-
194
- logger.trace(
195
- f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
196
- )
197
-
198
- # Call the parent ChatDatabricks implementation
199
- for chunk in super()._stream(processed_messages, stop, run_manager, **kwargs):
200
- chunk: ChatGenerationChunk
201
- # Apply message postprocessing to each chunk
202
- if isinstance(chunk, ChatGeneration) and chunk.message:
203
- chunk.message = self._postprocess_message(chunk.message)
204
- yield chunk
dao_ai/guardrails.py DELETED
@@ -1,112 +0,0 @@
1
- from typing import Any, Literal, Optional, Type
2
-
3
- from langchain_core.language_models import LanguageModelLike
4
- from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
5
- from langchain_core.runnables import RunnableConfig
6
- from langchain_core.runnables.base import RunnableLike
7
- from langgraph.graph import END, START, MessagesState, StateGraph
8
- from langgraph.graph.state import CompiledStateGraph
9
- from langgraph.managed import RemainingSteps
10
- from loguru import logger
11
- from openevals.llm import create_llm_as_judge
12
-
13
- from dao_ai.config import GuardrailModel
14
- from dao_ai.messages import last_ai_message, last_human_message
15
- from dao_ai.state import SharedState
16
-
17
-
18
- class MessagesWithSteps(MessagesState):
19
- guardrails_remaining_steps: RemainingSteps
20
-
21
-
22
- def end_or_reflect(state: MessagesWithSteps) -> Literal[END, "graph"]:
23
- if state["guardrails_remaining_steps"] < 2:
24
- return END
25
- if len(state["messages"]) == 0:
26
- return END
27
- last_message = state["messages"][-1]
28
- if isinstance(last_message, HumanMessage):
29
- return "graph"
30
- else:
31
- return END
32
-
33
-
34
- def create_reflection_graph(
35
- graph: CompiledStateGraph,
36
- reflection: CompiledStateGraph,
37
- state_schema: Optional[Type[Any]] = None,
38
- config_schema: Optional[Type[Any]] = None,
39
- ) -> StateGraph:
40
- logger.debug("Creating reflection graph")
41
- _state_schema = state_schema or graph.builder.schema
42
-
43
- if "guardrails_remaining_steps" in _state_schema.__annotations__:
44
- raise ValueError(
45
- "Has key 'guardrails_remaining_steps' in state_schema, this shadows a built in key"
46
- )
47
-
48
- if "messages" not in _state_schema.__annotations__:
49
- raise ValueError("Missing required key 'messages' in state_schema")
50
-
51
- class StateSchema(_state_schema):
52
- guardrails_remaining_steps: RemainingSteps
53
-
54
- rgraph = StateGraph(StateSchema, config_schema=config_schema)
55
- rgraph.add_node("graph", graph)
56
- rgraph.add_node("reflection", reflection)
57
- rgraph.add_edge(START, "graph")
58
- rgraph.add_edge("graph", "reflection")
59
- rgraph.add_conditional_edges("reflection", end_or_reflect)
60
- return rgraph
61
-
62
-
63
- def with_guardrails(
64
- graph: CompiledStateGraph, guardrail: CompiledStateGraph
65
- ) -> CompiledStateGraph:
66
- logger.debug("Creating graph with guardrails")
67
- return create_reflection_graph(
68
- graph, guardrail, state_schema=SharedState, config_schema=RunnableConfig
69
- ).compile()
70
-
71
-
72
- def judge_node(guardrails: GuardrailModel) -> RunnableLike:
73
- def judge(state: SharedState, config: RunnableConfig) -> dict[str, BaseMessage]:
74
- llm: LanguageModelLike = guardrails.model.as_chat_model()
75
-
76
- evaluator = create_llm_as_judge(
77
- prompt=guardrails.prompt,
78
- judge=llm,
79
- )
80
-
81
- ai_message: AIMessage = last_ai_message(state["messages"])
82
- human_message: HumanMessage = last_human_message(state["messages"])
83
-
84
- logger.debug(f"Evaluating response: {ai_message.content}")
85
- eval_result = evaluator(
86
- inputs=human_message.content, outputs=ai_message.content
87
- )
88
-
89
- if eval_result["score"]:
90
- logger.debug("Response approved by judge")
91
- logger.debug(f"Judge's comment: {eval_result['comment']}")
92
- return
93
- else:
94
- # Otherwise, return the judge's critique as a new user message
95
- logger.warning("Judge requested improvements")
96
- comment: str = eval_result["comment"]
97
- logger.warning(f"Judge's critique: {comment}")
98
- content: str = "\n".join([human_message.content, comment])
99
- return {"messages": [HumanMessage(content=content)]}
100
-
101
- return judge
102
-
103
-
104
- def reflection_guardrail(guardrails: GuardrailModel) -> CompiledStateGraph:
105
- judge: CompiledStateGraph = (
106
- StateGraph(SharedState, config_schema=RunnableConfig)
107
- .add_node("judge", judge_node(guardrails=guardrails))
108
- .add_edge(START, "judge")
109
- .add_edge("judge", END)
110
- .compile()
111
- )
112
- return judge
@@ -1,236 +0,0 @@
1
- """
2
- Genie tools for natural language queries to databases.
3
-
4
- This package provides tools for interacting with Databricks Genie to translate
5
- natural language questions into SQL queries.
6
-
7
- Main exports:
8
- - create_genie_tool: Factory function to create a Genie tool with optional caching
9
-
10
- Cache implementations are available in the genie cache package:
11
- - dao_ai.genie.cache.lru: LRU (Least Recently Used) cache
12
- - dao_ai.genie.cache.semantic: Semantic similarity cache using pg_vector
13
- """
14
-
15
- import json
16
- import os
17
- from textwrap import dedent
18
- from typing import Annotated, Any, Callable
19
-
20
- import pandas as pd
21
- from databricks_ai_bridge.genie import Genie, GenieResponse
22
- from langchain.tools import tool
23
- from langchain_core.messages import ToolMessage
24
- from langchain_core.tools import InjectedToolCallId
25
- from langgraph.prebuilt import InjectedState
26
- from langgraph.types import Command
27
- from loguru import logger
28
- from pydantic import BaseModel
29
-
30
- from dao_ai.config import (
31
- AnyVariable,
32
- CompositeVariableModel,
33
- GenieLRUCacheParametersModel,
34
- GenieRoomModel,
35
- GenieSemanticCacheParametersModel,
36
- value_of,
37
- )
38
- from dao_ai.genie import GenieService
39
- from dao_ai.genie.cache import (
40
- CacheResult,
41
- GenieServiceBase,
42
- LRUCacheService,
43
- SemanticCacheService,
44
- SQLCacheEntry,
45
- )
46
-
47
-
48
- class GenieToolInput(BaseModel):
49
- """Input schema for Genie tool - only includes user-facing parameters."""
50
-
51
- question: str
52
-
53
-
54
- def _response_to_json(response: GenieResponse) -> str:
55
- """Convert GenieResponse to JSON string, handling DataFrame results."""
56
- # Convert result to string if it's a DataFrame
57
- result: str | pd.DataFrame = response.result
58
- if isinstance(result, pd.DataFrame):
59
- result = result.to_markdown()
60
-
61
- data: dict[str, Any] = {
62
- "result": result,
63
- "query": response.query,
64
- "description": response.description,
65
- "conversation_id": response.conversation_id,
66
- }
67
- return json.dumps(data)
68
-
69
-
70
- def create_genie_tool(
71
- genie_room: GenieRoomModel | dict[str, Any],
72
- name: str | None = None,
73
- description: str | None = None,
74
- persist_conversation: bool = True,
75
- truncate_results: bool = False,
76
- lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
77
- semantic_cache_parameters: GenieSemanticCacheParametersModel
78
- | dict[str, Any]
79
- | None = None,
80
- ) -> Callable[..., Command]:
81
- """
82
- Create a tool for interacting with Databricks Genie for natural language queries to databases.
83
-
84
- This factory function generates a tool that leverages Databricks Genie to translate natural
85
- language questions into SQL queries and execute them against retail databases. This enables
86
- answering questions about inventory, sales, and other structured retail data.
87
-
88
- Args:
89
- genie_room: GenieRoomModel or dict containing Genie configuration
90
- name: Optional custom name for the tool. If None, uses default "genie_tool"
91
- description: Optional custom description for the tool. If None, uses default description
92
- persist_conversation: Whether to persist conversation IDs across tool calls for
93
- multi-turn conversations within the same Genie space
94
- truncate_results: Whether to truncate large query results to fit token limits
95
- lru_cache_parameters: Optional LRU cache configuration for SQL query caching
96
- semantic_cache_parameters: Optional semantic cache configuration using pg_vector
97
- for similarity-based query matching
98
-
99
- Returns:
100
- A LangGraph tool that processes natural language queries through Genie
101
- """
102
- logger.debug("create_genie_tool")
103
- logger.debug(f"genie_room type: {type(genie_room)}")
104
- logger.debug(f"genie_room: {genie_room}")
105
- logger.debug(f"persist_conversation: {persist_conversation}")
106
- logger.debug(f"truncate_results: {truncate_results}")
107
- logger.debug(f"name: {name}")
108
- logger.debug(f"description: {description}")
109
- logger.debug(f"genie_room: {genie_room}")
110
- logger.debug(f"persist_conversation: {persist_conversation}")
111
- logger.debug(f"truncate_results: {truncate_results}")
112
- logger.debug(f"lru_cache_parameters: {lru_cache_parameters}")
113
- logger.debug(f"semantic_cache_parameters: {semantic_cache_parameters}")
114
-
115
- if isinstance(genie_room, dict):
116
- genie_room = GenieRoomModel(**genie_room)
117
-
118
- if isinstance(lru_cache_parameters, dict):
119
- lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
120
-
121
- if isinstance(semantic_cache_parameters, dict):
122
- semantic_cache_parameters = GenieSemanticCacheParametersModel(
123
- **semantic_cache_parameters
124
- )
125
-
126
- space_id: AnyVariable = genie_room.space_id or os.environ.get(
127
- "DATABRICKS_GENIE_SPACE_ID"
128
- )
129
- if isinstance(space_id, dict):
130
- space_id = CompositeVariableModel(**space_id)
131
- space_id = value_of(space_id)
132
-
133
- default_description: str = dedent("""
134
- This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
135
- questions about the data and the tool will try to answer them.
136
- Please ask simple clear questions that can be answer by sql queries. If you need to do statistics or other forms of testing defer to using another tool.
137
- Try to ask for aggregations on the data and ask very simple questions.
138
- Prefer to call this tool multiple times rather than asking a complex question.
139
- """)
140
-
141
- tool_description: str = (
142
- description if description is not None else default_description
143
- )
144
- tool_name: str = name if name is not None else "genie_tool"
145
-
146
- function_docs = """
147
-
148
- Args:
149
- question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
150
-
151
- Returns:
152
- GenieResponse: A response object containing the conversation ID and result from Genie."""
153
- tool_description = tool_description + function_docs
154
-
155
- genie: Genie = Genie(
156
- space_id=space_id,
157
- client=genie_room.workspace_client,
158
- truncate_results=truncate_results,
159
- )
160
-
161
- genie_service: GenieServiceBase = GenieService(genie)
162
-
163
- # Wrap with semantic cache first (checked second due to decorator pattern)
164
- if semantic_cache_parameters is not None:
165
- genie_service = SemanticCacheService(
166
- impl=genie_service,
167
- parameters=semantic_cache_parameters,
168
- genie_space_id=space_id,
169
- )
170
-
171
- # Wrap with LRU cache last (checked first - fast O(1) exact match)
172
- if lru_cache_parameters is not None:
173
- genie_service = LRUCacheService(
174
- impl=genie_service,
175
- parameters=lru_cache_parameters,
176
- )
177
-
178
- @tool(
179
- name_or_callable=tool_name,
180
- description=tool_description,
181
- )
182
- def genie_tool(
183
- question: Annotated[str, "The question to ask Genie about your data"],
184
- state: Annotated[dict, InjectedState],
185
- tool_call_id: Annotated[str, InjectedToolCallId],
186
- ) -> Command:
187
- """Process a natural language question through Databricks Genie."""
188
- # Get existing conversation mapping and retrieve conversation ID for this space
189
- conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
190
- existing_conversation_id: str | None = conversation_ids.get(space_id)
191
- logger.debug(
192
- f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
193
- )
194
-
195
- response: GenieResponse = genie_service.ask_question(
196
- question, conversation_id=existing_conversation_id
197
- )
198
-
199
- current_conversation_id: str = response.conversation_id
200
- logger.debug(
201
- f"Current conversation ID for space {space_id}: {current_conversation_id}"
202
- )
203
-
204
- # Update the conversation mapping with the new conversation ID for this space
205
-
206
- update: dict[str, Any] = {
207
- "messages": [
208
- ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
209
- ],
210
- }
211
-
212
- if persist_conversation:
213
- updated_conversation_ids: dict[str, str] = conversation_ids.copy()
214
- updated_conversation_ids[space_id] = current_conversation_id
215
- update["genie_conversation_ids"] = updated_conversation_ids
216
-
217
- return Command(update=update)
218
-
219
- return genie_tool
220
-
221
-
222
- # Re-export cache types for convenience
223
- __all__ = [
224
- # Main tool
225
- "create_genie_tool",
226
- # Input types
227
- "GenieToolInput",
228
- # Service base classes
229
- "GenieService",
230
- "GenieServiceBase",
231
- # Cache types (from cache subpackage)
232
- "CacheResult",
233
- "LRUCacheService",
234
- "SemanticCacheService",
235
- "SQLCacheEntry",
236
- ]