dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -23,8 +23,11 @@ from databricks.sdk.credentials_provider import (
|
|
|
23
23
|
CredentialsStrategy,
|
|
24
24
|
ModelServingUserCredentials,
|
|
25
25
|
)
|
|
26
|
+
from databricks.sdk.errors.platform import NotFound
|
|
26
27
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
|
+
from databricks.sdk.service.dashboards import GenieSpace
|
|
27
29
|
from databricks.sdk.service.database import DatabaseInstance
|
|
30
|
+
from databricks.sdk.service.sql import GetWarehouseResponse
|
|
28
31
|
from databricks.vector_search.client import VectorSearchClient
|
|
29
32
|
from databricks.vector_search.index import VectorSearchIndex
|
|
30
33
|
from databricks_langchain import (
|
|
@@ -65,10 +68,13 @@ from pydantic import (
|
|
|
65
68
|
BaseModel,
|
|
66
69
|
ConfigDict,
|
|
67
70
|
Field,
|
|
71
|
+
PrivateAttr,
|
|
68
72
|
field_serializer,
|
|
69
73
|
model_validator,
|
|
70
74
|
)
|
|
71
75
|
|
|
76
|
+
from dao_ai.utils import normalize_name
|
|
77
|
+
|
|
72
78
|
|
|
73
79
|
class HasValue(ABC):
|
|
74
80
|
@abstractmethod
|
|
@@ -228,6 +234,9 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
228
234
|
workspace_host: Optional[AnyVariable] = None
|
|
229
235
|
pat: Optional[AnyVariable] = None
|
|
230
236
|
|
|
237
|
+
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
+
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
+
|
|
231
240
|
@abstractmethod
|
|
232
241
|
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
233
242
|
|
|
@@ -263,6 +272,8 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
263
272
|
"""
|
|
264
273
|
Get a WorkspaceClient configured with the appropriate authentication.
|
|
265
274
|
|
|
275
|
+
The client is lazily instantiated on first access and cached for subsequent calls.
|
|
276
|
+
|
|
266
277
|
Authentication priority:
|
|
267
278
|
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
268
279
|
2. If service principal credentials are configured (client_id, client_secret,
|
|
@@ -270,6 +281,10 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
270
281
|
3. If PAT is configured, uses token authentication
|
|
271
282
|
4. Otherwise, uses default/ambient authentication
|
|
272
283
|
"""
|
|
284
|
+
# Return cached client if already instantiated
|
|
285
|
+
if self._workspace_client is not None:
|
|
286
|
+
return self._workspace_client
|
|
287
|
+
|
|
273
288
|
from dao_ai.utils import normalize_host
|
|
274
289
|
|
|
275
290
|
# Check for OBO first (highest priority)
|
|
@@ -279,7 +294,10 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
279
294
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
280
295
|
f"with OBO credentials strategy"
|
|
281
296
|
)
|
|
282
|
-
|
|
297
|
+
self._workspace_client = WorkspaceClient(
|
|
298
|
+
credentials_strategy=credentials_strategy
|
|
299
|
+
)
|
|
300
|
+
return self._workspace_client
|
|
283
301
|
|
|
284
302
|
# Check for service principal credentials
|
|
285
303
|
client_id_value: str | None = (
|
|
@@ -299,12 +317,13 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
299
317
|
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
300
318
|
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
301
319
|
)
|
|
302
|
-
|
|
320
|
+
self._workspace_client = WorkspaceClient(
|
|
303
321
|
host=workspace_host_value,
|
|
304
322
|
client_id=client_id_value,
|
|
305
323
|
client_secret=client_secret_value,
|
|
306
324
|
auth_type="oauth-m2m",
|
|
307
325
|
)
|
|
326
|
+
return self._workspace_client
|
|
308
327
|
|
|
309
328
|
# Check for PAT authentication
|
|
310
329
|
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
@@ -312,18 +331,20 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
312
331
|
logger.debug(
|
|
313
332
|
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
314
333
|
)
|
|
315
|
-
|
|
334
|
+
self._workspace_client = WorkspaceClient(
|
|
316
335
|
host=workspace_host_value,
|
|
317
336
|
token=pat_value,
|
|
318
337
|
auth_type="pat",
|
|
319
338
|
)
|
|
339
|
+
return self._workspace_client
|
|
320
340
|
|
|
321
341
|
# Default: use ambient authentication
|
|
322
342
|
logger.debug(
|
|
323
343
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
324
344
|
"with default/ambient authentication"
|
|
325
345
|
)
|
|
326
|
-
|
|
346
|
+
self._workspace_client = WorkspaceClient()
|
|
347
|
+
return self._workspace_client
|
|
327
348
|
|
|
328
349
|
|
|
329
350
|
class Privilege(str, Enum):
|
|
@@ -431,6 +452,22 @@ class TableModel(IsDatabricksResource, HasFullName):
|
|
|
431
452
|
def api_scopes(self) -> Sequence[str]:
|
|
432
453
|
return []
|
|
433
454
|
|
|
455
|
+
def exists(self) -> bool:
|
|
456
|
+
"""Check if the table exists in Unity Catalog.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
True if the table exists, False otherwise.
|
|
460
|
+
"""
|
|
461
|
+
try:
|
|
462
|
+
self.workspace_client.tables.get(full_name=self.full_name)
|
|
463
|
+
return True
|
|
464
|
+
except NotFound:
|
|
465
|
+
logger.debug(f"Table not found: {self.full_name}")
|
|
466
|
+
return False
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.warning(f"Error checking table existence for {self.full_name}: {e}")
|
|
469
|
+
return False
|
|
470
|
+
|
|
434
471
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
435
472
|
resources: list[DatabricksResource] = []
|
|
436
473
|
|
|
@@ -477,6 +514,7 @@ class TableModel(IsDatabricksResource, HasFullName):
|
|
|
477
514
|
class LLMModel(IsDatabricksResource):
|
|
478
515
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
479
516
|
name: str
|
|
517
|
+
description: Optional[str] = None
|
|
480
518
|
temperature: Optional[float] = 0.1
|
|
481
519
|
max_tokens: Optional[int] = 8192
|
|
482
520
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
@@ -587,12 +625,297 @@ class IndexModel(IsDatabricksResource, HasFullName):
|
|
|
587
625
|
]
|
|
588
626
|
|
|
589
627
|
|
|
628
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
629
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
630
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
631
|
+
name: Optional[str] = None
|
|
632
|
+
|
|
633
|
+
@model_validator(mode="after")
|
|
634
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
635
|
+
if not self.name and not self.schema_model:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
638
|
+
)
|
|
639
|
+
return self
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def full_name(self) -> str:
|
|
643
|
+
if self.schema_model:
|
|
644
|
+
name: str = ""
|
|
645
|
+
if self.name:
|
|
646
|
+
name = f".{self.name}"
|
|
647
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
648
|
+
return self.name
|
|
649
|
+
|
|
650
|
+
def exists(self) -> bool:
|
|
651
|
+
"""Check if the function exists in Unity Catalog.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
True if the function exists, False otherwise.
|
|
655
|
+
"""
|
|
656
|
+
try:
|
|
657
|
+
self.workspace_client.functions.get(name=self.full_name)
|
|
658
|
+
return True
|
|
659
|
+
except NotFound:
|
|
660
|
+
logger.debug(f"Function not found: {self.full_name}")
|
|
661
|
+
return False
|
|
662
|
+
except Exception as e:
|
|
663
|
+
logger.warning(
|
|
664
|
+
f"Error checking function existence for {self.full_name}: {e}"
|
|
665
|
+
)
|
|
666
|
+
return False
|
|
667
|
+
|
|
668
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
669
|
+
resources: list[DatabricksResource] = []
|
|
670
|
+
if self.name:
|
|
671
|
+
resources.append(
|
|
672
|
+
DatabricksFunction(
|
|
673
|
+
function_name=self.full_name,
|
|
674
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
675
|
+
)
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
w: WorkspaceClient = self.workspace_client
|
|
679
|
+
schema_full_name: str = self.schema_model.full_name
|
|
680
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
681
|
+
catalog_name=self.schema_model.catalog_name,
|
|
682
|
+
schema_name=self.schema_model.schema_name,
|
|
683
|
+
)
|
|
684
|
+
resources.extend(
|
|
685
|
+
[
|
|
686
|
+
DatabricksFunction(
|
|
687
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
688
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
689
|
+
)
|
|
690
|
+
for function in functions
|
|
691
|
+
]
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return resources
|
|
695
|
+
|
|
696
|
+
@property
|
|
697
|
+
def api_scopes(self) -> Sequence[str]:
|
|
698
|
+
return ["sql.statement-execution"]
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
class WarehouseModel(IsDatabricksResource):
|
|
702
|
+
model_config = ConfigDict()
|
|
703
|
+
name: str
|
|
704
|
+
description: Optional[str] = None
|
|
705
|
+
warehouse_id: AnyVariable
|
|
706
|
+
|
|
707
|
+
@property
|
|
708
|
+
def api_scopes(self) -> Sequence[str]:
|
|
709
|
+
return [
|
|
710
|
+
"sql.warehouses",
|
|
711
|
+
"sql.statement-execution",
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
715
|
+
return [
|
|
716
|
+
DatabricksSQLWarehouse(
|
|
717
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
718
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
719
|
+
)
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
@model_validator(mode="after")
|
|
723
|
+
def update_warehouse_id(self) -> Self:
|
|
724
|
+
self.warehouse_id = value_of(self.warehouse_id)
|
|
725
|
+
return self
|
|
726
|
+
|
|
727
|
+
|
|
590
728
|
class GenieRoomModel(IsDatabricksResource):
|
|
591
729
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
592
730
|
name: str
|
|
593
731
|
description: Optional[str] = None
|
|
594
732
|
space_id: AnyVariable
|
|
595
733
|
|
|
734
|
+
_space_details: Optional[GenieSpace] = PrivateAttr(default=None)
|
|
735
|
+
|
|
736
|
+
def _get_space_details(self) -> GenieSpace:
|
|
737
|
+
if self._space_details is None:
|
|
738
|
+
self._space_details = self.workspace_client.genie.get_space(
|
|
739
|
+
space_id=self.space_id, include_serialized_space=True
|
|
740
|
+
)
|
|
741
|
+
return self._space_details
|
|
742
|
+
|
|
743
|
+
def _parse_serialized_space(self) -> dict[str, Any]:
|
|
744
|
+
"""Parse the serialized_space JSON string and return the parsed data."""
|
|
745
|
+
import json
|
|
746
|
+
|
|
747
|
+
space_details = self._get_space_details()
|
|
748
|
+
if not space_details.serialized_space:
|
|
749
|
+
return {}
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
return json.loads(space_details.serialized_space)
|
|
753
|
+
except json.JSONDecodeError as e:
|
|
754
|
+
logger.warning(f"Failed to parse serialized_space: {e}")
|
|
755
|
+
return {}
|
|
756
|
+
|
|
757
|
+
@property
|
|
758
|
+
def warehouse(self) -> Optional[WarehouseModel]:
|
|
759
|
+
"""Extract warehouse information from the Genie space.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
WarehouseModel instance if warehouse_id is available, None otherwise.
|
|
763
|
+
"""
|
|
764
|
+
space_details: GenieSpace = self._get_space_details()
|
|
765
|
+
|
|
766
|
+
if not space_details.warehouse_id:
|
|
767
|
+
return None
|
|
768
|
+
|
|
769
|
+
try:
|
|
770
|
+
response: GetWarehouseResponse = self.workspace_client.warehouses.get(
|
|
771
|
+
space_details.warehouse_id
|
|
772
|
+
)
|
|
773
|
+
warehouse_name: str = response.name or space_details.warehouse_id
|
|
774
|
+
|
|
775
|
+
warehouse_model = WarehouseModel(
|
|
776
|
+
name=warehouse_name,
|
|
777
|
+
warehouse_id=space_details.warehouse_id,
|
|
778
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
779
|
+
service_principal=self.service_principal,
|
|
780
|
+
client_id=self.client_id,
|
|
781
|
+
client_secret=self.client_secret,
|
|
782
|
+
workspace_host=self.workspace_host,
|
|
783
|
+
pat=self.pat,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# Share the cached workspace client if available
|
|
787
|
+
if self._workspace_client is not None:
|
|
788
|
+
warehouse_model._workspace_client = self._workspace_client
|
|
789
|
+
|
|
790
|
+
return warehouse_model
|
|
791
|
+
except Exception as e:
|
|
792
|
+
logger.warning(
|
|
793
|
+
f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
|
|
794
|
+
)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
@property
|
|
798
|
+
def tables(self) -> list[TableModel]:
|
|
799
|
+
"""Extract tables from the serialized Genie space.
|
|
800
|
+
|
|
801
|
+
Databricks Genie stores tables in: data_sources.tables[].identifier
|
|
802
|
+
Only includes tables that actually exist in Unity Catalog.
|
|
803
|
+
"""
|
|
804
|
+
parsed_space = self._parse_serialized_space()
|
|
805
|
+
tables_list: list[TableModel] = []
|
|
806
|
+
|
|
807
|
+
# Primary structure: data_sources.tables with 'identifier' field
|
|
808
|
+
if "data_sources" in parsed_space:
|
|
809
|
+
data_sources = parsed_space["data_sources"]
|
|
810
|
+
if isinstance(data_sources, dict) and "tables" in data_sources:
|
|
811
|
+
tables_data = data_sources["tables"]
|
|
812
|
+
if isinstance(tables_data, list):
|
|
813
|
+
for table_item in tables_data:
|
|
814
|
+
table_name: str | None = None
|
|
815
|
+
if isinstance(table_item, dict):
|
|
816
|
+
# Standard Databricks structure uses 'identifier'
|
|
817
|
+
table_name = table_item.get("identifier") or table_item.get(
|
|
818
|
+
"name"
|
|
819
|
+
)
|
|
820
|
+
elif isinstance(table_item, str):
|
|
821
|
+
table_name = table_item
|
|
822
|
+
|
|
823
|
+
if table_name:
|
|
824
|
+
table_model = TableModel(
|
|
825
|
+
name=table_name,
|
|
826
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
827
|
+
service_principal=self.service_principal,
|
|
828
|
+
client_id=self.client_id,
|
|
829
|
+
client_secret=self.client_secret,
|
|
830
|
+
workspace_host=self.workspace_host,
|
|
831
|
+
pat=self.pat,
|
|
832
|
+
)
|
|
833
|
+
# Share the cached workspace client if available
|
|
834
|
+
if self._workspace_client is not None:
|
|
835
|
+
table_model._workspace_client = self._workspace_client
|
|
836
|
+
|
|
837
|
+
# Verify the table exists before adding
|
|
838
|
+
if not table_model.exists():
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
tables_list.append(table_model)
|
|
842
|
+
|
|
843
|
+
return tables_list
|
|
844
|
+
|
|
845
|
+
@property
|
|
846
|
+
def functions(self) -> list[FunctionModel]:
|
|
847
|
+
"""Extract functions from the serialized Genie space.
|
|
848
|
+
|
|
849
|
+
Databricks Genie stores functions in multiple locations:
|
|
850
|
+
- instructions.sql_functions[].identifier (SQL functions)
|
|
851
|
+
- data_sources.functions[].identifier (other functions)
|
|
852
|
+
Only includes functions that actually exist in Unity Catalog.
|
|
853
|
+
"""
|
|
854
|
+
parsed_space = self._parse_serialized_space()
|
|
855
|
+
functions_list: list[FunctionModel] = []
|
|
856
|
+
seen_functions: set[str] = set()
|
|
857
|
+
|
|
858
|
+
def add_function_if_exists(function_name: str) -> None:
|
|
859
|
+
"""Helper to add a function if it exists and hasn't been added."""
|
|
860
|
+
if function_name in seen_functions:
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
seen_functions.add(function_name)
|
|
864
|
+
function_model = FunctionModel(
|
|
865
|
+
name=function_name,
|
|
866
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
867
|
+
service_principal=self.service_principal,
|
|
868
|
+
client_id=self.client_id,
|
|
869
|
+
client_secret=self.client_secret,
|
|
870
|
+
workspace_host=self.workspace_host,
|
|
871
|
+
pat=self.pat,
|
|
872
|
+
)
|
|
873
|
+
# Share the cached workspace client if available
|
|
874
|
+
if self._workspace_client is not None:
|
|
875
|
+
function_model._workspace_client = self._workspace_client
|
|
876
|
+
|
|
877
|
+
# Verify the function exists before adding
|
|
878
|
+
if not function_model.exists():
|
|
879
|
+
return
|
|
880
|
+
|
|
881
|
+
functions_list.append(function_model)
|
|
882
|
+
|
|
883
|
+
# Primary structure: instructions.sql_functions with 'identifier' field
|
|
884
|
+
if "instructions" in parsed_space:
|
|
885
|
+
instructions = parsed_space["instructions"]
|
|
886
|
+
if isinstance(instructions, dict) and "sql_functions" in instructions:
|
|
887
|
+
sql_functions_data = instructions["sql_functions"]
|
|
888
|
+
if isinstance(sql_functions_data, list):
|
|
889
|
+
for function_item in sql_functions_data:
|
|
890
|
+
if isinstance(function_item, dict):
|
|
891
|
+
# SQL functions use 'identifier' field
|
|
892
|
+
function_name = function_item.get(
|
|
893
|
+
"identifier"
|
|
894
|
+
) or function_item.get("name")
|
|
895
|
+
if function_name:
|
|
896
|
+
add_function_if_exists(function_name)
|
|
897
|
+
|
|
898
|
+
# Secondary structure: data_sources.functions with 'identifier' field
|
|
899
|
+
if "data_sources" in parsed_space:
|
|
900
|
+
data_sources = parsed_space["data_sources"]
|
|
901
|
+
if isinstance(data_sources, dict) and "functions" in data_sources:
|
|
902
|
+
functions_data = data_sources["functions"]
|
|
903
|
+
if isinstance(functions_data, list):
|
|
904
|
+
for function_item in functions_data:
|
|
905
|
+
function_name: str | None = None
|
|
906
|
+
if isinstance(function_item, dict):
|
|
907
|
+
# Standard Databricks structure uses 'identifier'
|
|
908
|
+
function_name = function_item.get(
|
|
909
|
+
"identifier"
|
|
910
|
+
) or function_item.get("name")
|
|
911
|
+
elif isinstance(function_item, str):
|
|
912
|
+
function_name = function_item
|
|
913
|
+
|
|
914
|
+
if function_name:
|
|
915
|
+
add_function_if_exists(function_name)
|
|
916
|
+
|
|
917
|
+
return functions_list
|
|
918
|
+
|
|
596
919
|
@property
|
|
597
920
|
def api_scopes(self) -> Sequence[str]:
|
|
598
921
|
return [
|
|
@@ -612,6 +935,18 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
612
935
|
self.space_id = value_of(self.space_id)
|
|
613
936
|
return self
|
|
614
937
|
|
|
938
|
+
@model_validator(mode="after")
|
|
939
|
+
def update_description_from_space(self) -> Self:
|
|
940
|
+
"""Populate description from GenieSpace if not provided."""
|
|
941
|
+
if not self.description:
|
|
942
|
+
try:
|
|
943
|
+
space_details = self._get_space_details()
|
|
944
|
+
if space_details.description:
|
|
945
|
+
self.description = space_details.description
|
|
946
|
+
except Exception as e:
|
|
947
|
+
logger.debug(f"Could not fetch description from Genie space: {e}")
|
|
948
|
+
return self
|
|
949
|
+
|
|
615
950
|
|
|
616
951
|
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
617
952
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -674,27 +1009,92 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
674
1009
|
|
|
675
1010
|
|
|
676
1011
|
class VectorStoreModel(IsDatabricksResource):
|
|
1012
|
+
"""
|
|
1013
|
+
Configuration model for a Databricks Vector Search store.
|
|
1014
|
+
|
|
1015
|
+
Supports two modes:
|
|
1016
|
+
1. **Use Existing Index**: Provide only `index` (fully qualified name).
|
|
1017
|
+
Used for querying an existing vector search index at runtime.
|
|
1018
|
+
2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
|
|
1019
|
+
Used for creating a new vector search index.
|
|
1020
|
+
|
|
1021
|
+
Examples:
|
|
1022
|
+
Minimal configuration (use existing index):
|
|
1023
|
+
```yaml
|
|
1024
|
+
vector_stores:
|
|
1025
|
+
products_search:
|
|
1026
|
+
index:
|
|
1027
|
+
name: catalog.schema.my_index
|
|
1028
|
+
```
|
|
1029
|
+
|
|
1030
|
+
Full provisioning configuration:
|
|
1031
|
+
```yaml
|
|
1032
|
+
vector_stores:
|
|
1033
|
+
products_search:
|
|
1034
|
+
source_table:
|
|
1035
|
+
schema: *my_schema
|
|
1036
|
+
name: products
|
|
1037
|
+
embedding_source_column: description
|
|
1038
|
+
endpoint:
|
|
1039
|
+
name: my_endpoint
|
|
1040
|
+
```
|
|
1041
|
+
"""
|
|
1042
|
+
|
|
677
1043
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
678
|
-
|
|
1044
|
+
|
|
1045
|
+
# RUNTIME: Only index is truly required for querying existing indexes
|
|
679
1046
|
index: Optional[IndexModel] = None
|
|
1047
|
+
|
|
1048
|
+
# PROVISIONING ONLY: Required when creating a new index
|
|
1049
|
+
source_table: Optional[TableModel] = None
|
|
1050
|
+
embedding_source_column: Optional[str] = None
|
|
1051
|
+
embedding_model: Optional[LLMModel] = None
|
|
680
1052
|
endpoint: Optional[VectorSearchEndpoint] = None
|
|
681
|
-
|
|
1053
|
+
|
|
1054
|
+
# OPTIONAL: For both modes
|
|
682
1055
|
source_path: Optional[VolumePathModel] = None
|
|
683
1056
|
checkpoint_path: Optional[VolumePathModel] = None
|
|
684
1057
|
primary_key: Optional[str] = None
|
|
685
1058
|
columns: Optional[list[str]] = Field(default_factory=list)
|
|
686
1059
|
doc_uri: Optional[str] = None
|
|
687
|
-
|
|
1060
|
+
|
|
1061
|
+
@model_validator(mode="after")
|
|
1062
|
+
def validate_configuration_mode(self) -> Self:
|
|
1063
|
+
"""
|
|
1064
|
+
Validate that configuration is valid for either:
|
|
1065
|
+
- Use existing mode: index is provided
|
|
1066
|
+
- Provisioning mode: source_table + embedding_source_column provided
|
|
1067
|
+
"""
|
|
1068
|
+
has_index = self.index is not None
|
|
1069
|
+
has_source_table = self.source_table is not None
|
|
1070
|
+
has_embedding_col = self.embedding_source_column is not None
|
|
1071
|
+
|
|
1072
|
+
# Must have at least index OR source_table
|
|
1073
|
+
if not has_index and not has_source_table:
|
|
1074
|
+
raise ValueError(
|
|
1075
|
+
"Either 'index' (for existing indexes) or 'source_table' "
|
|
1076
|
+
"(for provisioning) must be provided"
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
# If provisioning mode, need embedding_source_column
|
|
1080
|
+
if has_source_table and not has_embedding_col:
|
|
1081
|
+
raise ValueError(
|
|
1082
|
+
"embedding_source_column is required when source_table is provided (provisioning mode)"
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
return self
|
|
688
1086
|
|
|
689
1087
|
@model_validator(mode="after")
|
|
690
1088
|
def set_default_embedding_model(self) -> Self:
|
|
691
|
-
|
|
1089
|
+
# Only set default embedding model in provisioning mode
|
|
1090
|
+
if self.source_table is not None and not self.embedding_model:
|
|
692
1091
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
693
1092
|
return self
|
|
694
1093
|
|
|
695
1094
|
@model_validator(mode="after")
|
|
696
1095
|
def set_default_primary_key(self) -> Self:
|
|
697
|
-
|
|
1096
|
+
# Only auto-discover primary key in provisioning mode
|
|
1097
|
+
if self.primary_key is None and self.source_table is not None:
|
|
698
1098
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
699
1099
|
|
|
700
1100
|
provider: DatabricksProvider = DatabricksProvider()
|
|
@@ -715,14 +1115,16 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
715
1115
|
|
|
716
1116
|
@model_validator(mode="after")
|
|
717
1117
|
def set_default_index(self) -> Self:
|
|
718
|
-
|
|
1118
|
+
# Only generate index from source_table in provisioning mode
|
|
1119
|
+
if self.index is None and self.source_table is not None:
|
|
719
1120
|
name: str = f"{self.source_table.name}_index"
|
|
720
1121
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
721
1122
|
return self
|
|
722
1123
|
|
|
723
1124
|
@model_validator(mode="after")
|
|
724
1125
|
def set_default_endpoint(self) -> Self:
|
|
725
|
-
|
|
1126
|
+
# Only find/create endpoint in provisioning mode
|
|
1127
|
+
if self.endpoint is None and self.source_table is not None:
|
|
726
1128
|
from dao_ai.providers.databricks import (
|
|
727
1129
|
DatabricksProvider,
|
|
728
1130
|
with_available_indexes,
|
|
@@ -772,61 +1174,6 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
772
1174
|
provider.create_vector_store(self)
|
|
773
1175
|
|
|
774
1176
|
|
|
775
|
-
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
776
|
-
model_config = ConfigDict()
|
|
777
|
-
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
778
|
-
name: Optional[str] = None
|
|
779
|
-
|
|
780
|
-
@model_validator(mode="after")
|
|
781
|
-
def validate_name_or_schema_required(self) -> "FunctionModel":
|
|
782
|
-
if not self.name and not self.schema_model:
|
|
783
|
-
raise ValueError(
|
|
784
|
-
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
785
|
-
)
|
|
786
|
-
return self
|
|
787
|
-
|
|
788
|
-
@property
|
|
789
|
-
def full_name(self) -> str:
|
|
790
|
-
if self.schema_model:
|
|
791
|
-
name: str = ""
|
|
792
|
-
if self.name:
|
|
793
|
-
name = f".{self.name}"
|
|
794
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
795
|
-
return self.name
|
|
796
|
-
|
|
797
|
-
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
798
|
-
resources: list[DatabricksResource] = []
|
|
799
|
-
if self.name:
|
|
800
|
-
resources.append(
|
|
801
|
-
DatabricksFunction(
|
|
802
|
-
function_name=self.full_name,
|
|
803
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
804
|
-
)
|
|
805
|
-
)
|
|
806
|
-
else:
|
|
807
|
-
w: WorkspaceClient = self.workspace_client
|
|
808
|
-
schema_full_name: str = self.schema_model.full_name
|
|
809
|
-
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
810
|
-
catalog_name=self.schema_model.catalog_name,
|
|
811
|
-
schema_name=self.schema_model.schema_name,
|
|
812
|
-
)
|
|
813
|
-
resources.extend(
|
|
814
|
-
[
|
|
815
|
-
DatabricksFunction(
|
|
816
|
-
function_name=f"{schema_full_name}.{function.name}",
|
|
817
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
818
|
-
)
|
|
819
|
-
for function in functions
|
|
820
|
-
]
|
|
821
|
-
)
|
|
822
|
-
|
|
823
|
-
return resources
|
|
824
|
-
|
|
825
|
-
@property
|
|
826
|
-
def api_scopes(self) -> Sequence[str]:
|
|
827
|
-
return ["sql.statement-execution"]
|
|
828
|
-
|
|
829
|
-
|
|
830
1177
|
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
831
1178
|
model_config = ConfigDict()
|
|
832
1179
|
name: str
|
|
@@ -854,55 +1201,25 @@ class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
|
854
1201
|
]
|
|
855
1202
|
|
|
856
1203
|
|
|
857
|
-
class WarehouseModel(IsDatabricksResource):
|
|
858
|
-
model_config = ConfigDict()
|
|
859
|
-
name: str
|
|
860
|
-
description: Optional[str] = None
|
|
861
|
-
warehouse_id: AnyVariable
|
|
862
|
-
|
|
863
|
-
@property
|
|
864
|
-
def api_scopes(self) -> Sequence[str]:
|
|
865
|
-
return [
|
|
866
|
-
"sql.warehouses",
|
|
867
|
-
"sql.statement-execution",
|
|
868
|
-
]
|
|
869
|
-
|
|
870
|
-
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
871
|
-
return [
|
|
872
|
-
DatabricksSQLWarehouse(
|
|
873
|
-
warehouse_id=value_of(self.warehouse_id),
|
|
874
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
875
|
-
)
|
|
876
|
-
]
|
|
877
|
-
|
|
878
|
-
@model_validator(mode="after")
|
|
879
|
-
def update_warehouse_id(self) -> Self:
|
|
880
|
-
self.warehouse_id = value_of(self.warehouse_id)
|
|
881
|
-
return self
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
class DatabaseType(str, Enum):
|
|
885
|
-
POSTGRES = "postgres"
|
|
886
|
-
LAKEBASE = "lakebase"
|
|
887
|
-
|
|
888
|
-
|
|
889
1204
|
class DatabaseModel(IsDatabricksResource):
|
|
890
1205
|
"""
|
|
891
|
-
Configuration for
|
|
1206
|
+
Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
|
|
892
1207
|
|
|
893
1208
|
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
894
1209
|
- user/password: For user-based database authentication
|
|
895
1210
|
|
|
896
|
-
|
|
897
|
-
-
|
|
898
|
-
-
|
|
1211
|
+
Connection Types (determined by fields provided):
|
|
1212
|
+
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1213
|
+
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
899
1214
|
|
|
900
|
-
|
|
1215
|
+
Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
|
|
1216
|
+
|
|
1217
|
+
Example Databricks Lakebase with Service Principal:
|
|
901
1218
|
```yaml
|
|
902
1219
|
databases:
|
|
903
1220
|
my_lakebase:
|
|
904
1221
|
name: my-database
|
|
905
|
-
|
|
1222
|
+
instance_name: my-lakebase-instance
|
|
906
1223
|
service_principal:
|
|
907
1224
|
client_id:
|
|
908
1225
|
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
@@ -913,28 +1230,31 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
913
1230
|
env: DATABRICKS_HOST
|
|
914
1231
|
```
|
|
915
1232
|
|
|
916
|
-
Example
|
|
1233
|
+
Example Databricks Lakebase with Ambient Authentication:
|
|
917
1234
|
```yaml
|
|
918
1235
|
databases:
|
|
919
1236
|
my_lakebase:
|
|
920
1237
|
name: my-database
|
|
921
|
-
|
|
922
|
-
|
|
1238
|
+
instance_name: my-lakebase-instance
|
|
1239
|
+
on_behalf_of_user: true
|
|
923
1240
|
```
|
|
924
1241
|
|
|
925
|
-
Example
|
|
1242
|
+
Example Standard PostgreSQL:
|
|
926
1243
|
```yaml
|
|
927
1244
|
databases:
|
|
928
|
-
|
|
1245
|
+
my_postgres:
|
|
929
1246
|
name: my-database
|
|
930
|
-
|
|
931
|
-
|
|
1247
|
+
host: my-postgres-host.example.com
|
|
1248
|
+
port: 5432
|
|
1249
|
+
database: my_db
|
|
1250
|
+
user: my_user
|
|
1251
|
+
password:
|
|
1252
|
+
env: PGPASSWORD
|
|
932
1253
|
```
|
|
933
1254
|
"""
|
|
934
1255
|
|
|
935
1256
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
936
1257
|
name: str
|
|
937
|
-
type: Optional[DatabaseType] = DatabaseType.LAKEBASE
|
|
938
1258
|
instance_name: Optional[str] = None
|
|
939
1259
|
description: Optional[str] = None
|
|
940
1260
|
host: Optional[AnyVariable] = None
|
|
@@ -949,27 +1269,38 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
949
1269
|
user: Optional[AnyVariable] = None
|
|
950
1270
|
password: Optional[AnyVariable] = None
|
|
951
1271
|
|
|
952
|
-
@field_serializer("type")
|
|
953
|
-
def serialize_type(self, value: DatabaseType | None) -> str | None:
|
|
954
|
-
"""Serialize the database type enum to its string value."""
|
|
955
|
-
return value.value if value is not None else None
|
|
956
|
-
|
|
957
1272
|
@property
|
|
958
1273
|
def api_scopes(self) -> Sequence[str]:
|
|
959
1274
|
return ["database.database-instances"]
|
|
960
1275
|
|
|
1276
|
+
@property
|
|
1277
|
+
def is_lakebase(self) -> bool:
|
|
1278
|
+
"""Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
|
|
1279
|
+
return self.instance_name is not None
|
|
1280
|
+
|
|
961
1281
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
1282
|
+
if self.is_lakebase:
|
|
1283
|
+
return [
|
|
1284
|
+
DatabricksLakebase(
|
|
1285
|
+
database_instance_name=self.instance_name,
|
|
1286
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
1287
|
+
)
|
|
1288
|
+
]
|
|
1289
|
+
return []
|
|
968
1290
|
|
|
969
1291
|
@model_validator(mode="after")
|
|
970
|
-
def
|
|
971
|
-
|
|
972
|
-
|
|
1292
|
+
def validate_connection_type(self) -> Self:
|
|
1293
|
+
"""Validate connection configuration based on type.
|
|
1294
|
+
|
|
1295
|
+
- If instance_name is provided: Databricks Lakebase connection
|
|
1296
|
+
(host is optional - will be fetched from API if not provided)
|
|
1297
|
+
- If only host is provided: Standard PostgreSQL connection
|
|
1298
|
+
(must not have instance_name)
|
|
1299
|
+
"""
|
|
1300
|
+
if not self.instance_name and not self.host:
|
|
1301
|
+
raise ValueError(
|
|
1302
|
+
"Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
|
|
1303
|
+
)
|
|
973
1304
|
return self
|
|
974
1305
|
|
|
975
1306
|
@model_validator(mode="after")
|
|
@@ -978,10 +1309,10 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
978
1309
|
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
979
1310
|
return self
|
|
980
1311
|
|
|
981
|
-
# For
|
|
982
|
-
# For
|
|
983
|
-
if self.
|
|
984
|
-
#
|
|
1312
|
+
# For standard PostgreSQL, we need explicit user credentials
|
|
1313
|
+
# For Lakebase with no auth, ambient auth is allowed
|
|
1314
|
+
if not self.is_lakebase:
|
|
1315
|
+
# Standard PostgreSQL - try to determine current user for local development
|
|
985
1316
|
try:
|
|
986
1317
|
self.user = self.workspace_client.current_user.me().user_name
|
|
987
1318
|
except Exception as e:
|
|
@@ -990,12 +1321,12 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
990
1321
|
f"Please provide explicit user credentials."
|
|
991
1322
|
)
|
|
992
1323
|
else:
|
|
993
|
-
# For
|
|
1324
|
+
# For Lakebase, try to determine current user but don't fail if we can't
|
|
994
1325
|
try:
|
|
995
1326
|
self.user = self.workspace_client.current_user.me().user_name
|
|
996
1327
|
except Exception:
|
|
997
1328
|
# If we can't determine user and no explicit auth, that's okay
|
|
998
|
-
# for
|
|
1329
|
+
# for Lakebase with ambient auth - credentials will be injected at runtime
|
|
999
1330
|
pass
|
|
1000
1331
|
|
|
1001
1332
|
return self
|
|
@@ -1005,28 +1336,29 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1005
1336
|
if self.host is not None:
|
|
1006
1337
|
return self
|
|
1007
1338
|
|
|
1008
|
-
#
|
|
1339
|
+
# If instance_name is provided (Lakebase), try to fetch host from existing instance
|
|
1009
1340
|
# This may fail for OBO/ambient auth during model logging (before deployment)
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
self.host = existing_instance.read_write_dns
|
|
1017
|
-
except Exception as e:
|
|
1018
|
-
# For lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1019
|
-
# The host will need to be provided explicitly or fetched at runtime
|
|
1020
|
-
if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
|
|
1021
|
-
logger.debug(
|
|
1022
|
-
f"Could not fetch host for database {self.instance_name} "
|
|
1023
|
-
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1024
|
-
)
|
|
1025
|
-
else:
|
|
1026
|
-
raise ValueError(
|
|
1027
|
-
f"Could not fetch host for database {self.instance_name}. "
|
|
1028
|
-
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1341
|
+
if self.is_lakebase:
|
|
1342
|
+
try:
|
|
1343
|
+
existing_instance: DatabaseInstance = (
|
|
1344
|
+
self.workspace_client.database.get_database_instance(
|
|
1345
|
+
name=self.instance_name
|
|
1346
|
+
)
|
|
1029
1347
|
)
|
|
1348
|
+
self.host = existing_instance.read_write_dns
|
|
1349
|
+
except Exception as e:
|
|
1350
|
+
# For Lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1351
|
+
# The host will need to be provided explicitly or fetched at runtime
|
|
1352
|
+
if self.on_behalf_of_user:
|
|
1353
|
+
logger.debug(
|
|
1354
|
+
f"Could not fetch host for database {self.instance_name} "
|
|
1355
|
+
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1356
|
+
)
|
|
1357
|
+
else:
|
|
1358
|
+
raise ValueError(
|
|
1359
|
+
f"Could not fetch host for database {self.instance_name}. "
|
|
1360
|
+
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1361
|
+
)
|
|
1030
1362
|
return self
|
|
1031
1363
|
|
|
1032
1364
|
@model_validator(mode="after")
|
|
@@ -1054,9 +1386,9 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1054
1386
|
"or user credentials (user)."
|
|
1055
1387
|
)
|
|
1056
1388
|
|
|
1057
|
-
# For
|
|
1058
|
-
# For
|
|
1059
|
-
if self.
|
|
1389
|
+
# For standard PostgreSQL (host-based), at least one auth method must be configured
|
|
1390
|
+
# For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
|
|
1391
|
+
if not self.is_lakebase and auth_methods_count == 0:
|
|
1060
1392
|
raise ValueError(
|
|
1061
1393
|
"PostgreSQL databases require explicit authentication. "
|
|
1062
1394
|
"Please provide one of: "
|
|
@@ -1074,24 +1406,23 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1074
1406
|
Get database connection parameters as a dictionary.
|
|
1075
1407
|
|
|
1076
1408
|
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
1077
|
-
|
|
1078
|
-
|
|
1409
|
+
|
|
1410
|
+
For Lakebase: Uses Databricks-generated credentials (token-based auth).
|
|
1411
|
+
For standard PostgreSQL: Uses provided user/password credentials.
|
|
1079
1412
|
"""
|
|
1080
1413
|
import uuid as _uuid
|
|
1081
1414
|
|
|
1082
1415
|
from databricks.sdk.service.database import DatabaseCredential
|
|
1083
1416
|
|
|
1417
|
+
host: str
|
|
1418
|
+
port: int
|
|
1419
|
+
database: str
|
|
1084
1420
|
username: str | None = None
|
|
1085
|
-
|
|
1086
|
-
if self.client_id and self.client_secret and self.workspace_host:
|
|
1087
|
-
username = value_of(self.client_id)
|
|
1088
|
-
elif self.user:
|
|
1089
|
-
username = value_of(self.user)
|
|
1090
|
-
# For OBO mode, no username is needed - the token identity is used
|
|
1421
|
+
password_value: str | None = None
|
|
1091
1422
|
|
|
1092
1423
|
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1093
1424
|
host_value: Any = self.host
|
|
1094
|
-
if host_value is None and self.on_behalf_of_user:
|
|
1425
|
+
if host_value is None and self.is_lakebase and self.on_behalf_of_user:
|
|
1095
1426
|
# Fetch host at runtime for OBO mode
|
|
1096
1427
|
existing_instance: DatabaseInstance = (
|
|
1097
1428
|
self.workspace_client.database.get_database_instance(
|
|
@@ -1101,29 +1432,50 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1101
1432
|
host_value = existing_instance.read_write_dns
|
|
1102
1433
|
|
|
1103
1434
|
if host_value is None:
|
|
1435
|
+
instance_or_name = self.instance_name if self.is_lakebase else self.name
|
|
1104
1436
|
raise ValueError(
|
|
1105
|
-
f"Database host not configured for {
|
|
1437
|
+
f"Database host not configured for {instance_or_name}. "
|
|
1106
1438
|
"Please provide 'host' explicitly."
|
|
1107
1439
|
)
|
|
1108
1440
|
|
|
1109
|
-
host
|
|
1110
|
-
port
|
|
1111
|
-
database
|
|
1441
|
+
host = value_of(host_value)
|
|
1442
|
+
port = value_of(self.port)
|
|
1443
|
+
database = value_of(self.database)
|
|
1112
1444
|
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1445
|
+
if self.is_lakebase:
|
|
1446
|
+
# Lakebase: Use Databricks-generated credentials
|
|
1447
|
+
if self.client_id and self.client_secret and self.workspace_host:
|
|
1448
|
+
username = value_of(self.client_id)
|
|
1449
|
+
elif self.user:
|
|
1450
|
+
username = value_of(self.user)
|
|
1451
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
1452
|
+
|
|
1453
|
+
# Generate Databricks database credential (token)
|
|
1454
|
+
w: WorkspaceClient = self.workspace_client
|
|
1455
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1456
|
+
request_id=str(_uuid.uuid4()),
|
|
1457
|
+
instance_names=[self.instance_name],
|
|
1458
|
+
)
|
|
1459
|
+
password_value = cred.token
|
|
1460
|
+
else:
|
|
1461
|
+
# Standard PostgreSQL: Use provided credentials
|
|
1462
|
+
if self.user:
|
|
1463
|
+
username = value_of(self.user)
|
|
1464
|
+
if self.password:
|
|
1465
|
+
password_value = value_of(self.password)
|
|
1466
|
+
|
|
1467
|
+
if not username or not password_value:
|
|
1468
|
+
raise ValueError(
|
|
1469
|
+
f"Standard PostgreSQL databases require both 'user' and 'password'. "
|
|
1470
|
+
f"Database: {self.name}"
|
|
1471
|
+
)
|
|
1120
1472
|
|
|
1121
1473
|
# Build connection parameters dictionary
|
|
1122
1474
|
params: dict[str, Any] = {
|
|
1123
1475
|
"dbname": database,
|
|
1124
1476
|
"host": host,
|
|
1125
1477
|
"port": port,
|
|
1126
|
-
"password":
|
|
1478
|
+
"password": password_value,
|
|
1127
1479
|
"sslmode": "require",
|
|
1128
1480
|
}
|
|
1129
1481
|
|
|
@@ -1264,11 +1616,13 @@ class RerankParametersModel(BaseModel):
|
|
|
1264
1616
|
top_n: 5 # Return top 5 after reranking
|
|
1265
1617
|
```
|
|
1266
1618
|
|
|
1267
|
-
Available models (
|
|
1268
|
-
- "ms-marco-TinyBERT-L-2-v2" (
|
|
1269
|
-
- "ms-marco-MiniLM-L-
|
|
1270
|
-
- "
|
|
1271
|
-
- "
|
|
1619
|
+
Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1620
|
+
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1621
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
|
|
1622
|
+
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1623
|
+
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1624
|
+
- "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
|
|
1625
|
+
- "miniReranker_arabic_v1" (Arabic language)
|
|
1272
1626
|
"""
|
|
1273
1627
|
|
|
1274
1628
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1282,8 +1636,8 @@ class RerankParametersModel(BaseModel):
|
|
|
1282
1636
|
description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
|
|
1283
1637
|
)
|
|
1284
1638
|
cache_dir: Optional[str] = Field(
|
|
1285
|
-
default="/
|
|
1286
|
-
description="Directory to cache downloaded model weights.",
|
|
1639
|
+
default="~/.dao_ai/cache/flashrank",
|
|
1640
|
+
description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
|
|
1287
1641
|
)
|
|
1288
1642
|
columns: Optional[list[str]] = Field(
|
|
1289
1643
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
@@ -1373,7 +1727,6 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1373
1727
|
discriminator="type",
|
|
1374
1728
|
)
|
|
1375
1729
|
type: FunctionType
|
|
1376
|
-
name: str
|
|
1377
1730
|
human_in_the_loop: Optional[HumanInTheLoopModel] = None
|
|
1378
1731
|
|
|
1379
1732
|
@abstractmethod
|
|
@@ -1390,6 +1743,7 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1390
1743
|
class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
1391
1744
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1392
1745
|
type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
|
|
1746
|
+
name: str
|
|
1393
1747
|
|
|
1394
1748
|
@property
|
|
1395
1749
|
def full_name(self) -> str:
|
|
@@ -1403,8 +1757,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1403
1757
|
|
|
1404
1758
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
1405
1759
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1406
|
-
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1407
1760
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
1761
|
+
name: str
|
|
1762
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1408
1763
|
|
|
1409
1764
|
@property
|
|
1410
1765
|
def full_name(self) -> str:
|
|
@@ -1427,7 +1782,7 @@ class TransportType(str, Enum):
|
|
|
1427
1782
|
STDIO = "stdio"
|
|
1428
1783
|
|
|
1429
1784
|
|
|
1430
|
-
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource
|
|
1785
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
1431
1786
|
"""
|
|
1432
1787
|
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1433
1788
|
|
|
@@ -1466,10 +1821,6 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
|
|
|
1466
1821
|
"""MCP functions don't declare static resources."""
|
|
1467
1822
|
return []
|
|
1468
1823
|
|
|
1469
|
-
@property
|
|
1470
|
-
def full_name(self) -> str:
|
|
1471
|
-
return self.name
|
|
1472
|
-
|
|
1473
1824
|
def _get_workspace_host(self) -> str:
|
|
1474
1825
|
"""
|
|
1475
1826
|
Get the workspace host, either from config or from workspace client.
|
|
@@ -1635,17 +1986,11 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
|
|
|
1635
1986
|
return create_mcp_tools(self)
|
|
1636
1987
|
|
|
1637
1988
|
|
|
1638
|
-
class UnityCatalogFunctionModel(BaseFunctionModel
|
|
1989
|
+
class UnityCatalogFunctionModel(BaseFunctionModel):
|
|
1639
1990
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1640
|
-
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1641
|
-
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1642
1991
|
type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
|
|
1643
|
-
|
|
1644
|
-
|
|
1645
|
-
def full_name(self) -> str:
|
|
1646
|
-
if self.schema_model:
|
|
1647
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1648
|
-
return self.name
|
|
1992
|
+
resource: FunctionModel
|
|
1993
|
+
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1649
1994
|
|
|
1650
1995
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1651
1996
|
from dao_ai.tools import create_uc_tools
|
|
@@ -1679,6 +2024,12 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1679
2024
|
alias: Optional[str] = None
|
|
1680
2025
|
version: Optional[int] = None
|
|
1681
2026
|
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
2027
|
+
auto_register: bool = Field(
|
|
2028
|
+
default=False,
|
|
2029
|
+
description="Whether to automatically register the default_template to the prompt registry. "
|
|
2030
|
+
"If False, the prompt will only be loaded from the registry (never created/updated). "
|
|
2031
|
+
"Defaults to True for backward compatibility.",
|
|
2032
|
+
)
|
|
1682
2033
|
|
|
1683
2034
|
@property
|
|
1684
2035
|
def template(self) -> str:
|
|
@@ -1760,23 +2111,17 @@ class MiddlewareModel(BaseModel):
|
|
|
1760
2111
|
class StorageType(str, Enum):
|
|
1761
2112
|
POSTGRES = "postgres"
|
|
1762
2113
|
MEMORY = "memory"
|
|
1763
|
-
LAKEBASE = "lakebase"
|
|
1764
2114
|
|
|
1765
2115
|
|
|
1766
2116
|
class CheckpointerModel(BaseModel):
|
|
1767
2117
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1768
2118
|
name: str
|
|
1769
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1770
2119
|
database: Optional[DatabaseModel] = None
|
|
1771
2120
|
|
|
1772
|
-
@
|
|
1773
|
-
def
|
|
1774
|
-
|
|
1775
|
-
|
|
1776
|
-
and not self.database
|
|
1777
|
-
):
|
|
1778
|
-
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1779
|
-
return self
|
|
2121
|
+
@property
|
|
2122
|
+
def storage_type(self) -> StorageType:
|
|
2123
|
+
"""Infer storage type from database presence."""
|
|
2124
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1780
2125
|
|
|
1781
2126
|
def as_checkpointer(self) -> BaseCheckpointSaver:
|
|
1782
2127
|
from dao_ai.memory import CheckpointManager
|
|
@@ -1792,16 +2137,14 @@ class StoreModel(BaseModel):
|
|
|
1792
2137
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1793
2138
|
name: str
|
|
1794
2139
|
embedding_model: Optional[LLMModel] = None
|
|
1795
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1796
2140
|
dims: Optional[int] = 1536
|
|
1797
2141
|
database: Optional[DatabaseModel] = None
|
|
1798
2142
|
namespace: Optional[str] = None
|
|
1799
2143
|
|
|
1800
|
-
@
|
|
1801
|
-
def
|
|
1802
|
-
|
|
1803
|
-
|
|
1804
|
-
return self
|
|
2144
|
+
@property
|
|
2145
|
+
def storage_type(self) -> StorageType:
|
|
2146
|
+
"""Infer storage type from database presence."""
|
|
2147
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1805
2148
|
|
|
1806
2149
|
def as_store(self) -> BaseStore:
|
|
1807
2150
|
from dao_ai.memory import StoreManager
|
|
@@ -2044,6 +2387,10 @@ class SwarmModel(BaseModel):
|
|
|
2044
2387
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2045
2388
|
model: LLMModel
|
|
2046
2389
|
default_agent: Optional[AgentModel | str] = None
|
|
2390
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2391
|
+
default_factory=list,
|
|
2392
|
+
description="List of middleware to apply to all agents in the swarm",
|
|
2393
|
+
)
|
|
2047
2394
|
handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
|
|
2048
2395
|
default_factory=dict
|
|
2049
2396
|
)
|
|
@@ -2606,7 +2953,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
|
|
|
2606
2953
|
|
|
2607
2954
|
class UnityCatalogFunctionSqlModel(BaseModel):
|
|
2608
2955
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2609
|
-
function:
|
|
2956
|
+
function: FunctionModel
|
|
2610
2957
|
ddl: str
|
|
2611
2958
|
parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
2612
2959
|
test: Optional[UnityCatalogFunctionSqlTestModel] = None
|
|
@@ -2636,6 +2983,113 @@ class ResourcesModel(BaseModel):
|
|
|
2636
2983
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
2637
2984
|
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
2638
2985
|
|
|
2986
|
+
@model_validator(mode="after")
|
|
2987
|
+
def update_genie_warehouses(self) -> Self:
|
|
2988
|
+
"""
|
|
2989
|
+
Automatically populate warehouses from genie_rooms.
|
|
2990
|
+
|
|
2991
|
+
Warehouses are extracted from each Genie room and added to the
|
|
2992
|
+
resources if they don't already exist (based on warehouse_id).
|
|
2993
|
+
"""
|
|
2994
|
+
if not self.genie_rooms:
|
|
2995
|
+
return self
|
|
2996
|
+
|
|
2997
|
+
# Process warehouses from all genie rooms
|
|
2998
|
+
for genie_room in self.genie_rooms.values():
|
|
2999
|
+
genie_room: GenieRoomModel
|
|
3000
|
+
warehouse: Optional[WarehouseModel] = genie_room.warehouse
|
|
3001
|
+
|
|
3002
|
+
if warehouse is None:
|
|
3003
|
+
continue
|
|
3004
|
+
|
|
3005
|
+
# Check if warehouse already exists based on warehouse_id
|
|
3006
|
+
warehouse_exists: bool = any(
|
|
3007
|
+
existing_warehouse.warehouse_id == warehouse.warehouse_id
|
|
3008
|
+
for existing_warehouse in self.warehouses.values()
|
|
3009
|
+
)
|
|
3010
|
+
|
|
3011
|
+
if not warehouse_exists:
|
|
3012
|
+
warehouse_key: str = normalize_name(
|
|
3013
|
+
"_".join([genie_room.name, warehouse.warehouse_id])
|
|
3014
|
+
)
|
|
3015
|
+
self.warehouses[warehouse_key] = warehouse
|
|
3016
|
+
logger.trace(
|
|
3017
|
+
"Added warehouse from Genie room",
|
|
3018
|
+
room=genie_room.name,
|
|
3019
|
+
warehouse=warehouse.warehouse_id,
|
|
3020
|
+
key=warehouse_key,
|
|
3021
|
+
)
|
|
3022
|
+
|
|
3023
|
+
return self
|
|
3024
|
+
|
|
3025
|
+
@model_validator(mode="after")
|
|
3026
|
+
def update_genie_tables(self) -> Self:
|
|
3027
|
+
"""
|
|
3028
|
+
Automatically populate tables from genie_rooms.
|
|
3029
|
+
|
|
3030
|
+
Tables are extracted from each Genie room and added to the
|
|
3031
|
+
resources if they don't already exist (based on full_name).
|
|
3032
|
+
"""
|
|
3033
|
+
if not self.genie_rooms:
|
|
3034
|
+
return self
|
|
3035
|
+
|
|
3036
|
+
# Process tables from all genie rooms
|
|
3037
|
+
for genie_room in self.genie_rooms.values():
|
|
3038
|
+
genie_room: GenieRoomModel
|
|
3039
|
+
for table in genie_room.tables:
|
|
3040
|
+
table: TableModel
|
|
3041
|
+
table_exists: bool = any(
|
|
3042
|
+
existing_table.full_name == table.full_name
|
|
3043
|
+
for existing_table in self.tables.values()
|
|
3044
|
+
)
|
|
3045
|
+
if not table_exists:
|
|
3046
|
+
table_key: str = normalize_name(
|
|
3047
|
+
"_".join([genie_room.name, table.full_name])
|
|
3048
|
+
)
|
|
3049
|
+
self.tables[table_key] = table
|
|
3050
|
+
logger.trace(
|
|
3051
|
+
"Added table from Genie room",
|
|
3052
|
+
room=genie_room.name,
|
|
3053
|
+
table=table.name,
|
|
3054
|
+
key=table_key,
|
|
3055
|
+
)
|
|
3056
|
+
|
|
3057
|
+
return self
|
|
3058
|
+
|
|
3059
|
+
@model_validator(mode="after")
|
|
3060
|
+
def update_genie_functions(self) -> Self:
|
|
3061
|
+
"""
|
|
3062
|
+
Automatically populate functions from genie_rooms.
|
|
3063
|
+
|
|
3064
|
+
Functions are extracted from each Genie room and added to the
|
|
3065
|
+
resources if they don't already exist (based on full_name).
|
|
3066
|
+
"""
|
|
3067
|
+
if not self.genie_rooms:
|
|
3068
|
+
return self
|
|
3069
|
+
|
|
3070
|
+
# Process functions from all genie rooms
|
|
3071
|
+
for genie_room in self.genie_rooms.values():
|
|
3072
|
+
genie_room: GenieRoomModel
|
|
3073
|
+
for function in genie_room.functions:
|
|
3074
|
+
function: FunctionModel
|
|
3075
|
+
function_exists: bool = any(
|
|
3076
|
+
existing_function.full_name == function.full_name
|
|
3077
|
+
for existing_function in self.functions.values()
|
|
3078
|
+
)
|
|
3079
|
+
if not function_exists:
|
|
3080
|
+
function_key: str = normalize_name(
|
|
3081
|
+
"_".join([genie_room.name, function.full_name])
|
|
3082
|
+
)
|
|
3083
|
+
self.functions[function_key] = function
|
|
3084
|
+
logger.trace(
|
|
3085
|
+
"Added function from Genie room",
|
|
3086
|
+
room=genie_room.name,
|
|
3087
|
+
function=function.name,
|
|
3088
|
+
key=function_key,
|
|
3089
|
+
)
|
|
3090
|
+
|
|
3091
|
+
return self
|
|
3092
|
+
|
|
2639
3093
|
|
|
2640
3094
|
class AppConfig(BaseModel):
|
|
2641
3095
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -2674,10 +3128,10 @@ class AppConfig(BaseModel):
|
|
|
2674
3128
|
|
|
2675
3129
|
def initialize(self) -> None:
|
|
2676
3130
|
from dao_ai.hooks.core import create_hooks
|
|
3131
|
+
from dao_ai.logging import configure_logging
|
|
2677
3132
|
|
|
2678
3133
|
if self.app and self.app.log_level:
|
|
2679
|
-
|
|
2680
|
-
logger.add(sys.stderr, level=self.app.log_level)
|
|
3134
|
+
configure_logging(level=self.app.log_level)
|
|
2681
3135
|
|
|
2682
3136
|
logger.debug("Calling initialization hooks...")
|
|
2683
3137
|
initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
|