dao-ai 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +8 -3
- dao_ai/config.py +513 -32
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/cache/__init__.py +2 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/in_memory_semantic.py +871 -0
- dao_ai/genie/cache/lru.py +15 -11
- dao_ai/genie/cache/semantic.py +52 -18
- dao_ai/memory/postgres.py +146 -35
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/{prompts.py → prompts/__init__.py} +10 -1
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/databricks.py +33 -12
- dao_ai/tools/genie.py +28 -3
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/vector_search.py +441 -134
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/METADATA +4 -3
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/RECORD +30 -20
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
name: instructed_retriever_decomposition
|
|
2
|
+
description: Decomposes user queries into multiple search queries with metadata filters
|
|
3
|
+
|
|
4
|
+
template: |
|
|
5
|
+
You are a search query decomposition expert. Your task is to break down a user query into one or more focused search queries with appropriate metadata filters. Respond with a JSON object.
|
|
6
|
+
|
|
7
|
+
## Current Time
|
|
8
|
+
{current_time}
|
|
9
|
+
|
|
10
|
+
## Database Schema
|
|
11
|
+
{schema_description}
|
|
12
|
+
|
|
13
|
+
## Constraints
|
|
14
|
+
{constraints}
|
|
15
|
+
|
|
16
|
+
## Few-Shot Examples
|
|
17
|
+
{examples}
|
|
18
|
+
|
|
19
|
+
## Instructions
|
|
20
|
+
1. Analyze the user query and identify distinct search intents
|
|
21
|
+
2. For each intent, create a focused search query text
|
|
22
|
+
3. Extract metadata filters from the query using the exact filter syntax above
|
|
23
|
+
4. Resolve relative time references (e.g., "last month", "past year") using the current time
|
|
24
|
+
5. Generate at most {max_subqueries} search queries
|
|
25
|
+
6. If no filters apply, set filters to null
|
|
26
|
+
|
|
27
|
+
## User Query
|
|
28
|
+
{query}
|
|
29
|
+
|
|
30
|
+
Generate search queries that together capture all aspects of the user's information need.
|
|
31
|
+
|
|
32
|
+
variables:
|
|
33
|
+
- current_time
|
|
34
|
+
- schema_description
|
|
35
|
+
- constraints
|
|
36
|
+
- examples
|
|
37
|
+
- max_subqueries
|
|
38
|
+
- query
|
|
39
|
+
|
|
40
|
+
output_format: |
|
|
41
|
+
The output must be a JSON object with a "queries" field containing an array of search query objects.
|
|
42
|
+
Each search query object has:
|
|
43
|
+
- "text": The search query string
|
|
44
|
+
- "filters": An array of filter objects, each with "key" (column + optional operator) and "value", or null if no filters
|
|
45
|
+
|
|
46
|
+
Supported filter operators (append to column name):
|
|
47
|
+
- Equality: {"key": "column", "value": "val"} or {"key": "column", "value": ["val1", "val2"]}
|
|
48
|
+
- Exclusion: {"key": "column NOT", "value": "val"}
|
|
49
|
+
- Comparison: {"key": "column <", "value": 100}, also <=, >, >=
|
|
50
|
+
- Token match: {"key": "column LIKE", "value": "word"}
|
|
51
|
+
- Exclude token: {"key": "column NOT LIKE", "value": "word"}
|
|
52
|
+
|
|
53
|
+
Examples:
|
|
54
|
+
- [{"key": "brand_name", "value": "MILWAUKEE"}]
|
|
55
|
+
- [{"key": "price <", "value": 100}]
|
|
56
|
+
- [{"key": "brand_name NOT", "value": "DEWALT"}]
|
|
57
|
+
- [{"key": "brand_name", "value": ["MILWAUKEE", "DEWALT"]}]
|
|
58
|
+
- [{"key": "description LIKE", "value": "cordless"}]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
name: instruction_aware_reranking
|
|
2
|
+
version: "1.1"
|
|
3
|
+
description: Rerank documents based on user instructions and constraints
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
Rerank these search results for the query "{query}".
|
|
7
|
+
|
|
8
|
+
{instructions}
|
|
9
|
+
|
|
10
|
+
## Documents
|
|
11
|
+
|
|
12
|
+
{documents}
|
|
13
|
+
|
|
14
|
+
Score each document 0.0-1.0 based on relevance to the query and instructions. Return results sorted by score (highest first). Only include documents scoring > 0.1.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
name: router_query_classification
|
|
2
|
+
version: "1.0"
|
|
3
|
+
description: Classify query to determine execution mode (standard vs instructed)
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
You are a query classification system. Your task is to determine the best execution mode for a search query.
|
|
7
|
+
|
|
8
|
+
## Execution Modes
|
|
9
|
+
|
|
10
|
+
**standard**: Use for simple keyword or product searches without specific constraints.
|
|
11
|
+
- General questions about products
|
|
12
|
+
- Simple keyword searches
|
|
13
|
+
- Broad category browsing
|
|
14
|
+
|
|
15
|
+
**instructed**: Use for queries with explicit constraints that require metadata filtering.
|
|
16
|
+
- Price constraints ("under $100", "between $50 and $200")
|
|
17
|
+
- Brand preferences ("Milwaukee", "not DeWalt", "excluding Makita")
|
|
18
|
+
- Category filters ("power tools", "paint supplies")
|
|
19
|
+
- Time/recency constraints ("recent", "from last month", "updated this year")
|
|
20
|
+
- Comparison queries ("compare X and Y")
|
|
21
|
+
- Multiple combined constraints
|
|
22
|
+
|
|
23
|
+
## Available Schema for Filtering
|
|
24
|
+
|
|
25
|
+
{schema_description}
|
|
26
|
+
|
|
27
|
+
## Query to Classify
|
|
28
|
+
|
|
29
|
+
"{query}"
|
|
30
|
+
|
|
31
|
+
## Instructions
|
|
32
|
+
|
|
33
|
+
Analyze the query and determine:
|
|
34
|
+
1. Does it contain explicit constraints that can be translated to metadata filters?
|
|
35
|
+
2. Would the query benefit from being decomposed into subqueries?
|
|
36
|
+
|
|
37
|
+
Return your classification as a JSON object with a single field "mode" set to either "standard" or "instructed".
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
name: result_verification
|
|
2
|
+
version: "1.0"
|
|
3
|
+
description: Verify search results satisfy user constraints
|
|
4
|
+
|
|
5
|
+
template: |
|
|
6
|
+
You are a result verification system. Your task is to determine whether search results satisfy the user's query constraints.
|
|
7
|
+
|
|
8
|
+
## User Query
|
|
9
|
+
|
|
10
|
+
"{query}"
|
|
11
|
+
|
|
12
|
+
## Schema Information
|
|
13
|
+
|
|
14
|
+
{schema_description}
|
|
15
|
+
|
|
16
|
+
## Constraints to Verify
|
|
17
|
+
|
|
18
|
+
{constraints}
|
|
19
|
+
|
|
20
|
+
## Retrieved Results (Top {num_results})
|
|
21
|
+
|
|
22
|
+
{results_summary}
|
|
23
|
+
|
|
24
|
+
## Previous Attempt Feedback (if retry)
|
|
25
|
+
|
|
26
|
+
{previous_feedback}
|
|
27
|
+
|
|
28
|
+
## Instructions
|
|
29
|
+
|
|
30
|
+
Analyze whether the results satisfy the user's explicit and implicit constraints:
|
|
31
|
+
|
|
32
|
+
1. **Intent Match**: Do the results address what the user is looking for?
|
|
33
|
+
2. **Explicit Constraints**: Are price, brand, category, date constraints met?
|
|
34
|
+
3. **Relevance**: Are the results actually useful for the user's needs?
|
|
35
|
+
|
|
36
|
+
If results do NOT satisfy constraints, suggest specific filter relaxations:
|
|
37
|
+
- Use "REMOVE" to drop a filter entirely
|
|
38
|
+
- Use "BROADEN" to widen a range (e.g., price < 100 -> price < 150)
|
|
39
|
+
- Use specific values to change a filter
|
|
40
|
+
|
|
41
|
+
Return a JSON object with:
|
|
42
|
+
- passed: boolean (true if results are satisfactory)
|
|
43
|
+
- confidence: float (0.0-1.0, your confidence in the assessment)
|
|
44
|
+
- feedback: string (brief explanation of issues, if any)
|
|
45
|
+
- suggested_filter_relaxation: object (filter changes for retry, e.g., {{"brand_name": "REMOVE"}})
|
|
46
|
+
- unmet_constraints: array of strings (list of constraints not satisfied)
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -397,6 +397,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
397
397
|
|
|
398
398
|
pip_requirements += get_installed_packages()
|
|
399
399
|
|
|
400
|
+
code_paths = list(dict.fromkeys(code_paths))
|
|
401
|
+
|
|
400
402
|
logger.trace("Pip requirements prepared", count=len(pip_requirements))
|
|
401
403
|
logger.trace("Code paths prepared", count=len(code_paths))
|
|
402
404
|
|
|
@@ -434,19 +436,38 @@ class DatabricksProvider(ServiceProvider):
|
|
|
434
436
|
pip_packages_count=len(pip_requirements),
|
|
435
437
|
)
|
|
436
438
|
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
439
|
+
# End any stale runs before starting to ensure clean state on retry
|
|
440
|
+
if mlflow.active_run():
|
|
441
|
+
logger.warning(
|
|
442
|
+
"Ending stale MLflow run before creating new agent",
|
|
443
|
+
run_id=mlflow.active_run().info.run_id,
|
|
444
|
+
)
|
|
445
|
+
mlflow.end_run()
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
with mlflow.start_run(run_name=run_name):
|
|
449
|
+
mlflow.set_tag("type", "agent")
|
|
450
|
+
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
451
|
+
logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
|
|
452
|
+
python_model=model_path.as_posix(),
|
|
453
|
+
code_paths=code_paths,
|
|
454
|
+
model_config=config.model_dump(mode="json", by_alias=True),
|
|
455
|
+
name="agent",
|
|
456
|
+
conda_env=conda_env,
|
|
457
|
+
input_example=input_example,
|
|
458
|
+
# resources=all_resources,
|
|
459
|
+
auth_policy=auth_policy,
|
|
460
|
+
)
|
|
461
|
+
except Exception as e:
|
|
462
|
+
# Ensure run is ended on failure to prevent stale state on retry
|
|
463
|
+
if mlflow.active_run():
|
|
464
|
+
mlflow.end_run(status="FAILED")
|
|
465
|
+
logger.error(
|
|
466
|
+
"Failed to log model",
|
|
467
|
+
run_name=run_name,
|
|
468
|
+
error=str(e),
|
|
449
469
|
)
|
|
470
|
+
raise
|
|
450
471
|
|
|
451
472
|
registered_model_name: str = config.app.registered_model.full_name
|
|
452
473
|
|
dao_ai/tools/genie.py
CHANGED
|
@@ -25,13 +25,19 @@ from pydantic import BaseModel
|
|
|
25
25
|
from dao_ai.config import (
|
|
26
26
|
AnyVariable,
|
|
27
27
|
CompositeVariableModel,
|
|
28
|
+
GenieInMemorySemanticCacheParametersModel,
|
|
28
29
|
GenieLRUCacheParametersModel,
|
|
29
30
|
GenieRoomModel,
|
|
30
31
|
GenieSemanticCacheParametersModel,
|
|
31
32
|
value_of,
|
|
32
33
|
)
|
|
33
34
|
from dao_ai.genie import GenieService, GenieServiceBase
|
|
34
|
-
from dao_ai.genie.cache import
|
|
35
|
+
from dao_ai.genie.cache import (
|
|
36
|
+
CacheResult,
|
|
37
|
+
InMemorySemanticCacheService,
|
|
38
|
+
LRUCacheService,
|
|
39
|
+
SemanticCacheService,
|
|
40
|
+
)
|
|
35
41
|
from dao_ai.state import AgentState, Context, SessionState
|
|
36
42
|
|
|
37
43
|
|
|
@@ -67,6 +73,9 @@ def create_genie_tool(
|
|
|
67
73
|
semantic_cache_parameters: GenieSemanticCacheParametersModel
|
|
68
74
|
| dict[str, Any]
|
|
69
75
|
| None = None,
|
|
76
|
+
in_memory_semantic_cache_parameters: GenieInMemorySemanticCacheParametersModel
|
|
77
|
+
| dict[str, Any]
|
|
78
|
+
| None = None,
|
|
70
79
|
) -> Callable[..., Command]:
|
|
71
80
|
"""
|
|
72
81
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
@@ -84,7 +93,9 @@ def create_genie_tool(
|
|
|
84
93
|
truncate_results: Whether to truncate large query results to fit token limits
|
|
85
94
|
lru_cache_parameters: Optional LRU cache configuration for SQL query caching
|
|
86
95
|
semantic_cache_parameters: Optional semantic cache configuration using pg_vector
|
|
87
|
-
for similarity-based query matching
|
|
96
|
+
for similarity-based query matching (requires PostgreSQL/Lakebase)
|
|
97
|
+
in_memory_semantic_cache_parameters: Optional in-memory semantic cache configuration
|
|
98
|
+
for similarity-based query matching (no database required)
|
|
88
99
|
|
|
89
100
|
Returns:
|
|
90
101
|
A LangGraph tool that processes natural language queries through Genie
|
|
@@ -97,6 +108,7 @@ def create_genie_tool(
|
|
|
97
108
|
name=name,
|
|
98
109
|
has_lru_cache=lru_cache_parameters is not None,
|
|
99
110
|
has_semantic_cache=semantic_cache_parameters is not None,
|
|
111
|
+
has_in_memory_semantic_cache=in_memory_semantic_cache_parameters is not None,
|
|
100
112
|
)
|
|
101
113
|
|
|
102
114
|
if isinstance(genie_room, dict):
|
|
@@ -110,6 +122,11 @@ def create_genie_tool(
|
|
|
110
122
|
**semantic_cache_parameters
|
|
111
123
|
)
|
|
112
124
|
|
|
125
|
+
if isinstance(in_memory_semantic_cache_parameters, dict):
|
|
126
|
+
in_memory_semantic_cache_parameters = GenieInMemorySemanticCacheParametersModel(
|
|
127
|
+
**in_memory_semantic_cache_parameters
|
|
128
|
+
)
|
|
129
|
+
|
|
113
130
|
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
114
131
|
"DATABRICKS_GENIE_SPACE_ID"
|
|
115
132
|
)
|
|
@@ -165,7 +182,7 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
165
182
|
|
|
166
183
|
genie_service: GenieServiceBase = GenieService(genie)
|
|
167
184
|
|
|
168
|
-
# Wrap with semantic cache first (checked second due to decorator pattern)
|
|
185
|
+
# Wrap with semantic cache first (checked second/third due to decorator pattern)
|
|
169
186
|
if semantic_cache_parameters is not None:
|
|
170
187
|
genie_service = SemanticCacheService(
|
|
171
188
|
impl=genie_service,
|
|
@@ -173,6 +190,14 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
173
190
|
workspace_client=workspace_client,
|
|
174
191
|
).initialize()
|
|
175
192
|
|
|
193
|
+
# Wrap with in-memory semantic cache (alternative to PostgreSQL semantic cache)
|
|
194
|
+
if in_memory_semantic_cache_parameters is not None:
|
|
195
|
+
genie_service = InMemorySemanticCacheService(
|
|
196
|
+
impl=genie_service,
|
|
197
|
+
parameters=in_memory_semantic_cache_parameters,
|
|
198
|
+
workspace_client=workspace_client,
|
|
199
|
+
).initialize()
|
|
200
|
+
|
|
176
201
|
# Wrap with LRU cache last (checked first - fast O(1) exact match)
|
|
177
202
|
if lru_cache_parameters is not None:
|
|
178
203
|
genie_service = LRUCacheService(
|
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Instructed retriever for query decomposition and result fusion.
|
|
3
|
+
|
|
4
|
+
This module provides functions for decomposing user queries into multiple
|
|
5
|
+
subqueries with metadata filters and merging results using Reciprocal Rank Fusion.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Optional, Union
|
|
12
|
+
|
|
13
|
+
import mlflow
|
|
14
|
+
import yaml
|
|
15
|
+
from langchain_core.documents import Document
|
|
16
|
+
from langchain_core.language_models import BaseChatModel
|
|
17
|
+
from langchain_core.runnables import Runnable
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from mlflow.entities import SpanType
|
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
21
|
+
|
|
22
|
+
from dao_ai.config import (
|
|
23
|
+
ColumnInfo,
|
|
24
|
+
DecomposedQueries,
|
|
25
|
+
FilterItem,
|
|
26
|
+
LLMModel,
|
|
27
|
+
SearchQuery,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# Module-level cache for LLM clients
|
|
31
|
+
_llm_cache: dict[str, BaseChatModel] = {}
|
|
32
|
+
|
|
33
|
+
# Load prompt template
|
|
34
|
+
_PROMPT_PATH = (
|
|
35
|
+
Path(__file__).parent.parent / "prompts" / "instructed_retriever_decomposition.yaml"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _load_prompt_template() -> dict[str, Any]:
|
|
40
|
+
"""Load the decomposition prompt template from YAML."""
|
|
41
|
+
with open(_PROMPT_PATH) as f:
|
|
42
|
+
return yaml.safe_load(f)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_cached_llm(model_config: LLMModel) -> BaseChatModel:
|
|
46
|
+
"""
|
|
47
|
+
Get or create cached LLM client for decomposition.
|
|
48
|
+
|
|
49
|
+
Uses full config as cache key to avoid collisions when same model name
|
|
50
|
+
has different parameters (temperature, API keys, etc.).
|
|
51
|
+
"""
|
|
52
|
+
cache_key = model_config.model_dump_json()
|
|
53
|
+
if cache_key not in _llm_cache:
|
|
54
|
+
_llm_cache[cache_key] = model_config.as_chat_model()
|
|
55
|
+
logger.debug(
|
|
56
|
+
"Created new LLM client for decomposition", model=model_config.name
|
|
57
|
+
)
|
|
58
|
+
return _llm_cache[cache_key]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _format_constraints(constraints: list[str] | None) -> str:
|
|
62
|
+
"""Format constraints list for prompt injection."""
|
|
63
|
+
if not constraints:
|
|
64
|
+
return "No additional constraints."
|
|
65
|
+
return "\n".join(f"- {c}" for c in constraints)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _format_examples(examples: list[dict[str, Any]] | None) -> str:
|
|
69
|
+
"""Format few-shot examples for prompt injection.
|
|
70
|
+
|
|
71
|
+
Converts dict-style filters from config to FilterItem array format
|
|
72
|
+
to match the expected JSON schema output.
|
|
73
|
+
"""
|
|
74
|
+
if not examples:
|
|
75
|
+
return "No examples provided."
|
|
76
|
+
|
|
77
|
+
formatted = []
|
|
78
|
+
for i, ex in enumerate(examples, 1):
|
|
79
|
+
query = ex.get("query", "")
|
|
80
|
+
filters = ex.get("filters", {})
|
|
81
|
+
# Convert dict to FilterItem array format
|
|
82
|
+
filter_items = [{"key": k, "value": v} for k, v in filters.items()]
|
|
83
|
+
formatted.append(
|
|
84
|
+
f'Example {i}:\n Query: "{query}"\n Filters: {json.dumps(filter_items)}'
|
|
85
|
+
)
|
|
86
|
+
return "\n".join(formatted)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def create_decomposition_schema(
|
|
90
|
+
columns: list[ColumnInfo] | None = None,
|
|
91
|
+
) -> type[BaseModel]:
|
|
92
|
+
"""Create schema-aware DecomposedQueries model with dynamic descriptions.
|
|
93
|
+
|
|
94
|
+
When columns are provided, the column names and valid operators are embedded
|
|
95
|
+
directly into the JSON schema that with_structured_output sends to the LLM.
|
|
96
|
+
This improves accuracy by making valid filter keys explicit in the schema.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
columns: List of column metadata for dynamic schema generation
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
A DecomposedQueries-compatible Pydantic model class
|
|
103
|
+
"""
|
|
104
|
+
if not columns:
|
|
105
|
+
# Fall back to generic models
|
|
106
|
+
return DecomposedQueries
|
|
107
|
+
|
|
108
|
+
# Build column info with types for the schema description
|
|
109
|
+
column_info = ", ".join(f"{c.name} ({c.type})" for c in columns)
|
|
110
|
+
|
|
111
|
+
# Build operator list from column definitions (union of all column operators)
|
|
112
|
+
all_operators: set[str] = set()
|
|
113
|
+
for col in columns:
|
|
114
|
+
all_operators.update(col.operators)
|
|
115
|
+
# Remove empty string (equality) and sort for consistent output
|
|
116
|
+
named_operators = sorted(all_operators - {""})
|
|
117
|
+
operator_list = ", ".join(named_operators) if named_operators else "equality only"
|
|
118
|
+
|
|
119
|
+
# Build valid key examples with operators
|
|
120
|
+
key_examples: list[str] = []
|
|
121
|
+
for col in columns[:3]: # Show examples for first 3 columns
|
|
122
|
+
key_examples.append(f"'{col.name}'")
|
|
123
|
+
if "<" in col.operators:
|
|
124
|
+
key_examples.append(f"'{col.name} <'")
|
|
125
|
+
if "NOT" in col.operators:
|
|
126
|
+
key_examples.append(f"'{col.name} NOT'")
|
|
127
|
+
|
|
128
|
+
# Create dynamic FilterItem with schema-aware description
|
|
129
|
+
class SchemaFilterItem(BaseModel):
|
|
130
|
+
"""A metadata filter for vector search with schema-specific columns."""
|
|
131
|
+
|
|
132
|
+
model_config = ConfigDict(extra="forbid")
|
|
133
|
+
key: str = Field(
|
|
134
|
+
description=(
|
|
135
|
+
f"Column name with optional operator suffix. "
|
|
136
|
+
f"Valid columns: {column_info}. "
|
|
137
|
+
f"Operators: (none) for equality, {operator_list}. "
|
|
138
|
+
f"Examples: {', '.join(key_examples[:5])}"
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
|
|
142
|
+
description="The filter value matching the column type."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Create dynamic SearchQuery using SchemaFilterItem
|
|
146
|
+
class SchemaSearchQuery(BaseModel):
|
|
147
|
+
"""A search query with schema-aware filters."""
|
|
148
|
+
|
|
149
|
+
model_config = ConfigDict(extra="forbid")
|
|
150
|
+
text: str = Field(
|
|
151
|
+
description=(
|
|
152
|
+
"Natural language search query text optimized for semantic similarity. "
|
|
153
|
+
"Should be focused on a single search intent. "
|
|
154
|
+
"Do NOT include filter criteria in the text; use the filters field instead."
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
filters: Optional[list[SchemaFilterItem]] = Field(
|
|
158
|
+
default=None,
|
|
159
|
+
description=(
|
|
160
|
+
f"Metadata filters to constrain search results. "
|
|
161
|
+
f"Valid filter columns: {column_info}. "
|
|
162
|
+
f"Set to null if no filters apply."
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
# Create dynamic DecomposedQueries using SchemaSearchQuery
|
|
167
|
+
class SchemaDecomposedQueries(BaseModel):
|
|
168
|
+
"""Decomposed search queries with schema-aware filters."""
|
|
169
|
+
|
|
170
|
+
model_config = ConfigDict(extra="forbid")
|
|
171
|
+
queries: list[SchemaSearchQuery] = Field(
|
|
172
|
+
description=(
|
|
173
|
+
"List of search queries extracted from the user request. "
|
|
174
|
+
"Each query should target a distinct search intent. "
|
|
175
|
+
"Order queries by importance, with the most relevant first."
|
|
176
|
+
)
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
return SchemaDecomposedQueries
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@mlflow.trace(name="decompose_query", span_type=SpanType.LLM)
|
|
183
|
+
def decompose_query(
|
|
184
|
+
llm: BaseChatModel,
|
|
185
|
+
query: str,
|
|
186
|
+
schema_description: str,
|
|
187
|
+
constraints: list[str] | None = None,
|
|
188
|
+
max_subqueries: int = 3,
|
|
189
|
+
examples: list[dict[str, Any]] | None = None,
|
|
190
|
+
previous_feedback: str | None = None,
|
|
191
|
+
columns: list[ColumnInfo] | None = None,
|
|
192
|
+
) -> list[SearchQuery]:
|
|
193
|
+
"""
|
|
194
|
+
Decompose a user query into multiple search queries with filters.
|
|
195
|
+
|
|
196
|
+
Uses structured output for reliable parsing and injects current time
|
|
197
|
+
for resolving relative date references. When columns are provided,
|
|
198
|
+
schema-aware Pydantic models are used for improved filter accuracy.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
llm: Language model for decomposition
|
|
202
|
+
query: User's search query
|
|
203
|
+
schema_description: Column names, types, and valid filter syntax
|
|
204
|
+
constraints: Default constraints to apply
|
|
205
|
+
max_subqueries: Maximum number of subqueries to generate
|
|
206
|
+
examples: Few-shot examples for domain-specific filter translation
|
|
207
|
+
previous_feedback: Feedback from failed verification (for retry)
|
|
208
|
+
columns: Structured column info for dynamic schema generation
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
List of SearchQuery objects with text and optional filters
|
|
212
|
+
"""
|
|
213
|
+
current_time = datetime.now().isoformat()
|
|
214
|
+
|
|
215
|
+
# Load and format prompt
|
|
216
|
+
prompt_config = _load_prompt_template()
|
|
217
|
+
prompt_template = prompt_config["template"]
|
|
218
|
+
|
|
219
|
+
# Add previous feedback section if provided (for retry)
|
|
220
|
+
feedback_section = ""
|
|
221
|
+
if previous_feedback:
|
|
222
|
+
feedback_section = f"\n\n## Previous Attempt Feedback\nThe previous search attempt failed verification: {previous_feedback}\nAdjust your filters to address this feedback."
|
|
223
|
+
|
|
224
|
+
prompt = (
|
|
225
|
+
prompt_template.format(
|
|
226
|
+
current_time=current_time,
|
|
227
|
+
schema_description=schema_description,
|
|
228
|
+
constraints=_format_constraints(constraints),
|
|
229
|
+
examples=_format_examples(examples),
|
|
230
|
+
max_subqueries=max_subqueries,
|
|
231
|
+
query=query,
|
|
232
|
+
)
|
|
233
|
+
+ feedback_section
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
logger.trace(
|
|
237
|
+
"Decomposing query",
|
|
238
|
+
query=query[:100],
|
|
239
|
+
max_subqueries=max_subqueries,
|
|
240
|
+
dynamic_schema=columns is not None,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# Create schema-aware model when columns are provided
|
|
244
|
+
DecompositionSchema: type[BaseModel] = create_decomposition_schema(columns)
|
|
245
|
+
|
|
246
|
+
# Use LangChain's with_structured_output for automatic strategy selection
|
|
247
|
+
# (JSON schema vs tool calling based on model capabilities)
|
|
248
|
+
try:
|
|
249
|
+
structured_llm: Runnable[str, BaseModel] = llm.with_structured_output(
|
|
250
|
+
DecompositionSchema
|
|
251
|
+
)
|
|
252
|
+
result: BaseModel = structured_llm.invoke(prompt)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
logger.warning("Query decomposition failed", error=str(e))
|
|
255
|
+
raise
|
|
256
|
+
|
|
257
|
+
# Extract queries from result (works with both static and dynamic schemas)
|
|
258
|
+
subqueries: list[SearchQuery] = []
|
|
259
|
+
for query_obj in result.queries[:max_subqueries]:
|
|
260
|
+
# Convert dynamic schema objects to SearchQuery for consistent return type
|
|
261
|
+
filters: list[FilterItem] | None = None
|
|
262
|
+
if query_obj.filters:
|
|
263
|
+
filters = [FilterItem(key=f.key, value=f.value) for f in query_obj.filters]
|
|
264
|
+
subqueries.append(SearchQuery(text=query_obj.text, filters=filters))
|
|
265
|
+
|
|
266
|
+
# Log for observability
|
|
267
|
+
mlflow.set_tag("num_subqueries", len(subqueries))
|
|
268
|
+
mlflow.log_text(
|
|
269
|
+
json.dumps([sq.model_dump() for sq in subqueries], indent=2),
|
|
270
|
+
"decomposition.json",
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
logger.debug(
|
|
274
|
+
"Query decomposed",
|
|
275
|
+
num_subqueries=len(subqueries),
|
|
276
|
+
queries=[sq.text[:50] for sq in subqueries],
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return subqueries
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def rrf_merge(
|
|
283
|
+
results_lists: list[list[Document]],
|
|
284
|
+
k: int = 60,
|
|
285
|
+
primary_key: str | None = None,
|
|
286
|
+
) -> list[Document]:
|
|
287
|
+
"""
|
|
288
|
+
Merge results from multiple queries using Reciprocal Rank Fusion.
|
|
289
|
+
|
|
290
|
+
RRF is safer than raw score sorting because Databricks Vector Search
|
|
291
|
+
scores aren't normalized across query types (HYBRID vs ANN).
|
|
292
|
+
|
|
293
|
+
RRF Score = Σ 1 / (k + rank_i) for each result list
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
results_lists: List of document lists from different subqueries
|
|
297
|
+
k: RRF constant (lower values weight top ranks more heavily)
|
|
298
|
+
primary_key: Metadata key for document identity (for deduplication)
|
|
299
|
+
|
|
300
|
+
Returns:
|
|
301
|
+
Merged and deduplicated documents sorted by RRF score
|
|
302
|
+
"""
|
|
303
|
+
if not results_lists:
|
|
304
|
+
return []
|
|
305
|
+
|
|
306
|
+
# Filter empty lists first
|
|
307
|
+
non_empty = [r for r in results_lists if r]
|
|
308
|
+
if not non_empty:
|
|
309
|
+
return []
|
|
310
|
+
|
|
311
|
+
# Single list optimization (still add RRF scores for consistency)
|
|
312
|
+
if len(non_empty) == 1:
|
|
313
|
+
docs_with_scores: list[Document] = []
|
|
314
|
+
for rank, doc in enumerate(non_empty[0]):
|
|
315
|
+
rrf_score = 1.0 / (k + rank + 1)
|
|
316
|
+
docs_with_scores.append(
|
|
317
|
+
Document(
|
|
318
|
+
page_content=doc.page_content,
|
|
319
|
+
metadata={**doc.metadata, "rrf_score": rrf_score},
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
return docs_with_scores
|
|
323
|
+
|
|
324
|
+
# Calculate RRF scores
|
|
325
|
+
# Key: document identifier, Value: (total_rrf_score, Document)
|
|
326
|
+
doc_scores: dict[str, tuple[float, Document]] = {}
|
|
327
|
+
|
|
328
|
+
def get_doc_id(doc: Document) -> str:
|
|
329
|
+
"""Get unique identifier for document."""
|
|
330
|
+
if primary_key and primary_key in doc.metadata:
|
|
331
|
+
return str(doc.metadata[primary_key])
|
|
332
|
+
# Fallback to content hash
|
|
333
|
+
return str(hash(doc.page_content))
|
|
334
|
+
|
|
335
|
+
for result_list in non_empty:
|
|
336
|
+
for rank, doc in enumerate(result_list):
|
|
337
|
+
doc_id = get_doc_id(doc)
|
|
338
|
+
rrf_score = 1.0 / (k + rank + 1) # rank is 0-indexed
|
|
339
|
+
|
|
340
|
+
if doc_id in doc_scores:
|
|
341
|
+
# Accumulate RRF score for duplicates
|
|
342
|
+
existing_score, existing_doc = doc_scores[doc_id]
|
|
343
|
+
doc_scores[doc_id] = (existing_score + rrf_score, existing_doc)
|
|
344
|
+
else:
|
|
345
|
+
doc_scores[doc_id] = (rrf_score, doc)
|
|
346
|
+
|
|
347
|
+
# Sort by RRF score descending
|
|
348
|
+
sorted_docs = sorted(doc_scores.values(), key=lambda x: x[0], reverse=True)
|
|
349
|
+
|
|
350
|
+
# Add RRF score to metadata
|
|
351
|
+
merged_docs: list[Document] = []
|
|
352
|
+
for rrf_score, doc in sorted_docs:
|
|
353
|
+
merged_doc = Document(
|
|
354
|
+
page_content=doc.page_content,
|
|
355
|
+
metadata={**doc.metadata, "rrf_score": rrf_score},
|
|
356
|
+
)
|
|
357
|
+
merged_docs.append(merged_doc)
|
|
358
|
+
|
|
359
|
+
logger.debug(
|
|
360
|
+
"RRF merge complete",
|
|
361
|
+
input_lists=len(results_lists),
|
|
362
|
+
total_docs=sum(len(r) for r in results_lists),
|
|
363
|
+
unique_docs=len(merged_docs),
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
return merged_docs
|