dao-ai 0.0.25__py3-none-any.whl → 0.0.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +3 -0
- dao_ai/config.py +431 -27
- dao_ai/graph.py +29 -4
- dao_ai/nodes.py +29 -20
- dao_ai/providers/databricks.py +536 -35
- dao_ai/tools/genie.py +2 -3
- dao_ai/tools/mcp.py +46 -27
- dao_ai/tools/vector_search.py +232 -22
- dao_ai/utils.py +57 -1
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/METADATA +6 -3
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/RECORD +14 -14
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/WHEEL +1 -1
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.0.28.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import re
|
|
2
3
|
import uuid
|
|
3
|
-
from importlib.metadata import version
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Any, Callable, Final, Sequence
|
|
6
6
|
|
|
@@ -32,11 +32,14 @@ from mlflow import MlflowClient
|
|
|
32
32
|
from mlflow.entities import Experiment
|
|
33
33
|
from mlflow.entities.model_registry import PromptVersion
|
|
34
34
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
35
|
+
from mlflow.genai.datasets import EvaluationDataset, get_dataset
|
|
36
|
+
from mlflow.genai.prompts import load_prompt
|
|
35
37
|
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
|
|
36
38
|
from mlflow.models.model import ModelInfo
|
|
37
39
|
from mlflow.models.resources import (
|
|
38
40
|
DatabricksResource,
|
|
39
41
|
)
|
|
42
|
+
from mlflow.pyfunc import ResponsesAgent
|
|
40
43
|
from pyspark.sql import SparkSession
|
|
41
44
|
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
42
45
|
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
|
|
@@ -54,6 +57,7 @@ from dao_ai.config import (
|
|
|
54
57
|
IsDatabricksResource,
|
|
55
58
|
LLMModel,
|
|
56
59
|
PromptModel,
|
|
60
|
+
PromptOptimizationModel,
|
|
57
61
|
SchemaModel,
|
|
58
62
|
TableModel,
|
|
59
63
|
UnityCatalogFunctionSqlModel,
|
|
@@ -65,6 +69,7 @@ from dao_ai.config import (
|
|
|
65
69
|
from dao_ai.models import get_latest_model_version
|
|
66
70
|
from dao_ai.providers.base import ServiceProvider
|
|
67
71
|
from dao_ai.utils import (
|
|
72
|
+
dao_ai_version,
|
|
68
73
|
get_installed_packages,
|
|
69
74
|
is_installed,
|
|
70
75
|
is_lib_provided,
|
|
@@ -296,7 +301,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
296
301
|
if is_installed():
|
|
297
302
|
if not is_lib_provided("dao-ai", pip_requirements):
|
|
298
303
|
pip_requirements += [
|
|
299
|
-
f"dao-ai=={
|
|
304
|
+
f"dao-ai=={dao_ai_version()}",
|
|
300
305
|
]
|
|
301
306
|
else:
|
|
302
307
|
src_path: Path = model_root_path.parent
|
|
@@ -322,10 +327,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
322
327
|
|
|
323
328
|
with mlflow.start_run(run_name=run_name):
|
|
324
329
|
mlflow.set_tag("type", "agent")
|
|
330
|
+
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
325
331
|
logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
|
|
326
332
|
python_model=model_path.as_posix(),
|
|
327
333
|
code_paths=code_paths,
|
|
328
|
-
model_config=config.model_dump(by_alias=True),
|
|
334
|
+
model_config=config.model_dump(mode="json", by_alias=True),
|
|
329
335
|
name="agent",
|
|
330
336
|
pip_requirements=pip_requirements,
|
|
331
337
|
input_example=input_example,
|
|
@@ -344,9 +350,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
344
350
|
|
|
345
351
|
client: MlflowClient = MlflowClient()
|
|
346
352
|
|
|
353
|
+
# Set tags on the model version
|
|
354
|
+
client.set_model_version_tag(
|
|
355
|
+
name=registered_model_name,
|
|
356
|
+
version=model_version.version,
|
|
357
|
+
key="dao_ai",
|
|
358
|
+
value=dao_ai_version(),
|
|
359
|
+
)
|
|
360
|
+
logger.debug(f"Set dao_ai tag on model version {model_version.version}")
|
|
361
|
+
|
|
347
362
|
client.set_registered_model_alias(
|
|
348
363
|
name=registered_model_name,
|
|
349
|
-
alias="
|
|
364
|
+
alias="Champion",
|
|
350
365
|
version=model_version.version,
|
|
351
366
|
)
|
|
352
367
|
|
|
@@ -372,7 +387,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
372
387
|
scale_to_zero: bool = config.app.scale_to_zero
|
|
373
388
|
environment_vars: dict[str, str] = config.app.environment_vars
|
|
374
389
|
workload_size: str = config.app.workload_size
|
|
375
|
-
tags: dict[str, str] = config.app.tags
|
|
390
|
+
tags: dict[str, str] = config.app.tags.copy() if config.app.tags else {}
|
|
391
|
+
|
|
392
|
+
# Add dao_ai framework tag
|
|
393
|
+
tags["dao_ai"] = dao_ai_version()
|
|
376
394
|
|
|
377
395
|
latest_version: int = get_latest_model_version(registered_model_name)
|
|
378
396
|
|
|
@@ -1026,42 +1044,124 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1026
1044
|
)
|
|
1027
1045
|
raise
|
|
1028
1046
|
|
|
1029
|
-
def get_prompt(self, prompt_model: PromptModel) ->
|
|
1030
|
-
"""
|
|
1031
|
-
|
|
1047
|
+
def get_prompt(self, prompt_model: PromptModel) -> PromptVersion:
|
|
1048
|
+
"""
|
|
1049
|
+
Load prompt from MLflow Prompt Registry with fallback logic.
|
|
1032
1050
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
else:
|
|
1039
|
-
prompt_uri = f"prompts:/{prompt_name}@latest"
|
|
1051
|
+
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1052
|
+
Otherwise, tries to load prompts in this order:
|
|
1053
|
+
1. champion alias (if it exists)
|
|
1054
|
+
2. latest alias (if it exists)
|
|
1055
|
+
3. default_template (if provided)
|
|
1040
1056
|
|
|
1041
|
-
|
|
1042
|
-
|
|
1057
|
+
Args:
|
|
1058
|
+
prompt_model: The prompt model configuration
|
|
1043
1059
|
|
|
1044
|
-
|
|
1045
|
-
|
|
1060
|
+
Returns:
|
|
1061
|
+
PromptVersion: The loaded prompt version
|
|
1046
1062
|
|
|
1047
|
-
|
|
1048
|
-
|
|
1063
|
+
Raises:
|
|
1064
|
+
ValueError: If no prompt can be loaded from any source
|
|
1065
|
+
"""
|
|
1066
|
+
prompt_name: str = prompt_model.full_name
|
|
1049
1067
|
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1068
|
+
# If explicit version or alias is specified, use it directly without fallback
|
|
1069
|
+
if prompt_model.version or prompt_model.alias:
|
|
1070
|
+
try:
|
|
1071
|
+
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
1072
|
+
logger.debug(
|
|
1073
|
+
f"Loaded prompt '{prompt_name}' with explicit "
|
|
1074
|
+
f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}"
|
|
1075
|
+
)
|
|
1076
|
+
return prompt_version
|
|
1077
|
+
except Exception as e:
|
|
1078
|
+
logger.warning(
|
|
1079
|
+
f"Failed to load prompt '{prompt_name}' with explicit "
|
|
1080
|
+
f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
|
|
1054
1081
|
)
|
|
1055
|
-
|
|
1082
|
+
# Fall through to default_template if available
|
|
1083
|
+
else:
|
|
1084
|
+
# No explicit version/alias specified - check if default_template needs syncing first
|
|
1085
|
+
logger.debug(
|
|
1086
|
+
f"No explicit version/alias specified for '{prompt_name}', "
|
|
1087
|
+
"checking if default_template needs syncing"
|
|
1088
|
+
)
|
|
1056
1089
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1090
|
+
# If we have a default_template, check if it differs from what's in the registry
|
|
1091
|
+
# This ensures we always sync config changes before returning any alias
|
|
1092
|
+
if prompt_model.default_template:
|
|
1093
|
+
try:
|
|
1094
|
+
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1095
|
+
default_version: PromptVersion = load_prompt(default_uri)
|
|
1096
|
+
|
|
1097
|
+
if (
|
|
1098
|
+
default_version.to_single_brace_format().strip()
|
|
1099
|
+
!= prompt_model.default_template.strip()
|
|
1100
|
+
):
|
|
1101
|
+
logger.info(
|
|
1102
|
+
f"Config default_template for '{prompt_name}' differs from registry, syncing..."
|
|
1103
|
+
)
|
|
1104
|
+
return self._sync_default_template_to_registry(
|
|
1105
|
+
prompt_name,
|
|
1106
|
+
prompt_model.default_template,
|
|
1107
|
+
prompt_model.description,
|
|
1108
|
+
)
|
|
1109
|
+
except Exception as e:
|
|
1110
|
+
logger.debug(f"Could not check default alias for sync: {e}")
|
|
1111
|
+
|
|
1112
|
+
# Now try aliases in order: champion → latest → default
|
|
1113
|
+
logger.debug(
|
|
1114
|
+
f"Trying fallback order for '{prompt_name}': champion → latest → default"
|
|
1115
|
+
)
|
|
1116
|
+
|
|
1117
|
+
# Try champion alias first
|
|
1118
|
+
try:
|
|
1119
|
+
champion_uri: str = f"prompts:/{prompt_name}@champion"
|
|
1120
|
+
prompt_version: PromptVersion = load_prompt(champion_uri)
|
|
1121
|
+
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1122
|
+
return prompt_version
|
|
1123
|
+
except Exception as e:
|
|
1124
|
+
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1125
|
+
|
|
1126
|
+
# Try latest alias next
|
|
1127
|
+
try:
|
|
1128
|
+
latest_uri: str = f"prompts:/{prompt_name}@latest"
|
|
1129
|
+
prompt_version: PromptVersion = load_prompt(latest_uri)
|
|
1130
|
+
logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
|
|
1131
|
+
return prompt_version
|
|
1132
|
+
except Exception as e:
|
|
1133
|
+
logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
|
|
1134
|
+
|
|
1135
|
+
# Try default alias last
|
|
1136
|
+
try:
|
|
1137
|
+
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1138
|
+
prompt_version: PromptVersion = load_prompt(default_uri)
|
|
1139
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1140
|
+
return prompt_version
|
|
1141
|
+
except Exception as e:
|
|
1142
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1143
|
+
|
|
1144
|
+
# Fall back to registering default_template if provided
|
|
1145
|
+
if prompt_model.default_template:
|
|
1146
|
+
logger.info(
|
|
1147
|
+
f"Registering default_template for '{prompt_name}' "
|
|
1148
|
+
"(no aliases found in registry)"
|
|
1149
|
+
)
|
|
1150
|
+
return self._sync_default_template_to_registry(
|
|
1151
|
+
prompt_name, prompt_model.default_template, prompt_model.description
|
|
1152
|
+
)
|
|
1153
|
+
|
|
1154
|
+
raise ValueError(
|
|
1155
|
+
f"Prompt '{prompt_name}' not found in registry "
|
|
1156
|
+
"(tried champion, latest, default aliases) and no default_template provided"
|
|
1157
|
+
)
|
|
1060
1158
|
|
|
1061
1159
|
def _sync_default_template_to_registry(
|
|
1062
1160
|
self, prompt_name: str, default_template: str, description: str | None = None
|
|
1063
|
-
) ->
|
|
1161
|
+
) -> PromptVersion:
|
|
1064
1162
|
"""Register default_template to prompt registry under 'default' alias if changed."""
|
|
1163
|
+
prompt_version: PromptVersion | None = None
|
|
1164
|
+
|
|
1065
1165
|
try:
|
|
1066
1166
|
# Check if default alias already has the same template
|
|
1067
1167
|
try:
|
|
@@ -1074,7 +1174,45 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1074
1174
|
== default_template.strip()
|
|
1075
1175
|
):
|
|
1076
1176
|
logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
|
|
1077
|
-
|
|
1177
|
+
|
|
1178
|
+
# Ensure the "latest" and "champion" aliases also exist and point to the same version
|
|
1179
|
+
# This handles prompts created before the fix that added these aliases
|
|
1180
|
+
try:
|
|
1181
|
+
latest_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1182
|
+
f"prompts:/{prompt_name}@latest"
|
|
1183
|
+
)
|
|
1184
|
+
logger.debug(
|
|
1185
|
+
f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
|
|
1186
|
+
)
|
|
1187
|
+
except Exception:
|
|
1188
|
+
logger.info(
|
|
1189
|
+
f"Setting 'latest' alias for existing prompt '{prompt_name}' v{existing.version}"
|
|
1190
|
+
)
|
|
1191
|
+
mlflow.genai.set_prompt_alias(
|
|
1192
|
+
name=prompt_name,
|
|
1193
|
+
alias="latest",
|
|
1194
|
+
version=existing.version,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
# Ensure champion alias exists for first-time deployments
|
|
1198
|
+
try:
|
|
1199
|
+
champion_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1200
|
+
f"prompts:/{prompt_name}@champion"
|
|
1201
|
+
)
|
|
1202
|
+
logger.debug(
|
|
1203
|
+
f"Champion alias already exists for '{prompt_name}' pointing to version {champion_version.version}"
|
|
1204
|
+
)
|
|
1205
|
+
except Exception:
|
|
1206
|
+
logger.info(
|
|
1207
|
+
f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
|
|
1208
|
+
)
|
|
1209
|
+
mlflow.genai.set_prompt_alias(
|
|
1210
|
+
name=prompt_name,
|
|
1211
|
+
alias="champion",
|
|
1212
|
+
version=existing.version,
|
|
1213
|
+
)
|
|
1214
|
+
|
|
1215
|
+
return existing # Already up-to-date, return existing version
|
|
1078
1216
|
except Exception:
|
|
1079
1217
|
logger.debug(
|
|
1080
1218
|
f"Default alias for prompt '{prompt_name}' doesn't exist yet"
|
|
@@ -1082,22 +1220,385 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1082
1220
|
|
|
1083
1221
|
# Register new version and set as default alias
|
|
1084
1222
|
commit_message = description or "Auto-synced from default_template"
|
|
1085
|
-
prompt_version
|
|
1223
|
+
prompt_version = mlflow.genai.register_prompt(
|
|
1086
1224
|
name=prompt_name,
|
|
1087
1225
|
template=default_template,
|
|
1088
1226
|
commit_message=commit_message,
|
|
1227
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1089
1228
|
)
|
|
1090
1229
|
|
|
1091
|
-
logger.debug(
|
|
1230
|
+
logger.debug(
|
|
1231
|
+
f"Setting default, latest, and champion aliases for prompt '{prompt_name}'"
|
|
1232
|
+
)
|
|
1092
1233
|
mlflow.genai.set_prompt_alias(
|
|
1093
1234
|
name=prompt_name,
|
|
1094
1235
|
alias="default",
|
|
1095
1236
|
version=prompt_version.version,
|
|
1096
1237
|
)
|
|
1238
|
+
mlflow.genai.set_prompt_alias(
|
|
1239
|
+
name=prompt_name,
|
|
1240
|
+
alias="latest",
|
|
1241
|
+
version=prompt_version.version,
|
|
1242
|
+
)
|
|
1243
|
+
mlflow.genai.set_prompt_alias(
|
|
1244
|
+
name=prompt_name,
|
|
1245
|
+
alias="champion",
|
|
1246
|
+
version=prompt_version.version,
|
|
1247
|
+
)
|
|
1097
1248
|
|
|
1098
1249
|
logger.info(
|
|
1099
|
-
f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry"
|
|
1250
|
+
f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
|
|
1100
1251
|
)
|
|
1252
|
+
return prompt_version
|
|
1101
1253
|
|
|
1102
1254
|
except Exception as e:
|
|
1103
|
-
logger.
|
|
1255
|
+
logger.error(f"Failed to sync '{prompt_name}' to registry: {e}")
|
|
1256
|
+
raise ValueError(
|
|
1257
|
+
f"Failed to sync prompt '{prompt_name}' to registry and unable to retrieve existing version"
|
|
1258
|
+
) from e
|
|
1259
|
+
|
|
1260
|
+
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1261
|
+
"""
|
|
1262
|
+
Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
|
|
1263
|
+
|
|
1264
|
+
This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
|
|
1265
|
+
https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
|
|
1266
|
+
|
|
1267
|
+
Args:
|
|
1268
|
+
optimization: PromptOptimizationModel containing configuration
|
|
1269
|
+
|
|
1270
|
+
Returns:
|
|
1271
|
+
PromptModel: The optimized prompt with new URI
|
|
1272
|
+
"""
|
|
1273
|
+
from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
|
|
1274
|
+
from mlflow.genai.scorers import Correctness
|
|
1275
|
+
|
|
1276
|
+
from dao_ai.config import AgentModel, PromptModel
|
|
1277
|
+
|
|
1278
|
+
logger.info(f"Optimizing prompt: {optimization.name}")
|
|
1279
|
+
|
|
1280
|
+
# Get agent and prompt (prompt is guaranteed to be set by validator)
|
|
1281
|
+
agent_model: AgentModel = optimization.agent
|
|
1282
|
+
prompt: PromptModel = optimization.prompt # type: ignore[assignment]
|
|
1283
|
+
agent_model.prompt = prompt.uri
|
|
1284
|
+
|
|
1285
|
+
print(f"prompt={agent_model.prompt}")
|
|
1286
|
+
# Log the prompt URI scheme being used
|
|
1287
|
+
# Supports three schemes:
|
|
1288
|
+
# 1. Specific version: "prompts:/qa/1" (when version is specified)
|
|
1289
|
+
# 2. Alias: "prompts:/qa@champion" (when alias is specified)
|
|
1290
|
+
# 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
|
|
1291
|
+
prompt_uri: str = prompt.uri
|
|
1292
|
+
logger.info(f"Using prompt URI for optimization: {prompt_uri}")
|
|
1293
|
+
|
|
1294
|
+
# Load the specific prompt version by URI for comparison
|
|
1295
|
+
# Try to load the exact version specified, but if it doesn't exist,
|
|
1296
|
+
# use get_prompt to create it from default_template
|
|
1297
|
+
prompt_version: PromptVersion
|
|
1298
|
+
try:
|
|
1299
|
+
prompt_version = load_prompt(prompt_uri)
|
|
1300
|
+
logger.info(f"Successfully loaded prompt from registry: {prompt_uri}")
|
|
1301
|
+
except Exception as e:
|
|
1302
|
+
logger.warning(
|
|
1303
|
+
f"Could not load prompt '{prompt_uri}' directly: {e}. "
|
|
1304
|
+
"Attempting to create from default_template..."
|
|
1305
|
+
)
|
|
1306
|
+
# Use get_prompt which will create from default_template if needed
|
|
1307
|
+
prompt_version = self.get_prompt(prompt)
|
|
1308
|
+
logger.info(
|
|
1309
|
+
f"Created/loaded prompt '{prompt.full_name}' (will optimize against this version)"
|
|
1310
|
+
)
|
|
1311
|
+
|
|
1312
|
+
# Load the evaluation dataset by name
|
|
1313
|
+
logger.debug(f"Looking up dataset: {optimization.dataset}")
|
|
1314
|
+
dataset: EvaluationDataset
|
|
1315
|
+
if isinstance(optimization.dataset, str):
|
|
1316
|
+
dataset = get_dataset(name=optimization.dataset)
|
|
1317
|
+
else:
|
|
1318
|
+
dataset = optimization.dataset.as_dataset()
|
|
1319
|
+
|
|
1320
|
+
# Set up reflection model for the optimizer
|
|
1321
|
+
reflection_model_name: str
|
|
1322
|
+
if optimization.reflection_model:
|
|
1323
|
+
if isinstance(optimization.reflection_model, str):
|
|
1324
|
+
reflection_model_name = optimization.reflection_model
|
|
1325
|
+
else:
|
|
1326
|
+
reflection_model_name = optimization.reflection_model.uri
|
|
1327
|
+
else:
|
|
1328
|
+
reflection_model_name = agent_model.model.uri
|
|
1329
|
+
logger.debug(f"Using reflection model: {reflection_model_name}")
|
|
1330
|
+
|
|
1331
|
+
# Create the GepaPromptOptimizer
|
|
1332
|
+
optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
|
|
1333
|
+
reflection_model=reflection_model_name,
|
|
1334
|
+
max_metric_calls=optimization.num_candidates,
|
|
1335
|
+
display_progress_bar=True,
|
|
1336
|
+
)
|
|
1337
|
+
|
|
1338
|
+
# Set up scorer (judge model for evaluation)
|
|
1339
|
+
scorer_model: str
|
|
1340
|
+
if optimization.scorer_model:
|
|
1341
|
+
if isinstance(optimization.scorer_model, str):
|
|
1342
|
+
scorer_model = optimization.scorer_model
|
|
1343
|
+
else:
|
|
1344
|
+
scorer_model = optimization.scorer_model.uri
|
|
1345
|
+
else:
|
|
1346
|
+
scorer_model = agent_model.model.uri # Use Databricks default
|
|
1347
|
+
logger.debug(f"Using scorer with model: {scorer_model}")
|
|
1348
|
+
|
|
1349
|
+
scorers: list[Correctness] = [Correctness(model=scorer_model)]
|
|
1350
|
+
|
|
1351
|
+
# Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
|
|
1352
|
+
# DO NOT overwrite with prompt_version.uri as that uses fallback logic
|
|
1353
|
+
logger.debug(f"Optimizing prompt: {prompt_uri}")
|
|
1354
|
+
|
|
1355
|
+
agent: ResponsesAgent = agent_model.as_responses_agent()
|
|
1356
|
+
|
|
1357
|
+
# Create predict function that will be optimized
|
|
1358
|
+
def predict_fn(**inputs: dict[str, Any]) -> str:
|
|
1359
|
+
"""Prediction function that uses the ResponsesAgent with ChatPayload.
|
|
1360
|
+
|
|
1361
|
+
The agent already has the prompt referenced/applied, so we just need to
|
|
1362
|
+
convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
|
|
1363
|
+
|
|
1364
|
+
Args:
|
|
1365
|
+
**inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
|
|
1366
|
+
|
|
1367
|
+
Returns:
|
|
1368
|
+
str: The agent's response content
|
|
1369
|
+
"""
|
|
1370
|
+
from mlflow.types.responses import (
|
|
1371
|
+
ResponsesAgentRequest,
|
|
1372
|
+
ResponsesAgentResponse,
|
|
1373
|
+
)
|
|
1374
|
+
from mlflow.types.responses_helpers import Message
|
|
1375
|
+
|
|
1376
|
+
from dao_ai.config import ChatPayload
|
|
1377
|
+
|
|
1378
|
+
# Verify agent is accessible (should be captured from outer scope)
|
|
1379
|
+
if agent is None:
|
|
1380
|
+
raise RuntimeError(
|
|
1381
|
+
"Agent object is not available in predict_fn. "
|
|
1382
|
+
"This may indicate a serialization issue with the ResponsesAgent."
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
# Convert inputs to ChatPayload
|
|
1386
|
+
chat_payload: ChatPayload = ChatPayload(**inputs)
|
|
1387
|
+
|
|
1388
|
+
# Convert ChatPayload messages to MLflow Message format
|
|
1389
|
+
mlflow_messages: list[Message] = [
|
|
1390
|
+
Message(role=msg.role, content=msg.content)
|
|
1391
|
+
for msg in chat_payload.messages
|
|
1392
|
+
]
|
|
1393
|
+
|
|
1394
|
+
# Create ResponsesAgentRequest
|
|
1395
|
+
request: ResponsesAgentRequest = ResponsesAgentRequest(
|
|
1396
|
+
input=mlflow_messages,
|
|
1397
|
+
custom_inputs=chat_payload.custom_inputs,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# Call the ResponsesAgent's predict method
|
|
1401
|
+
response: ResponsesAgentResponse = agent.predict(request)
|
|
1402
|
+
|
|
1403
|
+
if response.output and len(response.output) > 0:
|
|
1404
|
+
content = response.output[0].content
|
|
1405
|
+
logger.debug(f"Response content type: {type(content)}")
|
|
1406
|
+
logger.debug(f"Response content: {content}")
|
|
1407
|
+
|
|
1408
|
+
# Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
|
|
1409
|
+
# Content can be:
|
|
1410
|
+
# - A string (return as is)
|
|
1411
|
+
# - A list of items with 'text' keys (extract and join)
|
|
1412
|
+
# - Other types (try to get 'text' attribute or convert to string)
|
|
1413
|
+
if isinstance(content, str):
|
|
1414
|
+
return content
|
|
1415
|
+
elif isinstance(content, list):
|
|
1416
|
+
text_parts = []
|
|
1417
|
+
for content_item in content:
|
|
1418
|
+
if isinstance(content_item, str):
|
|
1419
|
+
text_parts.append(content_item)
|
|
1420
|
+
elif isinstance(content_item, dict) and "text" in content_item:
|
|
1421
|
+
text_parts.append(content_item["text"])
|
|
1422
|
+
elif hasattr(content_item, "text"):
|
|
1423
|
+
text_parts.append(content_item.text)
|
|
1424
|
+
return "".join(text_parts) if text_parts else str(content)
|
|
1425
|
+
else:
|
|
1426
|
+
# Fallback for unknown types - try to extract text attribute
|
|
1427
|
+
return getattr(content, "text", str(content))
|
|
1428
|
+
else:
|
|
1429
|
+
return ""
|
|
1430
|
+
|
|
1431
|
+
# Set registry URI for Databricks Unity Catalog
|
|
1432
|
+
mlflow.set_registry_uri("databricks-uc")
|
|
1433
|
+
|
|
1434
|
+
# Run optimization with tracking disabled to prevent auto-registering all candidates
|
|
1435
|
+
logger.info("Running prompt optimization with GepaPromptOptimizer...")
|
|
1436
|
+
logger.info(
|
|
1437
|
+
f"Generating {optimization.num_candidates} candidate prompts for evaluation"
|
|
1438
|
+
)
|
|
1439
|
+
|
|
1440
|
+
from mlflow.genai.optimize.types import (
|
|
1441
|
+
PromptOptimizationResult,
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
result: PromptOptimizationResult = optimize_prompts(
|
|
1445
|
+
predict_fn=predict_fn,
|
|
1446
|
+
train_data=dataset,
|
|
1447
|
+
prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
|
|
1448
|
+
optimizer=optimizer,
|
|
1449
|
+
scorers=scorers,
|
|
1450
|
+
enable_tracking=False, # Don't auto-register all candidates
|
|
1451
|
+
)
|
|
1452
|
+
|
|
1453
|
+
# Log the optimization results
|
|
1454
|
+
logger.info("Optimization complete!")
|
|
1455
|
+
logger.info(f"Optimizer used: {result.optimizer_name}")
|
|
1456
|
+
|
|
1457
|
+
if result.optimized_prompts:
|
|
1458
|
+
optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
|
|
1459
|
+
|
|
1460
|
+
# Check if the optimized prompt is actually different from the original
|
|
1461
|
+
original_template: str = prompt_version.to_single_brace_format().strip()
|
|
1462
|
+
optimized_template: str = (
|
|
1463
|
+
optimized_prompt_version.to_single_brace_format().strip()
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# Normalize whitespace for more robust comparison
|
|
1467
|
+
original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
|
|
1468
|
+
optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
|
|
1469
|
+
|
|
1470
|
+
logger.debug(f"Original template length: {len(original_template)} chars")
|
|
1471
|
+
logger.debug(f"Optimized template length: {len(optimized_template)} chars")
|
|
1472
|
+
logger.debug(
|
|
1473
|
+
f"Templates identical: {original_normalized == optimized_normalized}"
|
|
1474
|
+
)
|
|
1475
|
+
|
|
1476
|
+
if original_normalized == optimized_normalized:
|
|
1477
|
+
logger.info(
|
|
1478
|
+
f"Optimized prompt is identical to original for '{prompt.full_name}'. "
|
|
1479
|
+
"No new version will be registered."
|
|
1480
|
+
)
|
|
1481
|
+
return prompt
|
|
1482
|
+
|
|
1483
|
+
logger.info("Optimized prompt is DIFFERENT from original")
|
|
1484
|
+
logger.info(
|
|
1485
|
+
f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
|
|
1486
|
+
)
|
|
1487
|
+
logger.debug(
|
|
1488
|
+
f"Original template (first 300 chars): {original_template[:300]}..."
|
|
1489
|
+
)
|
|
1490
|
+
logger.debug(
|
|
1491
|
+
f"Optimized template (first 300 chars): {optimized_template[:300]}..."
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
# Check evaluation scores to determine if we should register the optimized prompt
|
|
1495
|
+
should_register: bool = False
|
|
1496
|
+
has_improvement: bool = False
|
|
1497
|
+
|
|
1498
|
+
if (
|
|
1499
|
+
result.initial_eval_score is not None
|
|
1500
|
+
and result.final_eval_score is not None
|
|
1501
|
+
):
|
|
1502
|
+
logger.info("Evaluation scores:")
|
|
1503
|
+
logger.info(f" Initial score: {result.initial_eval_score}")
|
|
1504
|
+
logger.info(f" Final score: {result.final_eval_score}")
|
|
1505
|
+
|
|
1506
|
+
# Only register if there's improvement
|
|
1507
|
+
if result.final_eval_score > result.initial_eval_score:
|
|
1508
|
+
improvement: float = (
|
|
1509
|
+
(result.final_eval_score - result.initial_eval_score)
|
|
1510
|
+
/ result.initial_eval_score
|
|
1511
|
+
) * 100
|
|
1512
|
+
logger.info(
|
|
1513
|
+
f"Optimized prompt improved by {improvement:.2f}% "
|
|
1514
|
+
f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
|
|
1515
|
+
)
|
|
1516
|
+
should_register = True
|
|
1517
|
+
has_improvement = True
|
|
1518
|
+
else:
|
|
1519
|
+
logger.info(
|
|
1520
|
+
f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
|
|
1521
|
+
"No new version will be registered."
|
|
1522
|
+
)
|
|
1523
|
+
else:
|
|
1524
|
+
# No scores available - register anyway but warn
|
|
1525
|
+
logger.warning(
|
|
1526
|
+
"No evaluation scores available to compare performance. "
|
|
1527
|
+
"Registering optimized prompt without performance validation."
|
|
1528
|
+
)
|
|
1529
|
+
should_register = True
|
|
1530
|
+
|
|
1531
|
+
if not should_register:
|
|
1532
|
+
logger.info(
|
|
1533
|
+
f"Skipping registration for '{prompt.full_name}' (no improvement)"
|
|
1534
|
+
)
|
|
1535
|
+
return prompt
|
|
1536
|
+
|
|
1537
|
+
# Register the optimized prompt manually
|
|
1538
|
+
try:
|
|
1539
|
+
logger.info(f"Registering optimized prompt '{prompt.full_name}'")
|
|
1540
|
+
registered_version: PromptVersion = mlflow.genai.register_prompt(
|
|
1541
|
+
name=prompt.full_name,
|
|
1542
|
+
template=optimized_template,
|
|
1543
|
+
commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
|
|
1544
|
+
tags={
|
|
1545
|
+
"dao_ai": dao_ai_version(),
|
|
1546
|
+
"target_model": agent_model.model.uri,
|
|
1547
|
+
},
|
|
1548
|
+
)
|
|
1549
|
+
logger.info(
|
|
1550
|
+
f"Registered optimized prompt as version {registered_version.version}"
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
# Always set "latest" alias (represents most recently registered prompt)
|
|
1554
|
+
logger.info(
|
|
1555
|
+
f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1556
|
+
)
|
|
1557
|
+
mlflow.genai.set_prompt_alias(
|
|
1558
|
+
name=prompt.full_name,
|
|
1559
|
+
alias="latest",
|
|
1560
|
+
version=registered_version.version,
|
|
1561
|
+
)
|
|
1562
|
+
logger.info(
|
|
1563
|
+
f"Successfully set 'latest' alias for '{prompt.full_name}' v{registered_version.version}"
|
|
1564
|
+
)
|
|
1565
|
+
|
|
1566
|
+
# If there's confirmed improvement, also set the "champion" alias
|
|
1567
|
+
# (represents the prompt that should be used by deployed agents)
|
|
1568
|
+
if has_improvement:
|
|
1569
|
+
logger.info(
|
|
1570
|
+
f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1571
|
+
)
|
|
1572
|
+
mlflow.genai.set_prompt_alias(
|
|
1573
|
+
name=prompt.full_name,
|
|
1574
|
+
alias="champion",
|
|
1575
|
+
version=registered_version.version,
|
|
1576
|
+
)
|
|
1577
|
+
logger.info(
|
|
1578
|
+
f"Successfully set 'champion' alias for '{prompt.full_name}' v{registered_version.version}"
|
|
1579
|
+
)
|
|
1580
|
+
|
|
1581
|
+
# Add target_model and dao_ai tags
|
|
1582
|
+
tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
|
|
1583
|
+
tags["target_model"] = agent_model.model.uri
|
|
1584
|
+
tags["dao_ai"] = dao_ai_version()
|
|
1585
|
+
|
|
1586
|
+
# Return the optimized prompt with the appropriate alias
|
|
1587
|
+
# Use "champion" if there was improvement, otherwise "latest"
|
|
1588
|
+
result_alias: str = "champion" if has_improvement else "latest"
|
|
1589
|
+
return PromptModel(
|
|
1590
|
+
name=prompt.name,
|
|
1591
|
+
schema=prompt.schema_model,
|
|
1592
|
+
description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
|
|
1593
|
+
alias=result_alias,
|
|
1594
|
+
tags=tags,
|
|
1595
|
+
)
|
|
1596
|
+
|
|
1597
|
+
except Exception as e:
|
|
1598
|
+
logger.error(
|
|
1599
|
+
f"Failed to register optimized prompt '{prompt.full_name}': {e}"
|
|
1600
|
+
)
|
|
1601
|
+
return prompt
|
|
1602
|
+
else:
|
|
1603
|
+
logger.warning("No optimized prompts returned from optimization")
|
|
1604
|
+
return prompt
|