dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,586 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Message validation middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
These middleware implementations validate incoming messages and context
|
|
5
|
+
before agent processing begins.
|
|
6
|
+
|
|
7
|
+
Factory functions are provided for consistent configuration via the
|
|
8
|
+
DAO AI middleware factory pattern.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import json
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from langchain.agents.middleware import hook_config
|
|
15
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, RemoveMessage
|
|
16
|
+
from langgraph.runtime import Runtime
|
|
17
|
+
from loguru import logger
|
|
18
|
+
|
|
19
|
+
from dao_ai.messages import last_human_message
|
|
20
|
+
from dao_ai.middleware.base import AgentMiddleware
|
|
21
|
+
from dao_ai.state import AgentState, Context
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"MessageValidationMiddleware",
|
|
25
|
+
"UserIdValidationMiddleware",
|
|
26
|
+
"ThreadIdValidationMiddleware",
|
|
27
|
+
"CustomFieldValidationMiddleware",
|
|
28
|
+
"RequiredField",
|
|
29
|
+
"FilterLastHumanMessageMiddleware",
|
|
30
|
+
"create_user_id_validation_middleware",
|
|
31
|
+
"create_thread_id_validation_middleware",
|
|
32
|
+
"create_custom_field_validation_middleware",
|
|
33
|
+
"create_filter_last_human_message_middleware",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MessageValidationMiddleware(AgentMiddleware[AgentState, Context]):
|
|
38
|
+
"""
|
|
39
|
+
Base middleware for message validation.
|
|
40
|
+
|
|
41
|
+
Subclasses should implement the validate method to perform
|
|
42
|
+
specific validation logic.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
@hook_config(can_jump_to=["end"])
|
|
46
|
+
def before_agent(
|
|
47
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
48
|
+
) -> dict[str, Any] | None:
|
|
49
|
+
"""Validate messages before agent processing."""
|
|
50
|
+
try:
|
|
51
|
+
return self.validate(state, runtime)
|
|
52
|
+
except ValueError as e:
|
|
53
|
+
logger.error("Message validation failed", error=str(e))
|
|
54
|
+
return {
|
|
55
|
+
"is_valid": False,
|
|
56
|
+
"message_error": str(e),
|
|
57
|
+
"messages": [AIMessage(content=str(e))],
|
|
58
|
+
"jump_to": "end",
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
def validate(
|
|
62
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
63
|
+
) -> dict[str, Any] | None:
|
|
64
|
+
"""
|
|
65
|
+
Perform validation logic.
|
|
66
|
+
|
|
67
|
+
Override this method in subclasses to implement specific validation.
|
|
68
|
+
Raise ValueError to indicate validation failure.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
state: The current agent state
|
|
72
|
+
runtime: The LangGraph runtime context
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Optional dict with state updates
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
ValueError: If validation fails
|
|
79
|
+
"""
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class UserIdValidationMiddleware(MessageValidationMiddleware):
|
|
84
|
+
"""
|
|
85
|
+
Middleware that validates the presence and format of user_id.
|
|
86
|
+
|
|
87
|
+
Ensures that:
|
|
88
|
+
- user_id is provided in the context
|
|
89
|
+
- user_id does not contain invalid characters (like dots)
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def validate(
|
|
93
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
94
|
+
) -> dict[str, Any] | None:
|
|
95
|
+
"""Validate user_id is present and properly formatted."""
|
|
96
|
+
logger.trace("Executing user_id validation")
|
|
97
|
+
|
|
98
|
+
context: Context = runtime.context or Context()
|
|
99
|
+
user_id: str | None = context.user_id
|
|
100
|
+
|
|
101
|
+
if not user_id:
|
|
102
|
+
logger.error("User ID is required but not provided in configuration")
|
|
103
|
+
|
|
104
|
+
thread_val = context.thread_id or "<your_thread_id>"
|
|
105
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
106
|
+
context_dict = context.model_dump()
|
|
107
|
+
extra_fields = {
|
|
108
|
+
k: v
|
|
109
|
+
for k, v in context_dict.items()
|
|
110
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
corrected_config: dict[str, Any] = {
|
|
114
|
+
"configurable": {
|
|
115
|
+
"thread_id": thread_val,
|
|
116
|
+
"user_id": "<your_user_id>",
|
|
117
|
+
**extra_fields,
|
|
118
|
+
},
|
|
119
|
+
"session": {
|
|
120
|
+
"conversation_id": thread_val,
|
|
121
|
+
},
|
|
122
|
+
}
|
|
123
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
124
|
+
|
|
125
|
+
error_message = f"""
|
|
126
|
+
## Authentication Required
|
|
127
|
+
|
|
128
|
+
A **user_id** is required to process your request. Please provide your user ID in the configuration.
|
|
129
|
+
|
|
130
|
+
### Required Configuration Format
|
|
131
|
+
|
|
132
|
+
Please include the following JSON in your request configuration:
|
|
133
|
+
|
|
134
|
+
```json
|
|
135
|
+
{corrected_config_json}
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
### Field Descriptions
|
|
139
|
+
- **thread_id**: Thread identifier (required in configurable)
|
|
140
|
+
- **conversation_id**: Alias of thread_id (in session)
|
|
141
|
+
- **user_id**: Your unique user identifier (required)
|
|
142
|
+
|
|
143
|
+
Please update your configuration and try again.
|
|
144
|
+
""".strip()
|
|
145
|
+
|
|
146
|
+
raise ValueError(error_message)
|
|
147
|
+
|
|
148
|
+
if "." in user_id:
|
|
149
|
+
logger.error("User ID contains invalid character '.'", user_id=user_id)
|
|
150
|
+
|
|
151
|
+
corrected_user_id = user_id.replace(".", "_")
|
|
152
|
+
thread_val = context.thread_id or "<your_thread_id>"
|
|
153
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
154
|
+
context_dict = context.model_dump()
|
|
155
|
+
extra_fields = {
|
|
156
|
+
k: v
|
|
157
|
+
for k, v in context_dict.items()
|
|
158
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
corrected_config: dict[str, Any] = {
|
|
162
|
+
"configurable": {
|
|
163
|
+
"thread_id": thread_val,
|
|
164
|
+
"user_id": corrected_user_id,
|
|
165
|
+
**extra_fields,
|
|
166
|
+
},
|
|
167
|
+
"session": {
|
|
168
|
+
"conversation_id": thread_val,
|
|
169
|
+
},
|
|
170
|
+
}
|
|
171
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
172
|
+
|
|
173
|
+
error_message = f"""
|
|
174
|
+
## Invalid User ID Format
|
|
175
|
+
|
|
176
|
+
The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
|
|
177
|
+
|
|
178
|
+
### Corrected Configuration (Copy & Paste This)
|
|
179
|
+
```json
|
|
180
|
+
{corrected_config_json}
|
|
181
|
+
```
|
|
182
|
+
|
|
183
|
+
Please update your user_id and try again.
|
|
184
|
+
""".strip()
|
|
185
|
+
|
|
186
|
+
raise ValueError(error_message)
|
|
187
|
+
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class ThreadIdValidationMiddleware(MessageValidationMiddleware):
|
|
192
|
+
"""
|
|
193
|
+
Middleware that validates the presence of thread_id/conversation_id.
|
|
194
|
+
|
|
195
|
+
Note: thread_id and conversation_id are interchangeable in configurable.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def validate(
|
|
199
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
200
|
+
) -> dict[str, Any] | None:
|
|
201
|
+
"""Validate thread_id/conversation_id is present."""
|
|
202
|
+
logger.trace("Executing thread_id/conversation_id validation")
|
|
203
|
+
|
|
204
|
+
context: Context = runtime.context or Context()
|
|
205
|
+
thread_id: str | None = context.thread_id
|
|
206
|
+
|
|
207
|
+
if not thread_id:
|
|
208
|
+
logger.error("Thread ID / Conversation ID is required but not provided")
|
|
209
|
+
|
|
210
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
211
|
+
context_dict = context.model_dump()
|
|
212
|
+
extra_fields = {
|
|
213
|
+
k: v
|
|
214
|
+
for k, v in context_dict.items()
|
|
215
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
corrected_config: dict[str, Any] = {
|
|
219
|
+
"configurable": {
|
|
220
|
+
"thread_id": "<your_thread_id>",
|
|
221
|
+
"user_id": context.user_id or "<your_user_id>",
|
|
222
|
+
**extra_fields,
|
|
223
|
+
},
|
|
224
|
+
"session": {
|
|
225
|
+
"conversation_id": "<your_thread_id>",
|
|
226
|
+
},
|
|
227
|
+
}
|
|
228
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
229
|
+
|
|
230
|
+
error_message = f"""
|
|
231
|
+
## Configuration Required
|
|
232
|
+
|
|
233
|
+
A **thread_id** is required to process your request (or **conversation_id** as an alias).
|
|
234
|
+
|
|
235
|
+
### Required Configuration Format
|
|
236
|
+
|
|
237
|
+
Please include the following JSON in your request configuration:
|
|
238
|
+
|
|
239
|
+
```json
|
|
240
|
+
{corrected_config_json}
|
|
241
|
+
```
|
|
242
|
+
|
|
243
|
+
### Field Descriptions
|
|
244
|
+
- **thread_id**: Thread identifier (required in configurable)
|
|
245
|
+
- **conversation_id**: Alias of thread_id (in session)
|
|
246
|
+
- **user_id**: Your unique user identifier (required)
|
|
247
|
+
|
|
248
|
+
Please update your configuration and try again.
|
|
249
|
+
""".strip()
|
|
250
|
+
|
|
251
|
+
raise ValueError(error_message)
|
|
252
|
+
|
|
253
|
+
return None
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class RequiredField:
|
|
257
|
+
"""Definition of a field for validation.
|
|
258
|
+
|
|
259
|
+
Fields are marked as required or optional via the `required` flag:
|
|
260
|
+
- required=True (default): Field must be provided, validated
|
|
261
|
+
- required=False: Field is optional, not validated
|
|
262
|
+
|
|
263
|
+
For required fields, an `example_value` can be provided to show in error
|
|
264
|
+
messages, making it easy for users to copy-paste the configuration.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
name: The field name (e.g., "store_num", "user_id")
|
|
268
|
+
description: Human-readable description for error messages
|
|
269
|
+
required: Whether this field is required (default: True)
|
|
270
|
+
example_value: Example value to show in error messages for missing fields
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(
|
|
274
|
+
self,
|
|
275
|
+
name: str,
|
|
276
|
+
description: str | None = None,
|
|
277
|
+
required: bool = True,
|
|
278
|
+
example_value: Any = None,
|
|
279
|
+
):
|
|
280
|
+
self.name = name
|
|
281
|
+
self.description = description or f"Your {name}"
|
|
282
|
+
self.required = required
|
|
283
|
+
self.example_value = example_value
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def is_required(self) -> bool:
|
|
287
|
+
"""A field is required based on the required flag."""
|
|
288
|
+
return self.required
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
292
|
+
"""
|
|
293
|
+
Middleware that validates the presence of required custom fields.
|
|
294
|
+
|
|
295
|
+
This is a generic validation middleware that can check for multiple
|
|
296
|
+
required fields in the context object.
|
|
297
|
+
|
|
298
|
+
Fields are defined in the `fields` list. Each field can have:
|
|
299
|
+
- name: The field name (required)
|
|
300
|
+
- description: Human-readable description for error messages
|
|
301
|
+
- required: Whether field is required (default: True)
|
|
302
|
+
- example_value: Example value to show in error messages
|
|
303
|
+
|
|
304
|
+
Required fields (required=True) will be validated.
|
|
305
|
+
The example_value is used in error messages to help users copy-paste
|
|
306
|
+
the correct configuration format.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
fields: List of fields to validate/show. Each can be a RequiredField
|
|
310
|
+
or a dict with 'name', 'description', 'required', and 'example_value' keys.
|
|
311
|
+
"""
|
|
312
|
+
|
|
313
|
+
def __init__(
|
|
314
|
+
self,
|
|
315
|
+
fields: list[RequiredField | dict[str, Any]],
|
|
316
|
+
):
|
|
317
|
+
super().__init__()
|
|
318
|
+
|
|
319
|
+
# Convert fields to RequiredField objects
|
|
320
|
+
self.fields: list[RequiredField] = []
|
|
321
|
+
for field in fields:
|
|
322
|
+
if isinstance(field, RequiredField):
|
|
323
|
+
self.fields.append(field)
|
|
324
|
+
elif isinstance(field, dict):
|
|
325
|
+
self.fields.append(RequiredField(**field))
|
|
326
|
+
|
|
327
|
+
def validate(
|
|
328
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
329
|
+
) -> dict[str, Any] | None:
|
|
330
|
+
"""Validate that all required fields are present.
|
|
331
|
+
|
|
332
|
+
Generates error messages with the new input structure:
|
|
333
|
+
configurable:
|
|
334
|
+
conversation_id: "abc-123"
|
|
335
|
+
user_id: "nate.fleming"
|
|
336
|
+
<field_name>: <example_value>
|
|
337
|
+
session: {}
|
|
338
|
+
"""
|
|
339
|
+
logger.trace("Executing custom field validation")
|
|
340
|
+
|
|
341
|
+
context: Context = runtime.context or Context()
|
|
342
|
+
|
|
343
|
+
# Find all missing required fields
|
|
344
|
+
missing_fields: list[RequiredField] = []
|
|
345
|
+
for field in self.fields:
|
|
346
|
+
if field.is_required:
|
|
347
|
+
field_value: Any = getattr(context, field.name, None)
|
|
348
|
+
if field_value is None:
|
|
349
|
+
missing_fields.append(field)
|
|
350
|
+
|
|
351
|
+
if not missing_fields:
|
|
352
|
+
return None
|
|
353
|
+
|
|
354
|
+
# Log the missing fields
|
|
355
|
+
missing_names = [f.name for f in missing_fields]
|
|
356
|
+
logger.error("Required fields missing", fields=missing_names)
|
|
357
|
+
|
|
358
|
+
# Build the configurable dict preserving provided values
|
|
359
|
+
# and using example_value for missing required fields
|
|
360
|
+
# Note: only thread_id is in configurable (conversation_id goes in session)
|
|
361
|
+
configurable: dict[str, Any] = {}
|
|
362
|
+
|
|
363
|
+
thread_val = context.thread_id or "<your_thread_id>"
|
|
364
|
+
configurable["thread_id"] = thread_val
|
|
365
|
+
|
|
366
|
+
if context.user_id:
|
|
367
|
+
configurable["user_id"] = context.user_id
|
|
368
|
+
else:
|
|
369
|
+
configurable["user_id"] = "<your_user_id>"
|
|
370
|
+
|
|
371
|
+
# Add all extra values the user already provided
|
|
372
|
+
context_dict = context.model_dump()
|
|
373
|
+
for k, v in context_dict.items():
|
|
374
|
+
if k not in {"user_id", "thread_id"} and v is not None:
|
|
375
|
+
configurable[k] = v
|
|
376
|
+
|
|
377
|
+
# Then add our defined fields (provided values take precedence)
|
|
378
|
+
for field in self.fields:
|
|
379
|
+
if field.name in configurable:
|
|
380
|
+
# Field was provided by user - keep their value
|
|
381
|
+
continue
|
|
382
|
+
|
|
383
|
+
if field.is_required:
|
|
384
|
+
# Missing required field - use example_value or placeholder
|
|
385
|
+
configurable[field.name] = (
|
|
386
|
+
field.example_value
|
|
387
|
+
if field.example_value is not None
|
|
388
|
+
else f"<your_{field.name}>"
|
|
389
|
+
)
|
|
390
|
+
else:
|
|
391
|
+
# Optional field not provided - use example_value if available
|
|
392
|
+
if field.example_value is not None:
|
|
393
|
+
configurable[field.name] = field.example_value
|
|
394
|
+
|
|
395
|
+
# Build the corrected config with new structure
|
|
396
|
+
# Note: conversation_id is in session as an alias of thread_id
|
|
397
|
+
corrected_config: dict[str, Any] = {
|
|
398
|
+
"configurable": configurable,
|
|
399
|
+
"session": {
|
|
400
|
+
"conversation_id": thread_val,
|
|
401
|
+
},
|
|
402
|
+
}
|
|
403
|
+
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
404
|
+
|
|
405
|
+
# Build field descriptions
|
|
406
|
+
field_descriptions: list[str] = [
|
|
407
|
+
"- **thread_id**: Thread identifier (required in configurable)",
|
|
408
|
+
"- **conversation_id**: Alias of thread_id (in session)",
|
|
409
|
+
]
|
|
410
|
+
|
|
411
|
+
# Add user_id if not in custom fields
|
|
412
|
+
has_user_id_field = any(f.name == "user_id" for f in self.fields)
|
|
413
|
+
if not has_user_id_field:
|
|
414
|
+
field_descriptions.append(
|
|
415
|
+
"- **user_id**: Your unique user identifier (required)"
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Add custom field descriptions
|
|
419
|
+
for field in self.fields:
|
|
420
|
+
required_text = "(required)" if field.is_required else "(optional)"
|
|
421
|
+
field_descriptions.append(
|
|
422
|
+
f"- **{field.name}**: {field.description} {required_text}"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
field_descriptions_text = "\n".join(field_descriptions)
|
|
426
|
+
|
|
427
|
+
# Build the list of missing field names for the error message
|
|
428
|
+
missing_names_formatted = ", ".join(f"**{f.name}**" for f in missing_fields)
|
|
429
|
+
|
|
430
|
+
error_message = f"""
|
|
431
|
+
## Configuration Required
|
|
432
|
+
|
|
433
|
+
The following required fields are missing: {missing_names_formatted}
|
|
434
|
+
|
|
435
|
+
### Required Configuration Format
|
|
436
|
+
|
|
437
|
+
Please include the following JSON in your request configuration:
|
|
438
|
+
|
|
439
|
+
```json
|
|
440
|
+
{corrected_config_json}
|
|
441
|
+
```
|
|
442
|
+
|
|
443
|
+
### Field Descriptions
|
|
444
|
+
{field_descriptions_text}
|
|
445
|
+
|
|
446
|
+
Please update your configuration and try again.
|
|
447
|
+
""".strip()
|
|
448
|
+
|
|
449
|
+
raise ValueError(error_message)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class FilterLastHumanMessageMiddleware(AgentMiddleware[AgentState, Context]):
|
|
453
|
+
"""
|
|
454
|
+
Middleware that filters messages to keep only the last human message.
|
|
455
|
+
|
|
456
|
+
This is useful for scenarios where you want to process only the
|
|
457
|
+
latest user input without conversation history.
|
|
458
|
+
"""
|
|
459
|
+
|
|
460
|
+
def before_model(
|
|
461
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
462
|
+
) -> dict[str, Any] | None:
|
|
463
|
+
"""Filter messages to keep only the last human message."""
|
|
464
|
+
logger.trace("Executing filter_last_human_message middleware")
|
|
465
|
+
|
|
466
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
467
|
+
|
|
468
|
+
if not messages:
|
|
469
|
+
logger.trace("No messages found in state")
|
|
470
|
+
return None
|
|
471
|
+
|
|
472
|
+
last_message: HumanMessage | None = last_human_message(messages)
|
|
473
|
+
|
|
474
|
+
if last_message is None:
|
|
475
|
+
logger.trace("No human messages found in state")
|
|
476
|
+
return {"messages": []}
|
|
477
|
+
|
|
478
|
+
logger.trace(
|
|
479
|
+
"Filtered messages to last human message", original_count=len(messages)
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
removed_messages = [
|
|
483
|
+
RemoveMessage(id=message.id)
|
|
484
|
+
for message in messages
|
|
485
|
+
if message.id != last_message.id
|
|
486
|
+
]
|
|
487
|
+
|
|
488
|
+
return {"messages": removed_messages}
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
# =============================================================================
|
|
492
|
+
# Factory Functions
|
|
493
|
+
# =============================================================================
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
|
|
497
|
+
"""
|
|
498
|
+
Create a UserIdValidationMiddleware instance.
|
|
499
|
+
|
|
500
|
+
Factory function for creating middleware that validates the presence
|
|
501
|
+
and format of user_id in the runtime context.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
List containing UserIdValidationMiddleware instance
|
|
505
|
+
|
|
506
|
+
Example:
|
|
507
|
+
middleware = create_user_id_validation_middleware()
|
|
508
|
+
"""
|
|
509
|
+
logger.trace("Creating user_id validation middleware")
|
|
510
|
+
return UserIdValidationMiddleware()
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
|
|
514
|
+
"""
|
|
515
|
+
Create a ThreadIdValidationMiddleware instance.
|
|
516
|
+
|
|
517
|
+
Factory function for creating middleware that validates the presence
|
|
518
|
+
of thread_id in the runtime context.
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
List containing ThreadIdValidationMiddleware instance
|
|
522
|
+
|
|
523
|
+
Example:
|
|
524
|
+
middleware = create_thread_id_validation_middleware()
|
|
525
|
+
"""
|
|
526
|
+
logger.trace("Creating thread_id validation middleware")
|
|
527
|
+
return ThreadIdValidationMiddleware()
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
def create_custom_field_validation_middleware(
|
|
531
|
+
fields: list[dict[str, Any]],
|
|
532
|
+
) -> CustomFieldValidationMiddleware:
|
|
533
|
+
"""
|
|
534
|
+
Create a CustomFieldValidationMiddleware instance.
|
|
535
|
+
|
|
536
|
+
Factory function for creating middleware that validates the presence
|
|
537
|
+
of required custom fields in the context object.
|
|
538
|
+
|
|
539
|
+
Each field in the list should have:
|
|
540
|
+
- name: The field name (required)
|
|
541
|
+
- description: Human-readable description for error messages (optional)
|
|
542
|
+
- required: Whether field is required (default: True)
|
|
543
|
+
- example_value: Example value to show in error messages (optional)
|
|
544
|
+
|
|
545
|
+
Required fields (required=True or not specified) will be validated.
|
|
546
|
+
The example_value is used in error messages to help users copy-paste.
|
|
547
|
+
|
|
548
|
+
Args:
|
|
549
|
+
fields: List of field definitions. Each dict should have 'name', and
|
|
550
|
+
optionally 'description', 'required', and 'example_value' keys.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
List containing CustomFieldValidationMiddleware configured with the specified fields
|
|
554
|
+
|
|
555
|
+
Example:
|
|
556
|
+
middleware = create_custom_field_validation_middleware(
|
|
557
|
+
fields=[
|
|
558
|
+
# Required field with example value for easy copy-paste
|
|
559
|
+
{"name": "store_num", "description": "Your store number", "example_value": "12345"},
|
|
560
|
+
# Optional fields (required=False)
|
|
561
|
+
{"name": "thread_id", "description": "Thread ID", "required": False, "example_value": "1"},
|
|
562
|
+
{"name": "user_id", "description": "User ID", "required": False, "example_value": "my_user_id"},
|
|
563
|
+
],
|
|
564
|
+
)
|
|
565
|
+
"""
|
|
566
|
+
field_names = [f.get("name", "unknown") for f in fields]
|
|
567
|
+
logger.trace("Creating custom field validation middleware", fields=field_names)
|
|
568
|
+
return CustomFieldValidationMiddleware(fields=fields)
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddleware:
|
|
572
|
+
"""
|
|
573
|
+
Create a FilterLastHumanMessageMiddleware instance.
|
|
574
|
+
|
|
575
|
+
Factory function for creating middleware that filters messages to keep
|
|
576
|
+
only the last human message, useful for scenarios where you want to
|
|
577
|
+
process only the latest user input without conversation history.
|
|
578
|
+
|
|
579
|
+
Returns:
|
|
580
|
+
List containing FilterLastHumanMessageMiddleware instance
|
|
581
|
+
|
|
582
|
+
Example:
|
|
583
|
+
middleware = create_filter_last_human_message_middleware()
|
|
584
|
+
"""
|
|
585
|
+
logger.trace("Creating filter_last_human_message middleware")
|
|
586
|
+
return FilterLastHumanMessageMiddleware()
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model call limit middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
Limits the number of model (LLM) calls to prevent infinite loops or excessive costs.
|
|
5
|
+
|
|
6
|
+
Example:
|
|
7
|
+
from dao_ai.middleware import create_model_call_limit_middleware
|
|
8
|
+
|
|
9
|
+
# Limit model calls per run and thread
|
|
10
|
+
middleware = create_model_call_limit_middleware(
|
|
11
|
+
thread_limit=10,
|
|
12
|
+
run_limit=5,
|
|
13
|
+
)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Literal
|
|
19
|
+
|
|
20
|
+
from langchain.agents.middleware import ModelCallLimitMiddleware
|
|
21
|
+
from loguru import logger
|
|
22
|
+
|
|
23
|
+
__all__ = [
|
|
24
|
+
"ModelCallLimitMiddleware",
|
|
25
|
+
"create_model_call_limit_middleware",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_model_call_limit_middleware(
|
|
30
|
+
thread_limit: int | None = None,
|
|
31
|
+
run_limit: int | None = None,
|
|
32
|
+
exit_behavior: Literal["error", "end"] = "end",
|
|
33
|
+
) -> ModelCallLimitMiddleware:
|
|
34
|
+
"""
|
|
35
|
+
Create a ModelCallLimitMiddleware to limit LLM API calls.
|
|
36
|
+
|
|
37
|
+
Prevents runaway agents from making too many API calls and helps
|
|
38
|
+
enforce cost controls on production deployments.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
thread_limit: Max model calls per thread (conversation).
|
|
42
|
+
Requires checkpointer. None = no limit.
|
|
43
|
+
run_limit: Max model calls per run (single invocation).
|
|
44
|
+
None = no limit.
|
|
45
|
+
exit_behavior: What to do when limit hit:
|
|
46
|
+
- "end": Stop execution gracefully (default)
|
|
47
|
+
- "error": Raise ModelCallLimitExceededError immediately
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
List containing ModelCallLimitMiddleware instance
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
ValueError: If no limits specified
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
# Limit to 5 model calls per run, 10 per thread
|
|
57
|
+
limiter = create_model_call_limit_middleware(
|
|
58
|
+
run_limit=5,
|
|
59
|
+
thread_limit=10,
|
|
60
|
+
exit_behavior="end",
|
|
61
|
+
)
|
|
62
|
+
"""
|
|
63
|
+
if thread_limit is None and run_limit is None:
|
|
64
|
+
raise ValueError("At least one of thread_limit or run_limit must be specified.")
|
|
65
|
+
|
|
66
|
+
logger.debug(
|
|
67
|
+
"Creating model call limit middleware",
|
|
68
|
+
thread_limit=thread_limit,
|
|
69
|
+
run_limit=run_limit,
|
|
70
|
+
exit_behavior=exit_behavior,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return ModelCallLimitMiddleware(
|
|
74
|
+
thread_limit=thread_limit,
|
|
75
|
+
run_limit=run_limit,
|
|
76
|
+
exit_behavior=exit_behavior,
|
|
77
|
+
)
|