dao-ai 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  Iterator,
13
13
  Literal,
14
14
  Optional,
15
+ Self,
15
16
  Sequence,
16
17
  TypeAlias,
17
18
  Union,
@@ -200,6 +201,15 @@ AnyVariable: TypeAlias = (
200
201
  )
201
202
 
202
203
 
204
+ class ServicePrincipalModel(BaseModel):
205
+ model_config = ConfigDict(
206
+ frozen=True,
207
+ use_enum_values=True,
208
+ )
209
+ client_id: AnyVariable
210
+ client_secret: AnyVariable
211
+
212
+
203
213
  class Privilege(str, Enum):
204
214
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
205
215
  USE_CATALOG = "USE_CATALOG"
@@ -451,7 +461,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
451
461
  ]
452
462
 
453
463
  @model_validator(mode="after")
454
- def update_space_id(self):
464
+ def update_space_id(self) -> Self:
455
465
  self.space_id = value_of(self.space_id)
456
466
  return self
457
467
 
@@ -530,13 +540,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
530
540
  embedding_source_column: str
531
541
 
532
542
  @model_validator(mode="after")
533
- def set_default_embedding_model(self):
543
+ def set_default_embedding_model(self) -> Self:
534
544
  if not self.embedding_model:
535
545
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
536
546
  return self
537
547
 
538
548
  @model_validator(mode="after")
539
- def set_default_primary_key(self):
549
+ def set_default_primary_key(self) -> Self:
540
550
  if self.primary_key is None:
541
551
  from dao_ai.providers.databricks import DatabricksProvider
542
552
 
@@ -557,14 +567,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
557
567
  return self
558
568
 
559
569
  @model_validator(mode="after")
560
- def set_default_index(self):
570
+ def set_default_index(self) -> Self:
561
571
  if self.index is None:
562
572
  name: str = f"{self.source_table.name}_index"
563
573
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
564
574
  return self
565
575
 
566
576
  @model_validator(mode="after")
567
- def set_default_endpoint(self):
577
+ def set_default_endpoint(self) -> Self:
568
578
  if self.endpoint is None:
569
579
  from dao_ai.providers.databricks import (
570
580
  DatabricksProvider,
@@ -719,12 +729,68 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
719
729
  ]
720
730
 
721
731
  @model_validator(mode="after")
722
- def update_warehouse_id(self):
732
+ def update_warehouse_id(self) -> Self:
723
733
  self.warehouse_id = value_of(self.warehouse_id)
724
734
  return self
725
735
 
726
736
 
727
737
  class DatabaseModel(BaseModel, IsDatabricksResource):
738
+ """
739
+ Configuration for a Databricks Lakebase (PostgreSQL) database instance.
740
+
741
+ Authentication Model:
742
+ --------------------
743
+ This model uses TWO separate authentication contexts:
744
+
745
+ 1. **Workspace API Authentication** (inherited from IsDatabricksResource):
746
+ - Uses ambient/default authentication (environment variables, notebook context, app service principal)
747
+ - Used for: discovering database instance, getting host DNS, checking instance status
748
+ - Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
749
+
750
+ 2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
751
+ - Used for: connecting to the PostgreSQL database as a specific identity
752
+ - Service Principal: Set service_principal with workspace_host to connect as a service principal
753
+ - OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
754
+ - User Auth: Set user (and optionally password) to connect as a user identity
755
+
756
+ Example Service Principal Configuration:
757
+ ```yaml
758
+ databases:
759
+ my_lakebase:
760
+ name: my-database
761
+ service_principal:
762
+ client_id:
763
+ env: SERVICE_PRINCIPAL_CLIENT_ID
764
+ client_secret:
765
+ scope: my-scope
766
+ secret: sp-client-secret
767
+ workspace_host:
768
+ env: DATABRICKS_HOST
769
+ ```
770
+
771
+ Example OAuth M2M Configuration (alternative):
772
+ ```yaml
773
+ databases:
774
+ my_lakebase:
775
+ name: my-database
776
+ client_id:
777
+ env: SERVICE_PRINCIPAL_CLIENT_ID
778
+ client_secret:
779
+ scope: my-scope
780
+ secret: sp-client-secret
781
+ workspace_host:
782
+ env: DATABRICKS_HOST
783
+ ```
784
+
785
+ Example User Configuration:
786
+ ```yaml
787
+ databases:
788
+ my_lakebase:
789
+ name: my-database
790
+ user: my-user@databricks.com
791
+ ```
792
+ """
793
+
728
794
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
729
795
  name: str
730
796
  instance_name: Optional[str] = None
@@ -739,6 +805,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
739
805
  node_count: Optional[int] = None
740
806
  user: Optional[AnyVariable] = None
741
807
  password: Optional[AnyVariable] = None
808
+ service_principal: Optional[ServicePrincipalModel] = None
742
809
  client_id: Optional[AnyVariable] = None
743
810
  client_secret: Optional[AnyVariable] = None
744
811
  workspace_host: Optional[AnyVariable] = None
@@ -756,14 +823,24 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
756
823
  ]
757
824
 
758
825
  @model_validator(mode="after")
759
- def update_instance_name(self):
826
+ def update_instance_name(self) -> Self:
760
827
  if self.instance_name is None:
761
828
  self.instance_name = self.name
762
829
 
763
830
  return self
764
831
 
765
832
  @model_validator(mode="after")
766
- def update_user(self):
833
+ def expand_service_principal(self) -> Self:
834
+ """Expand service_principal into client_id and client_secret if provided."""
835
+ if self.service_principal is not None:
836
+ if self.client_id is None:
837
+ self.client_id = self.service_principal.client_id
838
+ if self.client_secret is None:
839
+ self.client_secret = self.service_principal.client_secret
840
+ return self
841
+
842
+ @model_validator(mode="after")
843
+ def update_user(self) -> Self:
767
844
  if self.client_id or self.user:
768
845
  return self
769
846
 
@@ -776,7 +853,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
776
853
  return self
777
854
 
778
855
  @model_validator(mode="after")
779
- def update_host(self):
856
+ def update_host(self) -> Self:
780
857
  if self.host is not None:
781
858
  return self
782
859
 
@@ -789,7 +866,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
789
866
  return self
790
867
 
791
868
  @model_validator(mode="after")
792
- def validate_auth_methods(self):
869
+ def validate_auth_methods(self) -> Self:
793
870
  oauth_fields: Sequence[Any] = [
794
871
  self.workspace_host,
795
872
  self.client_id,
@@ -809,8 +886,8 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
809
886
  if not has_oauth and not has_user_auth:
810
887
  raise ValueError(
811
888
  "At least one authentication method must be provided: "
812
- "either OAuth credentials (workspace_host, client_id, client_secret) "
813
- "or user credentials (user, password)."
889
+ "either OAuth credentials (workspace_host, client_id, client_secret), "
890
+ "service_principal with workspace_host, or user credentials (user, password)."
814
891
  )
815
892
 
816
893
  return self
@@ -883,7 +960,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
883
960
  def create(self, w: WorkspaceClient | None = None) -> None:
884
961
  from dao_ai.providers.databricks import DatabricksProvider
885
962
 
886
- provider: DatabricksProvider = DatabricksProvider()
963
+ provider: DatabricksProvider = DatabricksProvider(w=w)
887
964
  provider.create_lakebase(self)
888
965
  provider.create_lakebase_instance_role(self)
889
966
 
@@ -957,14 +1034,14 @@ class RetrieverModel(BaseModel):
957
1034
  )
958
1035
 
959
1036
  @model_validator(mode="after")
960
- def set_default_columns(self):
1037
+ def set_default_columns(self) -> Self:
961
1038
  if not self.columns:
962
1039
  columns: Sequence[str] = self.vector_store.columns
963
1040
  self.columns = columns
964
1041
  return self
965
1042
 
966
1043
  @model_validator(mode="after")
967
- def set_default_reranker(self):
1044
+ def set_default_reranker(self) -> Self:
968
1045
  """Convert bool to ReRankParametersModel with defaults."""
969
1046
  if isinstance(self.rerank, bool) and self.rerank:
970
1047
  self.rerank = RerankParametersModel()
@@ -1051,7 +1128,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1051
1128
  return [create_factory_tool(self, **kwargs)]
1052
1129
 
1053
1130
  @model_validator(mode="after")
1054
- def update_args(self):
1131
+ def update_args(self) -> Self:
1055
1132
  for key, value in self.args.items():
1056
1133
  self.args[key] = value_of(value)
1057
1134
  return self
@@ -1071,6 +1148,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1071
1148
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1072
1149
  args: list[str] = Field(default_factory=list)
1073
1150
  pat: Optional[AnyVariable] = None
1151
+ service_principal: Optional[ServicePrincipalModel] = None
1074
1152
  client_id: Optional[AnyVariable] = None
1075
1153
  client_secret: Optional[AnyVariable] = None
1076
1154
  workspace_host: Optional[AnyVariable] = None
@@ -1080,6 +1158,16 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1080
1158
  sql: Optional[bool] = None
1081
1159
  vector_search: Optional[VectorStoreModel] = None
1082
1160
 
1161
+ @model_validator(mode="after")
1162
+ def expand_service_principal(self) -> Self:
1163
+ """Expand service_principal into client_id and client_secret if provided."""
1164
+ if self.service_principal is not None:
1165
+ if self.client_id is None:
1166
+ self.client_id = self.service_principal.client_id
1167
+ if self.client_secret is None:
1168
+ self.client_secret = self.service_principal.client_secret
1169
+ return self
1170
+
1083
1171
  @property
1084
1172
  def full_name(self) -> str:
1085
1173
  return self.name
@@ -1089,12 +1177,12 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1089
1177
  Get the workspace host, either from config or from workspace client.
1090
1178
 
1091
1179
  If connection is provided, uses its workspace client.
1092
- Otherwise, falls back to creating a new workspace client.
1180
+ Otherwise, falls back to the default Databricks host.
1093
1181
 
1094
1182
  Returns:
1095
1183
  str: The workspace host URL without trailing slash
1096
1184
  """
1097
- from databricks.sdk import WorkspaceClient
1185
+ from dao_ai.utils import get_default_databricks_host
1098
1186
 
1099
1187
  # Try to get workspace_host from config
1100
1188
  workspace_host: str | None = (
@@ -1107,9 +1195,13 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1107
1195
  if self.connection:
1108
1196
  workspace_host = self.connection.workspace_client.config.host
1109
1197
  else:
1110
- # Create a default workspace client
1111
- w: WorkspaceClient = WorkspaceClient()
1112
- workspace_host = w.config.host
1198
+ workspace_host = get_default_databricks_host()
1199
+
1200
+ if not workspace_host:
1201
+ raise ValueError(
1202
+ "Could not determine workspace host. "
1203
+ "Please set workspace_host in config or DATABRICKS_HOST environment variable."
1204
+ )
1113
1205
 
1114
1206
  # Remove trailing slash
1115
1207
  return workspace_host.rstrip("/")
@@ -1316,7 +1408,7 @@ class CheckpointerModel(BaseModel):
1316
1408
  database: Optional[DatabaseModel] = None
1317
1409
 
1318
1410
  @model_validator(mode="after")
1319
- def validate_postgres_requires_database(self):
1411
+ def validate_postgres_requires_database(self) -> Self:
1320
1412
  if self.type == StorageType.POSTGRES and not self.database:
1321
1413
  raise ValueError("Database must be provided when storage type is POSTGRES")
1322
1414
  return self
@@ -1341,7 +1433,7 @@ class StoreModel(BaseModel):
1341
1433
  namespace: Optional[str] = None
1342
1434
 
1343
1435
  @model_validator(mode="after")
1344
- def validate_postgres_requires_database(self):
1436
+ def validate_postgres_requires_database(self) -> Self:
1345
1437
  if self.type == StorageType.POSTGRES and not self.database:
1346
1438
  raise ValueError("Database must be provided when storage type is POSTGRES")
1347
1439
  return self
@@ -1405,7 +1497,7 @@ class PromptModel(BaseModel, HasFullName):
1405
1497
  return prompt_version
1406
1498
 
1407
1499
  @model_validator(mode="after")
1408
- def validate_mutually_exclusive(self):
1500
+ def validate_mutually_exclusive(self) -> Self:
1409
1501
  if self.alias and self.version:
1410
1502
  raise ValueError("Cannot specify both alias and version")
1411
1503
  return self
@@ -1459,7 +1551,7 @@ class OrchestrationModel(BaseModel):
1459
1551
  memory: Optional[MemoryModel] = None
1460
1552
 
1461
1553
  @model_validator(mode="after")
1462
- def validate_mutually_exclusive(self):
1554
+ def validate_mutually_exclusive(self) -> Self:
1463
1555
  if self.supervisor is not None and self.swarm is not None:
1464
1556
  raise ValueError("Cannot specify both supervisor and swarm")
1465
1557
  if self.supervisor is None and self.swarm is None:
@@ -1586,6 +1678,7 @@ class ChatHistoryModel(BaseModel):
1586
1678
  class AppModel(BaseModel):
1587
1679
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1588
1680
  name: str
1681
+ service_principal: Optional[ServicePrincipalModel] = None
1589
1682
  description: Optional[str] = None
1590
1683
  log_level: Optional[LogLevel] = "WARNING"
1591
1684
  registered_model: RegisteredModelModel
@@ -1613,16 +1706,50 @@ class AppModel(BaseModel):
1613
1706
  chat_history: Optional[ChatHistoryModel] = None
1614
1707
  code_paths: list[str] = Field(default_factory=list)
1615
1708
  pip_requirements: list[str] = Field(default_factory=list)
1709
+ python_version: Optional[str] = Field(
1710
+ default="3.12",
1711
+ description="Python version for Model Serving deployment. Defaults to 3.12 "
1712
+ "which is supported by Databricks Model Serving. This allows deploying from "
1713
+ "environments with different Python versions (e.g., Databricks Apps with 3.11).",
1714
+ )
1715
+
1716
+ @model_validator(mode="after")
1717
+ def set_databricks_env_vars(self) -> Self:
1718
+ """Set Databricks environment variables for Model Serving.
1719
+
1720
+ Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
1721
+ Values explicitly provided in environment_vars take precedence.
1722
+ """
1723
+ from dao_ai.utils import get_default_databricks_host
1724
+
1725
+ # Set DATABRICKS_HOST if not already provided
1726
+ if "DATABRICKS_HOST" not in self.environment_vars:
1727
+ host: str | None = get_default_databricks_host()
1728
+ if host:
1729
+ self.environment_vars["DATABRICKS_HOST"] = host
1730
+
1731
+ # Set service principal credentials if provided
1732
+ if self.service_principal is not None:
1733
+ if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
1734
+ self.environment_vars["DATABRICKS_CLIENT_ID"] = (
1735
+ self.service_principal.client_id
1736
+ )
1737
+ if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
1738
+ self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
1739
+ self.service_principal.client_secret
1740
+ )
1741
+ return self
1616
1742
 
1617
1743
  @model_validator(mode="after")
1618
- def validate_agents_not_empty(self):
1744
+ def validate_agents_not_empty(self) -> Self:
1619
1745
  if not self.agents:
1620
1746
  raise ValueError("At least one agent must be specified")
1621
1747
  return self
1622
1748
 
1623
1749
  @model_validator(mode="after")
1624
- def update_environment_vars(self):
1750
+ def resolve_environment_vars(self) -> Self:
1625
1751
  for key, value in self.environment_vars.items():
1752
+ updated_value: str
1626
1753
  if isinstance(value, SecretVariableModel):
1627
1754
  updated_value = str(value)
1628
1755
  else:
@@ -1632,7 +1759,7 @@ class AppModel(BaseModel):
1632
1759
  return self
1633
1760
 
1634
1761
  @model_validator(mode="after")
1635
- def set_default_orchestration(self):
1762
+ def set_default_orchestration(self) -> Self:
1636
1763
  if self.orchestration is None:
1637
1764
  if len(self.agents) > 1:
1638
1765
  default_agent: AgentModel = self.agents[0]
@@ -1652,14 +1779,14 @@ class AppModel(BaseModel):
1652
1779
  return self
1653
1780
 
1654
1781
  @model_validator(mode="after")
1655
- def set_default_endpoint_name(self):
1782
+ def set_default_endpoint_name(self) -> Self:
1656
1783
  if self.endpoint_name is None:
1657
1784
  self.endpoint_name = self.name
1658
1785
  return self
1659
1786
 
1660
1787
  @model_validator(mode="after")
1661
- def set_default_agent(self):
1662
- default_agent_name = self.agents[0].name
1788
+ def set_default_agent(self) -> Self:
1789
+ default_agent_name: str = self.agents[0].name
1663
1790
 
1664
1791
  if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
1665
1792
  self.orchestration.swarm.default_agent = default_agent_name
@@ -1667,7 +1794,7 @@ class AppModel(BaseModel):
1667
1794
  return self
1668
1795
 
1669
1796
  @model_validator(mode="after")
1670
- def add_code_paths_to_sys_path(self):
1797
+ def add_code_paths_to_sys_path(self) -> Self:
1671
1798
  for code_path in self.code_paths:
1672
1799
  parent_path: str = str(Path(code_path).parent)
1673
1800
  if parent_path not in sys.path:
@@ -1700,7 +1827,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
1700
1827
  expected_facts: Optional[list[str]] = None
1701
1828
 
1702
1829
  @model_validator(mode="after")
1703
- def validate_mutually_exclusive(self):
1830
+ def validate_mutually_exclusive(self) -> Self:
1704
1831
  if self.expected_response is not None and self.expected_facts is not None:
1705
1832
  raise ValueError("Cannot specify both expected_response and expected_facts")
1706
1833
  return self
@@ -1808,7 +1935,7 @@ class PromptOptimizationModel(BaseModel):
1808
1935
  return optimized_prompt
1809
1936
 
1810
1937
  @model_validator(mode="after")
1811
- def set_defaults(self):
1938
+ def set_defaults(self) -> Self:
1812
1939
  # If no prompt is specified, try to use the agent's prompt
1813
1940
  if self.prompt is None:
1814
1941
  if isinstance(self.agent.prompt, PromptModel):
@@ -1930,6 +2057,7 @@ class ResourcesModel(BaseModel):
1930
2057
  class AppConfig(BaseModel):
1931
2058
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1932
2059
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
2060
+ service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
1933
2061
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
1934
2062
  resources: Optional[ResourcesModel] = None
1935
2063
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
dao_ai/prompts.py CHANGED
@@ -1,10 +1,10 @@
1
1
  from typing import Any, Callable, Optional, Sequence
2
2
 
3
- from langchain.prompts import PromptTemplate
4
3
  from langchain_core.messages import (
5
4
  BaseMessage,
6
5
  SystemMessage,
7
6
  )
7
+ from langchain_core.prompts import PromptTemplate
8
8
  from langchain_core.runnables import RunnableConfig
9
9
  from loguru import logger
10
10