dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/chat_models.py DELETED
@@ -1,204 +0,0 @@
1
- import json
2
- from typing import Any, Iterator, Optional, Sequence
3
-
4
- from databricks_langchain import ChatDatabricks
5
- from langchain_core.callbacks import CallbackManagerForLLMRun
6
- from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
7
- from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
8
- from loguru import logger
9
-
10
-
11
- class ChatDatabricksFiltered(ChatDatabricks):
12
- def __init__(self, **kwargs):
13
- super().__init__(**kwargs)
14
-
15
- def _preprocess_messages(
16
- self, messages: Sequence[BaseMessage]
17
- ) -> Sequence[BaseMessage]:
18
- logger.debug(f"Preprocessing {len(messages)} messages for filtering")
19
-
20
- logger.trace(
21
- f"Original messages:\n{json.dumps([msg.model_dump() for msg in messages], indent=2)}"
22
- )
23
-
24
- # Diagnostic logging to understand what types of messages we're getting
25
- message_types = {}
26
- remove_message_count = 0
27
- empty_content_count = 0
28
-
29
- for msg in messages:
30
- msg_type = msg.__class__.__name__
31
- message_types[msg_type] = message_types.get(msg_type, 0) + 1
32
-
33
- if msg_type == "RemoveMessage":
34
- remove_message_count += 1
35
- elif hasattr(msg, "content") and (msg.content == "" or msg.content is None):
36
- empty_content_count += 1
37
-
38
- logger.debug(f"Message type breakdown: {message_types}")
39
- logger.debug(
40
- f"RemoveMessage count: {remove_message_count}, Empty content count: {empty_content_count}"
41
- )
42
-
43
- filtered_messages = []
44
- for i, msg in enumerate(messages):
45
- # First, filter out RemoveMessage objects completely - they're LangGraph-specific
46
- # and should never be sent to an LLM
47
- if hasattr(msg, "__class__") and msg.__class__.__name__ == "RemoveMessage":
48
- logger.debug(f"Filtering out RemoveMessage at index {i}")
49
- continue
50
-
51
- # Be very conservative with filtering - only filter out messages that are:
52
- # 1. Have empty or None content AND
53
- # 2. Are not tool-related messages AND
54
- # 3. Don't break tool_use/tool_result pairing
55
- # 4. Are not the only remaining message (to avoid filtering everything)
56
- has_empty_content = hasattr(msg, "content") and (
57
- msg.content == "" or msg.content is None
58
- )
59
-
60
- # Check if this message has tool calls (non-empty list)
61
- has_tool_calls = (
62
- hasattr(msg, "tool_calls")
63
- and msg.tool_calls
64
- and len(msg.tool_calls) > 0
65
- )
66
-
67
- # Check if this is a tool result message
68
- is_tool_result = hasattr(msg, "tool_call_id") or isinstance(
69
- msg, ToolMessage
70
- )
71
-
72
- # Check if the previous message had tool calls (this message might be a tool result)
73
- prev_had_tool_calls = False
74
- if i > 0:
75
- prev_msg = messages[i - 1]
76
- prev_had_tool_calls = (
77
- hasattr(prev_msg, "tool_calls")
78
- and prev_msg.tool_calls
79
- and len(prev_msg.tool_calls) > 0
80
- )
81
-
82
- # Check if the next message is a tool result (this message might be a tool use)
83
- next_is_tool_result = False
84
- if i < len(messages) - 1:
85
- next_msg = messages[i + 1]
86
- next_is_tool_result = hasattr(next_msg, "tool_call_id") or isinstance(
87
- next_msg, ToolMessage
88
- )
89
-
90
- # Special handling for empty AIMessages - they might be placeholders or incomplete responses
91
- # Don't filter them if they're the only AI response or seem important to the conversation flow
92
- is_empty_ai_message = has_empty_content and isinstance(msg, AIMessage)
93
-
94
- # Only filter out messages with empty content that are definitely not needed
95
- should_filter = (
96
- has_empty_content
97
- and not has_tool_calls
98
- and not is_tool_result
99
- and not prev_had_tool_calls # Don't filter if previous message had tool calls
100
- and not next_is_tool_result # Don't filter if next message is a tool result
101
- and not (
102
- is_empty_ai_message and len(messages) <= 2
103
- ) # Don't filter empty AI messages in short conversations
104
- )
105
-
106
- if should_filter:
107
- logger.debug(f"Filtering out message at index {i}: {msg.model_dump()}")
108
- continue
109
- else:
110
- filtered_messages.append(msg)
111
-
112
- logger.debug(
113
- f"Filtered {len(messages)} messages down to {len(filtered_messages)} messages"
114
- )
115
-
116
- # Log diagnostic information if all messages were filtered out
117
- if len(filtered_messages) == 0:
118
- logger.warning(
119
- f"All {len(messages)} messages were filtered out! This indicates a problem with the conversation state."
120
- )
121
- logger.debug(f"Original message types: {message_types}")
122
-
123
- if remove_message_count == len(messages):
124
- logger.warning(
125
- "All messages were RemoveMessage objects - this suggests a bug in summarization logic"
126
- )
127
- elif empty_content_count > 0:
128
- logger.debug(f"{empty_content_count} messages had empty content")
129
-
130
- return filtered_messages
131
-
132
- def _postprocess_message(self, message: BaseMessage) -> BaseMessage:
133
- return message
134
-
135
- def _generate(
136
- self,
137
- messages: Sequence[BaseMessage],
138
- stop: Optional[Sequence[str]] = None,
139
- run_manager: Optional[CallbackManagerForLLMRun] = None,
140
- **kwargs: Any,
141
- ) -> ChatResult:
142
- """Override _generate to apply message preprocessing and postprocessing."""
143
- # Apply message preprocessing
144
- processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
145
-
146
- if len(processed_messages) == 0:
147
- logger.error(
148
- "All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
149
- )
150
- empty_generation = ChatGeneration(
151
- message=AIMessage(content="", id="empty-response")
152
- )
153
- return ChatResult(generations=[empty_generation])
154
-
155
- logger.trace(
156
- f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
157
- )
158
-
159
- result: ChatResult = super()._generate(
160
- processed_messages, stop, run_manager, **kwargs
161
- )
162
-
163
- if result.generations:
164
- for generation in result.generations:
165
- if isinstance(generation, ChatGeneration) and generation.message:
166
- generation.message = self._postprocess_message(generation.message)
167
-
168
- return result
169
-
170
- def _stream(
171
- self,
172
- messages: Sequence[BaseMessage],
173
- stop: Optional[Sequence[str]] = None,
174
- run_manager: Optional[CallbackManagerForLLMRun] = None,
175
- **kwargs: Any,
176
- ) -> Iterator[ChatGeneration]:
177
- """Override _stream to apply message preprocessing and postprocessing."""
178
- # Apply message preprocessing
179
- processed_messages: Sequence[BaseMessage] = self._preprocess_messages(messages)
180
-
181
- # Handle the edge case where all messages were filtered out
182
- if len(processed_messages) == 0:
183
- logger.error(
184
- "All messages were filtered out during preprocessing. This indicates a serious issue with the conversation state."
185
- )
186
- # Return an empty streaming result without calling the underlying API
187
- # This prevents API errors while making the issue visible through an empty response
188
- empty_chunk = ChatGenerationChunk(
189
- message=AIMessage(content="", id="empty-response")
190
- )
191
- yield empty_chunk
192
- return
193
-
194
- logger.trace(
195
- f"Processed messages:\n{json.dumps([msg.model_dump() for msg in processed_messages], indent=2)}"
196
- )
197
-
198
- # Call the parent ChatDatabricks implementation
199
- for chunk in super()._stream(processed_messages, stop, run_manager, **kwargs):
200
- chunk: ChatGenerationChunk
201
- # Apply message postprocessing to each chunk
202
- if isinstance(chunk, ChatGeneration) and chunk.message:
203
- chunk.message = self._postprocess_message(chunk.message)
204
- yield chunk
dao_ai/guardrails.py DELETED
@@ -1,112 +0,0 @@
1
- from typing import Any, Literal, Optional, Type
2
-
3
- from langchain_core.language_models import LanguageModelLike
4
- from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
5
- from langchain_core.runnables import RunnableConfig
6
- from langchain_core.runnables.base import RunnableLike
7
- from langgraph.graph import END, START, MessagesState, StateGraph
8
- from langgraph.graph.state import CompiledStateGraph
9
- from langgraph.managed import RemainingSteps
10
- from loguru import logger
11
- from openevals.llm import create_llm_as_judge
12
-
13
- from dao_ai.config import GuardrailModel
14
- from dao_ai.messages import last_ai_message, last_human_message
15
- from dao_ai.state import SharedState
16
-
17
-
18
- class MessagesWithSteps(MessagesState):
19
- guardrails_remaining_steps: RemainingSteps
20
-
21
-
22
- def end_or_reflect(state: MessagesWithSteps) -> Literal[END, "graph"]:
23
- if state["guardrails_remaining_steps"] < 2:
24
- return END
25
- if len(state["messages"]) == 0:
26
- return END
27
- last_message = state["messages"][-1]
28
- if isinstance(last_message, HumanMessage):
29
- return "graph"
30
- else:
31
- return END
32
-
33
-
34
- def create_reflection_graph(
35
- graph: CompiledStateGraph,
36
- reflection: CompiledStateGraph,
37
- state_schema: Optional[Type[Any]] = None,
38
- config_schema: Optional[Type[Any]] = None,
39
- ) -> StateGraph:
40
- logger.debug("Creating reflection graph")
41
- _state_schema = state_schema or graph.builder.schema
42
-
43
- if "guardrails_remaining_steps" in _state_schema.__annotations__:
44
- raise ValueError(
45
- "Has key 'guardrails_remaining_steps' in state_schema, this shadows a built in key"
46
- )
47
-
48
- if "messages" not in _state_schema.__annotations__:
49
- raise ValueError("Missing required key 'messages' in state_schema")
50
-
51
- class StateSchema(_state_schema):
52
- guardrails_remaining_steps: RemainingSteps
53
-
54
- rgraph = StateGraph(StateSchema, config_schema=config_schema)
55
- rgraph.add_node("graph", graph)
56
- rgraph.add_node("reflection", reflection)
57
- rgraph.add_edge(START, "graph")
58
- rgraph.add_edge("graph", "reflection")
59
- rgraph.add_conditional_edges("reflection", end_or_reflect)
60
- return rgraph
61
-
62
-
63
- def with_guardrails(
64
- graph: CompiledStateGraph, guardrail: CompiledStateGraph
65
- ) -> CompiledStateGraph:
66
- logger.debug("Creating graph with guardrails")
67
- return create_reflection_graph(
68
- graph, guardrail, state_schema=SharedState, config_schema=RunnableConfig
69
- ).compile()
70
-
71
-
72
- def judge_node(guardrails: GuardrailModel) -> RunnableLike:
73
- def judge(state: SharedState, config: RunnableConfig) -> dict[str, BaseMessage]:
74
- llm: LanguageModelLike = guardrails.model.as_chat_model()
75
-
76
- evaluator = create_llm_as_judge(
77
- prompt=guardrails.prompt,
78
- judge=llm,
79
- )
80
-
81
- ai_message: AIMessage = last_ai_message(state["messages"])
82
- human_message: HumanMessage = last_human_message(state["messages"])
83
-
84
- logger.debug(f"Evaluating response: {ai_message.content}")
85
- eval_result = evaluator(
86
- inputs=human_message.content, outputs=ai_message.content
87
- )
88
-
89
- if eval_result["score"]:
90
- logger.debug("Response approved by judge")
91
- logger.debug(f"Judge's comment: {eval_result['comment']}")
92
- return
93
- else:
94
- # Otherwise, return the judge's critique as a new user message
95
- logger.warning("Judge requested improvements")
96
- comment: str = eval_result["comment"]
97
- logger.warning(f"Judge's critique: {comment}")
98
- content: str = "\n".join([human_message.content, comment])
99
- return {"messages": [HumanMessage(content=content)]}
100
-
101
- return judge
102
-
103
-
104
- def reflection_guardrail(guardrails: GuardrailModel) -> CompiledStateGraph:
105
- judge: CompiledStateGraph = (
106
- StateGraph(SharedState, config_schema=RunnableConfig)
107
- .add_node("judge", judge_node(guardrails=guardrails))
108
- .add_edge(START, "judge")
109
- .add_edge("judge", END)
110
- .compile()
111
- )
112
- return judge
@@ -1,100 +0,0 @@
1
- from typing import Any, Optional
2
-
3
- from langchain_core.runnables import RunnableConfig
4
- from langchain_core.runnables.base import RunnableLike
5
- from langchain_core.tools import BaseTool
6
- from langchain_core.tools import tool as create_tool
7
- from langgraph.prebuilt.interrupt import HumanInterrupt, HumanInterruptConfig
8
- from langgraph.types import interrupt
9
- from loguru import logger
10
-
11
- from dao_ai.config import (
12
- BaseFunctionModel,
13
- HumanInTheLoopModel,
14
- )
15
-
16
-
17
- def add_human_in_the_loop(
18
- tool: RunnableLike,
19
- *,
20
- interrupt_config: HumanInterruptConfig | None = None,
21
- review_prompt: Optional[str] = "Please review the tool call",
22
- ) -> BaseTool:
23
- """
24
- Wrap a tool with human-in-the-loop functionality.
25
- This function takes a tool (either a callable or a BaseTool instance) and wraps it
26
- with a human-in-the-loop mechanism. When the tool is invoked, it will first
27
- request human review before executing the tool's logic. The human can choose to
28
- accept, edit the input, or provide a custom response.
29
-
30
- Args:
31
- tool (Callable[..., Any] | BaseTool): _description_
32
- interrupt_config (HumanInterruptConfig | None, optional): _description_. Defaults to None.
33
-
34
- Raises:
35
- ValueError: _description_
36
-
37
- Returns:
38
- BaseTool: _description_
39
- """
40
- if not isinstance(tool, BaseTool):
41
- tool = create_tool(tool)
42
-
43
- if interrupt_config is None:
44
- interrupt_config = {
45
- "allow_accept": True,
46
- "allow_edit": True,
47
- "allow_respond": True,
48
- }
49
-
50
- logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
51
-
52
- @create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
53
- async def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
54
- logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
55
- request: HumanInterrupt = {
56
- "action_request": {
57
- "action": tool.name,
58
- "args": tool_input,
59
- },
60
- "config": interrupt_config,
61
- "description": review_prompt,
62
- }
63
-
64
- logger.debug(f"Human interrupt request: {request}")
65
- response: dict[str, Any] = interrupt([request])[0]
66
- logger.debug(f"Human interrupt response: {response}")
67
-
68
- if response["type"] == "accept":
69
- tool_response = await tool.ainvoke(tool_input, config=config)
70
- elif response["type"] == "edit":
71
- tool_input = response["args"]["args"]
72
- tool_response = await tool.ainvoke(tool_input, config=config)
73
- elif response["type"] == "response":
74
- user_feedback = response["args"]
75
- tool_response = user_feedback
76
- else:
77
- raise ValueError(f"Unknown interrupt response type: {response['type']}")
78
-
79
- return tool_response
80
-
81
- return call_tool_with_interrupt
82
-
83
-
84
- def as_human_in_the_loop(
85
- tool: RunnableLike, function: BaseFunctionModel | str
86
- ) -> RunnableLike:
87
- if isinstance(function, BaseFunctionModel):
88
- human_in_the_loop: HumanInTheLoopModel | None = function.human_in_the_loop
89
- if human_in_the_loop:
90
- # Get tool name safely - handle RunnableBinding objects
91
- tool_name = getattr(tool, "name", None) or getattr(
92
- getattr(tool, "bound", None), "name", "unknown_tool"
93
- )
94
- logger.debug(f"Adding human-in-the-loop to tool: {tool_name}")
95
- tool = add_human_in_the_loop(
96
- tool=tool,
97
- interrupt_config=human_in_the_loop.interupt_config,
98
- review_prompt=human_in_the_loop.review_prompt,
99
- )
100
- return tool