dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. dao_ai/agent_as_code.py +2 -5
  2. dao_ai/cli.py +65 -15
  3. dao_ai/config.py +672 -218
  4. dao_ai/genie/cache/core.py +6 -2
  5. dao_ai/genie/cache/lru.py +29 -11
  6. dao_ai/genie/cache/semantic.py +95 -44
  7. dao_ai/hooks/core.py +5 -5
  8. dao_ai/logging.py +56 -0
  9. dao_ai/memory/core.py +61 -44
  10. dao_ai/memory/databricks.py +54 -41
  11. dao_ai/memory/postgres.py +77 -36
  12. dao_ai/middleware/assertions.py +45 -17
  13. dao_ai/middleware/core.py +13 -7
  14. dao_ai/middleware/guardrails.py +30 -25
  15. dao_ai/middleware/human_in_the_loop.py +9 -5
  16. dao_ai/middleware/message_validation.py +61 -29
  17. dao_ai/middleware/summarization.py +16 -11
  18. dao_ai/models.py +172 -69
  19. dao_ai/nodes.py +148 -19
  20. dao_ai/optimization.py +26 -16
  21. dao_ai/orchestration/core.py +15 -8
  22. dao_ai/orchestration/supervisor.py +22 -8
  23. dao_ai/orchestration/swarm.py +57 -12
  24. dao_ai/prompts.py +17 -17
  25. dao_ai/providers/databricks.py +365 -155
  26. dao_ai/state.py +24 -6
  27. dao_ai/tools/__init__.py +2 -0
  28. dao_ai/tools/agent.py +1 -3
  29. dao_ai/tools/core.py +7 -7
  30. dao_ai/tools/email.py +29 -77
  31. dao_ai/tools/genie.py +18 -13
  32. dao_ai/tools/mcp.py +223 -156
  33. dao_ai/tools/python.py +5 -2
  34. dao_ai/tools/search.py +1 -1
  35. dao_ai/tools/slack.py +21 -9
  36. dao_ai/tools/sql.py +202 -0
  37. dao_ai/tools/time.py +30 -7
  38. dao_ai/tools/unity_catalog.py +129 -86
  39. dao_ai/tools/vector_search.py +318 -244
  40. dao_ai/utils.py +15 -10
  41. dao_ai-0.1.3.dist-info/METADATA +455 -0
  42. dao_ai-0.1.3.dist-info/RECORD +64 -0
  43. dao_ai-0.1.1.dist-info/METADATA +0 -1878
  44. dao_ai-0.1.1.dist-info/RECORD +0 -62
  45. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
  46. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
  47. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
@@ -2,317 +2,391 @@
2
2
  Vector search tool for retrieving documents from Databricks Vector Search.
3
3
 
4
4
  This module provides a tool factory for creating semantic search tools
5
- using ToolRuntime[Context, AgentState] for type-safe runtime access.
5
+ with dynamic filter schemas based on table columns and FlashRank reranking support.
6
6
  """
7
7
 
8
+ import json
8
9
  import os
9
- from typing import Any, Callable, List, Optional, Sequence
10
+ from typing import Any, Optional
10
11
 
11
12
  import mlflow
13
+ from databricks.sdk import WorkspaceClient
12
14
  from databricks.vector_search.reranker import DatabricksReranker
13
- from databricks_ai_bridge.vector_search_retriever_tool import (
14
- FilterItem,
15
- VectorSearchRetrieverToolInput,
16
- )
17
- from databricks_langchain.vectorstores import DatabricksVectorSearch
15
+ from databricks_langchain import DatabricksVectorSearch
18
16
  from flashrank import Ranker, RerankRequest
19
- from langchain.tools import ToolRuntime, tool
20
17
  from langchain_core.documents import Document
18
+ from langchain_core.tools import StructuredTool
21
19
  from loguru import logger
22
20
  from mlflow.entities import SpanType
21
+ from pydantic import BaseModel, ConfigDict, Field, create_model
23
22
 
24
23
  from dao_ai.config import (
25
24
  RerankParametersModel,
26
25
  RetrieverModel,
26
+ SearchParametersModel,
27
27
  VectorStoreModel,
28
+ value_of,
28
29
  )
29
- from dao_ai.state import AgentState, Context
30
30
  from dao_ai.utils import normalize_host
31
31
 
32
+ # Create FilterItem model at module level so it can be used in type hints
33
+ FilterItem = create_model(
34
+ "FilterItem",
35
+ key=(
36
+ str,
37
+ Field(
38
+ description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
39
+ ),
40
+ ),
41
+ value=(
42
+ Any,
43
+ Field(
44
+ description="The filter value, which can be a single value or an array of values"
45
+ ),
46
+ ),
47
+ __config__=ConfigDict(extra="forbid"),
48
+ )
32
49
 
33
- def create_vector_search_tool(
34
- retriever: RetrieverModel | dict[str, Any],
35
- name: Optional[str] = None,
36
- description: Optional[str] = None,
37
- ) -> Callable[..., list[dict[str, Any]]]:
38
- """
39
- Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
40
50
 
41
- This function creates a tool that enables semantic search over product information,
42
- documentation, or other content using the @tool decorator pattern. It supports optional
43
- reranking of results using FlashRank for improved relevance.
51
+ def _create_dynamic_input_schema(
52
+ index_name: str, workspace_client: WorkspaceClient
53
+ ) -> type[BaseModel]:
54
+ """
55
+ Create dynamic input schema with column information from the table.
44
56
 
45
57
  Args:
46
- retriever: Configuration details for the vector search retriever, including:
47
- - name: Name of the tool
48
- - description: Description of the tool's purpose
49
- - primary_key: Primary key column for the vector store
50
- - text_column: Text column used for vector search
51
- - doc_uri: URI for documentation or additional context
52
- - vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
53
- - columns: List of columns to retrieve from the vector store
54
- - search_parameters: Additional parameters for customizing the search behavior
55
- - rerank: Optional rerank configuration for result reranking
56
- name: Optional custom name for the tool
57
- description: Optional custom description for the tool
58
+ index_name: Full name of the vector search index
59
+ workspace_client: Workspace client to query table metadata
58
60
 
59
61
  Returns:
60
- A LangChain tool that performs vector search with optional reranking
62
+ Pydantic model class for tool input
61
63
  """
62
64
 
63
- if isinstance(retriever, dict):
64
- retriever = RetrieverModel(**retriever)
65
-
66
- vector_store_config: VectorStoreModel = retriever.vector_store
67
-
68
- # Index is required for vector search
69
- if vector_store_config.index is None:
70
- raise ValueError("vector_store.index is required for vector search")
71
-
72
- index_name: str = vector_store_config.index.full_name
73
- columns: Sequence[str] = retriever.columns or []
74
- search_parameters: dict[str, Any] = retriever.search_parameters.model_dump()
75
- primary_key: str = vector_store_config.primary_key or ""
76
- doc_uri: str = vector_store_config.doc_uri or ""
77
- text_column: str = vector_store_config.embedding_source_column
78
-
79
- # Extract reranker configuration
80
- reranker_config: Optional[RerankParametersModel] = retriever.rerank
81
-
82
- # Initialize FlashRank ranker once if reranking is enabled
83
- # This is expensive (loads model weights), so we do it once and reuse across invocations
84
- ranker: Optional[Ranker] = None
85
- if reranker_config:
86
- logger.debug(
87
- f"Creating vector search tool with reranking: '{name}' "
88
- f"(model: {reranker_config.model}, top_n: {reranker_config.top_n or 'auto'})"
89
- )
90
- try:
91
- ranker = Ranker(
92
- model_name=reranker_config.model, cache_dir=reranker_config.cache_dir
93
- )
94
- logger.info(
95
- f"FlashRank ranker initialized successfully (model: {reranker_config.model})"
96
- )
97
- except Exception as e:
98
- logger.warning(
99
- f"Failed to initialize FlashRank ranker during tool creation: {e}. "
100
- "Reranking will be disabled for this tool."
101
- )
102
- # Set reranker_config to None so we don't attempt reranking
103
- reranker_config = None
104
- else:
65
+ # Try to get column information
66
+ column_descriptions = []
67
+ try:
68
+ table_info = workspace_client.tables.get(full_name=index_name)
69
+ for column_info in table_info.columns:
70
+ name = column_info.name
71
+ col_type = column_info.type_name.name
72
+ if not name.startswith("__"):
73
+ column_descriptions.append(f"{name} ({col_type})")
74
+ except Exception:
105
75
  logger.debug(
106
- f"Creating vector search tool without reranking: '{name}' (standard similarity search only)"
76
+ "Could not retrieve column information for dynamic schema",
77
+ index=index_name,
107
78
  )
108
79
 
109
- # Initialize the vector store
110
- # Note: text_column is only required for self-managed embeddings
111
- # For Databricks-managed embeddings, it's automatically determined from the index
80
+ # Build filter description matching VectorSearchRetrieverTool format
81
+ filter_description = (
82
+ "Optional filters to refine vector search results as an array of key-value pairs. "
83
+ "IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
84
+ "then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
85
+ )
112
86
 
113
- # Build client_args for VectorSearchClient from environment variables
114
- # This is needed because during MLflow model validation, credentials must be
115
- # explicitly passed to VectorSearchClient via client_args.
116
- # The workspace_client parameter in DatabricksVectorSearch is only used to detect
117
- # model serving mode - it doesn't pass credentials to VectorSearchClient.
118
- client_args: dict[str, Any] = {}
119
- databricks_host = normalize_host(os.environ.get("DATABRICKS_HOST"))
120
- if databricks_host:
121
- client_args["workspace_url"] = databricks_host
122
- if os.environ.get("DATABRICKS_TOKEN"):
123
- client_args["personal_access_token"] = os.environ.get("DATABRICKS_TOKEN")
124
- if os.environ.get("DATABRICKS_CLIENT_ID"):
125
- client_args["service_principal_client_id"] = os.environ.get(
126
- "DATABRICKS_CLIENT_ID"
127
- )
128
- if os.environ.get("DATABRICKS_CLIENT_SECRET"):
129
- client_args["service_principal_client_secret"] = os.environ.get(
130
- "DATABRICKS_CLIENT_SECRET"
87
+ if column_descriptions:
88
+ filter_description += (
89
+ f"Available columns for filtering: {', '.join(column_descriptions)}. "
131
90
  )
132
91
 
133
- logger.debug(
134
- f"Creating DatabricksVectorSearch with client_args keys: {list(client_args.keys())}"
92
+ filter_description += (
93
+ "Supports the following operators:\n\n"
94
+ '- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
95
+ '- Exclusion: [{"key": "column NOT", "value": value}]\n'
96
+ '- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
97
+ '- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
98
+ '- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
99
+ "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
100
+ "Examples:\n"
101
+ '- Filter by category: [{"key": "category", "value": "electronics"}]\n'
102
+ '- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
103
+ '- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
104
+ '- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
135
105
  )
136
106
 
137
- # Pass both workspace_client (for model serving detection) and client_args (for credentials)
138
- vector_store: DatabricksVectorSearch = DatabricksVectorSearch(
139
- index_name=index_name,
140
- text_column=None, # Let DatabricksVectorSearch determine this from the index
141
- columns=columns,
142
- include_score=True,
143
- workspace_client=vector_store_config.workspace_client,
144
- client_args=client_args if client_args else None,
107
+ # Create the input model
108
+ VectorSearchInput = create_model(
109
+ "VectorSearchInput",
110
+ query=(
111
+ str,
112
+ Field(description="The search query string to find relevant documents"),
113
+ ),
114
+ filters=(
115
+ Optional[list[FilterItem]],
116
+ Field(default=None, description=filter_description),
117
+ ),
118
+ __config__=ConfigDict(extra="forbid"),
145
119
  )
146
120
 
147
- # Register the retriever schema with MLflow for model serving integration
148
- mlflow.models.set_retriever_schema(
149
- name=name or "retriever",
150
- primary_key=primary_key,
151
- text_column=text_column,
152
- doc_uri=doc_uri,
153
- other_columns=list(columns),
121
+ return VectorSearchInput
122
+
123
+
124
+ @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
125
+ def _rerank_documents(
126
+ query: str,
127
+ documents: list[Document],
128
+ ranker: Ranker,
129
+ rerank_config: RerankParametersModel,
130
+ ) -> list[Document]:
131
+ """
132
+ Rerank documents using FlashRank cross-encoder model.
133
+
134
+ Args:
135
+ query: The search query string
136
+ documents: List of documents to rerank
137
+ ranker: The FlashRank Ranker instance
138
+ rerank_config: Reranking configuration
139
+
140
+ Returns:
141
+ Reranked list of documents with reranker_score in metadata
142
+ """
143
+ logger.trace(
144
+ "Starting reranking",
145
+ documents_count=len(documents),
146
+ model=rerank_config.model,
154
147
  )
155
148
 
156
- # Helper function to perform vector similarity search
157
- @mlflow.trace(name="find_documents", span_type=SpanType.RETRIEVER)
158
- def _find_documents(
159
- query: str, filters: Optional[List[FilterItem]] = None
160
- ) -> List[Document]:
161
- """Perform vector similarity search."""
162
- # Convert filters to dict format
163
- filters_dict: dict[str, Any] = {}
164
- if filters:
165
- for item in filters:
166
- item_dict = dict(item)
167
- filters_dict[item_dict["key"]] = item_dict["value"]
149
+ # Prepare passages for reranking
150
+ passages: list[dict[str, Any]] = [
151
+ {"text": doc.page_content, "meta": doc.metadata} for doc in documents
152
+ ]
168
153
 
169
- # Merge with any configured filters
170
- combined_filters: dict[str, Any] = {
171
- **filters_dict,
172
- **search_parameters.get("filters", {}),
173
- }
154
+ # Create reranking request
155
+ rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
174
156
 
175
- # Perform similarity search
176
- num_results: int = search_parameters.get("num_results", 10)
177
- query_type: str = search_parameters.get("query_type", "ANN")
157
+ # Perform reranking
158
+ results: list[dict[str, Any]] = ranker.rerank(rerank_request)
178
159
 
179
- logger.debug(
180
- f"Performing vector search: query='{query[:50]}...', k={num_results}, filters={combined_filters}"
160
+ # Apply top_n filtering
161
+ top_n: int = rerank_config.top_n or len(documents)
162
+ results = results[:top_n]
163
+ logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
164
+
165
+ # Convert back to Document objects with reranking scores
166
+ reranked_docs: list[Document] = []
167
+ for result in results:
168
+ orig_doc: Optional[Document] = next(
169
+ (doc for doc in documents if doc.page_content == result["text"]), None
181
170
  )
171
+ if orig_doc:
172
+ reranked_doc: Document = Document(
173
+ page_content=orig_doc.page_content,
174
+ metadata={
175
+ **orig_doc.metadata,
176
+ "reranker_score": result["score"],
177
+ },
178
+ )
179
+ reranked_docs.append(reranked_doc)
182
180
 
183
- # Build similarity search kwargs
184
- search_kwargs = {
185
- "query": query,
186
- "k": num_results,
187
- "filter": combined_filters if combined_filters else None,
188
- "query_type": query_type,
189
- }
181
+ logger.debug(
182
+ "Documents reranked",
183
+ input_count=len(documents),
184
+ output_count=len(reranked_docs),
185
+ model=rerank_config.model,
186
+ )
190
187
 
191
- # Add DatabricksReranker if configured with columns
192
- if reranker_config and reranker_config.columns:
193
- search_kwargs["reranker"] = DatabricksReranker(
194
- columns_to_rerank=reranker_config.columns
195
- )
188
+ return reranked_docs
196
189
 
197
- documents: List[Document] = vector_store.similarity_search(**search_kwargs)
198
190
 
199
- logger.debug(f"Retrieved {len(documents)} documents from vector search")
200
- return documents
191
+ def create_vector_search_tool(
192
+ retriever: Optional[RetrieverModel | dict[str, Any]] = None,
193
+ vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
194
+ name: Optional[str] = None,
195
+ description: Optional[str] = None,
196
+ ) -> StructuredTool:
197
+ """
198
+ Create a Vector Search tool with dynamic schema and optional reranking.
201
199
 
202
- # Helper function to rerank documents
203
- @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
204
- def _rerank_documents(query: str, documents: List[Document]) -> List[Document]:
205
- """Rerank documents using FlashRank.
200
+ Args:
201
+ retriever: Full retriever configuration with search parameters and reranking
202
+ vector_store: Direct vector store reference (uses default search parameters)
203
+ name: Optional custom name for the tool
204
+ description: Optional custom description for the tool
206
205
 
207
- Uses the ranker instance initialized at tool creation time (captured in closure).
208
- This avoids expensive model loading on every invocation.
209
- """
210
- if not reranker_config or ranker is None:
211
- return documents
206
+ Returns:
207
+ A LangChain StructuredTool with proper schema (additionalProperties: false)
208
+ """
212
209
 
213
- logger.debug(
214
- f"Starting reranking for {len(documents)} documents using model '{reranker_config.model}'"
210
+ # Validate mutually exclusive parameters
211
+ if retriever is None and vector_store is None:
212
+ raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
213
+ if retriever is not None and vector_store is not None:
214
+ raise ValueError(
215
+ "Cannot provide both 'retriever' and 'vector_store' parameters"
215
216
  )
216
217
 
217
- # Prepare passages for reranking
218
- passages: List[dict[str, Any]] = [
219
- {"text": doc.page_content, "meta": doc.metadata} for doc in documents
220
- ]
218
+ # Handle vector_store parameter
219
+ if vector_store is not None:
220
+ if isinstance(vector_store, dict):
221
+ vector_store = VectorStoreModel(**vector_store)
222
+ retriever = RetrieverModel(vector_store=vector_store)
223
+ else:
224
+ if isinstance(retriever, dict):
225
+ retriever = RetrieverModel(**retriever)
226
+
227
+ vector_store: VectorStoreModel = retriever.vector_store
221
228
 
222
- # Create reranking request
223
- rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
229
+ # Index is required
230
+ if vector_store.index is None:
231
+ raise ValueError("vector_store.index is required for vector search")
224
232
 
225
- # Perform reranking
226
- logger.debug(f"Reranking {len(passages)} passages for query: '{query[:50]}...'")
227
- results: List[dict[str, Any]] = ranker.rerank(rerank_request)
233
+ index_name: str = vector_store.index.full_name
234
+ columns: list[str] = list(retriever.columns or [])
235
+ search_parameters: SearchParametersModel = retriever.search_parameters
236
+ rerank_config: Optional[RerankParametersModel] = retriever.rerank
228
237
 
229
- # Apply top_n filtering
230
- top_n: int = reranker_config.top_n or len(documents)
231
- results = results[:top_n]
238
+ # Initialize FlashRank ranker if configured
239
+ ranker: Optional[Ranker] = None
240
+ if rerank_config and rerank_config.model:
232
241
  logger.debug(
233
- f"Reranking complete. Filtered to top {top_n} results from {len(documents)} candidates"
242
+ "Initializing FlashRank ranker",
243
+ model=rerank_config.model,
244
+ top_n=rerank_config.top_n or "auto",
234
245
  )
246
+ try:
247
+ cache_dir = os.path.expanduser(rerank_config.cache_dir)
248
+ ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
249
+ logger.success("FlashRank ranker initialized", model=rerank_config.model)
250
+ except Exception as e:
251
+ logger.warning("Failed to initialize FlashRank ranker", error=str(e))
252
+ rerank_config = None
235
253
 
236
- # Convert back to Document objects with reranking scores
237
- reranked_docs: List[Document] = []
238
- for result in results:
239
- # Find original document by matching text
240
- orig_doc: Optional[Document] = next(
241
- (doc for doc in documents if doc.page_content == result["text"]), None
242
- )
243
- if orig_doc:
244
- # Add reranking score to metadata
245
- reranked_doc: Document = Document(
246
- page_content=orig_doc.page_content,
247
- metadata={
248
- **orig_doc.metadata,
249
- "reranker_score": result["score"],
250
- },
251
- )
252
- reranked_docs.append(reranked_doc)
254
+ # Build client_args for VectorSearchClient
255
+ # Use getattr to safely access attributes that may not exist (e.g., in mocks)
256
+ client_args: dict[str, Any] = {}
257
+ has_explicit_auth = any(
258
+ [
259
+ os.environ.get("DATABRICKS_TOKEN"),
260
+ os.environ.get("DATABRICKS_CLIENT_ID"),
261
+ getattr(vector_store, "pat", None),
262
+ getattr(vector_store, "client_id", None),
263
+ getattr(vector_store, "on_behalf_of_user", None),
264
+ ]
265
+ )
253
266
 
254
- logger.debug(
255
- f"Reranked {len(documents)} documents → {len(reranked_docs)} results "
256
- f"(model: {reranker_config.model}, top score: {reranked_docs[0].metadata.get('reranker_score', 0):.4f})"
257
- if reranked_docs
258
- else f"Reranking completed with {len(reranked_docs)} results"
259
- )
267
+ if has_explicit_auth:
268
+ databricks_host = os.environ.get("DATABRICKS_HOST")
269
+ if (
270
+ not databricks_host
271
+ and getattr(vector_store, "_workspace_client", None) is not None
272
+ ):
273
+ databricks_host = vector_store.workspace_client.config.host
274
+ if databricks_host:
275
+ client_args["workspace_url"] = normalize_host(databricks_host)
276
+
277
+ token = os.environ.get("DATABRICKS_TOKEN")
278
+ if not token and getattr(vector_store, "pat", None):
279
+ token = value_of(vector_store.pat)
280
+ if token:
281
+ client_args["personal_access_token"] = token
282
+
283
+ client_id = os.environ.get("DATABRICKS_CLIENT_ID")
284
+ if not client_id and getattr(vector_store, "client_id", None):
285
+ client_id = value_of(vector_store.client_id)
286
+ if client_id:
287
+ client_args["service_principal_client_id"] = client_id
288
+
289
+ client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
290
+ if not client_secret and getattr(vector_store, "client_secret", None):
291
+ client_secret = value_of(vector_store.client_secret)
292
+ if client_secret:
293
+ client_args["service_principal_client_secret"] = client_secret
260
294
 
261
- return reranked_docs
295
+ logger.debug(
296
+ "Creating vector search tool",
297
+ name=name,
298
+ index=index_name,
299
+ client_args_keys=list(client_args.keys()) if client_args else [],
300
+ )
262
301
 
263
- # Create the main vector search tool using @tool decorator
264
- # Uses ToolRuntime[Context, AgentState] for type-safe runtime access
265
- @tool(
266
- name_or_callable=name or index_name,
267
- description=description or "Search for documents using vector similarity",
268
- args_schema=VectorSearchRetrieverToolInput,
302
+ # Create DatabricksVectorSearch
303
+ # Note: text_column should be None for Databricks-managed embeddings
304
+ # (it's automatically determined from the index)
305
+ vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
306
+ index_name=index_name,
307
+ text_column=None,
308
+ columns=columns,
309
+ workspace_client=vector_store.workspace_client,
310
+ client_args=client_args if client_args else None,
311
+ primary_key=vector_store.primary_key,
312
+ doc_uri=vector_store.doc_uri,
313
+ include_score=True,
314
+ reranker=(
315
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
+ if rerank_config and rerank_config.columns
317
+ else None
318
+ ),
269
319
  )
270
- def vector_search_tool(
271
- query: str,
272
- filters: Optional[List[FilterItem]] = None,
273
- runtime: ToolRuntime[Context, AgentState] = None,
274
- ) -> list[dict[str, Any]]:
275
- """
276
- Search for documents using vector similarity with optional reranking.
277
320
 
278
- This tool performs a two-stage retrieval process:
279
- 1. Vector similarity search to find candidate documents
280
- 2. Optional reranking using cross-encoder model for improved relevance
321
+ # Create dynamic input schema
322
+ input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
+ index_name, vector_store.workspace_client
324
+ )
281
325
 
282
- Both stages are traced in MLflow for observability.
326
+ # Define the tool function
327
+ def vector_search_func(
328
+ query: str, filters: Optional[list[FilterItem]] = None
329
+ ) -> str:
330
+ """Search for relevant documents using vector similarity."""
331
+ # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
332
+ filters_dict: dict[str, Any] = {}
333
+ if filters:
334
+ for item in filters:
335
+ filters_dict[item.key] = item.value
283
336
 
284
- Uses ToolRuntime[Context, AgentState] for type-safe runtime access.
337
+ # Merge with configured filters
338
+ combined_filters: dict[str, Any] = {
339
+ **filters_dict,
340
+ **(search_parameters.filters or {}),
341
+ }
285
342
 
286
- Returns:
287
- List of serialized documents with page_content and metadata
288
- """
289
- logger.debug(
290
- f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
343
+ # Perform vector search
344
+ logger.trace("Performing vector search", query_preview=query[:50])
345
+ documents: list[Document] = vector_search.similarity_search(
346
+ query=query,
347
+ k=search_parameters.num_results or 5,
348
+ filter=combined_filters if combined_filters else None,
349
+ query_type=search_parameters.query_type or "ANN",
291
350
  )
292
351
 
293
- # Step 1: Perform vector similarity search
294
- documents: List[Document] = _find_documents(query, filters)
295
-
296
- # Step 2: If reranking is enabled, rerank the documents
297
- if reranker_config:
298
- logger.debug(
299
- f"Reranking enabled (model: '{reranker_config.model}', top_n: {reranker_config.top_n or 'all'})"
352
+ # Apply FlashRank reranking if configured
353
+ if ranker and rerank_config:
354
+ logger.debug("Applying FlashRank reranking")
355
+ documents = _rerank_documents(query, documents, ranker, rerank_config)
356
+
357
+ # Serialize documents to JSON format for LLM consumption
358
+ # Convert Document objects to dicts with page_content and metadata
359
+ # Need to handle numpy types in metadata (e.g., float32, int64)
360
+ serialized_docs: list[dict[str, Any]] = []
361
+ for doc in documents:
362
+ doc: Document
363
+ # Convert metadata values to JSON-serializable types
364
+ metadata_serializable: dict[str, Any] = {}
365
+ for key, value in doc.metadata.items():
366
+ # Handle numpy types
367
+ if hasattr(value, "item"): # numpy scalar
368
+ metadata_serializable[key] = value.item()
369
+ else:
370
+ metadata_serializable[key] = value
371
+
372
+ serialized_docs.append(
373
+ {
374
+ "page_content": doc.page_content,
375
+ "metadata": metadata_serializable,
376
+ }
300
377
  )
301
- documents = _rerank_documents(query, documents)
302
- logger.debug(f"Returning {len(documents)} reranked documents")
303
- else:
304
- logger.debug("Reranking disabled, returning original vector search results")
305
-
306
- # Return Command with ToolMessage containing the documents
307
- # Serialize documents to dicts for proper ToolMessage handling
308
- serialized_docs: list[dict[str, Any]] = [
309
- {
310
- "page_content": doc.page_content,
311
- "metadata": doc.metadata,
312
- }
313
- for doc in documents
314
- ]
315
378
 
316
- return serialized_docs
379
+ # Return as JSON string
380
+ return json.dumps(serialized_docs)
381
+
382
+ # Create the StructuredTool
383
+ tool: StructuredTool = StructuredTool.from_function(
384
+ func=vector_search_func,
385
+ name=name or f"vector_search_{vector_store.index.name}",
386
+ description=description or f"Search documents in {index_name}",
387
+ args_schema=input_schema,
388
+ )
389
+
390
+ logger.success("Vector search tool created", name=tool.name, index=index_name)
317
391
 
318
- return vector_search_tool
392
+ return tool