dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py CHANGED
@@ -1,101 +1,152 @@
1
- from typing import Any, Callable, Optional, Sequence
1
+ """
2
+ Node creation utilities for DAO AI agents.
2
3
 
3
- import mlflow
4
+ This module provides factory functions for creating LangGraph nodes
5
+ that implement agent logic using LangChain v1's create_agent pattern.
6
+ """
7
+
8
+ from typing import Any, Optional, Sequence
9
+
10
+ from langchain.agents import create_agent
11
+ from langchain.agents.middleware import AgentMiddleware
4
12
  from langchain_core.language_models import LanguageModelLike
5
- from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
6
- from langchain_core.messages.utils import count_tokens_approximately
7
- from langchain_core.runnables import RunnableConfig
8
13
  from langchain_core.runnables.base import RunnableLike
9
14
  from langchain_core.tools import BaseTool
10
- from langgraph.graph import StateGraph
11
15
  from langgraph.graph.state import CompiledStateGraph
12
- from langgraph.prebuilt import create_react_agent
13
- from langgraph.runtime import Runtime
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from langmem.short_term import SummarizationNode
16
- from langmem.short_term.summarization import TokenCounter
16
+ from langmem import create_manage_memory_tool
17
17
  from loguru import logger
18
18
 
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
- AppConfig,
22
21
  ChatHistoryModel,
23
- FunctionHook,
24
22
  MemoryModel,
23
+ PromptModel,
25
24
  ToolModel,
26
25
  )
27
- from dao_ai.guardrails import reflection_guardrail, with_guardrails
28
- from dao_ai.hooks.core import create_hooks
26
+ from dao_ai.middleware.core import create_factory_middleware
27
+ from dao_ai.middleware.guardrails import GuardrailMiddleware
28
+ from dao_ai.middleware.human_in_the_loop import (
29
+ HumanInTheLoopMiddleware,
30
+ create_hitl_middleware_from_tool_models,
31
+ )
32
+ from dao_ai.middleware.summarization import (
33
+ LoggingSummarizationMiddleware,
34
+ create_summarization_middleware,
35
+ )
29
36
  from dao_ai.prompts import make_prompt
30
- from dao_ai.state import Context, IncomingState, SharedState
37
+ from dao_ai.state import AgentState, Context
31
38
  from dao_ai.tools import create_tools
39
+ from dao_ai.tools.memory import create_search_memory_tool
32
40
 
33
41
 
34
- def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
42
+ def _create_middleware_list(
43
+ agent: AgentModel,
44
+ tool_models: Sequence[ToolModel],
45
+ chat_history: Optional[ChatHistoryModel] = None,
46
+ ) -> list[Any]:
35
47
  """
36
- Create a summarization node for managing chat history.
48
+ Create a list of middleware instances from agent configuration.
37
49
 
38
50
  Args:
39
- chat_history: ChatHistoryModel configuration for summarization
51
+ agent: AgentModel configuration
52
+ tool_models: Tool model configurations (for HITL settings)
53
+ chat_history: Optional chat history configuration for summarization
40
54
 
41
55
  Returns:
42
- RunnableLike: A summarization node that processes messages
56
+ List of middleware instances (can include both AgentMiddleware and
57
+ LangChain built-in middleware)
43
58
  """
44
- if chat_history is None:
45
- raise ValueError("chat_history must be provided to use summarization")
46
-
47
- max_tokens: int = chat_history.max_tokens
48
- max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
49
- max_messages_before_summary: int | None = chat_history.max_messages_before_summary
50
- max_summary_tokens: int | None = chat_history.max_summary_tokens
51
- token_counter: TokenCounter = (
52
- count_tokens_approximately if max_tokens_before_summary else len
53
- )
54
-
55
- logger.debug(
56
- f"Creating summarization node with max_tokens: {max_tokens}, "
57
- f"max_tokens_before_summary: {max_tokens_before_summary}, "
58
- f"max_messages_before_summary: {max_messages_before_summary}, "
59
- f"max_summary_tokens: {max_summary_tokens}"
60
- )
59
+ logger.debug("Building middleware list for agent", agent=agent.name)
60
+ middleware_list: list[Any] = []
61
+
62
+ # Add configured middleware using factory pattern
63
+ if agent.middleware:
64
+ middleware_names: list[str] = [mw.name for mw in agent.middleware]
65
+ logger.info(
66
+ "Middleware configuration",
67
+ agent=agent.name,
68
+ middleware_count=len(agent.middleware),
69
+ middleware_names=middleware_names,
70
+ )
71
+ for middleware_config in agent.middleware:
72
+ logger.trace(
73
+ "Creating middleware for agent",
74
+ agent=agent.name,
75
+ middleware_name=middleware_config.name,
76
+ )
77
+ middleware: AgentMiddleware[AgentState, Context] = create_factory_middleware(
78
+ function_name=middleware_config.name,
79
+ args=middleware_config.args,
80
+ )
81
+ if middleware is not None:
82
+ middleware_list.append(middleware)
83
+
84
+ # Add guardrails as middleware
85
+ if agent.guardrails:
86
+ guardrail_names: list[str] = [gr.name for gr in agent.guardrails]
87
+ logger.info(
88
+ "Guardrails configuration",
89
+ agent=agent.name,
90
+ guardrails_count=len(agent.guardrails),
91
+ guardrail_names=guardrail_names,
92
+ )
93
+ for guardrail in agent.guardrails:
94
+ # Extract template string from PromptModel if needed
95
+ prompt_str: str
96
+ if isinstance(guardrail.prompt, PromptModel):
97
+ prompt_str = guardrail.prompt.template
98
+ else:
99
+ prompt_str = guardrail.prompt
100
+
101
+ guardrail_middleware: GuardrailMiddleware = GuardrailMiddleware(
102
+ name=guardrail.name,
103
+ model=guardrail.model.as_chat_model(),
104
+ prompt=prompt_str,
105
+ num_retries=guardrail.num_retries or 3,
106
+ )
107
+ logger.trace(
108
+ "Created guardrail middleware", guardrail=guardrail.name, agent=agent.name
109
+ )
110
+ middleware_list.append(guardrail_middleware)
111
+
112
+ # Add summarization middleware if chat_history is configured
113
+ if chat_history is not None:
114
+ logger.info(
115
+ "Chat history configuration",
116
+ agent=agent.name,
117
+ max_tokens=chat_history.max_tokens,
118
+ summary_model=chat_history.model.name,
119
+ )
120
+ summarization_middleware: LoggingSummarizationMiddleware = (
121
+ create_summarization_middleware(chat_history)
122
+ )
123
+ middleware_list.append(summarization_middleware)
61
124
 
62
- summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
63
-
64
- node: RunnableLike = SummarizationNode(
65
- model=summarization_model,
66
- max_tokens=max_tokens,
67
- max_tokens_before_summary=max_tokens_before_summary
68
- or max_messages_before_summary,
69
- max_summary_tokens=max_summary_tokens,
70
- token_counter=token_counter,
71
- input_messages_key="messages",
72
- output_messages_key="summarized_messages",
125
+ # Add human-in-the-loop middleware if any tools require it
126
+ hitl_middleware: HumanInTheLoopMiddleware | None = (
127
+ create_hitl_middleware_from_tool_models(tool_models)
73
128
  )
74
- return node
75
-
76
-
77
- def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
78
- async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
79
- logger.debug(f"Calling agent {agent.name} with summarized messages")
80
-
81
- # Get the summarized messages from the summarization node
82
- messages: Sequence[AnyMessage] = state.get("summarized_messages", [])
83
- logger.debug(f"Found {len(messages)} summarized messages")
84
- logger.trace(f"Summarized messages: {[m.model_dump() for m in messages]}")
85
-
86
- input: dict[str, Any] = {
87
- "messages": messages,
88
- }
89
-
90
- response: dict[str, Any] = await agent.ainvoke(
91
- input=input, context=runtime.context
129
+ if hitl_middleware is not None:
130
+ # Log which tools require HITL
131
+ hitl_tool_names: list[str] = [
132
+ tool.name
133
+ for tool in tool_models
134
+ if hasattr(tool.function, "human_in_the_loop")
135
+ and tool.function.human_in_the_loop is not None
136
+ ]
137
+ logger.info(
138
+ "Human-in-the-Loop configuration",
139
+ agent=agent.name,
140
+ hitl_tools=hitl_tool_names,
92
141
  )
93
- response_messages = response.get("messages", [])
94
- logger.debug(f"Agent returned {len(response_messages)} messages")
95
-
96
- return {"messages": response_messages}
142
+ middleware_list.append(hitl_middleware)
97
143
 
98
- return call_agent
144
+ logger.info(
145
+ "Middleware summary",
146
+ agent=agent.name,
147
+ total_middleware_count=len(middleware_list),
148
+ )
149
+ return middleware_list
99
150
 
100
151
 
101
152
  def create_agent_node(
@@ -107,9 +158,9 @@ def create_agent_node(
107
158
  """
108
159
  Factory function that creates a LangGraph node for a specialized agent.
109
160
 
110
- This creates a node function that handles user requests using a specialized agent.
111
- The function configures the agent with the appropriate model, prompt, tools, and guardrails.
112
- If chat_history is provided, it creates a workflow with summarization node.
161
+ This creates an agent using LangChain v1's create_agent function with
162
+ middleware for customization. The function configures the agent with
163
+ the appropriate model, prompt, tools, and middleware.
113
164
 
114
165
  Args:
115
166
  agent: AgentModel configuration for the agent
@@ -120,119 +171,154 @@ def create_agent_node(
120
171
  Returns:
121
172
  RunnableLike: An agent node that processes state and returns responses
122
173
  """
123
- logger.debug(f"Creating agent node for {agent.name}")
124
-
125
- if agent.create_agent_hook:
126
- agent_hook = next(iter(create_hooks(agent.create_agent_hook)), None)
127
- return agent_hook
174
+ logger.info("Creating agent node", agent=agent.name)
175
+
176
+ # Log agent configuration details
177
+ logger.info(
178
+ "Agent configuration",
179
+ agent=agent.name,
180
+ model=agent.model.name,
181
+ description=agent.description or "No description",
182
+ )
128
183
 
129
184
  llm: LanguageModelLike = agent.model.as_chat_model()
130
185
 
131
186
  tool_models: Sequence[ToolModel] = agent.tools
132
187
  if not additional_tools:
133
188
  additional_tools = []
134
- tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
189
+
190
+ # Log tools being created
191
+ tool_names: list[str] = [tool.name for tool in tool_models]
192
+ logger.info(
193
+ "Tools configuration",
194
+ agent=agent.name,
195
+ tools_count=len(tool_models),
196
+ tool_names=tool_names,
197
+ )
198
+
199
+ tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
200
+
201
+ if additional_tools:
202
+ logger.debug(
203
+ "Additional tools added",
204
+ agent=agent.name,
205
+ additional_count=len(additional_tools),
206
+ )
135
207
 
136
208
  if memory and memory.store:
137
209
  namespace: tuple[str, ...] = ("memory",)
138
210
  if memory.store.namespace:
139
211
  namespace = namespace + (memory.store.namespace,)
140
- logger.debug(f"Memory store namespace: {namespace}")
212
+ logger.info(
213
+ "Memory configuration",
214
+ agent=agent.name,
215
+ has_store=True,
216
+ has_checkpointer=memory.checkpointer is not None,
217
+ namespace=namespace,
218
+ )
219
+ elif memory:
220
+ logger.info(
221
+ "Memory configuration",
222
+ agent=agent.name,
223
+ has_store=False,
224
+ has_checkpointer=memory.checkpointer is not None,
225
+ )
141
226
 
227
+ # Add memory tools if store is configured
228
+ if memory and memory.store:
229
+ # Use Databricks-compatible search_memory tool (omits problematic filter field)
142
230
  tools += [
143
231
  create_manage_memory_tool(namespace=namespace),
144
232
  create_search_memory_tool(namespace=namespace),
145
233
  ]
234
+ logger.debug(
235
+ "Memory tools added",
236
+ agent=agent.name,
237
+ tools=["manage_memory", "search_memory"],
238
+ )
146
239
 
147
- pre_agent_hook: Callable[..., Any] = next(
148
- iter(create_hooks(agent.pre_agent_hook)), None
240
+ # Create middleware list from configuration
241
+ middleware_list = _create_middleware_list(
242
+ agent=agent,
243
+ tool_models=tool_models,
244
+ chat_history=chat_history,
149
245
  )
150
- logger.debug(f"pre_agent_hook: {pre_agent_hook}")
151
246
 
152
- post_agent_hook: Callable[..., Any] = next(
153
- iter(create_hooks(agent.post_agent_hook)), None
247
+ # Log prompt configuration
248
+ if agent.prompt:
249
+ if isinstance(agent.prompt, PromptModel):
250
+ logger.info(
251
+ "Prompt configuration",
252
+ agent=agent.name,
253
+ prompt_type="PromptModel",
254
+ prompt_name=agent.prompt.name,
255
+ )
256
+ else:
257
+ prompt_preview: str = (
258
+ agent.prompt[:100] + "..." if len(agent.prompt) > 100 else agent.prompt
259
+ )
260
+ logger.info(
261
+ "Prompt configuration",
262
+ agent=agent.name,
263
+ prompt_type="string",
264
+ prompt_preview=prompt_preview,
265
+ )
266
+ else:
267
+ logger.debug("No custom prompt configured", agent=agent.name)
268
+
269
+ checkpointer: bool = memory is not None and memory.checkpointer is not None
270
+
271
+ # Get the prompt as middleware (always returns AgentMiddleware or None)
272
+ prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
273
+
274
+ # Add prompt middleware at the beginning for priority
275
+ if prompt_middleware is not None:
276
+ middleware_list.insert(0, prompt_middleware)
277
+
278
+ # Configure structured output if response_format is specified
279
+ response_format: Any = None
280
+ if agent.response_format is not None:
281
+ try:
282
+ response_format = agent.response_format.as_strategy()
283
+ if response_format is not None:
284
+ logger.info(
285
+ "Response format configuration",
286
+ agent=agent.name,
287
+ format_type=type(response_format).__name__,
288
+ structured_output=True,
289
+ )
290
+ except ValueError as e:
291
+ logger.error(
292
+ "Failed to configure structured output for agent",
293
+ agent=agent.name,
294
+ error=str(e),
295
+ )
296
+ raise
297
+
298
+ # Use LangChain v1's create_agent with middleware
299
+ # AgentState extends MessagesState with additional DAO AI fields
300
+ # System prompt is provided via middleware (dynamic_prompt)
301
+ logger.info(
302
+ "Creating LangChain agent",
303
+ agent=agent.name,
304
+ tools_count=len(tools),
305
+ middleware_count=len(middleware_list),
306
+ has_checkpointer=checkpointer,
154
307
  )
155
- logger.debug(f"post_agent_hook: {post_agent_hook}")
156
308
 
157
- checkpointer: bool = memory and memory.checkpointer is not None
158
-
159
- compiled_agent: CompiledStateGraph = create_react_agent(
309
+ compiled_agent: CompiledStateGraph = create_agent(
160
310
  name=agent.name,
161
311
  model=llm,
162
- prompt=make_prompt(agent.prompt),
163
312
  tools=tools,
164
- store=True,
313
+ middleware=middleware_list,
165
314
  checkpointer=checkpointer,
166
- state_schema=SharedState,
315
+ state_schema=AgentState,
167
316
  context_schema=Context,
168
- pre_model_hook=pre_agent_hook,
169
- post_model_hook=post_agent_hook,
317
+ response_format=response_format, # Add structured output support
170
318
  )
171
319
 
172
- for guardrail_definition in agent.guardrails:
173
- guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
174
- compiled_agent = with_guardrails(compiled_agent, guardrail)
175
-
176
320
  compiled_agent.name = agent.name
177
321
 
178
- agent_node: CompiledStateGraph
322
+ logger.info("Agent node created successfully", agent=agent.name)
179
323
 
180
- if chat_history is None:
181
- logger.debug("No chat history configured, using compiled agent directly")
182
- agent_node = compiled_agent
183
- else:
184
- logger.debug("Creating agent node with chat history summarization")
185
- workflow: StateGraph = StateGraph(
186
- SharedState,
187
- config_schema=RunnableConfig,
188
- input=SharedState,
189
- output=SharedState,
190
- )
191
- workflow.add_node("summarization", summarization_node(chat_history))
192
- workflow.add_node(
193
- "agent",
194
- call_agent_with_summarized_messages(agent=compiled_agent),
195
- )
196
- workflow.add_edge("summarization", "agent")
197
- workflow.set_entry_point("summarization")
198
- agent_node = workflow.compile(name=agent.name)
199
-
200
- return agent_node
201
-
202
-
203
- def message_hook_node(config: AppConfig) -> RunnableLike:
204
- message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
205
-
206
- @mlflow.trace()
207
- async def message_hook(
208
- state: IncomingState, runtime: Runtime[Context]
209
- ) -> SharedState:
210
- logger.debug("Running message validation")
211
- response: dict[str, Any] = {"is_valid": True, "message_error": None}
212
-
213
- for message_hook in message_hooks:
214
- message_hook: FunctionHook
215
- if message_hook:
216
- try:
217
- hook_response: dict[str, Any] = message_hook(
218
- state=state,
219
- runtime=runtime,
220
- )
221
- response.update(hook_response)
222
- logger.debug(f"Hook response: {hook_response}")
223
- if not response.get("is_valid", True):
224
- break
225
- except Exception as e:
226
- logger.error(f"Message validation failed: {e}")
227
- response_messages: Sequence[BaseMessage] = [
228
- AIMessage(content=str(e))
229
- ]
230
- return {
231
- "is_valid": False,
232
- "message_error": str(e),
233
- "messages": response_messages,
234
- }
235
-
236
- return response
237
-
238
- return message_hook
324
+ return compiled_agent