dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py CHANGED
@@ -1,101 +1,103 @@
1
- from typing import Any, Callable, Optional, Sequence
1
+ """
2
+ Node creation utilities for DAO AI agents.
2
3
 
3
- import mlflow
4
+ This module provides factory functions for creating LangGraph nodes
5
+ that implement agent logic using LangChain v1's create_agent pattern.
6
+ """
7
+
8
+ from typing import Any, Optional, Sequence
9
+
10
+ from langchain.agents import create_agent
11
+ from langchain.agents.middleware import AgentMiddleware
4
12
  from langchain_core.language_models import LanguageModelLike
5
- from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
6
- from langchain_core.messages.utils import count_tokens_approximately
7
- from langchain_core.runnables import RunnableConfig
8
13
  from langchain_core.runnables.base import RunnableLike
9
14
  from langchain_core.tools import BaseTool
10
- from langgraph.graph import StateGraph
11
15
  from langgraph.graph.state import CompiledStateGraph
12
- from langgraph.prebuilt import create_react_agent
13
- from langgraph.runtime import Runtime
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from langmem.short_term import SummarizationNode
16
- from langmem.short_term.summarization import TokenCounter
16
+ from langmem import create_manage_memory_tool
17
17
  from loguru import logger
18
18
 
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
- AppConfig,
22
21
  ChatHistoryModel,
23
- FunctionHook,
24
22
  MemoryModel,
23
+ PromptModel,
25
24
  ToolModel,
26
25
  )
27
- from dao_ai.guardrails import reflection_guardrail, with_guardrails
28
- from dao_ai.hooks.core import create_hooks
26
+ from dao_ai.middleware.core import create_factory_middleware
27
+ from dao_ai.middleware.guardrails import GuardrailMiddleware
28
+ from dao_ai.middleware.human_in_the_loop import create_hitl_middleware_from_tool_models
29
+ from dao_ai.middleware.summarization import create_summarization_middleware
29
30
  from dao_ai.prompts import make_prompt
30
- from dao_ai.state import Context, IncomingState, SharedState
31
+ from dao_ai.state import AgentState, Context
31
32
  from dao_ai.tools import create_tools
33
+ from dao_ai.tools.memory import create_search_memory_tool
32
34
 
33
35
 
34
- def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
36
+ def _create_middleware_list(
37
+ agent: AgentModel,
38
+ tool_models: Sequence[ToolModel],
39
+ chat_history: Optional[ChatHistoryModel] = None,
40
+ ) -> list[Any]:
35
41
  """
36
- Create a summarization node for managing chat history.
42
+ Create a list of middleware instances from agent configuration.
37
43
 
38
44
  Args:
39
- chat_history: ChatHistoryModel configuration for summarization
45
+ agent: AgentModel configuration
46
+ tool_models: Tool model configurations (for HITL settings)
47
+ chat_history: Optional chat history configuration for summarization
40
48
 
41
49
  Returns:
42
- RunnableLike: A summarization node that processes messages
50
+ List of middleware instances (can include both AgentMiddleware and
51
+ LangChain built-in middleware)
43
52
  """
44
- if chat_history is None:
45
- raise ValueError("chat_history must be provided to use summarization")
46
-
47
- max_tokens: int = chat_history.max_tokens
48
- max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
49
- max_messages_before_summary: int | None = chat_history.max_messages_before_summary
50
- max_summary_tokens: int | None = chat_history.max_summary_tokens
51
- token_counter: TokenCounter = (
52
- count_tokens_approximately if max_tokens_before_summary else len
53
- )
54
-
55
- logger.debug(
56
- f"Creating summarization node with max_tokens: {max_tokens}, "
57
- f"max_tokens_before_summary: {max_tokens_before_summary}, "
58
- f"max_messages_before_summary: {max_messages_before_summary}, "
59
- f"max_summary_tokens: {max_summary_tokens}"
60
- )
61
-
62
- summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
63
-
64
- node: RunnableLike = SummarizationNode(
65
- model=summarization_model,
66
- max_tokens=max_tokens,
67
- max_tokens_before_summary=max_tokens_before_summary
68
- or max_messages_before_summary,
69
- max_summary_tokens=max_summary_tokens,
70
- token_counter=token_counter,
71
- input_messages_key="messages",
72
- output_messages_key="summarized_messages",
73
- )
74
- return node
75
-
76
-
77
- def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
78
- async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
79
- logger.debug(f"Calling agent {agent.name} with summarized messages")
80
-
81
- # Get the summarized messages from the summarization node
82
- messages: Sequence[AnyMessage] = state.get("summarized_messages", [])
83
- logger.debug(f"Found {len(messages)} summarized messages")
84
- logger.trace(f"Summarized messages: {[m.model_dump() for m in messages]}")
85
-
86
- input: dict[str, Any] = {
87
- "messages": messages,
88
- }
89
-
90
- response: dict[str, Any] = await agent.ainvoke(
91
- input=input, context=runtime.context
53
+ logger.debug(f"Building middleware list for agent '{agent.name}'")
54
+ middleware_list: list[Any] = []
55
+
56
+ # Add configured middleware using factory pattern
57
+ if agent.middleware:
58
+ logger.debug(f"Processing {len(agent.middleware)} configured middleware")
59
+ for middleware_config in agent.middleware:
60
+ middleware = create_factory_middleware(
61
+ function_name=middleware_config.name,
62
+ args=middleware_config.args,
63
+ )
64
+ if middleware is not None:
65
+ middleware_list.append(middleware)
66
+
67
+ # Add guardrails as middleware
68
+ if agent.guardrails:
69
+ logger.debug(f"Adding {len(agent.guardrails)} guardrail middleware")
70
+ for guardrail in agent.guardrails:
71
+ # Extract template string from PromptModel if needed
72
+ prompt_str: str
73
+ if isinstance(guardrail.prompt, PromptModel):
74
+ prompt_str = guardrail.prompt.template
75
+ else:
76
+ prompt_str = guardrail.prompt
77
+
78
+ guardrail_middleware = GuardrailMiddleware(
79
+ name=guardrail.name,
80
+ model=guardrail.model.as_chat_model(),
81
+ prompt=prompt_str,
82
+ num_retries=guardrail.num_retries or 3,
92
83
  )
93
- response_messages = response.get("messages", [])
94
- logger.debug(f"Agent returned {len(response_messages)} messages")
84
+ logger.debug(f"Created guardrail middleware: {guardrail.name}")
85
+ middleware_list.append(guardrail_middleware)
95
86
 
96
- return {"messages": response_messages}
87
+ # Add summarization middleware if chat_history is configured
88
+ if chat_history is not None:
89
+ logger.debug("Adding summarization middleware")
90
+ summarization_middleware = create_summarization_middleware(chat_history)
91
+ middleware_list.append(summarization_middleware)
97
92
 
98
- return call_agent
93
+ # Add human-in-the-loop middleware if any tools require it
94
+ hitl_middleware = create_hitl_middleware_from_tool_models(tool_models)
95
+ if hitl_middleware is not None:
96
+ logger.debug("Added human-in-the-loop middleware")
97
+ middleware_list.append(hitl_middleware)
98
+
99
+ logger.debug(f"Total middleware count: {len(middleware_list)}")
100
+ return middleware_list
99
101
 
100
102
 
101
103
  def create_agent_node(
@@ -107,9 +109,9 @@ def create_agent_node(
107
109
  """
108
110
  Factory function that creates a LangGraph node for a specialized agent.
109
111
 
110
- This creates a node function that handles user requests using a specialized agent.
111
- The function configures the agent with the appropriate model, prompt, tools, and guardrails.
112
- If chat_history is provided, it creates a workflow with summarization node.
112
+ This creates an agent using LangChain v1's create_agent function with
113
+ middleware for customization. The function configures the agent with
114
+ the appropriate model, prompt, tools, and middleware.
113
115
 
114
116
  Args:
115
117
  agent: AgentModel configuration for the agent
@@ -122,16 +124,12 @@ def create_agent_node(
122
124
  """
123
125
  logger.debug(f"Creating agent node for {agent.name}")
124
126
 
125
- if agent.create_agent_hook:
126
- agent_hook = next(iter(create_hooks(agent.create_agent_hook)), None)
127
- return agent_hook
128
-
129
127
  llm: LanguageModelLike = agent.model.as_chat_model()
130
128
 
131
129
  tool_models: Sequence[ToolModel] = agent.tools
132
130
  if not additional_tools:
133
131
  additional_tools = []
134
- tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
132
+ tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
135
133
 
136
134
  if memory and memory.store:
137
135
  namespace: tuple[str, ...] = ("memory",)
@@ -139,100 +137,59 @@ def create_agent_node(
139
137
  namespace = namespace + (memory.store.namespace,)
140
138
  logger.debug(f"Memory store namespace: {namespace}")
141
139
 
140
+ # Use Databricks-compatible search_memory tool (omits problematic filter field)
142
141
  tools += [
143
142
  create_manage_memory_tool(namespace=namespace),
144
143
  create_search_memory_tool(namespace=namespace),
145
144
  ]
146
145
 
147
- pre_agent_hook: Callable[..., Any] = next(
148
- iter(create_hooks(agent.pre_agent_hook)), None
146
+ # Create middleware list from configuration
147
+ middleware_list = _create_middleware_list(
148
+ agent=agent,
149
+ tool_models=tool_models,
150
+ chat_history=chat_history,
149
151
  )
150
- logger.debug(f"pre_agent_hook: {pre_agent_hook}")
151
152
 
152
- post_agent_hook: Callable[..., Any] = next(
153
- iter(create_hooks(agent.post_agent_hook)), None
154
- )
155
- logger.debug(f"post_agent_hook: {post_agent_hook}")
156
-
157
- checkpointer: bool = memory and memory.checkpointer is not None
158
-
159
- compiled_agent: CompiledStateGraph = create_react_agent(
153
+ logger.debug(f"Created {len(middleware_list)} middleware for agent {agent.name}")
154
+
155
+ checkpointer: bool = memory is not None and memory.checkpointer is not None
156
+
157
+ # Get the prompt as middleware (always returns AgentMiddleware or None)
158
+ prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
159
+
160
+ # Add prompt middleware at the beginning for priority
161
+ if prompt_middleware is not None:
162
+ middleware_list.insert(0, prompt_middleware)
163
+
164
+ # Configure structured output if response_format is specified
165
+ response_format: Any = None
166
+ if agent.response_format is not None:
167
+ try:
168
+ response_format = agent.response_format.as_strategy()
169
+ if response_format is not None:
170
+ logger.debug(
171
+ f"Agent '{agent.name}' using structured output: {type(response_format).__name__}"
172
+ )
173
+ except ValueError as e:
174
+ logger.error(
175
+ f"Failed to configure structured output for agent {agent.name}: {e}"
176
+ )
177
+ raise
178
+
179
+ # Use LangChain v1's create_agent with middleware
180
+ # AgentState extends MessagesState with additional DAO AI fields
181
+ # System prompt is provided via middleware (dynamic_prompt)
182
+ compiled_agent: CompiledStateGraph = create_agent(
160
183
  name=agent.name,
161
184
  model=llm,
162
- prompt=make_prompt(agent.prompt),
163
185
  tools=tools,
164
- store=True,
186
+ middleware=middleware_list,
165
187
  checkpointer=checkpointer,
166
- state_schema=SharedState,
188
+ state_schema=AgentState,
167
189
  context_schema=Context,
168
- pre_model_hook=pre_agent_hook,
169
- post_model_hook=post_agent_hook,
190
+ response_format=response_format, # Add structured output support
170
191
  )
171
192
 
172
- for guardrail_definition in agent.guardrails:
173
- guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
174
- compiled_agent = with_guardrails(compiled_agent, guardrail)
175
-
176
193
  compiled_agent.name = agent.name
177
194
 
178
- agent_node: CompiledStateGraph
179
-
180
- if chat_history is None:
181
- logger.debug("No chat history configured, using compiled agent directly")
182
- agent_node = compiled_agent
183
- else:
184
- logger.debug("Creating agent node with chat history summarization")
185
- workflow: StateGraph = StateGraph(
186
- SharedState,
187
- config_schema=RunnableConfig,
188
- input=SharedState,
189
- output=SharedState,
190
- )
191
- workflow.add_node("summarization", summarization_node(chat_history))
192
- workflow.add_node(
193
- "agent",
194
- call_agent_with_summarized_messages(agent=compiled_agent),
195
- )
196
- workflow.add_edge("summarization", "agent")
197
- workflow.set_entry_point("summarization")
198
- agent_node = workflow.compile(name=agent.name)
199
-
200
- return agent_node
201
-
202
-
203
- def message_hook_node(config: AppConfig) -> RunnableLike:
204
- message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
205
-
206
- @mlflow.trace()
207
- async def message_hook(
208
- state: IncomingState, runtime: Runtime[Context]
209
- ) -> SharedState:
210
- logger.debug("Running message validation")
211
- response: dict[str, Any] = {"is_valid": True, "message_error": None}
212
-
213
- for message_hook in message_hooks:
214
- message_hook: FunctionHook
215
- if message_hook:
216
- try:
217
- hook_response: dict[str, Any] = message_hook(
218
- state=state,
219
- runtime=runtime,
220
- )
221
- response.update(hook_response)
222
- logger.debug(f"Hook response: {hook_response}")
223
- if not response.get("is_valid", True):
224
- break
225
- except Exception as e:
226
- logger.error(f"Message validation failed: {e}")
227
- response_messages: Sequence[BaseMessage] = [
228
- AIMessage(content=str(e))
229
- ]
230
- return {
231
- "is_valid": False,
232
- "message_error": str(e),
233
- "messages": response_messages,
234
- }
235
-
236
- return response
237
-
238
- return message_hook
195
+ return compiled_agent