dao-ai 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/genie/cache/lru.py CHANGED
@@ -124,9 +124,7 @@ class LRUCacheService(GenieServiceBase):
124
124
  if self._cache:
125
125
  oldest_key: str = next(iter(self._cache))
126
126
  del self._cache[oldest_key]
127
- logger.trace(
128
- "Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
129
- )
127
+ logger.trace("Evicted cache entry", layer=self.name, key=oldest_key[:50])
130
128
 
131
129
  def _get(self, key: str) -> SQLCacheEntry | None:
132
130
  """Get from cache, returning None if not found or expired."""
@@ -137,7 +135,7 @@ class LRUCacheService(GenieServiceBase):
137
135
 
138
136
  if self._is_expired(entry):
139
137
  del self._cache[key]
140
- logger.trace("Expired cache entry", layer=self.name, key_prefix=key[:50])
138
+ logger.trace("Expired cache entry", layer=self.name, key=key[:50])
141
139
  return None
142
140
 
143
141
  self._cache.move_to_end(key)
@@ -157,11 +155,11 @@ class LRUCacheService(GenieServiceBase):
157
155
  conversation_id=response.conversation_id,
158
156
  created_at=datetime.now(),
159
157
  )
160
- logger.info(
158
+ logger.debug(
161
159
  "Stored cache entry",
162
160
  layer=self.name,
163
- key_prefix=key[:50],
164
- sql_prefix=response.query[:50] if response.query else None,
161
+ key=key[:50],
162
+ sql=response.query[:50] if response.query else None,
165
163
  cache_size=len(self._cache),
166
164
  capacity=self.capacity,
167
165
  )
@@ -180,7 +178,7 @@ class LRUCacheService(GenieServiceBase):
180
178
  w: WorkspaceClient = self.warehouse.workspace_client
181
179
  warehouse_id: str = str(self.warehouse.warehouse_id)
182
180
 
183
- logger.trace("Executing cached SQL", layer=self.name, sql_prefix=sql[:100])
181
+ logger.trace("Executing cached SQL", layer=self.name, sql=sql[:100])
184
182
 
185
183
  statement_response: StatementResponse = w.statement_execution.execute_statement(
186
184
  statement=sql,
@@ -258,13 +256,17 @@ class LRUCacheService(GenieServiceBase):
258
256
  cached: SQLCacheEntry | None = self._get(key)
259
257
 
260
258
  if cached is not None:
259
+ cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
261
260
  logger.info(
262
261
  "Cache HIT",
263
262
  layer=self.name,
264
- question_prefix=question[:50],
263
+ question=question[:80],
265
264
  conversation_id=conversation_id,
265
+ cached_sql=cached.query[:80] if cached.query else None,
266
+ cache_age_seconds=round(cache_age_seconds, 1),
266
267
  cache_size=self.size,
267
268
  capacity=self.capacity,
269
+ ttl_seconds=self.parameters.time_to_live_seconds,
268
270
  )
269
271
 
270
272
  # Re-execute the cached SQL to get fresh data
@@ -286,17 +288,19 @@ class LRUCacheService(GenieServiceBase):
286
288
  logger.info(
287
289
  "Cache MISS",
288
290
  layer=self.name,
289
- question_prefix=question[:50],
291
+ question=question[:80],
290
292
  conversation_id=conversation_id,
291
293
  cache_size=self.size,
292
294
  capacity=self.capacity,
295
+ ttl_seconds=self.parameters.time_to_live_seconds,
293
296
  delegating_to=type(self.impl).__name__,
294
297
  )
295
298
 
296
299
  result: CacheResult = self.impl.ask_question(question, conversation_id)
297
300
  with self._lock:
298
301
  self._put(key, result.response)
299
- return CacheResult(response=result.response, cache_hit=False, served_by=None)
302
+ # Propagate the inner cache's result - if it was a hit there, preserve that info
303
+ return result
300
304
 
301
305
  @property
302
306
  def space_id(self) -> str:
@@ -497,6 +497,7 @@ class SemanticCacheService(GenieServiceBase):
497
497
  conversation_context: str,
498
498
  question_embedding: list[float],
499
499
  context_embedding: list[float],
500
+ conversation_id: str | None = None,
500
501
  ) -> tuple[SQLCacheEntry, float] | None:
501
502
  """
502
503
  Find a semantically similar cached entry using dual embedding matching.
@@ -509,6 +510,7 @@ class SemanticCacheService(GenieServiceBase):
509
510
  conversation_context: The conversation context string
510
511
  question_embedding: The embedding vector of just the question
511
512
  context_embedding: The embedding vector of the conversation context
513
+ conversation_id: Optional conversation ID (for logging)
512
514
 
513
515
  Returns:
514
516
  Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
@@ -576,8 +578,9 @@ class SemanticCacheService(GenieServiceBase):
576
578
  logger.info(
577
579
  "Cache MISS (no entries)",
578
580
  layer=self.name,
579
- question_prefix=question[:50],
581
+ question=question[:50],
580
582
  space=self.space_id,
583
+ delegating_to=type(self.impl).__name__,
581
584
  )
582
585
  return None
583
586
 
@@ -602,8 +605,8 @@ class SemanticCacheService(GenieServiceBase):
602
605
  context_sim=f"{context_similarity:.4f}",
603
606
  combined_sim=f"{combined_similarity:.4f}",
604
607
  is_valid=is_valid,
605
- cached_question_prefix=cached_question[:50],
606
- cached_context_prefix=cached_context[:80],
608
+ cached_question=cached_question[:50],
609
+ cached_context=cached_context[:80],
607
610
  )
608
611
 
609
612
  # Check BOTH similarity thresholds (dual embedding precision check)
@@ -613,6 +616,7 @@ class SemanticCacheService(GenieServiceBase):
613
616
  layer=self.name,
614
617
  question_sim=f"{question_similarity:.4f}",
615
618
  threshold=self.parameters.similarity_threshold,
619
+ delegating_to=type(self.impl).__name__,
616
620
  )
617
621
  return None
618
622
 
@@ -622,6 +626,7 @@ class SemanticCacheService(GenieServiceBase):
622
626
  layer=self.name,
623
627
  context_sim=f"{context_similarity:.4f}",
624
628
  threshold=self.parameters.context_similarity_threshold,
629
+ delegating_to=type(self.impl).__name__,
625
630
  )
626
631
  return None
627
632
 
@@ -635,17 +640,32 @@ class SemanticCacheService(GenieServiceBase):
635
640
  layer=self.name,
636
641
  combined_sim=f"{combined_similarity:.4f}",
637
642
  ttl_seconds=ttl_seconds,
638
- cached_question_prefix=cached_question[:50],
643
+ cached_question=cached_question[:50],
644
+ delegating_to=type(self.impl).__name__,
639
645
  )
640
646
  return None
641
647
 
648
+ from datetime import datetime as dt
649
+
650
+ cache_age_seconds = (
651
+ (dt.now(created_at.tzinfo) - created_at).total_seconds()
652
+ if created_at
653
+ else None
654
+ )
642
655
  logger.info(
643
656
  "Cache HIT",
644
657
  layer=self.name,
645
- question_sim=f"{question_similarity:.4f}",
646
- context_sim=f"{context_similarity:.4f}",
647
- combined_sim=f"{combined_similarity:.4f}",
648
- cached_question_prefix=cached_question[:50],
658
+ question=question[:80],
659
+ conversation_id=conversation_id,
660
+ matched_question=cached_question[:80],
661
+ cache_age_seconds=round(cache_age_seconds, 1)
662
+ if cache_age_seconds
663
+ else None,
664
+ question_similarity=f"{question_similarity:.4f}",
665
+ context_similarity=f"{context_similarity:.4f}",
666
+ combined_similarity=f"{combined_similarity:.4f}",
667
+ cached_sql=sql_query[:80] if sql_query else None,
668
+ ttl_seconds=self.parameters.time_to_live_seconds,
649
669
  )
650
670
 
651
671
  entry = SQLCacheEntry(
@@ -696,12 +716,12 @@ class SemanticCacheService(GenieServiceBase):
696
716
  response.conversation_id,
697
717
  ),
698
718
  )
699
- logger.info(
719
+ logger.debug(
700
720
  "Stored cache entry",
701
721
  layer=self.name,
702
- question_prefix=question[:50],
703
- context_prefix=conversation_context[:80],
704
- sql_prefix=response.query[:50] if response.query else None,
722
+ question=question[:50],
723
+ context=conversation_context[:80],
724
+ sql=response.query[:50] if response.query else None,
705
725
  space=self.space_id,
706
726
  table=self.table_name,
707
727
  )
@@ -796,7 +816,11 @@ class SemanticCacheService(GenieServiceBase):
796
816
 
797
817
  # Check cache using dual embedding similarity
798
818
  cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
799
- question, conversation_context, question_embedding, context_embedding
819
+ question,
820
+ conversation_context,
821
+ question_embedding,
822
+ context_embedding,
823
+ conversation_id,
800
824
  )
801
825
 
802
826
  if cache_result is not None:
@@ -805,7 +829,8 @@ class SemanticCacheService(GenieServiceBase):
805
829
  "Semantic cache hit",
806
830
  layer=self.name,
807
831
  combined_similarity=f"{combined_similarity:.3f}",
808
- question_prefix=question[:50],
832
+ question=question[:50],
833
+ conversation_id=conversation_id,
809
834
  )
810
835
 
811
836
  # Re-execute the cached SQL to get fresh data
@@ -825,16 +850,25 @@ class SemanticCacheService(GenieServiceBase):
825
850
  return CacheResult(response=response, cache_hit=True, served_by=self.name)
826
851
 
827
852
  # Cache miss - delegate to wrapped service
828
- logger.trace("Cache miss", layer=self.name, question_prefix=question[:50])
853
+ logger.info(
854
+ "Cache MISS",
855
+ layer=self.name,
856
+ question=question[:80],
857
+ conversation_id=conversation_id,
858
+ space_id=self.space_id,
859
+ similarity_threshold=self.similarity_threshold,
860
+ delegating_to=type(self.impl).__name__,
861
+ )
829
862
 
830
863
  result: CacheResult = self.impl.ask_question(question, conversation_id)
831
864
 
832
865
  # Store in cache if we got a SQL query
833
866
  if result.response.query:
834
- logger.info(
867
+ logger.debug(
835
868
  "Storing new cache entry",
836
869
  layer=self.name,
837
- question_prefix=question[:50],
870
+ question=question[:50],
871
+ conversation_id=conversation_id,
838
872
  space=self.space_id,
839
873
  )
840
874
  self._store_entry(
@@ -848,7 +882,7 @@ class SemanticCacheService(GenieServiceBase):
848
882
  logger.warning(
849
883
  "Not caching: response has no SQL query",
850
884
  layer=self.name,
851
- question_prefix=question[:50],
885
+ question=question[:50],
852
886
  )
853
887
 
854
888
  return CacheResult(response=result.response, cache_hit=False, served_by=None)
dao_ai/memory/postgres.py CHANGED
@@ -3,6 +3,7 @@ import atexit
3
3
  import threading
4
4
  from typing import Any, Optional
5
5
 
6
+ from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
6
7
  from langgraph.checkpoint.base import BaseCheckpointSaver
7
8
  from langgraph.checkpoint.postgres import ShallowPostgresSaver
8
9
  from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
@@ -86,13 +87,22 @@ async def _create_async_pool(
86
87
 
87
88
 
88
89
  class AsyncPostgresPoolManager:
90
+ """
91
+ Asynchronous PostgreSQL connection pool manager that shares pools
92
+ based on database configuration.
93
+
94
+ For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
95
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
96
+ For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
97
+ """
98
+
89
99
  _pools: dict[str, AsyncConnectionPool] = {}
100
+ _lakebase_pools: dict[str, AsyncLakebasePool] = {}
90
101
  _lock: asyncio.Lock = asyncio.Lock()
91
102
 
92
103
  @classmethod
93
104
  async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
94
105
  connection_key: str = database.name
95
- connection_params: dict[str, Any] = database.connection_params
96
106
 
97
107
  async with cls._lock:
98
108
  if connection_key in cls._pools:
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
103
113
 
104
114
  logger.debug("Creating new async PostgreSQL pool", database=database.name)
105
115
 
106
- kwargs: dict[str, Any] = {
107
- "row_factory": dict_row,
108
- "autocommit": True,
109
- } | database.connection_kwargs or {}
110
-
111
- # Create connection pool
112
- pool: AsyncConnectionPool = await _create_async_pool(
113
- connection_params=connection_params,
114
- database_name=database.name,
115
- max_pool_size=database.max_pool_size,
116
- timeout_seconds=database.timeout_seconds,
117
- kwargs=kwargs,
118
- )
116
+ if database.is_lakebase:
117
+ # Use AsyncLakebasePool for Lakebase connections
118
+ # AsyncLakebasePool handles automatic token rotation and host resolution
119
+ lakebase_pool = AsyncLakebasePool(
120
+ instance_name=database.instance_name,
121
+ workspace_client=database.workspace_client,
122
+ min_size=1,
123
+ max_size=database.max_pool_size,
124
+ timeout=float(database.timeout_seconds),
125
+ )
126
+ # Open the async pool
127
+ await lakebase_pool.open()
128
+ # Store the AsyncLakebasePool for proper cleanup
129
+ cls._lakebase_pools[connection_key] = lakebase_pool
130
+ # Get the underlying AsyncConnectionPool
131
+ pool = lakebase_pool.pool
132
+ logger.success(
133
+ "Async Lakebase connection pool created",
134
+ database=database.name,
135
+ instance_name=database.instance_name,
136
+ pool_size=database.max_pool_size,
137
+ )
138
+ else:
139
+ # Use standard async PostgreSQL pool for non-Lakebase connections
140
+ connection_params: dict[str, Any] = database.connection_params
141
+ kwargs: dict[str, Any] = {
142
+ "row_factory": dict_row,
143
+ "autocommit": True,
144
+ } | database.connection_kwargs or {}
145
+
146
+ pool = await _create_async_pool(
147
+ connection_params=connection_params,
148
+ database_name=database.name,
149
+ max_pool_size=database.max_pool_size,
150
+ timeout_seconds=database.timeout_seconds,
151
+ kwargs=kwargs,
152
+ )
119
153
 
120
154
  cls._pools[connection_key] = pool
121
155
  return pool
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
125
159
  connection_key: str = database.name
126
160
 
127
161
  async with cls._lock:
128
- if connection_key in cls._pools:
162
+ # Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
163
+ if connection_key in cls._lakebase_pools:
164
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
165
+ await lakebase_pool.close()
166
+ cls._pools.pop(connection_key, None)
167
+ logger.debug("Async Lakebase pool closed", database=database.name)
168
+ elif connection_key in cls._pools:
129
169
  pool = cls._pools.pop(connection_key)
130
170
  await pool.close()
131
171
  logger.debug("Async PostgreSQL pool closed", database=database.name)
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
133
173
  @classmethod
134
174
  async def close_all_pools(cls):
135
175
  async with cls._lock:
176
+ # Close all AsyncLakebasePool instances first
177
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
178
+ try:
179
+ await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
180
+ logger.debug("Async Lakebase pool closed", pool=connection_key)
181
+ except asyncio.TimeoutError:
182
+ logger.warning(
183
+ "Timeout closing async Lakebase pool, forcing closure",
184
+ pool=connection_key,
185
+ )
186
+ except asyncio.CancelledError:
187
+ logger.warning(
188
+ "Async Lakebase pool closure cancelled (shutdown in progress)",
189
+ pool=connection_key,
190
+ )
191
+ except Exception as e:
192
+ logger.error(
193
+ "Error closing async Lakebase pool",
194
+ pool=connection_key,
195
+ error=str(e),
196
+ )
197
+ cls._lakebase_pools.clear()
198
+
199
+ # Close any remaining standard async PostgreSQL pools
136
200
  for connection_key, pool in cls._pools.items():
137
201
  try:
138
- # Use a short timeout to avoid blocking on pool closure
139
202
  await asyncio.wait_for(pool.close(), timeout=2.0)
140
203
  logger.debug("Async PostgreSQL pool closed", pool=connection_key)
141
204
  except asyncio.TimeoutError:
@@ -309,15 +372,19 @@ class PostgresPoolManager:
309
372
  """
310
373
  Synchronous PostgreSQL connection pool manager that shares pools
311
374
  based on database configuration.
375
+
376
+ For Lakebase connections (when instance_name is provided), uses LakebasePool
377
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
378
+ For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
312
379
  """
313
380
 
314
381
  _pools: dict[str, ConnectionPool] = {}
382
+ _lakebase_pools: dict[str, LakebasePool] = {}
315
383
  _lock: threading.Lock = threading.Lock()
316
384
 
317
385
  @classmethod
318
386
  def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
319
387
  connection_key: str = str(database.name)
320
- connection_params: dict[str, Any] = database.connection_params
321
388
 
322
389
  with cls._lock:
323
390
  if connection_key in cls._pools:
@@ -326,19 +393,41 @@ class PostgresPoolManager:
326
393
 
327
394
  logger.debug("Creating new PostgreSQL pool", database=database.name)
328
395
 
329
- kwargs: dict[str, Any] = {
330
- "row_factory": dict_row,
331
- "autocommit": True,
332
- } | database.connection_kwargs or {}
333
-
334
- # Create connection pool
335
- pool: ConnectionPool = _create_pool(
336
- connection_params=connection_params,
337
- database_name=database.name,
338
- max_pool_size=database.max_pool_size,
339
- timeout_seconds=database.timeout_seconds,
340
- kwargs=kwargs,
341
- )
396
+ if database.is_lakebase:
397
+ # Use LakebasePool for Lakebase connections
398
+ # LakebasePool handles automatic token rotation and host resolution
399
+ lakebase_pool = LakebasePool(
400
+ instance_name=database.instance_name,
401
+ workspace_client=database.workspace_client,
402
+ min_size=1,
403
+ max_size=database.max_pool_size,
404
+ timeout=float(database.timeout_seconds),
405
+ )
406
+ # Store the LakebasePool for proper cleanup
407
+ cls._lakebase_pools[connection_key] = lakebase_pool
408
+ # Get the underlying ConnectionPool
409
+ pool = lakebase_pool.pool
410
+ logger.success(
411
+ "Lakebase connection pool created",
412
+ database=database.name,
413
+ instance_name=database.instance_name,
414
+ pool_size=database.max_pool_size,
415
+ )
416
+ else:
417
+ # Use standard PostgreSQL pool for non-Lakebase connections
418
+ connection_params: dict[str, Any] = database.connection_params
419
+ kwargs: dict[str, Any] = {
420
+ "row_factory": dict_row,
421
+ "autocommit": True,
422
+ } | database.connection_kwargs or {}
423
+
424
+ pool = _create_pool(
425
+ connection_params=connection_params,
426
+ database_name=database.name,
427
+ max_pool_size=database.max_pool_size,
428
+ timeout_seconds=database.timeout_seconds,
429
+ kwargs=kwargs,
430
+ )
342
431
 
343
432
  cls._pools[connection_key] = pool
344
433
  return pool
@@ -348,7 +437,13 @@ class PostgresPoolManager:
348
437
  connection_key: str = database.name
349
438
 
350
439
  with cls._lock:
351
- if connection_key in cls._pools:
440
+ # Close LakebasePool if it exists (handles underlying pool cleanup)
441
+ if connection_key in cls._lakebase_pools:
442
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
443
+ lakebase_pool.close()
444
+ cls._pools.pop(connection_key, None)
445
+ logger.debug("Lakebase pool closed", database=database.name)
446
+ elif connection_key in cls._pools:
352
447
  pool = cls._pools.pop(connection_key)
353
448
  pool.close()
354
449
  logger.debug("PostgreSQL pool closed", database=database.name)
@@ -356,16 +451,32 @@ class PostgresPoolManager:
356
451
  @classmethod
357
452
  def close_all_pools(cls):
358
453
  with cls._lock:
359
- for connection_key, pool in cls._pools.items():
454
+ # Close all LakebasePool instances first
455
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
360
456
  try:
361
- pool.close()
362
- logger.debug("PostgreSQL pool closed", pool=connection_key)
457
+ lakebase_pool.close()
458
+ logger.debug("Lakebase pool closed", pool=connection_key)
363
459
  except Exception as e:
364
460
  logger.error(
365
- "Error closing PostgreSQL pool",
461
+ "Error closing Lakebase pool",
366
462
  pool=connection_key,
367
463
  error=str(e),
368
464
  )
465
+ cls._lakebase_pools.clear()
466
+
467
+ # Close any remaining standard PostgreSQL pools
468
+ for connection_key, pool in cls._pools.items():
469
+ # Skip if already closed via LakebasePool
470
+ if connection_key not in cls._lakebase_pools:
471
+ try:
472
+ pool.close()
473
+ logger.debug("PostgreSQL pool closed", pool=connection_key)
474
+ except Exception as e:
475
+ logger.error(
476
+ "Error closing PostgreSQL pool",
477
+ pool=connection_key,
478
+ error=str(e),
479
+ )
369
480
  cls._pools.clear()
370
481
 
371
482
 
@@ -9,7 +9,7 @@ This module provides the foundational utilities for multi-agent orchestration:
9
9
  - Main orchestration graph factory
10
10
  """
11
11
 
12
- from typing import Awaitable, Callable, Literal
12
+ from typing import Any, Awaitable, Callable, Literal
13
13
 
14
14
  from langchain.tools import ToolRuntime, tool
15
15
  from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
@@ -179,8 +179,16 @@ def create_agent_node_handler(
179
179
  "messages": filtered_messages,
180
180
  }
181
181
 
182
- # Invoke the agent
183
- result: AgentState = await agent.ainvoke(agent_state, context=runtime.context)
182
+ # Build config with configurable from context for langmem compatibility
183
+ # langmem tools expect user_id to be in config.configurable
184
+ config: dict[str, Any] = {}
185
+ if runtime.context:
186
+ config = {"configurable": runtime.context.model_dump()}
187
+
188
+ # Invoke the agent with both context and config
189
+ result: AgentState = await agent.ainvoke(
190
+ agent_state, context=runtime.context, config=config
191
+ )
184
192
 
185
193
  # Extract agent response based on output mode
186
194
  result_messages = result.get("messages", [])
@@ -227,15 +235,31 @@ def create_handoff_tool(
227
235
  tool_call_id: str = runtime.tool_call_id
228
236
  logger.debug("Handoff to agent", target_agent=target_agent_name)
229
237
 
238
+ # Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
239
+ # LLMs expect tool calls to be paired with their responses, so we must include both
240
+ # the AIMessage containing the tool call and the ToolMessage acknowledging it.
241
+ messages: list[BaseMessage] = runtime.state.get("messages", [])
242
+ last_ai_message: AIMessage | None = None
243
+ for msg in reversed(messages):
244
+ if isinstance(msg, AIMessage) and msg.tool_calls:
245
+ last_ai_message = msg
246
+ break
247
+
248
+ # Build message list with proper pairing
249
+ update_messages: list[BaseMessage] = []
250
+ if last_ai_message:
251
+ update_messages.append(last_ai_message)
252
+ update_messages.append(
253
+ ToolMessage(
254
+ content=f"Transferred to {target_agent_name}",
255
+ tool_call_id=tool_call_id,
256
+ )
257
+ )
258
+
230
259
  return Command(
231
260
  update={
232
261
  "active_agent": target_agent_name,
233
- "messages": [
234
- ToolMessage(
235
- content=f"Transferred to {target_agent_name}",
236
- tool_call_id=tool_call_id,
237
- )
238
- ],
262
+ "messages": update_messages,
239
263
  },
240
264
  goto=target_agent_name,
241
265
  graph=Command.PARENT,
@@ -13,7 +13,7 @@ from langchain.agents import create_agent
13
13
  from langchain.agents.middleware import AgentMiddleware as LangchainAgentMiddleware
14
14
  from langchain.tools import ToolRuntime, tool
15
15
  from langchain_core.language_models import LanguageModelLike
16
- from langchain_core.messages import ToolMessage
16
+ from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
17
17
  from langchain_core.tools import BaseTool
18
18
  from langgraph.checkpoint.base import BaseCheckpointSaver
19
19
  from langgraph.graph import StateGraph
@@ -75,15 +75,30 @@ def _create_handoff_back_to_supervisor_tool() -> BaseTool:
75
75
  tool_call_id: str = runtime.tool_call_id
76
76
  logger.debug("Agent handing back to supervisor", summary_preview=summary[:100])
77
77
 
78
+ # Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
79
+ # LLMs expect tool calls to be paired with their responses, so we must include both
80
+ # the AIMessage containing the tool call and the ToolMessage acknowledging it.
81
+ messages: list[BaseMessage] = runtime.state.get("messages", [])
82
+ last_ai_message: AIMessage | None = None
83
+ for msg in reversed(messages):
84
+ if isinstance(msg, AIMessage) and msg.tool_calls:
85
+ last_ai_message = msg
86
+ break
87
+
88
+ # Build message list with proper pairing
89
+ update_messages: list[BaseMessage] = []
90
+ if last_ai_message:
91
+ update_messages.append(last_ai_message)
92
+ update_messages.append(
93
+ ToolMessage(
94
+ content=f"Task completed: {summary}",
95
+ tool_call_id=tool_call_id,
96
+ )
97
+ )
98
+
78
99
  return Command(
79
100
  update={
80
- "active_agent": None,
81
- "messages": [
82
- ToolMessage(
83
- content=f"Task completed: {summary}",
84
- tool_call_id=tool_call_id,
85
- )
86
- ],
101
+ "messages": update_messages,
87
102
  },
88
103
  goto=SUPERVISOR_NODE,
89
104
  graph=Command.PARENT,
@@ -2,9 +2,11 @@
2
2
  Prompt utilities for DAO AI agents.
3
3
 
4
4
  This module provides utilities for creating dynamic prompts using
5
- LangChain v1's @dynamic_prompt middleware decorator pattern.
5
+ LangChain v1's @dynamic_prompt middleware decorator pattern, as well as
6
+ paths to prompt template files.
6
7
  """
7
8
 
9
+ from pathlib import Path
8
10
  from typing import Any, Optional
9
11
 
10
12
  from langchain.agents.middleware import (
@@ -18,6 +20,13 @@ from loguru import logger
18
20
  from dao_ai.config import PromptModel
19
21
  from dao_ai.state import Context
20
22
 
23
+ PROMPTS_DIR = Path(__file__).parent
24
+
25
+
26
+ def get_prompt_path(name: str) -> Path:
27
+ """Get the path to a prompt template file."""
28
+ return PROMPTS_DIR / name
29
+
21
30
 
22
31
  def make_prompt(
23
32
  base_system_prompt: Optional[str | PromptModel],