dao-ai 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/genie/cache/lru.py CHANGED
@@ -6,15 +6,16 @@ by Genie. On cache hit, the cached SQL is re-executed against the warehouse
6
6
  to return fresh data while avoiding the Genie NL-to-SQL translation cost.
7
7
  """
8
8
 
9
+ from __future__ import annotations
10
+
9
11
  from collections import OrderedDict
10
12
  from datetime import datetime, timedelta
11
13
  from threading import Lock
12
- from typing import Any
13
14
 
14
15
  import mlflow
15
16
  import pandas as pd
16
17
  from databricks.sdk import WorkspaceClient
17
- from databricks.sdk.service.sql import StatementResponse, StatementState
18
+ from databricks.sdk.service.dashboards import GenieFeedbackRating
18
19
  from databricks_ai_bridge.genie import GenieResponse
19
20
  from loguru import logger
20
21
 
@@ -24,6 +25,7 @@ from dao_ai.genie.cache.base import (
24
25
  GenieServiceBase,
25
26
  SQLCacheEntry,
26
27
  )
28
+ from dao_ai.genie.cache.core import execute_sql_via_warehouse
27
29
 
28
30
 
29
31
  class LRUCacheService(GenieServiceBase):
@@ -141,8 +143,20 @@ class LRUCacheService(GenieServiceBase):
141
143
  self._cache.move_to_end(key)
142
144
  return entry
143
145
 
144
- def _put(self, key: str, response: GenieResponse) -> None:
146
+ def _put(
147
+ self, key: str, response: GenieResponse, message_id: str | None = None
148
+ ) -> None:
145
149
  """Store SQL query in cache, evicting if at capacity."""
150
+ # Skip caching if query is empty or whitespace
151
+ if not response.query or not response.query.strip():
152
+ logger.warning(
153
+ "Not caching: response has no SQL query",
154
+ layer=self.name,
155
+ key=key[:50],
156
+ description=response.description[:80] if response.description else None,
157
+ )
158
+ return
159
+
146
160
  if key in self._cache:
147
161
  del self._cache[key]
148
162
 
@@ -154,6 +168,9 @@ class LRUCacheService(GenieServiceBase):
154
168
  description=response.description,
155
169
  conversation_id=response.conversation_id,
156
170
  created_at=datetime.now(),
171
+ message_id=message_id,
172
+ # LRU cache is in-memory only, no database row ID
173
+ cache_entry_id=None,
157
174
  )
158
175
  logger.debug(
159
176
  "Stored cache entry",
@@ -162,6 +179,7 @@ class LRUCacheService(GenieServiceBase):
162
179
  sql=response.query[:50] if response.query else None,
163
180
  cache_size=len(self._cache),
164
181
  capacity=self.capacity,
182
+ message_id=message_id,
165
183
  )
166
184
 
167
185
  @mlflow.trace(name="execute_cached_sql")
@@ -175,50 +193,22 @@ class LRUCacheService(GenieServiceBase):
175
193
  Returns:
176
194
  DataFrame with results, or error message string
177
195
  """
178
- w: WorkspaceClient = self.warehouse.workspace_client
179
- warehouse_id: str = str(self.warehouse.warehouse_id)
180
-
181
- logger.trace("Executing cached SQL", layer=self.name, sql=sql[:100])
182
-
183
- statement_response: StatementResponse = w.statement_execution.execute_statement(
184
- statement=sql,
185
- warehouse_id=warehouse_id,
186
- wait_timeout="30s",
187
- )
188
-
189
- # Poll for completion if still running
190
- while statement_response.status.state in [
191
- StatementState.PENDING,
192
- StatementState.RUNNING,
193
- ]:
194
- statement_response = w.statement_execution.get_statement(
195
- statement_response.statement_id
196
- )
197
-
198
- if statement_response.status.state != StatementState.SUCCEEDED:
199
- error_msg: str = f"SQL execution failed: {statement_response.status}"
196
+ # Validate SQL is not empty
197
+ if not sql or not sql.strip():
198
+ error_msg: str = "Cannot execute empty SQL query"
200
199
  logger.error(
201
- "SQL execution failed",
200
+ "SQL execution failed: empty query",
202
201
  layer=self.name,
203
- status=str(statement_response.status),
202
+ sql=repr(sql),
204
203
  )
205
204
  return error_msg
206
205
 
207
- # Convert to DataFrame
208
- if statement_response.result and statement_response.result.data_array:
209
- columns: list[str] = []
210
- if statement_response.manifest and statement_response.manifest.schema:
211
- columns = [
212
- col.name for col in statement_response.manifest.schema.columns
213
- ]
214
-
215
- data: list[list[Any]] = statement_response.result.data_array
216
- if columns:
217
- return pd.DataFrame(data, columns=columns)
218
- else:
219
- return pd.DataFrame(data)
220
-
221
- return pd.DataFrame()
206
+ # Use shared utility function for SQL execution
207
+ return execute_sql_via_warehouse(
208
+ warehouse=self.warehouse,
209
+ sql=sql,
210
+ layer_name=self.name,
211
+ )
222
212
 
223
213
  def ask_question(
224
214
  self, question: str, conversation_id: str | None = None
@@ -256,33 +246,131 @@ class LRUCacheService(GenieServiceBase):
256
246
  cached: SQLCacheEntry | None = self._get(key)
257
247
 
258
248
  if cached is not None:
259
- cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
260
- logger.info(
261
- "Cache HIT",
262
- layer=self.name,
263
- question=question[:80],
264
- conversation_id=conversation_id,
265
- cached_sql=cached.query[:80] if cached.query else None,
266
- cache_age_seconds=round(cache_age_seconds, 1),
267
- cache_size=self.size,
268
- capacity=self.capacity,
269
- ttl_seconds=self.parameters.time_to_live_seconds,
270
- )
271
-
272
- # Re-execute the cached SQL to get fresh data
273
- result: pd.DataFrame | str = self._execute_sql(cached.query)
274
-
275
- # Use current conversation_id, not the cached one
276
- response: GenieResponse = GenieResponse(
277
- result=result,
278
- query=cached.query,
279
- description=cached.description,
280
- conversation_id=conversation_id
281
- if conversation_id
282
- else cached.conversation_id,
283
- )
284
-
285
- return CacheResult(response=response, cache_hit=True, served_by=self.name)
249
+ # Defensive check: if cached query is empty, treat as cache miss
250
+ if not cached.query or not cached.query.strip():
251
+ logger.warning(
252
+ "Cache HIT but query is empty, treating as MISS",
253
+ layer=self.name,
254
+ question=question[:80],
255
+ conversation_id=conversation_id,
256
+ key=key[:50],
257
+ )
258
+ # Invalidate this bad cache entry
259
+ with self._lock:
260
+ if key in self._cache:
261
+ del self._cache[key]
262
+ # Fall through to cache miss logic below
263
+ else:
264
+ cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
265
+ logger.info(
266
+ "Cache HIT",
267
+ layer=self.name,
268
+ question=question[:80],
269
+ conversation_id=conversation_id,
270
+ cached_sql=cached.query[:80] if cached.query else None,
271
+ cache_age_seconds=round(cache_age_seconds, 1),
272
+ cache_size=self.size,
273
+ capacity=self.capacity,
274
+ ttl_seconds=self.parameters.time_to_live_seconds,
275
+ )
276
+
277
+ # Re-execute the cached SQL to get fresh data
278
+ result: pd.DataFrame | str = self._execute_sql(cached.query)
279
+
280
+ # Check if SQL execution failed (returns error string instead of DataFrame)
281
+ if isinstance(result, str):
282
+ logger.warning(
283
+ "Cached SQL execution failed, falling back to Genie",
284
+ layer=self.name,
285
+ question=question[:80],
286
+ conversation_id=conversation_id,
287
+ cached_sql=cached.query[:80],
288
+ error=result[:200],
289
+ cache_key=key[:50],
290
+ )
291
+
292
+ # Invalidate the bad cache entry
293
+ with self._lock:
294
+ if key in self._cache:
295
+ del self._cache[key]
296
+ logger.info(
297
+ "Invalidated stale cache entry",
298
+ layer=self.name,
299
+ cache_key=key[:50],
300
+ cache_size=len(self._cache),
301
+ capacity=self.capacity,
302
+ )
303
+
304
+ # Fall back to Genie to get fresh SQL
305
+ logger.info(
306
+ "Delegating to Genie for fresh SQL",
307
+ layer=self.name,
308
+ question=question[:80],
309
+ delegating_to=type(self.impl).__name__,
310
+ )
311
+ fallback_result: CacheResult = self.impl.ask_question(
312
+ question, conversation_id
313
+ )
314
+
315
+ # Store the fresh SQL in cache (including message_id for feedback)
316
+ if fallback_result.response.query:
317
+ with self._lock:
318
+ self._put(
319
+ key,
320
+ fallback_result.response,
321
+ message_id=fallback_result.message_id,
322
+ )
323
+ logger.info(
324
+ "Stored fresh SQL from fallback",
325
+ layer=self.name,
326
+ fresh_sql=fallback_result.response.query[:80],
327
+ cache_size=len(self._cache),
328
+ capacity=self.capacity,
329
+ message_id=fallback_result.message_id,
330
+ )
331
+ else:
332
+ logger.warning(
333
+ "Fallback response has no SQL query to cache",
334
+ layer=self.name,
335
+ question=question[:80],
336
+ )
337
+
338
+ logger.info(
339
+ "Fallback completed successfully",
340
+ layer=self.name,
341
+ question=question[:80],
342
+ fallback_from="stale_cache",
343
+ has_result=fallback_result.response.result is not None,
344
+ )
345
+
346
+ # Return as cache miss (fallback scenario)
347
+ # Propagate message_id from fallback result
348
+ return CacheResult(
349
+ response=fallback_result.response,
350
+ cache_hit=False,
351
+ served_by=None,
352
+ message_id=fallback_result.message_id,
353
+ )
354
+
355
+ # Use current conversation_id, not the cached one
356
+ response: GenieResponse = GenieResponse(
357
+ result=result,
358
+ query=cached.query,
359
+ description=cached.description,
360
+ conversation_id=conversation_id
361
+ if conversation_id
362
+ else cached.conversation_id,
363
+ )
364
+
365
+ # Cache hit - include message_id from original response for feedback support
366
+ return CacheResult(
367
+ response=response,
368
+ cache_hit=True,
369
+ served_by=self.name,
370
+ message_id=cached.message_id,
371
+ # LRU cache is in-memory only, no cache_entry_id for traceability
372
+ cache_entry_id=None,
373
+ )
286
374
 
287
375
  # Cache miss - delegate to wrapped service
288
376
  logger.info(
@@ -298,7 +386,7 @@ class LRUCacheService(GenieServiceBase):
298
386
 
299
387
  result: CacheResult = self.impl.ask_question(question, conversation_id)
300
388
  with self._lock:
301
- self._put(key, result.response)
389
+ self._put(key, result.response, message_id=result.message_id)
302
390
  # Propagate the inner cache's result - if it was a hit there, preserve that info
303
391
  return result
304
392
 
@@ -306,6 +394,11 @@ class LRUCacheService(GenieServiceBase):
306
394
  def space_id(self) -> str:
307
395
  return self.impl.space_id
308
396
 
397
+ @property
398
+ def workspace_client(self) -> WorkspaceClient | None:
399
+ """Get workspace client by delegating to impl."""
400
+ return self.impl.workspace_client
401
+
309
402
  def invalidate(self, question: str, conversation_id: str | None = None) -> bool:
310
403
  """
311
404
  Remove a specific entry from the cache.
@@ -349,3 +442,88 @@ class LRUCacheService(GenieServiceBase):
349
442
  "expired_entries": expired,
350
443
  "valid_entries": len(self._cache) - expired,
351
444
  }
445
+
446
+ @mlflow.trace(name="genie_lru_cache_send_feedback")
447
+ def send_feedback(
448
+ self,
449
+ conversation_id: str,
450
+ rating: GenieFeedbackRating,
451
+ message_id: str | None = None,
452
+ was_cache_hit: bool = False,
453
+ ) -> None:
454
+ """
455
+ Send feedback for a Genie message with cache invalidation.
456
+
457
+ For LRU cache, this method:
458
+ 1. If was_cache_hit is False: forwards feedback to the underlying service
459
+ 2. If rating is NEGATIVE: invalidates any matching cache entries
460
+
461
+ Args:
462
+ conversation_id: The conversation containing the message
463
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
464
+ message_id: Optional message ID. If None, looks up the most recent message.
465
+ was_cache_hit: Whether the response being rated was served from cache.
466
+
467
+ Note:
468
+ For cached responses (was_cache_hit=True), only cache invalidation is
469
+ performed. No feedback is sent to the Genie API because cached responses
470
+ don't have a corresponding Genie message.
471
+
472
+ Future Enhancement: To enable full Genie feedback for cached responses,
473
+ the cache would need to store the original message_id. See GenieServiceBase
474
+ docstring for details on required changes.
475
+ """
476
+ # Handle cache invalidation on negative feedback
477
+ invalidated = False
478
+ if rating == GenieFeedbackRating.NEGATIVE:
479
+ # For LRU cache, we invalidate by conversation_id since that's part of the key
480
+ # Iterate through cache and remove entries matching the conversation_id
481
+ with self._lock:
482
+ keys_to_remove: list[str] = []
483
+ for key, entry in self._cache.items():
484
+ if entry.conversation_id == conversation_id:
485
+ keys_to_remove.append(key)
486
+
487
+ for key in keys_to_remove:
488
+ del self._cache[key]
489
+ invalidated = True
490
+ logger.info(
491
+ "Invalidated cache entry due to negative feedback",
492
+ layer=self.name,
493
+ cache_key=key[:50],
494
+ conversation_id=conversation_id,
495
+ )
496
+
497
+ if not keys_to_remove:
498
+ logger.debug(
499
+ "No cache entries found to invalidate for negative feedback",
500
+ layer=self.name,
501
+ conversation_id=conversation_id,
502
+ )
503
+
504
+ # Forward feedback to underlying service if not a cache hit
505
+ # For cache hits, there's no Genie message to provide feedback on
506
+ if was_cache_hit:
507
+ logger.info(
508
+ "Skipping Genie API feedback - response was served from cache",
509
+ layer=self.name,
510
+ conversation_id=conversation_id,
511
+ rating=rating.value if rating else None,
512
+ cache_invalidated=invalidated,
513
+ )
514
+ return
515
+
516
+ # Forward to underlying service
517
+ logger.debug(
518
+ "Forwarding feedback to underlying service",
519
+ layer=self.name,
520
+ conversation_id=conversation_id,
521
+ rating=rating.value if rating else None,
522
+ delegating_to=type(self.impl).__name__,
523
+ )
524
+ self.impl.send_feedback(
525
+ conversation_id=conversation_id,
526
+ rating=rating,
527
+ message_id=message_id,
528
+ was_cache_hit=False, # Already handled, so pass False
529
+ )
dao_ai/genie/core.py CHANGED
@@ -1,35 +1,259 @@
1
1
  """
2
2
  Core Genie service implementation.
3
3
 
4
- This module provides the concrete implementation of GenieServiceBase
5
- that wraps the Databricks Genie SDK.
4
+ This module provides:
5
+ - Extended Genie and GenieResponse classes that capture message_id
6
+ - GenieService: Concrete implementation of GenieServiceBase
7
+
8
+ The extended classes wrap the databricks_ai_bridge versions to add message_id
9
+ support, which is needed for sending feedback to the Genie API.
6
10
  """
7
11
 
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import TYPE_CHECKING, Union
16
+
8
17
  import mlflow
9
- from databricks_ai_bridge.genie import Genie, GenieResponse
18
+ import pandas as pd
19
+ from databricks.sdk import WorkspaceClient
20
+ from databricks.sdk.service.dashboards import GenieFeedbackRating
21
+ from databricks_ai_bridge.genie import Genie as DatabricksGenie
22
+ from databricks_ai_bridge.genie import GenieResponse as DatabricksGenieResponse
23
+ from loguru import logger
10
24
 
11
25
  from dao_ai.genie.cache import CacheResult, GenieServiceBase
26
+ from dao_ai.genie.cache.base import get_latest_message_id
27
+
28
+ if TYPE_CHECKING:
29
+ from typing import Optional
30
+
31
+
32
+ # =============================================================================
33
+ # Extended Genie Classes with message_id Support
34
+ # =============================================================================
35
+
36
+
37
+ @dataclass
38
+ class GenieResponse(DatabricksGenieResponse):
39
+ """
40
+ Extended GenieResponse that includes message_id.
41
+
42
+ This extends the databricks_ai_bridge GenieResponse to capture the message_id
43
+ from API responses, which is required for sending feedback to the Genie API.
44
+
45
+ Attributes:
46
+ result: The query result as string or DataFrame
47
+ query: The generated SQL query
48
+ description: Description of the query
49
+ conversation_id: The conversation ID
50
+ message_id: The message ID (NEW - enables feedback without extra API call)
51
+ """
52
+
53
+ result: Union[str, pd.DataFrame] = ""
54
+ query: Optional[str] = ""
55
+ description: Optional[str] = ""
56
+ conversation_id: Optional[str] = None
57
+ message_id: Optional[str] = None
58
+
59
+
60
+ class Genie(DatabricksGenie):
61
+ """
62
+ Extended Genie that captures message_id in responses.
63
+
64
+ This extends the databricks_ai_bridge Genie to return GenieResponse objects
65
+ that include the message_id from the API response. This enables sending
66
+ feedback without requiring an additional API call to look up the message ID.
67
+
68
+ Usage:
69
+ genie = Genie(space_id="my-space")
70
+ response = genie.ask_question("What are total sales?")
71
+ print(response.message_id) # Now available!
72
+
73
+ The original databricks_ai_bridge classes are available as:
74
+ - DatabricksGenie
75
+ - DatabricksGenieResponse
76
+ """
77
+
78
+ def ask_question(
79
+ self, question: str, conversation_id: str | None = None
80
+ ) -> GenieResponse:
81
+ """
82
+ Ask a question and return response with message_id.
83
+
84
+ This overrides the parent method to capture the message_id from the
85
+ API response and include it in the returned GenieResponse.
86
+
87
+ Args:
88
+ question: The question to ask
89
+ conversation_id: Optional conversation ID for follow-up questions
90
+
91
+ Returns:
92
+ GenieResponse with message_id populated
93
+ """
94
+ with mlflow.start_span(name="ask_question"):
95
+ # Start or continue conversation
96
+ if not conversation_id:
97
+ resp = self.start_conversation(question)
98
+ else:
99
+ resp = self.create_message(conversation_id, question)
100
+
101
+ # Capture message_id from the API response
102
+ message_id = resp.get("message_id")
103
+
104
+ # Poll for the result using parent's method
105
+ genie_response = self.poll_for_result(resp["conversation_id"], message_id)
106
+
107
+ # Ensure conversation_id is set
108
+ if not genie_response.conversation_id:
109
+ genie_response.conversation_id = resp["conversation_id"]
110
+
111
+ # Return our extended response with message_id
112
+ return GenieResponse(
113
+ result=genie_response.result,
114
+ query=genie_response.query,
115
+ description=genie_response.description,
116
+ conversation_id=genie_response.conversation_id,
117
+ message_id=message_id,
118
+ )
119
+
120
+
121
+ # =============================================================================
122
+ # GenieService Implementation
123
+ # =============================================================================
12
124
 
13
125
 
14
126
  class GenieService(GenieServiceBase):
15
- """Concrete implementation of GenieServiceBase using the Genie SDK."""
127
+ """
128
+ Concrete implementation of GenieServiceBase using the extended Genie.
129
+
130
+ This service wraps the extended Genie class and provides the GenieServiceBase
131
+ interface for use with cache layers.
132
+ """
16
133
 
17
134
  genie: Genie
135
+ _workspace_client: WorkspaceClient | None
136
+
137
+ def __init__(
138
+ self,
139
+ genie: Genie | DatabricksGenie,
140
+ workspace_client: WorkspaceClient | None = None,
141
+ ) -> None:
142
+ """
143
+ Initialize the GenieService.
144
+
145
+ Args:
146
+ genie: The Genie instance for asking questions. Can be either our
147
+ extended Genie or the original DatabricksGenie.
148
+ workspace_client: Optional WorkspaceClient for feedback API.
149
+ If not provided, one will be created lazily when needed.
150
+ """
151
+ self.genie = genie # type: ignore[assignment]
152
+ self._workspace_client = workspace_client
153
+
154
+ @property
155
+ def workspace_client(self) -> WorkspaceClient:
156
+ """
157
+ Get or create a WorkspaceClient for API calls.
18
158
 
19
- def __init__(self, genie: Genie) -> None:
20
- self.genie = genie
159
+ Lazily creates a WorkspaceClient using default credentials if not provided.
160
+ """
161
+ if self._workspace_client is None:
162
+ self._workspace_client = WorkspaceClient()
163
+ return self._workspace_client
21
164
 
22
165
  @mlflow.trace(name="genie_ask_question")
23
166
  def ask_question(
24
167
  self, question: str, conversation_id: str | None = None
25
168
  ) -> CacheResult:
26
- """Ask question to Genie and return CacheResult (no caching at this level)."""
27
- response: GenieResponse = self.genie.ask_question(
28
- question, conversation_id=conversation_id
29
- )
169
+ """
170
+ Ask question to Genie and return CacheResult.
171
+
172
+ No caching at this level - returns cache miss with fresh response.
173
+ If using our extended Genie, the message_id will be captured in the response.
174
+ """
175
+ response = self.genie.ask_question(question, conversation_id=conversation_id)
176
+
177
+ # Extract message_id if available (from our extended GenieResponse)
178
+ message_id = getattr(response, "message_id", None)
179
+
30
180
  # No caching at this level - return cache miss
31
- return CacheResult(response=response, cache_hit=False, served_by=None)
181
+ return CacheResult(
182
+ response=response,
183
+ cache_hit=False,
184
+ served_by=None,
185
+ message_id=message_id,
186
+ )
32
187
 
33
188
  @property
34
189
  def space_id(self) -> str:
35
190
  return self.genie.space_id
191
+
192
+ @mlflow.trace(name="genie_send_feedback")
193
+ def send_feedback(
194
+ self,
195
+ conversation_id: str,
196
+ rating: GenieFeedbackRating,
197
+ message_id: str | None = None,
198
+ was_cache_hit: bool = False,
199
+ ) -> None:
200
+ """
201
+ Send feedback for a Genie message.
202
+
203
+ For the core GenieService, this always sends feedback to the Genie API
204
+ (the was_cache_hit parameter is ignored here - it's used by cache wrappers).
205
+
206
+ Args:
207
+ conversation_id: The conversation containing the message
208
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
209
+ message_id: Optional message ID. If None, looks up the most recent message.
210
+ was_cache_hit: Ignored by GenieService. Cache wrappers use this to decide
211
+ whether to forward feedback to the underlying service.
212
+ """
213
+ # Look up message_id if not provided
214
+ if message_id is None:
215
+ message_id = get_latest_message_id(
216
+ workspace_client=self.workspace_client,
217
+ space_id=self.space_id,
218
+ conversation_id=conversation_id,
219
+ )
220
+ if message_id is None:
221
+ logger.warning(
222
+ "Could not find message_id for feedback, skipping",
223
+ space_id=self.space_id,
224
+ conversation_id=conversation_id,
225
+ rating=rating.value if rating else None,
226
+ )
227
+ return
228
+
229
+ logger.info(
230
+ "Sending feedback to Genie",
231
+ space_id=self.space_id,
232
+ conversation_id=conversation_id,
233
+ message_id=message_id,
234
+ rating=rating.value if rating else None,
235
+ )
236
+
237
+ try:
238
+ self.workspace_client.genie.send_message_feedback(
239
+ space_id=self.space_id,
240
+ conversation_id=conversation_id,
241
+ message_id=message_id,
242
+ rating=rating,
243
+ )
244
+ logger.debug(
245
+ "Feedback sent successfully",
246
+ space_id=self.space_id,
247
+ conversation_id=conversation_id,
248
+ message_id=message_id,
249
+ )
250
+ except Exception as e:
251
+ logger.error(
252
+ "Failed to send feedback to Genie",
253
+ space_id=self.space_id,
254
+ conversation_id=conversation_id,
255
+ message_id=message_id,
256
+ rating=rating.value if rating else None,
257
+ error=str(e),
258
+ exc_info=True,
259
+ )
@@ -1,5 +1,5 @@
1
1
  # DAO AI Middleware Module
2
- # This module provides middleware implementations compatible with LangChain v1's create_agent
2
+ # Middleware implementations compatible with LangChain v1's create_agent
3
3
 
4
4
  # Re-export LangChain built-in middleware
5
5
  from langchain.agents.middleware import (
@@ -82,6 +82,10 @@ from dao_ai.middleware.summarization import (
82
82
  create_summarization_middleware,
83
83
  )
84
84
  from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
85
+ from dao_ai.middleware.tool_call_observability import (
86
+ ToolCallObservabilityMiddleware,
87
+ create_tool_call_observability_middleware,
88
+ )
85
89
  from dao_ai.middleware.tool_retry import create_tool_retry_middleware
86
90
  from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
87
91
 
@@ -160,4 +164,7 @@ __all__ = [
160
164
  "create_clear_tool_uses_edit",
161
165
  # PII middleware factory functions
162
166
  "create_pii_middleware",
167
+ # Tool call observability middleware
168
+ "ToolCallObservabilityMiddleware",
169
+ "create_tool_call_observability_middleware",
163
170
  ]