dao-ai 0.0.31__py3-none-any.whl → 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +47 -1
- dao_ai/prompts.py +1 -1
- dao_ai/providers/databricks.py +204 -146
- dao_ai/tools/genie.py +26 -262
- dao_ai/tools/vector_search.py +4 -2
- dao_ai/utils.py +34 -7
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.32.dist-info}/METADATA +9 -9
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.32.dist-info}/RECORD +11 -11
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.32.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.32.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.31.dist-info → dao_ai-0.0.32.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -725,6 +725,46 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
725
725
|
|
|
726
726
|
|
|
727
727
|
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
728
|
+
"""
|
|
729
|
+
Configuration for a Databricks Lakebase (PostgreSQL) database instance.
|
|
730
|
+
|
|
731
|
+
Authentication Model:
|
|
732
|
+
--------------------
|
|
733
|
+
This model uses TWO separate authentication contexts:
|
|
734
|
+
|
|
735
|
+
1. **Workspace API Authentication** (inherited from IsDatabricksResource):
|
|
736
|
+
- Uses ambient/default authentication (environment variables, notebook context, app service principal)
|
|
737
|
+
- Used for: discovering database instance, getting host DNS, checking instance status
|
|
738
|
+
- Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
|
|
739
|
+
|
|
740
|
+
2. **Database Connection Authentication** (configured via client_id/client_secret OR user):
|
|
741
|
+
- Used for: connecting to the PostgreSQL database as a specific identity
|
|
742
|
+
- OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
|
|
743
|
+
- User Auth: Set user (and optionally password) to connect as a user identity
|
|
744
|
+
|
|
745
|
+
Example OAuth M2M Configuration:
|
|
746
|
+
```yaml
|
|
747
|
+
databases:
|
|
748
|
+
my_lakebase:
|
|
749
|
+
name: my-database
|
|
750
|
+
client_id:
|
|
751
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
752
|
+
client_secret:
|
|
753
|
+
scope: my-scope
|
|
754
|
+
secret: sp-client-secret
|
|
755
|
+
workspace_host:
|
|
756
|
+
env: DATABRICKS_HOST
|
|
757
|
+
```
|
|
758
|
+
|
|
759
|
+
Example User Configuration:
|
|
760
|
+
```yaml
|
|
761
|
+
databases:
|
|
762
|
+
my_lakebase:
|
|
763
|
+
name: my-database
|
|
764
|
+
user: my-user@databricks.com
|
|
765
|
+
```
|
|
766
|
+
"""
|
|
767
|
+
|
|
728
768
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
729
769
|
name: str
|
|
730
770
|
instance_name: Optional[str] = None
|
|
@@ -883,7 +923,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
883
923
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
884
924
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
885
925
|
|
|
886
|
-
provider: DatabricksProvider = DatabricksProvider()
|
|
926
|
+
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
887
927
|
provider.create_lakebase(self)
|
|
888
928
|
provider.create_lakebase_instance_role(self)
|
|
889
929
|
|
|
@@ -1613,6 +1653,12 @@ class AppModel(BaseModel):
|
|
|
1613
1653
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1614
1654
|
code_paths: list[str] = Field(default_factory=list)
|
|
1615
1655
|
pip_requirements: list[str] = Field(default_factory=list)
|
|
1656
|
+
python_version: Optional[str] = Field(
|
|
1657
|
+
default="3.12",
|
|
1658
|
+
description="Python version for Model Serving deployment. Defaults to 3.12 "
|
|
1659
|
+
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
1660
|
+
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
1661
|
+
)
|
|
1616
1662
|
|
|
1617
1663
|
@model_validator(mode="after")
|
|
1618
1664
|
def validate_agents_not_empty(self):
|
dao_ai/prompts.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from typing import Any, Callable, Optional, Sequence
|
|
2
2
|
|
|
3
|
-
from langchain.prompts import PromptTemplate
|
|
4
3
|
from langchain_core.messages import (
|
|
5
4
|
BaseMessage,
|
|
6
5
|
SystemMessage,
|
|
7
6
|
)
|
|
7
|
+
from langchain_core.prompts import PromptTemplate
|
|
8
8
|
from langchain_core.runnables import RunnableConfig
|
|
9
9
|
from loguru import logger
|
|
10
10
|
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -332,6 +332,23 @@ class DatabricksProvider(ServiceProvider):
|
|
|
332
332
|
|
|
333
333
|
logger.debug(f"input_example: {input_example}")
|
|
334
334
|
|
|
335
|
+
# Create conda environment with configured Python version
|
|
336
|
+
# This allows deploying from environments with different Python versions
|
|
337
|
+
# (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
|
|
338
|
+
target_python_version: str = config.app.python_version
|
|
339
|
+
logger.debug(f"target_python_version: {target_python_version}")
|
|
340
|
+
|
|
341
|
+
conda_env: dict[str, Any] = {
|
|
342
|
+
"name": "mlflow-env",
|
|
343
|
+
"channels": ["conda-forge"],
|
|
344
|
+
"dependencies": [
|
|
345
|
+
f"python={target_python_version}",
|
|
346
|
+
"pip",
|
|
347
|
+
{"pip": list(pip_requirements)},
|
|
348
|
+
],
|
|
349
|
+
}
|
|
350
|
+
logger.debug(f"conda_env: {conda_env}")
|
|
351
|
+
|
|
335
352
|
with mlflow.start_run(run_name=run_name):
|
|
336
353
|
mlflow.set_tag("type", "agent")
|
|
337
354
|
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
@@ -340,7 +357,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
340
357
|
code_paths=code_paths,
|
|
341
358
|
model_config=config.model_dump(mode="json", by_alias=True),
|
|
342
359
|
name="agent",
|
|
343
|
-
|
|
360
|
+
conda_env=conda_env,
|
|
344
361
|
input_example=input_example,
|
|
345
362
|
# resources=all_resources,
|
|
346
363
|
auth_policy=auth_policy,
|
|
@@ -773,6 +790,72 @@ class DatabricksProvider(ServiceProvider):
|
|
|
773
790
|
logger.debug(f"Vector search index found: {found_endpoint_name}")
|
|
774
791
|
return found_endpoint_name
|
|
775
792
|
|
|
793
|
+
def _wait_for_database_available(
|
|
794
|
+
self,
|
|
795
|
+
workspace_client: WorkspaceClient,
|
|
796
|
+
instance_name: str,
|
|
797
|
+
max_wait_time: int = 600,
|
|
798
|
+
wait_interval: int = 10,
|
|
799
|
+
) -> None:
|
|
800
|
+
"""
|
|
801
|
+
Wait for a database instance to become AVAILABLE.
|
|
802
|
+
|
|
803
|
+
Args:
|
|
804
|
+
workspace_client: The Databricks workspace client
|
|
805
|
+
instance_name: Name of the database instance to wait for
|
|
806
|
+
max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
|
|
807
|
+
wait_interval: Time between status checks in seconds (default: 10)
|
|
808
|
+
|
|
809
|
+
Raises:
|
|
810
|
+
TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
|
|
811
|
+
RuntimeError: If the database enters a failed or deleted state
|
|
812
|
+
"""
|
|
813
|
+
import time
|
|
814
|
+
from typing import Any
|
|
815
|
+
|
|
816
|
+
logger.info(
|
|
817
|
+
f"Waiting for database instance {instance_name} to become AVAILABLE..."
|
|
818
|
+
)
|
|
819
|
+
elapsed: int = 0
|
|
820
|
+
|
|
821
|
+
while elapsed < max_wait_time:
|
|
822
|
+
try:
|
|
823
|
+
current_instance: Any = workspace_client.database.get_database_instance(
|
|
824
|
+
name=instance_name
|
|
825
|
+
)
|
|
826
|
+
current_state: str = current_instance.state
|
|
827
|
+
logger.debug(
|
|
828
|
+
f"Database instance {instance_name} state: {current_state}"
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
if current_state == "AVAILABLE":
|
|
832
|
+
logger.info(f"Database instance {instance_name} is now AVAILABLE")
|
|
833
|
+
return
|
|
834
|
+
elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
|
|
835
|
+
logger.debug(
|
|
836
|
+
f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
|
|
837
|
+
)
|
|
838
|
+
time.sleep(wait_interval)
|
|
839
|
+
elapsed += wait_interval
|
|
840
|
+
elif current_state in ["STOPPED", "DELETING", "FAILED"]:
|
|
841
|
+
raise RuntimeError(
|
|
842
|
+
f"Database instance {instance_name} entered unexpected state: {current_state}"
|
|
843
|
+
)
|
|
844
|
+
else:
|
|
845
|
+
logger.warning(
|
|
846
|
+
f"Unknown database state: {current_state}, continuing to wait..."
|
|
847
|
+
)
|
|
848
|
+
time.sleep(wait_interval)
|
|
849
|
+
elapsed += wait_interval
|
|
850
|
+
except NotFound:
|
|
851
|
+
raise RuntimeError(
|
|
852
|
+
f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
|
|
853
|
+
)
|
|
854
|
+
|
|
855
|
+
raise TimeoutError(
|
|
856
|
+
f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
857
|
+
)
|
|
858
|
+
|
|
776
859
|
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
777
860
|
"""
|
|
778
861
|
Create a Lakebase database instance using the Databricks workspace client.
|
|
@@ -907,6 +990,12 @@ class DatabricksProvider(ServiceProvider):
|
|
|
907
990
|
f"Successfully created database instance: {database.instance_name}"
|
|
908
991
|
)
|
|
909
992
|
|
|
993
|
+
# Wait for the newly created database to become AVAILABLE
|
|
994
|
+
self._wait_for_database_available(
|
|
995
|
+
workspace_client, database.instance_name
|
|
996
|
+
)
|
|
997
|
+
return
|
|
998
|
+
|
|
910
999
|
except Exception as create_error:
|
|
911
1000
|
error_msg: str = str(create_error)
|
|
912
1001
|
|
|
@@ -918,6 +1007,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
918
1007
|
logger.info(
|
|
919
1008
|
f"Database instance {database.instance_name} was created concurrently by another process"
|
|
920
1009
|
)
|
|
1010
|
+
# Still need to wait for the database to become AVAILABLE
|
|
1011
|
+
self._wait_for_database_available(
|
|
1012
|
+
workspace_client, database.instance_name
|
|
1013
|
+
)
|
|
921
1014
|
return
|
|
922
1015
|
else:
|
|
923
1016
|
# Re-raise unexpected errors
|
|
@@ -1057,9 +1150,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1057
1150
|
|
|
1058
1151
|
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1059
1152
|
Otherwise, tries to load prompts in this order:
|
|
1060
|
-
1. champion alias
|
|
1061
|
-
2. latest
|
|
1062
|
-
3.
|
|
1153
|
+
1. champion alias
|
|
1154
|
+
2. latest version (max version number from search_prompt_versions)
|
|
1155
|
+
3. default alias
|
|
1156
|
+
4. Register default_template if provided
|
|
1063
1157
|
|
|
1064
1158
|
Args:
|
|
1065
1159
|
prompt_model: The prompt model configuration
|
|
@@ -1070,9 +1164,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1070
1164
|
Raises:
|
|
1071
1165
|
ValueError: If no prompt can be loaded from any source
|
|
1072
1166
|
"""
|
|
1167
|
+
|
|
1073
1168
|
prompt_name: str = prompt_model.full_name
|
|
1169
|
+
mlflow_client: MlflowClient = MlflowClient()
|
|
1074
1170
|
|
|
1075
|
-
# If explicit version or alias is specified, use it directly
|
|
1171
|
+
# If explicit version or alias is specified, use it directly
|
|
1076
1172
|
if prompt_model.version or prompt_model.alias:
|
|
1077
1173
|
try:
|
|
1078
1174
|
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
@@ -1086,73 +1182,48 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1086
1182
|
f"Failed to load prompt '{prompt_name}' with explicit "
|
|
1087
1183
|
f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
|
|
1088
1184
|
)
|
|
1089
|
-
# Fall through to
|
|
1090
|
-
else:
|
|
1091
|
-
# No explicit version/alias specified - check if default_template needs syncing first
|
|
1092
|
-
logger.debug(
|
|
1093
|
-
f"No explicit version/alias specified for '{prompt_name}', "
|
|
1094
|
-
"checking if default_template needs syncing"
|
|
1095
|
-
)
|
|
1096
|
-
|
|
1097
|
-
# If we have a default_template, check if it differs from what's in the registry
|
|
1098
|
-
# This ensures we always sync config changes before returning any alias
|
|
1099
|
-
if prompt_model.default_template:
|
|
1100
|
-
try:
|
|
1101
|
-
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1102
|
-
default_version: PromptVersion = load_prompt(default_uri)
|
|
1103
|
-
|
|
1104
|
-
if (
|
|
1105
|
-
default_version.to_single_brace_format().strip()
|
|
1106
|
-
!= prompt_model.default_template.strip()
|
|
1107
|
-
):
|
|
1108
|
-
logger.info(
|
|
1109
|
-
f"Config default_template for '{prompt_name}' differs from registry, syncing..."
|
|
1110
|
-
)
|
|
1111
|
-
return self._sync_default_template_to_registry(
|
|
1112
|
-
prompt_name,
|
|
1113
|
-
prompt_model.default_template,
|
|
1114
|
-
prompt_model.description,
|
|
1115
|
-
)
|
|
1116
|
-
except Exception as e:
|
|
1117
|
-
logger.debug(f"Could not check default alias for sync: {e}")
|
|
1185
|
+
# Fall through to try other methods
|
|
1118
1186
|
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1187
|
+
# Try to load in priority order: champion → latest → default
|
|
1188
|
+
logger.debug(
|
|
1189
|
+
f"Trying fallback order for '{prompt_name}': champion → latest → default"
|
|
1190
|
+
)
|
|
1123
1191
|
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1192
|
+
# 1. Try champion alias
|
|
1193
|
+
try:
|
|
1194
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1195
|
+
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1196
|
+
return prompt_version
|
|
1197
|
+
except Exception as e:
|
|
1198
|
+
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1132
1199
|
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
logger.
|
|
1200
|
+
# 2. Try to get latest version by finding the max version number
|
|
1201
|
+
try:
|
|
1202
|
+
versions = mlflow_client.search_prompt_versions(
|
|
1203
|
+
prompt_name, max_results=100
|
|
1204
|
+
)
|
|
1205
|
+
if versions:
|
|
1206
|
+
latest = max(versions, key=lambda v: int(v.version))
|
|
1207
|
+
logger.info(
|
|
1208
|
+
f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
|
|
1209
|
+
)
|
|
1210
|
+
return latest
|
|
1211
|
+
except Exception as e:
|
|
1212
|
+
logger.debug(f"Failed to find latest version for '{prompt_name}': {e}")
|
|
1141
1213
|
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1214
|
+
# 3. Try default alias
|
|
1215
|
+
try:
|
|
1216
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1217
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1218
|
+
return prompt_version
|
|
1219
|
+
except Exception as e:
|
|
1220
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1150
1221
|
|
|
1151
|
-
#
|
|
1222
|
+
# 4. Try to register default_template if provided
|
|
1152
1223
|
if prompt_model.default_template:
|
|
1153
1224
|
logger.info(
|
|
1154
|
-
f"
|
|
1155
|
-
"
|
|
1225
|
+
f"No existing prompt found for '{prompt_name}', "
|
|
1226
|
+
"attempting to register default_template"
|
|
1156
1227
|
)
|
|
1157
1228
|
return self._sync_default_template_to_registry(
|
|
1158
1229
|
prompt_name, prompt_model.default_template, prompt_model.description
|
|
@@ -1160,72 +1231,58 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1160
1231
|
|
|
1161
1232
|
raise ValueError(
|
|
1162
1233
|
f"Prompt '{prompt_name}' not found in registry "
|
|
1163
|
-
"(tried champion, latest, default
|
|
1234
|
+
"(tried champion alias, latest version, default alias) "
|
|
1235
|
+
"and no default_template provided"
|
|
1164
1236
|
)
|
|
1165
1237
|
|
|
1166
1238
|
def _sync_default_template_to_registry(
|
|
1167
1239
|
self, prompt_name: str, default_template: str, description: str | None = None
|
|
1168
1240
|
) -> PromptVersion:
|
|
1169
|
-
"""
|
|
1170
|
-
|
|
1241
|
+
"""Get the best available prompt version, or register default_template if possible.
|
|
1242
|
+
|
|
1243
|
+
Tries to load prompts in order: champion → latest (max version) → default.
|
|
1244
|
+
If none found and we have write permissions, registers the default_template.
|
|
1245
|
+
If registration fails (e.g., in Model Serving), logs the error and raises.
|
|
1246
|
+
"""
|
|
1247
|
+
mlflow_client: MlflowClient = MlflowClient()
|
|
1171
1248
|
|
|
1249
|
+
# Try to find an existing prompt version in priority order
|
|
1250
|
+
# 1. Try champion alias
|
|
1172
1251
|
try:
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
)
|
|
1179
|
-
if (
|
|
1180
|
-
existing.to_single_brace_format().strip()
|
|
1181
|
-
== default_template.strip()
|
|
1182
|
-
):
|
|
1183
|
-
logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
|
|
1252
|
+
champion = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1253
|
+
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1254
|
+
return champion
|
|
1255
|
+
except Exception as e:
|
|
1256
|
+
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1184
1257
|
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1198
|
-
mlflow.genai.set_prompt_alias(
|
|
1199
|
-
name=prompt_name,
|
|
1200
|
-
alias="latest",
|
|
1201
|
-
version=existing.version,
|
|
1202
|
-
)
|
|
1258
|
+
# 2. Try to get the latest version by finding the max version number
|
|
1259
|
+
try:
|
|
1260
|
+
versions = mlflow_client.search_prompt_versions(
|
|
1261
|
+
prompt_name, max_results=100
|
|
1262
|
+
)
|
|
1263
|
+
if versions:
|
|
1264
|
+
latest = max(versions, key=lambda v: int(v.version))
|
|
1265
|
+
logger.info(
|
|
1266
|
+
f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
|
|
1267
|
+
)
|
|
1268
|
+
return latest
|
|
1269
|
+
except Exception as e:
|
|
1270
|
+
logger.debug(f"Failed to search versions for '{prompt_name}': {e}")
|
|
1203
1271
|
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
)
|
|
1212
|
-
except Exception:
|
|
1213
|
-
logger.info(
|
|
1214
|
-
f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
|
|
1215
|
-
)
|
|
1216
|
-
mlflow.genai.set_prompt_alias(
|
|
1217
|
-
name=prompt_name,
|
|
1218
|
-
alias="champion",
|
|
1219
|
-
version=existing.version,
|
|
1220
|
-
)
|
|
1272
|
+
# 3. Try default alias
|
|
1273
|
+
try:
|
|
1274
|
+
default = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@default")
|
|
1275
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1276
|
+
return default
|
|
1277
|
+
except Exception as e:
|
|
1278
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1221
1279
|
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
)
|
|
1280
|
+
# No existing prompt found - try to register if we have a template
|
|
1281
|
+
logger.info(
|
|
1282
|
+
f"No existing prompt found for '{prompt_name}', attempting to register default_template"
|
|
1283
|
+
)
|
|
1227
1284
|
|
|
1228
|
-
|
|
1285
|
+
try:
|
|
1229
1286
|
commit_message = description or "Auto-synced from default_template"
|
|
1230
1287
|
prompt_version = mlflow.genai.register_prompt(
|
|
1231
1288
|
name=prompt_name,
|
|
@@ -1234,35 +1291,36 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1234
1291
|
tags={"dao_ai": dao_ai_version()},
|
|
1235
1292
|
)
|
|
1236
1293
|
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
version=prompt_version.version,
|
|
1254
|
-
)
|
|
1294
|
+
# Try to set aliases (may fail in restricted environments)
|
|
1295
|
+
try:
|
|
1296
|
+
mlflow.genai.set_prompt_alias(
|
|
1297
|
+
name=prompt_name, alias="default", version=prompt_version.version
|
|
1298
|
+
)
|
|
1299
|
+
mlflow.genai.set_prompt_alias(
|
|
1300
|
+
name=prompt_name, alias="champion", version=prompt_version.version
|
|
1301
|
+
)
|
|
1302
|
+
logger.info(
|
|
1303
|
+
f"Registered prompt '{prompt_name}' v{prompt_version.version} with aliases"
|
|
1304
|
+
)
|
|
1305
|
+
except Exception as alias_error:
|
|
1306
|
+
logger.warning(
|
|
1307
|
+
f"Registered prompt '{prompt_name}' v{prompt_version.version} "
|
|
1308
|
+
f"but failed to set aliases: {alias_error}"
|
|
1309
|
+
)
|
|
1255
1310
|
|
|
1256
|
-
logger.info(
|
|
1257
|
-
f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
|
|
1258
|
-
)
|
|
1259
1311
|
return prompt_version
|
|
1260
1312
|
|
|
1261
|
-
except Exception as
|
|
1262
|
-
logger.error(
|
|
1263
|
-
|
|
1264
|
-
f"
|
|
1265
|
-
)
|
|
1313
|
+
except Exception as reg_error:
|
|
1314
|
+
logger.error(
|
|
1315
|
+
f"Failed to register prompt '{prompt_name}': {reg_error}. "
|
|
1316
|
+
f"Please register the prompt from a notebook with write permissions before deployment."
|
|
1317
|
+
)
|
|
1318
|
+
return PromptVersion(
|
|
1319
|
+
name=prompt_name,
|
|
1320
|
+
version=1,
|
|
1321
|
+
template=default_template,
|
|
1322
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1323
|
+
)
|
|
1266
1324
|
|
|
1267
1325
|
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1268
1326
|
"""
|
dao_ai/tools/genie.py
CHANGED
|
@@ -1,15 +1,10 @@
|
|
|
1
|
-
import bisect
|
|
2
1
|
import json
|
|
3
2
|
import os
|
|
4
|
-
import time
|
|
5
|
-
from dataclasses import asdict, dataclass
|
|
6
|
-
from datetime import datetime
|
|
7
3
|
from textwrap import dedent
|
|
8
|
-
from typing import Annotated, Any, Callable
|
|
4
|
+
from typing import Annotated, Any, Callable
|
|
9
5
|
|
|
10
|
-
import mlflow
|
|
11
6
|
import pandas as pd
|
|
12
|
-
from
|
|
7
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
13
8
|
from langchain_core.messages import ToolMessage
|
|
14
9
|
from langchain_core.tools import InjectedToolCallId, tool
|
|
15
10
|
from langgraph.prebuilt import InjectedState
|
|
@@ -19,28 +14,6 @@ from pydantic import BaseModel, Field
|
|
|
19
14
|
|
|
20
15
|
from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
|
|
21
16
|
|
|
22
|
-
MAX_TOKENS_OF_DATA: int = 20000
|
|
23
|
-
MAX_ITERATIONS: int = 50
|
|
24
|
-
DEFAULT_POLLING_INTERVAL_SECS: int = 2
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
def _count_tokens(text):
|
|
28
|
-
import tiktoken
|
|
29
|
-
|
|
30
|
-
encoding = tiktoken.encoding_for_model("gpt-4o")
|
|
31
|
-
return len(encoding.encode(text))
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@dataclass
|
|
35
|
-
class GenieResponse:
|
|
36
|
-
conversation_id: str
|
|
37
|
-
result: Union[str, pd.DataFrame]
|
|
38
|
-
query: Optional[str] = ""
|
|
39
|
-
description: Optional[str] = ""
|
|
40
|
-
|
|
41
|
-
def to_json(self):
|
|
42
|
-
return json.dumps(asdict(self))
|
|
43
|
-
|
|
44
17
|
|
|
45
18
|
class GenieToolInput(BaseModel):
|
|
46
19
|
"""Input schema for the Genie tool."""
|
|
@@ -50,235 +23,29 @@ class GenieToolInput(BaseModel):
|
|
|
50
23
|
)
|
|
51
24
|
|
|
52
25
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
return query_result.strip()
|
|
60
|
-
|
|
61
|
-
def is_too_big(n):
|
|
62
|
-
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
|
|
63
|
-
|
|
64
|
-
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
|
|
65
|
-
# Passing True, as this is the target value we are looking for when _is_too_big returns
|
|
66
|
-
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
|
|
67
|
-
|
|
68
|
-
# Slice to the found limit
|
|
69
|
-
truncated_df = dataframe.iloc[:cutoff]
|
|
70
|
-
|
|
71
|
-
# Edge case: Cannot return any rows because of tokens so return an empty string
|
|
72
|
-
if len(truncated_df) == 0:
|
|
73
|
-
return ""
|
|
74
|
-
|
|
75
|
-
truncated_result = truncated_df.to_markdown()
|
|
76
|
-
|
|
77
|
-
# Double-check edge case if we overshot by one
|
|
78
|
-
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
|
|
79
|
-
truncated_result = truncated_df.iloc[:-1].to_markdown()
|
|
80
|
-
return truncated_result
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
@mlflow.trace(span_type="PARSER")
|
|
84
|
-
def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
|
|
85
|
-
output = resp["result"]
|
|
86
|
-
if not output:
|
|
87
|
-
return "EMPTY"
|
|
88
|
-
|
|
89
|
-
columns = resp["manifest"]["schema"]["columns"]
|
|
90
|
-
header = [str(col["name"]) for col in columns]
|
|
91
|
-
rows = []
|
|
92
|
-
|
|
93
|
-
for item in output["data_array"]:
|
|
94
|
-
row = []
|
|
95
|
-
for column, value in zip(columns, item):
|
|
96
|
-
type_name = column["type_name"]
|
|
97
|
-
if value is None:
|
|
98
|
-
row.append(None)
|
|
99
|
-
continue
|
|
100
|
-
|
|
101
|
-
if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
|
|
102
|
-
row.append(int(value))
|
|
103
|
-
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
|
|
104
|
-
row.append(float(value))
|
|
105
|
-
elif type_name == "BOOLEAN":
|
|
106
|
-
row.append(value.lower() == "true")
|
|
107
|
-
elif type_name == "DATE" or type_name == "TIMESTAMP":
|
|
108
|
-
row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
|
|
109
|
-
elif type_name == "BINARY":
|
|
110
|
-
row.append(bytes(value, "utf-8"))
|
|
111
|
-
else:
|
|
112
|
-
row.append(value)
|
|
113
|
-
|
|
114
|
-
rows.append(row)
|
|
26
|
+
def _response_to_json(response: GenieResponse) -> str:
|
|
27
|
+
"""Convert GenieResponse to JSON string, handling DataFrame results."""
|
|
28
|
+
# Convert result to string if it's a DataFrame
|
|
29
|
+
result: str | pd.DataFrame = response.result
|
|
30
|
+
if isinstance(result, pd.DataFrame):
|
|
31
|
+
result = result.to_markdown()
|
|
115
32
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
return query_result.strip()
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
class Genie:
|
|
127
|
-
def __init__(
|
|
128
|
-
self,
|
|
129
|
-
space_id,
|
|
130
|
-
client: WorkspaceClient | None = None,
|
|
131
|
-
truncate_results: bool = False,
|
|
132
|
-
polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
133
|
-
):
|
|
134
|
-
self.space_id = space_id
|
|
135
|
-
workspace_client = client or WorkspaceClient()
|
|
136
|
-
self.genie = workspace_client.genie
|
|
137
|
-
self.description = self.genie.get_space(space_id).description
|
|
138
|
-
self.headers = {
|
|
139
|
-
"Accept": "application/json",
|
|
140
|
-
"Content-Type": "application/json",
|
|
141
|
-
}
|
|
142
|
-
self.truncate_results = truncate_results
|
|
143
|
-
if polling_interval < 1 or polling_interval > 30:
|
|
144
|
-
raise ValueError("poll_interval must be between 1 and 30 seconds")
|
|
145
|
-
self.poll_interval = polling_interval
|
|
146
|
-
|
|
147
|
-
@mlflow.trace()
|
|
148
|
-
def start_conversation(self, content):
|
|
149
|
-
resp = self.genie._api.do(
|
|
150
|
-
"POST",
|
|
151
|
-
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
|
|
152
|
-
body={"content": content},
|
|
153
|
-
headers=self.headers,
|
|
154
|
-
)
|
|
155
|
-
return resp
|
|
156
|
-
|
|
157
|
-
@mlflow.trace()
|
|
158
|
-
def create_message(self, conversation_id, content):
|
|
159
|
-
resp = self.genie._api.do(
|
|
160
|
-
"POST",
|
|
161
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
|
|
162
|
-
body={"content": content},
|
|
163
|
-
headers=self.headers,
|
|
164
|
-
)
|
|
165
|
-
return resp
|
|
166
|
-
|
|
167
|
-
@mlflow.trace()
|
|
168
|
-
def poll_for_result(self, conversation_id, message_id):
|
|
169
|
-
@mlflow.trace()
|
|
170
|
-
def poll_query_results(attachment_id, query_str, description):
|
|
171
|
-
iteration_count = 0
|
|
172
|
-
while iteration_count < MAX_ITERATIONS:
|
|
173
|
-
iteration_count += 1
|
|
174
|
-
resp = self.genie._api.do(
|
|
175
|
-
"GET",
|
|
176
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
|
|
177
|
-
headers=self.headers,
|
|
178
|
-
)["statement_response"]
|
|
179
|
-
state = resp["status"]["state"]
|
|
180
|
-
if state == "SUCCEEDED":
|
|
181
|
-
result = _parse_query_result(resp, self.truncate_results)
|
|
182
|
-
return GenieResponse(
|
|
183
|
-
conversation_id, result, query_str, description
|
|
184
|
-
)
|
|
185
|
-
elif state in ["RUNNING", "PENDING"]:
|
|
186
|
-
logger.debug("Waiting for query result...")
|
|
187
|
-
time.sleep(self.poll_interval)
|
|
188
|
-
else:
|
|
189
|
-
return GenieResponse(
|
|
190
|
-
conversation_id,
|
|
191
|
-
f"No query result: {resp['state']}",
|
|
192
|
-
query_str,
|
|
193
|
-
description,
|
|
194
|
-
)
|
|
195
|
-
return GenieResponse(
|
|
196
|
-
conversation_id,
|
|
197
|
-
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
198
|
-
query_str,
|
|
199
|
-
description,
|
|
200
|
-
)
|
|
201
|
-
|
|
202
|
-
@mlflow.trace()
|
|
203
|
-
def poll_result():
|
|
204
|
-
iteration_count = 0
|
|
205
|
-
while iteration_count < MAX_ITERATIONS:
|
|
206
|
-
iteration_count += 1
|
|
207
|
-
resp = self.genie._api.do(
|
|
208
|
-
"GET",
|
|
209
|
-
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
|
|
210
|
-
headers=self.headers,
|
|
211
|
-
)
|
|
212
|
-
if resp["status"] == "COMPLETED":
|
|
213
|
-
# Check if attachments key exists in response
|
|
214
|
-
attachments = resp.get("attachments", [])
|
|
215
|
-
if not attachments:
|
|
216
|
-
# Handle case where response has no attachments
|
|
217
|
-
return GenieResponse(
|
|
218
|
-
conversation_id,
|
|
219
|
-
result=f"Genie query completed but no attachments found. Response: {resp}",
|
|
220
|
-
)
|
|
221
|
-
|
|
222
|
-
attachment = next((r for r in attachments if "query" in r), None)
|
|
223
|
-
if attachment:
|
|
224
|
-
query_obj = attachment["query"]
|
|
225
|
-
description = query_obj.get("description", "")
|
|
226
|
-
query_str = query_obj.get("query", "")
|
|
227
|
-
attachment_id = attachment["attachment_id"]
|
|
228
|
-
return poll_query_results(attachment_id, query_str, description)
|
|
229
|
-
if resp["status"] == "COMPLETED":
|
|
230
|
-
text_content = next(
|
|
231
|
-
(r for r in attachments if "text" in r), None
|
|
232
|
-
)
|
|
233
|
-
if text_content:
|
|
234
|
-
return GenieResponse(
|
|
235
|
-
conversation_id, result=text_content["text"]["content"]
|
|
236
|
-
)
|
|
237
|
-
return GenieResponse(
|
|
238
|
-
conversation_id,
|
|
239
|
-
result="Genie query completed but no text content found in attachments.",
|
|
240
|
-
)
|
|
241
|
-
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
|
|
242
|
-
return GenieResponse(
|
|
243
|
-
conversation_id, result=f"Genie query {resp['status'].lower()}."
|
|
244
|
-
)
|
|
245
|
-
elif resp["status"] == "FAILED":
|
|
246
|
-
return GenieResponse(
|
|
247
|
-
conversation_id,
|
|
248
|
-
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
|
|
249
|
-
)
|
|
250
|
-
# includes EXECUTING_QUERY, Genie can retry after this status
|
|
251
|
-
else:
|
|
252
|
-
logger.debug(f"Waiting...: {resp['status']}")
|
|
253
|
-
time.sleep(self.poll_interval)
|
|
254
|
-
return GenieResponse(
|
|
255
|
-
conversation_id,
|
|
256
|
-
f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
257
|
-
)
|
|
258
|
-
|
|
259
|
-
return poll_result()
|
|
260
|
-
|
|
261
|
-
@mlflow.trace()
|
|
262
|
-
def ask_question(self, question: str, conversation_id: str | None = None):
|
|
263
|
-
logger.debug(
|
|
264
|
-
f"ask_question called with question: {question}, conversation_id: {conversation_id}"
|
|
265
|
-
)
|
|
266
|
-
if conversation_id:
|
|
267
|
-
resp = self.create_message(conversation_id, question)
|
|
268
|
-
else:
|
|
269
|
-
resp = self.start_conversation(question)
|
|
270
|
-
logger.debug(f"ask_question response: {resp}")
|
|
271
|
-
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
|
|
33
|
+
data: dict[str, Any] = {
|
|
34
|
+
"result": result,
|
|
35
|
+
"query": response.query,
|
|
36
|
+
"description": response.description,
|
|
37
|
+
"conversation_id": response.conversation_id,
|
|
38
|
+
}
|
|
39
|
+
return json.dumps(data)
|
|
272
40
|
|
|
273
41
|
|
|
274
42
|
def create_genie_tool(
|
|
275
43
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
276
|
-
name:
|
|
277
|
-
description:
|
|
44
|
+
name: str | None = None,
|
|
45
|
+
description: str | None = None,
|
|
278
46
|
persist_conversation: bool = False,
|
|
279
47
|
truncate_results: bool = False,
|
|
280
|
-
|
|
281
|
-
) -> Callable[[str], GenieResponse]:
|
|
48
|
+
) -> Callable[..., Command]:
|
|
282
49
|
"""
|
|
283
50
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
284
51
|
|
|
@@ -290,6 +57,9 @@ def create_genie_tool(
|
|
|
290
57
|
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
291
58
|
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
292
59
|
description: Optional custom description for the tool. If None, uses default description
|
|
60
|
+
persist_conversation: Whether to persist conversation IDs across tool calls for
|
|
61
|
+
multi-turn conversations within the same Genie space
|
|
62
|
+
truncate_results: Whether to truncate large query results to fit token limits
|
|
293
63
|
|
|
294
64
|
Returns:
|
|
295
65
|
A LangGraph tool that processes natural language queries through Genie
|
|
@@ -305,13 +75,6 @@ def create_genie_tool(
|
|
|
305
75
|
space_id = CompositeVariableModel(**space_id)
|
|
306
76
|
space_id = value_of(space_id)
|
|
307
77
|
|
|
308
|
-
# genie: Genie = Genie(
|
|
309
|
-
# space_id=space_id,
|
|
310
|
-
# client=genie_room.workspace_client,
|
|
311
|
-
# truncate_results=truncate_results,
|
|
312
|
-
# polling_interval=poll_interval,
|
|
313
|
-
# )
|
|
314
|
-
|
|
315
78
|
default_description: str = dedent("""
|
|
316
79
|
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
317
80
|
questions about the data and the tool will try to answer them.
|
|
@@ -343,14 +106,14 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
343
106
|
state: Annotated[dict, InjectedState],
|
|
344
107
|
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
345
108
|
) -> Command:
|
|
109
|
+
"""Process a natural language question through Databricks Genie."""
|
|
110
|
+
# Create Genie instance using databricks_langchain implementation
|
|
346
111
|
genie: Genie = Genie(
|
|
347
112
|
space_id=space_id,
|
|
348
113
|
client=genie_room.workspace_client,
|
|
349
114
|
truncate_results=truncate_results,
|
|
350
|
-
polling_interval=poll_interval,
|
|
351
115
|
)
|
|
352
116
|
|
|
353
|
-
"""Process a natural language question through Databricks Genie."""
|
|
354
117
|
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
355
118
|
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
356
119
|
existing_conversation_id: str | None = conversation_ids.get(space_id)
|
|
@@ -368,9 +131,10 @@ GenieResponse: A response object containing the conversation ID and result from
|
|
|
368
131
|
)
|
|
369
132
|
|
|
370
133
|
# Update the conversation mapping with the new conversation ID for this space
|
|
371
|
-
|
|
372
134
|
update: dict[str, Any] = {
|
|
373
|
-
"messages": [
|
|
135
|
+
"messages": [
|
|
136
|
+
ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
|
|
137
|
+
],
|
|
374
138
|
}
|
|
375
139
|
|
|
376
140
|
if persist_conversation:
|
dao_ai/tools/vector_search.py
CHANGED
|
@@ -101,7 +101,7 @@ def create_vector_search_tool(
|
|
|
101
101
|
# Initialize the vector store
|
|
102
102
|
# Note: text_column is only required for self-managed embeddings
|
|
103
103
|
# For Databricks-managed embeddings, it's automatically determined from the index
|
|
104
|
-
|
|
104
|
+
|
|
105
105
|
# Build client_args for VectorSearchClient from environment variables
|
|
106
106
|
# This is needed because during MLflow model validation, credentials must be
|
|
107
107
|
# explicitly passed to VectorSearchClient via client_args.
|
|
@@ -121,7 +121,9 @@ def create_vector_search_tool(
|
|
|
121
121
|
"DATABRICKS_CLIENT_SECRET"
|
|
122
122
|
)
|
|
123
123
|
|
|
124
|
-
logger.debug(
|
|
124
|
+
logger.debug(
|
|
125
|
+
f"Creating DatabricksVectorSearch with client_args keys: {list(client_args.keys())}"
|
|
126
|
+
)
|
|
125
127
|
|
|
126
128
|
# Pass both workspace_client (for model serving detection) and client_args (for credentials)
|
|
127
129
|
vector_store: DatabricksVectorSearch = DatabricksVectorSearch(
|
dao_ai/utils.py
CHANGED
|
@@ -99,7 +99,7 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
99
99
|
f"databricks-langchain=={version('databricks-langchain')}",
|
|
100
100
|
f"databricks-mcp=={version('databricks-mcp')}",
|
|
101
101
|
f"databricks-sdk[openai]=={version('databricks-sdk')}",
|
|
102
|
-
f"
|
|
102
|
+
f"ddgs=={version('ddgs')}",
|
|
103
103
|
f"flashrank=={version('flashrank')}",
|
|
104
104
|
f"langchain=={version('langchain')}",
|
|
105
105
|
f"langchain-mcp-adapters=={version('langchain-mcp-adapters')}",
|
|
@@ -141,12 +141,12 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
141
141
|
"module.submodule.function_name"
|
|
142
142
|
|
|
143
143
|
Returns:
|
|
144
|
-
The imported callable function
|
|
144
|
+
The imported callable function or langchain tool
|
|
145
145
|
|
|
146
146
|
Raises:
|
|
147
147
|
ImportError: If the module cannot be imported
|
|
148
148
|
AttributeError: If the function doesn't exist in the module
|
|
149
|
-
TypeError: If the resolved object is not callable
|
|
149
|
+
TypeError: If the resolved object is not callable or invocable
|
|
150
150
|
|
|
151
151
|
Example:
|
|
152
152
|
>>> func = callable_from_fqn("dao_ai.models.get_latest_model_version")
|
|
@@ -164,9 +164,14 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
164
164
|
# Get the function from the module
|
|
165
165
|
func = getattr(module, func_name)
|
|
166
166
|
|
|
167
|
-
# Verify that the resolved object is callable
|
|
168
|
-
|
|
169
|
-
|
|
167
|
+
# Verify that the resolved object is callable or is a langchain tool
|
|
168
|
+
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
169
|
+
# but have an invoke() method
|
|
170
|
+
is_callable = callable(func)
|
|
171
|
+
is_langchain_tool = hasattr(func, "invoke") and hasattr(func, "name")
|
|
172
|
+
|
|
173
|
+
if not is_callable and not is_langchain_tool:
|
|
174
|
+
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
170
175
|
|
|
171
176
|
return func
|
|
172
177
|
except (ImportError, AttributeError, TypeError) as e:
|
|
@@ -175,4 +180,26 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
175
180
|
|
|
176
181
|
|
|
177
182
|
def is_in_model_serving() -> bool:
|
|
178
|
-
|
|
183
|
+
"""Check if running in Databricks Model Serving environment.
|
|
184
|
+
|
|
185
|
+
Detects Model Serving by checking for environment variables that are
|
|
186
|
+
typically set in that environment.
|
|
187
|
+
"""
|
|
188
|
+
# Primary check - explicit Databricks Model Serving env var
|
|
189
|
+
if os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true":
|
|
190
|
+
return True
|
|
191
|
+
|
|
192
|
+
# Secondary check - Model Serving sets these environment variables
|
|
193
|
+
if os.environ.get("DATABRICKS_MODEL_SERVING_ENV"):
|
|
194
|
+
return True
|
|
195
|
+
|
|
196
|
+
# Check for cluster type indicator
|
|
197
|
+
cluster_type = os.environ.get("DATABRICKS_CLUSTER_TYPE", "")
|
|
198
|
+
if "model-serving" in cluster_type.lower():
|
|
199
|
+
return True
|
|
200
|
+
|
|
201
|
+
# Check for model serving specific paths
|
|
202
|
+
if os.path.exists("/opt/conda/envs/mlflow-env"):
|
|
203
|
+
return True
|
|
204
|
+
|
|
205
|
+
return False
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.32
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -26,24 +26,24 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
26
26
|
Classifier: Topic :: System :: Distributed Computing
|
|
27
27
|
Requires-Python: >=3.11
|
|
28
28
|
Requires-Dist: databricks-agents>=1.7.0
|
|
29
|
-
Requires-Dist: databricks-langchain>=0.
|
|
29
|
+
Requires-Dist: databricks-langchain>=0.11.0
|
|
30
30
|
Requires-Dist: databricks-mcp>=0.3.0
|
|
31
31
|
Requires-Dist: databricks-sdk[openai]>=0.67.0
|
|
32
|
-
Requires-Dist:
|
|
32
|
+
Requires-Dist: ddgs>=9.9.3
|
|
33
33
|
Requires-Dist: flashrank>=0.2.8
|
|
34
34
|
Requires-Dist: gepa>=0.0.17
|
|
35
35
|
Requires-Dist: grandalf>=0.8
|
|
36
36
|
Requires-Dist: langchain-mcp-adapters>=0.1.10
|
|
37
37
|
Requires-Dist: langchain-tavily>=0.2.11
|
|
38
|
-
Requires-Dist: langchain>=
|
|
38
|
+
Requires-Dist: langchain>=1.1.3
|
|
39
39
|
Requires-Dist: langgraph-checkpoint-postgres>=2.0.25
|
|
40
|
-
Requires-Dist: langgraph-supervisor>=0.0.
|
|
41
|
-
Requires-Dist: langgraph-swarm>=0.0
|
|
42
|
-
Requires-Dist: langgraph>=0.
|
|
40
|
+
Requires-Dist: langgraph-supervisor>=0.0.31
|
|
41
|
+
Requires-Dist: langgraph-swarm>=0.1.0
|
|
42
|
+
Requires-Dist: langgraph>=1.0.4
|
|
43
43
|
Requires-Dist: langmem>=0.0.29
|
|
44
44
|
Requires-Dist: loguru>=0.7.3
|
|
45
45
|
Requires-Dist: mcp>=1.17.0
|
|
46
|
-
Requires-Dist: mlflow>=3.
|
|
46
|
+
Requires-Dist: mlflow>=3.7.0
|
|
47
47
|
Requires-Dist: nest-asyncio>=1.6.0
|
|
48
48
|
Requires-Dist: openevals>=0.0.19
|
|
49
49
|
Requires-Dist: openpyxl>=3.1.5
|
|
@@ -55,7 +55,7 @@ Requires-Dist: rich>=14.0.0
|
|
|
55
55
|
Requires-Dist: scipy<=1.15
|
|
56
56
|
Requires-Dist: sqlparse>=0.5.3
|
|
57
57
|
Requires-Dist: tomli>=2.3.0
|
|
58
|
-
Requires-Dist: unitycatalog-ai[databricks]>=0.3.
|
|
58
|
+
Requires-Dist: unitycatalog-ai[databricks]>=0.3.2
|
|
59
59
|
Provides-Extra: databricks
|
|
60
60
|
Requires-Dist: databricks-connect>=15.0.0; extra == 'databricks'
|
|
61
61
|
Requires-Dist: databricks-vectorsearch>=0.63; extra == 'databricks'
|
|
@@ -3,16 +3,16 @@ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=gq-nsapWxDA1M6Jua3vajBvIwf0Oa6YLcB58lEtMKUo,22503
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=sc9iYPui5tHitG5kmOTd9LVjzgLJ2Dn0M6s-Zu3dw04,75022
|
|
7
7
|
dao_ai/graph.py,sha256=9kjJx0oFZKq5J9-Kpri4-0VCJILHYdYyhqQnj0_noxQ,8913
|
|
8
8
|
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
10
|
dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
|
|
11
11
|
dao_ai/nodes.py,sha256=iQ_5vL6mt1UcRnhwgz-l1D8Ww4CMQrSMVnP_Lu7fFjU,8781
|
|
12
|
-
dao_ai/prompts.py,sha256=
|
|
12
|
+
dao_ai/prompts.py,sha256=iA2Iaky7yzjwWT5cxg0cUIgwo1z1UVQua__8WPnvV6g,1633
|
|
13
13
|
dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
dao_ai/utils.py,sha256=
|
|
15
|
+
dao_ai/utils.py,sha256=FLXbiUaCeBva4vJ-czs-sRP7QSxjoKjyDt1Q4yeI7sU,7727
|
|
16
16
|
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
17
17
|
dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
|
|
18
18
|
dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
@@ -22,20 +22,20 @@ dao_ai/memory/core.py,sha256=DnEjQO3S7hXr3CDDd7C2eE7fQUmcCS_8q9BXEgjPH3U,4271
|
|
|
22
22
|
dao_ai/memory/postgres.py,sha256=vvI3osjx1EoU5GBA6SCUstTBKillcmLl12hVgDMjfJY,15346
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=rPBMdGcJvdGBRK9FZeBxkLfcTpXyxU1cs14YllyZKbY,67857
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
29
|
-
dao_ai/tools/genie.py,sha256=
|
|
29
|
+
dao_ai/tools/genie.py,sha256=BPM_1Sk5bf7QSCFPPboWWkZKYwBwDwbGhMVp5-QDd10,5956
|
|
30
30
|
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
31
|
dao_ai/tools/mcp.py,sha256=5aQoRtx2z4xm6zgRslc78rSfEQe-mfhqov2NsiybYfc,8416
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
|
|
34
34
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
35
35
|
dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
|
|
36
|
-
dao_ai/tools/vector_search.py,sha256=
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
41
|
-
dao_ai-0.0.
|
|
36
|
+
dao_ai/tools/vector_search.py,sha256=3cdiUaFpox25GSRNec7FKceY3DuLp7dLVH8FRA0BgeY,12624
|
|
37
|
+
dao_ai-0.0.32.dist-info/METADATA,sha256=1_BlILYdzDHCILhIxFNeWdM6CRg4uKqBNPiP_hjbXtE,42763
|
|
38
|
+
dao_ai-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
39
|
+
dao_ai-0.0.32.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
40
|
+
dao_ai-0.0.32.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
41
|
+
dao_ai-0.0.32.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|