dao-ai 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +162 -34
- dao_ai/prompts.py +1 -1
- dao_ai/providers/databricks.py +204 -146
- dao_ai/tools/core.py +1 -1
- dao_ai/tools/genie.py +26 -262
- dao_ai/tools/unity_catalog.py +31 -2
- dao_ai/tools/vector_search.py +4 -2
- dao_ai/utils.py +60 -7
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/METADATA +15 -15
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/RECORD +13 -13
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Iterator,
|
|
13
13
|
Literal,
|
|
14
14
|
Optional,
|
|
15
|
+
Self,
|
|
15
16
|
Sequence,
|
|
16
17
|
TypeAlias,
|
|
17
18
|
Union,
|
|
@@ -200,6 +201,15 @@ AnyVariable: TypeAlias = (
|
|
|
200
201
|
)
|
|
201
202
|
|
|
202
203
|
|
|
204
|
+
class ServicePrincipalModel(BaseModel):
|
|
205
|
+
model_config = ConfigDict(
|
|
206
|
+
frozen=True,
|
|
207
|
+
use_enum_values=True,
|
|
208
|
+
)
|
|
209
|
+
client_id: AnyVariable
|
|
210
|
+
client_secret: AnyVariable
|
|
211
|
+
|
|
212
|
+
|
|
203
213
|
class Privilege(str, Enum):
|
|
204
214
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
205
215
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -451,7 +461,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
451
461
|
]
|
|
452
462
|
|
|
453
463
|
@model_validator(mode="after")
|
|
454
|
-
def update_space_id(self):
|
|
464
|
+
def update_space_id(self) -> Self:
|
|
455
465
|
self.space_id = value_of(self.space_id)
|
|
456
466
|
return self
|
|
457
467
|
|
|
@@ -530,13 +540,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
530
540
|
embedding_source_column: str
|
|
531
541
|
|
|
532
542
|
@model_validator(mode="after")
|
|
533
|
-
def set_default_embedding_model(self):
|
|
543
|
+
def set_default_embedding_model(self) -> Self:
|
|
534
544
|
if not self.embedding_model:
|
|
535
545
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
536
546
|
return self
|
|
537
547
|
|
|
538
548
|
@model_validator(mode="after")
|
|
539
|
-
def set_default_primary_key(self):
|
|
549
|
+
def set_default_primary_key(self) -> Self:
|
|
540
550
|
if self.primary_key is None:
|
|
541
551
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
542
552
|
|
|
@@ -557,14 +567,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
557
567
|
return self
|
|
558
568
|
|
|
559
569
|
@model_validator(mode="after")
|
|
560
|
-
def set_default_index(self):
|
|
570
|
+
def set_default_index(self) -> Self:
|
|
561
571
|
if self.index is None:
|
|
562
572
|
name: str = f"{self.source_table.name}_index"
|
|
563
573
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
564
574
|
return self
|
|
565
575
|
|
|
566
576
|
@model_validator(mode="after")
|
|
567
|
-
def set_default_endpoint(self):
|
|
577
|
+
def set_default_endpoint(self) -> Self:
|
|
568
578
|
if self.endpoint is None:
|
|
569
579
|
from dao_ai.providers.databricks import (
|
|
570
580
|
DatabricksProvider,
|
|
@@ -719,12 +729,68 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
719
729
|
]
|
|
720
730
|
|
|
721
731
|
@model_validator(mode="after")
|
|
722
|
-
def update_warehouse_id(self):
|
|
732
|
+
def update_warehouse_id(self) -> Self:
|
|
723
733
|
self.warehouse_id = value_of(self.warehouse_id)
|
|
724
734
|
return self
|
|
725
735
|
|
|
726
736
|
|
|
727
737
|
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
738
|
+
"""
|
|
739
|
+
Configuration for a Databricks Lakebase (PostgreSQL) database instance.
|
|
740
|
+
|
|
741
|
+
Authentication Model:
|
|
742
|
+
--------------------
|
|
743
|
+
This model uses TWO separate authentication contexts:
|
|
744
|
+
|
|
745
|
+
1. **Workspace API Authentication** (inherited from IsDatabricksResource):
|
|
746
|
+
- Uses ambient/default authentication (environment variables, notebook context, app service principal)
|
|
747
|
+
- Used for: discovering database instance, getting host DNS, checking instance status
|
|
748
|
+
- Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
|
|
749
|
+
|
|
750
|
+
2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
|
|
751
|
+
- Used for: connecting to the PostgreSQL database as a specific identity
|
|
752
|
+
- Service Principal: Set service_principal with workspace_host to connect as a service principal
|
|
753
|
+
- OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
|
|
754
|
+
- User Auth: Set user (and optionally password) to connect as a user identity
|
|
755
|
+
|
|
756
|
+
Example Service Principal Configuration:
|
|
757
|
+
```yaml
|
|
758
|
+
databases:
|
|
759
|
+
my_lakebase:
|
|
760
|
+
name: my-database
|
|
761
|
+
service_principal:
|
|
762
|
+
client_id:
|
|
763
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
764
|
+
client_secret:
|
|
765
|
+
scope: my-scope
|
|
766
|
+
secret: sp-client-secret
|
|
767
|
+
workspace_host:
|
|
768
|
+
env: DATABRICKS_HOST
|
|
769
|
+
```
|
|
770
|
+
|
|
771
|
+
Example OAuth M2M Configuration (alternative):
|
|
772
|
+
```yaml
|
|
773
|
+
databases:
|
|
774
|
+
my_lakebase:
|
|
775
|
+
name: my-database
|
|
776
|
+
client_id:
|
|
777
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
778
|
+
client_secret:
|
|
779
|
+
scope: my-scope
|
|
780
|
+
secret: sp-client-secret
|
|
781
|
+
workspace_host:
|
|
782
|
+
env: DATABRICKS_HOST
|
|
783
|
+
```
|
|
784
|
+
|
|
785
|
+
Example User Configuration:
|
|
786
|
+
```yaml
|
|
787
|
+
databases:
|
|
788
|
+
my_lakebase:
|
|
789
|
+
name: my-database
|
|
790
|
+
user: my-user@databricks.com
|
|
791
|
+
```
|
|
792
|
+
"""
|
|
793
|
+
|
|
728
794
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
729
795
|
name: str
|
|
730
796
|
instance_name: Optional[str] = None
|
|
@@ -739,6 +805,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
739
805
|
node_count: Optional[int] = None
|
|
740
806
|
user: Optional[AnyVariable] = None
|
|
741
807
|
password: Optional[AnyVariable] = None
|
|
808
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
742
809
|
client_id: Optional[AnyVariable] = None
|
|
743
810
|
client_secret: Optional[AnyVariable] = None
|
|
744
811
|
workspace_host: Optional[AnyVariable] = None
|
|
@@ -756,14 +823,24 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
756
823
|
]
|
|
757
824
|
|
|
758
825
|
@model_validator(mode="after")
|
|
759
|
-
def update_instance_name(self):
|
|
826
|
+
def update_instance_name(self) -> Self:
|
|
760
827
|
if self.instance_name is None:
|
|
761
828
|
self.instance_name = self.name
|
|
762
829
|
|
|
763
830
|
return self
|
|
764
831
|
|
|
765
832
|
@model_validator(mode="after")
|
|
766
|
-
def
|
|
833
|
+
def expand_service_principal(self) -> Self:
|
|
834
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
835
|
+
if self.service_principal is not None:
|
|
836
|
+
if self.client_id is None:
|
|
837
|
+
self.client_id = self.service_principal.client_id
|
|
838
|
+
if self.client_secret is None:
|
|
839
|
+
self.client_secret = self.service_principal.client_secret
|
|
840
|
+
return self
|
|
841
|
+
|
|
842
|
+
@model_validator(mode="after")
|
|
843
|
+
def update_user(self) -> Self:
|
|
767
844
|
if self.client_id or self.user:
|
|
768
845
|
return self
|
|
769
846
|
|
|
@@ -776,7 +853,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
776
853
|
return self
|
|
777
854
|
|
|
778
855
|
@model_validator(mode="after")
|
|
779
|
-
def update_host(self):
|
|
856
|
+
def update_host(self) -> Self:
|
|
780
857
|
if self.host is not None:
|
|
781
858
|
return self
|
|
782
859
|
|
|
@@ -789,7 +866,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
789
866
|
return self
|
|
790
867
|
|
|
791
868
|
@model_validator(mode="after")
|
|
792
|
-
def validate_auth_methods(self):
|
|
869
|
+
def validate_auth_methods(self) -> Self:
|
|
793
870
|
oauth_fields: Sequence[Any] = [
|
|
794
871
|
self.workspace_host,
|
|
795
872
|
self.client_id,
|
|
@@ -809,8 +886,8 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
809
886
|
if not has_oauth and not has_user_auth:
|
|
810
887
|
raise ValueError(
|
|
811
888
|
"At least one authentication method must be provided: "
|
|
812
|
-
"either OAuth credentials (workspace_host, client_id, client_secret) "
|
|
813
|
-
"or user credentials (user, password)."
|
|
889
|
+
"either OAuth credentials (workspace_host, client_id, client_secret), "
|
|
890
|
+
"service_principal with workspace_host, or user credentials (user, password)."
|
|
814
891
|
)
|
|
815
892
|
|
|
816
893
|
return self
|
|
@@ -883,7 +960,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
883
960
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
884
961
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
885
962
|
|
|
886
|
-
provider: DatabricksProvider = DatabricksProvider()
|
|
963
|
+
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
887
964
|
provider.create_lakebase(self)
|
|
888
965
|
provider.create_lakebase_instance_role(self)
|
|
889
966
|
|
|
@@ -957,14 +1034,14 @@ class RetrieverModel(BaseModel):
|
|
|
957
1034
|
)
|
|
958
1035
|
|
|
959
1036
|
@model_validator(mode="after")
|
|
960
|
-
def set_default_columns(self):
|
|
1037
|
+
def set_default_columns(self) -> Self:
|
|
961
1038
|
if not self.columns:
|
|
962
1039
|
columns: Sequence[str] = self.vector_store.columns
|
|
963
1040
|
self.columns = columns
|
|
964
1041
|
return self
|
|
965
1042
|
|
|
966
1043
|
@model_validator(mode="after")
|
|
967
|
-
def set_default_reranker(self):
|
|
1044
|
+
def set_default_reranker(self) -> Self:
|
|
968
1045
|
"""Convert bool to ReRankParametersModel with defaults."""
|
|
969
1046
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
970
1047
|
self.rerank = RerankParametersModel()
|
|
@@ -1051,7 +1128,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1051
1128
|
return [create_factory_tool(self, **kwargs)]
|
|
1052
1129
|
|
|
1053
1130
|
@model_validator(mode="after")
|
|
1054
|
-
def update_args(self):
|
|
1131
|
+
def update_args(self) -> Self:
|
|
1055
1132
|
for key, value in self.args.items():
|
|
1056
1133
|
self.args[key] = value_of(value)
|
|
1057
1134
|
return self
|
|
@@ -1071,6 +1148,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1071
1148
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1072
1149
|
args: list[str] = Field(default_factory=list)
|
|
1073
1150
|
pat: Optional[AnyVariable] = None
|
|
1151
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1074
1152
|
client_id: Optional[AnyVariable] = None
|
|
1075
1153
|
client_secret: Optional[AnyVariable] = None
|
|
1076
1154
|
workspace_host: Optional[AnyVariable] = None
|
|
@@ -1080,6 +1158,16 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1080
1158
|
sql: Optional[bool] = None
|
|
1081
1159
|
vector_search: Optional[VectorStoreModel] = None
|
|
1082
1160
|
|
|
1161
|
+
@model_validator(mode="after")
|
|
1162
|
+
def expand_service_principal(self) -> Self:
|
|
1163
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
1164
|
+
if self.service_principal is not None:
|
|
1165
|
+
if self.client_id is None:
|
|
1166
|
+
self.client_id = self.service_principal.client_id
|
|
1167
|
+
if self.client_secret is None:
|
|
1168
|
+
self.client_secret = self.service_principal.client_secret
|
|
1169
|
+
return self
|
|
1170
|
+
|
|
1083
1171
|
@property
|
|
1084
1172
|
def full_name(self) -> str:
|
|
1085
1173
|
return self.name
|
|
@@ -1089,12 +1177,12 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1089
1177
|
Get the workspace host, either from config or from workspace client.
|
|
1090
1178
|
|
|
1091
1179
|
If connection is provided, uses its workspace client.
|
|
1092
|
-
Otherwise, falls back to
|
|
1180
|
+
Otherwise, falls back to the default Databricks host.
|
|
1093
1181
|
|
|
1094
1182
|
Returns:
|
|
1095
1183
|
str: The workspace host URL without trailing slash
|
|
1096
1184
|
"""
|
|
1097
|
-
from
|
|
1185
|
+
from dao_ai.utils import get_default_databricks_host
|
|
1098
1186
|
|
|
1099
1187
|
# Try to get workspace_host from config
|
|
1100
1188
|
workspace_host: str | None = (
|
|
@@ -1107,9 +1195,13 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1107
1195
|
if self.connection:
|
|
1108
1196
|
workspace_host = self.connection.workspace_client.config.host
|
|
1109
1197
|
else:
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1198
|
+
workspace_host = get_default_databricks_host()
|
|
1199
|
+
|
|
1200
|
+
if not workspace_host:
|
|
1201
|
+
raise ValueError(
|
|
1202
|
+
"Could not determine workspace host. "
|
|
1203
|
+
"Please set workspace_host in config or DATABRICKS_HOST environment variable."
|
|
1204
|
+
)
|
|
1113
1205
|
|
|
1114
1206
|
# Remove trailing slash
|
|
1115
1207
|
return workspace_host.rstrip("/")
|
|
@@ -1316,7 +1408,7 @@ class CheckpointerModel(BaseModel):
|
|
|
1316
1408
|
database: Optional[DatabaseModel] = None
|
|
1317
1409
|
|
|
1318
1410
|
@model_validator(mode="after")
|
|
1319
|
-
def validate_postgres_requires_database(self):
|
|
1411
|
+
def validate_postgres_requires_database(self) -> Self:
|
|
1320
1412
|
if self.type == StorageType.POSTGRES and not self.database:
|
|
1321
1413
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1322
1414
|
return self
|
|
@@ -1341,7 +1433,7 @@ class StoreModel(BaseModel):
|
|
|
1341
1433
|
namespace: Optional[str] = None
|
|
1342
1434
|
|
|
1343
1435
|
@model_validator(mode="after")
|
|
1344
|
-
def validate_postgres_requires_database(self):
|
|
1436
|
+
def validate_postgres_requires_database(self) -> Self:
|
|
1345
1437
|
if self.type == StorageType.POSTGRES and not self.database:
|
|
1346
1438
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1347
1439
|
return self
|
|
@@ -1405,7 +1497,7 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1405
1497
|
return prompt_version
|
|
1406
1498
|
|
|
1407
1499
|
@model_validator(mode="after")
|
|
1408
|
-
def validate_mutually_exclusive(self):
|
|
1500
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1409
1501
|
if self.alias and self.version:
|
|
1410
1502
|
raise ValueError("Cannot specify both alias and version")
|
|
1411
1503
|
return self
|
|
@@ -1459,7 +1551,7 @@ class OrchestrationModel(BaseModel):
|
|
|
1459
1551
|
memory: Optional[MemoryModel] = None
|
|
1460
1552
|
|
|
1461
1553
|
@model_validator(mode="after")
|
|
1462
|
-
def validate_mutually_exclusive(self):
|
|
1554
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1463
1555
|
if self.supervisor is not None and self.swarm is not None:
|
|
1464
1556
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
1465
1557
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -1586,6 +1678,7 @@ class ChatHistoryModel(BaseModel):
|
|
|
1586
1678
|
class AppModel(BaseModel):
|
|
1587
1679
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1588
1680
|
name: str
|
|
1681
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1589
1682
|
description: Optional[str] = None
|
|
1590
1683
|
log_level: Optional[LogLevel] = "WARNING"
|
|
1591
1684
|
registered_model: RegisteredModelModel
|
|
@@ -1613,16 +1706,50 @@ class AppModel(BaseModel):
|
|
|
1613
1706
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1614
1707
|
code_paths: list[str] = Field(default_factory=list)
|
|
1615
1708
|
pip_requirements: list[str] = Field(default_factory=list)
|
|
1709
|
+
python_version: Optional[str] = Field(
|
|
1710
|
+
default="3.12",
|
|
1711
|
+
description="Python version for Model Serving deployment. Defaults to 3.12 "
|
|
1712
|
+
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
1713
|
+
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
1714
|
+
)
|
|
1715
|
+
|
|
1716
|
+
@model_validator(mode="after")
|
|
1717
|
+
def set_databricks_env_vars(self) -> Self:
|
|
1718
|
+
"""Set Databricks environment variables for Model Serving.
|
|
1719
|
+
|
|
1720
|
+
Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
|
|
1721
|
+
Values explicitly provided in environment_vars take precedence.
|
|
1722
|
+
"""
|
|
1723
|
+
from dao_ai.utils import get_default_databricks_host
|
|
1724
|
+
|
|
1725
|
+
# Set DATABRICKS_HOST if not already provided
|
|
1726
|
+
if "DATABRICKS_HOST" not in self.environment_vars:
|
|
1727
|
+
host: str | None = get_default_databricks_host()
|
|
1728
|
+
if host:
|
|
1729
|
+
self.environment_vars["DATABRICKS_HOST"] = host
|
|
1730
|
+
|
|
1731
|
+
# Set service principal credentials if provided
|
|
1732
|
+
if self.service_principal is not None:
|
|
1733
|
+
if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
|
|
1734
|
+
self.environment_vars["DATABRICKS_CLIENT_ID"] = (
|
|
1735
|
+
self.service_principal.client_id
|
|
1736
|
+
)
|
|
1737
|
+
if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
|
|
1738
|
+
self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
|
|
1739
|
+
self.service_principal.client_secret
|
|
1740
|
+
)
|
|
1741
|
+
return self
|
|
1616
1742
|
|
|
1617
1743
|
@model_validator(mode="after")
|
|
1618
|
-
def validate_agents_not_empty(self):
|
|
1744
|
+
def validate_agents_not_empty(self) -> Self:
|
|
1619
1745
|
if not self.agents:
|
|
1620
1746
|
raise ValueError("At least one agent must be specified")
|
|
1621
1747
|
return self
|
|
1622
1748
|
|
|
1623
1749
|
@model_validator(mode="after")
|
|
1624
|
-
def
|
|
1750
|
+
def resolve_environment_vars(self) -> Self:
|
|
1625
1751
|
for key, value in self.environment_vars.items():
|
|
1752
|
+
updated_value: str
|
|
1626
1753
|
if isinstance(value, SecretVariableModel):
|
|
1627
1754
|
updated_value = str(value)
|
|
1628
1755
|
else:
|
|
@@ -1632,7 +1759,7 @@ class AppModel(BaseModel):
|
|
|
1632
1759
|
return self
|
|
1633
1760
|
|
|
1634
1761
|
@model_validator(mode="after")
|
|
1635
|
-
def set_default_orchestration(self):
|
|
1762
|
+
def set_default_orchestration(self) -> Self:
|
|
1636
1763
|
if self.orchestration is None:
|
|
1637
1764
|
if len(self.agents) > 1:
|
|
1638
1765
|
default_agent: AgentModel = self.agents[0]
|
|
@@ -1652,14 +1779,14 @@ class AppModel(BaseModel):
|
|
|
1652
1779
|
return self
|
|
1653
1780
|
|
|
1654
1781
|
@model_validator(mode="after")
|
|
1655
|
-
def set_default_endpoint_name(self):
|
|
1782
|
+
def set_default_endpoint_name(self) -> Self:
|
|
1656
1783
|
if self.endpoint_name is None:
|
|
1657
1784
|
self.endpoint_name = self.name
|
|
1658
1785
|
return self
|
|
1659
1786
|
|
|
1660
1787
|
@model_validator(mode="after")
|
|
1661
|
-
def set_default_agent(self):
|
|
1662
|
-
default_agent_name = self.agents[0].name
|
|
1788
|
+
def set_default_agent(self) -> Self:
|
|
1789
|
+
default_agent_name: str = self.agents[0].name
|
|
1663
1790
|
|
|
1664
1791
|
if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
|
|
1665
1792
|
self.orchestration.swarm.default_agent = default_agent_name
|
|
@@ -1667,7 +1794,7 @@ class AppModel(BaseModel):
|
|
|
1667
1794
|
return self
|
|
1668
1795
|
|
|
1669
1796
|
@model_validator(mode="after")
|
|
1670
|
-
def add_code_paths_to_sys_path(self):
|
|
1797
|
+
def add_code_paths_to_sys_path(self) -> Self:
|
|
1671
1798
|
for code_path in self.code_paths:
|
|
1672
1799
|
parent_path: str = str(Path(code_path).parent)
|
|
1673
1800
|
if parent_path not in sys.path:
|
|
@@ -1700,7 +1827,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
|
|
|
1700
1827
|
expected_facts: Optional[list[str]] = None
|
|
1701
1828
|
|
|
1702
1829
|
@model_validator(mode="after")
|
|
1703
|
-
def validate_mutually_exclusive(self):
|
|
1830
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1704
1831
|
if self.expected_response is not None and self.expected_facts is not None:
|
|
1705
1832
|
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1706
1833
|
return self
|
|
@@ -1808,7 +1935,7 @@ class PromptOptimizationModel(BaseModel):
|
|
|
1808
1935
|
return optimized_prompt
|
|
1809
1936
|
|
|
1810
1937
|
@model_validator(mode="after")
|
|
1811
|
-
def set_defaults(self):
|
|
1938
|
+
def set_defaults(self) -> Self:
|
|
1812
1939
|
# If no prompt is specified, try to use the agent's prompt
|
|
1813
1940
|
if self.prompt is None:
|
|
1814
1941
|
if isinstance(self.agent.prompt, PromptModel):
|
|
@@ -1930,6 +2057,7 @@ class ResourcesModel(BaseModel):
|
|
|
1930
2057
|
class AppConfig(BaseModel):
|
|
1931
2058
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1932
2059
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
2060
|
+
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
1933
2061
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
1934
2062
|
resources: Optional[ResourcesModel] = None
|
|
1935
2063
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
dao_ai/prompts.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from typing import Any, Callable, Optional, Sequence
|
|
2
2
|
|
|
3
|
-
from langchain.prompts import PromptTemplate
|
|
4
3
|
from langchain_core.messages import (
|
|
5
4
|
BaseMessage,
|
|
6
5
|
SystemMessage,
|
|
7
6
|
)
|
|
7
|
+
from langchain_core.prompts import PromptTemplate
|
|
8
8
|
from langchain_core.runnables import RunnableConfig
|
|
9
9
|
from loguru import logger
|
|
10
10
|
|