dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,282 +1,392 @@
1
- from typing import Annotated, Any, Callable, List, Optional, Sequence
1
+ """
2
+ Vector search tool for retrieving documents from Databricks Vector Search.
3
+
4
+ This module provides a tool factory for creating semantic search tools
5
+ with dynamic filter schemas based on table columns and FlashRank reranking support.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from typing import Any, Optional
2
11
 
3
12
  import mlflow
13
+ from databricks.sdk import WorkspaceClient
4
14
  from databricks.vector_search.reranker import DatabricksReranker
5
- from databricks_ai_bridge.vector_search_retriever_tool import (
6
- FilterItem,
7
- VectorSearchRetrieverToolInput,
8
- )
9
- from databricks_langchain.vectorstores import DatabricksVectorSearch
15
+ from databricks_langchain import DatabricksVectorSearch
10
16
  from flashrank import Ranker, RerankRequest
11
17
  from langchain_core.documents import Document
12
- from langchain_core.tools import InjectedToolCallId, tool
13
- from langgraph.prebuilt import InjectedState
18
+ from langchain_core.tools import StructuredTool
14
19
  from loguru import logger
15
20
  from mlflow.entities import SpanType
21
+ from pydantic import BaseModel, ConfigDict, Field, create_model
16
22
 
17
23
  from dao_ai.config import (
18
24
  RerankParametersModel,
19
25
  RetrieverModel,
26
+ SearchParametersModel,
20
27
  VectorStoreModel,
28
+ value_of,
21
29
  )
30
+ from dao_ai.utils import normalize_host
31
+
32
+ # Create FilterItem model at module level so it can be used in type hints
33
+ FilterItem = create_model(
34
+ "FilterItem",
35
+ key=(
36
+ str,
37
+ Field(
38
+ description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
39
+ ),
40
+ ),
41
+ value=(
42
+ Any,
43
+ Field(
44
+ description="The filter value, which can be a single value or an array of values"
45
+ ),
46
+ ),
47
+ __config__=ConfigDict(extra="forbid"),
48
+ )
49
+
50
+
51
+ def _create_dynamic_input_schema(
52
+ index_name: str, workspace_client: WorkspaceClient
53
+ ) -> type[BaseModel]:
54
+ """
55
+ Create dynamic input schema with column information from the table.
56
+
57
+ Args:
58
+ index_name: Full name of the vector search index
59
+ workspace_client: Workspace client to query table metadata
60
+
61
+ Returns:
62
+ Pydantic model class for tool input
63
+ """
64
+
65
+ # Try to get column information
66
+ column_descriptions = []
67
+ try:
68
+ table_info = workspace_client.tables.get(full_name=index_name)
69
+ for column_info in table_info.columns:
70
+ name = column_info.name
71
+ col_type = column_info.type_name.name
72
+ if not name.startswith("__"):
73
+ column_descriptions.append(f"{name} ({col_type})")
74
+ except Exception:
75
+ logger.debug(
76
+ "Could not retrieve column information for dynamic schema",
77
+ index=index_name,
78
+ )
79
+
80
+ # Build filter description matching VectorSearchRetrieverTool format
81
+ filter_description = (
82
+ "Optional filters to refine vector search results as an array of key-value pairs. "
83
+ "IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
84
+ "then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
85
+ )
86
+
87
+ if column_descriptions:
88
+ filter_description += (
89
+ f"Available columns for filtering: {', '.join(column_descriptions)}. "
90
+ )
91
+
92
+ filter_description += (
93
+ "Supports the following operators:\n\n"
94
+ '- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
95
+ '- Exclusion: [{"key": "column NOT", "value": value}]\n'
96
+ '- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
97
+ '- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
98
+ '- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
99
+ "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
100
+ "Examples:\n"
101
+ '- Filter by category: [{"key": "category", "value": "electronics"}]\n'
102
+ '- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
103
+ '- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
104
+ '- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
105
+ )
106
+
107
+ # Create the input model
108
+ VectorSearchInput = create_model(
109
+ "VectorSearchInput",
110
+ query=(
111
+ str,
112
+ Field(description="The search query string to find relevant documents"),
113
+ ),
114
+ filters=(
115
+ Optional[list[FilterItem]],
116
+ Field(default=None, description=filter_description),
117
+ ),
118
+ __config__=ConfigDict(extra="forbid"),
119
+ )
120
+
121
+ return VectorSearchInput
122
+
123
+
124
+ @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
125
+ def _rerank_documents(
126
+ query: str,
127
+ documents: list[Document],
128
+ ranker: Ranker,
129
+ rerank_config: RerankParametersModel,
130
+ ) -> list[Document]:
131
+ """
132
+ Rerank documents using FlashRank cross-encoder model.
133
+
134
+ Args:
135
+ query: The search query string
136
+ documents: List of documents to rerank
137
+ ranker: The FlashRank Ranker instance
138
+ rerank_config: Reranking configuration
139
+
140
+ Returns:
141
+ Reranked list of documents with reranker_score in metadata
142
+ """
143
+ logger.trace(
144
+ "Starting reranking",
145
+ documents_count=len(documents),
146
+ model=rerank_config.model,
147
+ )
148
+
149
+ # Prepare passages for reranking
150
+ passages: list[dict[str, Any]] = [
151
+ {"text": doc.page_content, "meta": doc.metadata} for doc in documents
152
+ ]
153
+
154
+ # Create reranking request
155
+ rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
156
+
157
+ # Perform reranking
158
+ results: list[dict[str, Any]] = ranker.rerank(rerank_request)
159
+
160
+ # Apply top_n filtering
161
+ top_n: int = rerank_config.top_n or len(documents)
162
+ results = results[:top_n]
163
+ logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
164
+
165
+ # Convert back to Document objects with reranking scores
166
+ reranked_docs: list[Document] = []
167
+ for result in results:
168
+ orig_doc: Optional[Document] = next(
169
+ (doc for doc in documents if doc.page_content == result["text"]), None
170
+ )
171
+ if orig_doc:
172
+ reranked_doc: Document = Document(
173
+ page_content=orig_doc.page_content,
174
+ metadata={
175
+ **orig_doc.metadata,
176
+ "reranker_score": result["score"],
177
+ },
178
+ )
179
+ reranked_docs.append(reranked_doc)
180
+
181
+ logger.debug(
182
+ "Documents reranked",
183
+ input_count=len(documents),
184
+ output_count=len(reranked_docs),
185
+ model=rerank_config.model,
186
+ )
187
+
188
+ return reranked_docs
22
189
 
23
190
 
24
191
  def create_vector_search_tool(
25
- retriever: RetrieverModel | dict[str, Any],
192
+ retriever: Optional[RetrieverModel | dict[str, Any]] = None,
193
+ vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
26
194
  name: Optional[str] = None,
27
195
  description: Optional[str] = None,
28
- ) -> Callable:
196
+ ) -> StructuredTool:
29
197
  """
30
- Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
31
-
32
- This function creates a tool that enables semantic search over product information,
33
- documentation, or other content using the @tool decorator pattern. It supports optional
34
- reranking of results using FlashRank for improved relevance.
198
+ Create a Vector Search tool with dynamic schema and optional reranking.
35
199
 
36
200
  Args:
37
- retriever: Configuration details for the vector search retriever, including:
38
- - name: Name of the tool
39
- - description: Description of the tool's purpose
40
- - primary_key: Primary key column for the vector store
41
- - text_column: Text column used for vector search
42
- - doc_uri: URI for documentation or additional context
43
- - vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
44
- - columns: List of columns to retrieve from the vector store
45
- - search_parameters: Additional parameters for customizing the search behavior
46
- - rerank: Optional rerank configuration for result reranking
201
+ retriever: Full retriever configuration with search parameters and reranking
202
+ vector_store: Direct vector store reference (uses default search parameters)
47
203
  name: Optional custom name for the tool
48
204
  description: Optional custom description for the tool
49
205
 
50
206
  Returns:
51
- A LangChain tool that performs vector search with optional reranking
207
+ A LangChain StructuredTool with proper schema (additionalProperties: false)
52
208
  """
53
209
 
54
- if isinstance(retriever, dict):
55
- retriever = RetrieverModel(**retriever)
210
+ # Validate mutually exclusive parameters
211
+ if retriever is None and vector_store is None:
212
+ raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
213
+ if retriever is not None and vector_store is not None:
214
+ raise ValueError(
215
+ "Cannot provide both 'retriever' and 'vector_store' parameters"
216
+ )
56
217
 
57
- vector_store_config: VectorStoreModel = retriever.vector_store
218
+ # Handle vector_store parameter
219
+ if vector_store is not None:
220
+ if isinstance(vector_store, dict):
221
+ vector_store = VectorStoreModel(**vector_store)
222
+ retriever = RetrieverModel(vector_store=vector_store)
223
+ else:
224
+ if isinstance(retriever, dict):
225
+ retriever = RetrieverModel(**retriever)
58
226
 
59
- # Index is required for vector search
60
- if vector_store_config.index is None:
61
- raise ValueError("vector_store.index is required for vector search")
227
+ vector_store: VectorStoreModel = retriever.vector_store
62
228
 
63
- index_name: str = vector_store_config.index.full_name
64
- columns: Sequence[str] = retriever.columns or []
65
- search_parameters: dict[str, Any] = retriever.search_parameters.model_dump()
66
- primary_key: str = vector_store_config.primary_key or ""
67
- doc_uri: str = vector_store_config.doc_uri or ""
68
- text_column: str = vector_store_config.embedding_source_column
229
+ # Index is required
230
+ if vector_store.index is None:
231
+ raise ValueError("vector_store.index is required for vector search")
69
232
 
70
- # Extract reranker configuration
71
- reranker_config: Optional[RerankParametersModel] = retriever.rerank
233
+ index_name: str = vector_store.index.full_name
234
+ columns: list[str] = list(retriever.columns or [])
235
+ search_parameters: SearchParametersModel = retriever.search_parameters
236
+ rerank_config: Optional[RerankParametersModel] = retriever.rerank
72
237
 
73
- # Initialize FlashRank ranker once if reranking is enabled
74
- # This is expensive (loads model weights), so we do it once and reuse across invocations
238
+ # Initialize FlashRank ranker if configured
75
239
  ranker: Optional[Ranker] = None
76
- if reranker_config:
240
+ if rerank_config and rerank_config.model:
77
241
  logger.debug(
78
- f"Creating vector search tool with reranking: '{name}' "
79
- f"(model: {reranker_config.model}, top_n: {reranker_config.top_n or 'auto'})"
242
+ "Initializing FlashRank ranker",
243
+ model=rerank_config.model,
244
+ top_n=rerank_config.top_n or "auto",
80
245
  )
81
246
  try:
82
- ranker = Ranker(
83
- model_name=reranker_config.model, cache_dir=reranker_config.cache_dir
84
- )
85
- logger.info(
86
- f"FlashRank ranker initialized successfully (model: {reranker_config.model})"
87
- )
247
+ cache_dir = os.path.expanduser(rerank_config.cache_dir)
248
+ ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
249
+ logger.success("FlashRank ranker initialized", model=rerank_config.model)
88
250
  except Exception as e:
89
- logger.warning(
90
- f"Failed to initialize FlashRank ranker during tool creation: {e}. "
91
- "Reranking will be disabled for this tool."
92
- )
93
- # Set reranker_config to None so we don't attempt reranking
94
- reranker_config = None
95
- else:
96
- logger.debug(
97
- f"Creating vector search tool without reranking: '{name}' (standard similarity search only)"
98
- )
251
+ logger.warning("Failed to initialize FlashRank ranker", error=str(e))
252
+ rerank_config = None
253
+
254
+ # Build client_args for VectorSearchClient
255
+ # Use getattr to safely access attributes that may not exist (e.g., in mocks)
256
+ client_args: dict[str, Any] = {}
257
+ has_explicit_auth = any(
258
+ [
259
+ os.environ.get("DATABRICKS_TOKEN"),
260
+ os.environ.get("DATABRICKS_CLIENT_ID"),
261
+ getattr(vector_store, "pat", None),
262
+ getattr(vector_store, "client_id", None),
263
+ getattr(vector_store, "on_behalf_of_user", None),
264
+ ]
265
+ )
266
+
267
+ if has_explicit_auth:
268
+ databricks_host = os.environ.get("DATABRICKS_HOST")
269
+ if (
270
+ not databricks_host
271
+ and getattr(vector_store, "_workspace_client", None) is not None
272
+ ):
273
+ databricks_host = vector_store.workspace_client.config.host
274
+ if databricks_host:
275
+ client_args["workspace_url"] = normalize_host(databricks_host)
276
+
277
+ token = os.environ.get("DATABRICKS_TOKEN")
278
+ if not token and getattr(vector_store, "pat", None):
279
+ token = value_of(vector_store.pat)
280
+ if token:
281
+ client_args["personal_access_token"] = token
282
+
283
+ client_id = os.environ.get("DATABRICKS_CLIENT_ID")
284
+ if not client_id and getattr(vector_store, "client_id", None):
285
+ client_id = value_of(vector_store.client_id)
286
+ if client_id:
287
+ client_args["service_principal_client_id"] = client_id
288
+
289
+ client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
290
+ if not client_secret and getattr(vector_store, "client_secret", None):
291
+ client_secret = value_of(vector_store.client_secret)
292
+ if client_secret:
293
+ client_args["service_principal_client_secret"] = client_secret
294
+
295
+ logger.debug(
296
+ "Creating vector search tool",
297
+ name=name,
298
+ index=index_name,
299
+ client_args_keys=list(client_args.keys()) if client_args else [],
300
+ )
99
301
 
100
- # Initialize the vector store
101
- # Note: text_column is only required for self-managed embeddings
102
- # For Databricks-managed embeddings, it's automatically determined from the index
103
- vector_store: DatabricksVectorSearch = DatabricksVectorSearch(
302
+ # Create DatabricksVectorSearch
303
+ # Note: text_column should be None for Databricks-managed embeddings
304
+ # (it's automatically determined from the index)
305
+ vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
104
306
  index_name=index_name,
105
- text_column=None, # Let DatabricksVectorSearch determine this from the index
307
+ text_column=None,
106
308
  columns=columns,
309
+ workspace_client=vector_store.workspace_client,
310
+ client_args=client_args if client_args else None,
311
+ primary_key=vector_store.primary_key,
312
+ doc_uri=vector_store.doc_uri,
107
313
  include_score=True,
108
- workspace_client=vector_store_config.workspace_client,
314
+ reranker=(
315
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
+ if rerank_config and rerank_config.columns
317
+ else None
318
+ ),
109
319
  )
110
320
 
111
- # Register the retriever schema with MLflow for model serving integration
112
- mlflow.models.set_retriever_schema(
113
- name=name or "retriever",
114
- primary_key=primary_key,
115
- text_column=text_column,
116
- doc_uri=doc_uri,
117
- other_columns=list(columns),
321
+ # Create dynamic input schema
322
+ input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
+ index_name, vector_store.workspace_client
118
324
  )
119
325
 
120
- # Helper function to perform vector similarity search
121
- @mlflow.trace(name="find_documents", span_type=SpanType.RETRIEVER)
122
- def _find_documents(
123
- query: str, filters: Optional[List[FilterItem]] = None
124
- ) -> List[Document]:
125
- """Perform vector similarity search."""
126
- # Convert filters to dict format
326
+ # Define the tool function
327
+ def vector_search_func(
328
+ query: str, filters: Optional[list[FilterItem]] = None
329
+ ) -> str:
330
+ """Search for relevant documents using vector similarity."""
331
+ # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
127
332
  filters_dict: dict[str, Any] = {}
128
333
  if filters:
129
334
  for item in filters:
130
- item_dict = dict(item)
131
- filters_dict[item_dict["key"]] = item_dict["value"]
335
+ filters_dict[item.key] = item.value
132
336
 
133
- # Merge with any configured filters
337
+ # Merge with configured filters
134
338
  combined_filters: dict[str, Any] = {
135
339
  **filters_dict,
136
- **search_parameters.get("filters", {}),
340
+ **(search_parameters.filters or {}),
137
341
  }
138
342
 
139
- # Perform similarity search
140
- num_results: int = search_parameters.get("num_results", 10)
141
- query_type: str = search_parameters.get("query_type", "ANN")
142
-
143
- logger.debug(
144
- f"Performing vector search: query='{query[:50]}...', k={num_results}, filters={combined_filters}"
343
+ # Perform vector search
344
+ logger.trace("Performing vector search", query_preview=query[:50])
345
+ documents: list[Document] = vector_search.similarity_search(
346
+ query=query,
347
+ k=search_parameters.num_results or 5,
348
+ filter=combined_filters if combined_filters else None,
349
+ query_type=search_parameters.query_type or "ANN",
145
350
  )
146
351
 
147
- # Build similarity search kwargs
148
- search_kwargs = {
149
- "query": query,
150
- "k": num_results,
151
- "filter": combined_filters if combined_filters else None,
152
- "query_type": query_type,
153
- }
154
-
155
- # Add DatabricksReranker if configured with columns
156
- if reranker_config and reranker_config.columns:
157
- search_kwargs["reranker"] = DatabricksReranker(
158
- columns_to_rerank=reranker_config.columns
352
+ # Apply FlashRank reranking if configured
353
+ if ranker and rerank_config:
354
+ logger.debug("Applying FlashRank reranking")
355
+ documents = _rerank_documents(query, documents, ranker, rerank_config)
356
+
357
+ # Serialize documents to JSON format for LLM consumption
358
+ # Convert Document objects to dicts with page_content and metadata
359
+ # Need to handle numpy types in metadata (e.g., float32, int64)
360
+ serialized_docs: list[dict[str, Any]] = []
361
+ for doc in documents:
362
+ doc: Document
363
+ # Convert metadata values to JSON-serializable types
364
+ metadata_serializable: dict[str, Any] = {}
365
+ for key, value in doc.metadata.items():
366
+ # Handle numpy types
367
+ if hasattr(value, "item"): # numpy scalar
368
+ metadata_serializable[key] = value.item()
369
+ else:
370
+ metadata_serializable[key] = value
371
+
372
+ serialized_docs.append(
373
+ {
374
+ "page_content": doc.page_content,
375
+ "metadata": metadata_serializable,
376
+ }
159
377
  )
160
378
 
161
- documents: List[Document] = vector_store.similarity_search(**search_kwargs)
162
-
163
- logger.debug(f"Retrieved {len(documents)} documents from vector search")
164
- return documents
165
-
166
- # Helper function to rerank documents
167
- @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
168
- def _rerank_documents(query: str, documents: List[Document]) -> List[Document]:
169
- """Rerank documents using FlashRank.
379
+ # Return as JSON string
380
+ return json.dumps(serialized_docs)
170
381
 
171
- Uses the ranker instance initialized at tool creation time (captured in closure).
172
- This avoids expensive model loading on every invocation.
173
- """
174
- if not reranker_config or ranker is None:
175
- return documents
176
-
177
- logger.debug(
178
- f"Starting reranking for {len(documents)} documents using model '{reranker_config.model}'"
179
- )
180
-
181
- # Prepare passages for reranking
182
- passages: List[dict[str, Any]] = [
183
- {"text": doc.page_content, "meta": doc.metadata} for doc in documents
184
- ]
185
-
186
- # Create reranking request
187
- rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
188
-
189
- # Perform reranking
190
- logger.debug(f"Reranking {len(passages)} passages for query: '{query[:50]}...'")
191
- results: List[dict[str, Any]] = ranker.rerank(rerank_request)
192
-
193
- # Apply top_n filtering
194
- top_n: int = reranker_config.top_n or len(documents)
195
- results = results[:top_n]
196
- logger.debug(
197
- f"Reranking complete. Filtered to top {top_n} results from {len(documents)} candidates"
198
- )
199
-
200
- # Convert back to Document objects with reranking scores
201
- reranked_docs: List[Document] = []
202
- for result in results:
203
- # Find original document by matching text
204
- orig_doc: Optional[Document] = next(
205
- (doc for doc in documents if doc.page_content == result["text"]), None
206
- )
207
- if orig_doc:
208
- # Add reranking score to metadata
209
- reranked_doc: Document = Document(
210
- page_content=orig_doc.page_content,
211
- metadata={
212
- **orig_doc.metadata,
213
- "reranker_score": result["score"],
214
- },
215
- )
216
- reranked_docs.append(reranked_doc)
217
-
218
- logger.debug(
219
- f"Reranked {len(documents)} documents → {len(reranked_docs)} results "
220
- f"(model: {reranker_config.model}, top score: {reranked_docs[0].metadata.get('reranker_score', 0):.4f})"
221
- if reranked_docs
222
- else f"Reranking completed with {len(reranked_docs)} results"
223
- )
224
-
225
- return reranked_docs
226
-
227
- # Create the main vector search tool using @tool decorator
228
- # Note: args_schema provides descriptions for query and filters,
229
- # so Annotated is only needed for injected LangGraph parameters
230
- @tool(
231
- name_or_callable=name or index_name,
232
- description=description or "Search for documents using vector similarity",
233
- args_schema=VectorSearchRetrieverToolInput,
382
+ # Create the StructuredTool
383
+ tool: StructuredTool = StructuredTool.from_function(
384
+ func=vector_search_func,
385
+ name=name or f"vector_search_{vector_store.index.name}",
386
+ description=description or f"Search documents in {index_name}",
387
+ args_schema=input_schema,
234
388
  )
235
- def vector_search_tool(
236
- query: str,
237
- filters: Optional[List[FilterItem]] = None,
238
- state: Annotated[dict, InjectedState] = None,
239
- tool_call_id: Annotated[str, InjectedToolCallId] = None,
240
- ) -> list[dict[str, Any]]:
241
- """
242
- Search for documents using vector similarity with optional reranking.
243
-
244
- This tool performs a two-stage retrieval process:
245
- 1. Vector similarity search to find candidate documents
246
- 2. Optional reranking using cross-encoder model for improved relevance
247
-
248
- Both stages are traced in MLflow for observability.
249
-
250
- Returns:
251
- Command with ToolMessage containing the retrieved documents
252
- """
253
- logger.debug(
254
- f"Vector search tool called: query='{query[:50]}...', reranking={reranker_config is not None}"
255
- )
256
-
257
- # Step 1: Perform vector similarity search
258
- documents: List[Document] = _find_documents(query, filters)
259
-
260
- # Step 2: If reranking is enabled, rerank the documents
261
- if reranker_config:
262
- logger.debug(
263
- f"Reranking enabled (model: '{reranker_config.model}', top_n: {reranker_config.top_n or 'all'})"
264
- )
265
- documents = _rerank_documents(query, documents)
266
- logger.debug(f"Returning {len(documents)} reranked documents")
267
- else:
268
- logger.debug("Reranking disabled, returning original vector search results")
269
-
270
- # Return Command with ToolMessage containing the documents
271
- # Serialize documents to dicts for proper ToolMessage handling
272
- serialized_docs: list[dict[str, Any]] = [
273
- {
274
- "page_content": doc.page_content,
275
- "metadata": doc.metadata,
276
- }
277
- for doc in documents
278
- ]
279
389
 
280
- return serialized_docs
390
+ logger.success("Vector search tool created", name=tool.name, index=index_name)
281
391
 
282
- return vector_search_tool
392
+ return tool