dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/vector_search.py
CHANGED
|
@@ -2,126 +2,54 @@
|
|
|
2
2
|
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
3
|
|
|
4
4
|
This module provides a tool factory for creating semantic search tools
|
|
5
|
-
with dynamic filter schemas based on table columns
|
|
5
|
+
with dynamic filter schemas based on table columns, FlashRank reranking support,
|
|
6
|
+
instructed retrieval with query decomposition and RRF merging, and optional
|
|
7
|
+
query routing, result verification, and instruction-aware reranking.
|
|
6
8
|
"""
|
|
7
9
|
|
|
8
10
|
import json
|
|
9
11
|
import os
|
|
10
|
-
from
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
from typing import Annotated, Any, Literal, Optional
|
|
11
14
|
|
|
12
15
|
import mlflow
|
|
13
16
|
from databricks.sdk import WorkspaceClient
|
|
14
17
|
from databricks.vector_search.reranker import DatabricksReranker
|
|
15
18
|
from databricks_langchain import DatabricksVectorSearch
|
|
16
19
|
from flashrank import Ranker, RerankRequest
|
|
20
|
+
from langchain.tools import ToolRuntime, tool
|
|
17
21
|
from langchain_core.documents import Document
|
|
18
22
|
from langchain_core.tools import StructuredTool
|
|
19
23
|
from loguru import logger
|
|
20
24
|
from mlflow.entities import SpanType
|
|
21
|
-
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
22
25
|
|
|
23
26
|
from dao_ai.config import (
|
|
27
|
+
ColumnInfo,
|
|
28
|
+
FilterItem,
|
|
29
|
+
InstructedRetrieverModel,
|
|
24
30
|
RerankParametersModel,
|
|
25
31
|
RetrieverModel,
|
|
32
|
+
RouterModel,
|
|
26
33
|
SearchParametersModel,
|
|
34
|
+
SearchQuery,
|
|
27
35
|
VectorStoreModel,
|
|
36
|
+
VerificationResult,
|
|
37
|
+
VerifierModel,
|
|
28
38
|
value_of,
|
|
29
39
|
)
|
|
30
|
-
from dao_ai.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
key=(
|
|
36
|
-
str,
|
|
37
|
-
Field(
|
|
38
|
-
description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
|
|
39
|
-
),
|
|
40
|
-
),
|
|
41
|
-
value=(
|
|
42
|
-
Any,
|
|
43
|
-
Field(
|
|
44
|
-
description="The filter value, which can be a single value or an array of values"
|
|
45
|
-
),
|
|
46
|
-
),
|
|
47
|
-
__config__=ConfigDict(extra="forbid"),
|
|
40
|
+
from dao_ai.state import Context
|
|
41
|
+
from dao_ai.tools.instructed_retriever import (
|
|
42
|
+
_get_cached_llm,
|
|
43
|
+
decompose_query,
|
|
44
|
+
rrf_merge,
|
|
48
45
|
)
|
|
46
|
+
from dao_ai.tools.instruction_reranker import instruction_aware_rerank
|
|
47
|
+
from dao_ai.tools.router import route_query
|
|
48
|
+
from dao_ai.tools.verifier import add_verification_metadata, verify_results
|
|
49
|
+
from dao_ai.utils import is_in_model_serving, normalize_host
|
|
49
50
|
|
|
50
51
|
|
|
51
|
-
|
|
52
|
-
index_name: str, workspace_client: WorkspaceClient
|
|
53
|
-
) -> type[BaseModel]:
|
|
54
|
-
"""
|
|
55
|
-
Create dynamic input schema with column information from the table.
|
|
56
|
-
|
|
57
|
-
Args:
|
|
58
|
-
index_name: Full name of the vector search index
|
|
59
|
-
workspace_client: Workspace client to query table metadata
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
Pydantic model class for tool input
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
# Try to get column information
|
|
66
|
-
column_descriptions = []
|
|
67
|
-
try:
|
|
68
|
-
table_info = workspace_client.tables.get(full_name=index_name)
|
|
69
|
-
for column_info in table_info.columns:
|
|
70
|
-
name = column_info.name
|
|
71
|
-
col_type = column_info.type_name.name
|
|
72
|
-
if not name.startswith("__"):
|
|
73
|
-
column_descriptions.append(f"{name} ({col_type})")
|
|
74
|
-
except Exception:
|
|
75
|
-
logger.debug(
|
|
76
|
-
"Could not retrieve column information for dynamic schema",
|
|
77
|
-
index=index_name,
|
|
78
|
-
)
|
|
79
|
-
|
|
80
|
-
# Build filter description matching VectorSearchRetrieverTool format
|
|
81
|
-
filter_description = (
|
|
82
|
-
"Optional filters to refine vector search results as an array of key-value pairs. "
|
|
83
|
-
"IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
|
|
84
|
-
"then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if column_descriptions:
|
|
88
|
-
filter_description += (
|
|
89
|
-
f"Available columns for filtering: {', '.join(column_descriptions)}. "
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
filter_description += (
|
|
93
|
-
"Supports the following operators:\n\n"
|
|
94
|
-
'- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
|
|
95
|
-
'- Exclusion: [{"key": "column NOT", "value": value}]\n'
|
|
96
|
-
'- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
|
|
97
|
-
'- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
|
|
98
|
-
'- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
|
|
99
|
-
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
|
|
100
|
-
"Examples:\n"
|
|
101
|
-
'- Filter by category: [{"key": "category", "value": "electronics"}]\n'
|
|
102
|
-
'- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
|
|
103
|
-
'- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
|
|
104
|
-
'- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
# Create the input model
|
|
108
|
-
VectorSearchInput = create_model(
|
|
109
|
-
"VectorSearchInput",
|
|
110
|
-
query=(
|
|
111
|
-
str,
|
|
112
|
-
Field(description="The search query string to find relevant documents"),
|
|
113
|
-
),
|
|
114
|
-
filters=(
|
|
115
|
-
Optional[list[FilterItem]],
|
|
116
|
-
Field(default=None, description=filter_description),
|
|
117
|
-
),
|
|
118
|
-
__config__=ConfigDict(extra="forbid"),
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
return VectorSearchInput
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
52
|
+
@mlflow.trace(name="rerank_documents", span_type=SpanType.RERANKER)
|
|
125
53
|
def _rerank_documents(
|
|
126
54
|
query: str,
|
|
127
55
|
documents: list[Document],
|
|
@@ -146,6 +74,11 @@ def _rerank_documents(
|
|
|
146
74
|
model=rerank_config.model,
|
|
147
75
|
)
|
|
148
76
|
|
|
77
|
+
# Early return if no documents to rerank
|
|
78
|
+
if not documents:
|
|
79
|
+
logger.debug("No documents to rerank, skipping")
|
|
80
|
+
return documents
|
|
81
|
+
|
|
149
82
|
# Prepare passages for reranking
|
|
150
83
|
passages: list[dict[str, Any]] = [
|
|
151
84
|
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
@@ -231,9 +164,12 @@ def create_vector_search_tool(
|
|
|
231
164
|
raise ValueError("vector_store.index is required for vector search")
|
|
232
165
|
|
|
233
166
|
index_name: str = vector_store.index.full_name
|
|
234
|
-
columns: list[str] = list(retriever.columns or [])
|
|
167
|
+
columns: list[str] = list(retriever.columns or vector_store.index.columns or [])
|
|
235
168
|
search_parameters: SearchParametersModel = retriever.search_parameters
|
|
169
|
+
router_config: Optional[RouterModel] = retriever.router
|
|
236
170
|
rerank_config: Optional[RerankParametersModel] = retriever.rerank
|
|
171
|
+
instructed_config: Optional[InstructedRetrieverModel] = retriever.instructed
|
|
172
|
+
verifier_config: Optional[VerifierModel] = retriever.verifier
|
|
237
173
|
|
|
238
174
|
# Initialize FlashRank ranker if configured
|
|
239
175
|
ranker: Optional[Ranker] = None
|
|
@@ -244,50 +180,121 @@ def create_vector_search_tool(
|
|
|
244
180
|
top_n=rerank_config.top_n or "auto",
|
|
245
181
|
)
|
|
246
182
|
try:
|
|
247
|
-
|
|
183
|
+
# Use /tmp for cache in Model Serving (home dir may not be writable)
|
|
184
|
+
if is_in_model_serving():
|
|
185
|
+
cache_dir = "/tmp/dao_ai/cache/flashrank"
|
|
186
|
+
if rerank_config.cache_dir != cache_dir:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"FlashRank cache_dir overridden in Model Serving",
|
|
189
|
+
configured=rerank_config.cache_dir,
|
|
190
|
+
actual=cache_dir,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
cache_dir = os.path.expanduser(rerank_config.cache_dir)
|
|
248
194
|
ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
|
|
195
|
+
|
|
196
|
+
# Patch rerank to always include token_type_ids for ONNX compatibility
|
|
197
|
+
# Some ONNX runtimes require token_type_ids even when the model doesn't use them
|
|
198
|
+
# FlashRank conditionally excludes them when all zeros, but ONNX may still expect them
|
|
199
|
+
# See: https://github.com/huggingface/optimum/issues/1500
|
|
200
|
+
if ranker.session is not None:
|
|
201
|
+
import numpy as np
|
|
202
|
+
|
|
203
|
+
_original_rerank = ranker.rerank
|
|
204
|
+
|
|
205
|
+
def _patched_rerank(request):
|
|
206
|
+
query = request.query
|
|
207
|
+
passages = request.passages
|
|
208
|
+
query_passage_pairs = [[query, p["text"]] for p in passages]
|
|
209
|
+
|
|
210
|
+
input_text = ranker.tokenizer.encode_batch(query_passage_pairs)
|
|
211
|
+
input_ids = np.array([e.ids for e in input_text])
|
|
212
|
+
token_type_ids = np.array([e.type_ids for e in input_text])
|
|
213
|
+
attention_mask = np.array([e.attention_mask for e in input_text])
|
|
214
|
+
|
|
215
|
+
# Always include token_type_ids (the fix for ONNX compatibility)
|
|
216
|
+
onnx_input = {
|
|
217
|
+
"input_ids": input_ids.astype(np.int64),
|
|
218
|
+
"attention_mask": attention_mask.astype(np.int64),
|
|
219
|
+
"token_type_ids": token_type_ids.astype(np.int64),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
outputs = ranker.session.run(None, onnx_input)
|
|
223
|
+
logits = outputs[0]
|
|
224
|
+
|
|
225
|
+
if logits.shape[1] == 1:
|
|
226
|
+
scores = 1 / (1 + np.exp(-logits.flatten()))
|
|
227
|
+
else:
|
|
228
|
+
exp_logits = np.exp(logits)
|
|
229
|
+
scores = exp_logits[:, 1] / np.sum(exp_logits, axis=1)
|
|
230
|
+
|
|
231
|
+
for score, passage in zip(scores, passages):
|
|
232
|
+
passage["score"] = score
|
|
233
|
+
|
|
234
|
+
passages.sort(key=lambda x: x["score"], reverse=True)
|
|
235
|
+
return passages
|
|
236
|
+
|
|
237
|
+
ranker.rerank = _patched_rerank
|
|
238
|
+
|
|
249
239
|
logger.success("FlashRank ranker initialized", model=rerank_config.model)
|
|
250
240
|
except Exception as e:
|
|
251
241
|
logger.warning("Failed to initialize FlashRank ranker", error=str(e))
|
|
252
242
|
rerank_config = None
|
|
253
243
|
|
|
244
|
+
# Log instructed retrieval configuration
|
|
245
|
+
if instructed_config:
|
|
246
|
+
logger.success(
|
|
247
|
+
"Instructed retrieval configured",
|
|
248
|
+
decomposition_model=instructed_config.decomposition_model.name
|
|
249
|
+
if instructed_config.decomposition_model
|
|
250
|
+
else None,
|
|
251
|
+
max_subqueries=instructed_config.max_subqueries,
|
|
252
|
+
rrf_k=instructed_config.rrf_k,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Log instruction-aware reranking configuration
|
|
256
|
+
if rerank_config and rerank_config.instruction_aware:
|
|
257
|
+
logger.success(
|
|
258
|
+
"Instruction-aware reranking configured",
|
|
259
|
+
model=rerank_config.instruction_aware.model.name
|
|
260
|
+
if rerank_config.instruction_aware.model
|
|
261
|
+
else None,
|
|
262
|
+
top_n=rerank_config.instruction_aware.top_n,
|
|
263
|
+
)
|
|
264
|
+
|
|
254
265
|
# Build client_args for VectorSearchClient
|
|
255
|
-
# Use getattr to safely access attributes that may not exist (e.g., in mocks)
|
|
256
266
|
client_args: dict[str, Any] = {}
|
|
257
267
|
has_explicit_auth = any(
|
|
258
268
|
[
|
|
259
269
|
os.environ.get("DATABRICKS_TOKEN"),
|
|
260
270
|
os.environ.get("DATABRICKS_CLIENT_ID"),
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
271
|
+
vector_store.pat,
|
|
272
|
+
vector_store.client_id,
|
|
273
|
+
vector_store.on_behalf_of_user,
|
|
264
274
|
]
|
|
265
275
|
)
|
|
266
276
|
|
|
267
277
|
if has_explicit_auth:
|
|
268
278
|
databricks_host = os.environ.get("DATABRICKS_HOST")
|
|
269
|
-
if
|
|
270
|
-
|
|
271
|
-
and getattr(vector_store, "_workspace_client", None) is not None
|
|
272
|
-
):
|
|
273
|
-
databricks_host = vector_store.workspace_client.config.host
|
|
279
|
+
if not databricks_host and vector_store.workspace_host:
|
|
280
|
+
databricks_host = value_of(vector_store.workspace_host)
|
|
274
281
|
if databricks_host:
|
|
275
282
|
client_args["workspace_url"] = normalize_host(databricks_host)
|
|
276
283
|
|
|
277
284
|
token = os.environ.get("DATABRICKS_TOKEN")
|
|
278
|
-
if not token and
|
|
285
|
+
if not token and vector_store.pat:
|
|
279
286
|
token = value_of(vector_store.pat)
|
|
280
287
|
if token:
|
|
281
288
|
client_args["personal_access_token"] = token
|
|
282
289
|
|
|
283
290
|
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
|
|
284
|
-
if not client_id and
|
|
291
|
+
if not client_id and vector_store.client_id:
|
|
285
292
|
client_id = value_of(vector_store.client_id)
|
|
286
293
|
if client_id:
|
|
287
294
|
client_args["service_principal_client_id"] = client_id
|
|
288
295
|
|
|
289
296
|
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
|
|
290
|
-
if not client_secret and
|
|
297
|
+
if not client_secret and vector_store.client_secret:
|
|
291
298
|
client_secret = value_of(vector_store.client_secret)
|
|
292
299
|
if client_secret:
|
|
293
300
|
client_args["service_principal_client_secret"] = client_secret
|
|
@@ -299,71 +306,406 @@ def create_vector_search_tool(
|
|
|
299
306
|
client_args_keys=list(client_args.keys()) if client_args else [],
|
|
300
307
|
)
|
|
301
308
|
|
|
302
|
-
#
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
)
|
|
319
|
-
|
|
309
|
+
# Cache for DatabricksVectorSearch - created lazily for OBO support
|
|
310
|
+
_cached_vector_search: DatabricksVectorSearch | None = None
|
|
311
|
+
|
|
312
|
+
def _get_vector_search(context: Context | None) -> DatabricksVectorSearch:
|
|
313
|
+
"""Get or create DatabricksVectorSearch, using context for OBO auth if available."""
|
|
314
|
+
nonlocal _cached_vector_search
|
|
315
|
+
|
|
316
|
+
# Use cached instance if available and not OBO
|
|
317
|
+
if _cached_vector_search is not None and not vector_store.on_behalf_of_user:
|
|
318
|
+
return _cached_vector_search
|
|
319
|
+
|
|
320
|
+
# Get workspace client with OBO support via context
|
|
321
|
+
workspace_client: WorkspaceClient = vector_store.workspace_client_from(context)
|
|
322
|
+
|
|
323
|
+
# Create DatabricksVectorSearch
|
|
324
|
+
# Note: text_column should be None for Databricks-managed embeddings
|
|
325
|
+
# (it's automatically determined from the index)
|
|
326
|
+
vs: DatabricksVectorSearch = DatabricksVectorSearch(
|
|
327
|
+
index_name=index_name,
|
|
328
|
+
text_column=None,
|
|
329
|
+
columns=columns,
|
|
330
|
+
workspace_client=workspace_client,
|
|
331
|
+
client_args=client_args if client_args else None,
|
|
332
|
+
primary_key=vector_store.primary_key,
|
|
333
|
+
doc_uri=vector_store.doc_uri,
|
|
334
|
+
include_score=True,
|
|
335
|
+
reranker=(
|
|
336
|
+
DatabricksReranker(columns_to_rerank=rerank_config.columns)
|
|
337
|
+
if rerank_config and rerank_config.columns
|
|
338
|
+
else None
|
|
339
|
+
),
|
|
340
|
+
)
|
|
320
341
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
342
|
+
# Cache for non-OBO scenarios
|
|
343
|
+
if not vector_store.on_behalf_of_user:
|
|
344
|
+
_cached_vector_search = vs
|
|
345
|
+
|
|
346
|
+
return vs
|
|
347
|
+
|
|
348
|
+
# Determine tool name and description
|
|
349
|
+
tool_name: str = name or f"vector_search_{vector_store.index.name}"
|
|
350
|
+
|
|
351
|
+
# Build tool description with available columns for filtering
|
|
352
|
+
base_description: str = description or f"Search documents in {index_name}"
|
|
353
|
+
if columns:
|
|
354
|
+
columns_list = ", ".join(columns)
|
|
355
|
+
tool_description = (
|
|
356
|
+
f"{base_description}. "
|
|
357
|
+
f"Available filter columns: {columns_list}. "
|
|
358
|
+
f"Filter operators: 'column' for equality, 'column NOT' for exclusion, "
|
|
359
|
+
f"'column <', 'column <=', 'column >', 'column >=' for comparison, "
|
|
360
|
+
f"'column LIKE' for token matching, 'column NOT LIKE' to exclude tokens."
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
tool_description = base_description
|
|
364
|
+
|
|
365
|
+
@mlflow.trace(name="execute_instructed_retrieval", span_type=SpanType.RETRIEVER)
|
|
366
|
+
def _execute_instructed_retrieval(
|
|
367
|
+
vs: DatabricksVectorSearch,
|
|
368
|
+
query: str,
|
|
369
|
+
base_filters: dict[str, Any],
|
|
370
|
+
previous_feedback: str | None = None,
|
|
371
|
+
) -> list[Document]:
|
|
372
|
+
"""Execute instructed retrieval with query decomposition and RRF merging."""
|
|
373
|
+
logger.trace(
|
|
374
|
+
"Executing instructed retrieval", query=query, base_filters=base_filters
|
|
375
|
+
)
|
|
376
|
+
try:
|
|
377
|
+
decomposition_llm = _get_cached_llm(instructed_config.decomposition_model)
|
|
378
|
+
|
|
379
|
+
# Fall back to retriever columns if instructed columns not provided
|
|
380
|
+
instructed_columns: list[ColumnInfo] | None = instructed_config.columns
|
|
381
|
+
if instructed_columns is None and columns:
|
|
382
|
+
instructed_columns = [ColumnInfo(name=col) for col in columns]
|
|
383
|
+
|
|
384
|
+
subqueries: list[SearchQuery] = decompose_query(
|
|
385
|
+
llm=decomposition_llm,
|
|
386
|
+
query=query,
|
|
387
|
+
schema_description=instructed_config.schema_description,
|
|
388
|
+
constraints=instructed_config.constraints,
|
|
389
|
+
max_subqueries=instructed_config.max_subqueries,
|
|
390
|
+
examples=instructed_config.examples,
|
|
391
|
+
previous_feedback=previous_feedback,
|
|
392
|
+
columns=instructed_columns,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if not subqueries:
|
|
396
|
+
logger.warning(
|
|
397
|
+
"Query decomposition returned no subqueries, using original"
|
|
398
|
+
)
|
|
399
|
+
return vs.similarity_search(
|
|
400
|
+
query=query,
|
|
401
|
+
k=search_parameters.num_results or 5,
|
|
402
|
+
filter=base_filters if base_filters else None,
|
|
403
|
+
query_type=search_parameters.query_type or "ANN",
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def normalize_filter_values(
|
|
407
|
+
filters: dict[str, Any], case: str | None
|
|
408
|
+
) -> dict[str, Any]:
|
|
409
|
+
"""Normalize string filter values to specified case."""
|
|
410
|
+
logger.trace("Normalizing filter values", filters=filters, case=case)
|
|
411
|
+
if not case or not filters:
|
|
412
|
+
return filters
|
|
413
|
+
normalized = {}
|
|
414
|
+
for key, value in filters.items():
|
|
415
|
+
if isinstance(value, str):
|
|
416
|
+
normalized[key] = (
|
|
417
|
+
value.upper() if case == "uppercase" else value.lower()
|
|
418
|
+
)
|
|
419
|
+
elif isinstance(value, list):
|
|
420
|
+
normalized[key] = [
|
|
421
|
+
v.upper()
|
|
422
|
+
if case == "uppercase"
|
|
423
|
+
else v.lower()
|
|
424
|
+
if isinstance(v, str)
|
|
425
|
+
else v
|
|
426
|
+
for v in value
|
|
427
|
+
]
|
|
428
|
+
else:
|
|
429
|
+
normalized[key] = value
|
|
430
|
+
return normalized
|
|
431
|
+
|
|
432
|
+
def execute_search(sq: SearchQuery) -> list[Document]:
|
|
433
|
+
logger.trace("Executing search", query=sq.text, filters=sq.filters)
|
|
434
|
+
# Convert FilterItem list to dict
|
|
435
|
+
sq_filters_dict: dict[str, Any] = {}
|
|
436
|
+
if sq.filters:
|
|
437
|
+
for item in sq.filters:
|
|
438
|
+
sq_filters_dict[item.key] = item.value
|
|
439
|
+
sq_filters = normalize_filter_values(
|
|
440
|
+
sq_filters_dict, instructed_config.normalize_filter_case
|
|
441
|
+
)
|
|
442
|
+
k: int = search_parameters.num_results or 5
|
|
443
|
+
query_type: str = search_parameters.query_type or "ANN"
|
|
444
|
+
combined_filters: dict[str, Any] = {**sq_filters, **base_filters}
|
|
445
|
+
logger.trace(
|
|
446
|
+
"Executing search",
|
|
447
|
+
query=sq.text,
|
|
448
|
+
k=k,
|
|
449
|
+
query_type=query_type,
|
|
450
|
+
filters=combined_filters,
|
|
451
|
+
)
|
|
452
|
+
return vs.similarity_search(
|
|
453
|
+
query=sq.text,
|
|
454
|
+
k=k,
|
|
455
|
+
filter=combined_filters if combined_filters else None,
|
|
456
|
+
query_type=query_type,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
logger.debug(
|
|
460
|
+
"Executing parallel searches",
|
|
461
|
+
num_subqueries=len(subqueries),
|
|
462
|
+
queries=[sq.text[:50] for sq in subqueries],
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
with ThreadPoolExecutor(
|
|
466
|
+
max_workers=instructed_config.max_subqueries
|
|
467
|
+
) as executor:
|
|
468
|
+
all_results = list(executor.map(execute_search, subqueries))
|
|
469
|
+
|
|
470
|
+
merged = rrf_merge(
|
|
471
|
+
all_results,
|
|
472
|
+
k=instructed_config.rrf_k,
|
|
473
|
+
primary_key=vector_store.primary_key,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
logger.debug(
|
|
477
|
+
"Instructed retrieval complete",
|
|
478
|
+
num_subqueries=len(subqueries),
|
|
479
|
+
total_results=sum(len(r) for r in all_results),
|
|
480
|
+
merged_results=len(merged),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
return merged
|
|
484
|
+
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.warning(
|
|
487
|
+
"Instructed retrieval failed, falling back to standard search",
|
|
488
|
+
error=str(e),
|
|
489
|
+
)
|
|
490
|
+
return vs.similarity_search(
|
|
491
|
+
query=query,
|
|
492
|
+
k=search_parameters.num_results or 5,
|
|
493
|
+
filter=base_filters if base_filters else None,
|
|
494
|
+
query_type=search_parameters.query_type or "ANN",
|
|
495
|
+
)
|
|
325
496
|
|
|
326
|
-
|
|
327
|
-
def
|
|
328
|
-
|
|
497
|
+
@mlflow.trace(name="execute_standard_search", span_type=SpanType.RETRIEVER)
|
|
498
|
+
def _execute_standard_search(
|
|
499
|
+
vs: DatabricksVectorSearch,
|
|
500
|
+
query: str,
|
|
501
|
+
base_filters: dict[str, Any],
|
|
502
|
+
) -> list[Document]:
|
|
503
|
+
"""Execute standard single-query search."""
|
|
504
|
+
logger.trace("Performing standard vector search", query_preview=query[:50])
|
|
505
|
+
return vs.similarity_search(
|
|
506
|
+
query=query,
|
|
507
|
+
k=search_parameters.num_results or 5,
|
|
508
|
+
filter=base_filters if base_filters else None,
|
|
509
|
+
query_type=search_parameters.query_type or "ANN",
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
@mlflow.trace(name="apply_post_processing", span_type=SpanType.RETRIEVER)
|
|
513
|
+
def _apply_post_processing(
|
|
514
|
+
documents: list[Document],
|
|
515
|
+
query: str,
|
|
516
|
+
mode: Literal["standard", "instructed"],
|
|
517
|
+
auto_bypass: bool,
|
|
518
|
+
) -> list[Document]:
|
|
519
|
+
"""Apply instruction-aware reranking and verification based on mode and bypass settings."""
|
|
520
|
+
# Skip post-processing for standard mode when auto_bypass is enabled
|
|
521
|
+
if mode == "standard" and auto_bypass:
|
|
522
|
+
mlflow.set_tag("router.bypassed_stages", "true")
|
|
523
|
+
return documents
|
|
524
|
+
|
|
525
|
+
# Apply instruction-aware reranking if configured
|
|
526
|
+
if rerank_config and rerank_config.instruction_aware:
|
|
527
|
+
instruction_config = rerank_config.instruction_aware
|
|
528
|
+
instruction_llm = (
|
|
529
|
+
_get_cached_llm(instruction_config.model)
|
|
530
|
+
if instruction_config.model
|
|
531
|
+
else None
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if instruction_llm:
|
|
535
|
+
schema_desc = (
|
|
536
|
+
instructed_config.schema_description if instructed_config else None
|
|
537
|
+
)
|
|
538
|
+
# Get columns for dynamic instruction generation
|
|
539
|
+
rerank_columns: list[ColumnInfo] | None = None
|
|
540
|
+
if instructed_config and instructed_config.columns:
|
|
541
|
+
rerank_columns = instructed_config.columns
|
|
542
|
+
elif columns:
|
|
543
|
+
rerank_columns = [ColumnInfo(name=col) for col in columns]
|
|
544
|
+
|
|
545
|
+
documents = instruction_aware_rerank(
|
|
546
|
+
llm=instruction_llm,
|
|
547
|
+
query=query,
|
|
548
|
+
documents=documents,
|
|
549
|
+
instructions=instruction_config.instructions,
|
|
550
|
+
schema_description=schema_desc,
|
|
551
|
+
columns=rerank_columns,
|
|
552
|
+
top_n=instruction_config.top_n,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Apply verification if configured
|
|
556
|
+
if verifier_config:
|
|
557
|
+
verifier_llm = (
|
|
558
|
+
_get_cached_llm(verifier_config.model)
|
|
559
|
+
if verifier_config.model
|
|
560
|
+
else None
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if verifier_llm:
|
|
564
|
+
schema_desc = (
|
|
565
|
+
instructed_config.schema_description if instructed_config else ""
|
|
566
|
+
)
|
|
567
|
+
constraints = (
|
|
568
|
+
instructed_config.constraints if instructed_config else None
|
|
569
|
+
)
|
|
570
|
+
retry_count = 0
|
|
571
|
+
verification_result: VerificationResult | None = None
|
|
572
|
+
previous_feedback: str | None = None
|
|
573
|
+
|
|
574
|
+
while retry_count <= verifier_config.max_retries:
|
|
575
|
+
verification_result = verify_results(
|
|
576
|
+
llm=verifier_llm,
|
|
577
|
+
query=query,
|
|
578
|
+
documents=documents,
|
|
579
|
+
schema_description=schema_desc,
|
|
580
|
+
constraints=constraints,
|
|
581
|
+
previous_feedback=previous_feedback,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if verification_result.passed:
|
|
585
|
+
mlflow.set_tag("verifier.outcome", "passed")
|
|
586
|
+
mlflow.set_tag("verifier.retries", str(retry_count))
|
|
587
|
+
break
|
|
588
|
+
|
|
589
|
+
# Handle failure based on configuration
|
|
590
|
+
if verifier_config.on_failure == "warn":
|
|
591
|
+
mlflow.set_tag("verifier.outcome", "warned")
|
|
592
|
+
documents = add_verification_metadata(
|
|
593
|
+
documents, verification_result
|
|
594
|
+
)
|
|
595
|
+
break
|
|
596
|
+
|
|
597
|
+
if retry_count >= verifier_config.max_retries:
|
|
598
|
+
mlflow.set_tag("verifier.outcome", "exhausted")
|
|
599
|
+
mlflow.set_tag("verifier.retries", str(retry_count))
|
|
600
|
+
documents = add_verification_metadata(
|
|
601
|
+
documents, verification_result, exhausted=True
|
|
602
|
+
)
|
|
603
|
+
break
|
|
604
|
+
|
|
605
|
+
# Retry with feedback
|
|
606
|
+
mlflow.set_tag("verifier.outcome", "retried")
|
|
607
|
+
previous_feedback = verification_result.feedback
|
|
608
|
+
retry_count += 1
|
|
609
|
+
logger.debug(
|
|
610
|
+
"Retrying search with verification feedback", retry=retry_count
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
return documents
|
|
614
|
+
|
|
615
|
+
# Use @tool decorator for proper ToolRuntime injection
|
|
616
|
+
@tool(name_or_callable=tool_name, description=tool_description)
|
|
617
|
+
def _vector_search_tool(
|
|
618
|
+
query: Annotated[str, "The search query to find relevant documents"],
|
|
619
|
+
filters: Annotated[
|
|
620
|
+
Optional[list[FilterItem]],
|
|
621
|
+
"Optional filters as key-value pairs. "
|
|
622
|
+
"Key operators: 'column' (equality), 'column NOT' (exclusion), "
|
|
623
|
+
"'column <', '<=', '>', '>=' (comparison), "
|
|
624
|
+
"'column LIKE' (token match), 'column NOT LIKE' (exclude token). "
|
|
625
|
+
f"Valid columns: {', '.join(columns) if columns else 'none'}.",
|
|
626
|
+
] = None,
|
|
627
|
+
runtime: ToolRuntime[Context] = None,
|
|
329
628
|
) -> str:
|
|
330
629
|
"""Search for relevant documents using vector similarity."""
|
|
331
|
-
|
|
630
|
+
context: Context | None = runtime.context if runtime else None
|
|
631
|
+
vs: DatabricksVectorSearch = _get_vector_search(context)
|
|
632
|
+
|
|
332
633
|
filters_dict: dict[str, Any] = {}
|
|
333
634
|
if filters:
|
|
334
635
|
for item in filters:
|
|
335
636
|
filters_dict[item.key] = item.value
|
|
336
637
|
|
|
337
|
-
|
|
338
|
-
combined_filters: dict[str, Any] = {
|
|
638
|
+
base_filters: dict[str, Any] = {
|
|
339
639
|
**filters_dict,
|
|
340
640
|
**(search_parameters.filters or {}),
|
|
341
641
|
}
|
|
342
642
|
|
|
343
|
-
#
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
643
|
+
# Determine execution mode via router or config
|
|
644
|
+
mode: Literal["standard", "instructed"] = "standard"
|
|
645
|
+
auto_bypass = True
|
|
646
|
+
|
|
647
|
+
logger.trace("Router configuration", router_config=router_config)
|
|
648
|
+
logger.trace("Instructed configuration", instructed_config=instructed_config)
|
|
649
|
+
logger.trace(
|
|
650
|
+
"Instruction-aware rerank configuration",
|
|
651
|
+
instruction_aware=rerank_config.instruction_aware
|
|
652
|
+
if rerank_config
|
|
653
|
+
else None,
|
|
350
654
|
)
|
|
351
655
|
|
|
656
|
+
if router_config:
|
|
657
|
+
router_llm = (
|
|
658
|
+
_get_cached_llm(router_config.model) if router_config.model else None
|
|
659
|
+
)
|
|
660
|
+
auto_bypass = router_config.auto_bypass
|
|
661
|
+
|
|
662
|
+
if router_llm and instructed_config:
|
|
663
|
+
try:
|
|
664
|
+
mode = route_query(
|
|
665
|
+
llm=router_llm,
|
|
666
|
+
query=query,
|
|
667
|
+
schema_description=instructed_config.schema_description,
|
|
668
|
+
)
|
|
669
|
+
except Exception as e:
|
|
670
|
+
# Router fail-safe: default to standard mode
|
|
671
|
+
logger.warning(
|
|
672
|
+
"Router failed, defaulting to standard mode", error=str(e)
|
|
673
|
+
)
|
|
674
|
+
mlflow.set_tag("router.fallback", "true")
|
|
675
|
+
mode = router_config.default_mode
|
|
676
|
+
else:
|
|
677
|
+
mode = router_config.default_mode
|
|
678
|
+
elif instructed_config:
|
|
679
|
+
# No router but instructed is configured - use instructed mode
|
|
680
|
+
mode = "instructed"
|
|
681
|
+
auto_bypass = False
|
|
682
|
+
elif rerank_config and rerank_config.instruction_aware:
|
|
683
|
+
# No router/instructed but instruction_aware reranking is configured
|
|
684
|
+
# Disable auto_bypass to ensure instruction_aware reranking runs
|
|
685
|
+
auto_bypass = False
|
|
686
|
+
|
|
687
|
+
logger.trace("Routing mode", mode=mode, auto_bypass=auto_bypass)
|
|
688
|
+
mlflow.set_tag("router.mode", mode)
|
|
689
|
+
|
|
690
|
+
# Execute search based on mode
|
|
691
|
+
if mode == "instructed" and instructed_config:
|
|
692
|
+
documents = _execute_instructed_retrieval(vs, query, base_filters)
|
|
693
|
+
else:
|
|
694
|
+
documents = _execute_standard_search(vs, query, base_filters)
|
|
695
|
+
|
|
352
696
|
# Apply FlashRank reranking if configured
|
|
353
697
|
if ranker and rerank_config:
|
|
354
698
|
logger.debug("Applying FlashRank reranking")
|
|
355
699
|
documents = _rerank_documents(query, documents, ranker, rerank_config)
|
|
356
700
|
|
|
701
|
+
# Apply post-processing (instruction reranking + verification)
|
|
702
|
+
documents = _apply_post_processing(documents, query, mode, auto_bypass)
|
|
703
|
+
|
|
357
704
|
# Serialize documents to JSON format for LLM consumption
|
|
358
|
-
# Convert Document objects to dicts with page_content and metadata
|
|
359
|
-
# Need to handle numpy types in metadata (e.g., float32, int64)
|
|
360
705
|
serialized_docs: list[dict[str, Any]] = []
|
|
361
706
|
for doc in documents:
|
|
362
|
-
doc: Document
|
|
363
|
-
# Convert metadata values to JSON-serializable types
|
|
364
707
|
metadata_serializable: dict[str, Any] = {}
|
|
365
708
|
for key, value in doc.metadata.items():
|
|
366
|
-
# Handle numpy types
|
|
367
709
|
if hasattr(value, "item"): # numpy scalar
|
|
368
710
|
metadata_serializable[key] = value.item()
|
|
369
711
|
else:
|
|
@@ -376,17 +718,8 @@ def create_vector_search_tool(
|
|
|
376
718
|
}
|
|
377
719
|
)
|
|
378
720
|
|
|
379
|
-
# Return as JSON string
|
|
380
721
|
return json.dumps(serialized_docs)
|
|
381
722
|
|
|
382
|
-
|
|
383
|
-
tool: StructuredTool = StructuredTool.from_function(
|
|
384
|
-
func=vector_search_func,
|
|
385
|
-
name=name or f"vector_search_{vector_store.index.name}",
|
|
386
|
-
description=description or f"Search documents in {index_name}",
|
|
387
|
-
args_schema=input_schema,
|
|
388
|
-
)
|
|
389
|
-
|
|
390
|
-
logger.success("Vector search tool created", name=tool.name, index=index_name)
|
|
723
|
+
logger.success("Vector search tool created", name=tool_name, index=index_name)
|
|
391
724
|
|
|
392
|
-
return
|
|
725
|
+
return _vector_search_tool
|