dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/vector_search.py
CHANGED
|
@@ -1,282 +1,392 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
|
+
|
|
4
|
+
This module provides a tool factory for creating semantic search tools
|
|
5
|
+
with dynamic filter schemas based on table columns and FlashRank reranking support.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, Optional
|
|
2
11
|
|
|
3
12
|
import mlflow
|
|
13
|
+
from databricks.sdk import WorkspaceClient
|
|
4
14
|
from databricks.vector_search.reranker import DatabricksReranker
|
|
5
|
-
from
|
|
6
|
-
FilterItem,
|
|
7
|
-
VectorSearchRetrieverToolInput,
|
|
8
|
-
)
|
|
9
|
-
from databricks_langchain.vectorstores import DatabricksVectorSearch
|
|
15
|
+
from databricks_langchain import DatabricksVectorSearch
|
|
10
16
|
from flashrank import Ranker, RerankRequest
|
|
11
17
|
from langchain_core.documents import Document
|
|
12
|
-
from langchain_core.tools import
|
|
13
|
-
from langgraph.prebuilt import InjectedState
|
|
18
|
+
from langchain_core.tools import StructuredTool
|
|
14
19
|
from loguru import logger
|
|
15
20
|
from mlflow.entities import SpanType
|
|
21
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
16
22
|
|
|
17
23
|
from dao_ai.config import (
|
|
18
24
|
RerankParametersModel,
|
|
19
25
|
RetrieverModel,
|
|
26
|
+
SearchParametersModel,
|
|
20
27
|
VectorStoreModel,
|
|
28
|
+
value_of,
|
|
21
29
|
)
|
|
30
|
+
from dao_ai.utils import normalize_host
|
|
31
|
+
|
|
32
|
+
# Create FilterItem model at module level so it can be used in type hints
|
|
33
|
+
FilterItem = create_model(
|
|
34
|
+
"FilterItem",
|
|
35
|
+
key=(
|
|
36
|
+
str,
|
|
37
|
+
Field(
|
|
38
|
+
description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
|
|
39
|
+
),
|
|
40
|
+
),
|
|
41
|
+
value=(
|
|
42
|
+
Any,
|
|
43
|
+
Field(
|
|
44
|
+
description="The filter value, which can be a single value or an array of values"
|
|
45
|
+
),
|
|
46
|
+
),
|
|
47
|
+
__config__=ConfigDict(extra="forbid"),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _create_dynamic_input_schema(
|
|
52
|
+
index_name: str, workspace_client: WorkspaceClient
|
|
53
|
+
) -> type[BaseModel]:
|
|
54
|
+
"""
|
|
55
|
+
Create dynamic input schema with column information from the table.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
index_name: Full name of the vector search index
|
|
59
|
+
workspace_client: Workspace client to query table metadata
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Pydantic model class for tool input
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Try to get column information
|
|
66
|
+
column_descriptions = []
|
|
67
|
+
try:
|
|
68
|
+
table_info = workspace_client.tables.get(full_name=index_name)
|
|
69
|
+
for column_info in table_info.columns:
|
|
70
|
+
name = column_info.name
|
|
71
|
+
col_type = column_info.type_name.name
|
|
72
|
+
if not name.startswith("__"):
|
|
73
|
+
column_descriptions.append(f"{name} ({col_type})")
|
|
74
|
+
except Exception:
|
|
75
|
+
logger.debug(
|
|
76
|
+
"Could not retrieve column information for dynamic schema",
|
|
77
|
+
index=index_name,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Build filter description matching VectorSearchRetrieverTool format
|
|
81
|
+
filter_description = (
|
|
82
|
+
"Optional filters to refine vector search results as an array of key-value pairs. "
|
|
83
|
+
"IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
|
|
84
|
+
"then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if column_descriptions:
|
|
88
|
+
filter_description += (
|
|
89
|
+
f"Available columns for filtering: {', '.join(column_descriptions)}. "
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
filter_description += (
|
|
93
|
+
"Supports the following operators:\n\n"
|
|
94
|
+
'- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
|
|
95
|
+
'- Exclusion: [{"key": "column NOT", "value": value}]\n'
|
|
96
|
+
'- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
|
|
97
|
+
'- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
|
|
98
|
+
'- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
|
|
99
|
+
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
|
|
100
|
+
"Examples:\n"
|
|
101
|
+
'- Filter by category: [{"key": "category", "value": "electronics"}]\n'
|
|
102
|
+
'- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
|
|
103
|
+
'- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
|
|
104
|
+
'- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Create the input model
|
|
108
|
+
VectorSearchInput = create_model(
|
|
109
|
+
"VectorSearchInput",
|
|
110
|
+
query=(
|
|
111
|
+
str,
|
|
112
|
+
Field(description="The search query string to find relevant documents"),
|
|
113
|
+
),
|
|
114
|
+
filters=(
|
|
115
|
+
Optional[list[FilterItem]],
|
|
116
|
+
Field(default=None, description=filter_description),
|
|
117
|
+
),
|
|
118
|
+
__config__=ConfigDict(extra="forbid"),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return VectorSearchInput
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
125
|
+
def _rerank_documents(
|
|
126
|
+
query: str,
|
|
127
|
+
documents: list[Document],
|
|
128
|
+
ranker: Ranker,
|
|
129
|
+
rerank_config: RerankParametersModel,
|
|
130
|
+
) -> list[Document]:
|
|
131
|
+
"""
|
|
132
|
+
Rerank documents using FlashRank cross-encoder model.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
query: The search query string
|
|
136
|
+
documents: List of documents to rerank
|
|
137
|
+
ranker: The FlashRank Ranker instance
|
|
138
|
+
rerank_config: Reranking configuration
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Reranked list of documents with reranker_score in metadata
|
|
142
|
+
"""
|
|
143
|
+
logger.trace(
|
|
144
|
+
"Starting reranking",
|
|
145
|
+
documents_count=len(documents),
|
|
146
|
+
model=rerank_config.model,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Prepare passages for reranking
|
|
150
|
+
passages: list[dict[str, Any]] = [
|
|
151
|
+
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
# Create reranking request
|
|
155
|
+
rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
|
|
156
|
+
|
|
157
|
+
# Perform reranking
|
|
158
|
+
results: list[dict[str, Any]] = ranker.rerank(rerank_request)
|
|
159
|
+
|
|
160
|
+
# Apply top_n filtering
|
|
161
|
+
top_n: int = rerank_config.top_n or len(documents)
|
|
162
|
+
results = results[:top_n]
|
|
163
|
+
logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
|
|
164
|
+
|
|
165
|
+
# Convert back to Document objects with reranking scores
|
|
166
|
+
reranked_docs: list[Document] = []
|
|
167
|
+
for result in results:
|
|
168
|
+
orig_doc: Optional[Document] = next(
|
|
169
|
+
(doc for doc in documents if doc.page_content == result["text"]), None
|
|
170
|
+
)
|
|
171
|
+
if orig_doc:
|
|
172
|
+
reranked_doc: Document = Document(
|
|
173
|
+
page_content=orig_doc.page_content,
|
|
174
|
+
metadata={
|
|
175
|
+
**orig_doc.metadata,
|
|
176
|
+
"reranker_score": result["score"],
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
reranked_docs.append(reranked_doc)
|
|
180
|
+
|
|
181
|
+
logger.debug(
|
|
182
|
+
"Documents reranked",
|
|
183
|
+
input_count=len(documents),
|
|
184
|
+
output_count=len(reranked_docs),
|
|
185
|
+
model=rerank_config.model,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return reranked_docs
|
|
22
189
|
|
|
23
190
|
|
|
24
191
|
def create_vector_search_tool(
|
|
25
|
-
retriever: RetrieverModel | dict[str, Any],
|
|
192
|
+
retriever: Optional[RetrieverModel | dict[str, Any]] = None,
|
|
193
|
+
vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
|
|
26
194
|
name: Optional[str] = None,
|
|
27
195
|
description: Optional[str] = None,
|
|
28
|
-
) ->
|
|
196
|
+
) -> StructuredTool:
|
|
29
197
|
"""
|
|
30
|
-
Create a Vector Search tool
|
|
31
|
-
|
|
32
|
-
This function creates a tool that enables semantic search over product information,
|
|
33
|
-
documentation, or other content using the @tool decorator pattern. It supports optional
|
|
34
|
-
reranking of results using FlashRank for improved relevance.
|
|
198
|
+
Create a Vector Search tool with dynamic schema and optional reranking.
|
|
35
199
|
|
|
36
200
|
Args:
|
|
37
|
-
retriever:
|
|
38
|
-
|
|
39
|
-
- description: Description of the tool's purpose
|
|
40
|
-
- primary_key: Primary key column for the vector store
|
|
41
|
-
- text_column: Text column used for vector search
|
|
42
|
-
- doc_uri: URI for documentation or additional context
|
|
43
|
-
- vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
|
|
44
|
-
- columns: List of columns to retrieve from the vector store
|
|
45
|
-
- search_parameters: Additional parameters for customizing the search behavior
|
|
46
|
-
- rerank: Optional rerank configuration for result reranking
|
|
201
|
+
retriever: Full retriever configuration with search parameters and reranking
|
|
202
|
+
vector_store: Direct vector store reference (uses default search parameters)
|
|
47
203
|
name: Optional custom name for the tool
|
|
48
204
|
description: Optional custom description for the tool
|
|
49
205
|
|
|
50
206
|
Returns:
|
|
51
|
-
A LangChain
|
|
207
|
+
A LangChain StructuredTool with proper schema (additionalProperties: false)
|
|
52
208
|
"""
|
|
53
209
|
|
|
54
|
-
|
|
55
|
-
|
|
210
|
+
# Validate mutually exclusive parameters
|
|
211
|
+
if retriever is None and vector_store is None:
|
|
212
|
+
raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
|
|
213
|
+
if retriever is not None and vector_store is not None:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Cannot provide both 'retriever' and 'vector_store' parameters"
|
|
216
|
+
)
|
|
56
217
|
|
|
57
|
-
|
|
218
|
+
# Handle vector_store parameter
|
|
219
|
+
if vector_store is not None:
|
|
220
|
+
if isinstance(vector_store, dict):
|
|
221
|
+
vector_store = VectorStoreModel(**vector_store)
|
|
222
|
+
retriever = RetrieverModel(vector_store=vector_store)
|
|
223
|
+
else:
|
|
224
|
+
if isinstance(retriever, dict):
|
|
225
|
+
retriever = RetrieverModel(**retriever)
|
|
58
226
|
|
|
59
|
-
|
|
60
|
-
if vector_store_config.index is None:
|
|
61
|
-
raise ValueError("vector_store.index is required for vector search")
|
|
227
|
+
vector_store: VectorStoreModel = retriever.vector_store
|
|
62
228
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
primary_key: str = vector_store_config.primary_key or ""
|
|
67
|
-
doc_uri: str = vector_store_config.doc_uri or ""
|
|
68
|
-
text_column: str = vector_store_config.embedding_source_column
|
|
229
|
+
# Index is required
|
|
230
|
+
if vector_store.index is None:
|
|
231
|
+
raise ValueError("vector_store.index is required for vector search")
|
|
69
232
|
|
|
70
|
-
|
|
71
|
-
|
|
233
|
+
index_name: str = vector_store.index.full_name
|
|
234
|
+
columns: list[str] = list(retriever.columns or [])
|
|
235
|
+
search_parameters: SearchParametersModel = retriever.search_parameters
|
|
236
|
+
rerank_config: Optional[RerankParametersModel] = retriever.rerank
|
|
72
237
|
|
|
73
|
-
# Initialize FlashRank ranker
|
|
74
|
-
# This is expensive (loads model weights), so we do it once and reuse across invocations
|
|
238
|
+
# Initialize FlashRank ranker if configured
|
|
75
239
|
ranker: Optional[Ranker] = None
|
|
76
|
-
if
|
|
240
|
+
if rerank_config and rerank_config.model:
|
|
77
241
|
logger.debug(
|
|
78
|
-
|
|
79
|
-
|
|
242
|
+
"Initializing FlashRank ranker",
|
|
243
|
+
model=rerank_config.model,
|
|
244
|
+
top_n=rerank_config.top_n or "auto",
|
|
80
245
|
)
|
|
81
246
|
try:
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
)
|
|
85
|
-
logger.info(
|
|
86
|
-
f"FlashRank ranker initialized successfully (model: {reranker_config.model})"
|
|
87
|
-
)
|
|
247
|
+
cache_dir = os.path.expanduser(rerank_config.cache_dir)
|
|
248
|
+
ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
|
|
249
|
+
logger.success("FlashRank ranker initialized", model=rerank_config.model)
|
|
88
250
|
except Exception as e:
|
|
89
|
-
logger.warning(
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
251
|
+
logger.warning("Failed to initialize FlashRank ranker", error=str(e))
|
|
252
|
+
rerank_config = None
|
|
253
|
+
|
|
254
|
+
# Build client_args for VectorSearchClient
|
|
255
|
+
# Use getattr to safely access attributes that may not exist (e.g., in mocks)
|
|
256
|
+
client_args: dict[str, Any] = {}
|
|
257
|
+
has_explicit_auth = any(
|
|
258
|
+
[
|
|
259
|
+
os.environ.get("DATABRICKS_TOKEN"),
|
|
260
|
+
os.environ.get("DATABRICKS_CLIENT_ID"),
|
|
261
|
+
getattr(vector_store, "pat", None),
|
|
262
|
+
getattr(vector_store, "client_id", None),
|
|
263
|
+
getattr(vector_store, "on_behalf_of_user", None),
|
|
264
|
+
]
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if has_explicit_auth:
|
|
268
|
+
databricks_host = os.environ.get("DATABRICKS_HOST")
|
|
269
|
+
if (
|
|
270
|
+
not databricks_host
|
|
271
|
+
and getattr(vector_store, "_workspace_client", None) is not None
|
|
272
|
+
):
|
|
273
|
+
databricks_host = vector_store.workspace_client.config.host
|
|
274
|
+
if databricks_host:
|
|
275
|
+
client_args["workspace_url"] = normalize_host(databricks_host)
|
|
276
|
+
|
|
277
|
+
token = os.environ.get("DATABRICKS_TOKEN")
|
|
278
|
+
if not token and getattr(vector_store, "pat", None):
|
|
279
|
+
token = value_of(vector_store.pat)
|
|
280
|
+
if token:
|
|
281
|
+
client_args["personal_access_token"] = token
|
|
282
|
+
|
|
283
|
+
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
|
|
284
|
+
if not client_id and getattr(vector_store, "client_id", None):
|
|
285
|
+
client_id = value_of(vector_store.client_id)
|
|
286
|
+
if client_id:
|
|
287
|
+
client_args["service_principal_client_id"] = client_id
|
|
288
|
+
|
|
289
|
+
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
|
|
290
|
+
if not client_secret and getattr(vector_store, "client_secret", None):
|
|
291
|
+
client_secret = value_of(vector_store.client_secret)
|
|
292
|
+
if client_secret:
|
|
293
|
+
client_args["service_principal_client_secret"] = client_secret
|
|
294
|
+
|
|
295
|
+
logger.debug(
|
|
296
|
+
"Creating vector search tool",
|
|
297
|
+
name=name,
|
|
298
|
+
index=index_name,
|
|
299
|
+
client_args_keys=list(client_args.keys()) if client_args else [],
|
|
300
|
+
)
|
|
99
301
|
|
|
100
|
-
#
|
|
101
|
-
# Note: text_column
|
|
102
|
-
#
|
|
103
|
-
|
|
302
|
+
# Create DatabricksVectorSearch
|
|
303
|
+
# Note: text_column should be None for Databricks-managed embeddings
|
|
304
|
+
# (it's automatically determined from the index)
|
|
305
|
+
vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
|
|
104
306
|
index_name=index_name,
|
|
105
|
-
text_column=None,
|
|
307
|
+
text_column=None,
|
|
106
308
|
columns=columns,
|
|
309
|
+
workspace_client=vector_store.workspace_client,
|
|
310
|
+
client_args=client_args if client_args else None,
|
|
311
|
+
primary_key=vector_store.primary_key,
|
|
312
|
+
doc_uri=vector_store.doc_uri,
|
|
107
313
|
include_score=True,
|
|
108
|
-
|
|
314
|
+
reranker=(
|
|
315
|
+
DatabricksReranker(columns_to_rerank=rerank_config.columns)
|
|
316
|
+
if rerank_config and rerank_config.columns
|
|
317
|
+
else None
|
|
318
|
+
),
|
|
109
319
|
)
|
|
110
320
|
|
|
111
|
-
#
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
primary_key=primary_key,
|
|
115
|
-
text_column=text_column,
|
|
116
|
-
doc_uri=doc_uri,
|
|
117
|
-
other_columns=list(columns),
|
|
321
|
+
# Create dynamic input schema
|
|
322
|
+
input_schema: type[BaseModel] = _create_dynamic_input_schema(
|
|
323
|
+
index_name, vector_store.workspace_client
|
|
118
324
|
)
|
|
119
325
|
|
|
120
|
-
#
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
# Convert filters to dict format
|
|
326
|
+
# Define the tool function
|
|
327
|
+
def vector_search_func(
|
|
328
|
+
query: str, filters: Optional[list[FilterItem]] = None
|
|
329
|
+
) -> str:
|
|
330
|
+
"""Search for relevant documents using vector similarity."""
|
|
331
|
+
# Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
|
|
127
332
|
filters_dict: dict[str, Any] = {}
|
|
128
333
|
if filters:
|
|
129
334
|
for item in filters:
|
|
130
|
-
|
|
131
|
-
filters_dict[item_dict["key"]] = item_dict["value"]
|
|
335
|
+
filters_dict[item.key] = item.value
|
|
132
336
|
|
|
133
|
-
# Merge with
|
|
337
|
+
# Merge with configured filters
|
|
134
338
|
combined_filters: dict[str, Any] = {
|
|
135
339
|
**filters_dict,
|
|
136
|
-
**search_parameters.
|
|
340
|
+
**(search_parameters.filters or {}),
|
|
137
341
|
}
|
|
138
342
|
|
|
139
|
-
# Perform
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
343
|
+
# Perform vector search
|
|
344
|
+
logger.trace("Performing vector search", query_preview=query[:50])
|
|
345
|
+
documents: list[Document] = vector_search.similarity_search(
|
|
346
|
+
query=query,
|
|
347
|
+
k=search_parameters.num_results or 5,
|
|
348
|
+
filter=combined_filters if combined_filters else None,
|
|
349
|
+
query_type=search_parameters.query_type or "ANN",
|
|
145
350
|
)
|
|
146
351
|
|
|
147
|
-
#
|
|
148
|
-
|
|
149
|
-
"
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
352
|
+
# Apply FlashRank reranking if configured
|
|
353
|
+
if ranker and rerank_config:
|
|
354
|
+
logger.debug("Applying FlashRank reranking")
|
|
355
|
+
documents = _rerank_documents(query, documents, ranker, rerank_config)
|
|
356
|
+
|
|
357
|
+
# Serialize documents to JSON format for LLM consumption
|
|
358
|
+
# Convert Document objects to dicts with page_content and metadata
|
|
359
|
+
# Need to handle numpy types in metadata (e.g., float32, int64)
|
|
360
|
+
serialized_docs: list[dict[str, Any]] = []
|
|
361
|
+
for doc in documents:
|
|
362
|
+
doc: Document
|
|
363
|
+
# Convert metadata values to JSON-serializable types
|
|
364
|
+
metadata_serializable: dict[str, Any] = {}
|
|
365
|
+
for key, value in doc.metadata.items():
|
|
366
|
+
# Handle numpy types
|
|
367
|
+
if hasattr(value, "item"): # numpy scalar
|
|
368
|
+
metadata_serializable[key] = value.item()
|
|
369
|
+
else:
|
|
370
|
+
metadata_serializable[key] = value
|
|
371
|
+
|
|
372
|
+
serialized_docs.append(
|
|
373
|
+
{
|
|
374
|
+
"page_content": doc.page_content,
|
|
375
|
+
"metadata": metadata_serializable,
|
|
376
|
+
}
|
|
159
377
|
)
|
|
160
378
|
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
logger.debug(f"Retrieved {len(documents)} documents from vector search")
|
|
164
|
-
return documents
|
|
165
|
-
|
|
166
|
-
# Helper function to rerank documents
|
|
167
|
-
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
168
|
-
def _rerank_documents(query: str, documents: List[Document]) -> List[Document]:
|
|
169
|
-
"""Rerank documents using FlashRank.
|
|
379
|
+
# Return as JSON string
|
|
380
|
+
return json.dumps(serialized_docs)
|
|
170
381
|
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
logger.debug(
|
|
178
|
-
f"Starting reranking for {len(documents)} documents using model '{reranker_config.model}'"
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
# Prepare passages for reranking
|
|
182
|
-
passages: List[dict[str, Any]] = [
|
|
183
|
-
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
184
|
-
]
|
|
185
|
-
|
|
186
|
-
# Create reranking request
|
|
187
|
-
rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
|
|
188
|
-
|
|
189
|
-
# Perform reranking
|
|
190
|
-
logger.debug(f"Reranking {len(passages)} passages for query: '{query[:50]}...'")
|
|
191
|
-
results: List[dict[str, Any]] = ranker.rerank(rerank_request)
|
|
192
|
-
|
|
193
|
-
# Apply top_n filtering
|
|
194
|
-
top_n: int = reranker_config.top_n or len(documents)
|
|
195
|
-
results = results[:top_n]
|
|
196
|
-
logger.debug(
|
|
197
|
-
f"Reranking complete. Filtered to top {top_n} results from {len(documents)} candidates"
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
# Convert back to Document objects with reranking scores
|
|
201
|
-
reranked_docs: List[Document] = []
|
|
202
|
-
for result in results:
|
|
203
|
-
# Find original document by matching text
|
|
204
|
-
orig_doc: Optional[Document] = next(
|
|
205
|
-
(doc for doc in documents if doc.page_content == result["text"]), None
|
|
206
|
-
)
|
|
207
|
-
if orig_doc:
|
|
208
|
-
# Add reranking score to metadata
|
|
209
|
-
reranked_doc: Document = Document(
|
|
210
|
-
page_content=orig_doc.page_content,
|
|
211
|
-
metadata={
|
|
212
|
-
**orig_doc.metadata,
|
|
213
|
-
"reranker_score": result["score"],
|
|
214
|
-
},
|
|
215
|
-
)
|
|
216
|
-
reranked_docs.append(reranked_doc)
|
|
217
|
-
|
|
218
|
-
logger.debug(
|
|
219
|
-
f"Reranked {len(documents)} documents → {len(reranked_docs)} results "
|
|
220
|
-
f"(model: {reranker_config.model}, top score: {reranked_docs[0].metadata.get('reranker_score', 0):.4f})"
|
|
221
|
-
if reranked_docs
|
|
222
|
-
else f"Reranking completed with {len(reranked_docs)} results"
|
|
223
|
-
)
|
|
224
|
-
|
|
225
|
-
return reranked_docs
|
|
226
|
-
|
|
227
|
-
# Create the main vector search tool using @tool decorator
|
|
228
|
-
# Note: args_schema provides descriptions for query and filters,
|
|
229
|
-
# so Annotated is only needed for injected LangGraph parameters
|
|
230
|
-
@tool(
|
|
231
|
-
name_or_callable=name or index_name,
|
|
232
|
-
description=description or "Search for documents using vector similarity",
|
|
233
|
-
args_schema=VectorSearchRetrieverToolInput,
|
|
382
|
+
# Create the StructuredTool
|
|
383
|
+
tool: StructuredTool = StructuredTool.from_function(
|
|
384
|
+
func=vector_search_func,
|
|
385
|
+
name=name or f"vector_search_{vector_store.index.name}",
|
|
386
|
+
description=description or f"Search documents in {index_name}",
|
|
387
|
+
args_schema=input_schema,
|
|
234
388
|
)
|
|
235
|
-
def vector_search_tool(
|
|
236
|
-
query: str,
|
|
237
|
-
filters: Optional[List[FilterItem]] = None,
|
|
238
|
-
state: Annotated[dict, InjectedState] = None,
|
|
239
|
-
tool_call_id: Annotated[str, InjectedToolCallId] = None,
|
|
240
|
-
) -> list[dict[str, Any]]:
|
|
241
|
-
"""
|
|
242
|
-
Search for documents using vector similarity with optional reranking.
|
|
243
|
-
|
|
244
|
-
This tool performs a two-stage retrieval process:
|
|
245
|
-
1. Vector similarity search to find candidate documents
|
|
246
|
-
2. Optional reranking using cross-encoder model for improved relevance
|
|
247
|
-
|
|
248
|
-
Both stages are traced in MLflow for observability.
|
|
249
|
-
|
|
250
|
-
Returns:
|
|
251
|
-
Command with ToolMessage containing the retrieved documents
|
|
252
|
-
"""
|
|
253
|
-
logger.debug(
|
|
254
|
-
f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
# Step 1: Perform vector similarity search
|
|
258
|
-
documents: List[Document] = _find_documents(query, filters)
|
|
259
|
-
|
|
260
|
-
# Step 2: If reranking is enabled, rerank the documents
|
|
261
|
-
if reranker_config:
|
|
262
|
-
logger.debug(
|
|
263
|
-
f"Reranking enabled (model: '{reranker_config.model}', top_n: {reranker_config.top_n or 'all'})"
|
|
264
|
-
)
|
|
265
|
-
documents = _rerank_documents(query, documents)
|
|
266
|
-
logger.debug(f"Returning {len(documents)} reranked documents")
|
|
267
|
-
else:
|
|
268
|
-
logger.debug("Reranking disabled, returning original vector search results")
|
|
269
|
-
|
|
270
|
-
# Return Command with ToolMessage containing the documents
|
|
271
|
-
# Serialize documents to dicts for proper ToolMessage handling
|
|
272
|
-
serialized_docs: list[dict[str, Any]] = [
|
|
273
|
-
{
|
|
274
|
-
"page_content": doc.page_content,
|
|
275
|
-
"metadata": doc.metadata,
|
|
276
|
-
}
|
|
277
|
-
for doc in documents
|
|
278
|
-
]
|
|
279
389
|
|
|
280
|
-
|
|
390
|
+
logger.success("Vector search tool created", name=tool.name, index=index_name)
|
|
281
391
|
|
|
282
|
-
return
|
|
392
|
+
return tool
|