dao-ai 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +8 -3
- dao_ai/config.py +414 -32
- dao_ai/evaluation.py +543 -0
- dao_ai/memory/postgres.py +146 -35
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/{prompts.py → prompts/__init__.py} +10 -1
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/databricks.py +33 -12
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/vector_search.py +441 -134
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.18.dist-info}/METADATA +2 -2
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.18.dist-info}/RECORD +24 -15
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.18.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.18.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.18.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/vector_search.py
CHANGED
|
@@ -2,12 +2,15 @@
|
|
|
2
2
|
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
3
|
|
|
4
4
|
This module provides a tool factory for creating semantic search tools
|
|
5
|
-
with dynamic filter schemas based on table columns
|
|
5
|
+
with dynamic filter schemas based on table columns, FlashRank reranking support,
|
|
6
|
+
instructed retrieval with query decomposition and RRF merging, and optional
|
|
7
|
+
query routing, result verification, and instruction-aware reranking.
|
|
6
8
|
"""
|
|
7
9
|
|
|
8
10
|
import json
|
|
9
11
|
import os
|
|
10
|
-
from
|
|
12
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
13
|
+
from typing import Annotated, Any, Literal, Optional
|
|
11
14
|
|
|
12
15
|
import mlflow
|
|
13
16
|
from databricks.sdk import WorkspaceClient
|
|
@@ -19,111 +22,34 @@ from langchain_core.documents import Document
|
|
|
19
22
|
from langchain_core.tools import StructuredTool
|
|
20
23
|
from loguru import logger
|
|
21
24
|
from mlflow.entities import SpanType
|
|
22
|
-
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
23
25
|
|
|
24
26
|
from dao_ai.config import (
|
|
27
|
+
ColumnInfo,
|
|
28
|
+
FilterItem,
|
|
29
|
+
InstructedRetrieverModel,
|
|
25
30
|
RerankParametersModel,
|
|
26
31
|
RetrieverModel,
|
|
32
|
+
RouterModel,
|
|
27
33
|
SearchParametersModel,
|
|
34
|
+
SearchQuery,
|
|
28
35
|
VectorStoreModel,
|
|
36
|
+
VerificationResult,
|
|
37
|
+
VerifierModel,
|
|
29
38
|
value_of,
|
|
30
39
|
)
|
|
31
40
|
from dao_ai.state import Context
|
|
32
|
-
from dao_ai.
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
"FilterItem",
|
|
37
|
-
key=(
|
|
38
|
-
str,
|
|
39
|
-
Field(
|
|
40
|
-
description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
|
|
41
|
-
),
|
|
42
|
-
),
|
|
43
|
-
value=(
|
|
44
|
-
Any,
|
|
45
|
-
Field(
|
|
46
|
-
description="The filter value, which can be a single value or an array of values"
|
|
47
|
-
),
|
|
48
|
-
),
|
|
49
|
-
__config__=ConfigDict(extra="forbid"),
|
|
41
|
+
from dao_ai.tools.instructed_retriever import (
|
|
42
|
+
_get_cached_llm,
|
|
43
|
+
decompose_query,
|
|
44
|
+
rrf_merge,
|
|
50
45
|
)
|
|
46
|
+
from dao_ai.tools.instruction_reranker import instruction_aware_rerank
|
|
47
|
+
from dao_ai.tools.router import route_query
|
|
48
|
+
from dao_ai.tools.verifier import add_verification_metadata, verify_results
|
|
49
|
+
from dao_ai.utils import is_in_model_serving, normalize_host
|
|
51
50
|
|
|
52
51
|
|
|
53
|
-
|
|
54
|
-
index_name: str, workspace_client: WorkspaceClient
|
|
55
|
-
) -> type[BaseModel]:
|
|
56
|
-
"""
|
|
57
|
-
Create dynamic input schema with column information from the table.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
index_name: Full name of the vector search index
|
|
61
|
-
workspace_client: Workspace client to query table metadata
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
Pydantic model class for tool input
|
|
65
|
-
"""
|
|
66
|
-
|
|
67
|
-
# Try to get column information
|
|
68
|
-
column_descriptions = []
|
|
69
|
-
try:
|
|
70
|
-
table_info = workspace_client.tables.get(full_name=index_name)
|
|
71
|
-
for column_info in table_info.columns:
|
|
72
|
-
name = column_info.name
|
|
73
|
-
col_type = column_info.type_name.name
|
|
74
|
-
if not name.startswith("__"):
|
|
75
|
-
column_descriptions.append(f"{name} ({col_type})")
|
|
76
|
-
except Exception:
|
|
77
|
-
logger.debug(
|
|
78
|
-
"Could not retrieve column information for dynamic schema",
|
|
79
|
-
index=index_name,
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
# Build filter description matching VectorSearchRetrieverTool format
|
|
83
|
-
filter_description = (
|
|
84
|
-
"Optional filters to refine vector search results as an array of key-value pairs. "
|
|
85
|
-
"IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
|
|
86
|
-
"then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
if column_descriptions:
|
|
90
|
-
filter_description += (
|
|
91
|
-
f"Available columns for filtering: {', '.join(column_descriptions)}. "
|
|
92
|
-
)
|
|
93
|
-
|
|
94
|
-
filter_description += (
|
|
95
|
-
"Supports the following operators:\n\n"
|
|
96
|
-
'- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
|
|
97
|
-
'- Exclusion: [{"key": "column NOT", "value": value}]\n'
|
|
98
|
-
'- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
|
|
99
|
-
'- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
|
|
100
|
-
'- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
|
|
101
|
-
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
|
|
102
|
-
"Examples:\n"
|
|
103
|
-
'- Filter by category: [{"key": "category", "value": "electronics"}]\n'
|
|
104
|
-
'- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
|
|
105
|
-
'- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
|
|
106
|
-
'- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
# Create the input model
|
|
110
|
-
VectorSearchInput = create_model(
|
|
111
|
-
"VectorSearchInput",
|
|
112
|
-
query=(
|
|
113
|
-
str,
|
|
114
|
-
Field(description="The search query string to find relevant documents"),
|
|
115
|
-
),
|
|
116
|
-
filters=(
|
|
117
|
-
Optional[list[FilterItem]],
|
|
118
|
-
Field(default=None, description=filter_description),
|
|
119
|
-
),
|
|
120
|
-
__config__=ConfigDict(extra="forbid"),
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
return VectorSearchInput
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
52
|
+
@mlflow.trace(name="rerank_documents", span_type=SpanType.RERANKER)
|
|
127
53
|
def _rerank_documents(
|
|
128
54
|
query: str,
|
|
129
55
|
documents: list[Document],
|
|
@@ -148,6 +74,11 @@ def _rerank_documents(
|
|
|
148
74
|
model=rerank_config.model,
|
|
149
75
|
)
|
|
150
76
|
|
|
77
|
+
# Early return if no documents to rerank
|
|
78
|
+
if not documents:
|
|
79
|
+
logger.debug("No documents to rerank, skipping")
|
|
80
|
+
return documents
|
|
81
|
+
|
|
151
82
|
# Prepare passages for reranking
|
|
152
83
|
passages: list[dict[str, Any]] = [
|
|
153
84
|
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
@@ -233,9 +164,12 @@ def create_vector_search_tool(
|
|
|
233
164
|
raise ValueError("vector_store.index is required for vector search")
|
|
234
165
|
|
|
235
166
|
index_name: str = vector_store.index.full_name
|
|
236
|
-
columns: list[str] = list(retriever.columns or [])
|
|
167
|
+
columns: list[str] = list(retriever.columns or vector_store.index.columns or [])
|
|
237
168
|
search_parameters: SearchParametersModel = retriever.search_parameters
|
|
169
|
+
router_config: Optional[RouterModel] = retriever.router
|
|
238
170
|
rerank_config: Optional[RerankParametersModel] = retriever.rerank
|
|
171
|
+
instructed_config: Optional[InstructedRetrieverModel] = retriever.instructed
|
|
172
|
+
verifier_config: Optional[VerifierModel] = retriever.verifier
|
|
239
173
|
|
|
240
174
|
# Initialize FlashRank ranker if configured
|
|
241
175
|
ranker: Optional[Ranker] = None
|
|
@@ -246,50 +180,121 @@ def create_vector_search_tool(
|
|
|
246
180
|
top_n=rerank_config.top_n or "auto",
|
|
247
181
|
)
|
|
248
182
|
try:
|
|
249
|
-
|
|
183
|
+
# Use /tmp for cache in Model Serving (home dir may not be writable)
|
|
184
|
+
if is_in_model_serving():
|
|
185
|
+
cache_dir = "/tmp/dao_ai/cache/flashrank"
|
|
186
|
+
if rerank_config.cache_dir != cache_dir:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"FlashRank cache_dir overridden in Model Serving",
|
|
189
|
+
configured=rerank_config.cache_dir,
|
|
190
|
+
actual=cache_dir,
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
cache_dir = os.path.expanduser(rerank_config.cache_dir)
|
|
250
194
|
ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
|
|
195
|
+
|
|
196
|
+
# Patch rerank to always include token_type_ids for ONNX compatibility
|
|
197
|
+
# Some ONNX runtimes require token_type_ids even when the model doesn't use them
|
|
198
|
+
# FlashRank conditionally excludes them when all zeros, but ONNX may still expect them
|
|
199
|
+
# See: https://github.com/huggingface/optimum/issues/1500
|
|
200
|
+
if ranker.session is not None:
|
|
201
|
+
import numpy as np
|
|
202
|
+
|
|
203
|
+
_original_rerank = ranker.rerank
|
|
204
|
+
|
|
205
|
+
def _patched_rerank(request):
|
|
206
|
+
query = request.query
|
|
207
|
+
passages = request.passages
|
|
208
|
+
query_passage_pairs = [[query, p["text"]] for p in passages]
|
|
209
|
+
|
|
210
|
+
input_text = ranker.tokenizer.encode_batch(query_passage_pairs)
|
|
211
|
+
input_ids = np.array([e.ids for e in input_text])
|
|
212
|
+
token_type_ids = np.array([e.type_ids for e in input_text])
|
|
213
|
+
attention_mask = np.array([e.attention_mask for e in input_text])
|
|
214
|
+
|
|
215
|
+
# Always include token_type_ids (the fix for ONNX compatibility)
|
|
216
|
+
onnx_input = {
|
|
217
|
+
"input_ids": input_ids.astype(np.int64),
|
|
218
|
+
"attention_mask": attention_mask.astype(np.int64),
|
|
219
|
+
"token_type_ids": token_type_ids.astype(np.int64),
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
outputs = ranker.session.run(None, onnx_input)
|
|
223
|
+
logits = outputs[0]
|
|
224
|
+
|
|
225
|
+
if logits.shape[1] == 1:
|
|
226
|
+
scores = 1 / (1 + np.exp(-logits.flatten()))
|
|
227
|
+
else:
|
|
228
|
+
exp_logits = np.exp(logits)
|
|
229
|
+
scores = exp_logits[:, 1] / np.sum(exp_logits, axis=1)
|
|
230
|
+
|
|
231
|
+
for score, passage in zip(scores, passages):
|
|
232
|
+
passage["score"] = score
|
|
233
|
+
|
|
234
|
+
passages.sort(key=lambda x: x["score"], reverse=True)
|
|
235
|
+
return passages
|
|
236
|
+
|
|
237
|
+
ranker.rerank = _patched_rerank
|
|
238
|
+
|
|
251
239
|
logger.success("FlashRank ranker initialized", model=rerank_config.model)
|
|
252
240
|
except Exception as e:
|
|
253
241
|
logger.warning("Failed to initialize FlashRank ranker", error=str(e))
|
|
254
242
|
rerank_config = None
|
|
255
243
|
|
|
244
|
+
# Log instructed retrieval configuration
|
|
245
|
+
if instructed_config:
|
|
246
|
+
logger.success(
|
|
247
|
+
"Instructed retrieval configured",
|
|
248
|
+
decomposition_model=instructed_config.decomposition_model.name
|
|
249
|
+
if instructed_config.decomposition_model
|
|
250
|
+
else None,
|
|
251
|
+
max_subqueries=instructed_config.max_subqueries,
|
|
252
|
+
rrf_k=instructed_config.rrf_k,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Log instruction-aware reranking configuration
|
|
256
|
+
if rerank_config and rerank_config.instruction_aware:
|
|
257
|
+
logger.success(
|
|
258
|
+
"Instruction-aware reranking configured",
|
|
259
|
+
model=rerank_config.instruction_aware.model.name
|
|
260
|
+
if rerank_config.instruction_aware.model
|
|
261
|
+
else None,
|
|
262
|
+
top_n=rerank_config.instruction_aware.top_n,
|
|
263
|
+
)
|
|
264
|
+
|
|
256
265
|
# Build client_args for VectorSearchClient
|
|
257
|
-
# Use getattr to safely access attributes that may not exist (e.g., in mocks)
|
|
258
266
|
client_args: dict[str, Any] = {}
|
|
259
267
|
has_explicit_auth = any(
|
|
260
268
|
[
|
|
261
269
|
os.environ.get("DATABRICKS_TOKEN"),
|
|
262
270
|
os.environ.get("DATABRICKS_CLIENT_ID"),
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
271
|
+
vector_store.pat,
|
|
272
|
+
vector_store.client_id,
|
|
273
|
+
vector_store.on_behalf_of_user,
|
|
266
274
|
]
|
|
267
275
|
)
|
|
268
276
|
|
|
269
277
|
if has_explicit_auth:
|
|
270
278
|
databricks_host = os.environ.get("DATABRICKS_HOST")
|
|
271
|
-
if
|
|
272
|
-
|
|
273
|
-
and getattr(vector_store, "_workspace_client", None) is not None
|
|
274
|
-
):
|
|
275
|
-
databricks_host = vector_store.workspace_client.config.host
|
|
279
|
+
if not databricks_host and vector_store.workspace_host:
|
|
280
|
+
databricks_host = value_of(vector_store.workspace_host)
|
|
276
281
|
if databricks_host:
|
|
277
282
|
client_args["workspace_url"] = normalize_host(databricks_host)
|
|
278
283
|
|
|
279
284
|
token = os.environ.get("DATABRICKS_TOKEN")
|
|
280
|
-
if not token and
|
|
285
|
+
if not token and vector_store.pat:
|
|
281
286
|
token = value_of(vector_store.pat)
|
|
282
287
|
if token:
|
|
283
288
|
client_args["personal_access_token"] = token
|
|
284
289
|
|
|
285
290
|
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
|
|
286
|
-
if not client_id and
|
|
291
|
+
if not client_id and vector_store.client_id:
|
|
287
292
|
client_id = value_of(vector_store.client_id)
|
|
288
293
|
if client_id:
|
|
289
294
|
client_args["service_principal_client_id"] = client_id
|
|
290
295
|
|
|
291
296
|
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
|
|
292
|
-
if not client_secret and
|
|
297
|
+
if not client_secret and vector_store.client_secret:
|
|
293
298
|
client_secret = value_of(vector_store.client_secret)
|
|
294
299
|
if client_secret:
|
|
295
300
|
client_args["service_principal_client_secret"] = client_secret
|
|
@@ -342,62 +347,365 @@ def create_vector_search_tool(
|
|
|
342
347
|
|
|
343
348
|
# Determine tool name and description
|
|
344
349
|
tool_name: str = name or f"vector_search_{vector_store.index.name}"
|
|
345
|
-
|
|
350
|
+
|
|
351
|
+
# Build tool description with available columns for filtering
|
|
352
|
+
base_description: str = description or f"Search documents in {index_name}"
|
|
353
|
+
if columns:
|
|
354
|
+
columns_list = ", ".join(columns)
|
|
355
|
+
tool_description = (
|
|
356
|
+
f"{base_description}. "
|
|
357
|
+
f"Available filter columns: {columns_list}. "
|
|
358
|
+
f"Filter operators: 'column' for equality, 'column NOT' for exclusion, "
|
|
359
|
+
f"'column <', 'column <=', 'column >', 'column >=' for comparison, "
|
|
360
|
+
f"'column LIKE' for token matching, 'column NOT LIKE' to exclude tokens."
|
|
361
|
+
)
|
|
362
|
+
else:
|
|
363
|
+
tool_description = base_description
|
|
364
|
+
|
|
365
|
+
@mlflow.trace(name="execute_instructed_retrieval", span_type=SpanType.RETRIEVER)
|
|
366
|
+
def _execute_instructed_retrieval(
|
|
367
|
+
vs: DatabricksVectorSearch,
|
|
368
|
+
query: str,
|
|
369
|
+
base_filters: dict[str, Any],
|
|
370
|
+
previous_feedback: str | None = None,
|
|
371
|
+
) -> list[Document]:
|
|
372
|
+
"""Execute instructed retrieval with query decomposition and RRF merging."""
|
|
373
|
+
logger.trace(
|
|
374
|
+
"Executing instructed retrieval", query=query, base_filters=base_filters
|
|
375
|
+
)
|
|
376
|
+
try:
|
|
377
|
+
decomposition_llm = _get_cached_llm(instructed_config.decomposition_model)
|
|
378
|
+
|
|
379
|
+
# Fall back to retriever columns if instructed columns not provided
|
|
380
|
+
instructed_columns: list[ColumnInfo] | None = instructed_config.columns
|
|
381
|
+
if instructed_columns is None and columns:
|
|
382
|
+
instructed_columns = [ColumnInfo(name=col) for col in columns]
|
|
383
|
+
|
|
384
|
+
subqueries: list[SearchQuery] = decompose_query(
|
|
385
|
+
llm=decomposition_llm,
|
|
386
|
+
query=query,
|
|
387
|
+
schema_description=instructed_config.schema_description,
|
|
388
|
+
constraints=instructed_config.constraints,
|
|
389
|
+
max_subqueries=instructed_config.max_subqueries,
|
|
390
|
+
examples=instructed_config.examples,
|
|
391
|
+
previous_feedback=previous_feedback,
|
|
392
|
+
columns=instructed_columns,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
if not subqueries:
|
|
396
|
+
logger.warning(
|
|
397
|
+
"Query decomposition returned no subqueries, using original"
|
|
398
|
+
)
|
|
399
|
+
return vs.similarity_search(
|
|
400
|
+
query=query,
|
|
401
|
+
k=search_parameters.num_results or 5,
|
|
402
|
+
filter=base_filters if base_filters else None,
|
|
403
|
+
query_type=search_parameters.query_type or "ANN",
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
def normalize_filter_values(
|
|
407
|
+
filters: dict[str, Any], case: str | None
|
|
408
|
+
) -> dict[str, Any]:
|
|
409
|
+
"""Normalize string filter values to specified case."""
|
|
410
|
+
logger.trace("Normalizing filter values", filters=filters, case=case)
|
|
411
|
+
if not case or not filters:
|
|
412
|
+
return filters
|
|
413
|
+
normalized = {}
|
|
414
|
+
for key, value in filters.items():
|
|
415
|
+
if isinstance(value, str):
|
|
416
|
+
normalized[key] = (
|
|
417
|
+
value.upper() if case == "uppercase" else value.lower()
|
|
418
|
+
)
|
|
419
|
+
elif isinstance(value, list):
|
|
420
|
+
normalized[key] = [
|
|
421
|
+
v.upper()
|
|
422
|
+
if case == "uppercase"
|
|
423
|
+
else v.lower()
|
|
424
|
+
if isinstance(v, str)
|
|
425
|
+
else v
|
|
426
|
+
for v in value
|
|
427
|
+
]
|
|
428
|
+
else:
|
|
429
|
+
normalized[key] = value
|
|
430
|
+
return normalized
|
|
431
|
+
|
|
432
|
+
def execute_search(sq: SearchQuery) -> list[Document]:
|
|
433
|
+
logger.trace("Executing search", query=sq.text, filters=sq.filters)
|
|
434
|
+
# Convert FilterItem list to dict
|
|
435
|
+
sq_filters_dict: dict[str, Any] = {}
|
|
436
|
+
if sq.filters:
|
|
437
|
+
for item in sq.filters:
|
|
438
|
+
sq_filters_dict[item.key] = item.value
|
|
439
|
+
sq_filters = normalize_filter_values(
|
|
440
|
+
sq_filters_dict, instructed_config.normalize_filter_case
|
|
441
|
+
)
|
|
442
|
+
k: int = search_parameters.num_results or 5
|
|
443
|
+
query_type: str = search_parameters.query_type or "ANN"
|
|
444
|
+
combined_filters: dict[str, Any] = {**sq_filters, **base_filters}
|
|
445
|
+
logger.trace(
|
|
446
|
+
"Executing search",
|
|
447
|
+
query=sq.text,
|
|
448
|
+
k=k,
|
|
449
|
+
query_type=query_type,
|
|
450
|
+
filters=combined_filters,
|
|
451
|
+
)
|
|
452
|
+
return vs.similarity_search(
|
|
453
|
+
query=sq.text,
|
|
454
|
+
k=k,
|
|
455
|
+
filter=combined_filters if combined_filters else None,
|
|
456
|
+
query_type=query_type,
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
logger.debug(
|
|
460
|
+
"Executing parallel searches",
|
|
461
|
+
num_subqueries=len(subqueries),
|
|
462
|
+
queries=[sq.text[:50] for sq in subqueries],
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
with ThreadPoolExecutor(
|
|
466
|
+
max_workers=instructed_config.max_subqueries
|
|
467
|
+
) as executor:
|
|
468
|
+
all_results = list(executor.map(execute_search, subqueries))
|
|
469
|
+
|
|
470
|
+
merged = rrf_merge(
|
|
471
|
+
all_results,
|
|
472
|
+
k=instructed_config.rrf_k,
|
|
473
|
+
primary_key=vector_store.primary_key,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
logger.debug(
|
|
477
|
+
"Instructed retrieval complete",
|
|
478
|
+
num_subqueries=len(subqueries),
|
|
479
|
+
total_results=sum(len(r) for r in all_results),
|
|
480
|
+
merged_results=len(merged),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
return merged
|
|
484
|
+
|
|
485
|
+
except Exception as e:
|
|
486
|
+
logger.warning(
|
|
487
|
+
"Instructed retrieval failed, falling back to standard search",
|
|
488
|
+
error=str(e),
|
|
489
|
+
)
|
|
490
|
+
return vs.similarity_search(
|
|
491
|
+
query=query,
|
|
492
|
+
k=search_parameters.num_results or 5,
|
|
493
|
+
filter=base_filters if base_filters else None,
|
|
494
|
+
query_type=search_parameters.query_type or "ANN",
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
@mlflow.trace(name="execute_standard_search", span_type=SpanType.RETRIEVER)
|
|
498
|
+
def _execute_standard_search(
|
|
499
|
+
vs: DatabricksVectorSearch,
|
|
500
|
+
query: str,
|
|
501
|
+
base_filters: dict[str, Any],
|
|
502
|
+
) -> list[Document]:
|
|
503
|
+
"""Execute standard single-query search."""
|
|
504
|
+
logger.trace("Performing standard vector search", query_preview=query[:50])
|
|
505
|
+
return vs.similarity_search(
|
|
506
|
+
query=query,
|
|
507
|
+
k=search_parameters.num_results or 5,
|
|
508
|
+
filter=base_filters if base_filters else None,
|
|
509
|
+
query_type=search_parameters.query_type or "ANN",
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
@mlflow.trace(name="apply_post_processing", span_type=SpanType.RETRIEVER)
|
|
513
|
+
def _apply_post_processing(
|
|
514
|
+
documents: list[Document],
|
|
515
|
+
query: str,
|
|
516
|
+
mode: Literal["standard", "instructed"],
|
|
517
|
+
auto_bypass: bool,
|
|
518
|
+
) -> list[Document]:
|
|
519
|
+
"""Apply instruction-aware reranking and verification based on mode and bypass settings."""
|
|
520
|
+
# Skip post-processing for standard mode when auto_bypass is enabled
|
|
521
|
+
if mode == "standard" and auto_bypass:
|
|
522
|
+
mlflow.set_tag("router.bypassed_stages", "true")
|
|
523
|
+
return documents
|
|
524
|
+
|
|
525
|
+
# Apply instruction-aware reranking if configured
|
|
526
|
+
if rerank_config and rerank_config.instruction_aware:
|
|
527
|
+
instruction_config = rerank_config.instruction_aware
|
|
528
|
+
instruction_llm = (
|
|
529
|
+
_get_cached_llm(instruction_config.model)
|
|
530
|
+
if instruction_config.model
|
|
531
|
+
else None
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
if instruction_llm:
|
|
535
|
+
schema_desc = (
|
|
536
|
+
instructed_config.schema_description if instructed_config else None
|
|
537
|
+
)
|
|
538
|
+
# Get columns for dynamic instruction generation
|
|
539
|
+
rerank_columns: list[ColumnInfo] | None = None
|
|
540
|
+
if instructed_config and instructed_config.columns:
|
|
541
|
+
rerank_columns = instructed_config.columns
|
|
542
|
+
elif columns:
|
|
543
|
+
rerank_columns = [ColumnInfo(name=col) for col in columns]
|
|
544
|
+
|
|
545
|
+
documents = instruction_aware_rerank(
|
|
546
|
+
llm=instruction_llm,
|
|
547
|
+
query=query,
|
|
548
|
+
documents=documents,
|
|
549
|
+
instructions=instruction_config.instructions,
|
|
550
|
+
schema_description=schema_desc,
|
|
551
|
+
columns=rerank_columns,
|
|
552
|
+
top_n=instruction_config.top_n,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Apply verification if configured
|
|
556
|
+
if verifier_config:
|
|
557
|
+
verifier_llm = (
|
|
558
|
+
_get_cached_llm(verifier_config.model)
|
|
559
|
+
if verifier_config.model
|
|
560
|
+
else None
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
if verifier_llm:
|
|
564
|
+
schema_desc = (
|
|
565
|
+
instructed_config.schema_description if instructed_config else ""
|
|
566
|
+
)
|
|
567
|
+
constraints = (
|
|
568
|
+
instructed_config.constraints if instructed_config else None
|
|
569
|
+
)
|
|
570
|
+
retry_count = 0
|
|
571
|
+
verification_result: VerificationResult | None = None
|
|
572
|
+
previous_feedback: str | None = None
|
|
573
|
+
|
|
574
|
+
while retry_count <= verifier_config.max_retries:
|
|
575
|
+
verification_result = verify_results(
|
|
576
|
+
llm=verifier_llm,
|
|
577
|
+
query=query,
|
|
578
|
+
documents=documents,
|
|
579
|
+
schema_description=schema_desc,
|
|
580
|
+
constraints=constraints,
|
|
581
|
+
previous_feedback=previous_feedback,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
if verification_result.passed:
|
|
585
|
+
mlflow.set_tag("verifier.outcome", "passed")
|
|
586
|
+
mlflow.set_tag("verifier.retries", str(retry_count))
|
|
587
|
+
break
|
|
588
|
+
|
|
589
|
+
# Handle failure based on configuration
|
|
590
|
+
if verifier_config.on_failure == "warn":
|
|
591
|
+
mlflow.set_tag("verifier.outcome", "warned")
|
|
592
|
+
documents = add_verification_metadata(
|
|
593
|
+
documents, verification_result
|
|
594
|
+
)
|
|
595
|
+
break
|
|
596
|
+
|
|
597
|
+
if retry_count >= verifier_config.max_retries:
|
|
598
|
+
mlflow.set_tag("verifier.outcome", "exhausted")
|
|
599
|
+
mlflow.set_tag("verifier.retries", str(retry_count))
|
|
600
|
+
documents = add_verification_metadata(
|
|
601
|
+
documents, verification_result, exhausted=True
|
|
602
|
+
)
|
|
603
|
+
break
|
|
604
|
+
|
|
605
|
+
# Retry with feedback
|
|
606
|
+
mlflow.set_tag("verifier.outcome", "retried")
|
|
607
|
+
previous_feedback = verification_result.feedback
|
|
608
|
+
retry_count += 1
|
|
609
|
+
logger.debug(
|
|
610
|
+
"Retrying search with verification feedback", retry=retry_count
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
return documents
|
|
346
614
|
|
|
347
615
|
# Use @tool decorator for proper ToolRuntime injection
|
|
348
|
-
# The decorator ensures runtime is automatically injected and hidden from the LLM
|
|
349
616
|
@tool(name_or_callable=tool_name, description=tool_description)
|
|
350
|
-
def
|
|
617
|
+
def _vector_search_tool(
|
|
351
618
|
query: Annotated[str, "The search query to find relevant documents"],
|
|
352
619
|
filters: Annotated[
|
|
353
620
|
Optional[list[FilterItem]],
|
|
354
|
-
"Optional filters
|
|
621
|
+
"Optional filters as key-value pairs. "
|
|
622
|
+
"Key operators: 'column' (equality), 'column NOT' (exclusion), "
|
|
623
|
+
"'column <', '<=', '>', '>=' (comparison), "
|
|
624
|
+
"'column LIKE' (token match), 'column NOT LIKE' (exclude token). "
|
|
625
|
+
f"Valid columns: {', '.join(columns) if columns else 'none'}.",
|
|
355
626
|
] = None,
|
|
356
627
|
runtime: ToolRuntime[Context] = None,
|
|
357
628
|
) -> str:
|
|
358
629
|
"""Search for relevant documents using vector similarity."""
|
|
359
|
-
# Get context for OBO support
|
|
360
630
|
context: Context | None = runtime.context if runtime else None
|
|
631
|
+
vs: DatabricksVectorSearch = _get_vector_search(context)
|
|
361
632
|
|
|
362
|
-
# Get vector search instance with OBO support
|
|
363
|
-
vector_search: DatabricksVectorSearch = _get_vector_search(context)
|
|
364
|
-
|
|
365
|
-
# Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
|
|
366
633
|
filters_dict: dict[str, Any] = {}
|
|
367
634
|
if filters:
|
|
368
635
|
for item in filters:
|
|
369
636
|
filters_dict[item.key] = item.value
|
|
370
637
|
|
|
371
|
-
|
|
372
|
-
combined_filters: dict[str, Any] = {
|
|
638
|
+
base_filters: dict[str, Any] = {
|
|
373
639
|
**filters_dict,
|
|
374
640
|
**(search_parameters.filters or {}),
|
|
375
641
|
}
|
|
376
642
|
|
|
377
|
-
#
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
643
|
+
# Determine execution mode via router or config
|
|
644
|
+
mode: Literal["standard", "instructed"] = "standard"
|
|
645
|
+
auto_bypass = True
|
|
646
|
+
|
|
647
|
+
logger.trace("Router configuration", router_config=router_config)
|
|
648
|
+
logger.trace("Instructed configuration", instructed_config=instructed_config)
|
|
649
|
+
logger.trace(
|
|
650
|
+
"Instruction-aware rerank configuration",
|
|
651
|
+
instruction_aware=rerank_config.instruction_aware
|
|
652
|
+
if rerank_config
|
|
653
|
+
else None,
|
|
384
654
|
)
|
|
385
655
|
|
|
656
|
+
if router_config:
|
|
657
|
+
router_llm = (
|
|
658
|
+
_get_cached_llm(router_config.model) if router_config.model else None
|
|
659
|
+
)
|
|
660
|
+
auto_bypass = router_config.auto_bypass
|
|
661
|
+
|
|
662
|
+
if router_llm and instructed_config:
|
|
663
|
+
try:
|
|
664
|
+
mode = route_query(
|
|
665
|
+
llm=router_llm,
|
|
666
|
+
query=query,
|
|
667
|
+
schema_description=instructed_config.schema_description,
|
|
668
|
+
)
|
|
669
|
+
except Exception as e:
|
|
670
|
+
# Router fail-safe: default to standard mode
|
|
671
|
+
logger.warning(
|
|
672
|
+
"Router failed, defaulting to standard mode", error=str(e)
|
|
673
|
+
)
|
|
674
|
+
mlflow.set_tag("router.fallback", "true")
|
|
675
|
+
mode = router_config.default_mode
|
|
676
|
+
else:
|
|
677
|
+
mode = router_config.default_mode
|
|
678
|
+
elif instructed_config:
|
|
679
|
+
# No router but instructed is configured - use instructed mode
|
|
680
|
+
mode = "instructed"
|
|
681
|
+
auto_bypass = False
|
|
682
|
+
elif rerank_config and rerank_config.instruction_aware:
|
|
683
|
+
# No router/instructed but instruction_aware reranking is configured
|
|
684
|
+
# Disable auto_bypass to ensure instruction_aware reranking runs
|
|
685
|
+
auto_bypass = False
|
|
686
|
+
|
|
687
|
+
logger.trace("Routing mode", mode=mode, auto_bypass=auto_bypass)
|
|
688
|
+
mlflow.set_tag("router.mode", mode)
|
|
689
|
+
|
|
690
|
+
# Execute search based on mode
|
|
691
|
+
if mode == "instructed" and instructed_config:
|
|
692
|
+
documents = _execute_instructed_retrieval(vs, query, base_filters)
|
|
693
|
+
else:
|
|
694
|
+
documents = _execute_standard_search(vs, query, base_filters)
|
|
695
|
+
|
|
386
696
|
# Apply FlashRank reranking if configured
|
|
387
697
|
if ranker and rerank_config:
|
|
388
698
|
logger.debug("Applying FlashRank reranking")
|
|
389
699
|
documents = _rerank_documents(query, documents, ranker, rerank_config)
|
|
390
700
|
|
|
701
|
+
# Apply post-processing (instruction reranking + verification)
|
|
702
|
+
documents = _apply_post_processing(documents, query, mode, auto_bypass)
|
|
703
|
+
|
|
391
704
|
# Serialize documents to JSON format for LLM consumption
|
|
392
|
-
# Convert Document objects to dicts with page_content and metadata
|
|
393
|
-
# Need to handle numpy types in metadata (e.g., float32, int64)
|
|
394
705
|
serialized_docs: list[dict[str, Any]] = []
|
|
395
706
|
for doc in documents:
|
|
396
|
-
doc: Document
|
|
397
|
-
# Convert metadata values to JSON-serializable types
|
|
398
707
|
metadata_serializable: dict[str, Any] = {}
|
|
399
708
|
for key, value in doc.metadata.items():
|
|
400
|
-
# Handle numpy types
|
|
401
709
|
if hasattr(value, "item"): # numpy scalar
|
|
402
710
|
metadata_serializable[key] = value.item()
|
|
403
711
|
else:
|
|
@@ -410,9 +718,8 @@ def create_vector_search_tool(
|
|
|
410
718
|
}
|
|
411
719
|
)
|
|
412
720
|
|
|
413
|
-
# Return as JSON string
|
|
414
721
|
return json.dumps(serialized_docs)
|
|
415
722
|
|
|
416
723
|
logger.success("Vector search tool created", name=tool_name, index=index_name)
|
|
417
724
|
|
|
418
|
-
return
|
|
725
|
+
return _vector_search_tool
|