dao-ai 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,12 +2,15 @@
2
2
  Vector search tool for retrieving documents from Databricks Vector Search.
3
3
 
4
4
  This module provides a tool factory for creating semantic search tools
5
- with dynamic filter schemas based on table columns and FlashRank reranking support.
5
+ with dynamic filter schemas based on table columns, FlashRank reranking support,
6
+ instructed retrieval with query decomposition and RRF merging, and optional
7
+ query routing, result verification, and instruction-aware reranking.
6
8
  """
7
9
 
8
10
  import json
9
11
  import os
10
- from typing import Annotated, Any, Optional
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ from typing import Annotated, Any, Literal, Optional
11
14
 
12
15
  import mlflow
13
16
  from databricks.sdk import WorkspaceClient
@@ -19,111 +22,34 @@ from langchain_core.documents import Document
19
22
  from langchain_core.tools import StructuredTool
20
23
  from loguru import logger
21
24
  from mlflow.entities import SpanType
22
- from pydantic import BaseModel, ConfigDict, Field, create_model
23
25
 
24
26
  from dao_ai.config import (
27
+ ColumnInfo,
28
+ FilterItem,
29
+ InstructedRetrieverModel,
25
30
  RerankParametersModel,
26
31
  RetrieverModel,
32
+ RouterModel,
27
33
  SearchParametersModel,
34
+ SearchQuery,
28
35
  VectorStoreModel,
36
+ VerificationResult,
37
+ VerifierModel,
29
38
  value_of,
30
39
  )
31
40
  from dao_ai.state import Context
32
- from dao_ai.utils import normalize_host
33
-
34
- # Create FilterItem model at module level so it can be used in type hints
35
- FilterItem = create_model(
36
- "FilterItem",
37
- key=(
38
- str,
39
- Field(
40
- description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
41
- ),
42
- ),
43
- value=(
44
- Any,
45
- Field(
46
- description="The filter value, which can be a single value or an array of values"
47
- ),
48
- ),
49
- __config__=ConfigDict(extra="forbid"),
41
+ from dao_ai.tools.instructed_retriever import (
42
+ _get_cached_llm,
43
+ decompose_query,
44
+ rrf_merge,
50
45
  )
46
+ from dao_ai.tools.instruction_reranker import instruction_aware_rerank
47
+ from dao_ai.tools.router import route_query
48
+ from dao_ai.tools.verifier import add_verification_metadata, verify_results
49
+ from dao_ai.utils import is_in_model_serving, normalize_host
51
50
 
52
51
 
53
- def _create_dynamic_input_schema(
54
- index_name: str, workspace_client: WorkspaceClient
55
- ) -> type[BaseModel]:
56
- """
57
- Create dynamic input schema with column information from the table.
58
-
59
- Args:
60
- index_name: Full name of the vector search index
61
- workspace_client: Workspace client to query table metadata
62
-
63
- Returns:
64
- Pydantic model class for tool input
65
- """
66
-
67
- # Try to get column information
68
- column_descriptions = []
69
- try:
70
- table_info = workspace_client.tables.get(full_name=index_name)
71
- for column_info in table_info.columns:
72
- name = column_info.name
73
- col_type = column_info.type_name.name
74
- if not name.startswith("__"):
75
- column_descriptions.append(f"{name} ({col_type})")
76
- except Exception:
77
- logger.debug(
78
- "Could not retrieve column information for dynamic schema",
79
- index=index_name,
80
- )
81
-
82
- # Build filter description matching VectorSearchRetrieverTool format
83
- filter_description = (
84
- "Optional filters to refine vector search results as an array of key-value pairs. "
85
- "IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
86
- "then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
87
- )
88
-
89
- if column_descriptions:
90
- filter_description += (
91
- f"Available columns for filtering: {', '.join(column_descriptions)}. "
92
- )
93
-
94
- filter_description += (
95
- "Supports the following operators:\n\n"
96
- '- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
97
- '- Exclusion: [{"key": "column NOT", "value": value}]\n'
98
- '- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
99
- '- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
100
- '- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
101
- "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
102
- "Examples:\n"
103
- '- Filter by category: [{"key": "category", "value": "electronics"}]\n'
104
- '- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
105
- '- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
106
- '- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
107
- )
108
-
109
- # Create the input model
110
- VectorSearchInput = create_model(
111
- "VectorSearchInput",
112
- query=(
113
- str,
114
- Field(description="The search query string to find relevant documents"),
115
- ),
116
- filters=(
117
- Optional[list[FilterItem]],
118
- Field(default=None, description=filter_description),
119
- ),
120
- __config__=ConfigDict(extra="forbid"),
121
- )
122
-
123
- return VectorSearchInput
124
-
125
-
126
- @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
52
+ @mlflow.trace(name="rerank_documents", span_type=SpanType.RERANKER)
127
53
  def _rerank_documents(
128
54
  query: str,
129
55
  documents: list[Document],
@@ -148,6 +74,11 @@ def _rerank_documents(
148
74
  model=rerank_config.model,
149
75
  )
150
76
 
77
+ # Early return if no documents to rerank
78
+ if not documents:
79
+ logger.debug("No documents to rerank, skipping")
80
+ return documents
81
+
151
82
  # Prepare passages for reranking
152
83
  passages: list[dict[str, Any]] = [
153
84
  {"text": doc.page_content, "meta": doc.metadata} for doc in documents
@@ -233,9 +164,12 @@ def create_vector_search_tool(
233
164
  raise ValueError("vector_store.index is required for vector search")
234
165
 
235
166
  index_name: str = vector_store.index.full_name
236
- columns: list[str] = list(retriever.columns or [])
167
+ columns: list[str] = list(retriever.columns or vector_store.index.columns or [])
237
168
  search_parameters: SearchParametersModel = retriever.search_parameters
169
+ router_config: Optional[RouterModel] = retriever.router
238
170
  rerank_config: Optional[RerankParametersModel] = retriever.rerank
171
+ instructed_config: Optional[InstructedRetrieverModel] = retriever.instructed
172
+ verifier_config: Optional[VerifierModel] = retriever.verifier
239
173
 
240
174
  # Initialize FlashRank ranker if configured
241
175
  ranker: Optional[Ranker] = None
@@ -246,50 +180,121 @@ def create_vector_search_tool(
246
180
  top_n=rerank_config.top_n or "auto",
247
181
  )
248
182
  try:
249
- cache_dir = os.path.expanduser(rerank_config.cache_dir)
183
+ # Use /tmp for cache in Model Serving (home dir may not be writable)
184
+ if is_in_model_serving():
185
+ cache_dir = "/tmp/dao_ai/cache/flashrank"
186
+ if rerank_config.cache_dir != cache_dir:
187
+ logger.warning(
188
+ "FlashRank cache_dir overridden in Model Serving",
189
+ configured=rerank_config.cache_dir,
190
+ actual=cache_dir,
191
+ )
192
+ else:
193
+ cache_dir = os.path.expanduser(rerank_config.cache_dir)
250
194
  ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
195
+
196
+ # Patch rerank to always include token_type_ids for ONNX compatibility
197
+ # Some ONNX runtimes require token_type_ids even when the model doesn't use them
198
+ # FlashRank conditionally excludes them when all zeros, but ONNX may still expect them
199
+ # See: https://github.com/huggingface/optimum/issues/1500
200
+ if ranker.session is not None:
201
+ import numpy as np
202
+
203
+ _original_rerank = ranker.rerank
204
+
205
+ def _patched_rerank(request):
206
+ query = request.query
207
+ passages = request.passages
208
+ query_passage_pairs = [[query, p["text"]] for p in passages]
209
+
210
+ input_text = ranker.tokenizer.encode_batch(query_passage_pairs)
211
+ input_ids = np.array([e.ids for e in input_text])
212
+ token_type_ids = np.array([e.type_ids for e in input_text])
213
+ attention_mask = np.array([e.attention_mask for e in input_text])
214
+
215
+ # Always include token_type_ids (the fix for ONNX compatibility)
216
+ onnx_input = {
217
+ "input_ids": input_ids.astype(np.int64),
218
+ "attention_mask": attention_mask.astype(np.int64),
219
+ "token_type_ids": token_type_ids.astype(np.int64),
220
+ }
221
+
222
+ outputs = ranker.session.run(None, onnx_input)
223
+ logits = outputs[0]
224
+
225
+ if logits.shape[1] == 1:
226
+ scores = 1 / (1 + np.exp(-logits.flatten()))
227
+ else:
228
+ exp_logits = np.exp(logits)
229
+ scores = exp_logits[:, 1] / np.sum(exp_logits, axis=1)
230
+
231
+ for score, passage in zip(scores, passages):
232
+ passage["score"] = score
233
+
234
+ passages.sort(key=lambda x: x["score"], reverse=True)
235
+ return passages
236
+
237
+ ranker.rerank = _patched_rerank
238
+
251
239
  logger.success("FlashRank ranker initialized", model=rerank_config.model)
252
240
  except Exception as e:
253
241
  logger.warning("Failed to initialize FlashRank ranker", error=str(e))
254
242
  rerank_config = None
255
243
 
244
+ # Log instructed retrieval configuration
245
+ if instructed_config:
246
+ logger.success(
247
+ "Instructed retrieval configured",
248
+ decomposition_model=instructed_config.decomposition_model.name
249
+ if instructed_config.decomposition_model
250
+ else None,
251
+ max_subqueries=instructed_config.max_subqueries,
252
+ rrf_k=instructed_config.rrf_k,
253
+ )
254
+
255
+ # Log instruction-aware reranking configuration
256
+ if rerank_config and rerank_config.instruction_aware:
257
+ logger.success(
258
+ "Instruction-aware reranking configured",
259
+ model=rerank_config.instruction_aware.model.name
260
+ if rerank_config.instruction_aware.model
261
+ else None,
262
+ top_n=rerank_config.instruction_aware.top_n,
263
+ )
264
+
256
265
  # Build client_args for VectorSearchClient
257
- # Use getattr to safely access attributes that may not exist (e.g., in mocks)
258
266
  client_args: dict[str, Any] = {}
259
267
  has_explicit_auth = any(
260
268
  [
261
269
  os.environ.get("DATABRICKS_TOKEN"),
262
270
  os.environ.get("DATABRICKS_CLIENT_ID"),
263
- getattr(vector_store, "pat", None),
264
- getattr(vector_store, "client_id", None),
265
- getattr(vector_store, "on_behalf_of_user", None),
271
+ vector_store.pat,
272
+ vector_store.client_id,
273
+ vector_store.on_behalf_of_user,
266
274
  ]
267
275
  )
268
276
 
269
277
  if has_explicit_auth:
270
278
  databricks_host = os.environ.get("DATABRICKS_HOST")
271
- if (
272
- not databricks_host
273
- and getattr(vector_store, "_workspace_client", None) is not None
274
- ):
275
- databricks_host = vector_store.workspace_client.config.host
279
+ if not databricks_host and vector_store.workspace_host:
280
+ databricks_host = value_of(vector_store.workspace_host)
276
281
  if databricks_host:
277
282
  client_args["workspace_url"] = normalize_host(databricks_host)
278
283
 
279
284
  token = os.environ.get("DATABRICKS_TOKEN")
280
- if not token and getattr(vector_store, "pat", None):
285
+ if not token and vector_store.pat:
281
286
  token = value_of(vector_store.pat)
282
287
  if token:
283
288
  client_args["personal_access_token"] = token
284
289
 
285
290
  client_id = os.environ.get("DATABRICKS_CLIENT_ID")
286
- if not client_id and getattr(vector_store, "client_id", None):
291
+ if not client_id and vector_store.client_id:
287
292
  client_id = value_of(vector_store.client_id)
288
293
  if client_id:
289
294
  client_args["service_principal_client_id"] = client_id
290
295
 
291
296
  client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
292
- if not client_secret and getattr(vector_store, "client_secret", None):
297
+ if not client_secret and vector_store.client_secret:
293
298
  client_secret = value_of(vector_store.client_secret)
294
299
  if client_secret:
295
300
  client_args["service_principal_client_secret"] = client_secret
@@ -342,62 +347,365 @@ def create_vector_search_tool(
342
347
 
343
348
  # Determine tool name and description
344
349
  tool_name: str = name or f"vector_search_{vector_store.index.name}"
345
- tool_description: str = description or f"Search documents in {index_name}"
350
+
351
+ # Build tool description with available columns for filtering
352
+ base_description: str = description or f"Search documents in {index_name}"
353
+ if columns:
354
+ columns_list = ", ".join(columns)
355
+ tool_description = (
356
+ f"{base_description}. "
357
+ f"Available filter columns: {columns_list}. "
358
+ f"Filter operators: 'column' for equality, 'column NOT' for exclusion, "
359
+ f"'column <', 'column <=', 'column >', 'column >=' for comparison, "
360
+ f"'column LIKE' for token matching, 'column NOT LIKE' to exclude tokens."
361
+ )
362
+ else:
363
+ tool_description = base_description
364
+
365
+ @mlflow.trace(name="execute_instructed_retrieval", span_type=SpanType.RETRIEVER)
366
+ def _execute_instructed_retrieval(
367
+ vs: DatabricksVectorSearch,
368
+ query: str,
369
+ base_filters: dict[str, Any],
370
+ previous_feedback: str | None = None,
371
+ ) -> list[Document]:
372
+ """Execute instructed retrieval with query decomposition and RRF merging."""
373
+ logger.trace(
374
+ "Executing instructed retrieval", query=query, base_filters=base_filters
375
+ )
376
+ try:
377
+ decomposition_llm = _get_cached_llm(instructed_config.decomposition_model)
378
+
379
+ # Fall back to retriever columns if instructed columns not provided
380
+ instructed_columns: list[ColumnInfo] | None = instructed_config.columns
381
+ if instructed_columns is None and columns:
382
+ instructed_columns = [ColumnInfo(name=col) for col in columns]
383
+
384
+ subqueries: list[SearchQuery] = decompose_query(
385
+ llm=decomposition_llm,
386
+ query=query,
387
+ schema_description=instructed_config.schema_description,
388
+ constraints=instructed_config.constraints,
389
+ max_subqueries=instructed_config.max_subqueries,
390
+ examples=instructed_config.examples,
391
+ previous_feedback=previous_feedback,
392
+ columns=instructed_columns,
393
+ )
394
+
395
+ if not subqueries:
396
+ logger.warning(
397
+ "Query decomposition returned no subqueries, using original"
398
+ )
399
+ return vs.similarity_search(
400
+ query=query,
401
+ k=search_parameters.num_results or 5,
402
+ filter=base_filters if base_filters else None,
403
+ query_type=search_parameters.query_type or "ANN",
404
+ )
405
+
406
+ def normalize_filter_values(
407
+ filters: dict[str, Any], case: str | None
408
+ ) -> dict[str, Any]:
409
+ """Normalize string filter values to specified case."""
410
+ logger.trace("Normalizing filter values", filters=filters, case=case)
411
+ if not case or not filters:
412
+ return filters
413
+ normalized = {}
414
+ for key, value in filters.items():
415
+ if isinstance(value, str):
416
+ normalized[key] = (
417
+ value.upper() if case == "uppercase" else value.lower()
418
+ )
419
+ elif isinstance(value, list):
420
+ normalized[key] = [
421
+ v.upper()
422
+ if case == "uppercase"
423
+ else v.lower()
424
+ if isinstance(v, str)
425
+ else v
426
+ for v in value
427
+ ]
428
+ else:
429
+ normalized[key] = value
430
+ return normalized
431
+
432
+ def execute_search(sq: SearchQuery) -> list[Document]:
433
+ logger.trace("Executing search", query=sq.text, filters=sq.filters)
434
+ # Convert FilterItem list to dict
435
+ sq_filters_dict: dict[str, Any] = {}
436
+ if sq.filters:
437
+ for item in sq.filters:
438
+ sq_filters_dict[item.key] = item.value
439
+ sq_filters = normalize_filter_values(
440
+ sq_filters_dict, instructed_config.normalize_filter_case
441
+ )
442
+ k: int = search_parameters.num_results or 5
443
+ query_type: str = search_parameters.query_type or "ANN"
444
+ combined_filters: dict[str, Any] = {**sq_filters, **base_filters}
445
+ logger.trace(
446
+ "Executing search",
447
+ query=sq.text,
448
+ k=k,
449
+ query_type=query_type,
450
+ filters=combined_filters,
451
+ )
452
+ return vs.similarity_search(
453
+ query=sq.text,
454
+ k=k,
455
+ filter=combined_filters if combined_filters else None,
456
+ query_type=query_type,
457
+ )
458
+
459
+ logger.debug(
460
+ "Executing parallel searches",
461
+ num_subqueries=len(subqueries),
462
+ queries=[sq.text[:50] for sq in subqueries],
463
+ )
464
+
465
+ with ThreadPoolExecutor(
466
+ max_workers=instructed_config.max_subqueries
467
+ ) as executor:
468
+ all_results = list(executor.map(execute_search, subqueries))
469
+
470
+ merged = rrf_merge(
471
+ all_results,
472
+ k=instructed_config.rrf_k,
473
+ primary_key=vector_store.primary_key,
474
+ )
475
+
476
+ logger.debug(
477
+ "Instructed retrieval complete",
478
+ num_subqueries=len(subqueries),
479
+ total_results=sum(len(r) for r in all_results),
480
+ merged_results=len(merged),
481
+ )
482
+
483
+ return merged
484
+
485
+ except Exception as e:
486
+ logger.warning(
487
+ "Instructed retrieval failed, falling back to standard search",
488
+ error=str(e),
489
+ )
490
+ return vs.similarity_search(
491
+ query=query,
492
+ k=search_parameters.num_results or 5,
493
+ filter=base_filters if base_filters else None,
494
+ query_type=search_parameters.query_type or "ANN",
495
+ )
496
+
497
+ @mlflow.trace(name="execute_standard_search", span_type=SpanType.RETRIEVER)
498
+ def _execute_standard_search(
499
+ vs: DatabricksVectorSearch,
500
+ query: str,
501
+ base_filters: dict[str, Any],
502
+ ) -> list[Document]:
503
+ """Execute standard single-query search."""
504
+ logger.trace("Performing standard vector search", query_preview=query[:50])
505
+ return vs.similarity_search(
506
+ query=query,
507
+ k=search_parameters.num_results or 5,
508
+ filter=base_filters if base_filters else None,
509
+ query_type=search_parameters.query_type or "ANN",
510
+ )
511
+
512
+ @mlflow.trace(name="apply_post_processing", span_type=SpanType.RETRIEVER)
513
+ def _apply_post_processing(
514
+ documents: list[Document],
515
+ query: str,
516
+ mode: Literal["standard", "instructed"],
517
+ auto_bypass: bool,
518
+ ) -> list[Document]:
519
+ """Apply instruction-aware reranking and verification based on mode and bypass settings."""
520
+ # Skip post-processing for standard mode when auto_bypass is enabled
521
+ if mode == "standard" and auto_bypass:
522
+ mlflow.set_tag("router.bypassed_stages", "true")
523
+ return documents
524
+
525
+ # Apply instruction-aware reranking if configured
526
+ if rerank_config and rerank_config.instruction_aware:
527
+ instruction_config = rerank_config.instruction_aware
528
+ instruction_llm = (
529
+ _get_cached_llm(instruction_config.model)
530
+ if instruction_config.model
531
+ else None
532
+ )
533
+
534
+ if instruction_llm:
535
+ schema_desc = (
536
+ instructed_config.schema_description if instructed_config else None
537
+ )
538
+ # Get columns for dynamic instruction generation
539
+ rerank_columns: list[ColumnInfo] | None = None
540
+ if instructed_config and instructed_config.columns:
541
+ rerank_columns = instructed_config.columns
542
+ elif columns:
543
+ rerank_columns = [ColumnInfo(name=col) for col in columns]
544
+
545
+ documents = instruction_aware_rerank(
546
+ llm=instruction_llm,
547
+ query=query,
548
+ documents=documents,
549
+ instructions=instruction_config.instructions,
550
+ schema_description=schema_desc,
551
+ columns=rerank_columns,
552
+ top_n=instruction_config.top_n,
553
+ )
554
+
555
+ # Apply verification if configured
556
+ if verifier_config:
557
+ verifier_llm = (
558
+ _get_cached_llm(verifier_config.model)
559
+ if verifier_config.model
560
+ else None
561
+ )
562
+
563
+ if verifier_llm:
564
+ schema_desc = (
565
+ instructed_config.schema_description if instructed_config else ""
566
+ )
567
+ constraints = (
568
+ instructed_config.constraints if instructed_config else None
569
+ )
570
+ retry_count = 0
571
+ verification_result: VerificationResult | None = None
572
+ previous_feedback: str | None = None
573
+
574
+ while retry_count <= verifier_config.max_retries:
575
+ verification_result = verify_results(
576
+ llm=verifier_llm,
577
+ query=query,
578
+ documents=documents,
579
+ schema_description=schema_desc,
580
+ constraints=constraints,
581
+ previous_feedback=previous_feedback,
582
+ )
583
+
584
+ if verification_result.passed:
585
+ mlflow.set_tag("verifier.outcome", "passed")
586
+ mlflow.set_tag("verifier.retries", str(retry_count))
587
+ break
588
+
589
+ # Handle failure based on configuration
590
+ if verifier_config.on_failure == "warn":
591
+ mlflow.set_tag("verifier.outcome", "warned")
592
+ documents = add_verification_metadata(
593
+ documents, verification_result
594
+ )
595
+ break
596
+
597
+ if retry_count >= verifier_config.max_retries:
598
+ mlflow.set_tag("verifier.outcome", "exhausted")
599
+ mlflow.set_tag("verifier.retries", str(retry_count))
600
+ documents = add_verification_metadata(
601
+ documents, verification_result, exhausted=True
602
+ )
603
+ break
604
+
605
+ # Retry with feedback
606
+ mlflow.set_tag("verifier.outcome", "retried")
607
+ previous_feedback = verification_result.feedback
608
+ retry_count += 1
609
+ logger.debug(
610
+ "Retrying search with verification feedback", retry=retry_count
611
+ )
612
+
613
+ return documents
346
614
 
347
615
  # Use @tool decorator for proper ToolRuntime injection
348
- # The decorator ensures runtime is automatically injected and hidden from the LLM
349
616
  @tool(name_or_callable=tool_name, description=tool_description)
350
- def vector_search_func(
617
+ def _vector_search_tool(
351
618
  query: Annotated[str, "The search query to find relevant documents"],
352
619
  filters: Annotated[
353
620
  Optional[list[FilterItem]],
354
- "Optional filters to apply to the search results",
621
+ "Optional filters as key-value pairs. "
622
+ "Key operators: 'column' (equality), 'column NOT' (exclusion), "
623
+ "'column <', '<=', '>', '>=' (comparison), "
624
+ "'column LIKE' (token match), 'column NOT LIKE' (exclude token). "
625
+ f"Valid columns: {', '.join(columns) if columns else 'none'}.",
355
626
  ] = None,
356
627
  runtime: ToolRuntime[Context] = None,
357
628
  ) -> str:
358
629
  """Search for relevant documents using vector similarity."""
359
- # Get context for OBO support
360
630
  context: Context | None = runtime.context if runtime else None
631
+ vs: DatabricksVectorSearch = _get_vector_search(context)
361
632
 
362
- # Get vector search instance with OBO support
363
- vector_search: DatabricksVectorSearch = _get_vector_search(context)
364
-
365
- # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
366
633
  filters_dict: dict[str, Any] = {}
367
634
  if filters:
368
635
  for item in filters:
369
636
  filters_dict[item.key] = item.value
370
637
 
371
- # Merge with configured filters
372
- combined_filters: dict[str, Any] = {
638
+ base_filters: dict[str, Any] = {
373
639
  **filters_dict,
374
640
  **(search_parameters.filters or {}),
375
641
  }
376
642
 
377
- # Perform vector search
378
- logger.trace("Performing vector search", query_preview=query[:50])
379
- documents: list[Document] = vector_search.similarity_search(
380
- query=query,
381
- k=search_parameters.num_results or 5,
382
- filter=combined_filters if combined_filters else None,
383
- query_type=search_parameters.query_type or "ANN",
643
+ # Determine execution mode via router or config
644
+ mode: Literal["standard", "instructed"] = "standard"
645
+ auto_bypass = True
646
+
647
+ logger.trace("Router configuration", router_config=router_config)
648
+ logger.trace("Instructed configuration", instructed_config=instructed_config)
649
+ logger.trace(
650
+ "Instruction-aware rerank configuration",
651
+ instruction_aware=rerank_config.instruction_aware
652
+ if rerank_config
653
+ else None,
384
654
  )
385
655
 
656
+ if router_config:
657
+ router_llm = (
658
+ _get_cached_llm(router_config.model) if router_config.model else None
659
+ )
660
+ auto_bypass = router_config.auto_bypass
661
+
662
+ if router_llm and instructed_config:
663
+ try:
664
+ mode = route_query(
665
+ llm=router_llm,
666
+ query=query,
667
+ schema_description=instructed_config.schema_description,
668
+ )
669
+ except Exception as e:
670
+ # Router fail-safe: default to standard mode
671
+ logger.warning(
672
+ "Router failed, defaulting to standard mode", error=str(e)
673
+ )
674
+ mlflow.set_tag("router.fallback", "true")
675
+ mode = router_config.default_mode
676
+ else:
677
+ mode = router_config.default_mode
678
+ elif instructed_config:
679
+ # No router but instructed is configured - use instructed mode
680
+ mode = "instructed"
681
+ auto_bypass = False
682
+ elif rerank_config and rerank_config.instruction_aware:
683
+ # No router/instructed but instruction_aware reranking is configured
684
+ # Disable auto_bypass to ensure instruction_aware reranking runs
685
+ auto_bypass = False
686
+
687
+ logger.trace("Routing mode", mode=mode, auto_bypass=auto_bypass)
688
+ mlflow.set_tag("router.mode", mode)
689
+
690
+ # Execute search based on mode
691
+ if mode == "instructed" and instructed_config:
692
+ documents = _execute_instructed_retrieval(vs, query, base_filters)
693
+ else:
694
+ documents = _execute_standard_search(vs, query, base_filters)
695
+
386
696
  # Apply FlashRank reranking if configured
387
697
  if ranker and rerank_config:
388
698
  logger.debug("Applying FlashRank reranking")
389
699
  documents = _rerank_documents(query, documents, ranker, rerank_config)
390
700
 
701
+ # Apply post-processing (instruction reranking + verification)
702
+ documents = _apply_post_processing(documents, query, mode, auto_bypass)
703
+
391
704
  # Serialize documents to JSON format for LLM consumption
392
- # Convert Document objects to dicts with page_content and metadata
393
- # Need to handle numpy types in metadata (e.g., float32, int64)
394
705
  serialized_docs: list[dict[str, Any]] = []
395
706
  for doc in documents:
396
- doc: Document
397
- # Convert metadata values to JSON-serializable types
398
707
  metadata_serializable: dict[str, Any] = {}
399
708
  for key, value in doc.metadata.items():
400
- # Handle numpy types
401
709
  if hasattr(value, "item"): # numpy scalar
402
710
  metadata_serializable[key] = value.item()
403
711
  else:
@@ -410,9 +718,8 @@ def create_vector_search_tool(
410
718
  }
411
719
  )
412
720
 
413
- # Return as JSON string
414
721
  return json.dumps(serialized_docs)
415
722
 
416
723
  logger.success("Vector search tool created", name=tool_name, index=index_name)
417
724
 
418
- return vector_search_func
725
+ return _vector_search_tool