dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py CHANGED
@@ -1,229 +1,324 @@
1
- from typing import Any, Callable, Optional, Sequence
1
+ """
2
+ Node creation utilities for DAO AI agents.
2
3
 
3
- import mlflow
4
+ This module provides factory functions for creating LangGraph nodes
5
+ that implement agent logic using LangChain v1's create_agent pattern.
6
+ """
7
+
8
+ from typing import Any, Optional, Sequence
9
+
10
+ from langchain.agents import create_agent
11
+ from langchain.agents.middleware import AgentMiddleware
4
12
  from langchain_core.language_models import LanguageModelLike
5
- from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
6
- from langchain_core.messages.utils import count_tokens_approximately
7
- from langchain_core.runnables import RunnableConfig
8
13
  from langchain_core.runnables.base import RunnableLike
9
14
  from langchain_core.tools import BaseTool
10
- from langgraph.graph import StateGraph
11
15
  from langgraph.graph.state import CompiledStateGraph
12
- from langgraph.prebuilt import create_react_agent
13
- from langgraph.runtime import Runtime
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from langmem.short_term import SummarizationNode
16
- from langmem.short_term.summarization import TokenCounter
16
+ from langmem import create_manage_memory_tool
17
17
  from loguru import logger
18
18
 
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
- AppConfig,
22
- AppModel,
23
21
  ChatHistoryModel,
24
- FunctionHook,
22
+ MemoryModel,
23
+ PromptModel,
25
24
  ToolModel,
26
25
  )
27
- from dao_ai.guardrails import reflection_guardrail, with_guardrails
28
- from dao_ai.hooks.core import create_hooks
26
+ from dao_ai.middleware.core import create_factory_middleware
27
+ from dao_ai.middleware.guardrails import GuardrailMiddleware
28
+ from dao_ai.middleware.human_in_the_loop import (
29
+ HumanInTheLoopMiddleware,
30
+ create_hitl_middleware_from_tool_models,
31
+ )
32
+ from dao_ai.middleware.summarization import (
33
+ LoggingSummarizationMiddleware,
34
+ create_summarization_middleware,
35
+ )
29
36
  from dao_ai.prompts import make_prompt
30
- from dao_ai.state import Context, IncomingState, SharedState
37
+ from dao_ai.state import AgentState, Context
31
38
  from dao_ai.tools import create_tools
39
+ from dao_ai.tools.memory import create_search_memory_tool
32
40
 
33
41
 
34
- def summarization_node(app_model: AppModel) -> RunnableLike:
35
- chat_history: ChatHistoryModel | None = app_model.chat_history
36
- if chat_history is None:
37
- raise ValueError(
38
- "AppModel must have chat_history configured to use summarization"
39
- )
42
+ def _create_middleware_list(
43
+ agent: AgentModel,
44
+ tool_models: Sequence[ToolModel],
45
+ chat_history: Optional[ChatHistoryModel] = None,
46
+ ) -> list[Any]:
47
+ """
48
+ Create a list of middleware instances from agent configuration.
40
49
 
41
- max_tokens: int = chat_history.max_tokens
42
- max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
43
- max_messages_before_summary: int | None = chat_history.max_messages_before_summary
44
- max_summary_tokens: int | None = chat_history.max_summary_tokens
45
- token_counter: TokenCounter = (
46
- count_tokens_approximately if max_tokens_before_summary else len
47
- )
50
+ Args:
51
+ agent: AgentModel configuration
52
+ tool_models: Tool model configurations (for HITL settings)
53
+ chat_history: Optional chat history configuration for summarization
48
54
 
49
- logger.debug(
50
- f"Creating summarization node with max_tokens: {max_tokens}, "
51
- f"max_tokens_before_summary: {max_tokens_before_summary}, "
52
- f"max_messages_before_summary: {max_messages_before_summary}, "
53
- f"max_summary_tokens: {max_summary_tokens}"
54
- )
55
+ Returns:
56
+ List of middleware instances (can include both AgentMiddleware and
57
+ LangChain built-in middleware)
58
+ """
59
+ logger.debug("Building middleware list for agent", agent=agent.name)
60
+ middleware_list: list[Any] = []
61
+
62
+ # Add configured middleware using factory pattern
63
+ if agent.middleware:
64
+ middleware_names: list[str] = [mw.name for mw in agent.middleware]
65
+ logger.info(
66
+ "Middleware configuration",
67
+ agent=agent.name,
68
+ middleware_count=len(agent.middleware),
69
+ middleware_names=middleware_names,
70
+ )
71
+ for middleware_config in agent.middleware:
72
+ logger.trace(
73
+ "Creating middleware for agent",
74
+ agent=agent.name,
75
+ middleware_name=middleware_config.name,
76
+ )
77
+ middleware: AgentMiddleware[AgentState, Context] = create_factory_middleware(
78
+ function_name=middleware_config.name,
79
+ args=middleware_config.args,
80
+ )
81
+ if middleware is not None:
82
+ middleware_list.append(middleware)
83
+
84
+ # Add guardrails as middleware
85
+ if agent.guardrails:
86
+ guardrail_names: list[str] = [gr.name for gr in agent.guardrails]
87
+ logger.info(
88
+ "Guardrails configuration",
89
+ agent=agent.name,
90
+ guardrails_count=len(agent.guardrails),
91
+ guardrail_names=guardrail_names,
92
+ )
93
+ for guardrail in agent.guardrails:
94
+ # Extract template string from PromptModel if needed
95
+ prompt_str: str
96
+ if isinstance(guardrail.prompt, PromptModel):
97
+ prompt_str = guardrail.prompt.template
98
+ else:
99
+ prompt_str = guardrail.prompt
100
+
101
+ guardrail_middleware: GuardrailMiddleware = GuardrailMiddleware(
102
+ name=guardrail.name,
103
+ model=guardrail.model.as_chat_model(),
104
+ prompt=prompt_str,
105
+ num_retries=guardrail.num_retries or 3,
106
+ )
107
+ logger.trace(
108
+ "Created guardrail middleware", guardrail=guardrail.name, agent=agent.name
109
+ )
110
+ middleware_list.append(guardrail_middleware)
111
+
112
+ # Add summarization middleware if chat_history is configured
113
+ if chat_history is not None:
114
+ logger.info(
115
+ "Chat history configuration",
116
+ agent=agent.name,
117
+ max_tokens=chat_history.max_tokens,
118
+ summary_model=chat_history.model.name,
119
+ )
120
+ summarization_middleware: LoggingSummarizationMiddleware = (
121
+ create_summarization_middleware(chat_history)
122
+ )
123
+ middleware_list.append(summarization_middleware)
55
124
 
56
- summarization_model: LanguageModelLike = chat_history.model.as_chat_model()
57
-
58
- node: RunnableLike = SummarizationNode(
59
- model=summarization_model,
60
- max_tokens=max_tokens,
61
- max_tokens_before_summary=max_tokens_before_summary
62
- or max_messages_before_summary,
63
- max_summary_tokens=max_summary_tokens,
64
- token_counter=token_counter,
65
- input_messages_key="messages",
66
- output_messages_key="summarized_messages",
125
+ # Add human-in-the-loop middleware if any tools require it
126
+ hitl_middleware: HumanInTheLoopMiddleware | None = (
127
+ create_hitl_middleware_from_tool_models(tool_models)
67
128
  )
68
- return node
69
-
70
-
71
- def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
72
- async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
73
- logger.debug(f"Calling agent {agent.name} with summarized messages")
74
-
75
- # Get the summarized messages from the summarization node
76
- messages: Sequence[AnyMessage] = state.get("summarized_messages", [])
77
- logger.debug(f"Found {len(messages)} summarized messages")
78
- logger.trace(f"Summarized messages: {[m.model_dump() for m in messages]}")
79
-
80
- input: dict[str, Any] = {
81
- "messages": messages,
82
- }
83
-
84
- response: dict[str, Any] = await agent.ainvoke(
85
- input=input, context=runtime.context
129
+ if hitl_middleware is not None:
130
+ # Log which tools require HITL
131
+ hitl_tool_names: list[str] = [
132
+ tool.name
133
+ for tool in tool_models
134
+ if hasattr(tool.function, "human_in_the_loop")
135
+ and tool.function.human_in_the_loop is not None
136
+ ]
137
+ logger.info(
138
+ "Human-in-the-Loop configuration",
139
+ agent=agent.name,
140
+ hitl_tools=hitl_tool_names,
86
141
  )
87
- response_messages = response.get("messages", [])
88
- logger.debug(f"Agent returned {len(response_messages)} messages")
89
-
90
- return {"messages": response_messages}
142
+ middleware_list.append(hitl_middleware)
91
143
 
92
- return call_agent
144
+ logger.info(
145
+ "Middleware summary",
146
+ agent=agent.name,
147
+ total_middleware_count=len(middleware_list),
148
+ )
149
+ return middleware_list
93
150
 
94
151
 
95
152
  def create_agent_node(
96
- app: AppModel,
97
153
  agent: AgentModel,
154
+ memory: Optional[MemoryModel] = None,
155
+ chat_history: Optional[ChatHistoryModel] = None,
98
156
  additional_tools: Optional[Sequence[BaseTool]] = None,
99
157
  ) -> RunnableLike:
100
158
  """
101
159
  Factory function that creates a LangGraph node for a specialized agent.
102
160
 
103
- This creates a node function that handles user requests using a specialized agent
104
- based on the provided agent_type. The function configures the agent with the
105
- appropriate model, prompt, tools, and guardrails from the model_config.
161
+ This creates an agent using LangChain v1's create_agent function with
162
+ middleware for customization. The function configures the agent with
163
+ the appropriate model, prompt, tools, and middleware.
106
164
 
107
165
  Args:
108
- model_config: Configuration containing models, prompts, tools, and guardrails
109
- agent_type: Type of agent to create (e.g., "general", "product", "inventory")
166
+ agent: AgentModel configuration for the agent
167
+ memory: Optional MemoryModel for memory store configuration
168
+ chat_history: Optional ChatHistoryModel for chat history summarization
169
+ additional_tools: Optional sequence of additional tools to add to the agent
110
170
 
111
171
  Returns:
112
- An agent callable function that processes state and returns responses
172
+ RunnableLike: An agent node that processes state and returns responses
113
173
  """
114
- logger.debug(f"Creating agent node for {agent.name}")
115
-
116
- if agent.create_agent_hook:
117
- agent_hook = next(iter(create_hooks(agent.create_agent_hook)), None)
118
- return agent_hook
174
+ logger.info("Creating agent node", agent=agent.name)
175
+
176
+ # Log agent configuration details
177
+ logger.info(
178
+ "Agent configuration",
179
+ agent=agent.name,
180
+ model=agent.model.name,
181
+ description=agent.description or "No description",
182
+ )
119
183
 
120
184
  llm: LanguageModelLike = agent.model.as_chat_model()
121
185
 
122
186
  tool_models: Sequence[ToolModel] = agent.tools
123
187
  if not additional_tools:
124
188
  additional_tools = []
125
- tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
126
189
 
127
- if app.orchestration.memory and app.orchestration.memory.store:
190
+ # Log tools being created
191
+ tool_names: list[str] = [tool.name for tool in tool_models]
192
+ logger.info(
193
+ "Tools configuration",
194
+ agent=agent.name,
195
+ tools_count=len(tool_models),
196
+ tool_names=tool_names,
197
+ )
198
+
199
+ tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
200
+
201
+ if additional_tools:
202
+ logger.debug(
203
+ "Additional tools added",
204
+ agent=agent.name,
205
+ additional_count=len(additional_tools),
206
+ )
207
+
208
+ if memory and memory.store:
128
209
  namespace: tuple[str, ...] = ("memory",)
129
- if app.orchestration.memory.store.namespace:
130
- namespace = namespace + (app.orchestration.memory.store.namespace,)
131
- logger.debug(f"Memory store namespace: {namespace}")
210
+ if memory.store.namespace:
211
+ namespace = namespace + (memory.store.namespace,)
212
+ logger.info(
213
+ "Memory configuration",
214
+ agent=agent.name,
215
+ has_store=True,
216
+ has_checkpointer=memory.checkpointer is not None,
217
+ namespace=namespace,
218
+ )
219
+ elif memory:
220
+ logger.info(
221
+ "Memory configuration",
222
+ agent=agent.name,
223
+ has_store=False,
224
+ has_checkpointer=memory.checkpointer is not None,
225
+ )
132
226
 
227
+ # Add memory tools if store is configured
228
+ if memory and memory.store:
229
+ # Use Databricks-compatible search_memory tool (omits problematic filter field)
133
230
  tools += [
134
231
  create_manage_memory_tool(namespace=namespace),
135
232
  create_search_memory_tool(namespace=namespace),
136
233
  ]
234
+ logger.debug(
235
+ "Memory tools added",
236
+ agent=agent.name,
237
+ tools=["manage_memory", "search_memory"],
238
+ )
137
239
 
138
- pre_agent_hook: Callable[..., Any] = next(
139
- iter(create_hooks(agent.pre_agent_hook)), None
240
+ # Create middleware list from configuration
241
+ middleware_list = _create_middleware_list(
242
+ agent=agent,
243
+ tool_models=tool_models,
244
+ chat_history=chat_history,
140
245
  )
141
- logger.debug(f"pre_agent_hook: {pre_agent_hook}")
142
246
 
143
- post_agent_hook: Callable[..., Any] = next(
144
- iter(create_hooks(agent.post_agent_hook)), None
247
+ # Log prompt configuration
248
+ if agent.prompt:
249
+ if isinstance(agent.prompt, PromptModel):
250
+ logger.info(
251
+ "Prompt configuration",
252
+ agent=agent.name,
253
+ prompt_type="PromptModel",
254
+ prompt_name=agent.prompt.name,
255
+ )
256
+ else:
257
+ prompt_preview: str = (
258
+ agent.prompt[:100] + "..." if len(agent.prompt) > 100 else agent.prompt
259
+ )
260
+ logger.info(
261
+ "Prompt configuration",
262
+ agent=agent.name,
263
+ prompt_type="string",
264
+ prompt_preview=prompt_preview,
265
+ )
266
+ else:
267
+ logger.debug("No custom prompt configured", agent=agent.name)
268
+
269
+ checkpointer: bool = memory is not None and memory.checkpointer is not None
270
+
271
+ # Get the prompt as middleware (always returns AgentMiddleware or None)
272
+ prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
273
+
274
+ # Add prompt middleware at the beginning for priority
275
+ if prompt_middleware is not None:
276
+ middleware_list.insert(0, prompt_middleware)
277
+
278
+ # Configure structured output if response_format is specified
279
+ response_format: Any = None
280
+ if agent.response_format is not None:
281
+ try:
282
+ response_format = agent.response_format.as_strategy()
283
+ if response_format is not None:
284
+ logger.info(
285
+ "Response format configuration",
286
+ agent=agent.name,
287
+ format_type=type(response_format).__name__,
288
+ structured_output=True,
289
+ )
290
+ except ValueError as e:
291
+ logger.error(
292
+ "Failed to configure structured output for agent",
293
+ agent=agent.name,
294
+ error=str(e),
295
+ )
296
+ raise
297
+
298
+ # Use LangChain v1's create_agent with middleware
299
+ # AgentState extends MessagesState with additional DAO AI fields
300
+ # System prompt is provided via middleware (dynamic_prompt)
301
+ logger.info(
302
+ "Creating LangChain agent",
303
+ agent=agent.name,
304
+ tools_count=len(tools),
305
+ middleware_count=len(middleware_list),
306
+ has_checkpointer=checkpointer,
145
307
  )
146
- logger.debug(f"post_agent_hook: {post_agent_hook}")
147
308
 
148
- compiled_agent: CompiledStateGraph = create_react_agent(
309
+ compiled_agent: CompiledStateGraph = create_agent(
149
310
  name=agent.name,
150
311
  model=llm,
151
- prompt=make_prompt(agent.prompt),
152
312
  tools=tools,
153
- store=True,
154
- checkpointer=True,
155
- state_schema=SharedState,
313
+ middleware=middleware_list,
314
+ checkpointer=checkpointer,
315
+ state_schema=AgentState,
156
316
  context_schema=Context,
157
- pre_model_hook=pre_agent_hook,
158
- post_model_hook=post_agent_hook,
317
+ response_format=response_format, # Add structured output support
159
318
  )
160
319
 
161
- for guardrail_definition in agent.guardrails:
162
- guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
163
- compiled_agent = with_guardrails(compiled_agent, guardrail)
164
-
165
320
  compiled_agent.name = agent.name
166
321
 
167
- agent_node: CompiledStateGraph
168
-
169
- chat_history: ChatHistoryModel = app.chat_history
322
+ logger.info("Agent node created successfully", agent=agent.name)
170
323
 
171
- if chat_history is None:
172
- logger.debug("No chat history configured, using compiled agent directly")
173
- agent_node = compiled_agent
174
- else:
175
- logger.debug("Creating agent node with chat history summarization")
176
- workflow: StateGraph = StateGraph(
177
- SharedState,
178
- config_schema=RunnableConfig,
179
- input=SharedState,
180
- output=SharedState,
181
- )
182
- workflow.add_node("summarization", summarization_node(app))
183
- workflow.add_node(
184
- "agent",
185
- call_agent_with_summarized_messages(agent=compiled_agent),
186
- )
187
- workflow.add_edge("summarization", "agent")
188
- workflow.set_entry_point("summarization")
189
- agent_node = workflow.compile(name=agent.name)
190
-
191
- return agent_node
192
-
193
-
194
- def message_hook_node(config: AppConfig) -> RunnableLike:
195
- message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
196
-
197
- @mlflow.trace()
198
- async def message_hook(
199
- state: IncomingState, runtime: Runtime[Context]
200
- ) -> SharedState:
201
- logger.debug("Running message validation")
202
- response: dict[str, Any] = {"is_valid": True, "message_error": None}
203
-
204
- for message_hook in message_hooks:
205
- message_hook: FunctionHook
206
- if message_hook:
207
- try:
208
- hook_response: dict[str, Any] = message_hook(
209
- state=state,
210
- runtime=runtime,
211
- )
212
- response.update(hook_response)
213
- logger.debug(f"Hook response: {hook_response}")
214
- if not response.get("is_valid", True):
215
- break
216
- except Exception as e:
217
- logger.error(f"Message validation failed: {e}")
218
- response_messages: Sequence[BaseMessage] = [
219
- AIMessage(content=str(e))
220
- ]
221
- return {
222
- "is_valid": False,
223
- "message_error": str(e),
224
- "messages": response_messages,
225
- }
226
-
227
- return response
228
-
229
- return message_hook
324
+ return compiled_agent