dao-ai 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +12 -4
- dao_ai/config.py +471 -44
- dao_ai/evaluation.py +543 -0
- dao_ai/memory/postgres.py +146 -35
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +10 -1
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/databricks.py +33 -12
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +21 -10
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/vector_search.py +441 -134
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.16.dist-info → dao_ai-0.1.18.dist-info}/METADATA +3 -3
- {dao_ai-0.1.16.dist-info → dao_ai-0.1.18.dist-info}/RECORD +26 -17
- {dao_ai-0.1.16.dist-info → dao_ai-0.1.18.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.16.dist-info → dao_ai-0.1.18.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.16.dist-info → dao_ai-0.1.18.dist-info}/licenses/LICENSE +0 -0
dao_ai/memory/postgres.py
CHANGED
|
@@ -3,6 +3,7 @@ import atexit
|
|
|
3
3
|
import threading
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
|
+
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
|
|
6
7
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
7
8
|
from langgraph.checkpoint.postgres import ShallowPostgresSaver
|
|
8
9
|
from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
|
|
@@ -86,13 +87,22 @@ async def _create_async_pool(
|
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
class AsyncPostgresPoolManager:
|
|
90
|
+
"""
|
|
91
|
+
Asynchronous PostgreSQL connection pool manager that shares pools
|
|
92
|
+
based on database configuration.
|
|
93
|
+
|
|
94
|
+
For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
|
|
95
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
96
|
+
For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
|
|
97
|
+
"""
|
|
98
|
+
|
|
89
99
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
100
|
+
_lakebase_pools: dict[str, AsyncLakebasePool] = {}
|
|
90
101
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
91
102
|
|
|
92
103
|
@classmethod
|
|
93
104
|
async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
|
|
94
105
|
connection_key: str = database.name
|
|
95
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
96
106
|
|
|
97
107
|
async with cls._lock:
|
|
98
108
|
if connection_key in cls._pools:
|
|
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
|
|
|
103
113
|
|
|
104
114
|
logger.debug("Creating new async PostgreSQL pool", database=database.name)
|
|
105
115
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
116
|
+
if database.is_lakebase:
|
|
117
|
+
# Use AsyncLakebasePool for Lakebase connections
|
|
118
|
+
# AsyncLakebasePool handles automatic token rotation and host resolution
|
|
119
|
+
lakebase_pool = AsyncLakebasePool(
|
|
120
|
+
instance_name=database.instance_name,
|
|
121
|
+
workspace_client=database.workspace_client,
|
|
122
|
+
min_size=1,
|
|
123
|
+
max_size=database.max_pool_size,
|
|
124
|
+
timeout=float(database.timeout_seconds),
|
|
125
|
+
)
|
|
126
|
+
# Open the async pool
|
|
127
|
+
await lakebase_pool.open()
|
|
128
|
+
# Store the AsyncLakebasePool for proper cleanup
|
|
129
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
130
|
+
# Get the underlying AsyncConnectionPool
|
|
131
|
+
pool = lakebase_pool.pool
|
|
132
|
+
logger.success(
|
|
133
|
+
"Async Lakebase connection pool created",
|
|
134
|
+
database=database.name,
|
|
135
|
+
instance_name=database.instance_name,
|
|
136
|
+
pool_size=database.max_pool_size,
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
# Use standard async PostgreSQL pool for non-Lakebase connections
|
|
140
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
141
|
+
kwargs: dict[str, Any] = {
|
|
142
|
+
"row_factory": dict_row,
|
|
143
|
+
"autocommit": True,
|
|
144
|
+
} | database.connection_kwargs or {}
|
|
145
|
+
|
|
146
|
+
pool = await _create_async_pool(
|
|
147
|
+
connection_params=connection_params,
|
|
148
|
+
database_name=database.name,
|
|
149
|
+
max_pool_size=database.max_pool_size,
|
|
150
|
+
timeout_seconds=database.timeout_seconds,
|
|
151
|
+
kwargs=kwargs,
|
|
152
|
+
)
|
|
119
153
|
|
|
120
154
|
cls._pools[connection_key] = pool
|
|
121
155
|
return pool
|
|
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
|
|
|
125
159
|
connection_key: str = database.name
|
|
126
160
|
|
|
127
161
|
async with cls._lock:
|
|
128
|
-
if
|
|
162
|
+
# Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
|
|
163
|
+
if connection_key in cls._lakebase_pools:
|
|
164
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
165
|
+
await lakebase_pool.close()
|
|
166
|
+
cls._pools.pop(connection_key, None)
|
|
167
|
+
logger.debug("Async Lakebase pool closed", database=database.name)
|
|
168
|
+
elif connection_key in cls._pools:
|
|
129
169
|
pool = cls._pools.pop(connection_key)
|
|
130
170
|
await pool.close()
|
|
131
171
|
logger.debug("Async PostgreSQL pool closed", database=database.name)
|
|
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
|
|
|
133
173
|
@classmethod
|
|
134
174
|
async def close_all_pools(cls):
|
|
135
175
|
async with cls._lock:
|
|
176
|
+
# Close all AsyncLakebasePool instances first
|
|
177
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
178
|
+
try:
|
|
179
|
+
await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
|
|
180
|
+
logger.debug("Async Lakebase pool closed", pool=connection_key)
|
|
181
|
+
except asyncio.TimeoutError:
|
|
182
|
+
logger.warning(
|
|
183
|
+
"Timeout closing async Lakebase pool, forcing closure",
|
|
184
|
+
pool=connection_key,
|
|
185
|
+
)
|
|
186
|
+
except asyncio.CancelledError:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"Async Lakebase pool closure cancelled (shutdown in progress)",
|
|
189
|
+
pool=connection_key,
|
|
190
|
+
)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.error(
|
|
193
|
+
"Error closing async Lakebase pool",
|
|
194
|
+
pool=connection_key,
|
|
195
|
+
error=str(e),
|
|
196
|
+
)
|
|
197
|
+
cls._lakebase_pools.clear()
|
|
198
|
+
|
|
199
|
+
# Close any remaining standard async PostgreSQL pools
|
|
136
200
|
for connection_key, pool in cls._pools.items():
|
|
137
201
|
try:
|
|
138
|
-
# Use a short timeout to avoid blocking on pool closure
|
|
139
202
|
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
140
203
|
logger.debug("Async PostgreSQL pool closed", pool=connection_key)
|
|
141
204
|
except asyncio.TimeoutError:
|
|
@@ -309,15 +372,19 @@ class PostgresPoolManager:
|
|
|
309
372
|
"""
|
|
310
373
|
Synchronous PostgreSQL connection pool manager that shares pools
|
|
311
374
|
based on database configuration.
|
|
375
|
+
|
|
376
|
+
For Lakebase connections (when instance_name is provided), uses LakebasePool
|
|
377
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
378
|
+
For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
|
|
312
379
|
"""
|
|
313
380
|
|
|
314
381
|
_pools: dict[str, ConnectionPool] = {}
|
|
382
|
+
_lakebase_pools: dict[str, LakebasePool] = {}
|
|
315
383
|
_lock: threading.Lock = threading.Lock()
|
|
316
384
|
|
|
317
385
|
@classmethod
|
|
318
386
|
def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
|
|
319
387
|
connection_key: str = str(database.name)
|
|
320
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
321
388
|
|
|
322
389
|
with cls._lock:
|
|
323
390
|
if connection_key in cls._pools:
|
|
@@ -326,19 +393,41 @@ class PostgresPoolManager:
|
|
|
326
393
|
|
|
327
394
|
logger.debug("Creating new PostgreSQL pool", database=database.name)
|
|
328
395
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
396
|
+
if database.is_lakebase:
|
|
397
|
+
# Use LakebasePool for Lakebase connections
|
|
398
|
+
# LakebasePool handles automatic token rotation and host resolution
|
|
399
|
+
lakebase_pool = LakebasePool(
|
|
400
|
+
instance_name=database.instance_name,
|
|
401
|
+
workspace_client=database.workspace_client,
|
|
402
|
+
min_size=1,
|
|
403
|
+
max_size=database.max_pool_size,
|
|
404
|
+
timeout=float(database.timeout_seconds),
|
|
405
|
+
)
|
|
406
|
+
# Store the LakebasePool for proper cleanup
|
|
407
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
408
|
+
# Get the underlying ConnectionPool
|
|
409
|
+
pool = lakebase_pool.pool
|
|
410
|
+
logger.success(
|
|
411
|
+
"Lakebase connection pool created",
|
|
412
|
+
database=database.name,
|
|
413
|
+
instance_name=database.instance_name,
|
|
414
|
+
pool_size=database.max_pool_size,
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
# Use standard PostgreSQL pool for non-Lakebase connections
|
|
418
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
419
|
+
kwargs: dict[str, Any] = {
|
|
420
|
+
"row_factory": dict_row,
|
|
421
|
+
"autocommit": True,
|
|
422
|
+
} | database.connection_kwargs or {}
|
|
423
|
+
|
|
424
|
+
pool = _create_pool(
|
|
425
|
+
connection_params=connection_params,
|
|
426
|
+
database_name=database.name,
|
|
427
|
+
max_pool_size=database.max_pool_size,
|
|
428
|
+
timeout_seconds=database.timeout_seconds,
|
|
429
|
+
kwargs=kwargs,
|
|
430
|
+
)
|
|
342
431
|
|
|
343
432
|
cls._pools[connection_key] = pool
|
|
344
433
|
return pool
|
|
@@ -348,7 +437,13 @@ class PostgresPoolManager:
|
|
|
348
437
|
connection_key: str = database.name
|
|
349
438
|
|
|
350
439
|
with cls._lock:
|
|
351
|
-
if
|
|
440
|
+
# Close LakebasePool if it exists (handles underlying pool cleanup)
|
|
441
|
+
if connection_key in cls._lakebase_pools:
|
|
442
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
443
|
+
lakebase_pool.close()
|
|
444
|
+
cls._pools.pop(connection_key, None)
|
|
445
|
+
logger.debug("Lakebase pool closed", database=database.name)
|
|
446
|
+
elif connection_key in cls._pools:
|
|
352
447
|
pool = cls._pools.pop(connection_key)
|
|
353
448
|
pool.close()
|
|
354
449
|
logger.debug("PostgreSQL pool closed", database=database.name)
|
|
@@ -356,16 +451,32 @@ class PostgresPoolManager:
|
|
|
356
451
|
@classmethod
|
|
357
452
|
def close_all_pools(cls):
|
|
358
453
|
with cls._lock:
|
|
359
|
-
|
|
454
|
+
# Close all LakebasePool instances first
|
|
455
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
360
456
|
try:
|
|
361
|
-
|
|
362
|
-
logger.debug("
|
|
457
|
+
lakebase_pool.close()
|
|
458
|
+
logger.debug("Lakebase pool closed", pool=connection_key)
|
|
363
459
|
except Exception as e:
|
|
364
460
|
logger.error(
|
|
365
|
-
"Error closing
|
|
461
|
+
"Error closing Lakebase pool",
|
|
366
462
|
pool=connection_key,
|
|
367
463
|
error=str(e),
|
|
368
464
|
)
|
|
465
|
+
cls._lakebase_pools.clear()
|
|
466
|
+
|
|
467
|
+
# Close any remaining standard PostgreSQL pools
|
|
468
|
+
for connection_key, pool in cls._pools.items():
|
|
469
|
+
# Skip if already closed via LakebasePool
|
|
470
|
+
if connection_key not in cls._lakebase_pools:
|
|
471
|
+
try:
|
|
472
|
+
pool.close()
|
|
473
|
+
logger.debug("PostgreSQL pool closed", pool=connection_key)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error(
|
|
476
|
+
"Error closing PostgreSQL pool",
|
|
477
|
+
pool=connection_key,
|
|
478
|
+
error=str(e),
|
|
479
|
+
)
|
|
369
480
|
cls._pools.clear()
|
|
370
481
|
|
|
371
482
|
|
dao_ai/orchestration/core.py
CHANGED
|
@@ -9,7 +9,7 @@ This module provides the foundational utilities for multi-agent orchestration:
|
|
|
9
9
|
- Main orchestration graph factory
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from typing import Awaitable, Callable, Literal
|
|
12
|
+
from typing import Any, Awaitable, Callable, Literal
|
|
13
13
|
|
|
14
14
|
from langchain.tools import ToolRuntime, tool
|
|
15
15
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
|
@@ -179,8 +179,16 @@ def create_agent_node_handler(
|
|
|
179
179
|
"messages": filtered_messages,
|
|
180
180
|
}
|
|
181
181
|
|
|
182
|
-
#
|
|
183
|
-
|
|
182
|
+
# Build config with configurable from context for langmem compatibility
|
|
183
|
+
# langmem tools expect user_id to be in config.configurable
|
|
184
|
+
config: dict[str, Any] = {}
|
|
185
|
+
if runtime.context:
|
|
186
|
+
config = {"configurable": runtime.context.model_dump()}
|
|
187
|
+
|
|
188
|
+
# Invoke the agent with both context and config
|
|
189
|
+
result: AgentState = await agent.ainvoke(
|
|
190
|
+
agent_state, context=runtime.context, config=config
|
|
191
|
+
)
|
|
184
192
|
|
|
185
193
|
# Extract agent response based on output mode
|
|
186
194
|
result_messages = result.get("messages", [])
|
|
@@ -227,15 +235,31 @@ def create_handoff_tool(
|
|
|
227
235
|
tool_call_id: str = runtime.tool_call_id
|
|
228
236
|
logger.debug("Handoff to agent", target_agent=target_agent_name)
|
|
229
237
|
|
|
238
|
+
# Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
|
|
239
|
+
# LLMs expect tool calls to be paired with their responses, so we must include both
|
|
240
|
+
# the AIMessage containing the tool call and the ToolMessage acknowledging it.
|
|
241
|
+
messages: list[BaseMessage] = runtime.state.get("messages", [])
|
|
242
|
+
last_ai_message: AIMessage | None = None
|
|
243
|
+
for msg in reversed(messages):
|
|
244
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
245
|
+
last_ai_message = msg
|
|
246
|
+
break
|
|
247
|
+
|
|
248
|
+
# Build message list with proper pairing
|
|
249
|
+
update_messages: list[BaseMessage] = []
|
|
250
|
+
if last_ai_message:
|
|
251
|
+
update_messages.append(last_ai_message)
|
|
252
|
+
update_messages.append(
|
|
253
|
+
ToolMessage(
|
|
254
|
+
content=f"Transferred to {target_agent_name}",
|
|
255
|
+
tool_call_id=tool_call_id,
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
230
259
|
return Command(
|
|
231
260
|
update={
|
|
232
261
|
"active_agent": target_agent_name,
|
|
233
|
-
"messages":
|
|
234
|
-
ToolMessage(
|
|
235
|
-
content=f"Transferred to {target_agent_name}",
|
|
236
|
-
tool_call_id=tool_call_id,
|
|
237
|
-
)
|
|
238
|
-
],
|
|
262
|
+
"messages": update_messages,
|
|
239
263
|
},
|
|
240
264
|
goto=target_agent_name,
|
|
241
265
|
graph=Command.PARENT,
|
|
@@ -13,7 +13,7 @@ from langchain.agents import create_agent
|
|
|
13
13
|
from langchain.agents.middleware import AgentMiddleware as LangchainAgentMiddleware
|
|
14
14
|
from langchain.tools import ToolRuntime, tool
|
|
15
15
|
from langchain_core.language_models import LanguageModelLike
|
|
16
|
-
from langchain_core.messages import ToolMessage
|
|
16
|
+
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
17
17
|
from langchain_core.tools import BaseTool
|
|
18
18
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
19
19
|
from langgraph.graph import StateGraph
|
|
@@ -75,15 +75,30 @@ def _create_handoff_back_to_supervisor_tool() -> BaseTool:
|
|
|
75
75
|
tool_call_id: str = runtime.tool_call_id
|
|
76
76
|
logger.debug("Agent handing back to supervisor", summary_preview=summary[:100])
|
|
77
77
|
|
|
78
|
+
# Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
|
|
79
|
+
# LLMs expect tool calls to be paired with their responses, so we must include both
|
|
80
|
+
# the AIMessage containing the tool call and the ToolMessage acknowledging it.
|
|
81
|
+
messages: list[BaseMessage] = runtime.state.get("messages", [])
|
|
82
|
+
last_ai_message: AIMessage | None = None
|
|
83
|
+
for msg in reversed(messages):
|
|
84
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
85
|
+
last_ai_message = msg
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
# Build message list with proper pairing
|
|
89
|
+
update_messages: list[BaseMessage] = []
|
|
90
|
+
if last_ai_message:
|
|
91
|
+
update_messages.append(last_ai_message)
|
|
92
|
+
update_messages.append(
|
|
93
|
+
ToolMessage(
|
|
94
|
+
content=f"Task completed: {summary}",
|
|
95
|
+
tool_call_id=tool_call_id,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
78
99
|
return Command(
|
|
79
100
|
update={
|
|
80
|
-
"
|
|
81
|
-
"messages": [
|
|
82
|
-
ToolMessage(
|
|
83
|
-
content=f"Task completed: {summary}",
|
|
84
|
-
tool_call_id=tool_call_id,
|
|
85
|
-
)
|
|
86
|
-
],
|
|
101
|
+
"messages": update_messages,
|
|
87
102
|
},
|
|
88
103
|
goto=SUPERVISOR_NODE,
|
|
89
104
|
graph=Command.PARENT,
|
dao_ai/orchestration/swarm.py
CHANGED
|
@@ -167,8 +167,13 @@ def create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
167
167
|
default_agent: str
|
|
168
168
|
if isinstance(swarm.default_agent, AgentModel):
|
|
169
169
|
default_agent = swarm.default_agent.name
|
|
170
|
-
|
|
170
|
+
elif swarm.default_agent is not None:
|
|
171
171
|
default_agent = swarm.default_agent
|
|
172
|
+
elif len(config.app.agents) > 0:
|
|
173
|
+
# Fallback to first agent if no default specified
|
|
174
|
+
default_agent = config.app.agents[0].name
|
|
175
|
+
else:
|
|
176
|
+
raise ValueError("Swarm requires at least one agent and a default_agent")
|
|
172
177
|
|
|
173
178
|
logger.info(
|
|
174
179
|
"Creating swarm graph",
|
|
@@ -2,9 +2,11 @@
|
|
|
2
2
|
Prompt utilities for DAO AI agents.
|
|
3
3
|
|
|
4
4
|
This module provides utilities for creating dynamic prompts using
|
|
5
|
-
LangChain v1's @dynamic_prompt middleware decorator pattern
|
|
5
|
+
LangChain v1's @dynamic_prompt middleware decorator pattern, as well as
|
|
6
|
+
paths to prompt template files.
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
from pathlib import Path
|
|
8
10
|
from typing import Any, Optional
|
|
9
11
|
|
|
10
12
|
from langchain.agents.middleware import (
|
|
@@ -18,6 +20,13 @@ from loguru import logger
|
|
|
18
20
|
from dao_ai.config import PromptModel
|
|
19
21
|
from dao_ai.state import Context
|
|
20
22
|
|
|
23
|
+
PROMPTS_DIR = Path(__file__).parent
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_prompt_path(name: str) -> Path:
|
|
27
|
+
"""Get the path to a prompt template file."""
|
|
28
|
+
return PROMPTS_DIR / name
|
|
29
|
+
|
|
21
30
|
|
|
22
31
|
def make_prompt(
|
|
23
32
|
base_system_prompt: Optional[str | PromptModel],
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
name: instructed_retriever_decomposition
|
|
2
|
+
description: Decomposes user queries into multiple search queries with metadata filters
|
|
3
|
+
|
|
4
|
+
template: |
|
|
5
|
+
You are a search query decomposition expert. Your task is to break down a user query into one or more focused search queries with appropriate metadata filters. Respond with a JSON object.
|
|
6
|
+
|
|
7
|
+
## Current Time
|
|
8
|
+
{current_time}
|
|
9
|
+
|
|
10
|
+
## Database Schema
|
|
11
|
+
{schema_description}
|
|
12
|
+
|
|
13
|
+
## Constraints
|
|
14
|
+
{constraints}
|
|
15
|
+
|
|
16
|
+
## Few-Shot Examples
|
|
17
|
+
{examples}
|
|
18
|
+
|
|
19
|
+
## Instructions
|
|
20
|
+
1. Analyze the user query and identify distinct search intents
|
|
21
|
+
2. For each intent, create a focused search query text
|
|
22
|
+
3. Extract metadata filters from the query using the exact filter syntax above
|
|
23
|
+
4. Resolve relative time references (e.g., "last month", "past year") using the current time
|
|
24
|
+
5. Generate at most {max_subqueries} search queries
|
|
25
|
+
6. If no filters apply, set filters to null
|
|
26
|
+
|
|
27
|
+
## User Query
|
|
28
|
+
{query}
|
|
29
|
+
|
|
30
|
+
Generate search queries that together capture all aspects of the user's information need.
|
|
31
|
+
|
|
32
|
+
variables:
|
|
33
|
+
- current_time
|
|
34
|
+
- schema_description
|
|
35
|
+
- constraints
|
|
36
|
+
- examples
|
|
37
|
+
- max_subqueries
|
|
38
|
+
- query
|
|
39
|
+
|
|
40
|
+
output_format: |
|
|
41
|
+
The output must be a JSON object with a "queries" field containing an array of search query objects.
|
|
42
|
+
Each search query object has:
|
|
43
|
+
- "text": The search query string
|
|
44
|
+
- "filters": An array of filter objects, each with "key" (column + optional operator) and "value", or null if no filters
|
|
45
|
+
|
|
46
|
+
Supported filter operators (append to column name):
|
|
47
|
+
- Equality: {"key": "column", "value": "val"} or {"key": "column", "value": ["val1", "val2"]}
|
|
48
|
+
- Exclusion: {"key": "column NOT", "value": "val"}
|
|
49
|
+
- Comparison: {"key": "column <", "value": 100}, also <=, >, >=
|
|
50
|
+
- Token match: {"key": "column LIKE", "value": "word"}
|
|
51
|
+
- Exclude token: {"key": "column NOT LIKE", "value": "word"}
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
- [{"key": "brand_name", "value": "MILWAUKEE"}]
|
|
55
|
+
- [{"key": "price <", "value": 100}]
|
|
56
|
+
- [{"key": "brand_name NOT", "value": "DEWALT"}]
|
|
57
|
+
- [{"key": "brand_name", "value": ["MILWAUKEE", "DEWALT"]}]
|
|
58
|
+
- [{"key": "description LIKE", "value": "cordless"}]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
name: instruction_aware_reranking
|
|
2
|
+
version: "1.1"
|
|
3
|
+
description: Rerank documents based on user instructions and constraints
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
Rerank these search results for the query "{query}".
|
|
7
|
+
|
|
8
|
+
{instructions}
|
|
9
|
+
|
|
10
|
+
## Documents
|
|
11
|
+
|
|
12
|
+
{documents}
|
|
13
|
+
|
|
14
|
+
Score each document 0.0-1.0 based on relevance to the query and instructions. Return results sorted by score (highest first). Only include documents scoring > 0.1.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
name: router_query_classification
|
|
2
|
+
version: "1.0"
|
|
3
|
+
description: Classify query to determine execution mode (standard vs instructed)
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
You are a query classification system. Your task is to determine the best execution mode for a search query.
|
|
7
|
+
|
|
8
|
+
## Execution Modes
|
|
9
|
+
|
|
10
|
+
**standard**: Use for simple keyword or product searches without specific constraints.
|
|
11
|
+
- General questions about products
|
|
12
|
+
- Simple keyword searches
|
|
13
|
+
- Broad category browsing
|
|
14
|
+
|
|
15
|
+
**instructed**: Use for queries with explicit constraints that require metadata filtering.
|
|
16
|
+
- Price constraints ("under $100", "between $50 and $200")
|
|
17
|
+
- Brand preferences ("Milwaukee", "not DeWalt", "excluding Makita")
|
|
18
|
+
- Category filters ("power tools", "paint supplies")
|
|
19
|
+
- Time/recency constraints ("recent", "from last month", "updated this year")
|
|
20
|
+
- Comparison queries ("compare X and Y")
|
|
21
|
+
- Multiple combined constraints
|
|
22
|
+
|
|
23
|
+
## Available Schema for Filtering
|
|
24
|
+
|
|
25
|
+
{schema_description}
|
|
26
|
+
|
|
27
|
+
## Query to Classify
|
|
28
|
+
|
|
29
|
+
"{query}"
|
|
30
|
+
|
|
31
|
+
## Instructions
|
|
32
|
+
|
|
33
|
+
Analyze the query and determine:
|
|
34
|
+
1. Does it contain explicit constraints that can be translated to metadata filters?
|
|
35
|
+
2. Would the query benefit from being decomposed into subqueries?
|
|
36
|
+
|
|
37
|
+
Return your classification as a JSON object with a single field "mode" set to either "standard" or "instructed".
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
name: result_verification
|
|
2
|
+
version: "1.0"
|
|
3
|
+
description: Verify search results satisfy user constraints
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
You are a result verification system. Your task is to determine whether search results satisfy the user's query constraints.
|
|
7
|
+
|
|
8
|
+
## User Query
|
|
9
|
+
|
|
10
|
+
"{query}"
|
|
11
|
+
|
|
12
|
+
## Schema Information
|
|
13
|
+
|
|
14
|
+
{schema_description}
|
|
15
|
+
|
|
16
|
+
## Constraints to Verify
|
|
17
|
+
|
|
18
|
+
{constraints}
|
|
19
|
+
|
|
20
|
+
## Retrieved Results (Top {num_results})
|
|
21
|
+
|
|
22
|
+
{results_summary}
|
|
23
|
+
|
|
24
|
+
## Previous Attempt Feedback (if retry)
|
|
25
|
+
|
|
26
|
+
{previous_feedback}
|
|
27
|
+
|
|
28
|
+
## Instructions
|
|
29
|
+
|
|
30
|
+
Analyze whether the results satisfy the user's explicit and implicit constraints:
|
|
31
|
+
|
|
32
|
+
1. **Intent Match**: Do the results address what the user is looking for?
|
|
33
|
+
2. **Explicit Constraints**: Are price, brand, category, date constraints met?
|
|
34
|
+
3. **Relevance**: Are the results actually useful for the user's needs?
|
|
35
|
+
|
|
36
|
+
If results do NOT satisfy constraints, suggest specific filter relaxations:
|
|
37
|
+
- Use "REMOVE" to drop a filter entirely
|
|
38
|
+
- Use "BROADEN" to widen a range (e.g., price < 100 -> price < 150)
|
|
39
|
+
- Use specific values to change a filter
|
|
40
|
+
|
|
41
|
+
Return a JSON object with:
|
|
42
|
+
- passed: boolean (true if results are satisfactory)
|
|
43
|
+
- confidence: float (0.0-1.0, your confidence in the assessment)
|
|
44
|
+
- feedback: string (brief explanation of issues, if any)
|
|
45
|
+
- suggested_filter_relaxation: object (filter changes for retry, e.g., {{"brand_name": "REMOVE"}})
|
|
46
|
+
- unmet_constraints: array of strings (list of constraints not satisfied)
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -397,6 +397,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
397
397
|
|
|
398
398
|
pip_requirements += get_installed_packages()
|
|
399
399
|
|
|
400
|
+
code_paths = list(dict.fromkeys(code_paths))
|
|
401
|
+
|
|
400
402
|
logger.trace("Pip requirements prepared", count=len(pip_requirements))
|
|
401
403
|
logger.trace("Code paths prepared", count=len(code_paths))
|
|
402
404
|
|
|
@@ -434,19 +436,38 @@ class DatabricksProvider(ServiceProvider):
|
|
|
434
436
|
pip_packages_count=len(pip_requirements),
|
|
435
437
|
)
|
|
436
438
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
439
|
+
# End any stale runs before starting to ensure clean state on retry
|
|
440
|
+
if mlflow.active_run():
|
|
441
|
+
logger.warning(
|
|
442
|
+
"Ending stale MLflow run before creating new agent",
|
|
443
|
+
run_id=mlflow.active_run().info.run_id,
|
|
444
|
+
)
|
|
445
|
+
mlflow.end_run()
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
with mlflow.start_run(run_name=run_name):
|
|
449
|
+
mlflow.set_tag("type", "agent")
|
|
450
|
+
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
451
|
+
logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
|
|
452
|
+
python_model=model_path.as_posix(),
|
|
453
|
+
code_paths=code_paths,
|
|
454
|
+
model_config=config.model_dump(mode="json", by_alias=True),
|
|
455
|
+
name="agent",
|
|
456
|
+
conda_env=conda_env,
|
|
457
|
+
input_example=input_example,
|
|
458
|
+
# resources=all_resources,
|
|
459
|
+
auth_policy=auth_policy,
|
|
460
|
+
)
|
|
461
|
+
except Exception as e:
|
|
462
|
+
# Ensure run is ended on failure to prevent stale state on retry
|
|
463
|
+
if mlflow.active_run():
|
|
464
|
+
mlflow.end_run(status="FAILED")
|
|
465
|
+
logger.error(
|
|
466
|
+
"Failed to log model",
|
|
467
|
+
run_name=run_name,
|
|
468
|
+
error=str(e),
|
|
449
469
|
)
|
|
470
|
+
raise
|
|
450
471
|
|
|
451
472
|
registered_model_name: str = config.app.registered_model.full_name
|
|
452
473
|
|