dao-ai 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +104 -25
- dao_ai/config.py +149 -40
- dao_ai/middleware/__init__.py +33 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/nodes.py +5 -12
- dao_ai/orchestration/supervisor.py +6 -5
- dao_ai/providers/databricks.py +11 -0
- dao_ai/vector_search.py +37 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/METADATA +36 -2
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/RECORD +24 -18
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Context editing middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Manages conversation context by clearing older tool call outputs when token limits
|
|
5
|
+
are reached, while preserving recent results.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from dao_ai.middleware import create_context_editing_middleware
|
|
9
|
+
|
|
10
|
+
# Clear old tool outputs when context exceeds 100k tokens
|
|
11
|
+
middleware = create_context_editing_middleware(
|
|
12
|
+
trigger=100000,
|
|
13
|
+
keep=3,
|
|
14
|
+
)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, Literal
|
|
20
|
+
|
|
21
|
+
from langchain.agents.middleware import ClearToolUsesEdit, ContextEditingMiddleware
|
|
22
|
+
from langchain_core.tools import BaseTool
|
|
23
|
+
from loguru import logger
|
|
24
|
+
|
|
25
|
+
from dao_ai.config import BaseFunctionModel, ToolModel
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"ContextEditingMiddleware",
|
|
29
|
+
"ClearToolUsesEdit",
|
|
30
|
+
"create_context_editing_middleware",
|
|
31
|
+
"create_clear_tool_uses_edit",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _resolve_tool_names(
|
|
36
|
+
tools: list[str | ToolModel | dict[str, Any]] | None,
|
|
37
|
+
) -> list[str]:
|
|
38
|
+
"""Resolve tool specs to a list of tool name strings."""
|
|
39
|
+
if tools is None:
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
result: list[str] = []
|
|
43
|
+
for tool in tools:
|
|
44
|
+
if isinstance(tool, str):
|
|
45
|
+
result.append(tool)
|
|
46
|
+
elif isinstance(tool, dict):
|
|
47
|
+
try:
|
|
48
|
+
tool_model = ToolModel(**tool)
|
|
49
|
+
result.extend(_extract_tool_names(tool_model))
|
|
50
|
+
except Exception as e:
|
|
51
|
+
raise ValueError(f"Failed to construct ToolModel from dict: {e}") from e
|
|
52
|
+
elif isinstance(tool, ToolModel):
|
|
53
|
+
result.extend(_extract_tool_names(tool))
|
|
54
|
+
else:
|
|
55
|
+
raise TypeError(
|
|
56
|
+
f"Tool must be str, ToolModel, or dict, got {type(tool).__name__}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _extract_tool_names(tool_model: ToolModel) -> list[str]:
|
|
63
|
+
"""Extract tool names from ToolModel, falling back to ToolModel.name."""
|
|
64
|
+
function = tool_model.function
|
|
65
|
+
|
|
66
|
+
if not isinstance(function, BaseFunctionModel):
|
|
67
|
+
return [tool_model.name]
|
|
68
|
+
|
|
69
|
+
try:
|
|
70
|
+
tool_names = [
|
|
71
|
+
tool.name
|
|
72
|
+
for tool in function.as_tools()
|
|
73
|
+
if isinstance(tool, BaseTool) and tool.name
|
|
74
|
+
]
|
|
75
|
+
return tool_names if tool_names else [tool_model.name]
|
|
76
|
+
except Exception:
|
|
77
|
+
return [tool_model.name]
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def create_clear_tool_uses_edit(
|
|
81
|
+
trigger: int = 100000,
|
|
82
|
+
keep: int = 3,
|
|
83
|
+
clear_at_least: int = 0,
|
|
84
|
+
clear_tool_inputs: bool = False,
|
|
85
|
+
exclude_tools: list[str | ToolModel | dict[str, Any]] | None = None,
|
|
86
|
+
placeholder: str = "[cleared]",
|
|
87
|
+
) -> ClearToolUsesEdit:
|
|
88
|
+
"""
|
|
89
|
+
Create a ClearToolUsesEdit for use with ContextEditingMiddleware.
|
|
90
|
+
|
|
91
|
+
This edit strategy clears older tool results when the conversation exceeds
|
|
92
|
+
a token threshold, while preserving recent results.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
trigger: Token count that triggers the edit. When conversation exceeds
|
|
96
|
+
this, older tool outputs are cleared. Default 100000.
|
|
97
|
+
keep: Number of most recent tool results to preserve. These are never
|
|
98
|
+
cleared. Default 3.
|
|
99
|
+
clear_at_least: Minimum tokens to reclaim when edit runs.
|
|
100
|
+
0 means clear as much as needed. Default 0.
|
|
101
|
+
clear_tool_inputs: Whether to clear tool call arguments on AI messages.
|
|
102
|
+
When True, tool call arguments are replaced with empty objects.
|
|
103
|
+
Default False.
|
|
104
|
+
exclude_tools: Tools to never clear. Can be:
|
|
105
|
+
- list of str: Tool names
|
|
106
|
+
- list of ToolModel: DAO AI tool models
|
|
107
|
+
- list of dict: Tool config dicts
|
|
108
|
+
Default None (no exclusions).
|
|
109
|
+
placeholder: Text inserted for cleared tool outputs.
|
|
110
|
+
Default "[cleared]".
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
ClearToolUsesEdit instance
|
|
114
|
+
|
|
115
|
+
Example:
|
|
116
|
+
edit = create_clear_tool_uses_edit(
|
|
117
|
+
trigger=50000,
|
|
118
|
+
keep=5,
|
|
119
|
+
clear_tool_inputs=True,
|
|
120
|
+
exclude_tools=["important_tool"],
|
|
121
|
+
)
|
|
122
|
+
"""
|
|
123
|
+
excluded = _resolve_tool_names(exclude_tools) if exclude_tools else []
|
|
124
|
+
|
|
125
|
+
logger.debug(
|
|
126
|
+
"Creating ClearToolUsesEdit",
|
|
127
|
+
trigger=trigger,
|
|
128
|
+
keep=keep,
|
|
129
|
+
clear_at_least=clear_at_least,
|
|
130
|
+
clear_tool_inputs=clear_tool_inputs,
|
|
131
|
+
exclude_tools=excluded or "none",
|
|
132
|
+
placeholder=placeholder,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return ClearToolUsesEdit(
|
|
136
|
+
trigger=trigger,
|
|
137
|
+
keep=keep,
|
|
138
|
+
clear_at_least=clear_at_least,
|
|
139
|
+
clear_tool_inputs=clear_tool_inputs,
|
|
140
|
+
exclude_tools=excluded,
|
|
141
|
+
placeholder=placeholder,
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def create_context_editing_middleware(
|
|
146
|
+
trigger: int = 100000,
|
|
147
|
+
keep: int = 3,
|
|
148
|
+
clear_at_least: int = 0,
|
|
149
|
+
clear_tool_inputs: bool = False,
|
|
150
|
+
exclude_tools: list[str | ToolModel | dict[str, Any]] | None = None,
|
|
151
|
+
placeholder: str = "[cleared]",
|
|
152
|
+
token_count_method: Literal["approximate", "model"] = "approximate",
|
|
153
|
+
) -> ContextEditingMiddleware:
|
|
154
|
+
"""
|
|
155
|
+
Create a ContextEditingMiddleware with ClearToolUsesEdit.
|
|
156
|
+
|
|
157
|
+
Manages conversation context by clearing older tool call outputs when token
|
|
158
|
+
limits are reached. Useful for long conversations with many tool calls that
|
|
159
|
+
exceed context window limits.
|
|
160
|
+
|
|
161
|
+
Use cases:
|
|
162
|
+
- Long conversations with many tool calls exceeding token limits
|
|
163
|
+
- Reducing token costs by removing older irrelevant tool outputs
|
|
164
|
+
- Maintaining only the most recent N tool results in context
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
trigger: Token count that triggers clearing. When conversation exceeds
|
|
168
|
+
this threshold, older tool outputs are cleared. Default 100000.
|
|
169
|
+
keep: Number of most recent tool results to always preserve.
|
|
170
|
+
These are never cleared. Default 3.
|
|
171
|
+
clear_at_least: Minimum tokens to reclaim when edit runs.
|
|
172
|
+
0 means clear as much as needed. Default 0.
|
|
173
|
+
clear_tool_inputs: Whether to also clear tool call arguments on AI
|
|
174
|
+
messages. When True, replaces arguments with empty objects.
|
|
175
|
+
Default False (preserves tool call context).
|
|
176
|
+
exclude_tools: Tools to never clear outputs from. Can be:
|
|
177
|
+
- list of str: Tool names
|
|
178
|
+
- list of ToolModel: DAO AI tool models
|
|
179
|
+
- list of dict: Tool config dicts
|
|
180
|
+
Default None (no exclusions).
|
|
181
|
+
placeholder: Text inserted for cleared tool outputs.
|
|
182
|
+
Default "[cleared]".
|
|
183
|
+
token_count_method: How to count tokens:
|
|
184
|
+
- "approximate": Fast estimation (default)
|
|
185
|
+
- "model": Accurate count using model tokenizer
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
List containing ContextEditingMiddleware instance
|
|
189
|
+
|
|
190
|
+
Example:
|
|
191
|
+
# Basic usage - clear old tool outputs after 100k tokens
|
|
192
|
+
middleware = create_context_editing_middleware(
|
|
193
|
+
trigger=100000,
|
|
194
|
+
keep=3,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Aggressive clearing with exclusions
|
|
198
|
+
middleware = create_context_editing_middleware(
|
|
199
|
+
trigger=50000,
|
|
200
|
+
keep=5,
|
|
201
|
+
clear_tool_inputs=True,
|
|
202
|
+
exclude_tools=["important_tool", "critical_search"],
|
|
203
|
+
placeholder="[output cleared to save context]",
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
# Accurate token counting
|
|
207
|
+
middleware = create_context_editing_middleware(
|
|
208
|
+
trigger=100000,
|
|
209
|
+
keep=3,
|
|
210
|
+
token_count_method="model",
|
|
211
|
+
)
|
|
212
|
+
"""
|
|
213
|
+
edit = create_clear_tool_uses_edit(
|
|
214
|
+
trigger=trigger,
|
|
215
|
+
keep=keep,
|
|
216
|
+
clear_at_least=clear_at_least,
|
|
217
|
+
clear_tool_inputs=clear_tool_inputs,
|
|
218
|
+
exclude_tools=exclude_tools,
|
|
219
|
+
placeholder=placeholder,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
logger.debug(
|
|
223
|
+
"Creating ContextEditingMiddleware",
|
|
224
|
+
token_count_method=token_count_method,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
return ContextEditingMiddleware(
|
|
228
|
+
edits=[edit],
|
|
229
|
+
token_count_method=token_count_method,
|
|
230
|
+
)
|
dao_ai/middleware/core.py
CHANGED
|
@@ -21,7 +21,6 @@ def create_factory_middleware(
|
|
|
21
21
|
"""
|
|
22
22
|
Create middleware from a factory function.
|
|
23
23
|
|
|
24
|
-
|
|
25
24
|
This factory function dynamically loads a Python function and calls it
|
|
26
25
|
with the provided arguments to create a middleware instance.
|
|
27
26
|
|
|
@@ -35,7 +34,7 @@ def create_factory_middleware(
|
|
|
35
34
|
args: Arguments to pass to the factory function
|
|
36
35
|
|
|
37
36
|
Returns:
|
|
38
|
-
|
|
37
|
+
The AgentMiddleware instance returned by the factory function.
|
|
39
38
|
|
|
40
39
|
Raises:
|
|
41
40
|
ImportError: If the function cannot be loaded
|
|
@@ -59,9 +58,10 @@ def create_factory_middleware(
|
|
|
59
58
|
factory: Callable[..., AgentMiddleware[AgentState, Context]] = load_function(
|
|
60
59
|
function_name=function_name
|
|
61
60
|
)
|
|
62
|
-
middleware
|
|
61
|
+
middleware = factory(**args)
|
|
63
62
|
|
|
64
63
|
logger.trace(
|
|
65
|
-
"Created middleware from factory",
|
|
64
|
+
"Created middleware from factory",
|
|
65
|
+
middleware_type=type(middleware).__name__,
|
|
66
66
|
)
|
|
67
67
|
return middleware
|
dao_ai/middleware/guardrails.py
CHANGED
|
@@ -342,7 +342,7 @@ def create_guardrail_middleware(
|
|
|
342
342
|
num_retries: Maximum number of retry attempts (default: 3)
|
|
343
343
|
|
|
344
344
|
Returns:
|
|
345
|
-
GuardrailMiddleware configured with the specified parameters
|
|
345
|
+
List containing GuardrailMiddleware configured with the specified parameters
|
|
346
346
|
|
|
347
347
|
Example:
|
|
348
348
|
middleware = create_guardrail_middleware(
|
|
@@ -376,7 +376,7 @@ def create_content_filter_middleware(
|
|
|
376
376
|
block_message: Message to return when content is blocked
|
|
377
377
|
|
|
378
378
|
Returns:
|
|
379
|
-
ContentFilterMiddleware configured with the specified parameters
|
|
379
|
+
List containing ContentFilterMiddleware configured with the specified parameters
|
|
380
380
|
|
|
381
381
|
Example:
|
|
382
382
|
middleware = create_content_filter_middleware(
|
|
@@ -407,7 +407,7 @@ def create_safety_guardrail_middleware(
|
|
|
407
407
|
defaults to gpt-4o-mini.
|
|
408
408
|
|
|
409
409
|
Returns:
|
|
410
|
-
SafetyGuardrailMiddleware configured with the specified model
|
|
410
|
+
List containing SafetyGuardrailMiddleware configured with the specified model
|
|
411
411
|
|
|
412
412
|
Example:
|
|
413
413
|
from databricks_langchain import ChatDatabricks
|
|
@@ -132,7 +132,7 @@ def create_human_in_the_loop_middleware(
|
|
|
132
132
|
description_prefix: Message prefix shown when pausing for review
|
|
133
133
|
|
|
134
134
|
Returns:
|
|
135
|
-
HumanInTheLoopMiddleware configured with the specified parameters
|
|
135
|
+
List containing HumanInTheLoopMiddleware configured with the specified parameters
|
|
136
136
|
|
|
137
137
|
Example:
|
|
138
138
|
from dao_ai.config import HumanInTheLoopModel
|
|
@@ -182,7 +182,8 @@ def create_hitl_middleware_from_tool_models(
|
|
|
182
182
|
description_prefix: Message prefix shown when pausing for review
|
|
183
183
|
|
|
184
184
|
Returns:
|
|
185
|
-
HumanInTheLoopMiddleware if any tools require approval,
|
|
185
|
+
List containing HumanInTheLoopMiddleware if any tools require approval,
|
|
186
|
+
empty list otherwise
|
|
186
187
|
|
|
187
188
|
Example:
|
|
188
189
|
from dao_ai.config import ToolModel, PythonFunctionModel, HumanInTheLoopModel
|
|
@@ -501,7 +501,7 @@ def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
|
|
|
501
501
|
and format of user_id in the runtime context.
|
|
502
502
|
|
|
503
503
|
Returns:
|
|
504
|
-
UserIdValidationMiddleware instance
|
|
504
|
+
List containing UserIdValidationMiddleware instance
|
|
505
505
|
|
|
506
506
|
Example:
|
|
507
507
|
middleware = create_user_id_validation_middleware()
|
|
@@ -518,7 +518,7 @@ def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
|
|
|
518
518
|
of thread_id in the runtime context.
|
|
519
519
|
|
|
520
520
|
Returns:
|
|
521
|
-
ThreadIdValidationMiddleware instance
|
|
521
|
+
List containing ThreadIdValidationMiddleware instance
|
|
522
522
|
|
|
523
523
|
Example:
|
|
524
524
|
middleware = create_thread_id_validation_middleware()
|
|
@@ -550,7 +550,7 @@ def create_custom_field_validation_middleware(
|
|
|
550
550
|
optionally 'description', 'required', and 'example_value' keys.
|
|
551
551
|
|
|
552
552
|
Returns:
|
|
553
|
-
CustomFieldValidationMiddleware configured with the specified fields
|
|
553
|
+
List containing CustomFieldValidationMiddleware configured with the specified fields
|
|
554
554
|
|
|
555
555
|
Example:
|
|
556
556
|
middleware = create_custom_field_validation_middleware(
|
|
@@ -577,7 +577,7 @@ def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddl
|
|
|
577
577
|
process only the latest user input without conversation history.
|
|
578
578
|
|
|
579
579
|
Returns:
|
|
580
|
-
FilterLastHumanMessageMiddleware instance
|
|
580
|
+
List containing FilterLastHumanMessageMiddleware instance
|
|
581
581
|
|
|
582
582
|
Example:
|
|
583
583
|
middleware = create_filter_last_human_message_middleware()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model call limit middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Limits the number of model (LLM) calls to prevent infinite loops or excessive costs.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
from dao_ai.middleware import create_model_call_limit_middleware
|
|
8
|
+
|
|
9
|
+
# Limit model calls per run and thread
|
|
10
|
+
middleware = create_model_call_limit_middleware(
|
|
11
|
+
thread_limit=10,
|
|
12
|
+
run_limit=5,
|
|
13
|
+
)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Literal
|
|
19
|
+
|
|
20
|
+
from langchain.agents.middleware import ModelCallLimitMiddleware
|
|
21
|
+
from loguru import logger
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ModelCallLimitMiddleware",
|
|
25
|
+
"create_model_call_limit_middleware",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_model_call_limit_middleware(
|
|
30
|
+
thread_limit: int | None = None,
|
|
31
|
+
run_limit: int | None = None,
|
|
32
|
+
exit_behavior: Literal["error", "end"] = "end",
|
|
33
|
+
) -> ModelCallLimitMiddleware:
|
|
34
|
+
"""
|
|
35
|
+
Create a ModelCallLimitMiddleware to limit LLM API calls.
|
|
36
|
+
|
|
37
|
+
Prevents runaway agents from making too many API calls and helps
|
|
38
|
+
enforce cost controls on production deployments.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
thread_limit: Max model calls per thread (conversation).
|
|
42
|
+
Requires checkpointer. None = no limit.
|
|
43
|
+
run_limit: Max model calls per run (single invocation).
|
|
44
|
+
None = no limit.
|
|
45
|
+
exit_behavior: What to do when limit hit:
|
|
46
|
+
- "end": Stop execution gracefully (default)
|
|
47
|
+
- "error": Raise ModelCallLimitExceededError immediately
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List containing ModelCallLimitMiddleware instance
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If no limits specified
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
# Limit to 5 model calls per run, 10 per thread
|
|
57
|
+
limiter = create_model_call_limit_middleware(
|
|
58
|
+
run_limit=5,
|
|
59
|
+
thread_limit=10,
|
|
60
|
+
exit_behavior="end",
|
|
61
|
+
)
|
|
62
|
+
"""
|
|
63
|
+
if thread_limit is None and run_limit is None:
|
|
64
|
+
raise ValueError("At least one of thread_limit or run_limit must be specified.")
|
|
65
|
+
|
|
66
|
+
logger.debug(
|
|
67
|
+
"Creating model call limit middleware",
|
|
68
|
+
thread_limit=thread_limit,
|
|
69
|
+
run_limit=run_limit,
|
|
70
|
+
exit_behavior=exit_behavior,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return ModelCallLimitMiddleware(
|
|
74
|
+
thread_limit=thread_limit,
|
|
75
|
+
run_limit=run_limit,
|
|
76
|
+
exit_behavior=exit_behavior,
|
|
77
|
+
)
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model retry middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Automatically retries failed model (LLM) calls with configurable exponential backoff.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
from dao_ai.middleware import create_model_retry_middleware
|
|
8
|
+
|
|
9
|
+
# Retry failed model calls with exponential backoff
|
|
10
|
+
middleware = create_model_retry_middleware(
|
|
11
|
+
max_retries=3,
|
|
12
|
+
backoff_factor=2.0,
|
|
13
|
+
initial_delay=1.0,
|
|
14
|
+
)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from typing import Any, Callable, Literal
|
|
20
|
+
|
|
21
|
+
from langchain.agents.middleware import ModelRetryMiddleware
|
|
22
|
+
from loguru import logger
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"ModelRetryMiddleware",
|
|
26
|
+
"create_model_retry_middleware",
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def create_model_retry_middleware(
|
|
31
|
+
max_retries: int = 3,
|
|
32
|
+
backoff_factor: float = 2.0,
|
|
33
|
+
initial_delay: float = 1.0,
|
|
34
|
+
max_delay: float | None = None,
|
|
35
|
+
jitter: bool = False,
|
|
36
|
+
retry_on: tuple[type[Exception], ...] | Callable[[Exception], bool] | None = None,
|
|
37
|
+
on_failure: Literal["continue", "error"] | Callable[[Exception], str] = "continue",
|
|
38
|
+
) -> ModelRetryMiddleware:
|
|
39
|
+
"""
|
|
40
|
+
Create a ModelRetryMiddleware for automatic model call retries.
|
|
41
|
+
|
|
42
|
+
Handles transient failures in model API calls with exponential backoff.
|
|
43
|
+
Useful for handling rate limits, network issues, and temporary outages.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
max_retries: Max retry attempts after initial call. Default 3.
|
|
47
|
+
backoff_factor: Multiplier for exponential backoff. Default 2.0.
|
|
48
|
+
Delay = initial_delay * (backoff_factor ** retry_number)
|
|
49
|
+
Set to 0.0 for constant delay.
|
|
50
|
+
initial_delay: Initial delay in seconds before first retry. Default 1.0.
|
|
51
|
+
max_delay: Max delay in seconds (caps exponential growth). None = no cap.
|
|
52
|
+
jitter: Add ±25% random jitter to avoid thundering herd. Default False.
|
|
53
|
+
retry_on: When to retry:
|
|
54
|
+
- None: Retry on all errors (default)
|
|
55
|
+
- tuple of Exception types: Retry only on these
|
|
56
|
+
- callable: Function(exception) -> bool for custom logic
|
|
57
|
+
on_failure: Behavior when all retries exhausted:
|
|
58
|
+
- "continue": Return AIMessage with error, let agent continue (default)
|
|
59
|
+
- "error": Re-raise exception, stop execution
|
|
60
|
+
- callable: Function(exception) -> str for custom error message
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List containing ModelRetryMiddleware instance
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
# Basic retry with defaults
|
|
67
|
+
retry = create_model_retry_middleware()
|
|
68
|
+
|
|
69
|
+
# Custom backoff for rate limits
|
|
70
|
+
retry = create_model_retry_middleware(
|
|
71
|
+
max_retries=5,
|
|
72
|
+
backoff_factor=2.0,
|
|
73
|
+
initial_delay=1.0,
|
|
74
|
+
max_delay=60.0,
|
|
75
|
+
jitter=True,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Retry only on specific exceptions, fail hard
|
|
79
|
+
retry = create_model_retry_middleware(
|
|
80
|
+
max_retries=3,
|
|
81
|
+
retry_on=(RateLimitError, TimeoutError),
|
|
82
|
+
on_failure="error",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Custom retry logic
|
|
86
|
+
def should_retry(error: Exception) -> bool:
|
|
87
|
+
return "rate_limit" in str(error).lower()
|
|
88
|
+
|
|
89
|
+
retry = create_model_retry_middleware(
|
|
90
|
+
max_retries=5,
|
|
91
|
+
retry_on=should_retry,
|
|
92
|
+
)
|
|
93
|
+
"""
|
|
94
|
+
logger.debug(
|
|
95
|
+
"Creating model retry middleware",
|
|
96
|
+
max_retries=max_retries,
|
|
97
|
+
backoff_factor=backoff_factor,
|
|
98
|
+
initial_delay=initial_delay,
|
|
99
|
+
max_delay=max_delay,
|
|
100
|
+
jitter=jitter,
|
|
101
|
+
on_failure=on_failure if isinstance(on_failure, str) else "custom",
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Build kwargs
|
|
105
|
+
kwargs: dict[str, Any] = {
|
|
106
|
+
"max_retries": max_retries,
|
|
107
|
+
"backoff_factor": backoff_factor,
|
|
108
|
+
"initial_delay": initial_delay,
|
|
109
|
+
"on_failure": on_failure,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
if max_delay is not None:
|
|
113
|
+
kwargs["max_delay"] = max_delay
|
|
114
|
+
|
|
115
|
+
if jitter:
|
|
116
|
+
kwargs["jitter"] = jitter
|
|
117
|
+
|
|
118
|
+
if retry_on is not None:
|
|
119
|
+
kwargs["retry_on"] = retry_on
|
|
120
|
+
|
|
121
|
+
return ModelRetryMiddleware(**kwargs)
|
dao_ai/middleware/pii.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PII detection middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Detects and handles Personally Identifiable Information (PII) in conversations
|
|
5
|
+
using configurable strategies (redact, mask, hash, block).
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from dao_ai.middleware import create_pii_middleware
|
|
9
|
+
|
|
10
|
+
# Redact emails in user input
|
|
11
|
+
middleware = create_pii_middleware(
|
|
12
|
+
pii_type="email",
|
|
13
|
+
strategy="redact",
|
|
14
|
+
apply_to_input=True,
|
|
15
|
+
)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Any, Callable, Literal, Pattern
|
|
21
|
+
|
|
22
|
+
from langchain.agents.middleware import PIIMiddleware
|
|
23
|
+
from loguru import logger
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"PIIMiddleware",
|
|
27
|
+
"create_pii_middleware",
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
# Type alias for PII detector
|
|
31
|
+
PIIDetector = str | Pattern[str] | Callable[[str], list[dict[str, str | int]]]
|
|
32
|
+
|
|
33
|
+
# Built-in PII types
|
|
34
|
+
BUILTIN_PII_TYPES = frozenset({"email", "credit_card", "ip", "mac_address", "url"})
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_pii_middleware(
|
|
38
|
+
pii_type: str,
|
|
39
|
+
strategy: Literal["redact", "mask", "hash", "block"] = "redact",
|
|
40
|
+
detector: PIIDetector | None = None,
|
|
41
|
+
apply_to_input: bool = True,
|
|
42
|
+
apply_to_output: bool = False,
|
|
43
|
+
apply_to_tool_results: bool = False,
|
|
44
|
+
) -> PIIMiddleware:
|
|
45
|
+
"""
|
|
46
|
+
Create a PIIMiddleware for detecting and handling PII.
|
|
47
|
+
|
|
48
|
+
Detects Personally Identifiable Information in conversations and handles
|
|
49
|
+
it according to the specified strategy. Useful for compliance, privacy,
|
|
50
|
+
and sanitizing logs.
|
|
51
|
+
|
|
52
|
+
Built-in PII types:
|
|
53
|
+
- email: Email addresses
|
|
54
|
+
- credit_card: Credit card numbers (Luhn validated)
|
|
55
|
+
- ip: IP addresses
|
|
56
|
+
- mac_address: MAC addresses
|
|
57
|
+
- url: URLs
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
pii_type: Type of PII to detect. Use built-in types (email, credit_card,
|
|
61
|
+
ip, mac_address, url) or custom type names with a detector.
|
|
62
|
+
strategy: How to handle detected PII:
|
|
63
|
+
- "redact": Replace with [REDACTED_{TYPE}] (default)
|
|
64
|
+
- "mask": Partially obscure (e.g., ****-****-****-1234)
|
|
65
|
+
- "hash": Replace with deterministic hash
|
|
66
|
+
- "block": Raise exception when detected
|
|
67
|
+
detector: Custom detector for non-built-in types. Can be:
|
|
68
|
+
- str: Regex pattern string
|
|
69
|
+
- re.Pattern: Compiled regex pattern
|
|
70
|
+
- Callable: Function(content: str) -> list[dict] with keys:
|
|
71
|
+
- text: The matched text
|
|
72
|
+
- start: Start index
|
|
73
|
+
- end: End index
|
|
74
|
+
Default None (uses built-in detector for built-in types).
|
|
75
|
+
apply_to_input: Check user messages before model call. Default True.
|
|
76
|
+
apply_to_output: Check AI messages after model call. Default False.
|
|
77
|
+
apply_to_tool_results: Check tool results after execution. Default False.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List containing PIIMiddleware instance
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
ValueError: If custom pii_type without detector, or invalid strategy
|
|
84
|
+
|
|
85
|
+
Example:
|
|
86
|
+
# Redact emails in input
|
|
87
|
+
email_redactor = create_pii_middleware(
|
|
88
|
+
pii_type="email",
|
|
89
|
+
strategy="redact",
|
|
90
|
+
apply_to_input=True,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Mask credit cards
|
|
94
|
+
card_masker = create_pii_middleware(
|
|
95
|
+
pii_type="credit_card",
|
|
96
|
+
strategy="mask",
|
|
97
|
+
apply_to_input=True,
|
|
98
|
+
apply_to_output=True,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Block API keys with custom regex
|
|
102
|
+
api_key_blocker = create_pii_middleware(
|
|
103
|
+
pii_type="api_key",
|
|
104
|
+
detector=r"sk-[a-zA-Z0-9]{32}",
|
|
105
|
+
strategy="block",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Custom SSN detector with validation
|
|
109
|
+
def detect_ssn(content: str) -> list[dict]:
|
|
110
|
+
matches = []
|
|
111
|
+
pattern = r"\\d{3}-\\d{2}-\\d{4}"
|
|
112
|
+
for match in re.finditer(pattern, content):
|
|
113
|
+
ssn = match.group(0)
|
|
114
|
+
first_three = int(ssn[:3])
|
|
115
|
+
if first_three not in [0, 666] and not (900 <= first_three <= 999):
|
|
116
|
+
matches.append({
|
|
117
|
+
"text": ssn,
|
|
118
|
+
"start": match.start(),
|
|
119
|
+
"end": match.end(),
|
|
120
|
+
})
|
|
121
|
+
return matches
|
|
122
|
+
|
|
123
|
+
ssn_hasher = create_pii_middleware(
|
|
124
|
+
pii_type="ssn",
|
|
125
|
+
detector=detect_ssn,
|
|
126
|
+
strategy="hash",
|
|
127
|
+
)
|
|
128
|
+
"""
|
|
129
|
+
# Validate: custom types require detector
|
|
130
|
+
if pii_type not in BUILTIN_PII_TYPES and detector is None:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
f"Custom PII type '{pii_type}' requires a detector. "
|
|
133
|
+
f"Built-in types are: {', '.join(sorted(BUILTIN_PII_TYPES))}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
logger.debug(
|
|
137
|
+
"Creating PII middleware",
|
|
138
|
+
pii_type=pii_type,
|
|
139
|
+
strategy=strategy,
|
|
140
|
+
has_custom_detector=detector is not None,
|
|
141
|
+
apply_to_input=apply_to_input,
|
|
142
|
+
apply_to_output=apply_to_output,
|
|
143
|
+
apply_to_tool_results=apply_to_tool_results,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# Build kwargs
|
|
147
|
+
kwargs: dict[str, Any] = {
|
|
148
|
+
"strategy": strategy,
|
|
149
|
+
"apply_to_input": apply_to_input,
|
|
150
|
+
"apply_to_output": apply_to_output,
|
|
151
|
+
"apply_to_tool_results": apply_to_tool_results,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
if detector is not None:
|
|
155
|
+
kwargs["detector"] = detector
|
|
156
|
+
|
|
157
|
+
return PIIMiddleware(pii_type, **kwargs)
|