dao-ai 0.0.31__py3-none-any.whl → 0.0.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +162 -34
- dao_ai/prompts.py +1 -1
- dao_ai/providers/databricks.py +204 -146
- dao_ai/tools/core.py +1 -1
- dao_ai/tools/genie.py +26 -262
- dao_ai/tools/unity_catalog.py +31 -2
- dao_ai/tools/vector_search.py +4 -2
- dao_ai/utils.py +60 -7
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/METADATA +15 -15
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/RECORD +13 -13
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.33.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -332,6 +332,23 @@ class DatabricksProvider(ServiceProvider):
|
|
|
332
332
|
|
|
333
333
|
logger.debug(f"input_example: {input_example}")
|
|
334
334
|
|
|
335
|
+
# Create conda environment with configured Python version
|
|
336
|
+
# This allows deploying from environments with different Python versions
|
|
337
|
+
# (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
|
|
338
|
+
target_python_version: str = config.app.python_version
|
|
339
|
+
logger.debug(f"target_python_version: {target_python_version}")
|
|
340
|
+
|
|
341
|
+
conda_env: dict[str, Any] = {
|
|
342
|
+
"name": "mlflow-env",
|
|
343
|
+
"channels": ["conda-forge"],
|
|
344
|
+
"dependencies": [
|
|
345
|
+
f"python={target_python_version}",
|
|
346
|
+
"pip",
|
|
347
|
+
{"pip": list(pip_requirements)},
|
|
348
|
+
],
|
|
349
|
+
}
|
|
350
|
+
logger.debug(f"conda_env: {conda_env}")
|
|
351
|
+
|
|
335
352
|
with mlflow.start_run(run_name=run_name):
|
|
336
353
|
mlflow.set_tag("type", "agent")
|
|
337
354
|
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
@@ -340,7 +357,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
340
357
|
code_paths=code_paths,
|
|
341
358
|
model_config=config.model_dump(mode="json", by_alias=True),
|
|
342
359
|
name="agent",
|
|
343
|
-
|
|
360
|
+
conda_env=conda_env,
|
|
344
361
|
input_example=input_example,
|
|
345
362
|
# resources=all_resources,
|
|
346
363
|
auth_policy=auth_policy,
|
|
@@ -773,6 +790,72 @@ class DatabricksProvider(ServiceProvider):
|
|
|
773
790
|
logger.debug(f"Vector search index found: {found_endpoint_name}")
|
|
774
791
|
return found_endpoint_name
|
|
775
792
|
|
|
793
|
+
def _wait_for_database_available(
|
|
794
|
+
self,
|
|
795
|
+
workspace_client: WorkspaceClient,
|
|
796
|
+
instance_name: str,
|
|
797
|
+
max_wait_time: int = 600,
|
|
798
|
+
wait_interval: int = 10,
|
|
799
|
+
) -> None:
|
|
800
|
+
"""
|
|
801
|
+
Wait for a database instance to become AVAILABLE.
|
|
802
|
+
|
|
803
|
+
Args:
|
|
804
|
+
workspace_client: The Databricks workspace client
|
|
805
|
+
instance_name: Name of the database instance to wait for
|
|
806
|
+
max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
|
|
807
|
+
wait_interval: Time between status checks in seconds (default: 10)
|
|
808
|
+
|
|
809
|
+
Raises:
|
|
810
|
+
TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
|
|
811
|
+
RuntimeError: If the database enters a failed or deleted state
|
|
812
|
+
"""
|
|
813
|
+
import time
|
|
814
|
+
from typing import Any
|
|
815
|
+
|
|
816
|
+
logger.info(
|
|
817
|
+
f"Waiting for database instance {instance_name} to become AVAILABLE..."
|
|
818
|
+
)
|
|
819
|
+
elapsed: int = 0
|
|
820
|
+
|
|
821
|
+
while elapsed < max_wait_time:
|
|
822
|
+
try:
|
|
823
|
+
current_instance: Any = workspace_client.database.get_database_instance(
|
|
824
|
+
name=instance_name
|
|
825
|
+
)
|
|
826
|
+
current_state: str = current_instance.state
|
|
827
|
+
logger.debug(
|
|
828
|
+
f"Database instance {instance_name} state: {current_state}"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
if current_state == "AVAILABLE":
|
|
832
|
+
logger.info(f"Database instance {instance_name} is now AVAILABLE")
|
|
833
|
+
return
|
|
834
|
+
elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
|
|
835
|
+
logger.debug(
|
|
836
|
+
f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
|
|
837
|
+
)
|
|
838
|
+
time.sleep(wait_interval)
|
|
839
|
+
elapsed += wait_interval
|
|
840
|
+
elif current_state in ["STOPPED", "DELETING", "FAILED"]:
|
|
841
|
+
raise RuntimeError(
|
|
842
|
+
f"Database instance {instance_name} entered unexpected state: {current_state}"
|
|
843
|
+
)
|
|
844
|
+
else:
|
|
845
|
+
logger.warning(
|
|
846
|
+
f"Unknown database state: {current_state}, continuing to wait..."
|
|
847
|
+
)
|
|
848
|
+
time.sleep(wait_interval)
|
|
849
|
+
elapsed += wait_interval
|
|
850
|
+
except NotFound:
|
|
851
|
+
raise RuntimeError(
|
|
852
|
+
f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
raise TimeoutError(
|
|
856
|
+
f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
857
|
+
)
|
|
858
|
+
|
|
776
859
|
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
777
860
|
"""
|
|
778
861
|
Create a Lakebase database instance using the Databricks workspace client.
|
|
@@ -907,6 +990,12 @@ class DatabricksProvider(ServiceProvider):
|
|
|
907
990
|
f"Successfully created database instance: {database.instance_name}"
|
|
908
991
|
)
|
|
909
992
|
|
|
993
|
+
# Wait for the newly created database to become AVAILABLE
|
|
994
|
+
self._wait_for_database_available(
|
|
995
|
+
workspace_client, database.instance_name
|
|
996
|
+
)
|
|
997
|
+
return
|
|
998
|
+
|
|
910
999
|
except Exception as create_error:
|
|
911
1000
|
error_msg: str = str(create_error)
|
|
912
1001
|
|
|
@@ -918,6 +1007,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
918
1007
|
logger.info(
|
|
919
1008
|
f"Database instance {database.instance_name} was created concurrently by another process"
|
|
920
1009
|
)
|
|
1010
|
+
# Still need to wait for the database to become AVAILABLE
|
|
1011
|
+
self._wait_for_database_available(
|
|
1012
|
+
workspace_client, database.instance_name
|
|
1013
|
+
)
|
|
921
1014
|
return
|
|
922
1015
|
else:
|
|
923
1016
|
# Re-raise unexpected errors
|
|
@@ -1057,9 +1150,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1057
1150
|
|
|
1058
1151
|
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1059
1152
|
Otherwise, tries to load prompts in this order:
|
|
1060
|
-
1. champion alias
|
|
1061
|
-
2. latest
|
|
1062
|
-
3.
|
|
1153
|
+
1. champion alias
|
|
1154
|
+
2. latest version (max version number from search_prompt_versions)
|
|
1155
|
+
3. default alias
|
|
1156
|
+
4. Register default_template if provided
|
|
1063
1157
|
|
|
1064
1158
|
Args:
|
|
1065
1159
|
prompt_model: The prompt model configuration
|
|
@@ -1070,9 +1164,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1070
1164
|
Raises:
|
|
1071
1165
|
ValueError: If no prompt can be loaded from any source
|
|
1072
1166
|
"""
|
|
1167
|
+
|
|
1073
1168
|
prompt_name: str = prompt_model.full_name
|
|
1169
|
+
mlflow_client: MlflowClient = MlflowClient()
|
|
1074
1170
|
|
|
1075
|
-
# If explicit version or alias is specified, use it directly
|
|
1171
|
+
# If explicit version or alias is specified, use it directly
|
|
1076
1172
|
if prompt_model.version or prompt_model.alias:
|
|
1077
1173
|
try:
|
|
1078
1174
|
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
@@ -1086,73 +1182,48 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1086
1182
|
f"Failed to load prompt '{prompt_name}' with explicit "
|
|
1087
1183
|
f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
|
|
1088
1184
|
)
|
|
1089
|
-
# Fall through to
|
|
1090
|
-
else:
|
|
1091
|
-
# No explicit version/alias specified - check if default_template needs syncing first
|
|
1092
|
-
logger.debug(
|
|
1093
|
-
f"No explicit version/alias specified for '{prompt_name}', "
|
|
1094
|
-
"checking if default_template needs syncing"
|
|
1095
|
-
)
|
|
1096
|
-
|
|
1097
|
-
# If we have a default_template, check if it differs from what's in the registry
|
|
1098
|
-
# This ensures we always sync config changes before returning any alias
|
|
1099
|
-
if prompt_model.default_template:
|
|
1100
|
-
try:
|
|
1101
|
-
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1102
|
-
default_version: PromptVersion = load_prompt(default_uri)
|
|
1103
|
-
|
|
1104
|
-
if (
|
|
1105
|
-
default_version.to_single_brace_format().strip()
|
|
1106
|
-
!= prompt_model.default_template.strip()
|
|
1107
|
-
):
|
|
1108
|
-
logger.info(
|
|
1109
|
-
f"Config default_template for '{prompt_name}' differs from registry, syncing..."
|
|
1110
|
-
)
|
|
1111
|
-
return self._sync_default_template_to_registry(
|
|
1112
|
-
prompt_name,
|
|
1113
|
-
prompt_model.default_template,
|
|
1114
|
-
prompt_model.description,
|
|
1115
|
-
)
|
|
1116
|
-
except Exception as e:
|
|
1117
|
-
logger.debug(f"Could not check default alias for sync: {e}")
|
|
1185
|
+
# Fall through to try other methods
|
|
1118
1186
|
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1187
|
+
# Try to load in priority order: champion → latest → default
|
|
1188
|
+
logger.debug(
|
|
1189
|
+
f"Trying fallback order for '{prompt_name}': champion → latest → default"
|
|
1190
|
+
)
|
|
1123
1191
|
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1192
|
+
# 1. Try champion alias
|
|
1193
|
+
try:
|
|
1194
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1195
|
+
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1196
|
+
return prompt_version
|
|
1197
|
+
except Exception as e:
|
|
1198
|
+
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1132
1199
|
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
logger.
|
|
1200
|
+
# 2. Try to get latest version by finding the max version number
|
|
1201
|
+
try:
|
|
1202
|
+
versions = mlflow_client.search_prompt_versions(
|
|
1203
|
+
prompt_name, max_results=100
|
|
1204
|
+
)
|
|
1205
|
+
if versions:
|
|
1206
|
+
latest = max(versions, key=lambda v: int(v.version))
|
|
1207
|
+
logger.info(
|
|
1208
|
+
f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
|
|
1209
|
+
)
|
|
1210
|
+
return latest
|
|
1211
|
+
except Exception as e:
|
|
1212
|
+
logger.debug(f"Failed to find latest version for '{prompt_name}': {e}")
|
|
1141
1213
|
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1214
|
+
# 3. Try default alias
|
|
1215
|
+
try:
|
|
1216
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1217
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1218
|
+
return prompt_version
|
|
1219
|
+
except Exception as e:
|
|
1220
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1150
1221
|
|
|
1151
|
-
#
|
|
1222
|
+
# 4. Try to register default_template if provided
|
|
1152
1223
|
if prompt_model.default_template:
|
|
1153
1224
|
logger.info(
|
|
1154
|
-
f"
|
|
1155
|
-
"
|
|
1225
|
+
f"No existing prompt found for '{prompt_name}', "
|
|
1226
|
+
"attempting to register default_template"
|
|
1156
1227
|
)
|
|
1157
1228
|
return self._sync_default_template_to_registry(
|
|
1158
1229
|
prompt_name, prompt_model.default_template, prompt_model.description
|
|
@@ -1160,72 +1231,58 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1160
1231
|
|
|
1161
1232
|
raise ValueError(
|
|
1162
1233
|
f"Prompt '{prompt_name}' not found in registry "
|
|
1163
|
-
"(tried champion, latest, default
|
|
1234
|
+
"(tried champion alias, latest version, default alias) "
|
|
1235
|
+
"and no default_template provided"
|
|
1164
1236
|
)
|
|
1165
1237
|
|
|
1166
1238
|
def _sync_default_template_to_registry(
|
|
1167
1239
|
self, prompt_name: str, default_template: str, description: str | None = None
|
|
1168
1240
|
) -> PromptVersion:
|
|
1169
|
-
"""
|
|
1170
|
-
|
|
1241
|
+
"""Get the best available prompt version, or register default_template if possible.
|
|
1242
|
+
|
|
1243
|
+
Tries to load prompts in order: champion → latest (max version) → default.
|
|
1244
|
+
If none found and we have write permissions, registers the default_template.
|
|
1245
|
+
If registration fails (e.g., in Model Serving), logs the error and raises.
|
|
1246
|
+
"""
|
|
1247
|
+
mlflow_client: MlflowClient = MlflowClient()
|
|
1171
1248
|
|
|
1249
|
+
# Try to find an existing prompt version in priority order
|
|
1250
|
+
# 1. Try champion alias
|
|
1172
1251
|
try:
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
)
|
|
1179
|
-
if (
|
|
1180
|
-
existing.to_single_brace_format().strip()
|
|
1181
|
-
== default_template.strip()
|
|
1182
|
-
):
|
|
1183
|
-
logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
|
|
1252
|
+
champion = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1253
|
+
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1254
|
+
return champion
|
|
1255
|
+
except Exception as e:
|
|
1256
|
+
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1184
1257
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
mlflow.genai.set_prompt_alias(
|
|
1199
|
-
name=prompt_name,
|
|
1200
|
-
alias="latest",
|
|
1201
|
-
version=existing.version,
|
|
1202
|
-
)
|
|
1258
|
+
# 2. Try to get the latest version by finding the max version number
|
|
1259
|
+
try:
|
|
1260
|
+
versions = mlflow_client.search_prompt_versions(
|
|
1261
|
+
prompt_name, max_results=100
|
|
1262
|
+
)
|
|
1263
|
+
if versions:
|
|
1264
|
+
latest = max(versions, key=lambda v: int(v.version))
|
|
1265
|
+
logger.info(
|
|
1266
|
+
f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
|
|
1267
|
+
)
|
|
1268
|
+
return latest
|
|
1269
|
+
except Exception as e:
|
|
1270
|
+
logger.debug(f"Failed to search versions for '{prompt_name}': {e}")
|
|
1203
1271
|
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
)
|
|
1212
|
-
except Exception:
|
|
1213
|
-
logger.info(
|
|
1214
|
-
f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
|
|
1215
|
-
)
|
|
1216
|
-
mlflow.genai.set_prompt_alias(
|
|
1217
|
-
name=prompt_name,
|
|
1218
|
-
alias="champion",
|
|
1219
|
-
version=existing.version,
|
|
1220
|
-
)
|
|
1272
|
+
# 3. Try default alias
|
|
1273
|
+
try:
|
|
1274
|
+
default = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@default")
|
|
1275
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1276
|
+
return default
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1221
1279
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
)
|
|
1280
|
+
# No existing prompt found - try to register if we have a template
|
|
1281
|
+
logger.info(
|
|
1282
|
+
f"No existing prompt found for '{prompt_name}', attempting to register default_template"
|
|
1283
|
+
)
|
|
1227
1284
|
|
|
1228
|
-
|
|
1285
|
+
try:
|
|
1229
1286
|
commit_message = description or "Auto-synced from default_template"
|
|
1230
1287
|
prompt_version = mlflow.genai.register_prompt(
|
|
1231
1288
|
name=prompt_name,
|
|
@@ -1234,35 +1291,36 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1234
1291
|
tags={"dao_ai": dao_ai_version()},
|
|
1235
1292
|
)
|
|
1236
1293
|
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
version=prompt_version.version,
|
|
1254
|
-
)
|
|
1294
|
+
# Try to set aliases (may fail in restricted environments)
|
|
1295
|
+
try:
|
|
1296
|
+
mlflow.genai.set_prompt_alias(
|
|
1297
|
+
name=prompt_name, alias="default", version=prompt_version.version
|
|
1298
|
+
)
|
|
1299
|
+
mlflow.genai.set_prompt_alias(
|
|
1300
|
+
name=prompt_name, alias="champion", version=prompt_version.version
|
|
1301
|
+
)
|
|
1302
|
+
logger.info(
|
|
1303
|
+
f"Registered prompt '{prompt_name}' v{prompt_version.version} with aliases"
|
|
1304
|
+
)
|
|
1305
|
+
except Exception as alias_error:
|
|
1306
|
+
logger.warning(
|
|
1307
|
+
f"Registered prompt '{prompt_name}' v{prompt_version.version} "
|
|
1308
|
+
f"but failed to set aliases: {alias_error}"
|
|
1309
|
+
)
|
|
1255
1310
|
|
|
1256
|
-
logger.info(
|
|
1257
|
-
f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
|
|
1258
|
-
)
|
|
1259
1311
|
return prompt_version
|
|
1260
1312
|
|
|
1261
|
-
except Exception as
|
|
1262
|
-
logger.error(
|
|
1263
|
-
|
|
1264
|
-
f"
|
|
1265
|
-
)
|
|
1313
|
+
except Exception as reg_error:
|
|
1314
|
+
logger.error(
|
|
1315
|
+
f"Failed to register prompt '{prompt_name}': {reg_error}. "
|
|
1316
|
+
f"Please register the prompt from a notebook with write permissions before deployment."
|
|
1317
|
+
)
|
|
1318
|
+
return PromptVersion(
|
|
1319
|
+
name=prompt_name,
|
|
1320
|
+
version=1,
|
|
1321
|
+
template=default_template,
|
|
1322
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1323
|
+
)
|
|
1266
1324
|
|
|
1267
1325
|
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1268
1326
|
"""
|
dao_ai/tools/core.py
CHANGED
|
@@ -35,7 +35,7 @@ def create_tools(tool_models: Sequence[ToolModel]) -> Sequence[RunnableLike]:
|
|
|
35
35
|
if name in tools:
|
|
36
36
|
logger.warning(f"Tools already registered for: {name}, skipping creation.")
|
|
37
37
|
continue
|
|
38
|
-
registered_tools: Sequence[RunnableLike] = tool_registry.get(name)
|
|
38
|
+
registered_tools: Sequence[RunnableLike] | None = tool_registry.get(name)
|
|
39
39
|
if registered_tools is None:
|
|
40
40
|
logger.debug(f"Creating tools for: {name}...")
|
|
41
41
|
function: AnyTool = tool_config.function
|