dao-ai 0.1.2__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +104 -25
- dao_ai/config.py +149 -40
- dao_ai/middleware/__init__.py +33 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/nodes.py +5 -12
- dao_ai/orchestration/supervisor.py +6 -5
- dao_ai/providers/databricks.py +11 -0
- dao_ai/vector_search.py +37 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/METADATA +36 -2
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/RECORD +24 -18
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/cli.py
CHANGED
|
@@ -47,6 +47,57 @@ def get_default_user_id() -> str:
|
|
|
47
47
|
return local_user
|
|
48
48
|
|
|
49
49
|
|
|
50
|
+
def detect_cloud_provider(profile: Optional[str] = None) -> Optional[str]:
|
|
51
|
+
"""
|
|
52
|
+
Detect the cloud provider from the Databricks workspace URL.
|
|
53
|
+
|
|
54
|
+
The cloud provider is determined by the workspace URL pattern:
|
|
55
|
+
- Azure: *.azuredatabricks.net
|
|
56
|
+
- AWS: *.cloud.databricks.com (without gcp subdomain)
|
|
57
|
+
- GCP: *.gcp.databricks.com
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
profile: Optional Databricks CLI profile name
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
Cloud provider string ('azure', 'aws', 'gcp') or None if detection fails
|
|
64
|
+
"""
|
|
65
|
+
try:
|
|
66
|
+
from databricks.sdk import WorkspaceClient
|
|
67
|
+
|
|
68
|
+
# Create workspace client with optional profile
|
|
69
|
+
if profile:
|
|
70
|
+
w = WorkspaceClient(profile=profile)
|
|
71
|
+
else:
|
|
72
|
+
w = WorkspaceClient()
|
|
73
|
+
|
|
74
|
+
# Get the workspace URL from config
|
|
75
|
+
host = w.config.host
|
|
76
|
+
if not host:
|
|
77
|
+
logger.warning("Could not determine workspace URL for cloud detection")
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
host_lower = host.lower()
|
|
81
|
+
|
|
82
|
+
if "azuredatabricks.net" in host_lower:
|
|
83
|
+
logger.debug(f"Detected Azure cloud from workspace URL: {host}")
|
|
84
|
+
return "azure"
|
|
85
|
+
elif ".gcp.databricks.com" in host_lower:
|
|
86
|
+
logger.debug(f"Detected GCP cloud from workspace URL: {host}")
|
|
87
|
+
return "gcp"
|
|
88
|
+
elif ".cloud.databricks.com" in host_lower or "databricks.com" in host_lower:
|
|
89
|
+
# AWS uses *.cloud.databricks.com or regional patterns
|
|
90
|
+
logger.debug(f"Detected AWS cloud from workspace URL: {host}")
|
|
91
|
+
return "aws"
|
|
92
|
+
else:
|
|
93
|
+
logger.warning(f"Could not determine cloud provider from URL: {host}")
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Could not detect cloud provider: {e}")
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
|
|
50
101
|
env_path: str = find_dotenv()
|
|
51
102
|
if env_path:
|
|
52
103
|
logger.info(f"Loading environment variables from: {env_path}")
|
|
@@ -220,6 +271,13 @@ Examples:
|
|
|
220
271
|
"-t",
|
|
221
272
|
"--target",
|
|
222
273
|
type=str,
|
|
274
|
+
help="Bundle target name (default: auto-generated from app name and cloud)",
|
|
275
|
+
)
|
|
276
|
+
bundle_parser.add_argument(
|
|
277
|
+
"--cloud",
|
|
278
|
+
type=str,
|
|
279
|
+
choices=["azure", "aws", "gcp"],
|
|
280
|
+
help="Cloud provider (auto-detected from workspace URL if not specified)",
|
|
223
281
|
)
|
|
224
282
|
bundle_parser.add_argument(
|
|
225
283
|
"--dry-run",
|
|
@@ -549,13 +607,6 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
549
607
|
# Find the last AI message
|
|
550
608
|
for msg in reversed(latest_messages):
|
|
551
609
|
if isinstance(msg, AIMessage):
|
|
552
|
-
logger.debug(f"AI message content: {msg.content}")
|
|
553
|
-
logger.debug(
|
|
554
|
-
f"AI message has tool_calls: {hasattr(msg, 'tool_calls')}"
|
|
555
|
-
)
|
|
556
|
-
if hasattr(msg, "tool_calls"):
|
|
557
|
-
logger.debug(f"Tool calls: {msg.tool_calls}")
|
|
558
|
-
|
|
559
610
|
if hasattr(msg, "content") and msg.content:
|
|
560
611
|
response_content = msg.content
|
|
561
612
|
print(response_content, end="", flush=True)
|
|
@@ -676,7 +727,7 @@ def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
|
|
|
676
727
|
4. Returns the path to the generated file
|
|
677
728
|
|
|
678
729
|
The generated databricks.yaml is overwritten on each deployment and is not tracked in git.
|
|
679
|
-
|
|
730
|
+
The template contains cloud-specific targets (azure, aws, gcp) with appropriate node types.
|
|
680
731
|
|
|
681
732
|
Args:
|
|
682
733
|
config_path: Path to the app config file
|
|
@@ -713,39 +764,59 @@ def run_databricks_command(
|
|
|
713
764
|
profile: Optional[str] = None,
|
|
714
765
|
config: Optional[str] = None,
|
|
715
766
|
target: Optional[str] = None,
|
|
767
|
+
cloud: Optional[str] = None,
|
|
716
768
|
dry_run: bool = False,
|
|
717
769
|
) -> None:
|
|
718
|
-
"""Execute a databricks CLI command with optional profile and
|
|
770
|
+
"""Execute a databricks CLI command with optional profile, target, and cloud.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
command: The databricks CLI command to execute (e.g., ["bundle", "deploy"])
|
|
774
|
+
profile: Optional Databricks CLI profile name
|
|
775
|
+
config: Optional path to the configuration file
|
|
776
|
+
target: Optional bundle target name (if not provided, auto-generated from app name and cloud)
|
|
777
|
+
cloud: Optional cloud provider ('azure', 'aws', 'gcp'). Auto-detected if not specified.
|
|
778
|
+
dry_run: If True, print the command without executing
|
|
779
|
+
"""
|
|
719
780
|
config_path = Path(config) if config else None
|
|
720
781
|
|
|
721
782
|
if config_path and not config_path.exists():
|
|
722
783
|
logger.error(f"Configuration file {config_path} does not exist.")
|
|
723
784
|
sys.exit(1)
|
|
724
785
|
|
|
725
|
-
# Load app config
|
|
786
|
+
# Load app config
|
|
726
787
|
app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
|
|
727
788
|
normalized_name: str = normalize_name(app_config.app.name) if app_config else None
|
|
728
789
|
|
|
790
|
+
# Auto-detect cloud provider if not specified
|
|
791
|
+
if not cloud:
|
|
792
|
+
cloud = detect_cloud_provider(profile)
|
|
793
|
+
if cloud:
|
|
794
|
+
logger.info(f"Auto-detected cloud provider: {cloud}")
|
|
795
|
+
else:
|
|
796
|
+
logger.warning("Could not detect cloud provider. Defaulting to 'azure'.")
|
|
797
|
+
cloud = "azure"
|
|
798
|
+
|
|
729
799
|
# Generate app-specific bundle from template (overwrites databricks.yaml temporarily)
|
|
730
800
|
if config_path and app_config:
|
|
731
801
|
generate_bundle_from_template(config_path, normalized_name)
|
|
732
802
|
|
|
733
|
-
# Use
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
target
|
|
737
|
-
logger.debug(f"Using app-specific target: {target}")
|
|
803
|
+
# Use cloud as target (azure, aws, gcp) - can be overridden with explicit --target
|
|
804
|
+
if not target:
|
|
805
|
+
target = cloud
|
|
806
|
+
logger.debug(f"Using cloud-based target: {target}")
|
|
738
807
|
|
|
739
|
-
# Build databricks command
|
|
808
|
+
# Build databricks command
|
|
809
|
+
# --profile is a global flag, --target is a subcommand flag for 'bundle'
|
|
740
810
|
cmd = ["databricks"]
|
|
741
811
|
if profile:
|
|
742
812
|
cmd.extend(["--profile", profile])
|
|
743
813
|
|
|
814
|
+
cmd.extend(command)
|
|
815
|
+
|
|
816
|
+
# --target must come after the bundle subcommand (it's a subcommand-specific flag)
|
|
744
817
|
if target:
|
|
745
818
|
cmd.extend(["--target", target])
|
|
746
819
|
|
|
747
|
-
cmd.extend(command)
|
|
748
|
-
|
|
749
820
|
# Add config_path variable for notebooks
|
|
750
821
|
if config_path and app_config:
|
|
751
822
|
# Calculate relative path from notebooks directory to config file
|
|
@@ -800,30 +871,38 @@ def handle_bundle_command(options: Namespace) -> None:
|
|
|
800
871
|
profile: Optional[str] = options.profile
|
|
801
872
|
config: Optional[str] = options.config
|
|
802
873
|
target: Optional[str] = options.target
|
|
874
|
+
cloud: Optional[str] = options.cloud
|
|
803
875
|
dry_run: bool = options.dry_run
|
|
804
876
|
|
|
805
877
|
if options.deploy:
|
|
806
878
|
logger.info("Deploying DAO AI asset bundle...")
|
|
807
879
|
run_databricks_command(
|
|
808
|
-
["bundle", "deploy"],
|
|
880
|
+
["bundle", "deploy"],
|
|
881
|
+
profile=profile,
|
|
882
|
+
config=config,
|
|
883
|
+
target=target,
|
|
884
|
+
cloud=cloud,
|
|
885
|
+
dry_run=dry_run,
|
|
809
886
|
)
|
|
810
887
|
if options.run:
|
|
811
888
|
logger.info("Running DAO AI system with current configuration...")
|
|
812
889
|
# Use static job resource key that matches databricks.yaml (resources.jobs.deploy_job)
|
|
813
890
|
run_databricks_command(
|
|
814
891
|
["bundle", "run", "deploy_job"],
|
|
815
|
-
profile,
|
|
816
|
-
config,
|
|
817
|
-
target,
|
|
892
|
+
profile=profile,
|
|
893
|
+
config=config,
|
|
894
|
+
target=target,
|
|
895
|
+
cloud=cloud,
|
|
818
896
|
dry_run=dry_run,
|
|
819
897
|
)
|
|
820
898
|
if options.destroy:
|
|
821
899
|
logger.info("Destroying DAO AI system with current configuration...")
|
|
822
900
|
run_databricks_command(
|
|
823
901
|
["bundle", "destroy", "--auto-approve"],
|
|
824
|
-
profile,
|
|
825
|
-
config,
|
|
826
|
-
target,
|
|
902
|
+
profile=profile,
|
|
903
|
+
config=config,
|
|
904
|
+
target=target,
|
|
905
|
+
cloud=cloud,
|
|
827
906
|
dry_run=dry_run,
|
|
828
907
|
)
|
|
829
908
|
else:
|
dao_ai/config.py
CHANGED
|
@@ -601,6 +601,8 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
601
601
|
|
|
602
602
|
|
|
603
603
|
class IndexModel(IsDatabricksResource, HasFullName):
|
|
604
|
+
"""Model representing a Databricks Vector Search index."""
|
|
605
|
+
|
|
604
606
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
605
607
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
606
608
|
name: str
|
|
@@ -624,6 +626,22 @@ class IndexModel(IsDatabricksResource, HasFullName):
|
|
|
624
626
|
)
|
|
625
627
|
]
|
|
626
628
|
|
|
629
|
+
def exists(self) -> bool:
|
|
630
|
+
"""Check if this vector search index exists.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
True if the index exists, False otherwise.
|
|
634
|
+
"""
|
|
635
|
+
try:
|
|
636
|
+
self.workspace_client.vector_search_indexes.get_index(self.full_name)
|
|
637
|
+
return True
|
|
638
|
+
except NotFound:
|
|
639
|
+
logger.debug(f"Index not found: {self.full_name}")
|
|
640
|
+
return False
|
|
641
|
+
except Exception as e:
|
|
642
|
+
logger.warning(f"Error checking index existence for {self.full_name}: {e}")
|
|
643
|
+
return False
|
|
644
|
+
|
|
627
645
|
|
|
628
646
|
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
629
647
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1009,27 +1027,92 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
1009
1027
|
|
|
1010
1028
|
|
|
1011
1029
|
class VectorStoreModel(IsDatabricksResource):
|
|
1030
|
+
"""
|
|
1031
|
+
Configuration model for a Databricks Vector Search store.
|
|
1032
|
+
|
|
1033
|
+
Supports two modes:
|
|
1034
|
+
1. **Use Existing Index**: Provide only `index` (fully qualified name).
|
|
1035
|
+
Used for querying an existing vector search index at runtime.
|
|
1036
|
+
2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
|
|
1037
|
+
Used for creating a new vector search index.
|
|
1038
|
+
|
|
1039
|
+
Examples:
|
|
1040
|
+
Minimal configuration (use existing index):
|
|
1041
|
+
```yaml
|
|
1042
|
+
vector_stores:
|
|
1043
|
+
products_search:
|
|
1044
|
+
index:
|
|
1045
|
+
name: catalog.schema.my_index
|
|
1046
|
+
```
|
|
1047
|
+
|
|
1048
|
+
Full provisioning configuration:
|
|
1049
|
+
```yaml
|
|
1050
|
+
vector_stores:
|
|
1051
|
+
products_search:
|
|
1052
|
+
source_table:
|
|
1053
|
+
schema: *my_schema
|
|
1054
|
+
name: products
|
|
1055
|
+
embedding_source_column: description
|
|
1056
|
+
endpoint:
|
|
1057
|
+
name: my_endpoint
|
|
1058
|
+
```
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1012
1061
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1013
|
-
|
|
1062
|
+
|
|
1063
|
+
# RUNTIME: Only index is truly required for querying existing indexes
|
|
1014
1064
|
index: Optional[IndexModel] = None
|
|
1065
|
+
|
|
1066
|
+
# PROVISIONING ONLY: Required when creating a new index
|
|
1067
|
+
source_table: Optional[TableModel] = None
|
|
1068
|
+
embedding_source_column: Optional[str] = None
|
|
1069
|
+
embedding_model: Optional[LLMModel] = None
|
|
1015
1070
|
endpoint: Optional[VectorSearchEndpoint] = None
|
|
1016
|
-
|
|
1071
|
+
|
|
1072
|
+
# OPTIONAL: For both modes
|
|
1017
1073
|
source_path: Optional[VolumePathModel] = None
|
|
1018
1074
|
checkpoint_path: Optional[VolumePathModel] = None
|
|
1019
1075
|
primary_key: Optional[str] = None
|
|
1020
1076
|
columns: Optional[list[str]] = Field(default_factory=list)
|
|
1021
1077
|
doc_uri: Optional[str] = None
|
|
1022
|
-
|
|
1078
|
+
|
|
1079
|
+
@model_validator(mode="after")
|
|
1080
|
+
def validate_configuration_mode(self) -> Self:
|
|
1081
|
+
"""
|
|
1082
|
+
Validate that configuration is valid for either:
|
|
1083
|
+
- Use existing mode: index is provided
|
|
1084
|
+
- Provisioning mode: source_table + embedding_source_column provided
|
|
1085
|
+
"""
|
|
1086
|
+
has_index = self.index is not None
|
|
1087
|
+
has_source_table = self.source_table is not None
|
|
1088
|
+
has_embedding_col = self.embedding_source_column is not None
|
|
1089
|
+
|
|
1090
|
+
# Must have at least index OR source_table
|
|
1091
|
+
if not has_index and not has_source_table:
|
|
1092
|
+
raise ValueError(
|
|
1093
|
+
"Either 'index' (for existing indexes) or 'source_table' "
|
|
1094
|
+
"(for provisioning) must be provided"
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
# If provisioning mode, need embedding_source_column
|
|
1098
|
+
if has_source_table and not has_embedding_col:
|
|
1099
|
+
raise ValueError(
|
|
1100
|
+
"embedding_source_column is required when source_table is provided (provisioning mode)"
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
return self
|
|
1023
1104
|
|
|
1024
1105
|
@model_validator(mode="after")
|
|
1025
1106
|
def set_default_embedding_model(self) -> Self:
|
|
1026
|
-
|
|
1107
|
+
# Only set default embedding model in provisioning mode
|
|
1108
|
+
if self.source_table is not None and not self.embedding_model:
|
|
1027
1109
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
1028
1110
|
return self
|
|
1029
1111
|
|
|
1030
1112
|
@model_validator(mode="after")
|
|
1031
1113
|
def set_default_primary_key(self) -> Self:
|
|
1032
|
-
|
|
1114
|
+
# Only auto-discover primary key in provisioning mode
|
|
1115
|
+
if self.primary_key is None and self.source_table is not None:
|
|
1033
1116
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1034
1117
|
|
|
1035
1118
|
provider: DatabricksProvider = DatabricksProvider()
|
|
@@ -1050,14 +1133,16 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
1050
1133
|
|
|
1051
1134
|
@model_validator(mode="after")
|
|
1052
1135
|
def set_default_index(self) -> Self:
|
|
1053
|
-
|
|
1136
|
+
# Only generate index from source_table in provisioning mode
|
|
1137
|
+
if self.index is None and self.source_table is not None:
|
|
1054
1138
|
name: str = f"{self.source_table.name}_index"
|
|
1055
1139
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
1056
1140
|
return self
|
|
1057
1141
|
|
|
1058
1142
|
@model_validator(mode="after")
|
|
1059
1143
|
def set_default_endpoint(self) -> Self:
|
|
1060
|
-
|
|
1144
|
+
# Only find/create endpoint in provisioning mode
|
|
1145
|
+
if self.endpoint is None and self.source_table is not None:
|
|
1061
1146
|
from dao_ai.providers.databricks import (
|
|
1062
1147
|
DatabricksProvider,
|
|
1063
1148
|
with_available_indexes,
|
|
@@ -1092,18 +1177,60 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
1092
1177
|
return self.index.as_resources()
|
|
1093
1178
|
|
|
1094
1179
|
def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
|
|
1095
|
-
from dao_ai.providers.base import ServiceProvider
|
|
1096
1180
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1097
1181
|
|
|
1098
|
-
provider:
|
|
1182
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
1099
1183
|
index: VectorSearchIndex = provider.get_vector_index(self)
|
|
1100
1184
|
return index
|
|
1101
1185
|
|
|
1102
1186
|
def create(self, vsc: VectorSearchClient | None = None) -> None:
|
|
1103
|
-
|
|
1187
|
+
"""
|
|
1188
|
+
Create or validate the vector search index.
|
|
1189
|
+
|
|
1190
|
+
Behavior depends on configuration mode:
|
|
1191
|
+
- **Provisioning Mode** (source_table provided): Creates the index
|
|
1192
|
+
- **Use Existing Mode** (only index provided): Validates the index exists
|
|
1193
|
+
|
|
1194
|
+
Args:
|
|
1195
|
+
vsc: Optional VectorSearchClient instance
|
|
1196
|
+
|
|
1197
|
+
Raises:
|
|
1198
|
+
ValueError: If configuration is invalid or index doesn't exist
|
|
1199
|
+
"""
|
|
1104
1200
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1105
1201
|
|
|
1106
|
-
provider:
|
|
1202
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
1203
|
+
|
|
1204
|
+
if self.source_table is not None:
|
|
1205
|
+
self._create_new_index(provider)
|
|
1206
|
+
else:
|
|
1207
|
+
self._validate_existing_index(provider)
|
|
1208
|
+
|
|
1209
|
+
def _validate_existing_index(self, provider: Any) -> None:
|
|
1210
|
+
"""Validate that an existing index is accessible."""
|
|
1211
|
+
if self.index is None:
|
|
1212
|
+
raise ValueError("index is required for 'use existing' mode")
|
|
1213
|
+
|
|
1214
|
+
if self.index.exists():
|
|
1215
|
+
logger.info(
|
|
1216
|
+
"Vector search index exists and ready",
|
|
1217
|
+
index_name=self.index.full_name,
|
|
1218
|
+
)
|
|
1219
|
+
else:
|
|
1220
|
+
raise ValueError(
|
|
1221
|
+
f"Index '{self.index.full_name}' does not exist. "
|
|
1222
|
+
"Provide 'source_table' to provision it."
|
|
1223
|
+
)
|
|
1224
|
+
|
|
1225
|
+
def _create_new_index(self, provider: Any) -> None:
|
|
1226
|
+
"""Create a new vector search index from source table."""
|
|
1227
|
+
if self.embedding_source_column is None:
|
|
1228
|
+
raise ValueError("embedding_source_column is required for provisioning")
|
|
1229
|
+
if self.endpoint is None:
|
|
1230
|
+
raise ValueError("endpoint is required for provisioning")
|
|
1231
|
+
if self.index is None:
|
|
1232
|
+
raise ValueError("index is required for provisioning")
|
|
1233
|
+
|
|
1107
1234
|
provider.create_vector_store(self)
|
|
1108
1235
|
|
|
1109
1236
|
|
|
@@ -1266,32 +1393,12 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1266
1393
|
|
|
1267
1394
|
@model_validator(mode="after")
|
|
1268
1395
|
def update_host(self) -> Self:
|
|
1269
|
-
|
|
1396
|
+
# Lakebase uses instance_name directly via databricks_langchain - host not needed
|
|
1397
|
+
if self.is_lakebase:
|
|
1270
1398
|
return self
|
|
1271
1399
|
|
|
1272
|
-
#
|
|
1273
|
-
#
|
|
1274
|
-
if self.is_lakebase:
|
|
1275
|
-
try:
|
|
1276
|
-
existing_instance: DatabaseInstance = (
|
|
1277
|
-
self.workspace_client.database.get_database_instance(
|
|
1278
|
-
name=self.instance_name
|
|
1279
|
-
)
|
|
1280
|
-
)
|
|
1281
|
-
self.host = existing_instance.read_write_dns
|
|
1282
|
-
except Exception as e:
|
|
1283
|
-
# For Lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1284
|
-
# The host will need to be provided explicitly or fetched at runtime
|
|
1285
|
-
if self.on_behalf_of_user:
|
|
1286
|
-
logger.debug(
|
|
1287
|
-
f"Could not fetch host for database {self.instance_name} "
|
|
1288
|
-
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1289
|
-
)
|
|
1290
|
-
else:
|
|
1291
|
-
raise ValueError(
|
|
1292
|
-
f"Could not fetch host for database {self.instance_name}. "
|
|
1293
|
-
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1294
|
-
)
|
|
1400
|
+
# For standard PostgreSQL, host must be provided by the user
|
|
1401
|
+
# (enforced by validate_connection_type)
|
|
1295
1402
|
return self
|
|
1296
1403
|
|
|
1297
1404
|
@model_validator(mode="after")
|
|
@@ -1549,11 +1656,13 @@ class RerankParametersModel(BaseModel):
|
|
|
1549
1656
|
top_n: 5 # Return top 5 after reranking
|
|
1550
1657
|
```
|
|
1551
1658
|
|
|
1552
|
-
Available models (
|
|
1553
|
-
- "ms-marco-TinyBERT-L-2-v2" (
|
|
1554
|
-
- "ms-marco-MiniLM-L-
|
|
1555
|
-
- "
|
|
1556
|
-
- "
|
|
1659
|
+
Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1660
|
+
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1661
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
|
|
1662
|
+
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1663
|
+
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1664
|
+
- "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
|
|
1665
|
+
- "miniReranker_arabic_v1" (Arabic language)
|
|
1557
1666
|
"""
|
|
1558
1667
|
|
|
1559
1668
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
dao_ai/middleware/__init__.py
CHANGED
|
@@ -3,8 +3,15 @@
|
|
|
3
3
|
|
|
4
4
|
# Re-export LangChain built-in middleware
|
|
5
5
|
from langchain.agents.middleware import (
|
|
6
|
+
ClearToolUsesEdit,
|
|
7
|
+
ContextEditingMiddleware,
|
|
6
8
|
HumanInTheLoopMiddleware,
|
|
9
|
+
ModelCallLimitMiddleware,
|
|
10
|
+
ModelRetryMiddleware,
|
|
11
|
+
PIIMiddleware,
|
|
7
12
|
SummarizationMiddleware,
|
|
13
|
+
ToolCallLimitMiddleware,
|
|
14
|
+
ToolRetryMiddleware,
|
|
8
15
|
after_agent,
|
|
9
16
|
after_model,
|
|
10
17
|
before_agent,
|
|
@@ -37,6 +44,10 @@ from dao_ai.middleware.base import (
|
|
|
37
44
|
ModelRequest,
|
|
38
45
|
ModelResponse,
|
|
39
46
|
)
|
|
47
|
+
from dao_ai.middleware.context_editing import (
|
|
48
|
+
create_clear_tool_uses_edit,
|
|
49
|
+
create_context_editing_middleware,
|
|
50
|
+
)
|
|
40
51
|
from dao_ai.middleware.core import create_factory_middleware
|
|
41
52
|
from dao_ai.middleware.guardrails import (
|
|
42
53
|
ContentFilterMiddleware,
|
|
@@ -62,10 +73,15 @@ from dao_ai.middleware.message_validation import (
|
|
|
62
73
|
create_thread_id_validation_middleware,
|
|
63
74
|
create_user_id_validation_middleware,
|
|
64
75
|
)
|
|
76
|
+
from dao_ai.middleware.model_call_limit import create_model_call_limit_middleware
|
|
77
|
+
from dao_ai.middleware.model_retry import create_model_retry_middleware
|
|
78
|
+
from dao_ai.middleware.pii import create_pii_middleware
|
|
65
79
|
from dao_ai.middleware.summarization import (
|
|
66
80
|
LoggingSummarizationMiddleware,
|
|
67
81
|
create_summarization_middleware,
|
|
68
82
|
)
|
|
83
|
+
from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
|
|
84
|
+
from dao_ai.middleware.tool_retry import create_tool_retry_middleware
|
|
69
85
|
|
|
70
86
|
__all__ = [
|
|
71
87
|
# Base class (from LangChain)
|
|
@@ -85,6 +101,13 @@ __all__ = [
|
|
|
85
101
|
"SummarizationMiddleware",
|
|
86
102
|
"LoggingSummarizationMiddleware",
|
|
87
103
|
"HumanInTheLoopMiddleware",
|
|
104
|
+
"ToolCallLimitMiddleware",
|
|
105
|
+
"ModelCallLimitMiddleware",
|
|
106
|
+
"ToolRetryMiddleware",
|
|
107
|
+
"ModelRetryMiddleware",
|
|
108
|
+
"ContextEditingMiddleware",
|
|
109
|
+
"ClearToolUsesEdit",
|
|
110
|
+
"PIIMiddleware",
|
|
88
111
|
# Core factory function
|
|
89
112
|
"create_factory_middleware",
|
|
90
113
|
# DAO AI middleware implementations
|
|
@@ -122,4 +145,14 @@ __all__ = [
|
|
|
122
145
|
"create_assert_middleware",
|
|
123
146
|
"create_suggest_middleware",
|
|
124
147
|
"create_refine_middleware",
|
|
148
|
+
# Limit and retry middleware factory functions
|
|
149
|
+
"create_tool_call_limit_middleware",
|
|
150
|
+
"create_model_call_limit_middleware",
|
|
151
|
+
"create_tool_retry_middleware",
|
|
152
|
+
"create_model_retry_middleware",
|
|
153
|
+
# Context editing middleware factory functions
|
|
154
|
+
"create_context_editing_middleware",
|
|
155
|
+
"create_clear_tool_uses_edit",
|
|
156
|
+
# PII middleware factory functions
|
|
157
|
+
"create_pii_middleware",
|
|
125
158
|
]
|
dao_ai/middleware/assertions.py
CHANGED
|
@@ -688,7 +688,7 @@ def create_assert_middleware(
|
|
|
688
688
|
name: Name for function constraints
|
|
689
689
|
|
|
690
690
|
Returns:
|
|
691
|
-
AssertMiddleware configured with the constraint
|
|
691
|
+
List containing AssertMiddleware configured with the constraint
|
|
692
692
|
|
|
693
693
|
Example:
|
|
694
694
|
# Using a Constraint class
|
|
@@ -737,7 +737,7 @@ def create_suggest_middleware(
|
|
|
737
737
|
name: Name for function constraints
|
|
738
738
|
|
|
739
739
|
Returns:
|
|
740
|
-
SuggestMiddleware configured with the constraint
|
|
740
|
+
List containing SuggestMiddleware configured with the constraint
|
|
741
741
|
|
|
742
742
|
Example:
|
|
743
743
|
def is_professional(response: str, ctx: dict) -> ConstraintResult:
|
|
@@ -783,7 +783,7 @@ def create_refine_middleware(
|
|
|
783
783
|
select_best: Track and return best response across iterations
|
|
784
784
|
|
|
785
785
|
Returns:
|
|
786
|
-
RefineMiddleware configured with the reward function
|
|
786
|
+
List containing RefineMiddleware configured with the reward function
|
|
787
787
|
|
|
788
788
|
Example:
|
|
789
789
|
def evaluate_completeness(response: str, ctx: dict) -> float:
|