dao-ai 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +8 -3
- dao_ai/config.py +513 -32
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/cache/__init__.py +2 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/in_memory_semantic.py +871 -0
- dao_ai/genie/cache/lru.py +15 -11
- dao_ai/genie/cache/semantic.py +52 -18
- dao_ai/memory/postgres.py +146 -35
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/{prompts.py → prompts/__init__.py} +10 -1
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/databricks.py +33 -12
- dao_ai/tools/genie.py +28 -3
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/vector_search.py +441 -134
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/METADATA +4 -3
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/RECORD +30 -20
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/licenses/LICENSE +0 -0
dao_ai/cli.py
CHANGED
|
@@ -521,14 +521,19 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
521
521
|
)
|
|
522
522
|
continue
|
|
523
523
|
|
|
524
|
+
# Normalize user_id for memory namespace compatibility (replace . with _)
|
|
525
|
+
# This matches the normalization in models.py _convert_to_context
|
|
526
|
+
if configurable.get("user_id"):
|
|
527
|
+
configurable["user_id"] = configurable["user_id"].replace(".", "_")
|
|
528
|
+
|
|
524
529
|
# Create Context object from configurable dict
|
|
525
530
|
from dao_ai.state import Context
|
|
526
531
|
|
|
527
532
|
context = Context(**configurable)
|
|
528
533
|
|
|
529
|
-
# Prepare config with
|
|
530
|
-
# Note:
|
|
531
|
-
config = {"configurable":
|
|
534
|
+
# Prepare config with all context fields for checkpointer/memory
|
|
535
|
+
# Note: langmem tools require user_id in config.configurable for namespace resolution
|
|
536
|
+
config = {"configurable": context.model_dump()}
|
|
532
537
|
|
|
533
538
|
# Invoke the graph and handle interrupts (HITL)
|
|
534
539
|
# Wrap in async function to maintain connection pool throughout
|
dao_ai/config.py
CHANGED
|
@@ -1402,13 +1402,20 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1402
1402
|
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1403
1403
|
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1404
1404
|
|
|
1405
|
-
Note:
|
|
1405
|
+
Note: For Lakebase connections, `name` is optional and defaults to `instance_name`.
|
|
1406
|
+
For PostgreSQL connections, `name` is required.
|
|
1407
|
+
|
|
1408
|
+
Example Databricks Lakebase (minimal):
|
|
1409
|
+
```yaml
|
|
1410
|
+
databases:
|
|
1411
|
+
my_lakebase:
|
|
1412
|
+
instance_name: my-lakebase-instance # name defaults to instance_name
|
|
1413
|
+
```
|
|
1406
1414
|
|
|
1407
1415
|
Example Databricks Lakebase with Service Principal:
|
|
1408
1416
|
```yaml
|
|
1409
1417
|
databases:
|
|
1410
1418
|
my_lakebase:
|
|
1411
|
-
name: my-database
|
|
1412
1419
|
instance_name: my-lakebase-instance
|
|
1413
1420
|
service_principal:
|
|
1414
1421
|
client_id:
|
|
@@ -1424,7 +1431,6 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1424
1431
|
```yaml
|
|
1425
1432
|
databases:
|
|
1426
1433
|
my_lakebase:
|
|
1427
|
-
name: my-database
|
|
1428
1434
|
instance_name: my-lakebase-instance
|
|
1429
1435
|
on_behalf_of_user: true
|
|
1430
1436
|
```
|
|
@@ -1444,7 +1450,7 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1444
1450
|
"""
|
|
1445
1451
|
|
|
1446
1452
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1447
|
-
name: str
|
|
1453
|
+
name: Optional[str] = None
|
|
1448
1454
|
instance_name: Optional[str] = None
|
|
1449
1455
|
description: Optional[str] = None
|
|
1450
1456
|
host: Optional[AnyVariable] = None
|
|
@@ -1493,6 +1499,17 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1493
1499
|
)
|
|
1494
1500
|
return self
|
|
1495
1501
|
|
|
1502
|
+
@model_validator(mode="after")
|
|
1503
|
+
def populate_name_from_instance_name(self) -> Self:
|
|
1504
|
+
"""Populate name from instance_name if not provided for Lakebase connections."""
|
|
1505
|
+
if self.name is None and self.instance_name:
|
|
1506
|
+
self.name = self.instance_name
|
|
1507
|
+
elif self.name is None:
|
|
1508
|
+
raise ValueError(
|
|
1509
|
+
"Either 'name' or 'instance_name' must be provided for DatabaseModel."
|
|
1510
|
+
)
|
|
1511
|
+
return self
|
|
1512
|
+
|
|
1496
1513
|
@model_validator(mode="after")
|
|
1497
1514
|
def update_user(self) -> Self:
|
|
1498
1515
|
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
@@ -1590,10 +1607,10 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1590
1607
|
username: str | None = None
|
|
1591
1608
|
password_value: str | None = None
|
|
1592
1609
|
|
|
1593
|
-
# Resolve host -
|
|
1610
|
+
# Resolve host - fetch from API at runtime for Lakebase if not provided
|
|
1594
1611
|
host_value: Any = self.host
|
|
1595
|
-
if host_value is None and self.is_lakebase
|
|
1596
|
-
# Fetch host
|
|
1612
|
+
if host_value is None and self.is_lakebase:
|
|
1613
|
+
# Fetch host from Lakebase instance API
|
|
1597
1614
|
existing_instance: DatabaseInstance = (
|
|
1598
1615
|
self.workspace_client.database.get_database_instance(
|
|
1599
1616
|
name=self.instance_name
|
|
@@ -1756,6 +1773,105 @@ class GenieSemanticCacheParametersModel(BaseModel):
|
|
|
1756
1773
|
return self
|
|
1757
1774
|
|
|
1758
1775
|
|
|
1776
|
+
# Memory estimation for capacity planning:
|
|
1777
|
+
# - Each entry: ~20KB (8KB question embedding + 8KB context embedding + 4KB strings/overhead)
|
|
1778
|
+
# - 1,000 entries: ~20MB (0.4% of 8GB)
|
|
1779
|
+
# - 5,000 entries: ~100MB (2% of 8GB)
|
|
1780
|
+
# - 10,000 entries: ~200MB (4-5% of 8GB) - default for ~30 users
|
|
1781
|
+
# - 20,000 entries: ~400MB (8-10% of 8GB)
|
|
1782
|
+
# Default 10,000 entries provides ~330 queries per user for 30 users.
|
|
1783
|
+
class GenieInMemorySemanticCacheParametersModel(BaseModel):
|
|
1784
|
+
"""
|
|
1785
|
+
Configuration for in-memory semantic cache (no database required).
|
|
1786
|
+
|
|
1787
|
+
This cache stores embeddings and cache entries entirely in memory, providing
|
|
1788
|
+
semantic similarity matching without requiring external database dependencies
|
|
1789
|
+
like PostgreSQL or Databricks Lakebase.
|
|
1790
|
+
|
|
1791
|
+
Default settings are tuned for ~30 users on an 8GB machine:
|
|
1792
|
+
- Capacity: 10,000 entries (~200MB memory, ~330 queries per user)
|
|
1793
|
+
- Eviction: LRU (Least Recently Used) - keeps frequently accessed queries
|
|
1794
|
+
- TTL: 1 week (accommodates weekly work patterns and batch jobs)
|
|
1795
|
+
- Memory overhead: ~4-5% of 8GB system
|
|
1796
|
+
|
|
1797
|
+
The LRU eviction strategy ensures hot queries stay cached while cold queries
|
|
1798
|
+
are evicted, providing better hit rates than FIFO eviction.
|
|
1799
|
+
|
|
1800
|
+
For larger deployments or memory-constrained environments, adjust capacity and TTL accordingly.
|
|
1801
|
+
|
|
1802
|
+
Use this when:
|
|
1803
|
+
- No external database access is available
|
|
1804
|
+
- Single-instance deployments (cache not shared across instances)
|
|
1805
|
+
- Cache persistence across restarts is not required
|
|
1806
|
+
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
1807
|
+
|
|
1808
|
+
For multi-instance deployments or large cache sizes, use GenieSemanticCacheParametersModel
|
|
1809
|
+
with PostgreSQL backend instead.
|
|
1810
|
+
"""
|
|
1811
|
+
|
|
1812
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1813
|
+
time_to_live_seconds: int | None = (
|
|
1814
|
+
60 * 60 * 24 * 7
|
|
1815
|
+
) # 1 week default (604800 seconds), None or negative = never expires
|
|
1816
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1817
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1818
|
+
question_weight: Optional[float] = (
|
|
1819
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1820
|
+
)
|
|
1821
|
+
context_weight: Optional[float] = (
|
|
1822
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1823
|
+
)
|
|
1824
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1825
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1826
|
+
warehouse: WarehouseModel
|
|
1827
|
+
capacity: int | None = (
|
|
1828
|
+
10000 # Maximum cache entries. ~200MB for 10000 entries (1024-dim embeddings). LRU eviction when full. None = unlimited (not recommended for production).
|
|
1829
|
+
)
|
|
1830
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1831
|
+
max_context_tokens: int = (
|
|
1832
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1833
|
+
)
|
|
1834
|
+
|
|
1835
|
+
@model_validator(mode="after")
|
|
1836
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1837
|
+
"""
|
|
1838
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1839
|
+
|
|
1840
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1841
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1842
|
+
If both are provided, they must sum to 1.0.
|
|
1843
|
+
"""
|
|
1844
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1845
|
+
# Both missing - use defaults
|
|
1846
|
+
self.question_weight = 0.6
|
|
1847
|
+
self.context_weight = 0.4
|
|
1848
|
+
elif self.question_weight is None:
|
|
1849
|
+
# Compute question_weight from context_weight
|
|
1850
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1851
|
+
raise ValueError(
|
|
1852
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1853
|
+
)
|
|
1854
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1855
|
+
elif self.context_weight is None:
|
|
1856
|
+
# Compute context_weight from question_weight
|
|
1857
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1858
|
+
raise ValueError(
|
|
1859
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1860
|
+
)
|
|
1861
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1862
|
+
else:
|
|
1863
|
+
# Both provided - validate they sum to 1.0
|
|
1864
|
+
total_weight = self.question_weight + self.context_weight
|
|
1865
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1866
|
+
raise ValueError(
|
|
1867
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1868
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1869
|
+
f"of question vs context similarity in the combined score."
|
|
1870
|
+
)
|
|
1871
|
+
|
|
1872
|
+
return self
|
|
1873
|
+
|
|
1874
|
+
|
|
1759
1875
|
class SearchParametersModel(BaseModel):
|
|
1760
1876
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1761
1877
|
num_results: Optional[int] = 10
|
|
@@ -1763,43 +1879,83 @@ class SearchParametersModel(BaseModel):
|
|
|
1763
1879
|
query_type: Optional[str] = "ANN"
|
|
1764
1880
|
|
|
1765
1881
|
|
|
1882
|
+
class InstructionAwareRerankModel(BaseModel):
|
|
1883
|
+
"""
|
|
1884
|
+
LLM-based reranking considering user instructions and constraints.
|
|
1885
|
+
|
|
1886
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~100ms).
|
|
1887
|
+
Runs AFTER FlashRank as an additional constraint-aware reranking stage.
|
|
1888
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
1889
|
+
|
|
1890
|
+
Example:
|
|
1891
|
+
```yaml
|
|
1892
|
+
rerank:
|
|
1893
|
+
model: ms-marco-MiniLM-L-12-v2
|
|
1894
|
+
top_n: 20
|
|
1895
|
+
instruction_aware:
|
|
1896
|
+
model: *fast_llm
|
|
1897
|
+
instructions: |
|
|
1898
|
+
Prioritize results matching price and brand constraints.
|
|
1899
|
+
top_n: 10
|
|
1900
|
+
```
|
|
1901
|
+
"""
|
|
1902
|
+
|
|
1903
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1904
|
+
|
|
1905
|
+
model: Optional["LLMModel"] = Field(
|
|
1906
|
+
default=None,
|
|
1907
|
+
description="LLM for instruction reranking (fast model recommended)",
|
|
1908
|
+
)
|
|
1909
|
+
instructions: Optional[str] = Field(
|
|
1910
|
+
default=None,
|
|
1911
|
+
description="Custom reranking instructions for constraint prioritization",
|
|
1912
|
+
)
|
|
1913
|
+
top_n: Optional[int] = Field(
|
|
1914
|
+
default=None,
|
|
1915
|
+
description="Number of documents to return after instruction reranking",
|
|
1916
|
+
)
|
|
1917
|
+
|
|
1918
|
+
|
|
1766
1919
|
class RerankParametersModel(BaseModel):
|
|
1767
1920
|
"""
|
|
1768
|
-
Configuration for reranking retrieved documents
|
|
1921
|
+
Configuration for reranking retrieved documents.
|
|
1769
1922
|
|
|
1770
|
-
|
|
1771
|
-
|
|
1772
|
-
|
|
1923
|
+
Supports three reranking options that can be combined:
|
|
1924
|
+
1. FlashRank (local cross-encoder) - set `model`
|
|
1925
|
+
2. Databricks server-side reranking - set `columns`
|
|
1926
|
+
3. LLM instruction-aware reranking - set `instruction_aware`
|
|
1773
1927
|
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
|
|
1777
|
-
|
|
1928
|
+
Example with Databricks columns + instruction-aware (no FlashRank):
|
|
1929
|
+
```yaml
|
|
1930
|
+
rerank:
|
|
1931
|
+
columns: # Databricks server-side reranking
|
|
1932
|
+
- product_name
|
|
1933
|
+
- brand_name
|
|
1934
|
+
instruction_aware: # LLM-based constraint reranking
|
|
1935
|
+
model: *fast_llm
|
|
1936
|
+
instructions: "Prioritize by brand preferences"
|
|
1937
|
+
top_n: 10
|
|
1938
|
+
```
|
|
1778
1939
|
|
|
1779
|
-
Example:
|
|
1940
|
+
Example with FlashRank:
|
|
1780
1941
|
```yaml
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
rerank:
|
|
1785
|
-
model: ms-marco-MiniLM-L-12-v2
|
|
1786
|
-
top_n: 5 # Return top 5 after reranking
|
|
1942
|
+
rerank:
|
|
1943
|
+
model: ms-marco-MiniLM-L-12-v2 # FlashRank model
|
|
1944
|
+
top_n: 10
|
|
1787
1945
|
```
|
|
1788
1946
|
|
|
1789
|
-
Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1947
|
+
Available FlashRank models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1790
1948
|
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1791
|
-
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder
|
|
1949
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder)
|
|
1792
1950
|
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1793
1951
|
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1794
|
-
- "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
|
|
1795
|
-
- "miniReranker_arabic_v1" (Arabic language)
|
|
1796
1952
|
"""
|
|
1797
1953
|
|
|
1798
1954
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1799
1955
|
|
|
1800
|
-
model: str = Field(
|
|
1801
|
-
default=
|
|
1802
|
-
description="FlashRank model name.
|
|
1956
|
+
model: Optional[str] = Field(
|
|
1957
|
+
default=None,
|
|
1958
|
+
description="FlashRank model name. If None, FlashRank is not used (use columns for Databricks reranking).",
|
|
1803
1959
|
)
|
|
1804
1960
|
top_n: Optional[int] = Field(
|
|
1805
1961
|
default=None,
|
|
@@ -1812,6 +1968,289 @@ class RerankParametersModel(BaseModel):
|
|
|
1812
1968
|
columns: Optional[list[str]] = Field(
|
|
1813
1969
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
1814
1970
|
)
|
|
1971
|
+
instruction_aware: Optional[InstructionAwareRerankModel] = Field(
|
|
1972
|
+
default=None,
|
|
1973
|
+
description="Optional LLM-based reranking stage after FlashRank",
|
|
1974
|
+
)
|
|
1975
|
+
|
|
1976
|
+
|
|
1977
|
+
class FilterItem(BaseModel):
|
|
1978
|
+
"""A metadata filter for vector search.
|
|
1979
|
+
|
|
1980
|
+
Filters constrain search results by matching column values.
|
|
1981
|
+
Use column names from the provided schema description.
|
|
1982
|
+
"""
|
|
1983
|
+
|
|
1984
|
+
model_config = ConfigDict(extra="forbid")
|
|
1985
|
+
key: str = Field(
|
|
1986
|
+
description=(
|
|
1987
|
+
"Column name with optional operator suffix. "
|
|
1988
|
+
"Operators: (none) for equality, NOT for exclusion, "
|
|
1989
|
+
"< <= > >= for numeric comparison, "
|
|
1990
|
+
"LIKE for token match, NOT LIKE to exclude tokens."
|
|
1991
|
+
)
|
|
1992
|
+
)
|
|
1993
|
+
value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
|
|
1994
|
+
description=(
|
|
1995
|
+
"The filter value matching the column type. "
|
|
1996
|
+
"Use an array for IN-style matching multiple values."
|
|
1997
|
+
)
|
|
1998
|
+
)
|
|
1999
|
+
|
|
2000
|
+
|
|
2001
|
+
class SearchQuery(BaseModel):
|
|
2002
|
+
"""A single search query with optional metadata filters.
|
|
2003
|
+
|
|
2004
|
+
Represents one focused search intent extracted from the user's request.
|
|
2005
|
+
The text should be a natural language query optimized for semantic search.
|
|
2006
|
+
Filters constrain results to match specific metadata values.
|
|
2007
|
+
"""
|
|
2008
|
+
|
|
2009
|
+
model_config = ConfigDict(extra="forbid")
|
|
2010
|
+
text: str = Field(
|
|
2011
|
+
description=(
|
|
2012
|
+
"Natural language search query text optimized for semantic similarity. "
|
|
2013
|
+
"Should be focused on a single search intent. "
|
|
2014
|
+
"Do NOT include filter criteria in the text; use the filters field instead."
|
|
2015
|
+
)
|
|
2016
|
+
)
|
|
2017
|
+
filters: Optional[list[FilterItem]] = Field(
|
|
2018
|
+
default=None,
|
|
2019
|
+
description=(
|
|
2020
|
+
"Metadata filters to constrain search results. "
|
|
2021
|
+
"Set to null if no filters apply. "
|
|
2022
|
+
"Extract filter values from explicit constraints in the user query."
|
|
2023
|
+
),
|
|
2024
|
+
)
|
|
2025
|
+
|
|
2026
|
+
|
|
2027
|
+
class DecomposedQueries(BaseModel):
|
|
2028
|
+
"""Decomposed search queries extracted from a user request.
|
|
2029
|
+
|
|
2030
|
+
Break down complex user queries into multiple focused search queries.
|
|
2031
|
+
Each query targets a distinct search intent with appropriate filters.
|
|
2032
|
+
Generate 1-3 queries depending on the complexity of the user request.
|
|
2033
|
+
"""
|
|
2034
|
+
|
|
2035
|
+
model_config = ConfigDict(extra="forbid")
|
|
2036
|
+
queries: list[SearchQuery] = Field(
|
|
2037
|
+
description=(
|
|
2038
|
+
"List of search queries extracted from the user request. "
|
|
2039
|
+
"Each query should target a distinct search intent. "
|
|
2040
|
+
"Order queries by importance, with the most relevant first."
|
|
2041
|
+
)
|
|
2042
|
+
)
|
|
2043
|
+
|
|
2044
|
+
|
|
2045
|
+
class ColumnInfo(BaseModel):
|
|
2046
|
+
"""Column metadata for dynamic schema generation in structured output.
|
|
2047
|
+
|
|
2048
|
+
When provided, column information is embedded directly into the JSON schema
|
|
2049
|
+
that with_structured_output sends to the LLM, improving filter accuracy.
|
|
2050
|
+
"""
|
|
2051
|
+
|
|
2052
|
+
model_config = ConfigDict(extra="forbid")
|
|
2053
|
+
|
|
2054
|
+
name: str = Field(description="Column name as it appears in the database")
|
|
2055
|
+
type: Literal["string", "number", "boolean", "datetime"] = Field(
|
|
2056
|
+
default="string",
|
|
2057
|
+
description="Column data type for value validation",
|
|
2058
|
+
)
|
|
2059
|
+
operators: list[str] = Field(
|
|
2060
|
+
default=["", "NOT", "<", "<=", ">", ">=", "LIKE", "NOT LIKE"],
|
|
2061
|
+
description="Valid filter operators for this column",
|
|
2062
|
+
)
|
|
2063
|
+
|
|
2064
|
+
|
|
2065
|
+
class InstructedRetrieverModel(BaseModel):
|
|
2066
|
+
"""
|
|
2067
|
+
Configuration for instructed retrieval with query decomposition and RRF merging.
|
|
2068
|
+
|
|
2069
|
+
Instructed retrieval decomposes user queries into multiple subqueries with
|
|
2070
|
+
metadata filters, executes them in parallel, and merges results using
|
|
2071
|
+
Reciprocal Rank Fusion (RRF) before reranking.
|
|
2072
|
+
|
|
2073
|
+
Example:
|
|
2074
|
+
```yaml
|
|
2075
|
+
retriever:
|
|
2076
|
+
vector_store: *products_vector_store
|
|
2077
|
+
instructed:
|
|
2078
|
+
decomposition_model: *fast_llm
|
|
2079
|
+
schema_description: |
|
|
2080
|
+
Products table: product_id, brand_name, category, price, updated_at
|
|
2081
|
+
Filter operators: {"col": val}, {"col >": val}, {"col NOT": val}
|
|
2082
|
+
columns:
|
|
2083
|
+
- name: brand_name
|
|
2084
|
+
type: string
|
|
2085
|
+
- name: price
|
|
2086
|
+
type: number
|
|
2087
|
+
operators: ["", "<", "<=", ">", ">="]
|
|
2088
|
+
constraints:
|
|
2089
|
+
- "Prefer recent products"
|
|
2090
|
+
max_subqueries: 3
|
|
2091
|
+
examples:
|
|
2092
|
+
- query: "cheap drills"
|
|
2093
|
+
filters: {"price <": 100}
|
|
2094
|
+
```
|
|
2095
|
+
"""
|
|
2096
|
+
|
|
2097
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2098
|
+
|
|
2099
|
+
decomposition_model: Optional["LLMModel"] = Field(
|
|
2100
|
+
default=None,
|
|
2101
|
+
description="LLM for query decomposition (smaller/faster model recommended)",
|
|
2102
|
+
)
|
|
2103
|
+
schema_description: str = Field(
|
|
2104
|
+
description="Column names, types, and valid filter syntax for the LLM"
|
|
2105
|
+
)
|
|
2106
|
+
columns: Optional[list[ColumnInfo]] = Field(
|
|
2107
|
+
default=None,
|
|
2108
|
+
description=(
|
|
2109
|
+
"Structured column info for dynamic schema generation. "
|
|
2110
|
+
"When provided, column names are embedded in the JSON schema for better LLM accuracy."
|
|
2111
|
+
),
|
|
2112
|
+
)
|
|
2113
|
+
constraints: Optional[list[str]] = Field(
|
|
2114
|
+
default=None, description="Default constraints to always apply"
|
|
2115
|
+
)
|
|
2116
|
+
max_subqueries: int = Field(
|
|
2117
|
+
default=3, description="Maximum number of parallel subqueries"
|
|
2118
|
+
)
|
|
2119
|
+
rrf_k: int = Field(
|
|
2120
|
+
default=60,
|
|
2121
|
+
description="RRF constant (lower values weight top ranks more heavily)",
|
|
2122
|
+
)
|
|
2123
|
+
examples: Optional[list[dict[str, Any]]] = Field(
|
|
2124
|
+
default=None,
|
|
2125
|
+
description="Few-shot examples for domain-specific filter translation",
|
|
2126
|
+
)
|
|
2127
|
+
normalize_filter_case: Optional[Literal["uppercase", "lowercase"]] = Field(
|
|
2128
|
+
default=None,
|
|
2129
|
+
description="Auto-normalize filter string values to uppercase or lowercase",
|
|
2130
|
+
)
|
|
2131
|
+
|
|
2132
|
+
|
|
2133
|
+
class RouterModel(BaseModel):
|
|
2134
|
+
"""
|
|
2135
|
+
Select internal execution mode based on query characteristics.
|
|
2136
|
+
|
|
2137
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2138
|
+
Routes to internal modes within the same retriever, not external retrievers.
|
|
2139
|
+
Cross-index routing belongs at the agent/tool-selection level.
|
|
2140
|
+
|
|
2141
|
+
Execution Modes:
|
|
2142
|
+
- "standard": Single similarity_search() for simple keyword/product searches
|
|
2143
|
+
- "instructed": Decompose -> Parallel Search -> RRF for constrained queries
|
|
2144
|
+
|
|
2145
|
+
Example:
|
|
2146
|
+
```yaml
|
|
2147
|
+
retriever:
|
|
2148
|
+
router:
|
|
2149
|
+
model: *fast_llm
|
|
2150
|
+
default_mode: standard
|
|
2151
|
+
auto_bypass: true
|
|
2152
|
+
```
|
|
2153
|
+
"""
|
|
2154
|
+
|
|
2155
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2156
|
+
|
|
2157
|
+
model: Optional["LLMModel"] = Field(
|
|
2158
|
+
default=None,
|
|
2159
|
+
description="LLM for routing decision (fast model recommended)",
|
|
2160
|
+
)
|
|
2161
|
+
default_mode: Literal["standard", "instructed"] = Field(
|
|
2162
|
+
default="standard",
|
|
2163
|
+
description="Fallback mode if routing fails",
|
|
2164
|
+
)
|
|
2165
|
+
auto_bypass: bool = Field(
|
|
2166
|
+
default=True,
|
|
2167
|
+
description="Skip Instruction Reranker and Verifier for standard mode",
|
|
2168
|
+
)
|
|
2169
|
+
|
|
2170
|
+
|
|
2171
|
+
class VerificationResult(BaseModel):
|
|
2172
|
+
"""Verification of whether search results satisfy the user's constraints.
|
|
2173
|
+
|
|
2174
|
+
Analyze the retrieved results against the original query and any explicit
|
|
2175
|
+
constraints to determine if a retry with modified filters is needed.
|
|
2176
|
+
"""
|
|
2177
|
+
|
|
2178
|
+
model_config = ConfigDict(extra="forbid")
|
|
2179
|
+
|
|
2180
|
+
passed: bool = Field(
|
|
2181
|
+
description="True if results satisfy the user's query intent and constraints."
|
|
2182
|
+
)
|
|
2183
|
+
confidence: float = Field(
|
|
2184
|
+
ge=0.0,
|
|
2185
|
+
le=1.0,
|
|
2186
|
+
description="Confidence in the verification decision, from 0.0 (uncertain) to 1.0 (certain).",
|
|
2187
|
+
)
|
|
2188
|
+
feedback: Optional[str] = Field(
|
|
2189
|
+
default=None,
|
|
2190
|
+
description="Explanation of why verification passed or failed. Include specific issues found.",
|
|
2191
|
+
)
|
|
2192
|
+
suggested_filter_relaxation: Optional[dict[str, Any]] = Field(
|
|
2193
|
+
default=None,
|
|
2194
|
+
description=(
|
|
2195
|
+
"Suggested filter modifications for retry. "
|
|
2196
|
+
"Keys are column names, values indicate changes (e.g., 'REMOVE', 'WIDEN', or new values)."
|
|
2197
|
+
),
|
|
2198
|
+
)
|
|
2199
|
+
unmet_constraints: Optional[list[str]] = Field(
|
|
2200
|
+
default=None,
|
|
2201
|
+
description="List of user constraints that the results failed to satisfy.",
|
|
2202
|
+
)
|
|
2203
|
+
|
|
2204
|
+
|
|
2205
|
+
class VerifierModel(BaseModel):
|
|
2206
|
+
"""
|
|
2207
|
+
Validate results against user constraints with structured feedback.
|
|
2208
|
+
|
|
2209
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2210
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
2211
|
+
Returns structured feedback for intelligent retry, not blind retry.
|
|
2212
|
+
|
|
2213
|
+
Example:
|
|
2214
|
+
```yaml
|
|
2215
|
+
retriever:
|
|
2216
|
+
verifier:
|
|
2217
|
+
model: *fast_llm
|
|
2218
|
+
on_failure: warn_and_retry
|
|
2219
|
+
max_retries: 1
|
|
2220
|
+
```
|
|
2221
|
+
"""
|
|
2222
|
+
|
|
2223
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2224
|
+
|
|
2225
|
+
model: Optional["LLMModel"] = Field(
|
|
2226
|
+
default=None,
|
|
2227
|
+
description="LLM for verification (fast model recommended)",
|
|
2228
|
+
)
|
|
2229
|
+
on_failure: Literal["warn", "retry", "warn_and_retry"] = Field(
|
|
2230
|
+
default="warn",
|
|
2231
|
+
description="Behavior when verification fails",
|
|
2232
|
+
)
|
|
2233
|
+
max_retries: int = Field(
|
|
2234
|
+
default=1,
|
|
2235
|
+
description="Maximum retry attempts before returning with warning",
|
|
2236
|
+
)
|
|
2237
|
+
|
|
2238
|
+
|
|
2239
|
+
class RankedDocument(BaseModel):
|
|
2240
|
+
"""Single ranked document."""
|
|
2241
|
+
|
|
2242
|
+
index: int = Field(description="Document index from input list")
|
|
2243
|
+
score: float = Field(description="0.0-1.0 relevance score")
|
|
2244
|
+
reason: str = Field(default="", description="Why this score")
|
|
2245
|
+
|
|
2246
|
+
|
|
2247
|
+
class RankingResult(BaseModel):
|
|
2248
|
+
"""Reranking output."""
|
|
2249
|
+
|
|
2250
|
+
rankings: list[RankedDocument] = Field(
|
|
2251
|
+
default_factory=list,
|
|
2252
|
+
description="Ranked documents, highest score first",
|
|
2253
|
+
)
|
|
1815
2254
|
|
|
1816
2255
|
|
|
1817
2256
|
class RetrieverModel(BaseModel):
|
|
@@ -1821,10 +2260,22 @@ class RetrieverModel(BaseModel):
|
|
|
1821
2260
|
search_parameters: SearchParametersModel = Field(
|
|
1822
2261
|
default_factory=SearchParametersModel
|
|
1823
2262
|
)
|
|
2263
|
+
router: Optional[RouterModel] = Field(
|
|
2264
|
+
default=None,
|
|
2265
|
+
description="Optional query router for selecting execution mode (standard vs instructed).",
|
|
2266
|
+
)
|
|
1824
2267
|
rerank: Optional[RerankParametersModel | bool] = Field(
|
|
1825
2268
|
default=None,
|
|
1826
2269
|
description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
|
|
1827
2270
|
)
|
|
2271
|
+
instructed: Optional[InstructedRetrieverModel] = Field(
|
|
2272
|
+
default=None,
|
|
2273
|
+
description="Optional instructed retrieval with query decomposition and RRF merging.",
|
|
2274
|
+
)
|
|
2275
|
+
verifier: Optional[VerifierModel] = Field(
|
|
2276
|
+
default=None,
|
|
2277
|
+
description="Optional result verification with structured feedback for retry.",
|
|
2278
|
+
)
|
|
1828
2279
|
|
|
1829
2280
|
@model_validator(mode="after")
|
|
1830
2281
|
def set_default_columns(self) -> Self:
|
|
@@ -1835,9 +2286,13 @@ class RetrieverModel(BaseModel):
|
|
|
1835
2286
|
|
|
1836
2287
|
@model_validator(mode="after")
|
|
1837
2288
|
def set_default_reranker(self) -> Self:
|
|
1838
|
-
"""Convert bool to ReRankParametersModel with defaults.
|
|
2289
|
+
"""Convert bool to ReRankParametersModel with defaults.
|
|
2290
|
+
|
|
2291
|
+
When rerank: true is used, sets the default FlashRank model
|
|
2292
|
+
(ms-marco-MiniLM-L-12-v2) to enable reranking.
|
|
2293
|
+
"""
|
|
1839
2294
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
1840
|
-
self.rerank = RerankParametersModel()
|
|
2295
|
+
self.rerank = RerankParametersModel(model="ms-marco-MiniLM-L-12-v2")
|
|
1841
2296
|
return self
|
|
1842
2297
|
|
|
1843
2298
|
|
|
@@ -2985,8 +3440,24 @@ class GuidelineModel(BaseModel):
|
|
|
2985
3440
|
|
|
2986
3441
|
|
|
2987
3442
|
class EvaluationModel(BaseModel):
|
|
3443
|
+
"""
|
|
3444
|
+
Configuration for MLflow GenAI evaluation.
|
|
3445
|
+
|
|
3446
|
+
Attributes:
|
|
3447
|
+
model: LLM model used as the judge for LLM-based scorers (e.g., Guidelines, Safety).
|
|
3448
|
+
This model evaluates agent responses during evaluation.
|
|
3449
|
+
table: Table to store evaluation results.
|
|
3450
|
+
num_evals: Number of evaluation samples to generate.
|
|
3451
|
+
agent_description: Description of the agent for evaluation data generation.
|
|
3452
|
+
question_guidelines: Guidelines for generating evaluation questions.
|
|
3453
|
+
custom_inputs: Custom inputs to pass to the agent during evaluation.
|
|
3454
|
+
guidelines: List of guideline configurations for Guidelines scorers.
|
|
3455
|
+
"""
|
|
3456
|
+
|
|
2988
3457
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2989
|
-
model: LLMModel
|
|
3458
|
+
model: LLMModel = Field(
|
|
3459
|
+
..., description="LLM model used as the judge for LLM-based evaluation scorers"
|
|
3460
|
+
)
|
|
2990
3461
|
table: TableModel
|
|
2991
3462
|
num_evals: int
|
|
2992
3463
|
agent_description: Optional[str] = None
|
|
@@ -2994,6 +3465,16 @@ class EvaluationModel(BaseModel):
|
|
|
2994
3465
|
custom_inputs: dict[str, Any] = Field(default_factory=dict)
|
|
2995
3466
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
2996
3467
|
|
|
3468
|
+
@property
|
|
3469
|
+
def judge_model_endpoint(self) -> str:
|
|
3470
|
+
"""
|
|
3471
|
+
Get the judge model endpoint string for MLflow scorers.
|
|
3472
|
+
|
|
3473
|
+
Returns:
|
|
3474
|
+
Endpoint string in format 'databricks:/model-name'
|
|
3475
|
+
"""
|
|
3476
|
+
return f"databricks:/{self.model.name}"
|
|
3477
|
+
|
|
2997
3478
|
|
|
2998
3479
|
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
2999
3480
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|