dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,121 @@
1
+ """
2
+ Model retry middleware for DAO AI agents.
3
+
4
+ Automatically retries failed model (LLM) calls with configurable exponential backoff.
5
+
6
+ Example:
7
+ from dao_ai.middleware import create_model_retry_middleware
8
+
9
+ # Retry failed model calls with exponential backoff
10
+ middleware = create_model_retry_middleware(
11
+ max_retries=3,
12
+ backoff_factor=2.0,
13
+ initial_delay=1.0,
14
+ )
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from typing import Any, Callable, Literal
20
+
21
+ from langchain.agents.middleware import ModelRetryMiddleware
22
+ from loguru import logger
23
+
24
+ __all__ = [
25
+ "ModelRetryMiddleware",
26
+ "create_model_retry_middleware",
27
+ ]
28
+
29
+
30
+ def create_model_retry_middleware(
31
+ max_retries: int = 3,
32
+ backoff_factor: float = 2.0,
33
+ initial_delay: float = 1.0,
34
+ max_delay: float | None = None,
35
+ jitter: bool = False,
36
+ retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] | None = None,
37
+ on_failure: Literal["continue", "error"] | Callable[[Exception], str] = "continue",
38
+ ) -> ModelRetryMiddleware:
39
+ """
40
+ Create a ModelRetryMiddleware for automatic model call retries.
41
+
42
+ Handles transient failures in model API calls with exponential backoff.
43
+ Useful for handling rate limits, network issues, and temporary outages.
44
+
45
+ Args:
46
+ max_retries: Max retry attempts after initial call. Default 3.
47
+ backoff_factor: Multiplier for exponential backoff. Default 2.0.
48
+ Delay = initial_delay * (backoff_factor ** retry_number)
49
+ Set to 0.0 for constant delay.
50
+ initial_delay: Initial delay in seconds before first retry. Default 1.0.
51
+ max_delay: Max delay in seconds (caps exponential growth). None = no cap.
52
+ jitter: Add ±25% random jitter to avoid thundering herd. Default False.
53
+ retry_on: When to retry:
54
+ - None: Retry on all errors (default)
55
+ - tuple of Exception types: Retry only on these
56
+ - callable: Function(exception) -> bool for custom logic
57
+ on_failure: Behavior when all retries exhausted:
58
+ - "continue": Return AIMessage with error, let agent continue (default)
59
+ - "error": Re-raise exception, stop execution
60
+ - callable: Function(exception) -> str for custom error message
61
+
62
+ Returns:
63
+ List containing ModelRetryMiddleware instance
64
+
65
+ Example:
66
+ # Basic retry with defaults
67
+ retry = create_model_retry_middleware()
68
+
69
+ # Custom backoff for rate limits
70
+ retry = create_model_retry_middleware(
71
+ max_retries=5,
72
+ backoff_factor=2.0,
73
+ initial_delay=1.0,
74
+ max_delay=60.0,
75
+ jitter=True,
76
+ )
77
+
78
+ # Retry only on specific exceptions, fail hard
79
+ retry = create_model_retry_middleware(
80
+ max_retries=3,
81
+ retry_on=(RateLimitError, TimeoutError),
82
+ on_failure="error",
83
+ )
84
+
85
+ # Custom retry logic
86
+ def should_retry(error: Exception) -> bool:
87
+ return "rate_limit" in str(error).lower()
88
+
89
+ retry = create_model_retry_middleware(
90
+ max_retries=5,
91
+ retry_on=should_retry,
92
+ )
93
+ """
94
+ logger.debug(
95
+ "Creating model retry middleware",
96
+ max_retries=max_retries,
97
+ backoff_factor=backoff_factor,
98
+ initial_delay=initial_delay,
99
+ max_delay=max_delay,
100
+ jitter=jitter,
101
+ on_failure=on_failure if isinstance(on_failure, str) else "custom",
102
+ )
103
+
104
+ # Build kwargs
105
+ kwargs: dict[str, Any] = {
106
+ "max_retries": max_retries,
107
+ "backoff_factor": backoff_factor,
108
+ "initial_delay": initial_delay,
109
+ "on_failure": on_failure,
110
+ }
111
+
112
+ if max_delay is not None:
113
+ kwargs["max_delay"] = max_delay
114
+
115
+ if jitter:
116
+ kwargs["jitter"] = jitter
117
+
118
+ if retry_on is not None:
119
+ kwargs["retry_on"] = retry_on
120
+
121
+ return ModelRetryMiddleware(**kwargs)
@@ -0,0 +1,157 @@
1
+ """
2
+ PII detection middleware for DAO AI agents.
3
+
4
+ Detects and handles Personally Identifiable Information (PII) in conversations
5
+ using configurable strategies (redact, mask, hash, block).
6
+
7
+ Example:
8
+ from dao_ai.middleware import create_pii_middleware
9
+
10
+ # Redact emails in user input
11
+ middleware = create_pii_middleware(
12
+ pii_type="email",
13
+ strategy="redact",
14
+ apply_to_input=True,
15
+ )
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Any, Callable, Literal, Pattern
21
+
22
+ from langchain.agents.middleware import PIIMiddleware
23
+ from loguru import logger
24
+
25
+ __all__ = [
26
+ "PIIMiddleware",
27
+ "create_pii_middleware",
28
+ ]
29
+
30
+ # Type alias for PII detector
31
+ PIIDetector = str | Pattern[str] | Callable[[str], list[dict[str, str | int]]]
32
+
33
+ # Built-in PII types
34
+ BUILTIN_PII_TYPES = frozenset({"email", "credit_card", "ip", "mac_address", "url"})
35
+
36
+
37
+ def create_pii_middleware(
38
+ pii_type: str,
39
+ strategy: Literal["redact", "mask", "hash", "block"] = "redact",
40
+ detector: PIIDetector | None = None,
41
+ apply_to_input: bool = True,
42
+ apply_to_output: bool = False,
43
+ apply_to_tool_results: bool = False,
44
+ ) -> PIIMiddleware:
45
+ """
46
+ Create a PIIMiddleware for detecting and handling PII.
47
+
48
+ Detects Personally Identifiable Information in conversations and handles
49
+ it according to the specified strategy. Useful for compliance, privacy,
50
+ and sanitizing logs.
51
+
52
+ Built-in PII types:
53
+ - email: Email addresses
54
+ - credit_card: Credit card numbers (Luhn validated)
55
+ - ip: IP addresses
56
+ - mac_address: MAC addresses
57
+ - url: URLs
58
+
59
+ Args:
60
+ pii_type: Type of PII to detect. Use built-in types (email, credit_card,
61
+ ip, mac_address, url) or custom type names with a detector.
62
+ strategy: How to handle detected PII:
63
+ - "redact": Replace with [REDACTED_{TYPE}] (default)
64
+ - "mask": Partially obscure (e.g., ****-****-****-1234)
65
+ - "hash": Replace with deterministic hash
66
+ - "block": Raise exception when detected
67
+ detector: Custom detector for non-built-in types. Can be:
68
+ - str: Regex pattern string
69
+ - re.Pattern: Compiled regex pattern
70
+ - Callable: Function(content: str) -> list[dict] with keys:
71
+ - text: The matched text
72
+ - start: Start index
73
+ - end: End index
74
+ Default None (uses built-in detector for built-in types).
75
+ apply_to_input: Check user messages before model call. Default True.
76
+ apply_to_output: Check AI messages after model call. Default False.
77
+ apply_to_tool_results: Check tool results after execution. Default False.
78
+
79
+ Returns:
80
+ List containing PIIMiddleware instance
81
+
82
+ Raises:
83
+ ValueError: If custom pii_type without detector, or invalid strategy
84
+
85
+ Example:
86
+ # Redact emails in input
87
+ email_redactor = create_pii_middleware(
88
+ pii_type="email",
89
+ strategy="redact",
90
+ apply_to_input=True,
91
+ )
92
+
93
+ # Mask credit cards
94
+ card_masker = create_pii_middleware(
95
+ pii_type="credit_card",
96
+ strategy="mask",
97
+ apply_to_input=True,
98
+ apply_to_output=True,
99
+ )
100
+
101
+ # Block API keys with custom regex
102
+ api_key_blocker = create_pii_middleware(
103
+ pii_type="api_key",
104
+ detector=r"sk-[a-zA-Z0-9]{32}",
105
+ strategy="block",
106
+ )
107
+
108
+ # Custom SSN detector with validation
109
+ def detect_ssn(content: str) -> list[dict]:
110
+ matches = []
111
+ pattern = r"\\d{3}-\\d{2}-\\d{4}"
112
+ for match in re.finditer(pattern, content):
113
+ ssn = match.group(0)
114
+ first_three = int(ssn[:3])
115
+ if first_three not in [0, 666] and not (900 <= first_three <= 999):
116
+ matches.append({
117
+ "text": ssn,
118
+ "start": match.start(),
119
+ "end": match.end(),
120
+ })
121
+ return matches
122
+
123
+ ssn_hasher = create_pii_middleware(
124
+ pii_type="ssn",
125
+ detector=detect_ssn,
126
+ strategy="hash",
127
+ )
128
+ """
129
+ # Validate: custom types require detector
130
+ if pii_type not in BUILTIN_PII_TYPES and detector is None:
131
+ raise ValueError(
132
+ f"Custom PII type '{pii_type}' requires a detector. "
133
+ f"Built-in types are: {', '.join(sorted(BUILTIN_PII_TYPES))}"
134
+ )
135
+
136
+ logger.debug(
137
+ "Creating PII middleware",
138
+ pii_type=pii_type,
139
+ strategy=strategy,
140
+ has_custom_detector=detector is not None,
141
+ apply_to_input=apply_to_input,
142
+ apply_to_output=apply_to_output,
143
+ apply_to_tool_results=apply_to_tool_results,
144
+ )
145
+
146
+ # Build kwargs
147
+ kwargs: dict[str, Any] = {
148
+ "strategy": strategy,
149
+ "apply_to_input": apply_to_input,
150
+ "apply_to_output": apply_to_output,
151
+ "apply_to_tool_results": apply_to_tool_results,
152
+ }
153
+
154
+ if detector is not None:
155
+ kwargs["detector"] = detector
156
+
157
+ return PIIMiddleware(pii_type, **kwargs)
@@ -0,0 +1,197 @@
1
+ """
2
+ Summarization middleware for DAO AI agents.
3
+
4
+ This module provides a LoggingSummarizationMiddleware that extends LangChain's
5
+ built-in SummarizationMiddleware with logging capabilities, and provides
6
+ helper utilities for creating summarization middleware from DAO AI configuration.
7
+
8
+ The middleware automatically:
9
+ - Summarizes older messages using a separate LLM call when thresholds are exceeded
10
+ - Replaces them with a summary message in State (permanently)
11
+ - Keeps recent messages intact for context
12
+ - Logs when summarization is triggered and completed
13
+
14
+ Example:
15
+ from dao_ai.middleware import create_summarization_middleware
16
+ from dao_ai.config import ChatHistoryModel, LLMModel
17
+
18
+ chat_history = ChatHistoryModel(
19
+ model=LLMModel(name="gpt-4o-mini"),
20
+ max_tokens=256,
21
+ max_tokens_before_summary=4000,
22
+ )
23
+
24
+ middleware = create_summarization_middleware(chat_history)
25
+ """
26
+
27
+ from typing import Any, Tuple
28
+
29
+ from langchain.agents.middleware import SummarizationMiddleware
30
+ from langchain_core.language_models import LanguageModelLike
31
+ from langchain_core.messages import BaseMessage
32
+ from langgraph.runtime import Runtime
33
+ from loguru import logger
34
+
35
+ from dao_ai.config import ChatHistoryModel
36
+
37
+ __all__ = [
38
+ "SummarizationMiddleware",
39
+ "LoggingSummarizationMiddleware",
40
+ "create_summarization_middleware",
41
+ ]
42
+
43
+
44
+ class LoggingSummarizationMiddleware(SummarizationMiddleware):
45
+ """
46
+ SummarizationMiddleware with logging for when summarization occurs.
47
+
48
+ This extends LangChain's SummarizationMiddleware to add logging at INFO level
49
+ when summarization is triggered and completed, providing visibility into
50
+ when conversation history is being summarized.
51
+
52
+ Logs include:
53
+ - Original message count and approximate token count (before summarization)
54
+ - New message count and approximate token count (after summarization)
55
+ - Number of messages that were summarized
56
+ """
57
+
58
+ def _log_summarization(
59
+ self,
60
+ original_message_count: int,
61
+ original_token_count: int,
62
+ result_messages: list[Any],
63
+ ) -> None:
64
+ """Log summarization details with before/after metrics."""
65
+ # Result messages: [RemoveMessage, summary_message, ...preserved_messages]
66
+ # New message count excludes RemoveMessage (index 0)
67
+ new_messages = [
68
+ msg for msg in result_messages if not self._is_remove_message(msg)
69
+ ]
70
+ new_message_count = len(new_messages)
71
+ new_token_count = self.token_counter(new_messages) if new_messages else 0
72
+
73
+ # Calculate how many messages were summarized
74
+ # preserved = new_messages - 1 (the summary message)
75
+ preserved_count = max(0, new_message_count - 1)
76
+ summarized_count = original_message_count - preserved_count
77
+
78
+ logger.info(
79
+ "Conversation summarized",
80
+ before_messages=original_message_count,
81
+ before_tokens=original_token_count,
82
+ after_messages=new_message_count,
83
+ after_tokens=new_token_count,
84
+ summarized_messages=summarized_count,
85
+ )
86
+ logger.debug(
87
+ "Summarization details",
88
+ trigger=self.trigger,
89
+ keep=self.keep,
90
+ preserved_messages=preserved_count,
91
+ token_reduction=original_token_count - new_token_count,
92
+ )
93
+
94
+ def _is_remove_message(self, msg: Any) -> bool:
95
+ """Check if a message is a RemoveMessage."""
96
+ return type(msg).__name__ == "RemoveMessage"
97
+
98
+ def before_model(
99
+ self, state: dict[str, Any], runtime: Runtime
100
+ ) -> dict[str, Any] | None:
101
+ """Process messages before model invocation, logging when summarization occurs."""
102
+ messages: list[BaseMessage] = state.get("messages", [])
103
+ original_message_count = len(messages)
104
+ original_token_count = self.token_counter(messages) if messages else 0
105
+
106
+ result = super().before_model(state, runtime)
107
+
108
+ if result is not None:
109
+ result_messages = result.get("messages", [])
110
+ self._log_summarization(
111
+ original_message_count,
112
+ original_token_count,
113
+ result_messages,
114
+ )
115
+
116
+ return result
117
+
118
+ async def abefore_model(
119
+ self, state: dict[str, Any], runtime: Runtime
120
+ ) -> dict[str, Any] | None:
121
+ """Process messages before model invocation (async), logging when summarization occurs."""
122
+ messages: list[BaseMessage] = state.get("messages", [])
123
+ original_message_count = len(messages)
124
+ original_token_count = self.token_counter(messages) if messages else 0
125
+
126
+ result = await super().abefore_model(state, runtime)
127
+
128
+ if result is not None:
129
+ result_messages = result.get("messages", [])
130
+ self._log_summarization(
131
+ original_message_count,
132
+ original_token_count,
133
+ result_messages,
134
+ )
135
+
136
+ return result
137
+
138
+
139
+ def create_summarization_middleware(
140
+ chat_history: ChatHistoryModel,
141
+ ) -> LoggingSummarizationMiddleware:
142
+ """
143
+ Create a LoggingSummarizationMiddleware from DAO AI ChatHistoryModel configuration.
144
+
145
+ This factory function creates a LoggingSummarizationMiddleware instance
146
+ configured according to the DAO AI ChatHistoryModel settings. The middleware
147
+ includes logging at INFO level when summarization is triggered.
148
+
149
+ Args:
150
+ chat_history: ChatHistoryModel configuration for summarization
151
+
152
+ Returns:
153
+ List containing LoggingSummarizationMiddleware configured with the specified parameters
154
+
155
+ Example:
156
+ from dao_ai.config import ChatHistoryModel, LLMModel
157
+
158
+ chat_history = ChatHistoryModel(
159
+ model=LLMModel(name="gpt-4o-mini"),
160
+ max_tokens=256,
161
+ max_tokens_before_summary=4000,
162
+ )
163
+
164
+ middleware = create_summarization_middleware(chat_history)
165
+ """
166
+ logger.debug(
167
+ "Creating summarization middleware",
168
+ max_tokens=chat_history.max_tokens,
169
+ max_tokens_before_summary=chat_history.max_tokens_before_summary,
170
+ max_messages_before_summary=chat_history.max_messages_before_summary,
171
+ )
172
+
173
+ # Get the LLM model
174
+ model: LanguageModelLike = chat_history.model.as_chat_model()
175
+
176
+ # Determine trigger condition
177
+ # LangChain uses ("tokens", value) or ("messages", value) tuples
178
+ trigger: Tuple[str, int]
179
+ if chat_history.max_tokens_before_summary:
180
+ trigger = ("tokens", chat_history.max_tokens_before_summary)
181
+ elif chat_history.max_messages_before_summary:
182
+ trigger = ("messages", chat_history.max_messages_before_summary)
183
+ else:
184
+ # Default to a reasonable token threshold
185
+ trigger = ("tokens", chat_history.max_tokens * 10)
186
+
187
+ # Determine keep condition - how many recent messages/tokens to preserve
188
+ # Default to keeping enough for context
189
+ keep: Tuple[str, int] = ("tokens", chat_history.max_tokens)
190
+
191
+ logger.info("Summarization middleware configured", trigger=trigger, keep=keep)
192
+
193
+ return LoggingSummarizationMiddleware(
194
+ model=model,
195
+ trigger=trigger,
196
+ keep=keep,
197
+ )
@@ -0,0 +1,210 @@
1
+ """
2
+ Tool call limit middleware for DAO AI agents.
3
+
4
+ This module provides a factory for creating LangChain's ToolCallLimitMiddleware
5
+ from DAO AI configuration.
6
+
7
+ Example:
8
+ from dao_ai.middleware import create_tool_call_limit_middleware
9
+
10
+ # Global limit across all tools
11
+ middleware = create_tool_call_limit_middleware(
12
+ thread_limit=20,
13
+ run_limit=10,
14
+ )
15
+
16
+ # Limit specific tool by name
17
+ search_limiter = create_tool_call_limit_middleware(
18
+ tool="search_web",
19
+ run_limit=3,
20
+ exit_behavior="continue",
21
+ )
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from typing import Any, Literal
27
+
28
+ from langchain.agents.middleware import ToolCallLimitMiddleware
29
+ from langchain_core.tools import BaseTool
30
+ from loguru import logger
31
+
32
+ from dao_ai.config import BaseFunctionModel, ToolModel
33
+
34
+ __all__ = [
35
+ "ToolCallLimitMiddleware",
36
+ "create_tool_call_limit_middleware",
37
+ ]
38
+
39
+
40
+ def _resolve_tool(tool: str | ToolModel | dict[str, Any]) -> list[str]:
41
+ """
42
+ Resolve tool argument to a list of actual tool names.
43
+
44
+ Args:
45
+ tool: String name, ToolModel, or dict to resolve
46
+
47
+ Returns:
48
+ List of tool name strings
49
+
50
+ Raises:
51
+ ValueError: If dict cannot be converted to ToolModel
52
+ TypeError: If tool is not a supported type
53
+ """
54
+ # String: return as single-item list
55
+ if isinstance(tool, str):
56
+ return [tool]
57
+
58
+ # Dict: convert to ToolModel first
59
+ if isinstance(tool, dict):
60
+ try:
61
+ tool_model = ToolModel(**tool)
62
+ except Exception as e:
63
+ raise ValueError(
64
+ f"Failed to construct ToolModel from dict: {e}\n"
65
+ f"Dict must have 'name' and 'function' keys."
66
+ ) from e
67
+ elif isinstance(tool, ToolModel):
68
+ tool_model = tool
69
+ else:
70
+ raise TypeError(
71
+ f"tool must be str, ToolModel, or dict, got {type(tool).__name__}"
72
+ )
73
+
74
+ # Extract tool names from ToolModel
75
+ return _extract_tool_names(tool_model)
76
+
77
+
78
+ def _extract_tool_names(tool_model: ToolModel) -> list[str]:
79
+ """
80
+ Extract actual tool names from a ToolModel.
81
+
82
+ A single ToolModel can produce multiple tools (e.g., UC functions).
83
+ Falls back to ToolModel.name if extraction fails.
84
+ """
85
+ function = tool_model.function
86
+
87
+ # String function references can't be introspected
88
+ if not isinstance(function, BaseFunctionModel):
89
+ logger.debug(
90
+ "Cannot extract names from string function, using ToolModel.name",
91
+ tool_model_name=tool_model.name,
92
+ )
93
+ return [tool_model.name]
94
+
95
+ # Try to extract names from created tools
96
+ try:
97
+ tool_names = [
98
+ tool.name
99
+ for tool in function.as_tools()
100
+ if isinstance(tool, BaseTool) and tool.name
101
+ ]
102
+ if tool_names:
103
+ logger.trace(
104
+ "Extracted tool names",
105
+ tool_model_name=tool_model.name,
106
+ tool_names=tool_names,
107
+ )
108
+ return tool_names
109
+ except Exception as e:
110
+ logger.warning(
111
+ "Error extracting tool names from ToolModel",
112
+ tool_model_name=tool_model.name,
113
+ error=str(e),
114
+ )
115
+
116
+ # Fallback to ToolModel.name
117
+ logger.debug(
118
+ "Falling back to ToolModel.name",
119
+ tool_model_name=tool_model.name,
120
+ )
121
+ return [tool_model.name]
122
+
123
+
124
+ def create_tool_call_limit_middleware(
125
+ tool: str | ToolModel | dict[str, Any] | None = None,
126
+ thread_limit: int | None = None,
127
+ run_limit: int | None = None,
128
+ exit_behavior: Literal["continue", "error", "end"] = "continue",
129
+ ) -> ToolCallLimitMiddleware:
130
+ """
131
+ Create a ToolCallLimitMiddleware with graceful termination support.
132
+
133
+ Factory for LangChain's ToolCallLimitMiddleware that supports DAO AI
134
+ configuration types.
135
+
136
+ Args:
137
+ tool: Tool to limit. Can be:
138
+ - None: Global limit on all tools
139
+ - str: Limit specific tool by name
140
+ - ToolModel: Limit tool(s) from DAO AI config
141
+ - dict: Tool config dict (converted to ToolModel)
142
+ thread_limit: Max calls per thread (conversation). Requires checkpointer.
143
+ run_limit: Max calls per run (single invocation).
144
+ exit_behavior: What to do when limit hit:
145
+ - "continue": Block tool with error message, let agent continue
146
+ - "error": Raise ToolCallLimitExceededError immediately
147
+ - "end": Stop execution gracefully (single-tool only)
148
+
149
+ Returns:
150
+ A ToolCallLimitMiddleware instance. If ToolModel produces multiple tools,
151
+ only the first tool is used (with a warning logged).
152
+
153
+ Raises:
154
+ ValueError: If no limits specified, or invalid dict
155
+ TypeError: If tool is unsupported type
156
+
157
+ Example:
158
+ # Global limit
159
+ limiter = create_tool_call_limit_middleware(run_limit=10)
160
+
161
+ # Tool-specific limit
162
+ limiter = create_tool_call_limit_middleware(
163
+ tool="search_web",
164
+ run_limit=3,
165
+ exit_behavior="continue",
166
+ )
167
+ """
168
+ if thread_limit is None and run_limit is None:
169
+ raise ValueError("At least one of thread_limit or run_limit must be specified.")
170
+
171
+ # Global limit: no tool parameter
172
+ if tool is None:
173
+ logger.debug(
174
+ "Creating global tool call limit",
175
+ thread_limit=thread_limit,
176
+ run_limit=run_limit,
177
+ exit_behavior=exit_behavior,
178
+ )
179
+ return ToolCallLimitMiddleware(
180
+ thread_limit=thread_limit,
181
+ run_limit=run_limit,
182
+ exit_behavior=exit_behavior,
183
+ )
184
+
185
+ # Resolve to list of tool names
186
+ names = _resolve_tool(tool)
187
+
188
+ # Use first tool name (warn if multiple)
189
+ tool_name = names[0]
190
+ if len(names) > 1:
191
+ logger.warning(
192
+ "ToolModel resolved to multiple tool names, using first only",
193
+ tool_names=names,
194
+ using=tool_name,
195
+ )
196
+
197
+ logger.debug(
198
+ "Creating tool call limit middleware",
199
+ tool_name=tool_name,
200
+ thread_limit=thread_limit,
201
+ run_limit=run_limit,
202
+ exit_behavior=exit_behavior,
203
+ )
204
+
205
+ return ToolCallLimitMiddleware(
206
+ tool_name=tool_name,
207
+ thread_limit=thread_limit,
208
+ run_limit=run_limit,
209
+ exit_behavior=exit_behavior,
210
+ )