dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. dao_ai/agent_as_code.py +2 -5
  2. dao_ai/cli.py +65 -15
  3. dao_ai/config.py +672 -218
  4. dao_ai/genie/cache/core.py +6 -2
  5. dao_ai/genie/cache/lru.py +29 -11
  6. dao_ai/genie/cache/semantic.py +95 -44
  7. dao_ai/hooks/core.py +5 -5
  8. dao_ai/logging.py +56 -0
  9. dao_ai/memory/core.py +61 -44
  10. dao_ai/memory/databricks.py +54 -41
  11. dao_ai/memory/postgres.py +77 -36
  12. dao_ai/middleware/assertions.py +45 -17
  13. dao_ai/middleware/core.py +13 -7
  14. dao_ai/middleware/guardrails.py +30 -25
  15. dao_ai/middleware/human_in_the_loop.py +9 -5
  16. dao_ai/middleware/message_validation.py +61 -29
  17. dao_ai/middleware/summarization.py +16 -11
  18. dao_ai/models.py +172 -69
  19. dao_ai/nodes.py +148 -19
  20. dao_ai/optimization.py +26 -16
  21. dao_ai/orchestration/core.py +15 -8
  22. dao_ai/orchestration/supervisor.py +22 -8
  23. dao_ai/orchestration/swarm.py +57 -12
  24. dao_ai/prompts.py +17 -17
  25. dao_ai/providers/databricks.py +365 -155
  26. dao_ai/state.py +24 -6
  27. dao_ai/tools/__init__.py +2 -0
  28. dao_ai/tools/agent.py +1 -3
  29. dao_ai/tools/core.py +7 -7
  30. dao_ai/tools/email.py +29 -77
  31. dao_ai/tools/genie.py +18 -13
  32. dao_ai/tools/mcp.py +223 -156
  33. dao_ai/tools/python.py +5 -2
  34. dao_ai/tools/search.py +1 -1
  35. dao_ai/tools/slack.py +21 -9
  36. dao_ai/tools/sql.py +202 -0
  37. dao_ai/tools/time.py +30 -7
  38. dao_ai/tools/unity_catalog.py +129 -86
  39. dao_ai/tools/vector_search.py +318 -244
  40. dao_ai/utils.py +15 -10
  41. dao_ai-0.1.3.dist-info/METADATA +455 -0
  42. dao_ai-0.1.3.dist-info/RECORD +64 -0
  43. dao_ai-0.1.1.dist-info/METADATA +0 -1878
  44. dao_ai-0.1.1.dist-info/RECORD +0 -62
  45. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
  46. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
  47. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/middleware/core.py CHANGED
@@ -7,15 +7,17 @@ from fully qualified function names.
7
7
 
8
8
  from typing import Any, Callable
9
9
 
10
+ from langchain.agents.middleware import AgentMiddleware
10
11
  from loguru import logger
11
12
 
13
+ from dao_ai.state import AgentState, Context
12
14
  from dao_ai.utils import load_function
13
15
 
14
16
 
15
17
  def create_factory_middleware(
16
18
  function_name: str,
17
19
  args: dict[str, Any] | None = None,
18
- ) -> Any:
20
+ ) -> AgentMiddleware[AgentState, Context]:
19
21
  """
20
22
  Create middleware from a factory function.
21
23
 
@@ -33,14 +35,14 @@ def create_factory_middleware(
33
35
  args: Arguments to pass to the factory function
34
36
 
35
37
  Returns:
36
- A middleware instance returned by the factory function
38
+ An AgentMiddleware instance returned by the factory function
37
39
 
38
40
  Raises:
39
41
  ImportError: If the function cannot be loaded
40
42
 
41
43
  Example:
42
44
  # Factory function in my_module.py:
43
- def create_custom_middleware(threshold: float = 0.5) -> AgentMiddleware:
45
+ def create_custom_middleware(threshold: float = 0.5) -> AgentMiddleware[AgentState, Context]:
44
46
  return MyCustomMiddleware(threshold=threshold)
45
47
 
46
48
  # Usage:
@@ -52,10 +54,14 @@ def create_factory_middleware(
52
54
  if args is None:
53
55
  args = {}
54
56
 
55
- logger.debug(f"Creating factory middleware: {function_name} with args: {args}")
57
+ logger.trace("Creating factory middleware", function_name=function_name, args=args)
56
58
 
57
- factory: Callable[..., Any] = load_function(function_name=function_name)
58
- middleware: Any = factory(**args)
59
+ factory: Callable[..., AgentMiddleware[AgentState, Context]] = load_function(
60
+ function_name=function_name
61
+ )
62
+ middleware: AgentMiddleware[AgentState, Context] = factory(**args)
59
63
 
60
- logger.debug(f"Created middleware from factory: {type(middleware).__name__}")
64
+ logger.trace(
65
+ "Created middleware from factory", middleware_type=type(middleware).__name__
66
+ )
61
67
  return middleware
@@ -117,29 +117,29 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
117
117
 
118
118
  # Skip evaluation if the AI message has tool calls (not the final response yet)
119
119
  if ai_message.tool_calls:
120
- logger.debug(
121
- f"Guardrail '{self.guardrail_name}' skipping evaluation - "
122
- "AI message contains tool calls, waiting for final response"
120
+ logger.trace(
121
+ "Guardrail skipping evaluation - AI message contains tool calls",
122
+ guardrail_name=self.guardrail_name,
123
123
  )
124
124
  return None
125
125
 
126
126
  # Skip evaluation if the AI message has no content to evaluate
127
127
  if not ai_message.content:
128
- logger.debug(
129
- f"Guardrail '{self.guardrail_name}' skipping evaluation - "
130
- "AI message has no content"
128
+ logger.trace(
129
+ "Guardrail skipping evaluation - AI message has no content",
130
+ guardrail_name=self.guardrail_name,
131
131
  )
132
132
  return None
133
133
 
134
- logger.debug(f"Evaluating response with guardrail '{self.guardrail_name}'")
135
-
136
134
  # Extract text content from messages (handles both string and structured content)
137
135
  human_content = _extract_text_content(human_message)
138
136
  ai_content = _extract_text_content(ai_message)
139
137
 
140
138
  logger.debug(
141
- f"Guardrail '{self.guardrail_name}' evaluating: "
142
- f"input_length={len(human_content)}, output_length={len(ai_content)}"
139
+ "Evaluating response with guardrail",
140
+ guardrail_name=self.guardrail_name,
141
+ input_length=len(human_content),
142
+ output_length=len(ai_content),
143
143
  )
144
144
 
145
145
  evaluator = create_llm_as_judge(
@@ -150,8 +150,11 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
150
150
  eval_result = evaluator(inputs=human_content, outputs=ai_content)
151
151
 
152
152
  if eval_result["score"]:
153
- logger.debug(f"Response approved by guardrail '{self.guardrail_name}'")
154
- logger.debug(f"Judge's comment: {eval_result['comment']}")
153
+ logger.debug(
154
+ "Response approved by guardrail",
155
+ guardrail_name=self.guardrail_name,
156
+ comment=eval_result["comment"],
157
+ )
155
158
  self._retry_count = 0
156
159
  return None
157
160
  else:
@@ -160,10 +163,12 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
160
163
 
161
164
  if self._retry_count >= self.num_retries:
162
165
  logger.warning(
163
- f"Guardrail '{self.guardrail_name}' failed - max retries reached "
164
- f"({self._retry_count}/{self.num_retries})"
166
+ "Guardrail failed - max retries reached",
167
+ guardrail_name=self.guardrail_name,
168
+ retry_count=self._retry_count,
169
+ max_retries=self.num_retries,
170
+ critique=comment,
165
171
  )
166
- logger.warning(f"Final judge's critique: {comment}")
167
172
  self._retry_count = 0
168
173
 
169
174
  # Add system message to inform user of guardrail failure
@@ -177,10 +182,12 @@ class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
177
182
  return {"messages": [AIMessage(content=failure_message)]}
178
183
 
179
184
  logger.warning(
180
- f"Guardrail '{self.guardrail_name}' requested improvements "
181
- f"(retry {self._retry_count}/{self.num_retries})"
185
+ "Guardrail requested improvements",
186
+ guardrail_name=self.guardrail_name,
187
+ retry=self._retry_count,
188
+ max_retries=self.num_retries,
189
+ critique=comment,
182
190
  )
183
- logger.warning(f"Judge's critique: {comment}")
184
191
 
185
192
  content: str = "\n".join([str(human_message.content), comment])
186
193
  return {"messages": [HumanMessage(content=content)]}
@@ -250,9 +257,7 @@ class ContentFilterMiddleware(AgentMiddleware[AgentState, Context]):
250
257
 
251
258
  for keyword in self.banned_keywords:
252
259
  if keyword in content:
253
- logger.warning(
254
- f"Content filter blocked response containing '{keyword}'"
255
- )
260
+ logger.warning("Content filter blocked response", keyword=keyword)
256
261
  # Modify the last message content
257
262
  last_message.content = self.block_message
258
263
  return None
@@ -347,7 +352,7 @@ def create_guardrail_middleware(
347
352
  num_retries=2,
348
353
  )
349
354
  """
350
- logger.debug(f"Creating guardrail middleware: {name}")
355
+ logger.trace("Creating guardrail middleware", guardrail_name=name)
351
356
  return GuardrailMiddleware(
352
357
  name=name,
353
358
  model=model,
@@ -379,8 +384,8 @@ def create_content_filter_middleware(
379
384
  block_message="I cannot discuss sensitive credentials.",
380
385
  )
381
386
  """
382
- logger.debug(
383
- f"Creating content filter middleware with {len(banned_keywords)} keywords"
387
+ logger.trace(
388
+ "Creating content filter middleware", keywords_count=len(banned_keywords)
384
389
  )
385
390
  return ContentFilterMiddleware(
386
391
  banned_keywords=banned_keywords,
@@ -411,5 +416,5 @@ def create_safety_guardrail_middleware(
411
416
  safety_model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
412
417
  )
413
418
  """
414
- logger.debug("Creating safety guardrail middleware")
419
+ logger.trace("Creating safety guardrail middleware")
415
420
  return SafetyGuardrailMiddleware(safety_model=safety_model)
@@ -103,7 +103,10 @@ def _config_to_interrupt_on_entry(
103
103
  interrupt_entry["description"] = config.review_prompt
104
104
  return interrupt_entry
105
105
 
106
- logger.warning(f"Unknown HITL config type: {type(config)}, defaulting to True")
106
+ logger.warning(
107
+ "Unknown HITL config type, defaulting to True",
108
+ config_type=type(config).__name__,
109
+ )
107
110
  return True
108
111
 
109
112
 
@@ -152,8 +155,9 @@ def create_human_in_the_loop_middleware(
152
155
  normalized_interrupt_on[tool_name] = config
153
156
 
154
157
  logger.debug(
155
- f"Creating HITL middleware for {len(normalized_interrupt_on)} tools: "
156
- f"{list(normalized_interrupt_on.keys())}"
158
+ "Creating HITL middleware",
159
+ tools_count=len(normalized_interrupt_on),
160
+ tools=list(normalized_interrupt_on.keys()),
157
161
  )
158
162
 
159
163
  return HumanInTheLoopMiddleware(
@@ -216,10 +220,10 @@ def create_hitl_middleware_from_tool_models(
216
220
  tool_name: str | None = getattr(func_tool, "name", None)
217
221
  if tool_name:
218
222
  interrupt_on[tool_name] = hitl_config
219
- logger.debug(f"Tool '{tool_name}' configured for HITL")
223
+ logger.trace("Tool configured for HITL", tool_name=tool_name)
220
224
 
221
225
  if not interrupt_on:
222
- logger.debug("No tools require HITL - returning None")
226
+ logger.trace("No tools require HITL - returning None")
223
227
  return None
224
228
 
225
229
  return create_human_in_the_loop_middleware(
@@ -50,7 +50,7 @@ class MessageValidationMiddleware(AgentMiddleware[AgentState, Context]):
50
50
  try:
51
51
  return self.validate(state, runtime)
52
52
  except ValueError as e:
53
- logger.error(f"Message validation failed: {e}")
53
+ logger.error("Message validation failed", error=str(e))
54
54
  return {
55
55
  "is_valid": False,
56
56
  "message_error": str(e),
@@ -93,20 +93,28 @@ class UserIdValidationMiddleware(MessageValidationMiddleware):
93
93
  self, state: AgentState, runtime: Runtime[Context]
94
94
  ) -> dict[str, Any] | None:
95
95
  """Validate user_id is present and properly formatted."""
96
- logger.debug("Executing user_id validation")
96
+ logger.trace("Executing user_id validation")
97
97
 
98
98
  context: Context = runtime.context or Context()
99
99
  user_id: str | None = context.user_id
100
100
 
101
101
  if not user_id:
102
- logger.error("User ID is required but not provided in the configuration.")
102
+ logger.error("User ID is required but not provided in configuration")
103
103
 
104
104
  thread_val = context.thread_id or "<your_thread_id>"
105
+ # Get extra fields from context (excluding user_id and thread_id)
106
+ context_dict = context.model_dump()
107
+ extra_fields = {
108
+ k: v
109
+ for k, v in context_dict.items()
110
+ if k not in {"user_id", "thread_id"} and v is not None
111
+ }
112
+
105
113
  corrected_config: dict[str, Any] = {
106
114
  "configurable": {
107
115
  "thread_id": thread_val,
108
116
  "user_id": "<your_user_id>",
109
- **context.custom,
117
+ **extra_fields,
110
118
  },
111
119
  "session": {
112
120
  "conversation_id": thread_val,
@@ -138,15 +146,23 @@ Please update your configuration and try again.
138
146
  raise ValueError(error_message)
139
147
 
140
148
  if "." in user_id:
141
- logger.error(f"User ID '{user_id}' contains invalid character '.'")
149
+ logger.error("User ID contains invalid character '.'", user_id=user_id)
142
150
 
143
151
  corrected_user_id = user_id.replace(".", "_")
144
152
  thread_val = context.thread_id or "<your_thread_id>"
153
+ # Get extra fields from context (excluding user_id and thread_id)
154
+ context_dict = context.model_dump()
155
+ extra_fields = {
156
+ k: v
157
+ for k, v in context_dict.items()
158
+ if k not in {"user_id", "thread_id"} and v is not None
159
+ }
160
+
145
161
  corrected_config: dict[str, Any] = {
146
162
  "configurable": {
147
163
  "thread_id": thread_val,
148
164
  "user_id": corrected_user_id,
149
- **context.custom,
165
+ **extra_fields,
150
166
  },
151
167
  "session": {
152
168
  "conversation_id": thread_val,
@@ -183,19 +199,27 @@ class ThreadIdValidationMiddleware(MessageValidationMiddleware):
183
199
  self, state: AgentState, runtime: Runtime[Context]
184
200
  ) -> dict[str, Any] | None:
185
201
  """Validate thread_id/conversation_id is present."""
186
- logger.debug("Executing thread_id/conversation_id validation")
202
+ logger.trace("Executing thread_id/conversation_id validation")
187
203
 
188
204
  context: Context = runtime.context or Context()
189
205
  thread_id: str | None = context.thread_id
190
206
 
191
207
  if not thread_id:
192
- logger.error("Thread ID / Conversation ID is required but not provided.")
208
+ logger.error("Thread ID / Conversation ID is required but not provided")
209
+
210
+ # Get extra fields from context (excluding user_id and thread_id)
211
+ context_dict = context.model_dump()
212
+ extra_fields = {
213
+ k: v
214
+ for k, v in context_dict.items()
215
+ if k not in {"user_id", "thread_id"} and v is not None
216
+ }
193
217
 
194
218
  corrected_config: dict[str, Any] = {
195
219
  "configurable": {
196
220
  "thread_id": "<your_thread_id>",
197
221
  "user_id": context.user_id or "<your_user_id>",
198
- **context.custom,
222
+ **extra_fields,
199
223
  },
200
224
  "session": {
201
225
  "conversation_id": "<your_thread_id>",
@@ -269,7 +293,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
269
293
  Middleware that validates the presence of required custom fields.
270
294
 
271
295
  This is a generic validation middleware that can check for multiple
272
- required fields in context.custom.
296
+ required fields in the context object.
273
297
 
274
298
  Fields are defined in the `fields` list. Each field can have:
275
299
  - name: The field name (required)
@@ -312,7 +336,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
312
336
  <field_name>: <example_value>
313
337
  session: {}
314
338
  """
315
- logger.debug("Executing custom field validation")
339
+ logger.trace("Executing custom field validation")
316
340
 
317
341
  context: Context = runtime.context or Context()
318
342
 
@@ -320,7 +344,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
320
344
  missing_fields: list[RequiredField] = []
321
345
  for field in self.fields:
322
346
  if field.is_required:
323
- field_value: Any = context.custom.get(field.name)
347
+ field_value: Any = getattr(context, field.name, None)
324
348
  if field_value is None:
325
349
  missing_fields.append(field)
326
350
 
@@ -329,7 +353,7 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
329
353
 
330
354
  # Log the missing fields
331
355
  missing_names = [f.name for f in missing_fields]
332
- logger.error(f"Required fields missing: {', '.join(missing_names)}")
356
+ logger.error("Required fields missing", fields=missing_names)
333
357
 
334
358
  # Build the configurable dict preserving provided values
335
359
  # and using example_value for missing required fields
@@ -344,9 +368,11 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
344
368
  else:
345
369
  configurable["user_id"] = "<your_user_id>"
346
370
 
347
- # Add all values the user already provided in custom
348
- for k, v in context.custom.items():
349
- configurable[k] = v
371
+ # Add all extra values the user already provided
372
+ context_dict = context.model_dump()
373
+ for k, v in context_dict.items():
374
+ if k not in {"user_id", "thread_id"} and v is not None:
375
+ configurable[k] = v
350
376
 
351
377
  # Then add our defined fields (provided values take precedence)
352
378
  for field in self.fields:
@@ -380,8 +406,16 @@ class CustomFieldValidationMiddleware(MessageValidationMiddleware):
380
406
  field_descriptions: list[str] = [
381
407
  "- **thread_id**: Thread identifier (required in configurable)",
382
408
  "- **conversation_id**: Alias of thread_id (in session)",
383
- "- **user_id**: Your unique user identifier (required)",
384
409
  ]
410
+
411
+ # Add user_id if not in custom fields
412
+ has_user_id_field = any(f.name == "user_id" for f in self.fields)
413
+ if not has_user_id_field:
414
+ field_descriptions.append(
415
+ "- **user_id**: Your unique user identifier (required)"
416
+ )
417
+
418
+ # Add custom field descriptions
385
419
  for field in self.fields:
386
420
  required_text = "(required)" if field.is_required else "(optional)"
387
421
  field_descriptions.append(
@@ -427,22 +461,22 @@ class FilterLastHumanMessageMiddleware(AgentMiddleware[AgentState, Context]):
427
461
  self, state: AgentState, runtime: Runtime[Context]
428
462
  ) -> dict[str, Any] | None:
429
463
  """Filter messages to keep only the last human message."""
430
- logger.debug("Executing filter_last_human_message middleware")
464
+ logger.trace("Executing filter_last_human_message middleware")
431
465
 
432
466
  messages: list[BaseMessage] = state.get("messages", [])
433
467
 
434
468
  if not messages:
435
- logger.debug("No messages found in state")
469
+ logger.trace("No messages found in state")
436
470
  return None
437
471
 
438
472
  last_message: HumanMessage | None = last_human_message(messages)
439
473
 
440
474
  if last_message is None:
441
- logger.debug("No human messages found in state")
475
+ logger.trace("No human messages found in state")
442
476
  return {"messages": []}
443
477
 
444
- logger.debug(
445
- f"Filtered {len(messages)} messages down to 1 (last human message)"
478
+ logger.trace(
479
+ "Filtered messages to last human message", original_count=len(messages)
446
480
  )
447
481
 
448
482
  removed_messages = [
@@ -472,7 +506,7 @@ def create_user_id_validation_middleware() -> UserIdValidationMiddleware:
472
506
  Example:
473
507
  middleware = create_user_id_validation_middleware()
474
508
  """
475
- logger.debug("Creating user_id validation middleware")
509
+ logger.trace("Creating user_id validation middleware")
476
510
  return UserIdValidationMiddleware()
477
511
 
478
512
 
@@ -489,7 +523,7 @@ def create_thread_id_validation_middleware() -> ThreadIdValidationMiddleware:
489
523
  Example:
490
524
  middleware = create_thread_id_validation_middleware()
491
525
  """
492
- logger.debug("Creating thread_id validation middleware")
526
+ logger.trace("Creating thread_id validation middleware")
493
527
  return ThreadIdValidationMiddleware()
494
528
 
495
529
 
@@ -500,7 +534,7 @@ def create_custom_field_validation_middleware(
500
534
  Create a CustomFieldValidationMiddleware instance.
501
535
 
502
536
  Factory function for creating middleware that validates the presence
503
- of required custom fields in context.custom.
537
+ of required custom fields in the context object.
504
538
 
505
539
  Each field in the list should have:
506
540
  - name: The field name (required)
@@ -530,9 +564,7 @@ def create_custom_field_validation_middleware(
530
564
  )
531
565
  """
532
566
  field_names = [f.get("name", "unknown") for f in fields]
533
- logger.debug(
534
- f"Creating custom field validation middleware for fields: {field_names}"
535
- )
567
+ logger.trace("Creating custom field validation middleware", fields=field_names)
536
568
  return CustomFieldValidationMiddleware(fields=fields)
537
569
 
538
570
 
@@ -550,5 +582,5 @@ def create_filter_last_human_message_middleware() -> FilterLastHumanMessageMiddl
550
582
  Example:
551
583
  middleware = create_filter_last_human_message_middleware()
552
584
  """
553
- logger.debug("Creating filter_last_human_message middleware")
585
+ logger.trace("Creating filter_last_human_message middleware")
554
586
  return FilterLastHumanMessageMiddleware()
@@ -76,15 +76,19 @@ class LoggingSummarizationMiddleware(SummarizationMiddleware):
76
76
  summarized_count = original_message_count - preserved_count
77
77
 
78
78
  logger.info(
79
- f"Conversation summarized: "
80
- f"BEFORE: {original_message_count} messages (~{original_token_count:,} tokens) → "
81
- f"AFTER: {new_message_count} messages (~{new_token_count:,} tokens) | "
82
- f"{summarized_count} messages condensed into 1 summary"
79
+ "Conversation summarized",
80
+ before_messages=original_message_count,
81
+ before_tokens=original_token_count,
82
+ after_messages=new_message_count,
83
+ after_tokens=new_token_count,
84
+ summarized_messages=summarized_count,
83
85
  )
84
86
  logger.debug(
85
- f"Summarization details: trigger={self.trigger}, keep={self.keep}, "
86
- f"preserved_messages={preserved_count}, "
87
- f"token_reduction={original_token_count - new_token_count:,}"
87
+ "Summarization details",
88
+ trigger=self.trigger,
89
+ keep=self.keep,
90
+ preserved_messages=preserved_count,
91
+ token_reduction=original_token_count - new_token_count,
88
92
  )
89
93
 
90
94
  def _is_remove_message(self, msg: Any) -> bool:
@@ -160,9 +164,10 @@ def create_summarization_middleware(
160
164
  middleware = create_summarization_middleware(chat_history)
161
165
  """
162
166
  logger.debug(
163
- f"Creating summarization middleware with max_tokens: {chat_history.max_tokens}, "
164
- f"max_tokens_before_summary: {chat_history.max_tokens_before_summary}, "
165
- f"max_messages_before_summary: {chat_history.max_messages_before_summary}"
167
+ "Creating summarization middleware",
168
+ max_tokens=chat_history.max_tokens,
169
+ max_tokens_before_summary=chat_history.max_tokens_before_summary,
170
+ max_messages_before_summary=chat_history.max_messages_before_summary,
166
171
  )
167
172
 
168
173
  # Get the LLM model
@@ -183,7 +188,7 @@ def create_summarization_middleware(
183
188
  # Default to keeping enough for context
184
189
  keep: Tuple[str, int] = ("tokens", chat_history.max_tokens)
185
190
 
186
- logger.info(f"Summarization middleware configured: trigger={trigger}, keep={keep}")
191
+ logger.info("Summarization middleware configured", trigger=trigger, keep=keep)
187
192
 
188
193
  return LoggingSummarizationMiddleware(
189
194
  model=model,