dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py CHANGED
@@ -1,101 +1,145 @@
1
- from typing import Any, Callable, Optional, Sequence
1
+ """
2
+ Node creation utilities for DAO AI agents.
2
3
 
3
- import mlflow
4
+ This module provides factory functions for creating LangGraph nodes
5
+ that implement agent logic using LangChain v1's create_agent pattern.
6
+ """
7
+
8
+ from typing import Any, Optional, Sequence
9
+
10
+ from langchain.agents import create_agent
11
+ from langchain.agents.middleware import AgentMiddleware
4
12
  from langchain_core.language_models import LanguageModelLike
5
- from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
6
- from langchain_core.messages.utils import count_tokens_approximately
7
- from langchain_core.runnables import RunnableConfig
8
13
  from langchain_core.runnables.base import RunnableLike
9
14
  from langchain_core.tools import BaseTool
10
- from langgraph.graph import StateGraph
11
15
  from langgraph.graph.state import CompiledStateGraph
12
- from langgraph.prebuilt import create_react_agent
13
- from langgraph.runtime import Runtime
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from langmem.short_term import SummarizationNode
16
- from langmem.short_term.summarization import TokenCounter
16
+ from langmem import create_manage_memory_tool
17
17
  from loguru import logger
18
18
 
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
- AppConfig,
22
21
  ChatHistoryModel,
23
- FunctionHook,
24
22
  MemoryModel,
23
+ PromptModel,
25
24
  ToolModel,
26
25
  )
27
- from dao_ai.guardrails import reflection_guardrail, with_guardrails
28
- from dao_ai.hooks.core import create_hooks
26
+ from dao_ai.middleware.core import create_factory_middleware
27
+ from dao_ai.middleware.guardrails import GuardrailMiddleware
28
+ from dao_ai.middleware.human_in_the_loop import (
29
+ create_hitl_middleware_from_tool_models,
30
+ )
31
+ from dao_ai.middleware.summarization import (
32
+ create_summarization_middleware,
33
+ )
29
34
  from dao_ai.prompts import make_prompt
30
- from dao_ai.state import Context, IncomingState, SharedState
35
+ from dao_ai.state import AgentState, Context
31
36
  from dao_ai.tools import create_tools
37
+ from dao_ai.tools.memory import create_search_memory_tool
32
38
 
33
39
 
34
- def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
40
+ def _create_middleware_list(
41
+ agent: AgentModel,
42
+ tool_models: Sequence[ToolModel],
43
+ chat_history: Optional[ChatHistoryModel] = None,
44
+ ) -> list[Any]:
35
45
  """
36
- Create a summarization node for managing chat history.
46
+ Create a list of middleware instances from agent configuration.
37
47
 
38
48
  Args:
39
- chat_history: ChatHistoryModel configuration for summarization
49
+ agent: AgentModel configuration
50
+ tool_models: Tool model configurations (for HITL settings)
51
+ chat_history: Optional chat history configuration for summarization
40
52
 
41
53
  Returns:
42
- RunnableLike: A summarization node that processes messages
54
+ List of middleware instances (can include both AgentMiddleware and
55
+ LangChain built-in middleware)
43
56
  """
44
- if chat_history is None:
45
- raise ValueError("chat_history must be provided to use summarization")
46
-
47
- max_tokens: int = chat_history.max_tokens
48
- max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
49
- max_messages_before_summary: int | None = chat_history.max_messages_before_summary
50
- max_summary_tokens: int | None = chat_history.max_summary_tokens
51
- token_counter: TokenCounter = (
52
- count_tokens_approximately if max_tokens_before_summary else len
53
- )
54
-
55
- logger.debug(
56
- f"Creating summarization node with max_tokens: {max_tokens}, "
57
- f"max_tokens_before_summary: {max_tokens_before_summary}, "
58
- f"max_messages_before_summary: {max_messages_before_summary}, "
59
- f"max_summary_tokens: {max_summary_tokens}"
60
- )
61
-
62
- summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
63
-
64
- node: RunnableLike = SummarizationNode(
65
- model=summarization_model,
66
- max_tokens=max_tokens,
67
- max_tokens_before_summary=max_tokens_before_summary
68
- or max_messages_before_summary,
69
- max_summary_tokens=max_summary_tokens,
70
- token_counter=token_counter,
71
- input_messages_key="messages",
72
- output_messages_key="summarized_messages",
73
- )
74
- return node
75
-
76
-
77
- def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
78
- async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
79
- logger.debug(f"Calling agent {agent.name} with summarized messages")
80
-
81
- # Get the summarized messages from the summarization node
82
- messages: Sequence[AnyMessage] = state.get("summarized_messages", [])
83
- logger.debug(f"Found {len(messages)} summarized messages")
84
- logger.trace(f"Summarized messages: {[m.model_dump() for m in messages]}")
85
-
86
- input: dict[str, Any] = {
87
- "messages": messages,
88
- }
89
-
90
- response: dict[str, Any] = await agent.ainvoke(
91
- input=input, context=runtime.context
57
+ logger.debug("Building middleware list for agent", agent=agent.name)
58
+ middleware_list: list[Any] = []
59
+
60
+ # Add configured middleware using factory pattern
61
+ if agent.middleware:
62
+ middleware_names: list[str] = [mw.name for mw in agent.middleware]
63
+ logger.info(
64
+ "Middleware configuration",
65
+ agent=agent.name,
66
+ middleware_count=len(agent.middleware),
67
+ middleware_names=middleware_names,
92
68
  )
93
- response_messages = response.get("messages", [])
94
- logger.debug(f"Agent returned {len(response_messages)} messages")
95
-
96
- return {"messages": response_messages}
69
+ for middleware_config in agent.middleware:
70
+ logger.trace(
71
+ "Creating middleware for agent",
72
+ agent=agent.name,
73
+ middleware_name=middleware_config.name,
74
+ )
75
+ middleware: AgentMiddleware[AgentState, Context] = create_factory_middleware(
76
+ function_name=middleware_config.name,
77
+ args=middleware_config.args,
78
+ )
79
+ middleware_list.append(middleware)
80
+
81
+ # Add guardrails as middleware
82
+ if agent.guardrails:
83
+ guardrail_names: list[str] = [gr.name for gr in agent.guardrails]
84
+ logger.info(
85
+ "Guardrails configuration",
86
+ agent=agent.name,
87
+ guardrails_count=len(agent.guardrails),
88
+ guardrail_names=guardrail_names,
89
+ )
90
+ for guardrail in agent.guardrails:
91
+ # Extract template string from PromptModel if needed
92
+ prompt_str: str
93
+ if isinstance(guardrail.prompt, PromptModel):
94
+ prompt_str = guardrail.prompt.template
95
+ else:
96
+ prompt_str = guardrail.prompt
97
+
98
+ guardrail_middleware: GuardrailMiddleware = GuardrailMiddleware(
99
+ name=guardrail.name,
100
+ model=guardrail.model.as_chat_model(),
101
+ prompt=prompt_str,
102
+ num_retries=guardrail.num_retries or 3,
103
+ )
104
+ logger.trace(
105
+ "Created guardrail middleware", guardrail=guardrail.name, agent=agent.name
106
+ )
107
+ middleware_list.append(guardrail_middleware)
108
+
109
+ # Add summarization middleware if chat_history is configured
110
+ if chat_history is not None:
111
+ logger.info(
112
+ "Chat history configuration",
113
+ agent=agent.name,
114
+ max_tokens=chat_history.max_tokens,
115
+ summary_model=chat_history.model.name,
116
+ )
117
+ summarization_middleware = create_summarization_middleware(chat_history)
118
+ middleware_list.append(summarization_middleware)
119
+
120
+ # Add human-in-the-loop middleware if any tools require it
121
+ hitl_middlewares = create_hitl_middleware_from_tool_models(tool_models)
122
+ if hitl_middlewares:
123
+ # Log which tools require HITL
124
+ hitl_tool_names: list[str] = [
125
+ tool.name
126
+ for tool in tool_models
127
+ if hasattr(tool.function, "human_in_the_loop")
128
+ and tool.function.human_in_the_loop is not None
129
+ ]
130
+ logger.info(
131
+ "Human-in-the-Loop configuration",
132
+ agent=agent.name,
133
+ hitl_tools=hitl_tool_names,
134
+ )
135
+ middleware_list.append(hitl_middlewares)
97
136
 
98
- return call_agent
137
+ logger.info(
138
+ "Middleware summary",
139
+ agent=agent.name,
140
+ total_middleware_count=len(middleware_list),
141
+ )
142
+ return middleware_list
99
143
 
100
144
 
101
145
  def create_agent_node(
@@ -107,9 +151,9 @@ def create_agent_node(
107
151
  """
108
152
  Factory function that creates a LangGraph node for a specialized agent.
109
153
 
110
- This creates a node function that handles user requests using a specialized agent.
111
- The function configures the agent with the appropriate model, prompt, tools, and guardrails.
112
- If chat_history is provided, it creates a workflow with summarization node.
154
+ This creates an agent using LangChain v1's create_agent function with
155
+ middleware for customization. The function configures the agent with
156
+ the appropriate model, prompt, tools, and middleware.
113
157
 
114
158
  Args:
115
159
  agent: AgentModel configuration for the agent
@@ -120,119 +164,154 @@ def create_agent_node(
120
164
  Returns:
121
165
  RunnableLike: An agent node that processes state and returns responses
122
166
  """
123
- logger.debug(f"Creating agent node for {agent.name}")
124
-
125
- if agent.create_agent_hook:
126
- agent_hook = next(iter(create_hooks(agent.create_agent_hook)), None)
127
- return agent_hook
167
+ logger.info("Creating agent node", agent=agent.name)
168
+
169
+ # Log agent configuration details
170
+ logger.info(
171
+ "Agent configuration",
172
+ agent=agent.name,
173
+ model=agent.model.name,
174
+ description=agent.description or "No description",
175
+ )
128
176
 
129
177
  llm: LanguageModelLike = agent.model.as_chat_model()
130
178
 
131
179
  tool_models: Sequence[ToolModel] = agent.tools
132
180
  if not additional_tools:
133
181
  additional_tools = []
134
- tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
182
+
183
+ # Log tools being created
184
+ tool_names: list[str] = [tool.name for tool in tool_models]
185
+ logger.info(
186
+ "Tools configuration",
187
+ agent=agent.name,
188
+ tools_count=len(tool_models),
189
+ tool_names=tool_names,
190
+ )
191
+
192
+ tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
193
+
194
+ if additional_tools:
195
+ logger.debug(
196
+ "Additional tools added",
197
+ agent=agent.name,
198
+ additional_count=len(additional_tools),
199
+ )
135
200
 
136
201
  if memory and memory.store:
137
202
  namespace: tuple[str, ...] = ("memory",)
138
203
  if memory.store.namespace:
139
204
  namespace = namespace + (memory.store.namespace,)
140
- logger.debug(f"Memory store namespace: {namespace}")
205
+ logger.info(
206
+ "Memory configuration",
207
+ agent=agent.name,
208
+ has_store=True,
209
+ has_checkpointer=memory.checkpointer is not None,
210
+ namespace=namespace,
211
+ )
212
+ elif memory:
213
+ logger.info(
214
+ "Memory configuration",
215
+ agent=agent.name,
216
+ has_store=False,
217
+ has_checkpointer=memory.checkpointer is not None,
218
+ )
141
219
 
220
+ # Add memory tools if store is configured
221
+ if memory and memory.store:
222
+ # Use Databricks-compatible search_memory tool (omits problematic filter field)
142
223
  tools += [
143
224
  create_manage_memory_tool(namespace=namespace),
144
225
  create_search_memory_tool(namespace=namespace),
145
226
  ]
227
+ logger.debug(
228
+ "Memory tools added",
229
+ agent=agent.name,
230
+ tools=["manage_memory", "search_memory"],
231
+ )
146
232
 
147
- pre_agent_hook: Callable[..., Any] = next(
148
- iter(create_hooks(agent.pre_agent_hook)), None
233
+ # Create middleware list from configuration
234
+ middleware_list = _create_middleware_list(
235
+ agent=agent,
236
+ tool_models=tool_models,
237
+ chat_history=chat_history,
149
238
  )
150
- logger.debug(f"pre_agent_hook: {pre_agent_hook}")
151
239
 
152
- post_agent_hook: Callable[..., Any] = next(
153
- iter(create_hooks(agent.post_agent_hook)), None
240
+ # Log prompt configuration
241
+ if agent.prompt:
242
+ if isinstance(agent.prompt, PromptModel):
243
+ logger.info(
244
+ "Prompt configuration",
245
+ agent=agent.name,
246
+ prompt_type="PromptModel",
247
+ prompt_name=agent.prompt.name,
248
+ )
249
+ else:
250
+ prompt_preview: str = (
251
+ agent.prompt[:100] + "..." if len(agent.prompt) > 100 else agent.prompt
252
+ )
253
+ logger.info(
254
+ "Prompt configuration",
255
+ agent=agent.name,
256
+ prompt_type="string",
257
+ prompt_preview=prompt_preview,
258
+ )
259
+ else:
260
+ logger.debug("No custom prompt configured", agent=agent.name)
261
+
262
+ checkpointer: bool = memory is not None and memory.checkpointer is not None
263
+
264
+ # Get the prompt as middleware (always returns AgentMiddleware or None)
265
+ prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
266
+
267
+ # Add prompt middleware at the beginning for priority
268
+ if prompt_middleware is not None:
269
+ middleware_list.insert(0, prompt_middleware)
270
+
271
+ # Configure structured output if response_format is specified
272
+ response_format: Any = None
273
+ if agent.response_format is not None:
274
+ try:
275
+ response_format = agent.response_format.as_strategy()
276
+ if response_format is not None:
277
+ logger.info(
278
+ "Response format configuration",
279
+ agent=agent.name,
280
+ format_type=type(response_format).__name__,
281
+ structured_output=True,
282
+ )
283
+ except ValueError as e:
284
+ logger.error(
285
+ "Failed to configure structured output for agent",
286
+ agent=agent.name,
287
+ error=str(e),
288
+ )
289
+ raise
290
+
291
+ # Use LangChain v1's create_agent with middleware
292
+ # AgentState extends MessagesState with additional DAO AI fields
293
+ # System prompt is provided via middleware (dynamic_prompt)
294
+ logger.info(
295
+ "Creating LangChain agent",
296
+ agent=agent.name,
297
+ tools_count=len(tools),
298
+ middleware_count=len(middleware_list),
299
+ has_checkpointer=checkpointer,
154
300
  )
155
- logger.debug(f"post_agent_hook: {post_agent_hook}")
156
301
 
157
- checkpointer: bool = memory and memory.checkpointer is not None
158
-
159
- compiled_agent: CompiledStateGraph = create_react_agent(
302
+ compiled_agent: CompiledStateGraph = create_agent(
160
303
  name=agent.name,
161
304
  model=llm,
162
- prompt=make_prompt(agent.prompt),
163
305
  tools=tools,
164
- store=True,
306
+ middleware=middleware_list,
165
307
  checkpointer=checkpointer,
166
- state_schema=SharedState,
308
+ state_schema=AgentState,
167
309
  context_schema=Context,
168
- pre_model_hook=pre_agent_hook,
169
- post_model_hook=post_agent_hook,
310
+ response_format=response_format, # Add structured output support
170
311
  )
171
312
 
172
- for guardrail_definition in agent.guardrails:
173
- guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
174
- compiled_agent = with_guardrails(compiled_agent, guardrail)
175
-
176
313
  compiled_agent.name = agent.name
177
314
 
178
- agent_node: CompiledStateGraph
315
+ logger.info("Agent node created successfully", agent=agent.name)
179
316
 
180
- if chat_history is None:
181
- logger.debug("No chat history configured, using compiled agent directly")
182
- agent_node = compiled_agent
183
- else:
184
- logger.debug("Creating agent node with chat history summarization")
185
- workflow: StateGraph = StateGraph(
186
- SharedState,
187
- config_schema=RunnableConfig,
188
- input=SharedState,
189
- output=SharedState,
190
- )
191
- workflow.add_node("summarization", summarization_node(chat_history))
192
- workflow.add_node(
193
- "agent",
194
- call_agent_with_summarized_messages(agent=compiled_agent),
195
- )
196
- workflow.add_edge("summarization", "agent")
197
- workflow.set_entry_point("summarization")
198
- agent_node = workflow.compile(name=agent.name)
199
-
200
- return agent_node
201
-
202
-
203
- def message_hook_node(config: AppConfig) -> RunnableLike:
204
- message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
205
-
206
- @mlflow.trace()
207
- async def message_hook(
208
- state: IncomingState, runtime: Runtime[Context]
209
- ) -> SharedState:
210
- logger.debug("Running message validation")
211
- response: dict[str, Any] = {"is_valid": True, "message_error": None}
212
-
213
- for message_hook in message_hooks:
214
- message_hook: FunctionHook
215
- if message_hook:
216
- try:
217
- hook_response: dict[str, Any] = message_hook(
218
- state=state,
219
- runtime=runtime,
220
- )
221
- response.update(hook_response)
222
- logger.debug(f"Hook response: {hook_response}")
223
- if not response.get("is_valid", True):
224
- break
225
- except Exception as e:
226
- logger.error(f"Message validation failed: {e}")
227
- response_messages: Sequence[BaseMessage] = [
228
- AIMessage(content=str(e))
229
- ]
230
- return {
231
- "is_valid": False,
232
- "message_error": str(e),
233
- "messages": response_messages,
234
- }
235
-
236
- return response
237
-
238
- return message_hook
317
+ return compiled_agent