dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/vector_search.py
CHANGED
|
@@ -1,72 +1,392 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Vector search tool for retrieving documents from Databricks Vector Search.
|
|
3
|
+
|
|
4
|
+
This module provides a tool factory for creating semantic search tools
|
|
5
|
+
with dynamic filter schemas based on table columns and FlashRank reranking support.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
from typing import Any, Optional
|
|
2
11
|
|
|
3
12
|
import mlflow
|
|
4
|
-
from
|
|
5
|
-
from
|
|
13
|
+
from databricks.sdk import WorkspaceClient
|
|
14
|
+
from databricks.vector_search.reranker import DatabricksReranker
|
|
15
|
+
from databricks_langchain import DatabricksVectorSearch
|
|
16
|
+
from flashrank import Ranker, RerankRequest
|
|
17
|
+
from langchain_core.documents import Document
|
|
18
|
+
from langchain_core.tools import StructuredTool
|
|
19
|
+
from loguru import logger
|
|
20
|
+
from mlflow.entities import SpanType
|
|
21
|
+
from pydantic import BaseModel, ConfigDict, Field, create_model
|
|
6
22
|
|
|
7
23
|
from dao_ai.config import (
|
|
24
|
+
RerankParametersModel,
|
|
8
25
|
RetrieverModel,
|
|
26
|
+
SearchParametersModel,
|
|
9
27
|
VectorStoreModel,
|
|
28
|
+
value_of,
|
|
29
|
+
)
|
|
30
|
+
from dao_ai.utils import normalize_host
|
|
31
|
+
|
|
32
|
+
# Create FilterItem model at module level so it can be used in type hints
|
|
33
|
+
FilterItem = create_model(
|
|
34
|
+
"FilterItem",
|
|
35
|
+
key=(
|
|
36
|
+
str,
|
|
37
|
+
Field(
|
|
38
|
+
description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
|
|
39
|
+
),
|
|
40
|
+
),
|
|
41
|
+
value=(
|
|
42
|
+
Any,
|
|
43
|
+
Field(
|
|
44
|
+
description="The filter value, which can be a single value or an array of values"
|
|
45
|
+
),
|
|
46
|
+
),
|
|
47
|
+
__config__=ConfigDict(extra="forbid"),
|
|
10
48
|
)
|
|
11
49
|
|
|
12
50
|
|
|
51
|
+
def _create_dynamic_input_schema(
|
|
52
|
+
index_name: str, workspace_client: WorkspaceClient
|
|
53
|
+
) -> type[BaseModel]:
|
|
54
|
+
"""
|
|
55
|
+
Create dynamic input schema with column information from the table.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
index_name: Full name of the vector search index
|
|
59
|
+
workspace_client: Workspace client to query table metadata
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Pydantic model class for tool input
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Try to get column information
|
|
66
|
+
column_descriptions = []
|
|
67
|
+
try:
|
|
68
|
+
table_info = workspace_client.tables.get(full_name=index_name)
|
|
69
|
+
for column_info in table_info.columns:
|
|
70
|
+
name = column_info.name
|
|
71
|
+
col_type = column_info.type_name.name
|
|
72
|
+
if not name.startswith("__"):
|
|
73
|
+
column_descriptions.append(f"{name} ({col_type})")
|
|
74
|
+
except Exception:
|
|
75
|
+
logger.debug(
|
|
76
|
+
"Could not retrieve column information for dynamic schema",
|
|
77
|
+
index=index_name,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Build filter description matching VectorSearchRetrieverTool format
|
|
81
|
+
filter_description = (
|
|
82
|
+
"Optional filters to refine vector search results as an array of key-value pairs. "
|
|
83
|
+
"IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
|
|
84
|
+
"then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if column_descriptions:
|
|
88
|
+
filter_description += (
|
|
89
|
+
f"Available columns for filtering: {', '.join(column_descriptions)}. "
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
filter_description += (
|
|
93
|
+
"Supports the following operators:\n\n"
|
|
94
|
+
'- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
|
|
95
|
+
'- Exclusion: [{"key": "column NOT", "value": value}]\n'
|
|
96
|
+
'- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
|
|
97
|
+
'- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
|
|
98
|
+
'- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
|
|
99
|
+
"(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
|
|
100
|
+
"Examples:\n"
|
|
101
|
+
'- Filter by category: [{"key": "category", "value": "electronics"}]\n'
|
|
102
|
+
'- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
|
|
103
|
+
'- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
|
|
104
|
+
'- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Create the input model
|
|
108
|
+
VectorSearchInput = create_model(
|
|
109
|
+
"VectorSearchInput",
|
|
110
|
+
query=(
|
|
111
|
+
str,
|
|
112
|
+
Field(description="The search query string to find relevant documents"),
|
|
113
|
+
),
|
|
114
|
+
filters=(
|
|
115
|
+
Optional[list[FilterItem]],
|
|
116
|
+
Field(default=None, description=filter_description),
|
|
117
|
+
),
|
|
118
|
+
__config__=ConfigDict(extra="forbid"),
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return VectorSearchInput
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
|
|
125
|
+
def _rerank_documents(
|
|
126
|
+
query: str,
|
|
127
|
+
documents: list[Document],
|
|
128
|
+
ranker: Ranker,
|
|
129
|
+
rerank_config: RerankParametersModel,
|
|
130
|
+
) -> list[Document]:
|
|
131
|
+
"""
|
|
132
|
+
Rerank documents using FlashRank cross-encoder model.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
query: The search query string
|
|
136
|
+
documents: List of documents to rerank
|
|
137
|
+
ranker: The FlashRank Ranker instance
|
|
138
|
+
rerank_config: Reranking configuration
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Reranked list of documents with reranker_score in metadata
|
|
142
|
+
"""
|
|
143
|
+
logger.trace(
|
|
144
|
+
"Starting reranking",
|
|
145
|
+
documents_count=len(documents),
|
|
146
|
+
model=rerank_config.model,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Prepare passages for reranking
|
|
150
|
+
passages: list[dict[str, Any]] = [
|
|
151
|
+
{"text": doc.page_content, "meta": doc.metadata} for doc in documents
|
|
152
|
+
]
|
|
153
|
+
|
|
154
|
+
# Create reranking request
|
|
155
|
+
rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
|
|
156
|
+
|
|
157
|
+
# Perform reranking
|
|
158
|
+
results: list[dict[str, Any]] = ranker.rerank(rerank_request)
|
|
159
|
+
|
|
160
|
+
# Apply top_n filtering
|
|
161
|
+
top_n: int = rerank_config.top_n or len(documents)
|
|
162
|
+
results = results[:top_n]
|
|
163
|
+
logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
|
|
164
|
+
|
|
165
|
+
# Convert back to Document objects with reranking scores
|
|
166
|
+
reranked_docs: list[Document] = []
|
|
167
|
+
for result in results:
|
|
168
|
+
orig_doc: Optional[Document] = next(
|
|
169
|
+
(doc for doc in documents if doc.page_content == result["text"]), None
|
|
170
|
+
)
|
|
171
|
+
if orig_doc:
|
|
172
|
+
reranked_doc: Document = Document(
|
|
173
|
+
page_content=orig_doc.page_content,
|
|
174
|
+
metadata={
|
|
175
|
+
**orig_doc.metadata,
|
|
176
|
+
"reranker_score": result["score"],
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
reranked_docs.append(reranked_doc)
|
|
180
|
+
|
|
181
|
+
logger.debug(
|
|
182
|
+
"Documents reranked",
|
|
183
|
+
input_count=len(documents),
|
|
184
|
+
output_count=len(reranked_docs),
|
|
185
|
+
model=rerank_config.model,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return reranked_docs
|
|
189
|
+
|
|
190
|
+
|
|
13
191
|
def create_vector_search_tool(
|
|
14
|
-
retriever: RetrieverModel | dict[str, Any],
|
|
192
|
+
retriever: Optional[RetrieverModel | dict[str, Any]] = None,
|
|
193
|
+
vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
|
|
15
194
|
name: Optional[str] = None,
|
|
16
195
|
description: Optional[str] = None,
|
|
17
|
-
) ->
|
|
196
|
+
) -> StructuredTool:
|
|
18
197
|
"""
|
|
19
|
-
Create a Vector Search tool
|
|
20
|
-
|
|
21
|
-
This function creates a tool that enables semantic search over product information,
|
|
22
|
-
documentation, or other content. It also registers the retriever schema with MLflow
|
|
23
|
-
for proper integration with the model serving infrastructure.
|
|
198
|
+
Create a Vector Search tool with dynamic schema and optional reranking.
|
|
24
199
|
|
|
25
200
|
Args:
|
|
26
|
-
retriever:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
- text_column: Text column used for vector search
|
|
31
|
-
- doc_uri: URI for documentation or additional context
|
|
32
|
-
- vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
|
|
33
|
-
- columns: List of columns to retrieve from the vector store
|
|
34
|
-
- search_parameters: Additional parameters for customizing the search behavior
|
|
201
|
+
retriever: Full retriever configuration with search parameters and reranking
|
|
202
|
+
vector_store: Direct vector store reference (uses default search parameters)
|
|
203
|
+
name: Optional custom name for the tool
|
|
204
|
+
description: Optional custom description for the tool
|
|
35
205
|
|
|
36
206
|
Returns:
|
|
37
|
-
A
|
|
207
|
+
A LangChain StructuredTool with proper schema (additionalProperties: false)
|
|
38
208
|
"""
|
|
39
209
|
|
|
40
|
-
|
|
41
|
-
|
|
210
|
+
# Validate mutually exclusive parameters
|
|
211
|
+
if retriever is None and vector_store is None:
|
|
212
|
+
raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
|
|
213
|
+
if retriever is not None and vector_store is not None:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
"Cannot provide both 'retriever' and 'vector_store' parameters"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Handle vector_store parameter
|
|
219
|
+
if vector_store is not None:
|
|
220
|
+
if isinstance(vector_store, dict):
|
|
221
|
+
vector_store = VectorStoreModel(**vector_store)
|
|
222
|
+
retriever = RetrieverModel(vector_store=vector_store)
|
|
223
|
+
else:
|
|
224
|
+
if isinstance(retriever, dict):
|
|
225
|
+
retriever = RetrieverModel(**retriever)
|
|
42
226
|
|
|
43
227
|
vector_store: VectorStoreModel = retriever.vector_store
|
|
44
228
|
|
|
229
|
+
# Index is required
|
|
230
|
+
if vector_store.index is None:
|
|
231
|
+
raise ValueError("vector_store.index is required for vector search")
|
|
232
|
+
|
|
45
233
|
index_name: str = vector_store.index.full_name
|
|
46
|
-
columns:
|
|
47
|
-
search_parameters:
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
234
|
+
columns: list[str] = list(retriever.columns or [])
|
|
235
|
+
search_parameters: SearchParametersModel = retriever.search_parameters
|
|
236
|
+
rerank_config: Optional[RerankParametersModel] = retriever.rerank
|
|
237
|
+
|
|
238
|
+
# Initialize FlashRank ranker if configured
|
|
239
|
+
ranker: Optional[Ranker] = None
|
|
240
|
+
if rerank_config and rerank_config.model:
|
|
241
|
+
logger.debug(
|
|
242
|
+
"Initializing FlashRank ranker",
|
|
243
|
+
model=rerank_config.model,
|
|
244
|
+
top_n=rerank_config.top_n or "auto",
|
|
245
|
+
)
|
|
246
|
+
try:
|
|
247
|
+
cache_dir = os.path.expanduser(rerank_config.cache_dir)
|
|
248
|
+
ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
|
|
249
|
+
logger.success("FlashRank ranker initialized", model=rerank_config.model)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.warning("Failed to initialize FlashRank ranker", error=str(e))
|
|
252
|
+
rerank_config = None
|
|
253
|
+
|
|
254
|
+
# Build client_args for VectorSearchClient
|
|
255
|
+
# Use getattr to safely access attributes that may not exist (e.g., in mocks)
|
|
256
|
+
client_args: dict[str, Any] = {}
|
|
257
|
+
has_explicit_auth = any(
|
|
258
|
+
[
|
|
259
|
+
os.environ.get("DATABRICKS_TOKEN"),
|
|
260
|
+
os.environ.get("DATABRICKS_CLIENT_ID"),
|
|
261
|
+
getattr(vector_store, "pat", None),
|
|
262
|
+
getattr(vector_store, "client_id", None),
|
|
263
|
+
getattr(vector_store, "on_behalf_of_user", None),
|
|
264
|
+
]
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
if has_explicit_auth:
|
|
268
|
+
databricks_host = os.environ.get("DATABRICKS_HOST")
|
|
269
|
+
if (
|
|
270
|
+
not databricks_host
|
|
271
|
+
and getattr(vector_store, "_workspace_client", None) is not None
|
|
272
|
+
):
|
|
273
|
+
databricks_host = vector_store.workspace_client.config.host
|
|
274
|
+
if databricks_host:
|
|
275
|
+
client_args["workspace_url"] = normalize_host(databricks_host)
|
|
51
276
|
|
|
52
|
-
|
|
277
|
+
token = os.environ.get("DATABRICKS_TOKEN")
|
|
278
|
+
if not token and getattr(vector_store, "pat", None):
|
|
279
|
+
token = value_of(vector_store.pat)
|
|
280
|
+
if token:
|
|
281
|
+
client_args["personal_access_token"] = token
|
|
282
|
+
|
|
283
|
+
client_id = os.environ.get("DATABRICKS_CLIENT_ID")
|
|
284
|
+
if not client_id and getattr(vector_store, "client_id", None):
|
|
285
|
+
client_id = value_of(vector_store.client_id)
|
|
286
|
+
if client_id:
|
|
287
|
+
client_args["service_principal_client_id"] = client_id
|
|
288
|
+
|
|
289
|
+
client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
|
|
290
|
+
if not client_secret and getattr(vector_store, "client_secret", None):
|
|
291
|
+
client_secret = value_of(vector_store.client_secret)
|
|
292
|
+
if client_secret:
|
|
293
|
+
client_args["service_principal_client_secret"] = client_secret
|
|
294
|
+
|
|
295
|
+
logger.debug(
|
|
296
|
+
"Creating vector search tool",
|
|
53
297
|
name=name,
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
298
|
+
index=index_name,
|
|
299
|
+
client_args_keys=list(client_args.keys()) if client_args else [],
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Create DatabricksVectorSearch
|
|
303
|
+
# Note: text_column should be None for Databricks-managed embeddings
|
|
304
|
+
# (it's automatically determined from the index)
|
|
305
|
+
vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
|
|
57
306
|
index_name=index_name,
|
|
307
|
+
text_column=None,
|
|
58
308
|
columns=columns,
|
|
59
|
-
**search_parameters,
|
|
60
309
|
workspace_client=vector_store.workspace_client,
|
|
310
|
+
client_args=client_args if client_args else None,
|
|
311
|
+
primary_key=vector_store.primary_key,
|
|
312
|
+
doc_uri=vector_store.doc_uri,
|
|
313
|
+
include_score=True,
|
|
314
|
+
reranker=(
|
|
315
|
+
DatabricksReranker(columns_to_rerank=rerank_config.columns)
|
|
316
|
+
if rerank_config and rerank_config.columns
|
|
317
|
+
else None
|
|
318
|
+
),
|
|
61
319
|
)
|
|
62
320
|
|
|
63
|
-
#
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
primary_key=primary_key,
|
|
67
|
-
text_column=text_column,
|
|
68
|
-
doc_uri=doc_uri,
|
|
69
|
-
other_columns=columns,
|
|
321
|
+
# Create dynamic input schema
|
|
322
|
+
input_schema: type[BaseModel] = _create_dynamic_input_schema(
|
|
323
|
+
index_name, vector_store.workspace_client
|
|
70
324
|
)
|
|
71
325
|
|
|
72
|
-
|
|
326
|
+
# Define the tool function
|
|
327
|
+
def vector_search_func(
|
|
328
|
+
query: str, filters: Optional[list[FilterItem]] = None
|
|
329
|
+
) -> str:
|
|
330
|
+
"""Search for relevant documents using vector similarity."""
|
|
331
|
+
# Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
|
|
332
|
+
filters_dict: dict[str, Any] = {}
|
|
333
|
+
if filters:
|
|
334
|
+
for item in filters:
|
|
335
|
+
filters_dict[item.key] = item.value
|
|
336
|
+
|
|
337
|
+
# Merge with configured filters
|
|
338
|
+
combined_filters: dict[str, Any] = {
|
|
339
|
+
**filters_dict,
|
|
340
|
+
**(search_parameters.filters or {}),
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
# Perform vector search
|
|
344
|
+
logger.trace("Performing vector search", query_preview=query[:50])
|
|
345
|
+
documents: list[Document] = vector_search.similarity_search(
|
|
346
|
+
query=query,
|
|
347
|
+
k=search_parameters.num_results or 5,
|
|
348
|
+
filter=combined_filters if combined_filters else None,
|
|
349
|
+
query_type=search_parameters.query_type or "ANN",
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Apply FlashRank reranking if configured
|
|
353
|
+
if ranker and rerank_config:
|
|
354
|
+
logger.debug("Applying FlashRank reranking")
|
|
355
|
+
documents = _rerank_documents(query, documents, ranker, rerank_config)
|
|
356
|
+
|
|
357
|
+
# Serialize documents to JSON format for LLM consumption
|
|
358
|
+
# Convert Document objects to dicts with page_content and metadata
|
|
359
|
+
# Need to handle numpy types in metadata (e.g., float32, int64)
|
|
360
|
+
serialized_docs: list[dict[str, Any]] = []
|
|
361
|
+
for doc in documents:
|
|
362
|
+
doc: Document
|
|
363
|
+
# Convert metadata values to JSON-serializable types
|
|
364
|
+
metadata_serializable: dict[str, Any] = {}
|
|
365
|
+
for key, value in doc.metadata.items():
|
|
366
|
+
# Handle numpy types
|
|
367
|
+
if hasattr(value, "item"): # numpy scalar
|
|
368
|
+
metadata_serializable[key] = value.item()
|
|
369
|
+
else:
|
|
370
|
+
metadata_serializable[key] = value
|
|
371
|
+
|
|
372
|
+
serialized_docs.append(
|
|
373
|
+
{
|
|
374
|
+
"page_content": doc.page_content,
|
|
375
|
+
"metadata": metadata_serializable,
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
# Return as JSON string
|
|
380
|
+
return json.dumps(serialized_docs)
|
|
381
|
+
|
|
382
|
+
# Create the StructuredTool
|
|
383
|
+
tool: StructuredTool = StructuredTool.from_function(
|
|
384
|
+
func=vector_search_func,
|
|
385
|
+
name=name or f"vector_search_{vector_store.index.name}",
|
|
386
|
+
description=description or f"Search documents in {index_name}",
|
|
387
|
+
args_schema=input_schema,
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
logger.success("Vector search tool created", name=tool.name, index=index_name)
|
|
391
|
+
|
|
392
|
+
return tool
|