dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/vector_search.py
CHANGED
|
@@ -2,317 +2,391 @@
|
|
|
2
2
|
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
3
|
|
|
4
4
|
This module provides a tool factory for creating semantic search tools
|
|
5
|
-
|
|
5
|
+
with dynamic filter schemas based on table columns and FlashRank reranking support.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
+
import json
|
|
8
9
|
import os
|
|
9
|
-
from typing import Any,
|
|
10
|
+
from typing import Any, Optional
|
|
10
11
|
|
|
11
12
|
import mlflow
|
|
13
|
+
from databricks.sdk import WorkspaceClient
|
|
12
14
|
from databricks.vector_search.reranker import DatabricksReranker
|
|
13
|
-
from
|
|
14
|
-
FilterItem,
|
|
15
|
-
VectorSearchRetrieverToolInput,
|
|
16
|
-
)
|
|
17
|
-
from databricks_langchain.vectorstores import DatabricksVectorSearch
|
|
15
|
+
from databricks_langchain import DatabricksVectorSearch
|
|
18
16
|
from flashrank import Ranker, RerankRequest
|
|
19
|
-
from langchain.tools import ToolRuntime, tool
|
|
20
17
|
from langchain_core.documents import Document
|
|
18
|
+
from langchain_core.tools import StructuredTool
|
|
21
19
|
from loguru import logger
|
|
22
20
|
from mlflow.entities import SpanType
|
|
21
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
23
22
|
|
|
24
23
|
from dao_ai.config import (
|
|
25
24
|
RerankParametersModel,
|
|
26
25
|
RetrieverModel,
|
|
26
|
+
SearchParametersModel,
|
|
27
27
|
VectorStoreModel,
|
|
28
|
+
value_of,
|
|
28
29
|
)
|
|
29
|
-
from dao_ai.state import AgentState, Context
|
|
30
30
|
from dao_ai.utils import normalize_host
|
|
31
31
|
|
|
32
|
+
# Create FilterItem model at module level so it can be used in type hints
|
|
33
|
+
FilterItem = create_model(
|
|
34
|
+
"FilterItem",
|
|
35
|
+
key=(
|
|
36
|
+
str,
|
|
37
|
+
Field(
|
|
38
|
+
description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
|
|
39
|
+
),
|
|
40
|
+
),
|
|
41
|
+
value=(
|
|
42
|
+
Any,
|
|
43
|
+
Field(
|
|
44
|
+
description="The filter value, which can be a single value or an array of values"
|
|
45
|
+
),
|
|
46
|
+
),
|
|
47
|
+
__config__=ConfigDict(extra="forbid"),
|
|
48
|
+
)
|
|
32
49
|
|
|
33
|
-
def create_vector_search_tool(
|
|
34
|
-
retriever: RetrieverModel | dict[str, Any],
|
|
35
|
-
name: Optional[str] = None,
|
|
36
|
-
description: Optional[str] = None,
|
|
37
|
-
) -> Callable[..., list[dict[str, Any]]]:
|
|
38
|
-
"""
|
|
39
|
-
Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
|
|
40
50
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
51
|
+
def _create_dynamic_input_schema(
|
|
52
|
+
index_name: str, workspace_client: WorkspaceClient
|
|
53
|
+
) -> type[BaseModel]:
|
|
54
|
+
"""
|
|
55
|
+
Create dynamic input schema with column information from the table.
|
|
44
56
|
|
|
45
57
|
Args:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
- description: Description of the tool's purpose
|
|
49
|
-
- primary_key: Primary key column for the vector store
|
|
50
|
-
- text_column: Text column used for vector search
|
|
51
|
-
- doc_uri: URI for documentation or additional context
|
|
52
|
-
- vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
|
|
53
|
-
- columns: List of columns to retrieve from the vector store
|
|
54
|
-
- search_parameters: Additional parameters for customizing the search behavior
|
|
55
|
-
- rerank: Optional rerank configuration for result reranking
|
|
56
|
-
name: Optional custom name for the tool
|
|
57
|
-
description: Optional custom description for the tool
|
|
58
|
+
index_name: Full name of the vector search index
|
|
59
|
+
workspace_client: Workspace client to query table metadata
|
|
58
60
|
|
|
59
61
|
Returns:
|
|
60
|
-
|
|
62
|
+
Pydantic model class for tool input
|
|
61
63
|
"""
|
|
62
64
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
columns: Sequence[str] = retriever.columns or []
|
|
74
|
-
search_parameters: dict[str, Any] = retriever.search_parameters.model_dump()
|
|
75
|
-
primary_key: str = vector_store_config.primary_key or ""
|
|
76
|
-
doc_uri: str = vector_store_config.doc_uri or ""
|
|
77
|
-
text_column: str = vector_store_config.embedding_source_column
|
|
78
|
-
|
|
79
|
-
# Extract reranker configuration
|
|
80
|
-
reranker_config: Optional[RerankParametersModel] = retriever.rerank
|
|
81
|
-
|
|
82
|
-
# Initialize FlashRank ranker once if reranking is enabled
|
|
83
|
-
# This is expensive (loads model weights), so we do it once and reuse across invocations
|
|
84
|
-
ranker: Optional[Ranker] = None
|
|
85
|
-
if reranker_config:
|
|
86
|
-
logger.debug(
|
|
87
|
-
f"Creating vector search tool with reranking: '{name}' "
|
|
88
|
-
f"(model: {reranker_config.model}, top_n: {reranker_config.top_n or 'auto'})"
|
|
89
|
-
)
|
|
90
|
-
try:
|
|
91
|
-
ranker = Ranker(
|
|
92
|
-
model_name=reranker_config.model, cache_dir=reranker_config.cache_dir
|
|
93
|
-
)
|
|
94
|
-
logger.info(
|
|
95
|
-
f"FlashRank ranker initialized successfully (model: {reranker_config.model})"
|
|
96
|
-
)
|
|
97
|
-
except Exception as e:
|
|
98
|
-
logger.warning(
|
|
99
|
-
f"Failed to initialize FlashRank ranker during tool creation: {e}. "
|
|
100
|
-
"Reranking will be disabled for this tool."
|
|
101
|
-
)
|
|
102
|
-
# Set reranker_config to None so we don't attempt reranking
|
|
103
|
-
reranker_config = None
|
|
104
|
-
else:
|
|
65
|
+
# Try to get column information
|
|
66
|
+
column_descriptions = []
|
|
67
|
+
try:
|
|
68
|
+
table_info = workspace_client.tables.get(full_name=index_name)
|
|
69
|
+
for column_info in table_info.columns:
|
|
70
|
+
name = column_info.name
|
|
71
|
+
col_type = column_info.type_name.name
|
|
72
|
+
if not name.startswith("__"):
|
|
73
|
+
column_descriptions.append(f"{name} ({col_type})")
|
|
74
|
+
except Exception:
|
|
105
75
|
logger.debug(
|
|
106
|
-
|
|
76
|
+
"Could not retrieve column information for dynamic schema",
|
|
77
|
+
index=index_name,
|
|
107
78
|
)
|
|
108
79
|
|
|
109
|
-
#
|
|
110
|
-
|
|
111
|
-
|
|
80
|
+
# Build filter description matching VectorSearchRetrieverTool format
|
|
81
|
+
filter_description = (
|
|
82
|
+
"Optional filters to refine vector search results as an array of key-value pairs. "
|
|
83
|
+
"IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
|
|
84
|
+
"then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
|
|
85
|
+
)
|
|
112
86
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
# The workspace_client parameter in DatabricksVectorSearch is only used to detect
|
|
117
|
-
# model serving mode - it doesn't pass credentials to VectorSearchClient.
|
|
118
|
-
client_args: dict[str, Any] = {}
|
|
119
|
-
databricks_host = normalize_host(os.environ.get("DATABRICKS_HOST"))
|
|
120
|
-
if databricks_host:
|
|
121
|
-
client_args["workspace_url"] = databricks_host
|
|
122
|
-
if os.environ.get("DATABRICKS_TOKEN"):
|
|
123
|
-
client_args["personal_access_token"] = os.environ.get("DATABRICKS_TOKEN")
|
|
124
|
-
if os.environ.get("DATABRICKS_CLIENT_ID"):
|
|
125
|
-
client_args["service_principal_client_id"] = os.environ.get(
|
|
126
|
-
"DATABRICKS_CLIENT_ID"
|
|
127
|
-
)
|
|
128
|
-
if os.environ.get("DATABRICKS_CLIENT_SECRET"):
|
|
129
|
-
client_args["service_principal_client_secret"] = os.environ.get(
|
|
130
|
-
"DATABRICKS_CLIENT_SECRET"
|
|
87
|
+
if column_descriptions:
|
|
88
|
+
filter_description += (
|
|
89
|
+
f"Available columns for filtering: {', '.join(column_descriptions)}. "
|
|
131
90
|
)
|
|
132
91
|
|
|
133
|
-
|
|
134
|
-
|
|
92
|
+
filter_description += (
|
|
93
|
+
"Supports the following operators:\n\n"
|
|
94
|
+
'- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
|
|
95
|
+
'- Exclusion: [{"key": "column NOT", "value": value}]\n'
|
|
96
|
+
'- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
|
|
97
|
+
'- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
|
|
98
|
+
'- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
|
|
99
|
+
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
|
|
100
|
+
"Examples:\n"
|
|
101
|
+
'- Filter by category: [{"key": "category", "value": "electronics"}]\n'
|
|
102
|
+
'- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
|
|
103
|
+
'- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
|
|
104
|
+
'- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
|
|
135
105
|
)
|
|
136
106
|
|
|
137
|
-
#
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
107
|
+
# Create the input model
|
|
108
|
+
VectorSearchInput = create_model(
|
|
109
|
+
"VectorSearchInput",
|
|
110
|
+
query=(
|
|
111
|
+
str,
|
|
112
|
+
Field(description="The search query string to find relevant documents"),
|
|
113
|
+
),
|
|
114
|
+
filters=(
|
|
115
|
+
Optional[list[FilterItem]],
|
|
116
|
+
Field(default=None, description=filter_description),
|
|
117
|
+
),
|
|
118
|
+
__config__=ConfigDict(extra="forbid"),
|
|
145
119
|
)
|
|
146
120
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
121
|
+
return VectorSearchInput
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
125
|
+
def _rerank_documents(
|
|
126
|
+
query: str,
|
|
127
|
+
documents: list[Document],
|
|
128
|
+
ranker: Ranker,
|
|
129
|
+
rerank_config: RerankParametersModel,
|
|
130
|
+
) -> list[Document]:
|
|
131
|
+
"""
|
|
132
|
+
Rerank documents using FlashRank cross-encoder model.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
query: The search query string
|
|
136
|
+
documents: List of documents to rerank
|
|
137
|
+
ranker: The FlashRank Ranker instance
|
|
138
|
+
rerank_config: Reranking configuration
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Reranked list of documents with reranker_score in metadata
|
|
142
|
+
"""
|
|
143
|
+
logger.trace(
|
|
144
|
+
"Starting reranking",
|
|
145
|
+
documents_count=len(documents),
|
|
146
|
+
model=rerank_config.model,
|
|
154
147
|
)
|
|
155
148
|
|
|
156
|
-
#
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
) -> List[Document]:
|
|
161
|
-
"""Perform vector similarity search."""
|
|
162
|
-
# Convert filters to dict format
|
|
163
|
-
filters_dict: dict[str, Any] = {}
|
|
164
|
-
if filters:
|
|
165
|
-
for item in filters:
|
|
166
|
-
item_dict = dict(item)
|
|
167
|
-
filters_dict[item_dict["key"]] = item_dict["value"]
|
|
149
|
+
# Prepare passages for reranking
|
|
150
|
+
passages: list[dict[str, Any]] = [
|
|
151
|
+
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
152
|
+
]
|
|
168
153
|
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
**filters_dict,
|
|
172
|
-
**search_parameters.get("filters", {}),
|
|
173
|
-
}
|
|
154
|
+
# Create reranking request
|
|
155
|
+
rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
|
|
174
156
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
query_type: str = search_parameters.get("query_type", "ANN")
|
|
157
|
+
# Perform reranking
|
|
158
|
+
results: list[dict[str, Any]] = ranker.rerank(rerank_request)
|
|
178
159
|
|
|
179
|
-
|
|
180
|
-
|
|
160
|
+
# Apply top_n filtering
|
|
161
|
+
top_n: int = rerank_config.top_n or len(documents)
|
|
162
|
+
results = results[:top_n]
|
|
163
|
+
logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
|
|
164
|
+
|
|
165
|
+
# Convert back to Document objects with reranking scores
|
|
166
|
+
reranked_docs: list[Document] = []
|
|
167
|
+
for result in results:
|
|
168
|
+
orig_doc: Optional[Document] = next(
|
|
169
|
+
(doc for doc in documents if doc.page_content == result["text"]), None
|
|
181
170
|
)
|
|
171
|
+
if orig_doc:
|
|
172
|
+
reranked_doc: Document = Document(
|
|
173
|
+
page_content=orig_doc.page_content,
|
|
174
|
+
metadata={
|
|
175
|
+
**orig_doc.metadata,
|
|
176
|
+
"reranker_score": result["score"],
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
reranked_docs.append(reranked_doc)
|
|
182
180
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
}
|
|
181
|
+
logger.debug(
|
|
182
|
+
"Documents reranked",
|
|
183
|
+
input_count=len(documents),
|
|
184
|
+
output_count=len(reranked_docs),
|
|
185
|
+
model=rerank_config.model,
|
|
186
|
+
)
|
|
190
187
|
|
|
191
|
-
|
|
192
|
-
if reranker_config and reranker_config.columns:
|
|
193
|
-
search_kwargs["reranker"] = DatabricksReranker(
|
|
194
|
-
columns_to_rerank=reranker_config.columns
|
|
195
|
-
)
|
|
188
|
+
return reranked_docs
|
|
196
189
|
|
|
197
|
-
documents: List[Document] = vector_store.similarity_search(**search_kwargs)
|
|
198
190
|
|
|
199
|
-
|
|
200
|
-
|
|
191
|
+
def create_vector_search_tool(
|
|
192
|
+
retriever: Optional[RetrieverModel | dict[str, Any]] = None,
|
|
193
|
+
vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
|
|
194
|
+
name: Optional[str] = None,
|
|
195
|
+
description: Optional[str] = None,
|
|
196
|
+
) -> StructuredTool:
|
|
197
|
+
"""
|
|
198
|
+
Create a Vector Search tool with dynamic schema and optional reranking.
|
|
201
199
|
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
200
|
+
Args:
|
|
201
|
+
retriever: Full retriever configuration with search parameters and reranking
|
|
202
|
+
vector_store: Direct vector store reference (uses default search parameters)
|
|
203
|
+
name: Optional custom name for the tool
|
|
204
|
+
description: Optional custom description for the tool
|
|
206
205
|
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
if not reranker_config or ranker is None:
|
|
211
|
-
return documents
|
|
206
|
+
Returns:
|
|
207
|
+
A LangChain StructuredTool with proper schema (additionalProperties: false)
|
|
208
|
+
"""
|
|
212
209
|
|
|
213
|
-
|
|
214
|
-
|
|
210
|
+
# Validate mutually exclusive parameters
|
|
211
|
+
if retriever is None and vector_store is None:
|
|
212
|
+
raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
|
|
213
|
+
if retriever is not None and vector_store is not None:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Cannot provide both 'retriever' and 'vector_store' parameters"
|
|
215
216
|
)
|
|
216
217
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
218
|
+
# Handle vector_store parameter
|
|
219
|
+
if vector_store is not None:
|
|
220
|
+
if isinstance(vector_store, dict):
|
|
221
|
+
vector_store = VectorStoreModel(**vector_store)
|
|
222
|
+
retriever = RetrieverModel(vector_store=vector_store)
|
|
223
|
+
else:
|
|
224
|
+
if isinstance(retriever, dict):
|
|
225
|
+
retriever = RetrieverModel(**retriever)
|
|
226
|
+
|
|
227
|
+
vector_store: VectorStoreModel = retriever.vector_store
|
|
221
228
|
|
|
222
|
-
|
|
223
|
-
|
|
229
|
+
# Index is required
|
|
230
|
+
if vector_store.index is None:
|
|
231
|
+
raise ValueError("vector_store.index is required for vector search")
|
|
224
232
|
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
233
|
+
index_name: str = vector_store.index.full_name
|
|
234
|
+
columns: list[str] = list(retriever.columns or [])
|
|
235
|
+
search_parameters: SearchParametersModel = retriever.search_parameters
|
|
236
|
+
rerank_config: Optional[RerankParametersModel] = retriever.rerank
|
|
228
237
|
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
238
|
+
# Initialize FlashRank ranker if configured
|
|
239
|
+
ranker: Optional[Ranker] = None
|
|
240
|
+
if rerank_config and rerank_config.model:
|
|
232
241
|
logger.debug(
|
|
233
|
-
|
|
242
|
+
"Initializing FlashRank ranker",
|
|
243
|
+
model=rerank_config.model,
|
|
244
|
+
top_n=rerank_config.top_n or "auto",
|
|
234
245
|
)
|
|
246
|
+
try:
|
|
247
|
+
cache_dir = os.path.expanduser(rerank_config.cache_dir)
|
|
248
|
+
ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
|
|
249
|
+
logger.success("FlashRank ranker initialized", model=rerank_config.model)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.warning("Failed to initialize FlashRank ranker", error=str(e))
|
|
252
|
+
rerank_config = None
|
|
235
253
|
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
)
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
**orig_doc.metadata,
|
|
249
|
-
"reranker_score": result["score"],
|
|
250
|
-
},
|
|
251
|
-
)
|
|
252
|
-
reranked_docs.append(reranked_doc)
|
|
254
|
+
# Build client_args for VectorSearchClient
|
|
255
|
+
# Use getattr to safely access attributes that may not exist (e.g., in mocks)
|
|
256
|
+
client_args: dict[str, Any] = {}
|
|
257
|
+
has_explicit_auth = any(
|
|
258
|
+
[
|
|
259
|
+
os.environ.get("DATABRICKS_TOKEN"),
|
|
260
|
+
os.environ.get("DATABRICKS_CLIENT_ID"),
|
|
261
|
+
getattr(vector_store, "pat", None),
|
|
262
|
+
getattr(vector_store, "client_id", None),
|
|
263
|
+
getattr(vector_store, "on_behalf_of_user", None),
|
|
264
|
+
]
|
|
265
|
+
)
|
|
253
266
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
)
|
|
267
|
+
if has_explicit_auth:
|
|
268
|
+
databricks_host = os.environ.get("DATABRICKS_HOST")
|
|
269
|
+
if (
|
|
270
|
+
not databricks_host
|
|
271
|
+
and getattr(vector_store, "_workspace_client", None) is not None
|
|
272
|
+
):
|
|
273
|
+
databricks_host = vector_store.workspace_client.config.host
|
|
274
|
+
if databricks_host:
|
|
275
|
+
client_args["workspace_url"] = normalize_host(databricks_host)
|
|
276
|
+
|
|
277
|
+
token = os.environ.get("DATABRICKS_TOKEN")
|
|
278
|
+
if not token and getattr(vector_store, "pat", None):
|
|
279
|
+
token = value_of(vector_store.pat)
|
|
280
|
+
if token:
|
|
281
|
+
client_args["personal_access_token"] = token
|
|
282
|
+
|
|
283
|
+
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
|
|
284
|
+
if not client_id and getattr(vector_store, "client_id", None):
|
|
285
|
+
client_id = value_of(vector_store.client_id)
|
|
286
|
+
if client_id:
|
|
287
|
+
client_args["service_principal_client_id"] = client_id
|
|
288
|
+
|
|
289
|
+
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
|
|
290
|
+
if not client_secret and getattr(vector_store, "client_secret", None):
|
|
291
|
+
client_secret = value_of(vector_store.client_secret)
|
|
292
|
+
if client_secret:
|
|
293
|
+
client_args["service_principal_client_secret"] = client_secret
|
|
260
294
|
|
|
261
|
-
|
|
295
|
+
logger.debug(
|
|
296
|
+
"Creating vector search tool",
|
|
297
|
+
name=name,
|
|
298
|
+
index=index_name,
|
|
299
|
+
client_args_keys=list(client_args.keys()) if client_args else [],
|
|
300
|
+
)
|
|
262
301
|
|
|
263
|
-
# Create
|
|
264
|
-
#
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
302
|
+
# Create DatabricksVectorSearch
|
|
303
|
+
# Note: text_column should be None for Databricks-managed embeddings
|
|
304
|
+
# (it's automatically determined from the index)
|
|
305
|
+
vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
|
|
306
|
+
index_name=index_name,
|
|
307
|
+
text_column=None,
|
|
308
|
+
columns=columns,
|
|
309
|
+
workspace_client=vector_store.workspace_client,
|
|
310
|
+
client_args=client_args if client_args else None,
|
|
311
|
+
primary_key=vector_store.primary_key,
|
|
312
|
+
doc_uri=vector_store.doc_uri,
|
|
313
|
+
include_score=True,
|
|
314
|
+
reranker=(
|
|
315
|
+
DatabricksReranker(columns_to_rerank=rerank_config.columns)
|
|
316
|
+
if rerank_config and rerank_config.columns
|
|
317
|
+
else None
|
|
318
|
+
),
|
|
269
319
|
)
|
|
270
|
-
def vector_search_tool(
|
|
271
|
-
query: str,
|
|
272
|
-
filters: Optional[List[FilterItem]] = None,
|
|
273
|
-
runtime: ToolRuntime[Context, AgentState] = None,
|
|
274
|
-
) -> list[dict[str, Any]]:
|
|
275
|
-
"""
|
|
276
|
-
Search for documents using vector similarity with optional reranking.
|
|
277
320
|
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
321
|
+
# Create dynamic input schema
|
|
322
|
+
input_schema: type[BaseModel] = _create_dynamic_input_schema(
|
|
323
|
+
index_name, vector_store.workspace_client
|
|
324
|
+
)
|
|
281
325
|
|
|
282
|
-
|
|
326
|
+
# Define the tool function
|
|
327
|
+
def vector_search_func(
|
|
328
|
+
query: str, filters: Optional[list[FilterItem]] = None
|
|
329
|
+
) -> str:
|
|
330
|
+
"""Search for relevant documents using vector similarity."""
|
|
331
|
+
# Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
|
|
332
|
+
filters_dict: dict[str, Any] = {}
|
|
333
|
+
if filters:
|
|
334
|
+
for item in filters:
|
|
335
|
+
filters_dict[item.key] = item.value
|
|
283
336
|
|
|
284
|
-
|
|
337
|
+
# Merge with configured filters
|
|
338
|
+
combined_filters: dict[str, Any] = {
|
|
339
|
+
**filters_dict,
|
|
340
|
+
**(search_parameters.filters or {}),
|
|
341
|
+
}
|
|
285
342
|
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
343
|
+
# Perform vector search
|
|
344
|
+
logger.trace("Performing vector search", query_preview=query[:50])
|
|
345
|
+
documents: list[Document] = vector_search.similarity_search(
|
|
346
|
+
query=query,
|
|
347
|
+
k=search_parameters.num_results or 5,
|
|
348
|
+
filter=combined_filters if combined_filters else None,
|
|
349
|
+
query_type=search_parameters.query_type or "ANN",
|
|
291
350
|
)
|
|
292
351
|
|
|
293
|
-
#
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
352
|
+
# Apply FlashRank reranking if configured
|
|
353
|
+
if ranker and rerank_config:
|
|
354
|
+
logger.debug("Applying FlashRank reranking")
|
|
355
|
+
documents = _rerank_documents(query, documents, ranker, rerank_config)
|
|
356
|
+
|
|
357
|
+
# Serialize documents to JSON format for LLM consumption
|
|
358
|
+
# Convert Document objects to dicts with page_content and metadata
|
|
359
|
+
# Need to handle numpy types in metadata (e.g., float32, int64)
|
|
360
|
+
serialized_docs: list[dict[str, Any]] = []
|
|
361
|
+
for doc in documents:
|
|
362
|
+
doc: Document
|
|
363
|
+
# Convert metadata values to JSON-serializable types
|
|
364
|
+
metadata_serializable: dict[str, Any] = {}
|
|
365
|
+
for key, value in doc.metadata.items():
|
|
366
|
+
# Handle numpy types
|
|
367
|
+
if hasattr(value, "item"): # numpy scalar
|
|
368
|
+
metadata_serializable[key] = value.item()
|
|
369
|
+
else:
|
|
370
|
+
metadata_serializable[key] = value
|
|
371
|
+
|
|
372
|
+
serialized_docs.append(
|
|
373
|
+
{
|
|
374
|
+
"page_content": doc.page_content,
|
|
375
|
+
"metadata": metadata_serializable,
|
|
376
|
+
}
|
|
300
377
|
)
|
|
301
|
-
documents = _rerank_documents(query, documents)
|
|
302
|
-
logger.debug(f"Returning {len(documents)} reranked documents")
|
|
303
|
-
else:
|
|
304
|
-
logger.debug("Reranking disabled, returning original vector search results")
|
|
305
|
-
|
|
306
|
-
# Return Command with ToolMessage containing the documents
|
|
307
|
-
# Serialize documents to dicts for proper ToolMessage handling
|
|
308
|
-
serialized_docs: list[dict[str, Any]] = [
|
|
309
|
-
{
|
|
310
|
-
"page_content": doc.page_content,
|
|
311
|
-
"metadata": doc.metadata,
|
|
312
|
-
}
|
|
313
|
-
for doc in documents
|
|
314
|
-
]
|
|
315
378
|
|
|
316
|
-
|
|
379
|
+
# Return as JSON string
|
|
380
|
+
return json.dumps(serialized_docs)
|
|
381
|
+
|
|
382
|
+
# Create the StructuredTool
|
|
383
|
+
tool: StructuredTool = StructuredTool.from_function(
|
|
384
|
+
func=vector_search_func,
|
|
385
|
+
name=name or f"vector_search_{vector_store.index.name}",
|
|
386
|
+
description=description or f"Search documents in {index_name}",
|
|
387
|
+
args_schema=input_schema,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
logger.success("Vector search tool created", name=tool.name, index=index_name)
|
|
317
391
|
|
|
318
|
-
return
|
|
392
|
+
return tool
|