dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. dao_ai/agent_as_code.py +2 -5
  2. dao_ai/cli.py +65 -15
  3. dao_ai/config.py +672 -218
  4. dao_ai/genie/cache/core.py +6 -2
  5. dao_ai/genie/cache/lru.py +29 -11
  6. dao_ai/genie/cache/semantic.py +95 -44
  7. dao_ai/hooks/core.py +5 -5
  8. dao_ai/logging.py +56 -0
  9. dao_ai/memory/core.py +61 -44
  10. dao_ai/memory/databricks.py +54 -41
  11. dao_ai/memory/postgres.py +77 -36
  12. dao_ai/middleware/assertions.py +45 -17
  13. dao_ai/middleware/core.py +13 -7
  14. dao_ai/middleware/guardrails.py +30 -25
  15. dao_ai/middleware/human_in_the_loop.py +9 -5
  16. dao_ai/middleware/message_validation.py +61 -29
  17. dao_ai/middleware/summarization.py +16 -11
  18. dao_ai/models.py +172 -69
  19. dao_ai/nodes.py +148 -19
  20. dao_ai/optimization.py +26 -16
  21. dao_ai/orchestration/core.py +15 -8
  22. dao_ai/orchestration/supervisor.py +22 -8
  23. dao_ai/orchestration/swarm.py +57 -12
  24. dao_ai/prompts.py +17 -17
  25. dao_ai/providers/databricks.py +365 -155
  26. dao_ai/state.py +24 -6
  27. dao_ai/tools/__init__.py +2 -0
  28. dao_ai/tools/agent.py +1 -3
  29. dao_ai/tools/core.py +7 -7
  30. dao_ai/tools/email.py +29 -77
  31. dao_ai/tools/genie.py +18 -13
  32. dao_ai/tools/mcp.py +223 -156
  33. dao_ai/tools/python.py +5 -2
  34. dao_ai/tools/search.py +1 -1
  35. dao_ai/tools/slack.py +21 -9
  36. dao_ai/tools/sql.py +202 -0
  37. dao_ai/tools/time.py +30 -7
  38. dao_ai/tools/unity_catalog.py +129 -86
  39. dao_ai/tools/vector_search.py +318 -244
  40. dao_ai/utils.py +15 -10
  41. dao_ai-0.1.3.dist-info/METADATA +455 -0
  42. dao_ai-0.1.3.dist-info/RECORD +64 -0
  43. dao_ai-0.1.1.dist-info/METADATA +0 -1878
  44. dao_ai-0.1.1.dist-info/RECORD +0 -62
  45. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
  46. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
  47. {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -23,8 +23,11 @@ from databricks.sdk.credentials_provider import (
23
23
  CredentialsStrategy,
24
24
  ModelServingUserCredentials,
25
25
  )
26
+ from databricks.sdk.errors.platform import NotFound
26
27
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
28
+ from databricks.sdk.service.dashboards import GenieSpace
27
29
  from databricks.sdk.service.database import DatabaseInstance
30
+ from databricks.sdk.service.sql import GetWarehouseResponse
28
31
  from databricks.vector_search.client import VectorSearchClient
29
32
  from databricks.vector_search.index import VectorSearchIndex
30
33
  from databricks_langchain import (
@@ -65,10 +68,13 @@ from pydantic import (
65
68
  BaseModel,
66
69
  ConfigDict,
67
70
  Field,
71
+ PrivateAttr,
68
72
  field_serializer,
69
73
  model_validator,
70
74
  )
71
75
 
76
+ from dao_ai.utils import normalize_name
77
+
72
78
 
73
79
  class HasValue(ABC):
74
80
  @abstractmethod
@@ -228,6 +234,9 @@ class IsDatabricksResource(ABC, BaseModel):
228
234
  workspace_host: Optional[AnyVariable] = None
229
235
  pat: Optional[AnyVariable] = None
230
236
 
237
+ # Private attribute to cache the workspace client (lazy instantiation)
238
+ _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
239
+
231
240
  @abstractmethod
232
241
  def as_resources(self) -> Sequence[DatabricksResource]: ...
233
242
 
@@ -263,6 +272,8 @@ class IsDatabricksResource(ABC, BaseModel):
263
272
  """
264
273
  Get a WorkspaceClient configured with the appropriate authentication.
265
274
 
275
+ The client is lazily instantiated on first access and cached for subsequent calls.
276
+
266
277
  Authentication priority:
267
278
  1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
268
279
  2. If service principal credentials are configured (client_id, client_secret,
@@ -270,6 +281,10 @@ class IsDatabricksResource(ABC, BaseModel):
270
281
  3. If PAT is configured, uses token authentication
271
282
  4. Otherwise, uses default/ambient authentication
272
283
  """
284
+ # Return cached client if already instantiated
285
+ if self._workspace_client is not None:
286
+ return self._workspace_client
287
+
273
288
  from dao_ai.utils import normalize_host
274
289
 
275
290
  # Check for OBO first (highest priority)
@@ -279,7 +294,10 @@ class IsDatabricksResource(ABC, BaseModel):
279
294
  f"Creating WorkspaceClient for {self.__class__.__name__} "
280
295
  f"with OBO credentials strategy"
281
296
  )
282
- return WorkspaceClient(credentials_strategy=credentials_strategy)
297
+ self._workspace_client = WorkspaceClient(
298
+ credentials_strategy=credentials_strategy
299
+ )
300
+ return self._workspace_client
283
301
 
284
302
  # Check for service principal credentials
285
303
  client_id_value: str | None = (
@@ -299,12 +317,13 @@ class IsDatabricksResource(ABC, BaseModel):
299
317
  f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
300
318
  f"client_id={client_id_value}, host={workspace_host_value}"
301
319
  )
302
- return WorkspaceClient(
320
+ self._workspace_client = WorkspaceClient(
303
321
  host=workspace_host_value,
304
322
  client_id=client_id_value,
305
323
  client_secret=client_secret_value,
306
324
  auth_type="oauth-m2m",
307
325
  )
326
+ return self._workspace_client
308
327
 
309
328
  # Check for PAT authentication
310
329
  pat_value: str | None = value_of(self.pat) if self.pat else None
@@ -312,18 +331,20 @@ class IsDatabricksResource(ABC, BaseModel):
312
331
  logger.debug(
313
332
  f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
314
333
  )
315
- return WorkspaceClient(
334
+ self._workspace_client = WorkspaceClient(
316
335
  host=workspace_host_value,
317
336
  token=pat_value,
318
337
  auth_type="pat",
319
338
  )
339
+ return self._workspace_client
320
340
 
321
341
  # Default: use ambient authentication
322
342
  logger.debug(
323
343
  f"Creating WorkspaceClient for {self.__class__.__name__} "
324
344
  "with default/ambient authentication"
325
345
  )
326
- return WorkspaceClient()
346
+ self._workspace_client = WorkspaceClient()
347
+ return self._workspace_client
327
348
 
328
349
 
329
350
  class Privilege(str, Enum):
@@ -431,6 +452,22 @@ class TableModel(IsDatabricksResource, HasFullName):
431
452
  def api_scopes(self) -> Sequence[str]:
432
453
  return []
433
454
 
455
+ def exists(self) -> bool:
456
+ """Check if the table exists in Unity Catalog.
457
+
458
+ Returns:
459
+ True if the table exists, False otherwise.
460
+ """
461
+ try:
462
+ self.workspace_client.tables.get(full_name=self.full_name)
463
+ return True
464
+ except NotFound:
465
+ logger.debug(f"Table not found: {self.full_name}")
466
+ return False
467
+ except Exception as e:
468
+ logger.warning(f"Error checking table existence for {self.full_name}: {e}")
469
+ return False
470
+
434
471
  def as_resources(self) -> Sequence[DatabricksResource]:
435
472
  resources: list[DatabricksResource] = []
436
473
 
@@ -477,6 +514,7 @@ class TableModel(IsDatabricksResource, HasFullName):
477
514
  class LLMModel(IsDatabricksResource):
478
515
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
479
516
  name: str
517
+ description: Optional[str] = None
480
518
  temperature: Optional[float] = 0.1
481
519
  max_tokens: Optional[int] = 8192
482
520
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
@@ -587,12 +625,297 @@ class IndexModel(IsDatabricksResource, HasFullName):
587
625
  ]
588
626
 
589
627
 
628
+ class FunctionModel(IsDatabricksResource, HasFullName):
629
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
630
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
631
+ name: Optional[str] = None
632
+
633
+ @model_validator(mode="after")
634
+ def validate_name_or_schema_required(self) -> Self:
635
+ if not self.name and not self.schema_model:
636
+ raise ValueError(
637
+ "Either 'name' or 'schema_model' must be provided for FunctionModel"
638
+ )
639
+ return self
640
+
641
+ @property
642
+ def full_name(self) -> str:
643
+ if self.schema_model:
644
+ name: str = ""
645
+ if self.name:
646
+ name = f".{self.name}"
647
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
648
+ return self.name
649
+
650
+ def exists(self) -> bool:
651
+ """Check if the function exists in Unity Catalog.
652
+
653
+ Returns:
654
+ True if the function exists, False otherwise.
655
+ """
656
+ try:
657
+ self.workspace_client.functions.get(name=self.full_name)
658
+ return True
659
+ except NotFound:
660
+ logger.debug(f"Function not found: {self.full_name}")
661
+ return False
662
+ except Exception as e:
663
+ logger.warning(
664
+ f"Error checking function existence for {self.full_name}: {e}"
665
+ )
666
+ return False
667
+
668
+ def as_resources(self) -> Sequence[DatabricksResource]:
669
+ resources: list[DatabricksResource] = []
670
+ if self.name:
671
+ resources.append(
672
+ DatabricksFunction(
673
+ function_name=self.full_name,
674
+ on_behalf_of_user=self.on_behalf_of_user,
675
+ )
676
+ )
677
+ else:
678
+ w: WorkspaceClient = self.workspace_client
679
+ schema_full_name: str = self.schema_model.full_name
680
+ functions: Iterator[FunctionInfo] = w.functions.list(
681
+ catalog_name=self.schema_model.catalog_name,
682
+ schema_name=self.schema_model.schema_name,
683
+ )
684
+ resources.extend(
685
+ [
686
+ DatabricksFunction(
687
+ function_name=f"{schema_full_name}.{function.name}",
688
+ on_behalf_of_user=self.on_behalf_of_user,
689
+ )
690
+ for function in functions
691
+ ]
692
+ )
693
+
694
+ return resources
695
+
696
+ @property
697
+ def api_scopes(self) -> Sequence[str]:
698
+ return ["sql.statement-execution"]
699
+
700
+
701
+ class WarehouseModel(IsDatabricksResource):
702
+ model_config = ConfigDict()
703
+ name: str
704
+ description: Optional[str] = None
705
+ warehouse_id: AnyVariable
706
+
707
+ @property
708
+ def api_scopes(self) -> Sequence[str]:
709
+ return [
710
+ "sql.warehouses",
711
+ "sql.statement-execution",
712
+ ]
713
+
714
+ def as_resources(self) -> Sequence[DatabricksResource]:
715
+ return [
716
+ DatabricksSQLWarehouse(
717
+ warehouse_id=value_of(self.warehouse_id),
718
+ on_behalf_of_user=self.on_behalf_of_user,
719
+ )
720
+ ]
721
+
722
+ @model_validator(mode="after")
723
+ def update_warehouse_id(self) -> Self:
724
+ self.warehouse_id = value_of(self.warehouse_id)
725
+ return self
726
+
727
+
590
728
  class GenieRoomModel(IsDatabricksResource):
591
729
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
592
730
  name: str
593
731
  description: Optional[str] = None
594
732
  space_id: AnyVariable
595
733
 
734
+ _space_details: Optional[GenieSpace] = PrivateAttr(default=None)
735
+
736
+ def _get_space_details(self) -> GenieSpace:
737
+ if self._space_details is None:
738
+ self._space_details = self.workspace_client.genie.get_space(
739
+ space_id=self.space_id, include_serialized_space=True
740
+ )
741
+ return self._space_details
742
+
743
+ def _parse_serialized_space(self) -> dict[str, Any]:
744
+ """Parse the serialized_space JSON string and return the parsed data."""
745
+ import json
746
+
747
+ space_details = self._get_space_details()
748
+ if not space_details.serialized_space:
749
+ return {}
750
+
751
+ try:
752
+ return json.loads(space_details.serialized_space)
753
+ except json.JSONDecodeError as e:
754
+ logger.warning(f"Failed to parse serialized_space: {e}")
755
+ return {}
756
+
757
+ @property
758
+ def warehouse(self) -> Optional[WarehouseModel]:
759
+ """Extract warehouse information from the Genie space.
760
+
761
+ Returns:
762
+ WarehouseModel instance if warehouse_id is available, None otherwise.
763
+ """
764
+ space_details: GenieSpace = self._get_space_details()
765
+
766
+ if not space_details.warehouse_id:
767
+ return None
768
+
769
+ try:
770
+ response: GetWarehouseResponse = self.workspace_client.warehouses.get(
771
+ space_details.warehouse_id
772
+ )
773
+ warehouse_name: str = response.name or space_details.warehouse_id
774
+
775
+ warehouse_model = WarehouseModel(
776
+ name=warehouse_name,
777
+ warehouse_id=space_details.warehouse_id,
778
+ on_behalf_of_user=self.on_behalf_of_user,
779
+ service_principal=self.service_principal,
780
+ client_id=self.client_id,
781
+ client_secret=self.client_secret,
782
+ workspace_host=self.workspace_host,
783
+ pat=self.pat,
784
+ )
785
+
786
+ # Share the cached workspace client if available
787
+ if self._workspace_client is not None:
788
+ warehouse_model._workspace_client = self._workspace_client
789
+
790
+ return warehouse_model
791
+ except Exception as e:
792
+ logger.warning(
793
+ f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
794
+ )
795
+ return None
796
+
797
+ @property
798
+ def tables(self) -> list[TableModel]:
799
+ """Extract tables from the serialized Genie space.
800
+
801
+ Databricks Genie stores tables in: data_sources.tables[].identifier
802
+ Only includes tables that actually exist in Unity Catalog.
803
+ """
804
+ parsed_space = self._parse_serialized_space()
805
+ tables_list: list[TableModel] = []
806
+
807
+ # Primary structure: data_sources.tables with 'identifier' field
808
+ if "data_sources" in parsed_space:
809
+ data_sources = parsed_space["data_sources"]
810
+ if isinstance(data_sources, dict) and "tables" in data_sources:
811
+ tables_data = data_sources["tables"]
812
+ if isinstance(tables_data, list):
813
+ for table_item in tables_data:
814
+ table_name: str | None = None
815
+ if isinstance(table_item, dict):
816
+ # Standard Databricks structure uses 'identifier'
817
+ table_name = table_item.get("identifier") or table_item.get(
818
+ "name"
819
+ )
820
+ elif isinstance(table_item, str):
821
+ table_name = table_item
822
+
823
+ if table_name:
824
+ table_model = TableModel(
825
+ name=table_name,
826
+ on_behalf_of_user=self.on_behalf_of_user,
827
+ service_principal=self.service_principal,
828
+ client_id=self.client_id,
829
+ client_secret=self.client_secret,
830
+ workspace_host=self.workspace_host,
831
+ pat=self.pat,
832
+ )
833
+ # Share the cached workspace client if available
834
+ if self._workspace_client is not None:
835
+ table_model._workspace_client = self._workspace_client
836
+
837
+ # Verify the table exists before adding
838
+ if not table_model.exists():
839
+ continue
840
+
841
+ tables_list.append(table_model)
842
+
843
+ return tables_list
844
+
845
+ @property
846
+ def functions(self) -> list[FunctionModel]:
847
+ """Extract functions from the serialized Genie space.
848
+
849
+ Databricks Genie stores functions in multiple locations:
850
+ - instructions.sql_functions[].identifier (SQL functions)
851
+ - data_sources.functions[].identifier (other functions)
852
+ Only includes functions that actually exist in Unity Catalog.
853
+ """
854
+ parsed_space = self._parse_serialized_space()
855
+ functions_list: list[FunctionModel] = []
856
+ seen_functions: set[str] = set()
857
+
858
+ def add_function_if_exists(function_name: str) -> None:
859
+ """Helper to add a function if it exists and hasn't been added."""
860
+ if function_name in seen_functions:
861
+ return
862
+
863
+ seen_functions.add(function_name)
864
+ function_model = FunctionModel(
865
+ name=function_name,
866
+ on_behalf_of_user=self.on_behalf_of_user,
867
+ service_principal=self.service_principal,
868
+ client_id=self.client_id,
869
+ client_secret=self.client_secret,
870
+ workspace_host=self.workspace_host,
871
+ pat=self.pat,
872
+ )
873
+ # Share the cached workspace client if available
874
+ if self._workspace_client is not None:
875
+ function_model._workspace_client = self._workspace_client
876
+
877
+ # Verify the function exists before adding
878
+ if not function_model.exists():
879
+ return
880
+
881
+ functions_list.append(function_model)
882
+
883
+ # Primary structure: instructions.sql_functions with 'identifier' field
884
+ if "instructions" in parsed_space:
885
+ instructions = parsed_space["instructions"]
886
+ if isinstance(instructions, dict) and "sql_functions" in instructions:
887
+ sql_functions_data = instructions["sql_functions"]
888
+ if isinstance(sql_functions_data, list):
889
+ for function_item in sql_functions_data:
890
+ if isinstance(function_item, dict):
891
+ # SQL functions use 'identifier' field
892
+ function_name = function_item.get(
893
+ "identifier"
894
+ ) or function_item.get("name")
895
+ if function_name:
896
+ add_function_if_exists(function_name)
897
+
898
+ # Secondary structure: data_sources.functions with 'identifier' field
899
+ if "data_sources" in parsed_space:
900
+ data_sources = parsed_space["data_sources"]
901
+ if isinstance(data_sources, dict) and "functions" in data_sources:
902
+ functions_data = data_sources["functions"]
903
+ if isinstance(functions_data, list):
904
+ for function_item in functions_data:
905
+ function_name: str | None = None
906
+ if isinstance(function_item, dict):
907
+ # Standard Databricks structure uses 'identifier'
908
+ function_name = function_item.get(
909
+ "identifier"
910
+ ) or function_item.get("name")
911
+ elif isinstance(function_item, str):
912
+ function_name = function_item
913
+
914
+ if function_name:
915
+ add_function_if_exists(function_name)
916
+
917
+ return functions_list
918
+
596
919
  @property
597
920
  def api_scopes(self) -> Sequence[str]:
598
921
  return [
@@ -612,6 +935,18 @@ class GenieRoomModel(IsDatabricksResource):
612
935
  self.space_id = value_of(self.space_id)
613
936
  return self
614
937
 
938
+ @model_validator(mode="after")
939
+ def update_description_from_space(self) -> Self:
940
+ """Populate description from GenieSpace if not provided."""
941
+ if not self.description:
942
+ try:
943
+ space_details = self._get_space_details()
944
+ if space_details.description:
945
+ self.description = space_details.description
946
+ except Exception as e:
947
+ logger.debug(f"Could not fetch description from Genie space: {e}")
948
+ return self
949
+
615
950
 
616
951
  class VolumeModel(IsDatabricksResource, HasFullName):
617
952
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -674,27 +1009,92 @@ class VolumePathModel(BaseModel, HasFullName):
674
1009
 
675
1010
 
676
1011
  class VectorStoreModel(IsDatabricksResource):
1012
+ """
1013
+ Configuration model for a Databricks Vector Search store.
1014
+
1015
+ Supports two modes:
1016
+ 1. **Use Existing Index**: Provide only `index` (fully qualified name).
1017
+ Used for querying an existing vector search index at runtime.
1018
+ 2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
1019
+ Used for creating a new vector search index.
1020
+
1021
+ Examples:
1022
+ Minimal configuration (use existing index):
1023
+ ```yaml
1024
+ vector_stores:
1025
+ products_search:
1026
+ index:
1027
+ name: catalog.schema.my_index
1028
+ ```
1029
+
1030
+ Full provisioning configuration:
1031
+ ```yaml
1032
+ vector_stores:
1033
+ products_search:
1034
+ source_table:
1035
+ schema: *my_schema
1036
+ name: products
1037
+ embedding_source_column: description
1038
+ endpoint:
1039
+ name: my_endpoint
1040
+ ```
1041
+ """
1042
+
677
1043
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
678
- embedding_model: Optional[LLMModel] = None
1044
+
1045
+ # RUNTIME: Only index is truly required for querying existing indexes
679
1046
  index: Optional[IndexModel] = None
1047
+
1048
+ # PROVISIONING ONLY: Required when creating a new index
1049
+ source_table: Optional[TableModel] = None
1050
+ embedding_source_column: Optional[str] = None
1051
+ embedding_model: Optional[LLMModel] = None
680
1052
  endpoint: Optional[VectorSearchEndpoint] = None
681
- source_table: TableModel
1053
+
1054
+ # OPTIONAL: For both modes
682
1055
  source_path: Optional[VolumePathModel] = None
683
1056
  checkpoint_path: Optional[VolumePathModel] = None
684
1057
  primary_key: Optional[str] = None
685
1058
  columns: Optional[list[str]] = Field(default_factory=list)
686
1059
  doc_uri: Optional[str] = None
687
- embedding_source_column: str
1060
+
1061
+ @model_validator(mode="after")
1062
+ def validate_configuration_mode(self) -> Self:
1063
+ """
1064
+ Validate that configuration is valid for either:
1065
+ - Use existing mode: index is provided
1066
+ - Provisioning mode: source_table + embedding_source_column provided
1067
+ """
1068
+ has_index = self.index is not None
1069
+ has_source_table = self.source_table is not None
1070
+ has_embedding_col = self.embedding_source_column is not None
1071
+
1072
+ # Must have at least index OR source_table
1073
+ if not has_index and not has_source_table:
1074
+ raise ValueError(
1075
+ "Either 'index' (for existing indexes) or 'source_table' "
1076
+ "(for provisioning) must be provided"
1077
+ )
1078
+
1079
+ # If provisioning mode, need embedding_source_column
1080
+ if has_source_table and not has_embedding_col:
1081
+ raise ValueError(
1082
+ "embedding_source_column is required when source_table is provided (provisioning mode)"
1083
+ )
1084
+
1085
+ return self
688
1086
 
689
1087
  @model_validator(mode="after")
690
1088
  def set_default_embedding_model(self) -> Self:
691
- if not self.embedding_model:
1089
+ # Only set default embedding model in provisioning mode
1090
+ if self.source_table is not None and not self.embedding_model:
692
1091
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
693
1092
  return self
694
1093
 
695
1094
  @model_validator(mode="after")
696
1095
  def set_default_primary_key(self) -> Self:
697
- if self.primary_key is None:
1096
+ # Only auto-discover primary key in provisioning mode
1097
+ if self.primary_key is None and self.source_table is not None:
698
1098
  from dao_ai.providers.databricks import DatabricksProvider
699
1099
 
700
1100
  provider: DatabricksProvider = DatabricksProvider()
@@ -715,14 +1115,16 @@ class VectorStoreModel(IsDatabricksResource):
715
1115
 
716
1116
  @model_validator(mode="after")
717
1117
  def set_default_index(self) -> Self:
718
- if self.index is None:
1118
+ # Only generate index from source_table in provisioning mode
1119
+ if self.index is None and self.source_table is not None:
719
1120
  name: str = f"{self.source_table.name}_index"
720
1121
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
721
1122
  return self
722
1123
 
723
1124
  @model_validator(mode="after")
724
1125
  def set_default_endpoint(self) -> Self:
725
- if self.endpoint is None:
1126
+ # Only find/create endpoint in provisioning mode
1127
+ if self.endpoint is None and self.source_table is not None:
726
1128
  from dao_ai.providers.databricks import (
727
1129
  DatabricksProvider,
728
1130
  with_available_indexes,
@@ -772,61 +1174,6 @@ class VectorStoreModel(IsDatabricksResource):
772
1174
  provider.create_vector_store(self)
773
1175
 
774
1176
 
775
- class FunctionModel(IsDatabricksResource, HasFullName):
776
- model_config = ConfigDict()
777
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
778
- name: Optional[str] = None
779
-
780
- @model_validator(mode="after")
781
- def validate_name_or_schema_required(self) -> "FunctionModel":
782
- if not self.name and not self.schema_model:
783
- raise ValueError(
784
- "Either 'name' or 'schema_model' must be provided for FunctionModel"
785
- )
786
- return self
787
-
788
- @property
789
- def full_name(self) -> str:
790
- if self.schema_model:
791
- name: str = ""
792
- if self.name:
793
- name = f".{self.name}"
794
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
795
- return self.name
796
-
797
- def as_resources(self) -> Sequence[DatabricksResource]:
798
- resources: list[DatabricksResource] = []
799
- if self.name:
800
- resources.append(
801
- DatabricksFunction(
802
- function_name=self.full_name,
803
- on_behalf_of_user=self.on_behalf_of_user,
804
- )
805
- )
806
- else:
807
- w: WorkspaceClient = self.workspace_client
808
- schema_full_name: str = self.schema_model.full_name
809
- functions: Iterator[FunctionInfo] = w.functions.list(
810
- catalog_name=self.schema_model.catalog_name,
811
- schema_name=self.schema_model.schema_name,
812
- )
813
- resources.extend(
814
- [
815
- DatabricksFunction(
816
- function_name=f"{schema_full_name}.{function.name}",
817
- on_behalf_of_user=self.on_behalf_of_user,
818
- )
819
- for function in functions
820
- ]
821
- )
822
-
823
- return resources
824
-
825
- @property
826
- def api_scopes(self) -> Sequence[str]:
827
- return ["sql.statement-execution"]
828
-
829
-
830
1177
  class ConnectionModel(IsDatabricksResource, HasFullName):
831
1178
  model_config = ConfigDict()
832
1179
  name: str
@@ -854,55 +1201,25 @@ class ConnectionModel(IsDatabricksResource, HasFullName):
854
1201
  ]
855
1202
 
856
1203
 
857
- class WarehouseModel(IsDatabricksResource):
858
- model_config = ConfigDict()
859
- name: str
860
- description: Optional[str] = None
861
- warehouse_id: AnyVariable
862
-
863
- @property
864
- def api_scopes(self) -> Sequence[str]:
865
- return [
866
- "sql.warehouses",
867
- "sql.statement-execution",
868
- ]
869
-
870
- def as_resources(self) -> Sequence[DatabricksResource]:
871
- return [
872
- DatabricksSQLWarehouse(
873
- warehouse_id=value_of(self.warehouse_id),
874
- on_behalf_of_user=self.on_behalf_of_user,
875
- )
876
- ]
877
-
878
- @model_validator(mode="after")
879
- def update_warehouse_id(self) -> Self:
880
- self.warehouse_id = value_of(self.warehouse_id)
881
- return self
882
-
883
-
884
- class DatabaseType(str, Enum):
885
- POSTGRES = "postgres"
886
- LAKEBASE = "lakebase"
887
-
888
-
889
1204
  class DatabaseModel(IsDatabricksResource):
890
1205
  """
891
- Configuration for a Databricks Lakebase (PostgreSQL) database instance.
1206
+ Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
892
1207
 
893
1208
  Authentication is inherited from IsDatabricksResource. Additionally supports:
894
1209
  - user/password: For user-based database authentication
895
1210
 
896
- Database Type:
897
- - lakebase: Databricks-managed Lakebase instance (authentication optional, supports ambient auth)
898
- - postgres: Standard PostgreSQL database (authentication required)
1211
+ Connection Types (determined by fields provided):
1212
+ - Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
1213
+ - Standard PostgreSQL: Provide `host` (authentication required via user/password)
899
1214
 
900
- Example Service Principal Configuration:
1215
+ Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
1216
+
1217
+ Example Databricks Lakebase with Service Principal:
901
1218
  ```yaml
902
1219
  databases:
903
1220
  my_lakebase:
904
1221
  name: my-database
905
- type: lakebase
1222
+ instance_name: my-lakebase-instance
906
1223
  service_principal:
907
1224
  client_id:
908
1225
  env: SERVICE_PRINCIPAL_CLIENT_ID
@@ -913,28 +1230,31 @@ class DatabaseModel(IsDatabricksResource):
913
1230
  env: DATABRICKS_HOST
914
1231
  ```
915
1232
 
916
- Example User Configuration:
1233
+ Example Databricks Lakebase with Ambient Authentication:
917
1234
  ```yaml
918
1235
  databases:
919
1236
  my_lakebase:
920
1237
  name: my-database
921
- type: lakebase
922
- user: my-user@databricks.com
1238
+ instance_name: my-lakebase-instance
1239
+ on_behalf_of_user: true
923
1240
  ```
924
1241
 
925
- Example Ambient Authentication (Lakebase only):
1242
+ Example Standard PostgreSQL:
926
1243
  ```yaml
927
1244
  databases:
928
- my_lakebase:
1245
+ my_postgres:
929
1246
  name: my-database
930
- type: lakebase
931
- on_behalf_of_user: true
1247
+ host: my-postgres-host.example.com
1248
+ port: 5432
1249
+ database: my_db
1250
+ user: my_user
1251
+ password:
1252
+ env: PGPASSWORD
932
1253
  ```
933
1254
  """
934
1255
 
935
1256
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
936
1257
  name: str
937
- type: Optional[DatabaseType] = DatabaseType.LAKEBASE
938
1258
  instance_name: Optional[str] = None
939
1259
  description: Optional[str] = None
940
1260
  host: Optional[AnyVariable] = None
@@ -949,27 +1269,38 @@ class DatabaseModel(IsDatabricksResource):
949
1269
  user: Optional[AnyVariable] = None
950
1270
  password: Optional[AnyVariable] = None
951
1271
 
952
- @field_serializer("type")
953
- def serialize_type(self, value: DatabaseType | None) -> str | None:
954
- """Serialize the database type enum to its string value."""
955
- return value.value if value is not None else None
956
-
957
1272
  @property
958
1273
  def api_scopes(self) -> Sequence[str]:
959
1274
  return ["database.database-instances"]
960
1275
 
1276
+ @property
1277
+ def is_lakebase(self) -> bool:
1278
+ """Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
1279
+ return self.instance_name is not None
1280
+
961
1281
  def as_resources(self) -> Sequence[DatabricksResource]:
962
- return [
963
- DatabricksLakebase(
964
- database_instance_name=self.instance_name,
965
- on_behalf_of_user=self.on_behalf_of_user,
966
- )
967
- ]
1282
+ if self.is_lakebase:
1283
+ return [
1284
+ DatabricksLakebase(
1285
+ database_instance_name=self.instance_name,
1286
+ on_behalf_of_user=self.on_behalf_of_user,
1287
+ )
1288
+ ]
1289
+ return []
968
1290
 
969
1291
  @model_validator(mode="after")
970
- def update_instance_name(self) -> Self:
971
- if self.instance_name is None:
972
- self.instance_name = self.name
1292
+ def validate_connection_type(self) -> Self:
1293
+ """Validate connection configuration based on type.
1294
+
1295
+ - If instance_name is provided: Databricks Lakebase connection
1296
+ (host is optional - will be fetched from API if not provided)
1297
+ - If only host is provided: Standard PostgreSQL connection
1298
+ (must not have instance_name)
1299
+ """
1300
+ if not self.instance_name and not self.host:
1301
+ raise ValueError(
1302
+ "Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
1303
+ )
973
1304
  return self
974
1305
 
975
1306
  @model_validator(mode="after")
@@ -978,10 +1309,10 @@ class DatabaseModel(IsDatabricksResource):
978
1309
  if self.on_behalf_of_user or self.client_id or self.user or self.pat:
979
1310
  return self
980
1311
 
981
- # For postgres, we need explicit user credentials
982
- # For lakebase with no auth, ambient auth is allowed
983
- if self.type == DatabaseType.POSTGRES:
984
- # Try to determine current user for local development
1312
+ # For standard PostgreSQL, we need explicit user credentials
1313
+ # For Lakebase with no auth, ambient auth is allowed
1314
+ if not self.is_lakebase:
1315
+ # Standard PostgreSQL - try to determine current user for local development
985
1316
  try:
986
1317
  self.user = self.workspace_client.current_user.me().user_name
987
1318
  except Exception as e:
@@ -990,12 +1321,12 @@ class DatabaseModel(IsDatabricksResource):
990
1321
  f"Please provide explicit user credentials."
991
1322
  )
992
1323
  else:
993
- # For lakebase, try to determine current user but don't fail if we can't
1324
+ # For Lakebase, try to determine current user but don't fail if we can't
994
1325
  try:
995
1326
  self.user = self.workspace_client.current_user.me().user_name
996
1327
  except Exception:
997
1328
  # If we can't determine user and no explicit auth, that's okay
998
- # for lakebase with ambient auth - credentials will be injected at runtime
1329
+ # for Lakebase with ambient auth - credentials will be injected at runtime
999
1330
  pass
1000
1331
 
1001
1332
  return self
@@ -1005,28 +1336,29 @@ class DatabaseModel(IsDatabricksResource):
1005
1336
  if self.host is not None:
1006
1337
  return self
1007
1338
 
1008
- # Try to fetch host from existing instance
1339
+ # If instance_name is provided (Lakebase), try to fetch host from existing instance
1009
1340
  # This may fail for OBO/ambient auth during model logging (before deployment)
1010
- try:
1011
- existing_instance: DatabaseInstance = (
1012
- self.workspace_client.database.get_database_instance(
1013
- name=self.instance_name
1014
- )
1015
- )
1016
- self.host = existing_instance.read_write_dns
1017
- except Exception as e:
1018
- # For lakebase with OBO/ambient auth, we can't fetch at config time
1019
- # The host will need to be provided explicitly or fetched at runtime
1020
- if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
1021
- logger.debug(
1022
- f"Could not fetch host for database {self.instance_name} "
1023
- f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1024
- )
1025
- else:
1026
- raise ValueError(
1027
- f"Could not fetch host for database {self.instance_name}. "
1028
- f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1341
+ if self.is_lakebase:
1342
+ try:
1343
+ existing_instance: DatabaseInstance = (
1344
+ self.workspace_client.database.get_database_instance(
1345
+ name=self.instance_name
1346
+ )
1029
1347
  )
1348
+ self.host = existing_instance.read_write_dns
1349
+ except Exception as e:
1350
+ # For Lakebase with OBO/ambient auth, we can't fetch at config time
1351
+ # The host will need to be provided explicitly or fetched at runtime
1352
+ if self.on_behalf_of_user:
1353
+ logger.debug(
1354
+ f"Could not fetch host for database {self.instance_name} "
1355
+ f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1356
+ )
1357
+ else:
1358
+ raise ValueError(
1359
+ f"Could not fetch host for database {self.instance_name}. "
1360
+ f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1361
+ )
1030
1362
  return self
1031
1363
 
1032
1364
  @model_validator(mode="after")
@@ -1054,9 +1386,9 @@ class DatabaseModel(IsDatabricksResource):
1054
1386
  "or user credentials (user)."
1055
1387
  )
1056
1388
 
1057
- # For postgres type, at least one auth method must be configured
1058
- # For lakebase type, auth is optional (supports ambient authentication)
1059
- if self.type == DatabaseType.POSTGRES and auth_methods_count == 0:
1389
+ # For standard PostgreSQL (host-based), at least one auth method must be configured
1390
+ # For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
1391
+ if not self.is_lakebase and auth_methods_count == 0:
1060
1392
  raise ValueError(
1061
1393
  "PostgreSQL databases require explicit authentication. "
1062
1394
  "Please provide one of: "
@@ -1074,24 +1406,23 @@ class DatabaseModel(IsDatabricksResource):
1074
1406
  Get database connection parameters as a dictionary.
1075
1407
 
1076
1408
  Returns a dict with connection parameters suitable for psycopg ConnectionPool.
1077
- If username is configured, it will be included; otherwise it will be omitted
1078
- to allow Lakebase to authenticate using the token's identity.
1409
+
1410
+ For Lakebase: Uses Databricks-generated credentials (token-based auth).
1411
+ For standard PostgreSQL: Uses provided user/password credentials.
1079
1412
  """
1080
1413
  import uuid as _uuid
1081
1414
 
1082
1415
  from databricks.sdk.service.database import DatabaseCredential
1083
1416
 
1417
+ host: str
1418
+ port: int
1419
+ database: str
1084
1420
  username: str | None = None
1085
-
1086
- if self.client_id and self.client_secret and self.workspace_host:
1087
- username = value_of(self.client_id)
1088
- elif self.user:
1089
- username = value_of(self.user)
1090
- # For OBO mode, no username is needed - the token identity is used
1421
+ password_value: str | None = None
1091
1422
 
1092
1423
  # Resolve host - may need to fetch at runtime for OBO mode
1093
1424
  host_value: Any = self.host
1094
- if host_value is None and self.on_behalf_of_user:
1425
+ if host_value is None and self.is_lakebase and self.on_behalf_of_user:
1095
1426
  # Fetch host at runtime for OBO mode
1096
1427
  existing_instance: DatabaseInstance = (
1097
1428
  self.workspace_client.database.get_database_instance(
@@ -1101,29 +1432,50 @@ class DatabaseModel(IsDatabricksResource):
1101
1432
  host_value = existing_instance.read_write_dns
1102
1433
 
1103
1434
  if host_value is None:
1435
+ instance_or_name = self.instance_name if self.is_lakebase else self.name
1104
1436
  raise ValueError(
1105
- f"Database host not configured for {self.instance_name}. "
1437
+ f"Database host not configured for {instance_or_name}. "
1106
1438
  "Please provide 'host' explicitly."
1107
1439
  )
1108
1440
 
1109
- host: str = value_of(host_value)
1110
- port: int = value_of(self.port)
1111
- database: str = value_of(self.database)
1441
+ host = value_of(host_value)
1442
+ port = value_of(self.port)
1443
+ database = value_of(self.database)
1112
1444
 
1113
- # Use the resource's own workspace_client to generate the database credential
1114
- w: WorkspaceClient = self.workspace_client
1115
- cred: DatabaseCredential = w.database.generate_database_credential(
1116
- request_id=str(_uuid.uuid4()),
1117
- instance_names=[self.instance_name],
1118
- )
1119
- token: str = cred.token
1445
+ if self.is_lakebase:
1446
+ # Lakebase: Use Databricks-generated credentials
1447
+ if self.client_id and self.client_secret and self.workspace_host:
1448
+ username = value_of(self.client_id)
1449
+ elif self.user:
1450
+ username = value_of(self.user)
1451
+ # For OBO mode, no username is needed - the token identity is used
1452
+
1453
+ # Generate Databricks database credential (token)
1454
+ w: WorkspaceClient = self.workspace_client
1455
+ cred: DatabaseCredential = w.database.generate_database_credential(
1456
+ request_id=str(_uuid.uuid4()),
1457
+ instance_names=[self.instance_name],
1458
+ )
1459
+ password_value = cred.token
1460
+ else:
1461
+ # Standard PostgreSQL: Use provided credentials
1462
+ if self.user:
1463
+ username = value_of(self.user)
1464
+ if self.password:
1465
+ password_value = value_of(self.password)
1466
+
1467
+ if not username or not password_value:
1468
+ raise ValueError(
1469
+ f"Standard PostgreSQL databases require both 'user' and 'password'. "
1470
+ f"Database: {self.name}"
1471
+ )
1120
1472
 
1121
1473
  # Build connection parameters dictionary
1122
1474
  params: dict[str, Any] = {
1123
1475
  "dbname": database,
1124
1476
  "host": host,
1125
1477
  "port": port,
1126
- "password": token,
1478
+ "password": password_value,
1127
1479
  "sslmode": "require",
1128
1480
  }
1129
1481
 
@@ -1264,11 +1616,13 @@ class RerankParametersModel(BaseModel):
1264
1616
  top_n: 5 # Return top 5 after reranking
1265
1617
  ```
1266
1618
 
1267
- Available models (from fastest to most accurate):
1268
- - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
1269
- - "ms-marco-MiniLM-L-6-v2"
1270
- - "ms-marco-MiniLM-L-12-v2" (default, good balance)
1271
- - "rank-T5-flan" (most accurate, slower)
1619
+ Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
1620
+ - "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
1621
+ - "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
1622
+ - "rank-T5-flan" (~110MB, best non cross-encoder)
1623
+ - "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
1624
+ - "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
1625
+ - "miniReranker_arabic_v1" (Arabic language)
1272
1626
  """
1273
1627
 
1274
1628
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1282,8 +1636,8 @@ class RerankParametersModel(BaseModel):
1282
1636
  description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
1283
1637
  )
1284
1638
  cache_dir: Optional[str] = Field(
1285
- default="/tmp/flashrank_cache",
1286
- description="Directory to cache downloaded model weights.",
1639
+ default="~/.dao_ai/cache/flashrank",
1640
+ description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
1287
1641
  )
1288
1642
  columns: Optional[list[str]] = Field(
1289
1643
  default_factory=list, description="Columns to rerank using DatabricksReranker"
@@ -1373,7 +1727,6 @@ class BaseFunctionModel(ABC, BaseModel):
1373
1727
  discriminator="type",
1374
1728
  )
1375
1729
  type: FunctionType
1376
- name: str
1377
1730
  human_in_the_loop: Optional[HumanInTheLoopModel] = None
1378
1731
 
1379
1732
  @abstractmethod
@@ -1390,6 +1743,7 @@ class BaseFunctionModel(ABC, BaseModel):
1390
1743
  class PythonFunctionModel(BaseFunctionModel, HasFullName):
1391
1744
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1392
1745
  type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
1746
+ name: str
1393
1747
 
1394
1748
  @property
1395
1749
  def full_name(self) -> str:
@@ -1403,8 +1757,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
1403
1757
 
1404
1758
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1405
1759
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1406
- args: Optional[dict[str, Any]] = Field(default_factory=dict)
1407
1760
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
1761
+ name: str
1762
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
1408
1763
 
1409
1764
  @property
1410
1765
  def full_name(self) -> str:
@@ -1427,7 +1782,7 @@ class TransportType(str, Enum):
1427
1782
  STDIO = "stdio"
1428
1783
 
1429
1784
 
1430
- class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
1785
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1431
1786
  """
1432
1787
  MCP Function Model with authentication inherited from IsDatabricksResource.
1433
1788
 
@@ -1466,10 +1821,6 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
1466
1821
  """MCP functions don't declare static resources."""
1467
1822
  return []
1468
1823
 
1469
- @property
1470
- def full_name(self) -> str:
1471
- return self.name
1472
-
1473
1824
  def _get_workspace_host(self) -> str:
1474
1825
  """
1475
1826
  Get the workspace host, either from config or from workspace client.
@@ -1635,17 +1986,11 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
1635
1986
  return create_mcp_tools(self)
1636
1987
 
1637
1988
 
1638
- class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
1989
+ class UnityCatalogFunctionModel(BaseFunctionModel):
1639
1990
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1640
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1641
- partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1642
1991
  type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
1643
-
1644
- @property
1645
- def full_name(self) -> str:
1646
- if self.schema_model:
1647
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1648
- return self.name
1992
+ resource: FunctionModel
1993
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1649
1994
 
1650
1995
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1651
1996
  from dao_ai.tools import create_uc_tools
@@ -1679,6 +2024,12 @@ class PromptModel(BaseModel, HasFullName):
1679
2024
  alias: Optional[str] = None
1680
2025
  version: Optional[int] = None
1681
2026
  tags: Optional[dict[str, Any]] = Field(default_factory=dict)
2027
+ auto_register: bool = Field(
2028
+ default=False,
2029
+ description="Whether to automatically register the default_template to the prompt registry. "
2030
+ "If False, the prompt will only be loaded from the registry (never created/updated). "
2031
+ "Defaults to True for backward compatibility.",
2032
+ )
1682
2033
 
1683
2034
  @property
1684
2035
  def template(self) -> str:
@@ -1760,23 +2111,17 @@ class MiddlewareModel(BaseModel):
1760
2111
  class StorageType(str, Enum):
1761
2112
  POSTGRES = "postgres"
1762
2113
  MEMORY = "memory"
1763
- LAKEBASE = "lakebase"
1764
2114
 
1765
2115
 
1766
2116
  class CheckpointerModel(BaseModel):
1767
2117
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1768
2118
  name: str
1769
- type: Optional[StorageType] = StorageType.MEMORY
1770
2119
  database: Optional[DatabaseModel] = None
1771
2120
 
1772
- @model_validator(mode="after")
1773
- def validate_storage_requires_database(self) -> Self:
1774
- if (
1775
- self.type in [StorageType.POSTGRES, StorageType.LAKEBASE]
1776
- and not self.database
1777
- ):
1778
- raise ValueError("Database must be provided when storage type is POSTGRES")
1779
- return self
2121
+ @property
2122
+ def storage_type(self) -> StorageType:
2123
+ """Infer storage type from database presence."""
2124
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1780
2125
 
1781
2126
  def as_checkpointer(self) -> BaseCheckpointSaver:
1782
2127
  from dao_ai.memory import CheckpointManager
@@ -1792,16 +2137,14 @@ class StoreModel(BaseModel):
1792
2137
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1793
2138
  name: str
1794
2139
  embedding_model: Optional[LLMModel] = None
1795
- type: Optional[StorageType] = StorageType.MEMORY
1796
2140
  dims: Optional[int] = 1536
1797
2141
  database: Optional[DatabaseModel] = None
1798
2142
  namespace: Optional[str] = None
1799
2143
 
1800
- @model_validator(mode="after")
1801
- def validate_postgres_requires_database(self) -> Self:
1802
- if self.type == StorageType.POSTGRES and not self.database:
1803
- raise ValueError("Database must be provided when storage type is POSTGRES")
1804
- return self
2144
+ @property
2145
+ def storage_type(self) -> StorageType:
2146
+ """Infer storage type from database presence."""
2147
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1805
2148
 
1806
2149
  def as_store(self) -> BaseStore:
1807
2150
  from dao_ai.memory import StoreManager
@@ -2044,6 +2387,10 @@ class SwarmModel(BaseModel):
2044
2387
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
2045
2388
  model: LLMModel
2046
2389
  default_agent: Optional[AgentModel | str] = None
2390
+ middleware: list[MiddlewareModel] = Field(
2391
+ default_factory=list,
2392
+ description="List of middleware to apply to all agents in the swarm",
2393
+ )
2047
2394
  handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
2048
2395
  default_factory=dict
2049
2396
  )
@@ -2606,7 +2953,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
2606
2953
 
2607
2954
  class UnityCatalogFunctionSqlModel(BaseModel):
2608
2955
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
2609
- function: UnityCatalogFunctionModel
2956
+ function: FunctionModel
2610
2957
  ddl: str
2611
2958
  parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
2612
2959
  test: Optional[UnityCatalogFunctionSqlTestModel] = None
@@ -2636,6 +2983,113 @@ class ResourcesModel(BaseModel):
2636
2983
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
2637
2984
  apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
2638
2985
 
2986
+ @model_validator(mode="after")
2987
+ def update_genie_warehouses(self) -> Self:
2988
+ """
2989
+ Automatically populate warehouses from genie_rooms.
2990
+
2991
+ Warehouses are extracted from each Genie room and added to the
2992
+ resources if they don't already exist (based on warehouse_id).
2993
+ """
2994
+ if not self.genie_rooms:
2995
+ return self
2996
+
2997
+ # Process warehouses from all genie rooms
2998
+ for genie_room in self.genie_rooms.values():
2999
+ genie_room: GenieRoomModel
3000
+ warehouse: Optional[WarehouseModel] = genie_room.warehouse
3001
+
3002
+ if warehouse is None:
3003
+ continue
3004
+
3005
+ # Check if warehouse already exists based on warehouse_id
3006
+ warehouse_exists: bool = any(
3007
+ existing_warehouse.warehouse_id == warehouse.warehouse_id
3008
+ for existing_warehouse in self.warehouses.values()
3009
+ )
3010
+
3011
+ if not warehouse_exists:
3012
+ warehouse_key: str = normalize_name(
3013
+ "_".join([genie_room.name, warehouse.warehouse_id])
3014
+ )
3015
+ self.warehouses[warehouse_key] = warehouse
3016
+ logger.trace(
3017
+ "Added warehouse from Genie room",
3018
+ room=genie_room.name,
3019
+ warehouse=warehouse.warehouse_id,
3020
+ key=warehouse_key,
3021
+ )
3022
+
3023
+ return self
3024
+
3025
+ @model_validator(mode="after")
3026
+ def update_genie_tables(self) -> Self:
3027
+ """
3028
+ Automatically populate tables from genie_rooms.
3029
+
3030
+ Tables are extracted from each Genie room and added to the
3031
+ resources if they don't already exist (based on full_name).
3032
+ """
3033
+ if not self.genie_rooms:
3034
+ return self
3035
+
3036
+ # Process tables from all genie rooms
3037
+ for genie_room in self.genie_rooms.values():
3038
+ genie_room: GenieRoomModel
3039
+ for table in genie_room.tables:
3040
+ table: TableModel
3041
+ table_exists: bool = any(
3042
+ existing_table.full_name == table.full_name
3043
+ for existing_table in self.tables.values()
3044
+ )
3045
+ if not table_exists:
3046
+ table_key: str = normalize_name(
3047
+ "_".join([genie_room.name, table.full_name])
3048
+ )
3049
+ self.tables[table_key] = table
3050
+ logger.trace(
3051
+ "Added table from Genie room",
3052
+ room=genie_room.name,
3053
+ table=table.name,
3054
+ key=table_key,
3055
+ )
3056
+
3057
+ return self
3058
+
3059
+ @model_validator(mode="after")
3060
+ def update_genie_functions(self) -> Self:
3061
+ """
3062
+ Automatically populate functions from genie_rooms.
3063
+
3064
+ Functions are extracted from each Genie room and added to the
3065
+ resources if they don't already exist (based on full_name).
3066
+ """
3067
+ if not self.genie_rooms:
3068
+ return self
3069
+
3070
+ # Process functions from all genie rooms
3071
+ for genie_room in self.genie_rooms.values():
3072
+ genie_room: GenieRoomModel
3073
+ for function in genie_room.functions:
3074
+ function: FunctionModel
3075
+ function_exists: bool = any(
3076
+ existing_function.full_name == function.full_name
3077
+ for existing_function in self.functions.values()
3078
+ )
3079
+ if not function_exists:
3080
+ function_key: str = normalize_name(
3081
+ "_".join([genie_room.name, function.full_name])
3082
+ )
3083
+ self.functions[function_key] = function
3084
+ logger.trace(
3085
+ "Added function from Genie room",
3086
+ room=genie_room.name,
3087
+ function=function.name,
3088
+ key=function_key,
3089
+ )
3090
+
3091
+ return self
3092
+
2639
3093
 
2640
3094
  class AppConfig(BaseModel):
2641
3095
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -2674,10 +3128,10 @@ class AppConfig(BaseModel):
2674
3128
 
2675
3129
  def initialize(self) -> None:
2676
3130
  from dao_ai.hooks.core import create_hooks
3131
+ from dao_ai.logging import configure_logging
2677
3132
 
2678
3133
  if self.app and self.app.log_level:
2679
- logger.remove()
2680
- logger.add(sys.stderr, level=self.app.log_level)
3134
+ configure_logging(level=self.app.log_level)
2681
3135
 
2682
3136
  logger.debug("Calling initialization hooks...")
2683
3137
  initialization_functions: Sequence[Callable[..., Any]] = create_hooks(