dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -2,126 +2,54 @@
2
2
  Vector search tool for retrieving documents from Databricks Vector Search.
3
3
 
4
4
  This module provides a tool factory for creating semantic search tools
5
- with dynamic filter schemas based on table columns and FlashRank reranking support.
5
+ with dynamic filter schemas based on table columns, FlashRank reranking support,
6
+ instructed retrieval with query decomposition and RRF merging, and optional
7
+ query routing, result verification, and instruction-aware reranking.
6
8
  """
7
9
 
8
10
  import json
9
11
  import os
10
- from typing import Any, Optional
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from typing import Annotated, Any, Literal, Optional
11
14
 
12
15
  import mlflow
13
16
  from databricks.sdk import WorkspaceClient
14
17
  from databricks.vector_search.reranker import DatabricksReranker
15
18
  from databricks_langchain import DatabricksVectorSearch
16
19
  from flashrank import Ranker, RerankRequest
20
+ from langchain.tools import ToolRuntime, tool
17
21
  from langchain_core.documents import Document
18
22
  from langchain_core.tools import StructuredTool
19
23
  from loguru import logger
20
24
  from mlflow.entities import SpanType
21
- from pydantic import BaseModel, ConfigDict, Field, create_model
22
25
 
23
26
  from dao_ai.config import (
27
+ ColumnInfo,
28
+ FilterItem,
29
+ InstructedRetrieverModel,
24
30
  RerankParametersModel,
25
31
  RetrieverModel,
32
+ RouterModel,
26
33
  SearchParametersModel,
34
+ SearchQuery,
27
35
  VectorStoreModel,
36
+ VerificationResult,
37
+ VerifierModel,
28
38
  value_of,
29
39
  )
30
- from dao_ai.utils import normalize_host
31
-
32
- # Create FilterItem model at module level so it can be used in type hints
33
- FilterItem = create_model(
34
- "FilterItem",
35
- key=(
36
- str,
37
- Field(
38
- description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
39
- ),
40
- ),
41
- value=(
42
- Any,
43
- Field(
44
- description="The filter value, which can be a single value or an array of values"
45
- ),
46
- ),
47
- __config__=ConfigDict(extra="forbid"),
40
+ from dao_ai.state import Context
41
+ from dao_ai.tools.instructed_retriever import (
42
+ _get_cached_llm,
43
+ decompose_query,
44
+ rrf_merge,
48
45
  )
46
+ from dao_ai.tools.instruction_reranker import instruction_aware_rerank
47
+ from dao_ai.tools.router import route_query
48
+ from dao_ai.tools.verifier import add_verification_metadata, verify_results
49
+ from dao_ai.utils import is_in_model_serving, normalize_host
49
50
 
50
51
 
51
- def _create_dynamic_input_schema(
52
- index_name: str, workspace_client: WorkspaceClient
53
- ) -> type[BaseModel]:
54
- """
55
- Create dynamic input schema with column information from the table.
56
-
57
- Args:
58
- index_name: Full name of the vector search index
59
- workspace_client: Workspace client to query table metadata
60
-
61
- Returns:
62
- Pydantic model class for tool input
63
- """
64
-
65
- # Try to get column information
66
- column_descriptions = []
67
- try:
68
- table_info = workspace_client.tables.get(full_name=index_name)
69
- for column_info in table_info.columns:
70
- name = column_info.name
71
- col_type = column_info.type_name.name
72
- if not name.startswith("__"):
73
- column_descriptions.append(f"{name} ({col_type})")
74
- except Exception:
75
- logger.debug(
76
- "Could not retrieve column information for dynamic schema",
77
- index=index_name,
78
- )
79
-
80
- # Build filter description matching VectorSearchRetrieverTool format
81
- filter_description = (
82
- "Optional filters to refine vector search results as an array of key-value pairs. "
83
- "IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
84
- "then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
85
- )
86
-
87
- if column_descriptions:
88
- filter_description += (
89
- f"Available columns for filtering: {', '.join(column_descriptions)}. "
90
- )
91
-
92
- filter_description += (
93
- "Supports the following operators:\n\n"
94
- '- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
95
- '- Exclusion: [{"key": "column NOT", "value": value}]\n'
96
- '- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
97
- '- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
98
- '- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
99
- "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
100
- "Examples:\n"
101
- '- Filter by category: [{"key": "category", "value": "electronics"}]\n'
102
- '- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
103
- '- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
104
- '- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
105
- )
106
-
107
- # Create the input model
108
- VectorSearchInput = create_model(
109
- "VectorSearchInput",
110
- query=(
111
- str,
112
- Field(description="The search query string to find relevant documents"),
113
- ),
114
- filters=(
115
- Optional[list[FilterItem]],
116
- Field(default=None, description=filter_description),
117
- ),
118
- __config__=ConfigDict(extra="forbid"),
119
- )
120
-
121
- return VectorSearchInput
122
-
123
-
124
- @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
52
+ @mlflow.trace(name="rerank_documents", span_type=SpanType.RERANKER)
125
53
  def _rerank_documents(
126
54
  query: str,
127
55
  documents: list[Document],
@@ -146,6 +74,11 @@ def _rerank_documents(
146
74
  model=rerank_config.model,
147
75
  )
148
76
 
77
+ # Early return if no documents to rerank
78
+ if not documents:
79
+ logger.debug("No documents to rerank, skipping")
80
+ return documents
81
+
149
82
  # Prepare passages for reranking
150
83
  passages: list[dict[str, Any]] = [
151
84
  {"text": doc.page_content, "meta": doc.metadata} for doc in documents
@@ -231,9 +164,12 @@ def create_vector_search_tool(
231
164
  raise ValueError("vector_store.index is required for vector search")
232
165
 
233
166
  index_name: str = vector_store.index.full_name
234
- columns: list[str] = list(retriever.columns or [])
167
+ columns: list[str] = list(retriever.columns or vector_store.index.columns or [])
235
168
  search_parameters: SearchParametersModel = retriever.search_parameters
169
+ router_config: Optional[RouterModel] = retriever.router
236
170
  rerank_config: Optional[RerankParametersModel] = retriever.rerank
171
+ instructed_config: Optional[InstructedRetrieverModel] = retriever.instructed
172
+ verifier_config: Optional[VerifierModel] = retriever.verifier
237
173
 
238
174
  # Initialize FlashRank ranker if configured
239
175
  ranker: Optional[Ranker] = None
@@ -244,50 +180,121 @@ def create_vector_search_tool(
244
180
  top_n=rerank_config.top_n or "auto",
245
181
  )
246
182
  try:
247
- cache_dir = os.path.expanduser(rerank_config.cache_dir)
183
+ # Use /tmp for cache in Model Serving (home dir may not be writable)
184
+ if is_in_model_serving():
185
+ cache_dir = "/tmp/dao_ai/cache/flashrank"
186
+ if rerank_config.cache_dir != cache_dir:
187
+ logger.warning(
188
+ "FlashRank cache_dir overridden in Model Serving",
189
+ configured=rerank_config.cache_dir,
190
+ actual=cache_dir,
191
+ )
192
+ else:
193
+ cache_dir = os.path.expanduser(rerank_config.cache_dir)
248
194
  ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
195
+
196
+ # Patch rerank to always include token_type_ids for ONNX compatibility
197
+ # Some ONNX runtimes require token_type_ids even when the model doesn't use them
198
+ # FlashRank conditionally excludes them when all zeros, but ONNX may still expect them
199
+ # See: https://github.com/huggingface/optimum/issues/1500
200
+ if ranker.session is not None:
201
+ import numpy as np
202
+
203
+ _original_rerank = ranker.rerank
204
+
205
+ def _patched_rerank(request):
206
+ query = request.query
207
+ passages = request.passages
208
+ query_passage_pairs = [[query, p["text"]] for p in passages]
209
+
210
+ input_text = ranker.tokenizer.encode_batch(query_passage_pairs)
211
+ input_ids = np.array([e.ids for e in input_text])
212
+ token_type_ids = np.array([e.type_ids for e in input_text])
213
+ attention_mask = np.array([e.attention_mask for e in input_text])
214
+
215
+ # Always include token_type_ids (the fix for ONNX compatibility)
216
+ onnx_input = {
217
+ "input_ids": input_ids.astype(np.int64),
218
+ "attention_mask": attention_mask.astype(np.int64),
219
+ "token_type_ids": token_type_ids.astype(np.int64),
220
+ }
221
+
222
+ outputs = ranker.session.run(None, onnx_input)
223
+ logits = outputs[0]
224
+
225
+ if logits.shape[1] == 1:
226
+ scores = 1 / (1 + np.exp(-logits.flatten()))
227
+ else:
228
+ exp_logits = np.exp(logits)
229
+ scores = exp_logits[:, 1] / np.sum(exp_logits, axis=1)
230
+
231
+ for score, passage in zip(scores, passages):
232
+ passage["score"] = score
233
+
234
+ passages.sort(key=lambda x: x["score"], reverse=True)
235
+ return passages
236
+
237
+ ranker.rerank = _patched_rerank
238
+
249
239
  logger.success("FlashRank ranker initialized", model=rerank_config.model)
250
240
  except Exception as e:
251
241
  logger.warning("Failed to initialize FlashRank ranker", error=str(e))
252
242
  rerank_config = None
253
243
 
244
+ # Log instructed retrieval configuration
245
+ if instructed_config:
246
+ logger.success(
247
+ "Instructed retrieval configured",
248
+ decomposition_model=instructed_config.decomposition_model.name
249
+ if instructed_config.decomposition_model
250
+ else None,
251
+ max_subqueries=instructed_config.max_subqueries,
252
+ rrf_k=instructed_config.rrf_k,
253
+ )
254
+
255
+ # Log instruction-aware reranking configuration
256
+ if rerank_config and rerank_config.instruction_aware:
257
+ logger.success(
258
+ "Instruction-aware reranking configured",
259
+ model=rerank_config.instruction_aware.model.name
260
+ if rerank_config.instruction_aware.model
261
+ else None,
262
+ top_n=rerank_config.instruction_aware.top_n,
263
+ )
264
+
254
265
  # Build client_args for VectorSearchClient
255
- # Use getattr to safely access attributes that may not exist (e.g., in mocks)
256
266
  client_args: dict[str, Any] = {}
257
267
  has_explicit_auth = any(
258
268
  [
259
269
  os.environ.get("DATABRICKS_TOKEN"),
260
270
  os.environ.get("DATABRICKS_CLIENT_ID"),
261
- getattr(vector_store, "pat", None),
262
- getattr(vector_store, "client_id", None),
263
- getattr(vector_store, "on_behalf_of_user", None),
271
+ vector_store.pat,
272
+ vector_store.client_id,
273
+ vector_store.on_behalf_of_user,
264
274
  ]
265
275
  )
266
276
 
267
277
  if has_explicit_auth:
268
278
  databricks_host = os.environ.get("DATABRICKS_HOST")
269
- if (
270
- not databricks_host
271
- and getattr(vector_store, "_workspace_client", None) is not None
272
- ):
273
- databricks_host = vector_store.workspace_client.config.host
279
+ if not databricks_host and vector_store.workspace_host:
280
+ databricks_host = value_of(vector_store.workspace_host)
274
281
  if databricks_host:
275
282
  client_args["workspace_url"] = normalize_host(databricks_host)
276
283
 
277
284
  token = os.environ.get("DATABRICKS_TOKEN")
278
- if not token and getattr(vector_store, "pat", None):
285
+ if not token and vector_store.pat:
279
286
  token = value_of(vector_store.pat)
280
287
  if token:
281
288
  client_args["personal_access_token"] = token
282
289
 
283
290
  client_id = os.environ.get("DATABRICKS_CLIENT_ID")
284
- if not client_id and getattr(vector_store, "client_id", None):
291
+ if not client_id and vector_store.client_id:
285
292
  client_id = value_of(vector_store.client_id)
286
293
  if client_id:
287
294
  client_args["service_principal_client_id"] = client_id
288
295
 
289
296
  client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
290
- if not client_secret and getattr(vector_store, "client_secret", None):
297
+ if not client_secret and vector_store.client_secret:
291
298
  client_secret = value_of(vector_store.client_secret)
292
299
  if client_secret:
293
300
  client_args["service_principal_client_secret"] = client_secret
@@ -299,71 +306,406 @@ def create_vector_search_tool(
299
306
  client_args_keys=list(client_args.keys()) if client_args else [],
300
307
  )
301
308
 
302
- # Create DatabricksVectorSearch
303
- # Note: text_column should be None for Databricks-managed embeddings
304
- # (it's automatically determined from the index)
305
- vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
306
- index_name=index_name,
307
- text_column=None,
308
- columns=columns,
309
- workspace_client=vector_store.workspace_client,
310
- client_args=client_args if client_args else None,
311
- primary_key=vector_store.primary_key,
312
- doc_uri=vector_store.doc_uri,
313
- include_score=True,
314
- reranker=(
315
- DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
- if rerank_config and rerank_config.columns
317
- else None
318
- ),
319
- )
309
+ # Cache for DatabricksVectorSearch - created lazily for OBO support
310
+ _cached_vector_search: DatabricksVectorSearch | None = None
311
+
312
+ def _get_vector_search(context: Context | None) -> DatabricksVectorSearch:
313
+ """Get or create DatabricksVectorSearch, using context for OBO auth if available."""
314
+ nonlocal _cached_vector_search
315
+
316
+ # Use cached instance if available and not OBO
317
+ if _cached_vector_search is not None and not vector_store.on_behalf_of_user:
318
+ return _cached_vector_search
319
+
320
+ # Get workspace client with OBO support via context
321
+ workspace_client: WorkspaceClient = vector_store.workspace_client_from(context)
322
+
323
+ # Create DatabricksVectorSearch
324
+ # Note: text_column should be None for Databricks-managed embeddings
325
+ # (it's automatically determined from the index)
326
+ vs: DatabricksVectorSearch = DatabricksVectorSearch(
327
+ index_name=index_name,
328
+ text_column=None,
329
+ columns=columns,
330
+ workspace_client=workspace_client,
331
+ client_args=client_args if client_args else None,
332
+ primary_key=vector_store.primary_key,
333
+ doc_uri=vector_store.doc_uri,
334
+ include_score=True,
335
+ reranker=(
336
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
337
+ if rerank_config and rerank_config.columns
338
+ else None
339
+ ),
340
+ )
320
341
 
321
- # Create dynamic input schema
322
- input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
- index_name, vector_store.workspace_client
324
- )
342
+ # Cache for non-OBO scenarios
343
+ if not vector_store.on_behalf_of_user:
344
+ _cached_vector_search = vs
345
+
346
+ return vs
347
+
348
+ # Determine tool name and description
349
+ tool_name: str = name or f"vector_search_{vector_store.index.name}"
350
+
351
+ # Build tool description with available columns for filtering
352
+ base_description: str = description or f"Search documents in {index_name}"
353
+ if columns:
354
+ columns_list = ", ".join(columns)
355
+ tool_description = (
356
+ f"{base_description}. "
357
+ f"Available filter columns: {columns_list}. "
358
+ f"Filter operators: 'column' for equality, 'column NOT' for exclusion, "
359
+ f"'column <', 'column <=', 'column >', 'column >=' for comparison, "
360
+ f"'column LIKE' for token matching, 'column NOT LIKE' to exclude tokens."
361
+ )
362
+ else:
363
+ tool_description = base_description
364
+
365
+ @mlflow.trace(name="execute_instructed_retrieval", span_type=SpanType.RETRIEVER)
366
+ def _execute_instructed_retrieval(
367
+ vs: DatabricksVectorSearch,
368
+ query: str,
369
+ base_filters: dict[str, Any],
370
+ previous_feedback: str | None = None,
371
+ ) -> list[Document]:
372
+ """Execute instructed retrieval with query decomposition and RRF merging."""
373
+ logger.trace(
374
+ "Executing instructed retrieval", query=query, base_filters=base_filters
375
+ )
376
+ try:
377
+ decomposition_llm = _get_cached_llm(instructed_config.decomposition_model)
378
+
379
+ # Fall back to retriever columns if instructed columns not provided
380
+ instructed_columns: list[ColumnInfo] | None = instructed_config.columns
381
+ if instructed_columns is None and columns:
382
+ instructed_columns = [ColumnInfo(name=col) for col in columns]
383
+
384
+ subqueries: list[SearchQuery] = decompose_query(
385
+ llm=decomposition_llm,
386
+ query=query,
387
+ schema_description=instructed_config.schema_description,
388
+ constraints=instructed_config.constraints,
389
+ max_subqueries=instructed_config.max_subqueries,
390
+ examples=instructed_config.examples,
391
+ previous_feedback=previous_feedback,
392
+ columns=instructed_columns,
393
+ )
394
+
395
+ if not subqueries:
396
+ logger.warning(
397
+ "Query decomposition returned no subqueries, using original"
398
+ )
399
+ return vs.similarity_search(
400
+ query=query,
401
+ k=search_parameters.num_results or 5,
402
+ filter=base_filters if base_filters else None,
403
+ query_type=search_parameters.query_type or "ANN",
404
+ )
405
+
406
+ def normalize_filter_values(
407
+ filters: dict[str, Any], case: str | None
408
+ ) -> dict[str, Any]:
409
+ """Normalize string filter values to specified case."""
410
+ logger.trace("Normalizing filter values", filters=filters, case=case)
411
+ if not case or not filters:
412
+ return filters
413
+ normalized = {}
414
+ for key, value in filters.items():
415
+ if isinstance(value, str):
416
+ normalized[key] = (
417
+ value.upper() if case == "uppercase" else value.lower()
418
+ )
419
+ elif isinstance(value, list):
420
+ normalized[key] = [
421
+ v.upper()
422
+ if case == "uppercase"
423
+ else v.lower()
424
+ if isinstance(v, str)
425
+ else v
426
+ for v in value
427
+ ]
428
+ else:
429
+ normalized[key] = value
430
+ return normalized
431
+
432
+ def execute_search(sq: SearchQuery) -> list[Document]:
433
+ logger.trace("Executing search", query=sq.text, filters=sq.filters)
434
+ # Convert FilterItem list to dict
435
+ sq_filters_dict: dict[str, Any] = {}
436
+ if sq.filters:
437
+ for item in sq.filters:
438
+ sq_filters_dict[item.key] = item.value
439
+ sq_filters = normalize_filter_values(
440
+ sq_filters_dict, instructed_config.normalize_filter_case
441
+ )
442
+ k: int = search_parameters.num_results or 5
443
+ query_type: str = search_parameters.query_type or "ANN"
444
+ combined_filters: dict[str, Any] = {**sq_filters, **base_filters}
445
+ logger.trace(
446
+ "Executing search",
447
+ query=sq.text,
448
+ k=k,
449
+ query_type=query_type,
450
+ filters=combined_filters,
451
+ )
452
+ return vs.similarity_search(
453
+ query=sq.text,
454
+ k=k,
455
+ filter=combined_filters if combined_filters else None,
456
+ query_type=query_type,
457
+ )
458
+
459
+ logger.debug(
460
+ "Executing parallel searches",
461
+ num_subqueries=len(subqueries),
462
+ queries=[sq.text[:50] for sq in subqueries],
463
+ )
464
+
465
+ with ThreadPoolExecutor(
466
+ max_workers=instructed_config.max_subqueries
467
+ ) as executor:
468
+ all_results = list(executor.map(execute_search, subqueries))
469
+
470
+ merged = rrf_merge(
471
+ all_results,
472
+ k=instructed_config.rrf_k,
473
+ primary_key=vector_store.primary_key,
474
+ )
475
+
476
+ logger.debug(
477
+ "Instructed retrieval complete",
478
+ num_subqueries=len(subqueries),
479
+ total_results=sum(len(r) for r in all_results),
480
+ merged_results=len(merged),
481
+ )
482
+
483
+ return merged
484
+
485
+ except Exception as e:
486
+ logger.warning(
487
+ "Instructed retrieval failed, falling back to standard search",
488
+ error=str(e),
489
+ )
490
+ return vs.similarity_search(
491
+ query=query,
492
+ k=search_parameters.num_results or 5,
493
+ filter=base_filters if base_filters else None,
494
+ query_type=search_parameters.query_type or "ANN",
495
+ )
325
496
 
326
- # Define the tool function
327
- def vector_search_func(
328
- query: str, filters: Optional[list[FilterItem]] = None
497
+ @mlflow.trace(name="execute_standard_search", span_type=SpanType.RETRIEVER)
498
+ def _execute_standard_search(
499
+ vs: DatabricksVectorSearch,
500
+ query: str,
501
+ base_filters: dict[str, Any],
502
+ ) -> list[Document]:
503
+ """Execute standard single-query search."""
504
+ logger.trace("Performing standard vector search", query_preview=query[:50])
505
+ return vs.similarity_search(
506
+ query=query,
507
+ k=search_parameters.num_results or 5,
508
+ filter=base_filters if base_filters else None,
509
+ query_type=search_parameters.query_type or "ANN",
510
+ )
511
+
512
+ @mlflow.trace(name="apply_post_processing", span_type=SpanType.RETRIEVER)
513
+ def _apply_post_processing(
514
+ documents: list[Document],
515
+ query: str,
516
+ mode: Literal["standard", "instructed"],
517
+ auto_bypass: bool,
518
+ ) -> list[Document]:
519
+ """Apply instruction-aware reranking and verification based on mode and bypass settings."""
520
+ # Skip post-processing for standard mode when auto_bypass is enabled
521
+ if mode == "standard" and auto_bypass:
522
+ mlflow.set_tag("router.bypassed_stages", "true")
523
+ return documents
524
+
525
+ # Apply instruction-aware reranking if configured
526
+ if rerank_config and rerank_config.instruction_aware:
527
+ instruction_config = rerank_config.instruction_aware
528
+ instruction_llm = (
529
+ _get_cached_llm(instruction_config.model)
530
+ if instruction_config.model
531
+ else None
532
+ )
533
+
534
+ if instruction_llm:
535
+ schema_desc = (
536
+ instructed_config.schema_description if instructed_config else None
537
+ )
538
+ # Get columns for dynamic instruction generation
539
+ rerank_columns: list[ColumnInfo] | None = None
540
+ if instructed_config and instructed_config.columns:
541
+ rerank_columns = instructed_config.columns
542
+ elif columns:
543
+ rerank_columns = [ColumnInfo(name=col) for col in columns]
544
+
545
+ documents = instruction_aware_rerank(
546
+ llm=instruction_llm,
547
+ query=query,
548
+ documents=documents,
549
+ instructions=instruction_config.instructions,
550
+ schema_description=schema_desc,
551
+ columns=rerank_columns,
552
+ top_n=instruction_config.top_n,
553
+ )
554
+
555
+ # Apply verification if configured
556
+ if verifier_config:
557
+ verifier_llm = (
558
+ _get_cached_llm(verifier_config.model)
559
+ if verifier_config.model
560
+ else None
561
+ )
562
+
563
+ if verifier_llm:
564
+ schema_desc = (
565
+ instructed_config.schema_description if instructed_config else ""
566
+ )
567
+ constraints = (
568
+ instructed_config.constraints if instructed_config else None
569
+ )
570
+ retry_count = 0
571
+ verification_result: VerificationResult | None = None
572
+ previous_feedback: str | None = None
573
+
574
+ while retry_count <= verifier_config.max_retries:
575
+ verification_result = verify_results(
576
+ llm=verifier_llm,
577
+ query=query,
578
+ documents=documents,
579
+ schema_description=schema_desc,
580
+ constraints=constraints,
581
+ previous_feedback=previous_feedback,
582
+ )
583
+
584
+ if verification_result.passed:
585
+ mlflow.set_tag("verifier.outcome", "passed")
586
+ mlflow.set_tag("verifier.retries", str(retry_count))
587
+ break
588
+
589
+ # Handle failure based on configuration
590
+ if verifier_config.on_failure == "warn":
591
+ mlflow.set_tag("verifier.outcome", "warned")
592
+ documents = add_verification_metadata(
593
+ documents, verification_result
594
+ )
595
+ break
596
+
597
+ if retry_count >= verifier_config.max_retries:
598
+ mlflow.set_tag("verifier.outcome", "exhausted")
599
+ mlflow.set_tag("verifier.retries", str(retry_count))
600
+ documents = add_verification_metadata(
601
+ documents, verification_result, exhausted=True
602
+ )
603
+ break
604
+
605
+ # Retry with feedback
606
+ mlflow.set_tag("verifier.outcome", "retried")
607
+ previous_feedback = verification_result.feedback
608
+ retry_count += 1
609
+ logger.debug(
610
+ "Retrying search with verification feedback", retry=retry_count
611
+ )
612
+
613
+ return documents
614
+
615
+ # Use @tool decorator for proper ToolRuntime injection
616
+ @tool(name_or_callable=tool_name, description=tool_description)
617
+ def _vector_search_tool(
618
+ query: Annotated[str, "The search query to find relevant documents"],
619
+ filters: Annotated[
620
+ Optional[list[FilterItem]],
621
+ "Optional filters as key-value pairs. "
622
+ "Key operators: 'column' (equality), 'column NOT' (exclusion), "
623
+ "'column <', '<=', '>', '>=' (comparison), "
624
+ "'column LIKE' (token match), 'column NOT LIKE' (exclude token). "
625
+ f"Valid columns: {', '.join(columns) if columns else 'none'}.",
626
+ ] = None,
627
+ runtime: ToolRuntime[Context] = None,
329
628
  ) -> str:
330
629
  """Search for relevant documents using vector similarity."""
331
- # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
630
+ context: Context | None = runtime.context if runtime else None
631
+ vs: DatabricksVectorSearch = _get_vector_search(context)
632
+
332
633
  filters_dict: dict[str, Any] = {}
333
634
  if filters:
334
635
  for item in filters:
335
636
  filters_dict[item.key] = item.value
336
637
 
337
- # Merge with configured filters
338
- combined_filters: dict[str, Any] = {
638
+ base_filters: dict[str, Any] = {
339
639
  **filters_dict,
340
640
  **(search_parameters.filters or {}),
341
641
  }
342
642
 
343
- # Perform vector search
344
- logger.trace("Performing vector search", query_preview=query[:50])
345
- documents: list[Document] = vector_search.similarity_search(
346
- query=query,
347
- k=search_parameters.num_results or 5,
348
- filter=combined_filters if combined_filters else None,
349
- query_type=search_parameters.query_type or "ANN",
643
+ # Determine execution mode via router or config
644
+ mode: Literal["standard", "instructed"] = "standard"
645
+ auto_bypass = True
646
+
647
+ logger.trace("Router configuration", router_config=router_config)
648
+ logger.trace("Instructed configuration", instructed_config=instructed_config)
649
+ logger.trace(
650
+ "Instruction-aware rerank configuration",
651
+ instruction_aware=rerank_config.instruction_aware
652
+ if rerank_config
653
+ else None,
350
654
  )
351
655
 
656
+ if router_config:
657
+ router_llm = (
658
+ _get_cached_llm(router_config.model) if router_config.model else None
659
+ )
660
+ auto_bypass = router_config.auto_bypass
661
+
662
+ if router_llm and instructed_config:
663
+ try:
664
+ mode = route_query(
665
+ llm=router_llm,
666
+ query=query,
667
+ schema_description=instructed_config.schema_description,
668
+ )
669
+ except Exception as e:
670
+ # Router fail-safe: default to standard mode
671
+ logger.warning(
672
+ "Router failed, defaulting to standard mode", error=str(e)
673
+ )
674
+ mlflow.set_tag("router.fallback", "true")
675
+ mode = router_config.default_mode
676
+ else:
677
+ mode = router_config.default_mode
678
+ elif instructed_config:
679
+ # No router but instructed is configured - use instructed mode
680
+ mode = "instructed"
681
+ auto_bypass = False
682
+ elif rerank_config and rerank_config.instruction_aware:
683
+ # No router/instructed but instruction_aware reranking is configured
684
+ # Disable auto_bypass to ensure instruction_aware reranking runs
685
+ auto_bypass = False
686
+
687
+ logger.trace("Routing mode", mode=mode, auto_bypass=auto_bypass)
688
+ mlflow.set_tag("router.mode", mode)
689
+
690
+ # Execute search based on mode
691
+ if mode == "instructed" and instructed_config:
692
+ documents = _execute_instructed_retrieval(vs, query, base_filters)
693
+ else:
694
+ documents = _execute_standard_search(vs, query, base_filters)
695
+
352
696
  # Apply FlashRank reranking if configured
353
697
  if ranker and rerank_config:
354
698
  logger.debug("Applying FlashRank reranking")
355
699
  documents = _rerank_documents(query, documents, ranker, rerank_config)
356
700
 
701
+ # Apply post-processing (instruction reranking + verification)
702
+ documents = _apply_post_processing(documents, query, mode, auto_bypass)
703
+
357
704
  # Serialize documents to JSON format for LLM consumption
358
- # Convert Document objects to dicts with page_content and metadata
359
- # Need to handle numpy types in metadata (e.g., float32, int64)
360
705
  serialized_docs: list[dict[str, Any]] = []
361
706
  for doc in documents:
362
- doc: Document
363
- # Convert metadata values to JSON-serializable types
364
707
  metadata_serializable: dict[str, Any] = {}
365
708
  for key, value in doc.metadata.items():
366
- # Handle numpy types
367
709
  if hasattr(value, "item"): # numpy scalar
368
710
  metadata_serializable[key] = value.item()
369
711
  else:
@@ -376,17 +718,8 @@ def create_vector_search_tool(
376
718
  }
377
719
  )
378
720
 
379
- # Return as JSON string
380
721
  return json.dumps(serialized_docs)
381
722
 
382
- # Create the StructuredTool
383
- tool: StructuredTool = StructuredTool.from_function(
384
- func=vector_search_func,
385
- name=name or f"vector_search_{vector_store.index.name}",
386
- description=description or f"Search documents in {index_name}",
387
- args_schema=input_schema,
388
- )
389
-
390
- logger.success("Vector search tool created", name=tool.name, index=index_name)
723
+ logger.success("Vector search tool created", name=tool_name, index=index_name)
391
724
 
392
- return tool
725
+ return _vector_search_tool