dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/middleware/core.py
CHANGED
|
@@ -7,15 +7,17 @@ from fully qualified function names.
|
|
|
7
7
|
|
|
8
8
|
from typing import Any, Callable
|
|
9
9
|
|
|
10
|
+
from langchain.agents.middleware import AgentMiddleware
|
|
10
11
|
from loguru import logger
|
|
11
12
|
|
|
13
|
+
from dao_ai.state import AgentState, Context
|
|
12
14
|
from dao_ai.utils import load_function
|
|
13
15
|
|
|
14
16
|
|
|
15
17
|
def create_factory_middleware(
|
|
16
18
|
function_name: str,
|
|
17
19
|
args: dict[str, Any] | None = None,
|
|
18
|
-
) ->
|
|
20
|
+
) -> AgentMiddleware[AgentState, Context]:
|
|
19
21
|
"""
|
|
20
22
|
Create middleware from a factory function.
|
|
21
23
|
|
|
@@ -33,14 +35,14 @@ def create_factory_middleware(
|
|
|
33
35
|
args: Arguments to pass to the factory function
|
|
34
36
|
|
|
35
37
|
Returns:
|
|
36
|
-
|
|
38
|
+
An AgentMiddleware instance returned by the factory function
|
|
37
39
|
|
|
38
40
|
Raises:
|
|
39
41
|
ImportError: If the function cannot be loaded
|
|
40
42
|
|
|
41
43
|
Example:
|
|
42
44
|
# Factory function in my_module.py:
|
|
43
|
-
def create_custom_middleware(threshold: float = 0.5) -> AgentMiddleware:
|
|
45
|
+
def create_custom_middleware(threshold: float = 0.5) -> AgentMiddleware[AgentState, Context]:
|
|
44
46
|
return MyCustomMiddleware(threshold=threshold)
|
|
45
47
|
|
|
46
48
|
# Usage:
|
|
@@ -52,10 +54,14 @@ def create_factory_middleware(
|
|
|
52
54
|
if args is None:
|
|
53
55
|
args = {}
|
|
54
56
|
|
|
55
|
-
logger.
|
|
57
|
+
logger.trace("Creating factory middleware", function_name=function_name, args=args)
|
|
56
58
|
|
|
57
|
-
factory: Callable[...,
|
|
58
|
-
|
|
59
|
+
factory: Callable[..., AgentMiddleware[AgentState, Context]] = load_function(
|
|
60
|
+
function_name=function_name
|
|
61
|
+
)
|
|
62
|
+
middleware: AgentMiddleware[AgentState, Context] = factory(**args)
|
|
59
63
|
|
|
60
|
-
logger.
|
|
64
|
+
logger.trace(
|
|
65
|
+
"Created middleware from factory", middleware_type=type(middleware).__name__
|
|
66
|
+
)
|
|
61
67
|
return middleware
|
dao_ai/middleware/guardrails.py
CHANGED
|
@@ -117,29 +117,29 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
117
117
|
|
|
118
118
|
# Skip evaluation if the AI message has tool calls (not the final response yet)
|
|
119
119
|
if ai_message.tool_calls:
|
|
120
|
-
logger.
|
|
121
|
-
|
|
122
|
-
|
|
120
|
+
logger.trace(
|
|
121
|
+
"Guardrail skipping evaluation - AI message contains tool calls",
|
|
122
|
+
guardrail_name=self.guardrail_name,
|
|
123
123
|
)
|
|
124
124
|
return None
|
|
125
125
|
|
|
126
126
|
# Skip evaluation if the AI message has no content to evaluate
|
|
127
127
|
if not ai_message.content:
|
|
128
|
-
logger.
|
|
129
|
-
|
|
130
|
-
|
|
128
|
+
logger.trace(
|
|
129
|
+
"Guardrail skipping evaluation - AI message has no content",
|
|
130
|
+
guardrail_name=self.guardrail_name,
|
|
131
131
|
)
|
|
132
132
|
return None
|
|
133
133
|
|
|
134
|
-
logger.debug(f"Evaluating response with guardrail '{self.guardrail_name}'")
|
|
135
|
-
|
|
136
134
|
# Extract text content from messages (handles both string and structured content)
|
|
137
135
|
human_content = _extract_text_content(human_message)
|
|
138
136
|
ai_content = _extract_text_content(ai_message)
|
|
139
137
|
|
|
140
138
|
logger.debug(
|
|
141
|
-
|
|
142
|
-
|
|
139
|
+
"Evaluating response with guardrail",
|
|
140
|
+
guardrail_name=self.guardrail_name,
|
|
141
|
+
input_length=len(human_content),
|
|
142
|
+
output_length=len(ai_content),
|
|
143
143
|
)
|
|
144
144
|
|
|
145
145
|
evaluator = create_llm_as_judge(
|
|
@@ -150,8 +150,11 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
150
150
|
eval_result = evaluator(inputs=human_content, outputs=ai_content)
|
|
151
151
|
|
|
152
152
|
if eval_result["score"]:
|
|
153
|
-
logger.debug(
|
|
154
|
-
|
|
153
|
+
logger.debug(
|
|
154
|
+
"Response approved by guardrail",
|
|
155
|
+
guardrail_name=self.guardrail_name,
|
|
156
|
+
comment=eval_result["comment"],
|
|
157
|
+
)
|
|
155
158
|
self._retry_count = 0
|
|
156
159
|
return None
|
|
157
160
|
else:
|
|
@@ -160,10 +163,12 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
160
163
|
|
|
161
164
|
if self._retry_count >= self.num_retries:
|
|
162
165
|
logger.warning(
|
|
163
|
-
|
|
164
|
-
|
|
166
|
+
"Guardrail failed - max retries reached",
|
|
167
|
+
guardrail_name=self.guardrail_name,
|
|
168
|
+
retry_count=self._retry_count,
|
|
169
|
+
max_retries=self.num_retries,
|
|
170
|
+
critique=comment,
|
|
165
171
|
)
|
|
166
|
-
logger.warning(f"Final judge's critique: {comment}")
|
|
167
172
|
self._retry_count = 0
|
|
168
173
|
|
|
169
174
|
# Add system message to inform user of guardrail failure
|
|
@@ -177,10 +182,12 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
177
182
|
return {"messages": [AIMessage(content=failure_message)]}
|
|
178
183
|
|
|
179
184
|
logger.warning(
|
|
180
|
-
|
|
181
|
-
|
|
185
|
+
"Guardrail requested improvements",
|
|
186
|
+
guardrail_name=self.guardrail_name,
|
|
187
|
+
retry=self._retry_count,
|
|
188
|
+
max_retries=self.num_retries,
|
|
189
|
+
critique=comment,
|
|
182
190
|
)
|
|
183
|
-
logger.warning(f"Judge's critique: {comment}")
|
|
184
191
|
|
|
185
192
|
content: str = "\n".join([str(human_message.content), comment])
|
|
186
193
|
return {"messages": [HumanMessage(content=content)]}
|
|
@@ -250,9 +257,7 @@ class ContentFilterMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
250
257
|
|
|
251
258
|
for keyword in self.banned_keywords:
|
|
252
259
|
if keyword in content:
|
|
253
|
-
logger.warning(
|
|
254
|
-
f"Content filter blocked response containing '{keyword}'"
|
|
255
|
-
)
|
|
260
|
+
logger.warning("Content filter blocked response", keyword=keyword)
|
|
256
261
|
# Modify the last message content
|
|
257
262
|
last_message.content = self.block_message
|
|
258
263
|
return None
|
|
@@ -347,7 +352,7 @@ def create_guardrail_middleware(
|
|
|
347
352
|
num_retries=2,
|
|
348
353
|
)
|
|
349
354
|
"""
|
|
350
|
-
logger.
|
|
355
|
+
logger.trace("Creating guardrail middleware", guardrail_name=name)
|
|
351
356
|
return GuardrailMiddleware(
|
|
352
357
|
name=name,
|
|
353
358
|
model=model,
|
|
@@ -379,8 +384,8 @@ def create_content_filter_middleware(
|
|
|
379
384
|
block_message="I cannot discuss sensitive credentials.",
|
|
380
385
|
)
|
|
381
386
|
"""
|
|
382
|
-
logger.
|
|
383
|
-
|
|
387
|
+
logger.trace(
|
|
388
|
+
"Creating content filter middleware", keywords_count=len(banned_keywords)
|
|
384
389
|
)
|
|
385
390
|
return ContentFilterMiddleware(
|
|
386
391
|
banned_keywords=banned_keywords,
|
|
@@ -411,5 +416,5 @@ def create_safety_guardrail_middleware(
|
|
|
411
416
|
safety_model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
|
|
412
417
|
)
|
|
413
418
|
"""
|
|
414
|
-
logger.
|
|
419
|
+
logger.trace("Creating safety guardrail middleware")
|
|
415
420
|
return SafetyGuardrailMiddleware(safety_model=safety_model)
|
|
@@ -103,7 +103,10 @@ def _config_to_interrupt_on_entry(
|
|
|
103
103
|
interrupt_entry["description"] = config.review_prompt
|
|
104
104
|
return interrupt_entry
|
|
105
105
|
|
|
106
|
-
logger.warning(
|
|
106
|
+
logger.warning(
|
|
107
|
+
"Unknown HITL config type, defaulting to True",
|
|
108
|
+
config_type=type(config).__name__,
|
|
109
|
+
)
|
|
107
110
|
return True
|
|
108
111
|
|
|
109
112
|
|
|
@@ -152,8 +155,9 @@ def create_human_in_the_loop_middleware(
|
|
|
152
155
|
normalized_interrupt_on[tool_name] = config
|
|
153
156
|
|
|
154
157
|
logger.debug(
|
|
155
|
-
|
|
156
|
-
|
|
158
|
+
"Creating HITL middleware",
|
|
159
|
+
tools_count=len(normalized_interrupt_on),
|
|
160
|
+
tools=list(normalized_interrupt_on.keys()),
|
|
157
161
|
)
|
|
158
162
|
|
|
159
163
|
return HumanInTheLoopMiddleware(
|
|
@@ -216,10 +220,10 @@ def create_hitl_middleware_from_tool_models(
|
|
|
216
220
|
tool_name: str | None = getattr(func_tool, "name", None)
|
|
217
221
|
if tool_name:
|
|
218
222
|
interrupt_on[tool_name] = hitl_config
|
|
219
|
-
logger.
|
|
223
|
+
logger.trace("Tool configured for HITL", tool_name=tool_name)
|
|
220
224
|
|
|
221
225
|
if not interrupt_on:
|
|
222
|
-
logger.
|
|
226
|
+
logger.trace("No tools require HITL - returning None")
|
|
223
227
|
return None
|
|
224
228
|
|
|
225
229
|
return create_human_in_the_loop_middleware(
|
|
@@ -50,7 +50,7 @@ class MessageValidationMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
50
50
|
try:
|
|
51
51
|
return self.validate(state, runtime)
|
|
52
52
|
except ValueError as e:
|
|
53
|
-
logger.error(
|
|
53
|
+
logger.error("Message validation failed", error=str(e))
|
|
54
54
|
return {
|
|
55
55
|
"is_valid": False,
|
|
56
56
|
"message_error": str(e),
|
|
@@ -93,20 +93,28 @@ class UserIdValidationMiddleware(MessageValidationMiddleware):
|
|
|
93
93
|
self, state: AgentState, runtime: Runtime[Context]
|
|
94
94
|
) -> dict[str, Any] | None:
|
|
95
95
|
"""Validate user_id is present and properly formatted."""
|
|
96
|
-
logger.
|
|
96
|
+
logger.trace("Executing user_id validation")
|
|
97
97
|
|
|
98
98
|
context: Context = runtime.context or Context()
|
|
99
99
|
user_id: str | None = context.user_id
|
|
100
100
|
|
|
101
101
|
if not user_id:
|
|
102
|
-
logger.error("User ID is required but not provided in
|
|
102
|
+
logger.error("User ID is required but not provided in configuration")
|
|
103
103
|
|
|
104
104
|
thread_val = context.thread_id or "<your_thread_id>"
|
|
105
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
106
|
+
context_dict = context.model_dump()
|
|
107
|
+
extra_fields = {
|
|
108
|
+
k: v
|
|
109
|
+
for k, v in context_dict.items()
|
|
110
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
111
|
+
}
|
|
112
|
+
|
|
105
113
|
corrected_config: dict[str, Any] = {
|
|
106
114
|
"configurable": {
|
|
107
115
|
"thread_id": thread_val,
|
|
108
116
|
"user_id": "<your_user_id>",
|
|
109
|
-
**
|
|
117
|
+
**extra_fields,
|
|
110
118
|
},
|
|
111
119
|
"session": {
|
|
112
120
|
"conversation_id": thread_val,
|
|
@@ -138,15 +146,23 @@ Please update your configuration and try again.
|
|
|
138
146
|
raise ValueError(error_message)
|
|
139
147
|
|
|
140
148
|
if "." in user_id:
|
|
141
|
-
logger.error(
|
|
149
|
+
logger.error("User ID contains invalid character '.'", user_id=user_id)
|
|
142
150
|
|
|
143
151
|
corrected_user_id = user_id.replace(".", "_")
|
|
144
152
|
thread_val = context.thread_id or "<your_thread_id>"
|
|
153
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
154
|
+
context_dict = context.model_dump()
|
|
155
|
+
extra_fields = {
|
|
156
|
+
k: v
|
|
157
|
+
for k, v in context_dict.items()
|
|
158
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
159
|
+
}
|
|
160
|
+
|
|
145
161
|
corrected_config: dict[str, Any] = {
|
|
146
162
|
"configurable": {
|
|
147
163
|
"thread_id": thread_val,
|
|
148
164
|
"user_id": corrected_user_id,
|
|
149
|
-
**
|
|
165
|
+
**extra_fields,
|
|
150
166
|
},
|
|
151
167
|
"session": {
|
|
152
168
|
"conversation_id": thread_val,
|
|
@@ -183,19 +199,27 @@ class ThreadIdValidationMiddleware(MessageValidationMiddleware):
|
|
|
183
199
|
self, state: AgentState, runtime: Runtime[Context]
|
|
184
200
|
) -> dict[str, Any] | None:
|
|
185
201
|
"""Validate thread_id/conversation_id is present."""
|
|
186
|
-
logger.
|
|
202
|
+
logger.trace("Executing thread_id/conversation_id validation")
|
|
187
203
|
|
|
188
204
|
context: Context = runtime.context or Context()
|
|
189
205
|
thread_id: str | None = context.thread_id
|
|
190
206
|
|
|
191
207
|
if not thread_id:
|
|
192
|
-
logger.error("Thread ID / Conversation ID is required but not provided
|
|
208
|
+
logger.error("Thread ID / Conversation ID is required but not provided")
|
|
209
|
+
|
|
210
|
+
# Get extra fields from context (excluding user_id and thread_id)
|
|
211
|
+
context_dict = context.model_dump()
|
|
212
|
+
extra_fields = {
|
|
213
|
+
k: v
|
|
214
|
+
for k, v in context_dict.items()
|
|
215
|
+
if k not in {"user_id", "thread_id"} and v is not None
|
|
216
|
+
}
|
|
193
217
|
|
|
194
218
|
corrected_config: dict[str, Any] = {
|
|
195
219
|
"configurable": {
|
|
196
220
|
"thread_id": "<your_thread_id>",
|
|
197
221
|
"user_id": context.user_id or "<your_user_id>",
|
|
198
|
-
**
|
|
222
|
+
**extra_fields,
|
|
199
223
|
},
|
|
200
224
|
"session": {
|
|
201
225
|
"conversation_id": "<your_thread_id>",
|
|
@@ -269,7 +293,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
269
293
|
Middleware that validates the presence of required custom fields.
|
|
270
294
|
|
|
271
295
|
This is a generic validation middleware that can check for multiple
|
|
272
|
-
required fields in context.
|
|
296
|
+
required fields in the context object.
|
|
273
297
|
|
|
274
298
|
Fields are defined in the `fields` list. Each field can have:
|
|
275
299
|
- name: The field name (required)
|
|
@@ -312,7 +336,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
312
336
|
<field_name>: <example_value>
|
|
313
337
|
session: {}
|
|
314
338
|
"""
|
|
315
|
-
logger.
|
|
339
|
+
logger.trace("Executing custom field validation")
|
|
316
340
|
|
|
317
341
|
context: Context = runtime.context or Context()
|
|
318
342
|
|
|
@@ -320,7 +344,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
320
344
|
missing_fields: list[RequiredField] = []
|
|
321
345
|
for field in self.fields:
|
|
322
346
|
if field.is_required:
|
|
323
|
-
field_value: Any = context
|
|
347
|
+
field_value: Any = getattr(context, field.name, None)
|
|
324
348
|
if field_value is None:
|
|
325
349
|
missing_fields.append(field)
|
|
326
350
|
|
|
@@ -329,7 +353,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
329
353
|
|
|
330
354
|
# Log the missing fields
|
|
331
355
|
missing_names = [f.name for f in missing_fields]
|
|
332
|
-
logger.error(
|
|
356
|
+
logger.error("Required fields missing", fields=missing_names)
|
|
333
357
|
|
|
334
358
|
# Build the configurable dict preserving provided values
|
|
335
359
|
# and using example_value for missing required fields
|
|
@@ -344,9 +368,11 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
344
368
|
else:
|
|
345
369
|
configurable["user_id"] = "<your_user_id>"
|
|
346
370
|
|
|
347
|
-
# Add all values the user already provided
|
|
348
|
-
|
|
349
|
-
|
|
371
|
+
# Add all extra values the user already provided
|
|
372
|
+
context_dict = context.model_dump()
|
|
373
|
+
for k, v in context_dict.items():
|
|
374
|
+
if k not in {"user_id", "thread_id"} and v is not None:
|
|
375
|
+
configurable[k] = v
|
|
350
376
|
|
|
351
377
|
# Then add our defined fields (provided values take precedence)
|
|
352
378
|
for field in self.fields:
|
|
@@ -380,8 +406,16 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
|
|
|
380
406
|
field_descriptions: list[str] = [
|
|
381
407
|
"- **thread_id**: Thread identifier (required in configurable)",
|
|
382
408
|
"- **conversation_id**: Alias of thread_id (in session)",
|
|
383
|
-
"- **user_id**: Your unique user identifier (required)",
|
|
384
409
|
]
|
|
410
|
+
|
|
411
|
+
# Add user_id if not in custom fields
|
|
412
|
+
has_user_id_field = any(f.name == "user_id" for f in self.fields)
|
|
413
|
+
if not has_user_id_field:
|
|
414
|
+
field_descriptions.append(
|
|
415
|
+
"- **user_id**: Your unique user identifier (required)"
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
# Add custom field descriptions
|
|
385
419
|
for field in self.fields:
|
|
386
420
|
required_text = "(required)" if field.is_required else "(optional)"
|
|
387
421
|
field_descriptions.append(
|
|
@@ -427,22 +461,22 @@ class FilterLastHumanMessageMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
427
461
|
self, state: AgentState, runtime: Runtime[Context]
|
|
428
462
|
) -> dict[str, Any] | None:
|
|
429
463
|
"""Filter messages to keep only the last human message."""
|
|
430
|
-
logger.
|
|
464
|
+
logger.trace("Executing filter_last_human_message middleware")
|
|
431
465
|
|
|
432
466
|
messages: list[BaseMessage] = state.get("messages", [])
|
|
433
467
|
|
|
434
468
|
if not messages:
|
|
435
|
-
logger.
|
|
469
|
+
logger.trace("No messages found in state")
|
|
436
470
|
return None
|
|
437
471
|
|
|
438
472
|
last_message: HumanMessage | None = last_human_message(messages)
|
|
439
473
|
|
|
440
474
|
if last_message is None:
|
|
441
|
-
logger.
|
|
475
|
+
logger.trace("No human messages found in state")
|
|
442
476
|
return {"messages": []}
|
|
443
477
|
|
|
444
|
-
logger.
|
|
445
|
-
|
|
478
|
+
logger.trace(
|
|
479
|
+
"Filtered messages to last human message", original_count=len(messages)
|
|
446
480
|
)
|
|
447
481
|
|
|
448
482
|
removed_messages = [
|
|
@@ -472,7 +506,7 @@ def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
|
|
|
472
506
|
Example:
|
|
473
507
|
middleware = create_user_id_validation_middleware()
|
|
474
508
|
"""
|
|
475
|
-
logger.
|
|
509
|
+
logger.trace("Creating user_id validation middleware")
|
|
476
510
|
return UserIdValidationMiddleware()
|
|
477
511
|
|
|
478
512
|
|
|
@@ -489,7 +523,7 @@ def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
|
|
|
489
523
|
Example:
|
|
490
524
|
middleware = create_thread_id_validation_middleware()
|
|
491
525
|
"""
|
|
492
|
-
logger.
|
|
526
|
+
logger.trace("Creating thread_id validation middleware")
|
|
493
527
|
return ThreadIdValidationMiddleware()
|
|
494
528
|
|
|
495
529
|
|
|
@@ -500,7 +534,7 @@ def create_custom_field_validation_middleware(
|
|
|
500
534
|
Create a CustomFieldValidationMiddleware instance.
|
|
501
535
|
|
|
502
536
|
Factory function for creating middleware that validates the presence
|
|
503
|
-
of required custom fields in context.
|
|
537
|
+
of required custom fields in the context object.
|
|
504
538
|
|
|
505
539
|
Each field in the list should have:
|
|
506
540
|
- name: The field name (required)
|
|
@@ -530,9 +564,7 @@ def create_custom_field_validation_middleware(
|
|
|
530
564
|
)
|
|
531
565
|
"""
|
|
532
566
|
field_names = [f.get("name", "unknown") for f in fields]
|
|
533
|
-
logger.
|
|
534
|
-
f"Creating custom field validation middleware for fields: {field_names}"
|
|
535
|
-
)
|
|
567
|
+
logger.trace("Creating custom field validation middleware", fields=field_names)
|
|
536
568
|
return CustomFieldValidationMiddleware(fields=fields)
|
|
537
569
|
|
|
538
570
|
|
|
@@ -550,5 +582,5 @@ def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddl
|
|
|
550
582
|
Example:
|
|
551
583
|
middleware = create_filter_last_human_message_middleware()
|
|
552
584
|
"""
|
|
553
|
-
logger.
|
|
585
|
+
logger.trace("Creating filter_last_human_message middleware")
|
|
554
586
|
return FilterLastHumanMessageMiddleware()
|
|
@@ -76,15 +76,19 @@ class LoggingSummarizationMiddleware(SummarizationMiddleware):
|
|
|
76
76
|
summarized_count = original_message_count - preserved_count
|
|
77
77
|
|
|
78
78
|
logger.info(
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
79
|
+
"Conversation summarized",
|
|
80
|
+
before_messages=original_message_count,
|
|
81
|
+
before_tokens=original_token_count,
|
|
82
|
+
after_messages=new_message_count,
|
|
83
|
+
after_tokens=new_token_count,
|
|
84
|
+
summarized_messages=summarized_count,
|
|
83
85
|
)
|
|
84
86
|
logger.debug(
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
87
|
+
"Summarization details",
|
|
88
|
+
trigger=self.trigger,
|
|
89
|
+
keep=self.keep,
|
|
90
|
+
preserved_messages=preserved_count,
|
|
91
|
+
token_reduction=original_token_count - new_token_count,
|
|
88
92
|
)
|
|
89
93
|
|
|
90
94
|
def _is_remove_message(self, msg: Any) -> bool:
|
|
@@ -160,9 +164,10 @@ def create_summarization_middleware(
|
|
|
160
164
|
middleware = create_summarization_middleware(chat_history)
|
|
161
165
|
"""
|
|
162
166
|
logger.debug(
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
167
|
+
"Creating summarization middleware",
|
|
168
|
+
max_tokens=chat_history.max_tokens,
|
|
169
|
+
max_tokens_before_summary=chat_history.max_tokens_before_summary,
|
|
170
|
+
max_messages_before_summary=chat_history.max_messages_before_summary,
|
|
166
171
|
)
|
|
167
172
|
|
|
168
173
|
# Get the LLM model
|
|
@@ -183,7 +188,7 @@ def create_summarization_middleware(
|
|
|
183
188
|
# Default to keeping enough for context
|
|
184
189
|
keep: Tuple[str, int] = ("tokens", chat_history.max_tokens)
|
|
185
190
|
|
|
186
|
-
logger.info(
|
|
191
|
+
logger.info("Summarization middleware configured", trigger=trigger, keep=keep)
|
|
187
192
|
|
|
188
193
|
return LoggingSummarizationMiddleware(
|
|
189
194
|
model=model,
|