dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,50 @@
1
+ """
2
+ Base classes and types for DAO AI middleware.
3
+
4
+ This module re-exports LangChain's middleware types for convenience.
5
+ Use LangChainAgentMiddleware directly with DAO AI's state and context types.
6
+
7
+ Example:
8
+ from langchain.agents.middleware import AgentMiddleware
9
+ from dao_ai.state import AgentState, Context
10
+ from langgraph.runtime import Runtime
11
+
12
+ class MyMiddleware(AgentMiddleware[AgentState, Context]):
13
+ def before_model(
14
+ self,
15
+ state: AgentState,
16
+ runtime: Runtime[Context]
17
+ ) -> dict[str, Any] | None:
18
+ print(f"About to call model with {len(state['messages'])} messages")
19
+ return None
20
+ """
21
+
22
+ from langchain.agents.middleware import (
23
+ AgentMiddleware,
24
+ ModelRequest,
25
+ after_agent,
26
+ after_model,
27
+ before_agent,
28
+ before_model,
29
+ dynamic_prompt,
30
+ wrap_model_call,
31
+ wrap_tool_call,
32
+ )
33
+ from langchain.agents.middleware.types import ModelResponse
34
+
35
+ # Re-export LangChain types for convenience
36
+ __all__ = [
37
+ # Base middleware class
38
+ "AgentMiddleware",
39
+ # Types
40
+ "ModelRequest",
41
+ "ModelResponse",
42
+ # Decorators
43
+ "before_agent",
44
+ "before_model",
45
+ "after_agent",
46
+ "after_model",
47
+ "wrap_model_call",
48
+ "wrap_tool_call",
49
+ "dynamic_prompt",
50
+ ]
@@ -0,0 +1,61 @@
1
+ """
2
+ Core middleware utilities for DAO AI.
3
+
4
+ This module provides the factory function for creating middleware instances
5
+ from fully qualified function names.
6
+ """
7
+
8
+ from typing import Any, Callable
9
+
10
+ from loguru import logger
11
+
12
+ from dao_ai.utils import load_function
13
+
14
+
15
+ def create_factory_middleware(
16
+ function_name: str,
17
+ args: dict[str, Any] | None = None,
18
+ ) -> Any:
19
+ """
20
+ Create middleware from a factory function.
21
+
22
+
23
+ This factory function dynamically loads a Python function and calls it
24
+ with the provided arguments to create a middleware instance.
25
+
26
+ The factory function should return a middleware object compatible with
27
+ LangChain's create_agent middleware parameter (AgentMiddleware or any
28
+ callable/object that implements the middleware interface).
29
+
30
+ Args:
31
+ function_name: Fully qualified name of the factory function
32
+ (e.g., 'my_module.create_custom_middleware')
33
+ args: Arguments to pass to the factory function
34
+
35
+ Returns:
36
+ A middleware instance returned by the factory function
37
+
38
+ Raises:
39
+ ImportError: If the function cannot be loaded
40
+
41
+ Example:
42
+ # Factory function in my_module.py:
43
+ def create_custom_middleware(threshold: float = 0.5) -> AgentMiddleware:
44
+ return MyCustomMiddleware(threshold=threshold)
45
+
46
+ # Usage:
47
+ middleware = create_factory_middleware(
48
+ function_name="my_module.create_custom_middleware",
49
+ args={"threshold": 0.8}
50
+ )
51
+ """
52
+ if args is None:
53
+ args = {}
54
+
55
+ logger.debug(f"Creating factory middleware: {function_name} with args: {args}")
56
+
57
+ factory: Callable[..., Any] = load_function(function_name=function_name)
58
+ middleware: Any = factory(**args)
59
+
60
+ logger.debug(f"Created middleware from factory: {type(middleware).__name__}")
61
+ return middleware
@@ -0,0 +1,415 @@
1
+ """
2
+ Guardrail middleware for DAO AI agents.
3
+
4
+ This module provides middleware implementations for applying guardrails
5
+ to agent responses, including LLM-based judging and content validation.
6
+
7
+ Factory functions are provided for consistent configuration via the
8
+ DAO AI middleware factory pattern.
9
+ """
10
+
11
+ from typing import Any, Optional
12
+
13
+ from langchain.agents.middleware import hook_config
14
+ from langchain_core.language_models import LanguageModelLike
15
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
16
+ from langgraph.runtime import Runtime
17
+ from loguru import logger
18
+ from openevals.llm import create_llm_as_judge
19
+
20
+ from dao_ai.messages import last_ai_message, last_human_message
21
+ from dao_ai.middleware.base import AgentMiddleware
22
+ from dao_ai.state import AgentState, Context
23
+
24
+
25
+ def _extract_text_content(message: BaseMessage) -> str:
26
+ """
27
+ Extract text content from a message, handling both string and list formats.
28
+
29
+ Args:
30
+ message: The message to extract text from
31
+
32
+ Returns:
33
+ The extracted text content as a string
34
+ """
35
+ content = message.content
36
+
37
+ if isinstance(content, str):
38
+ return content
39
+ elif isinstance(content, list):
40
+ # Extract text from content blocks (e.g., Claude's structured content)
41
+ text_parts = []
42
+ for block in content:
43
+ if isinstance(block, dict) and block.get("type") == "text":
44
+ text_parts.append(block.get("text", ""))
45
+ elif isinstance(block, str):
46
+ text_parts.append(block)
47
+ return " ".join(text_parts)
48
+ else:
49
+ return str(content)
50
+
51
+
52
+ __all__ = [
53
+ "GuardrailMiddleware",
54
+ "ContentFilterMiddleware",
55
+ "SafetyGuardrailMiddleware",
56
+ "create_guardrail_middleware",
57
+ "create_content_filter_middleware",
58
+ "create_safety_guardrail_middleware",
59
+ ]
60
+
61
+
62
+ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
63
+ """
64
+ Middleware that applies LLM-based guardrails to agent responses.
65
+
66
+ Uses an LLM judge to evaluate responses against a prompt/criteria and
67
+ can request improvements if the response doesn't meet the criteria.
68
+
69
+ This is equivalent to the previous reflection_guardrail pattern but
70
+ implemented as middleware for better composability.
71
+
72
+ Args:
73
+ guardrail_name: Name identifying this guardrail
74
+ model: The LLM to use for evaluation
75
+ prompt: The evaluation prompt/criteria
76
+ num_retries: Maximum number of retry attempts (default: 3)
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ name: str,
82
+ model: LanguageModelLike,
83
+ prompt: str,
84
+ num_retries: int = 3,
85
+ ):
86
+ super().__init__()
87
+ self.guardrail_name = name
88
+ self.model = model
89
+ self.prompt = prompt
90
+ self.num_retries = num_retries
91
+ self._retry_count = 0
92
+
93
+ @property
94
+ def name(self) -> str:
95
+ """Return the guardrail name for middleware identification."""
96
+ return self.guardrail_name
97
+
98
+ def after_model(
99
+ self, state: AgentState, runtime: Runtime[Context]
100
+ ) -> dict[str, Any] | None:
101
+ """
102
+ Evaluate the model's response using an LLM judge.
103
+
104
+ If the response doesn't meet the guardrail criteria, returns a
105
+ HumanMessage with feedback to trigger a retry.
106
+ """
107
+ messages: list[BaseMessage] = state.get("messages", [])
108
+
109
+ if not messages:
110
+ return None
111
+
112
+ ai_message: AIMessage | None = last_ai_message(messages)
113
+ human_message: HumanMessage | None = last_human_message(messages)
114
+
115
+ if not ai_message or not human_message:
116
+ return None
117
+
118
+ # Skip evaluation if the AI message has tool calls (not the final response yet)
119
+ if ai_message.tool_calls:
120
+ logger.debug(
121
+ f"Guardrail '{self.guardrail_name}' skipping evaluation - "
122
+ "AI message contains tool calls, waiting for final response"
123
+ )
124
+ return None
125
+
126
+ # Skip evaluation if the AI message has no content to evaluate
127
+ if not ai_message.content:
128
+ logger.debug(
129
+ f"Guardrail '{self.guardrail_name}' skipping evaluation - "
130
+ "AI message has no content"
131
+ )
132
+ return None
133
+
134
+ logger.debug(f"Evaluating response with guardrail '{self.guardrail_name}'")
135
+
136
+ # Extract text content from messages (handles both string and structured content)
137
+ human_content = _extract_text_content(human_message)
138
+ ai_content = _extract_text_content(ai_message)
139
+
140
+ logger.debug(
141
+ f"Guardrail '{self.guardrail_name}' evaluating: "
142
+ f"input_length={len(human_content)}, output_length={len(ai_content)}"
143
+ )
144
+
145
+ evaluator = create_llm_as_judge(
146
+ prompt=self.prompt,
147
+ judge=self.model,
148
+ )
149
+
150
+ eval_result = evaluator(inputs=human_content, outputs=ai_content)
151
+
152
+ if eval_result["score"]:
153
+ logger.debug(f"Response approved by guardrail '{self.guardrail_name}'")
154
+ logger.debug(f"Judge's comment: {eval_result['comment']}")
155
+ self._retry_count = 0
156
+ return None
157
+ else:
158
+ self._retry_count += 1
159
+ comment: str = eval_result["comment"]
160
+
161
+ if self._retry_count >= self.num_retries:
162
+ logger.warning(
163
+ f"Guardrail '{self.guardrail_name}' failed - max retries reached "
164
+ f"({self._retry_count}/{self.num_retries})"
165
+ )
166
+ logger.warning(f"Final judge's critique: {comment}")
167
+ self._retry_count = 0
168
+
169
+ # Add system message to inform user of guardrail failure
170
+ failure_message = (
171
+ f"⚠️ **Quality Check Failed**\n\n"
172
+ f"The response did not meet the '{self.guardrail_name}' quality standards "
173
+ f"after {self.num_retries} attempts.\n\n"
174
+ f"**Issue:** {comment}\n\n"
175
+ f"The best available response has been provided, but please be aware it may not fully meet quality expectations."
176
+ )
177
+ return {"messages": [AIMessage(content=failure_message)]}
178
+
179
+ logger.warning(
180
+ f"Guardrail '{self.guardrail_name}' requested improvements "
181
+ f"(retry {self._retry_count}/{self.num_retries})"
182
+ )
183
+ logger.warning(f"Judge's critique: {comment}")
184
+
185
+ content: str = "\n".join([str(human_message.content), comment])
186
+ return {"messages": [HumanMessage(content=content)]}
187
+
188
+
189
+ class ContentFilterMiddleware(AgentMiddleware[AgentState, Context]):
190
+ """
191
+ Middleware that filters responses containing banned keywords.
192
+
193
+ This is a deterministic guardrail that blocks responses containing
194
+ specified keywords.
195
+
196
+ Args:
197
+ banned_keywords: List of keywords to block
198
+ block_message: Message to return when content is blocked
199
+ """
200
+
201
+ def __init__(
202
+ self,
203
+ banned_keywords: list[str],
204
+ block_message: str = "I cannot provide that response. Please rephrase your request.",
205
+ ):
206
+ super().__init__()
207
+ self.banned_keywords = [kw.lower() for kw in banned_keywords]
208
+ self.block_message = block_message
209
+
210
+ @hook_config(can_jump_to=["end"])
211
+ def before_agent(
212
+ self, state: AgentState, runtime: Runtime[Context]
213
+ ) -> dict[str, Any] | None:
214
+ """Block requests containing banned keywords."""
215
+ messages: list[BaseMessage] = state.get("messages", [])
216
+
217
+ if not messages:
218
+ return None
219
+
220
+ first_message = messages[0]
221
+ if not isinstance(first_message, HumanMessage):
222
+ return None
223
+
224
+ content = str(first_message.content).lower()
225
+
226
+ for keyword in self.banned_keywords:
227
+ if keyword in content:
228
+ logger.warning(f"Content filter blocked request containing '{keyword}'")
229
+ return {
230
+ "messages": [AIMessage(content=self.block_message)],
231
+ "jump_to": "end",
232
+ }
233
+
234
+ return None
235
+
236
+ def after_model(
237
+ self, state: AgentState, runtime: Runtime[Context]
238
+ ) -> dict[str, Any] | None:
239
+ """Block responses containing banned keywords."""
240
+ messages: list[BaseMessage] = state.get("messages", [])
241
+
242
+ if not messages:
243
+ return None
244
+
245
+ last_message: AIMessage | None = last_ai_message(messages)
246
+ if not last_message:
247
+ return None
248
+
249
+ content = str(last_message.content).lower()
250
+
251
+ for keyword in self.banned_keywords:
252
+ if keyword in content:
253
+ logger.warning(
254
+ f"Content filter blocked response containing '{keyword}'"
255
+ )
256
+ # Modify the last message content
257
+ last_message.content = self.block_message
258
+ return None
259
+
260
+ return None
261
+
262
+
263
+ class SafetyGuardrailMiddleware(AgentMiddleware[AgentState, Context]):
264
+ """
265
+ Model-based safety guardrail middleware.
266
+
267
+ Uses an LLM to evaluate whether responses are safe and appropriate.
268
+ This is applied after the agent completes to validate the final output.
269
+
270
+ Args:
271
+ safety_model: The LLM to use for safety evaluation
272
+ """
273
+
274
+ def __init__(self, safety_model: Optional[LanguageModelLike] = None):
275
+ super().__init__()
276
+ self._safety_model = safety_model
277
+
278
+ @property
279
+ def safety_model(self) -> LanguageModelLike:
280
+ """Lazily initialize the safety model."""
281
+ if self._safety_model is None:
282
+ from langchain.chat_models import init_chat_model
283
+
284
+ self._safety_model = init_chat_model("gpt-4o-mini")
285
+ return self._safety_model
286
+
287
+ def after_agent(
288
+ self, state: AgentState, runtime: Runtime[Context]
289
+ ) -> dict[str, Any] | None:
290
+ """Evaluate response safety before returning to user."""
291
+ messages: list[BaseMessage] = state.get("messages", [])
292
+
293
+ if not messages:
294
+ return None
295
+
296
+ last_message: AIMessage | None = last_ai_message(messages)
297
+ if not last_message:
298
+ return None
299
+
300
+ safety_prompt = f"""Evaluate if this response is safe and appropriate.
301
+ Respond with only 'SAFE' or 'UNSAFE'.
302
+
303
+ Response: {last_message.content}"""
304
+
305
+ result = self.safety_model.invoke([{"role": "user", "content": safety_prompt}])
306
+
307
+ if "UNSAFE" in str(result.content):
308
+ logger.warning("Safety guardrail blocked unsafe response")
309
+ last_message.content = (
310
+ "I cannot provide that response. Please rephrase your request."
311
+ )
312
+
313
+ return None
314
+
315
+
316
+ # =============================================================================
317
+ # Factory Functions
318
+ # =============================================================================
319
+
320
+
321
+ def create_guardrail_middleware(
322
+ name: str,
323
+ model: LanguageModelLike,
324
+ prompt: str,
325
+ num_retries: int = 3,
326
+ ) -> GuardrailMiddleware:
327
+ """
328
+ Create a GuardrailMiddleware instance.
329
+
330
+ Factory function for creating LLM-based guardrail middleware that evaluates
331
+ agent responses against specified criteria using an LLM judge.
332
+
333
+ Args:
334
+ name: Name identifying this guardrail
335
+ model: The LLM to use for evaluation
336
+ prompt: The evaluation prompt/criteria
337
+ num_retries: Maximum number of retry attempts (default: 3)
338
+
339
+ Returns:
340
+ GuardrailMiddleware configured with the specified parameters
341
+
342
+ Example:
343
+ middleware = create_guardrail_middleware(
344
+ name="tone_check",
345
+ model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
346
+ prompt="Evaluate if the response is professional and helpful.",
347
+ num_retries=2,
348
+ )
349
+ """
350
+ logger.debug(f"Creating guardrail middleware: {name}")
351
+ return GuardrailMiddleware(
352
+ name=name,
353
+ model=model,
354
+ prompt=prompt,
355
+ num_retries=num_retries,
356
+ )
357
+
358
+
359
+ def create_content_filter_middleware(
360
+ banned_keywords: list[str],
361
+ block_message: str = "I cannot provide that response. Please rephrase your request.",
362
+ ) -> ContentFilterMiddleware:
363
+ """
364
+ Create a ContentFilterMiddleware instance.
365
+
366
+ Factory function for creating deterministic content filter middleware
367
+ that blocks requests/responses containing banned keywords.
368
+
369
+ Args:
370
+ banned_keywords: List of keywords to block
371
+ block_message: Message to return when content is blocked
372
+
373
+ Returns:
374
+ ContentFilterMiddleware configured with the specified parameters
375
+
376
+ Example:
377
+ middleware = create_content_filter_middleware(
378
+ banned_keywords=["password", "secret", "api_key"],
379
+ block_message="I cannot discuss sensitive credentials.",
380
+ )
381
+ """
382
+ logger.debug(
383
+ f"Creating content filter middleware with {len(banned_keywords)} keywords"
384
+ )
385
+ return ContentFilterMiddleware(
386
+ banned_keywords=banned_keywords,
387
+ block_message=block_message,
388
+ )
389
+
390
+
391
+ def create_safety_guardrail_middleware(
392
+ safety_model: Optional[LanguageModelLike] = None,
393
+ ) -> SafetyGuardrailMiddleware:
394
+ """
395
+ Create a SafetyGuardrailMiddleware instance.
396
+
397
+ Factory function for creating model-based safety guardrail middleware
398
+ that evaluates whether responses are safe and appropriate.
399
+
400
+ Args:
401
+ safety_model: The LLM to use for safety evaluation. If not provided,
402
+ defaults to gpt-4o-mini.
403
+
404
+ Returns:
405
+ SafetyGuardrailMiddleware configured with the specified model
406
+
407
+ Example:
408
+ from databricks_langchain import ChatDatabricks
409
+
410
+ middleware = create_safety_guardrail_middleware(
411
+ safety_model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
412
+ )
413
+ """
414
+ logger.debug("Creating safety guardrail middleware")
415
+ return SafetyGuardrailMiddleware(safety_model=safety_model)