dao-ai 0.0.32__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +117 -35
- dao_ai/tools/core.py +1 -1
- dao_ai/tools/unity_catalog.py +31 -2
- dao_ai/utils.py +26 -0
- {dao_ai-0.0.32.dist-info → dao_ai-0.0.33.dist-info}/METADATA +7 -7
- {dao_ai-0.0.32.dist-info → dao_ai-0.0.33.dist-info}/RECORD +9 -9
- {dao_ai-0.0.32.dist-info → dao_ai-0.0.33.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.32.dist-info → dao_ai-0.0.33.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.32.dist-info → dao_ai-0.0.33.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Iterator,
|
|
13
13
|
Literal,
|
|
14
14
|
Optional,
|
|
15
|
+
Self,
|
|
15
16
|
Sequence,
|
|
16
17
|
TypeAlias,
|
|
17
18
|
Union,
|
|
@@ -200,6 +201,15 @@ AnyVariable: TypeAlias = (
|
|
|
200
201
|
)
|
|
201
202
|
|
|
202
203
|
|
|
204
|
+
class ServicePrincipalModel(BaseModel):
|
|
205
|
+
model_config = ConfigDict(
|
|
206
|
+
frozen=True,
|
|
207
|
+
use_enum_values=True,
|
|
208
|
+
)
|
|
209
|
+
client_id: AnyVariable
|
|
210
|
+
client_secret: AnyVariable
|
|
211
|
+
|
|
212
|
+
|
|
203
213
|
class Privilege(str, Enum):
|
|
204
214
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
205
215
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -451,7 +461,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
451
461
|
]
|
|
452
462
|
|
|
453
463
|
@model_validator(mode="after")
|
|
454
|
-
def update_space_id(self):
|
|
464
|
+
def update_space_id(self) -> Self:
|
|
455
465
|
self.space_id = value_of(self.space_id)
|
|
456
466
|
return self
|
|
457
467
|
|
|
@@ -530,13 +540,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
530
540
|
embedding_source_column: str
|
|
531
541
|
|
|
532
542
|
@model_validator(mode="after")
|
|
533
|
-
def set_default_embedding_model(self):
|
|
543
|
+
def set_default_embedding_model(self) -> Self:
|
|
534
544
|
if not self.embedding_model:
|
|
535
545
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
536
546
|
return self
|
|
537
547
|
|
|
538
548
|
@model_validator(mode="after")
|
|
539
|
-
def set_default_primary_key(self):
|
|
549
|
+
def set_default_primary_key(self) -> Self:
|
|
540
550
|
if self.primary_key is None:
|
|
541
551
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
542
552
|
|
|
@@ -557,14 +567,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
557
567
|
return self
|
|
558
568
|
|
|
559
569
|
@model_validator(mode="after")
|
|
560
|
-
def set_default_index(self):
|
|
570
|
+
def set_default_index(self) -> Self:
|
|
561
571
|
if self.index is None:
|
|
562
572
|
name: str = f"{self.source_table.name}_index"
|
|
563
573
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
564
574
|
return self
|
|
565
575
|
|
|
566
576
|
@model_validator(mode="after")
|
|
567
|
-
def set_default_endpoint(self):
|
|
577
|
+
def set_default_endpoint(self) -> Self:
|
|
568
578
|
if self.endpoint is None:
|
|
569
579
|
from dao_ai.providers.databricks import (
|
|
570
580
|
DatabricksProvider,
|
|
@@ -719,7 +729,7 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
719
729
|
]
|
|
720
730
|
|
|
721
731
|
@model_validator(mode="after")
|
|
722
|
-
def update_warehouse_id(self):
|
|
732
|
+
def update_warehouse_id(self) -> Self:
|
|
723
733
|
self.warehouse_id = value_of(self.warehouse_id)
|
|
724
734
|
return self
|
|
725
735
|
|
|
@@ -737,12 +747,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
737
747
|
- Used for: discovering database instance, getting host DNS, checking instance status
|
|
738
748
|
- Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
|
|
739
749
|
|
|
740
|
-
2. **Database Connection Authentication** (configured via client_id/client_secret OR user):
|
|
750
|
+
2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
|
|
741
751
|
- Used for: connecting to the PostgreSQL database as a specific identity
|
|
752
|
+
- Service Principal: Set service_principal with workspace_host to connect as a service principal
|
|
742
753
|
- OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
|
|
743
754
|
- User Auth: Set user (and optionally password) to connect as a user identity
|
|
744
755
|
|
|
745
|
-
Example
|
|
756
|
+
Example Service Principal Configuration:
|
|
757
|
+
```yaml
|
|
758
|
+
databases:
|
|
759
|
+
my_lakebase:
|
|
760
|
+
name: my-database
|
|
761
|
+
service_principal:
|
|
762
|
+
client_id:
|
|
763
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
764
|
+
client_secret:
|
|
765
|
+
scope: my-scope
|
|
766
|
+
secret: sp-client-secret
|
|
767
|
+
workspace_host:
|
|
768
|
+
env: DATABRICKS_HOST
|
|
769
|
+
```
|
|
770
|
+
|
|
771
|
+
Example OAuth M2M Configuration (alternative):
|
|
746
772
|
```yaml
|
|
747
773
|
databases:
|
|
748
774
|
my_lakebase:
|
|
@@ -779,6 +805,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
779
805
|
node_count: Optional[int] = None
|
|
780
806
|
user: Optional[AnyVariable] = None
|
|
781
807
|
password: Optional[AnyVariable] = None
|
|
808
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
782
809
|
client_id: Optional[AnyVariable] = None
|
|
783
810
|
client_secret: Optional[AnyVariable] = None
|
|
784
811
|
workspace_host: Optional[AnyVariable] = None
|
|
@@ -796,14 +823,24 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
796
823
|
]
|
|
797
824
|
|
|
798
825
|
@model_validator(mode="after")
|
|
799
|
-
def update_instance_name(self):
|
|
826
|
+
def update_instance_name(self) -> Self:
|
|
800
827
|
if self.instance_name is None:
|
|
801
828
|
self.instance_name = self.name
|
|
802
829
|
|
|
803
830
|
return self
|
|
804
831
|
|
|
805
832
|
@model_validator(mode="after")
|
|
806
|
-
def
|
|
833
|
+
def expand_service_principal(self) -> Self:
|
|
834
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
835
|
+
if self.service_principal is not None:
|
|
836
|
+
if self.client_id is None:
|
|
837
|
+
self.client_id = self.service_principal.client_id
|
|
838
|
+
if self.client_secret is None:
|
|
839
|
+
self.client_secret = self.service_principal.client_secret
|
|
840
|
+
return self
|
|
841
|
+
|
|
842
|
+
@model_validator(mode="after")
|
|
843
|
+
def update_user(self) -> Self:
|
|
807
844
|
if self.client_id or self.user:
|
|
808
845
|
return self
|
|
809
846
|
|
|
@@ -816,7 +853,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
816
853
|
return self
|
|
817
854
|
|
|
818
855
|
@model_validator(mode="after")
|
|
819
|
-
def update_host(self):
|
|
856
|
+
def update_host(self) -> Self:
|
|
820
857
|
if self.host is not None:
|
|
821
858
|
return self
|
|
822
859
|
|
|
@@ -829,7 +866,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
829
866
|
return self
|
|
830
867
|
|
|
831
868
|
@model_validator(mode="after")
|
|
832
|
-
def validate_auth_methods(self):
|
|
869
|
+
def validate_auth_methods(self) -> Self:
|
|
833
870
|
oauth_fields: Sequence[Any] = [
|
|
834
871
|
self.workspace_host,
|
|
835
872
|
self.client_id,
|
|
@@ -849,8 +886,8 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
849
886
|
if not has_oauth and not has_user_auth:
|
|
850
887
|
raise ValueError(
|
|
851
888
|
"At least one authentication method must be provided: "
|
|
852
|
-
"either OAuth credentials (workspace_host, client_id, client_secret) "
|
|
853
|
-
"or user credentials (user, password)."
|
|
889
|
+
"either OAuth credentials (workspace_host, client_id, client_secret), "
|
|
890
|
+
"service_principal with workspace_host, or user credentials (user, password)."
|
|
854
891
|
)
|
|
855
892
|
|
|
856
893
|
return self
|
|
@@ -997,14 +1034,14 @@ class RetrieverModel(BaseModel):
|
|
|
997
1034
|
)
|
|
998
1035
|
|
|
999
1036
|
@model_validator(mode="after")
|
|
1000
|
-
def set_default_columns(self):
|
|
1037
|
+
def set_default_columns(self) -> Self:
|
|
1001
1038
|
if not self.columns:
|
|
1002
1039
|
columns: Sequence[str] = self.vector_store.columns
|
|
1003
1040
|
self.columns = columns
|
|
1004
1041
|
return self
|
|
1005
1042
|
|
|
1006
1043
|
@model_validator(mode="after")
|
|
1007
|
-
def set_default_reranker(self):
|
|
1044
|
+
def set_default_reranker(self) -> Self:
|
|
1008
1045
|
"""Convert bool to ReRankParametersModel with defaults."""
|
|
1009
1046
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
1010
1047
|
self.rerank = RerankParametersModel()
|
|
@@ -1091,7 +1128,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1091
1128
|
return [create_factory_tool(self, **kwargs)]
|
|
1092
1129
|
|
|
1093
1130
|
@model_validator(mode="after")
|
|
1094
|
-
def update_args(self):
|
|
1131
|
+
def update_args(self) -> Self:
|
|
1095
1132
|
for key, value in self.args.items():
|
|
1096
1133
|
self.args[key] = value_of(value)
|
|
1097
1134
|
return self
|
|
@@ -1111,6 +1148,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1111
1148
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1112
1149
|
args: list[str] = Field(default_factory=list)
|
|
1113
1150
|
pat: Optional[AnyVariable] = None
|
|
1151
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1114
1152
|
client_id: Optional[AnyVariable] = None
|
|
1115
1153
|
client_secret: Optional[AnyVariable] = None
|
|
1116
1154
|
workspace_host: Optional[AnyVariable] = None
|
|
@@ -1120,6 +1158,16 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1120
1158
|
sql: Optional[bool] = None
|
|
1121
1159
|
vector_search: Optional[VectorStoreModel] = None
|
|
1122
1160
|
|
|
1161
|
+
@model_validator(mode="after")
|
|
1162
|
+
def expand_service_principal(self) -> Self:
|
|
1163
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
1164
|
+
if self.service_principal is not None:
|
|
1165
|
+
if self.client_id is None:
|
|
1166
|
+
self.client_id = self.service_principal.client_id
|
|
1167
|
+
if self.client_secret is None:
|
|
1168
|
+
self.client_secret = self.service_principal.client_secret
|
|
1169
|
+
return self
|
|
1170
|
+
|
|
1123
1171
|
@property
|
|
1124
1172
|
def full_name(self) -> str:
|
|
1125
1173
|
return self.name
|
|
@@ -1129,12 +1177,12 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1129
1177
|
Get the workspace host, either from config or from workspace client.
|
|
1130
1178
|
|
|
1131
1179
|
If connection is provided, uses its workspace client.
|
|
1132
|
-
Otherwise, falls back to
|
|
1180
|
+
Otherwise, falls back to the default Databricks host.
|
|
1133
1181
|
|
|
1134
1182
|
Returns:
|
|
1135
1183
|
str: The workspace host URL without trailing slash
|
|
1136
1184
|
"""
|
|
1137
|
-
from
|
|
1185
|
+
from dao_ai.utils import get_default_databricks_host
|
|
1138
1186
|
|
|
1139
1187
|
# Try to get workspace_host from config
|
|
1140
1188
|
workspace_host: str | None = (
|
|
@@ -1147,9 +1195,13 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1147
1195
|
if self.connection:
|
|
1148
1196
|
workspace_host = self.connection.workspace_client.config.host
|
|
1149
1197
|
else:
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1198
|
+
workspace_host = get_default_databricks_host()
|
|
1199
|
+
|
|
1200
|
+
if not workspace_host:
|
|
1201
|
+
raise ValueError(
|
|
1202
|
+
"Could not determine workspace host. "
|
|
1203
|
+
"Please set workspace_host in config or DATABRICKS_HOST environment variable."
|
|
1204
|
+
)
|
|
1153
1205
|
|
|
1154
1206
|
# Remove trailing slash
|
|
1155
1207
|
return workspace_host.rstrip("/")
|
|
@@ -1356,7 +1408,7 @@ class CheckpointerModel(BaseModel):
|
|
|
1356
1408
|
database: Optional[DatabaseModel] = None
|
|
1357
1409
|
|
|
1358
1410
|
@model_validator(mode="after")
|
|
1359
|
-
def validate_postgres_requires_database(self):
|
|
1411
|
+
def validate_postgres_requires_database(self) -> Self:
|
|
1360
1412
|
if self.type == StorageType.POSTGRES and not self.database:
|
|
1361
1413
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1362
1414
|
return self
|
|
@@ -1381,7 +1433,7 @@ class StoreModel(BaseModel):
|
|
|
1381
1433
|
namespace: Optional[str] = None
|
|
1382
1434
|
|
|
1383
1435
|
@model_validator(mode="after")
|
|
1384
|
-
def validate_postgres_requires_database(self):
|
|
1436
|
+
def validate_postgres_requires_database(self) -> Self:
|
|
1385
1437
|
if self.type == StorageType.POSTGRES and not self.database:
|
|
1386
1438
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1387
1439
|
return self
|
|
@@ -1445,7 +1497,7 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1445
1497
|
return prompt_version
|
|
1446
1498
|
|
|
1447
1499
|
@model_validator(mode="after")
|
|
1448
|
-
def validate_mutually_exclusive(self):
|
|
1500
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1449
1501
|
if self.alias and self.version:
|
|
1450
1502
|
raise ValueError("Cannot specify both alias and version")
|
|
1451
1503
|
return self
|
|
@@ -1499,7 +1551,7 @@ class OrchestrationModel(BaseModel):
|
|
|
1499
1551
|
memory: Optional[MemoryModel] = None
|
|
1500
1552
|
|
|
1501
1553
|
@model_validator(mode="after")
|
|
1502
|
-
def validate_mutually_exclusive(self):
|
|
1554
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1503
1555
|
if self.supervisor is not None and self.swarm is not None:
|
|
1504
1556
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
1505
1557
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -1626,6 +1678,7 @@ class ChatHistoryModel(BaseModel):
|
|
|
1626
1678
|
class AppModel(BaseModel):
|
|
1627
1679
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1628
1680
|
name: str
|
|
1681
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1629
1682
|
description: Optional[str] = None
|
|
1630
1683
|
log_level: Optional[LogLevel] = "WARNING"
|
|
1631
1684
|
registered_model: RegisteredModelModel
|
|
@@ -1661,14 +1714,42 @@ class AppModel(BaseModel):
|
|
|
1661
1714
|
)
|
|
1662
1715
|
|
|
1663
1716
|
@model_validator(mode="after")
|
|
1664
|
-
def
|
|
1717
|
+
def set_databricks_env_vars(self) -> Self:
|
|
1718
|
+
"""Set Databricks environment variables for Model Serving.
|
|
1719
|
+
|
|
1720
|
+
Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
|
|
1721
|
+
Values explicitly provided in environment_vars take precedence.
|
|
1722
|
+
"""
|
|
1723
|
+
from dao_ai.utils import get_default_databricks_host
|
|
1724
|
+
|
|
1725
|
+
# Set DATABRICKS_HOST if not already provided
|
|
1726
|
+
if "DATABRICKS_HOST" not in self.environment_vars:
|
|
1727
|
+
host: str | None = get_default_databricks_host()
|
|
1728
|
+
if host:
|
|
1729
|
+
self.environment_vars["DATABRICKS_HOST"] = host
|
|
1730
|
+
|
|
1731
|
+
# Set service principal credentials if provided
|
|
1732
|
+
if self.service_principal is not None:
|
|
1733
|
+
if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
|
|
1734
|
+
self.environment_vars["DATABRICKS_CLIENT_ID"] = (
|
|
1735
|
+
self.service_principal.client_id
|
|
1736
|
+
)
|
|
1737
|
+
if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
|
|
1738
|
+
self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
|
|
1739
|
+
self.service_principal.client_secret
|
|
1740
|
+
)
|
|
1741
|
+
return self
|
|
1742
|
+
|
|
1743
|
+
@model_validator(mode="after")
|
|
1744
|
+
def validate_agents_not_empty(self) -> Self:
|
|
1665
1745
|
if not self.agents:
|
|
1666
1746
|
raise ValueError("At least one agent must be specified")
|
|
1667
1747
|
return self
|
|
1668
1748
|
|
|
1669
1749
|
@model_validator(mode="after")
|
|
1670
|
-
def
|
|
1750
|
+
def resolve_environment_vars(self) -> Self:
|
|
1671
1751
|
for key, value in self.environment_vars.items():
|
|
1752
|
+
updated_value: str
|
|
1672
1753
|
if isinstance(value, SecretVariableModel):
|
|
1673
1754
|
updated_value = str(value)
|
|
1674
1755
|
else:
|
|
@@ -1678,7 +1759,7 @@ class AppModel(BaseModel):
|
|
|
1678
1759
|
return self
|
|
1679
1760
|
|
|
1680
1761
|
@model_validator(mode="after")
|
|
1681
|
-
def set_default_orchestration(self):
|
|
1762
|
+
def set_default_orchestration(self) -> Self:
|
|
1682
1763
|
if self.orchestration is None:
|
|
1683
1764
|
if len(self.agents) > 1:
|
|
1684
1765
|
default_agent: AgentModel = self.agents[0]
|
|
@@ -1698,14 +1779,14 @@ class AppModel(BaseModel):
|
|
|
1698
1779
|
return self
|
|
1699
1780
|
|
|
1700
1781
|
@model_validator(mode="after")
|
|
1701
|
-
def set_default_endpoint_name(self):
|
|
1782
|
+
def set_default_endpoint_name(self) -> Self:
|
|
1702
1783
|
if self.endpoint_name is None:
|
|
1703
1784
|
self.endpoint_name = self.name
|
|
1704
1785
|
return self
|
|
1705
1786
|
|
|
1706
1787
|
@model_validator(mode="after")
|
|
1707
|
-
def set_default_agent(self):
|
|
1708
|
-
default_agent_name = self.agents[0].name
|
|
1788
|
+
def set_default_agent(self) -> Self:
|
|
1789
|
+
default_agent_name: str = self.agents[0].name
|
|
1709
1790
|
|
|
1710
1791
|
if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
|
|
1711
1792
|
self.orchestration.swarm.default_agent = default_agent_name
|
|
@@ -1713,7 +1794,7 @@ class AppModel(BaseModel):
|
|
|
1713
1794
|
return self
|
|
1714
1795
|
|
|
1715
1796
|
@model_validator(mode="after")
|
|
1716
|
-
def add_code_paths_to_sys_path(self):
|
|
1797
|
+
def add_code_paths_to_sys_path(self) -> Self:
|
|
1717
1798
|
for code_path in self.code_paths:
|
|
1718
1799
|
parent_path: str = str(Path(code_path).parent)
|
|
1719
1800
|
if parent_path not in sys.path:
|
|
@@ -1746,7 +1827,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
|
|
|
1746
1827
|
expected_facts: Optional[list[str]] = None
|
|
1747
1828
|
|
|
1748
1829
|
@model_validator(mode="after")
|
|
1749
|
-
def validate_mutually_exclusive(self):
|
|
1830
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1750
1831
|
if self.expected_response is not None and self.expected_facts is not None:
|
|
1751
1832
|
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1752
1833
|
return self
|
|
@@ -1854,7 +1935,7 @@ class PromptOptimizationModel(BaseModel):
|
|
|
1854
1935
|
return optimized_prompt
|
|
1855
1936
|
|
|
1856
1937
|
@model_validator(mode="after")
|
|
1857
|
-
def set_defaults(self):
|
|
1938
|
+
def set_defaults(self) -> Self:
|
|
1858
1939
|
# If no prompt is specified, try to use the agent's prompt
|
|
1859
1940
|
if self.prompt is None:
|
|
1860
1941
|
if isinstance(self.agent.prompt, PromptModel):
|
|
@@ -1976,6 +2057,7 @@ class ResourcesModel(BaseModel):
|
|
|
1976
2057
|
class AppConfig(BaseModel):
|
|
1977
2058
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1978
2059
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
2060
|
+
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
1979
2061
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
1980
2062
|
resources: Optional[ResourcesModel] = None
|
|
1981
2063
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
dao_ai/tools/core.py
CHANGED
|
@@ -35,7 +35,7 @@ def create_tools(tool_models: Sequence[ToolModel]) -> Sequence[RunnableLike]:
|
|
|
35
35
|
if name in tools:
|
|
36
36
|
logger.warning(f"Tools already registered for: {name}, skipping creation.")
|
|
37
37
|
continue
|
|
38
|
-
registered_tools: Sequence[RunnableLike] = tool_registry.get(name)
|
|
38
|
+
registered_tools: Sequence[RunnableLike] | None = tool_registry.get(name)
|
|
39
39
|
if registered_tools is None:
|
|
40
40
|
logger.debug(f"Creating tools for: {name}...")
|
|
41
41
|
function: AnyTool = tool_config.function
|
dao_ai/tools/unity_catalog.py
CHANGED
|
@@ -265,23 +265,52 @@ def with_partial_args(
|
|
|
265
265
|
|
|
266
266
|
Args:
|
|
267
267
|
tool: ToolModel containing the Unity Catalog function configuration
|
|
268
|
-
partial_args: Dictionary of arguments to pre-fill in the tool
|
|
268
|
+
partial_args: Dictionary of arguments to pre-fill in the tool.
|
|
269
|
+
Supports:
|
|
270
|
+
- client_id, client_secret: OAuth credentials directly
|
|
271
|
+
- service_principal: ServicePrincipalModel with client_id and client_secret
|
|
272
|
+
- host or workspace_host: Databricks workspace host
|
|
269
273
|
|
|
270
274
|
Returns:
|
|
271
275
|
StructuredTool: A LangChain tool with partial arguments pre-filled
|
|
272
276
|
"""
|
|
273
277
|
from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
|
|
274
278
|
|
|
279
|
+
from dao_ai.config import ServicePrincipalModel
|
|
280
|
+
|
|
275
281
|
logger.debug(f"with_partial_args: {tool}")
|
|
276
282
|
|
|
277
283
|
# Convert dict-based variables to CompositeVariableModel and resolve their values
|
|
278
|
-
resolved_args = {}
|
|
284
|
+
resolved_args: dict[str, Any] = {}
|
|
279
285
|
for k, v in partial_args.items():
|
|
280
286
|
if isinstance(v, dict):
|
|
281
287
|
resolved_args[k] = value_of(CompositeVariableModel(**v))
|
|
282
288
|
else:
|
|
283
289
|
resolved_args[k] = value_of(v)
|
|
284
290
|
|
|
291
|
+
# Handle service_principal - expand into client_id and client_secret
|
|
292
|
+
if "service_principal" in resolved_args:
|
|
293
|
+
sp = resolved_args.pop("service_principal")
|
|
294
|
+
if isinstance(sp, dict):
|
|
295
|
+
sp = ServicePrincipalModel(**sp)
|
|
296
|
+
if isinstance(sp, ServicePrincipalModel):
|
|
297
|
+
if "client_id" not in resolved_args:
|
|
298
|
+
resolved_args["client_id"] = value_of(sp.client_id)
|
|
299
|
+
if "client_secret" not in resolved_args:
|
|
300
|
+
resolved_args["client_secret"] = value_of(sp.client_secret)
|
|
301
|
+
|
|
302
|
+
# Normalize host/workspace_host - accept either key
|
|
303
|
+
if "workspace_host" in resolved_args and "host" not in resolved_args:
|
|
304
|
+
resolved_args["host"] = resolved_args.pop("workspace_host")
|
|
305
|
+
|
|
306
|
+
# Default host from WorkspaceClient if not provided
|
|
307
|
+
if "host" not in resolved_args:
|
|
308
|
+
from dao_ai.utils import get_default_databricks_host
|
|
309
|
+
|
|
310
|
+
host: str | None = get_default_databricks_host()
|
|
311
|
+
if host:
|
|
312
|
+
resolved_args["host"] = host
|
|
313
|
+
|
|
285
314
|
logger.debug(f"Resolved partial args: {resolved_args.keys()}")
|
|
286
315
|
|
|
287
316
|
if isinstance(tool, dict):
|
dao_ai/utils.py
CHANGED
|
@@ -38,6 +38,32 @@ def normalize_name(name: str) -> str:
|
|
|
38
38
|
return normalized.strip("_")
|
|
39
39
|
|
|
40
40
|
|
|
41
|
+
def get_default_databricks_host() -> str | None:
|
|
42
|
+
"""Get the default Databricks workspace host.
|
|
43
|
+
|
|
44
|
+
Attempts to get the host from:
|
|
45
|
+
1. DATABRICKS_HOST environment variable
|
|
46
|
+
2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
The Databricks workspace host URL, or None if not available.
|
|
50
|
+
"""
|
|
51
|
+
# Try environment variable first
|
|
52
|
+
host: str | None = os.environ.get("DATABRICKS_HOST")
|
|
53
|
+
if host:
|
|
54
|
+
return host
|
|
55
|
+
|
|
56
|
+
# Fall back to WorkspaceClient
|
|
57
|
+
try:
|
|
58
|
+
from databricks.sdk import WorkspaceClient
|
|
59
|
+
|
|
60
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
61
|
+
return w.config.host
|
|
62
|
+
except Exception:
|
|
63
|
+
logger.debug("Could not get default Databricks host from WorkspaceClient")
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
41
67
|
def dao_ai_version() -> str:
|
|
42
68
|
"""
|
|
43
69
|
Get the dao-ai package version, with fallback for source installations.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.33
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -25,7 +25,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
25
25
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
26
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
27
27
|
Requires-Python: >=3.11
|
|
28
|
-
Requires-Dist: databricks-agents>=1.
|
|
28
|
+
Requires-Dist: databricks-agents>=1.8.2
|
|
29
29
|
Requires-Dist: databricks-langchain>=0.11.0
|
|
30
30
|
Requires-Dist: databricks-mcp>=0.3.0
|
|
31
31
|
Requires-Dist: databricks-sdk[openai]>=0.67.0
|
|
@@ -33,21 +33,21 @@ Requires-Dist: ddgs>=9.9.3
|
|
|
33
33
|
Requires-Dist: flashrank>=0.2.8
|
|
34
34
|
Requires-Dist: gepa>=0.0.17
|
|
35
35
|
Requires-Dist: grandalf>=0.8
|
|
36
|
-
Requires-Dist: langchain-mcp-adapters>=0.1
|
|
36
|
+
Requires-Dist: langchain-mcp-adapters>=0.2.1
|
|
37
37
|
Requires-Dist: langchain-tavily>=0.2.11
|
|
38
38
|
Requires-Dist: langchain>=1.1.3
|
|
39
|
-
Requires-Dist: langgraph-checkpoint-postgres>=
|
|
39
|
+
Requires-Dist: langgraph-checkpoint-postgres>=3.0.2
|
|
40
40
|
Requires-Dist: langgraph-supervisor>=0.0.31
|
|
41
41
|
Requires-Dist: langgraph-swarm>=0.1.0
|
|
42
42
|
Requires-Dist: langgraph>=1.0.4
|
|
43
|
-
Requires-Dist: langmem>=0.0.
|
|
43
|
+
Requires-Dist: langmem>=0.0.30
|
|
44
44
|
Requires-Dist: loguru>=0.7.3
|
|
45
|
-
Requires-Dist: mcp>=1.
|
|
45
|
+
Requires-Dist: mcp>=1.23.3
|
|
46
46
|
Requires-Dist: mlflow>=3.7.0
|
|
47
47
|
Requires-Dist: nest-asyncio>=1.6.0
|
|
48
48
|
Requires-Dist: openevals>=0.0.19
|
|
49
49
|
Requires-Dist: openpyxl>=3.1.5
|
|
50
|
-
Requires-Dist: psycopg[binary,pool]>=3.2
|
|
50
|
+
Requires-Dist: psycopg[binary,pool]>=3.3.2
|
|
51
51
|
Requires-Dist: pydantic>=2.12.0
|
|
52
52
|
Requires-Dist: python-dotenv>=1.1.0
|
|
53
53
|
Requires-Dist: pyyaml>=6.0.2
|
|
@@ -3,7 +3,7 @@ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=gq-nsapWxDA1M6Jua3vajBvIwf0Oa6YLcB58lEtMKUo,22503
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=Uj0FgOhjnYp0qEmY44mCnp3Ijafg-381FNXt8R_QuWw,78513
|
|
7
7
|
dao_ai/graph.py,sha256=9kjJx0oFZKq5J9-Kpri4-0VCJILHYdYyhqQnj0_noxQ,8913
|
|
8
8
|
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
@@ -12,7 +12,7 @@ dao_ai/nodes.py,sha256=iQ_5vL6mt1UcRnhwgz-l1D8Ww4CMQrSMVnP_Lu7fFjU,8781
|
|
|
12
12
|
dao_ai/prompts.py,sha256=iA2Iaky7yzjwWT5cxg0cUIgwo1z1UVQua__8WPnvV6g,1633
|
|
13
13
|
dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
dao_ai/utils.py,sha256=
|
|
15
|
+
dao_ai/utils.py,sha256=oIPmz02kZ3LMntbqxUajFXh4nswOhbvEjOTi4e5_cvI,8500
|
|
16
16
|
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
17
17
|
dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
|
|
18
18
|
dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
@@ -25,17 +25,17 @@ dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
|
25
25
|
dao_ai/providers/databricks.py,sha256=rPBMdGcJvdGBRK9FZeBxkLfcTpXyxU1cs14YllyZKbY,67857
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
|
-
dao_ai/tools/core.py,sha256=
|
|
28
|
+
dao_ai/tools/core.py,sha256=kN77fWOzVY7qOs4NiW72cUxCsSTC0DnPp73s6VJEZOQ,1991
|
|
29
29
|
dao_ai/tools/genie.py,sha256=BPM_1Sk5bf7QSCFPPboWWkZKYwBwDwbGhMVp5-QDd10,5956
|
|
30
30
|
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
31
|
dao_ai/tools/mcp.py,sha256=5aQoRtx2z4xm6zgRslc78rSfEQe-mfhqov2NsiybYfc,8416
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
|
|
34
34
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
35
|
-
dao_ai/tools/unity_catalog.py,sha256=
|
|
35
|
+
dao_ai/tools/unity_catalog.py,sha256=K9t8M4spsbxbecWmV5yEZy16s_AG7AfaoxT-7IDW43I,14438
|
|
36
36
|
dao_ai/tools/vector_search.py,sha256=3cdiUaFpox25GSRNec7FKceY3DuLp7dLVH8FRA0BgeY,12624
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
41
|
-
dao_ai-0.0.
|
|
37
|
+
dao_ai-0.0.33.dist-info/METADATA,sha256=aa4BvkiG1dEvLorpgADosf1LCKRVBg-n8LtReVYJNxc,42761
|
|
38
|
+
dao_ai-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
39
|
+
dao_ai-0.0.33.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
40
|
+
dao_ai-0.0.33.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
41
|
+
dao_ai-0.0.33.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|