dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import re
|
|
3
2
|
import uuid
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
from typing import Any, Callable, Final, Sequence
|
|
@@ -32,14 +31,12 @@ from mlflow import MlflowClient
|
|
|
32
31
|
from mlflow.entities import Experiment
|
|
33
32
|
from mlflow.entities.model_registry import PromptVersion
|
|
34
33
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
35
|
-
from mlflow.genai.datasets import EvaluationDataset, get_dataset
|
|
36
34
|
from mlflow.genai.prompts import load_prompt
|
|
37
35
|
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
|
|
38
36
|
from mlflow.models.model import ModelInfo
|
|
39
37
|
from mlflow.models.resources import (
|
|
40
38
|
DatabricksResource,
|
|
41
39
|
)
|
|
42
|
-
from mlflow.pyfunc import ResponsesAgent
|
|
43
40
|
from pyspark.sql import SparkSession
|
|
44
41
|
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
45
42
|
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
|
|
@@ -49,6 +46,7 @@ from dao_ai.config import (
|
|
|
49
46
|
AppConfig,
|
|
50
47
|
ConnectionModel,
|
|
51
48
|
DatabaseModel,
|
|
49
|
+
DatabricksAppModel,
|
|
52
50
|
DatasetModel,
|
|
53
51
|
FunctionModel,
|
|
54
52
|
GenieRoomModel,
|
|
@@ -57,7 +55,6 @@ from dao_ai.config import (
|
|
|
57
55
|
IsDatabricksResource,
|
|
58
56
|
LLMModel,
|
|
59
57
|
PromptModel,
|
|
60
|
-
PromptOptimizationModel,
|
|
61
58
|
SchemaModel,
|
|
62
59
|
TableModel,
|
|
63
60
|
UnityCatalogFunctionSqlModel,
|
|
@@ -250,6 +247,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
250
247
|
)
|
|
251
248
|
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
252
249
|
volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
|
|
250
|
+
apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
|
|
253
251
|
|
|
254
252
|
resources: Sequence[IsDatabricksResource] = (
|
|
255
253
|
llms
|
|
@@ -261,6 +259,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
261
259
|
+ connections
|
|
262
260
|
+ databases
|
|
263
261
|
+ volumes
|
|
262
|
+
+ apps
|
|
264
263
|
)
|
|
265
264
|
|
|
266
265
|
# Flatten all resources from all models into a single list
|
|
@@ -1190,20 +1189,94 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1190
1189
|
)
|
|
1191
1190
|
# Fall through to try other methods
|
|
1192
1191
|
|
|
1193
|
-
# Try to load in priority order: champion →
|
|
1192
|
+
# Try to load in priority order: champion → default (with sync check)
|
|
1194
1193
|
logger.debug(
|
|
1195
|
-
f"Trying fallback order for '{prompt_name}': champion →
|
|
1194
|
+
f"Trying fallback order for '{prompt_name}': champion → default (with auto-sync)"
|
|
1196
1195
|
)
|
|
1197
1196
|
|
|
1198
|
-
#
|
|
1197
|
+
# First, sync default alias if template has changed (even if champion exists)
|
|
1198
|
+
if prompt_model.default_template:
|
|
1199
|
+
try:
|
|
1200
|
+
# Try to load existing default
|
|
1201
|
+
existing_default = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1202
|
+
|
|
1203
|
+
# Check if champion exists and if it matches default
|
|
1204
|
+
champion_matches_default = False
|
|
1205
|
+
try:
|
|
1206
|
+
existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1207
|
+
champion_matches_default = (
|
|
1208
|
+
existing_champion.version == existing_default.version
|
|
1209
|
+
)
|
|
1210
|
+
logger.debug(
|
|
1211
|
+
f"Champion v{existing_champion.version} vs Default v{existing_default.version}: "
|
|
1212
|
+
f"{'tracking' if champion_matches_default else 'pinned separately'}"
|
|
1213
|
+
)
|
|
1214
|
+
except Exception:
|
|
1215
|
+
# No champion exists
|
|
1216
|
+
logger.debug(f"No champion alias found for '{prompt_name}'")
|
|
1217
|
+
|
|
1218
|
+
# Check if default_template differs from existing default
|
|
1219
|
+
if (
|
|
1220
|
+
existing_default.template.strip()
|
|
1221
|
+
!= prompt_model.default_template.strip()
|
|
1222
|
+
):
|
|
1223
|
+
logger.info(
|
|
1224
|
+
f"Default template for '{prompt_name}' has changed, "
|
|
1225
|
+
"registering new version with default alias"
|
|
1226
|
+
)
|
|
1227
|
+
|
|
1228
|
+
# Only update champion if it was pointing to the old default
|
|
1229
|
+
if champion_matches_default:
|
|
1230
|
+
logger.info(
|
|
1231
|
+
f"Champion was tracking default (v{existing_default.version}), "
|
|
1232
|
+
"will update champion to new default version"
|
|
1233
|
+
)
|
|
1234
|
+
set_champion = True
|
|
1235
|
+
else:
|
|
1236
|
+
logger.info("Champion is pinned separately, preserving it")
|
|
1237
|
+
set_champion = False
|
|
1238
|
+
|
|
1239
|
+
self._register_default_template(
|
|
1240
|
+
prompt_name,
|
|
1241
|
+
prompt_model.default_template,
|
|
1242
|
+
prompt_model.description,
|
|
1243
|
+
set_champion=set_champion,
|
|
1244
|
+
)
|
|
1245
|
+
except Exception as e:
|
|
1246
|
+
# No default exists yet, register it
|
|
1247
|
+
logger.debug(f"No default alias found for '{prompt_name}': {e}")
|
|
1248
|
+
logger.info(
|
|
1249
|
+
f"Registering default_template for '{prompt_name}' as default alias"
|
|
1250
|
+
)
|
|
1251
|
+
# First registration - set both default and champion
|
|
1252
|
+
self._register_default_template(
|
|
1253
|
+
prompt_name,
|
|
1254
|
+
prompt_model.default_template,
|
|
1255
|
+
prompt_model.description,
|
|
1256
|
+
set_champion=True,
|
|
1257
|
+
)
|
|
1258
|
+
|
|
1259
|
+
# 1. Try champion alias (highest priority for execution)
|
|
1199
1260
|
try:
|
|
1200
1261
|
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1201
|
-
logger.info(
|
|
1262
|
+
logger.info(
|
|
1263
|
+
f"Loaded prompt '{prompt_name}' from champion alias (default was synced separately)"
|
|
1264
|
+
)
|
|
1202
1265
|
return prompt_version
|
|
1203
1266
|
except Exception as e:
|
|
1204
1267
|
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1205
1268
|
|
|
1206
|
-
# 2. Try
|
|
1269
|
+
# 2. Try default alias (already synced above)
|
|
1270
|
+
if prompt_model.default_template:
|
|
1271
|
+
try:
|
|
1272
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1273
|
+
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1274
|
+
return prompt_version
|
|
1275
|
+
except Exception as e:
|
|
1276
|
+
# Should not happen since we just registered it above, but handle anyway
|
|
1277
|
+
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1278
|
+
|
|
1279
|
+
# 3. Try latest alias as final fallback
|
|
1207
1280
|
try:
|
|
1208
1281
|
prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
|
|
1209
1282
|
logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
|
|
@@ -1211,43 +1284,49 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1211
1284
|
except Exception as e:
|
|
1212
1285
|
logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
|
|
1213
1286
|
|
|
1214
|
-
#
|
|
1215
|
-
try:
|
|
1216
|
-
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1217
|
-
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1218
|
-
return prompt_version
|
|
1219
|
-
except Exception as e:
|
|
1220
|
-
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1221
|
-
|
|
1222
|
-
# 4. Try to register default_template if provided
|
|
1287
|
+
# 4. Final fallback: use default_template directly if available
|
|
1223
1288
|
if prompt_model.default_template:
|
|
1224
|
-
logger.
|
|
1225
|
-
f"
|
|
1226
|
-
"
|
|
1289
|
+
logger.warning(
|
|
1290
|
+
f"Could not load prompt '{prompt_name}' from registry. "
|
|
1291
|
+
"Using default_template directly (likely in test environment)"
|
|
1227
1292
|
)
|
|
1228
|
-
return
|
|
1229
|
-
prompt_name,
|
|
1293
|
+
return PromptVersion(
|
|
1294
|
+
name=prompt_name,
|
|
1295
|
+
version=1,
|
|
1296
|
+
template=prompt_model.default_template,
|
|
1297
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1230
1298
|
)
|
|
1231
1299
|
|
|
1232
1300
|
raise ValueError(
|
|
1233
1301
|
f"Prompt '{prompt_name}' not found in registry "
|
|
1234
|
-
"(tried champion
|
|
1302
|
+
"(tried champion, default, latest aliases) "
|
|
1235
1303
|
"and no default_template provided"
|
|
1236
1304
|
)
|
|
1237
1305
|
|
|
1238
1306
|
def _register_default_template(
|
|
1239
|
-
self,
|
|
1307
|
+
self,
|
|
1308
|
+
prompt_name: str,
|
|
1309
|
+
default_template: str,
|
|
1310
|
+
description: str | None = None,
|
|
1311
|
+
set_champion: bool = True,
|
|
1240
1312
|
) -> PromptVersion:
|
|
1241
1313
|
"""Register default_template as a new prompt version.
|
|
1242
1314
|
|
|
1243
|
-
|
|
1244
|
-
|
|
1315
|
+
Registers the template and sets the 'default' alias.
|
|
1316
|
+
Optionally sets 'champion' alias if no champion exists.
|
|
1317
|
+
|
|
1318
|
+
Args:
|
|
1319
|
+
prompt_name: Full name of the prompt
|
|
1320
|
+
default_template: The template content
|
|
1321
|
+
description: Optional description for commit message
|
|
1322
|
+
set_champion: Whether to also set champion alias (default: True)
|
|
1245
1323
|
|
|
1246
1324
|
If registration fails (e.g., in Model Serving with restricted permissions),
|
|
1247
1325
|
logs the error and raises.
|
|
1248
1326
|
"""
|
|
1249
1327
|
logger.info(
|
|
1250
|
-
f"
|
|
1328
|
+
f"Registering default_template for '{prompt_name}' "
|
|
1329
|
+
f"(set_champion={set_champion})"
|
|
1251
1330
|
)
|
|
1252
1331
|
|
|
1253
1332
|
try:
|
|
@@ -1259,23 +1338,35 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1259
1338
|
tags={"dao_ai": dao_ai_version()},
|
|
1260
1339
|
)
|
|
1261
1340
|
|
|
1262
|
-
#
|
|
1341
|
+
# Always set default alias
|
|
1263
1342
|
try:
|
|
1264
1343
|
mlflow.genai.set_prompt_alias(
|
|
1265
1344
|
name=prompt_name, alias="default", version=prompt_version.version
|
|
1266
1345
|
)
|
|
1267
|
-
mlflow.genai.set_prompt_alias(
|
|
1268
|
-
name=prompt_name, alias="champion", version=prompt_version.version
|
|
1269
|
-
)
|
|
1270
1346
|
logger.info(
|
|
1271
|
-
f"
|
|
1347
|
+
f"Set default alias for '{prompt_name}' v{prompt_version.version}"
|
|
1272
1348
|
)
|
|
1273
1349
|
except Exception as alias_error:
|
|
1274
1350
|
logger.warning(
|
|
1275
|
-
f"
|
|
1276
|
-
f"but failed to set aliases: {alias_error}"
|
|
1351
|
+
f"Could not set default alias for '{prompt_name}': {alias_error}"
|
|
1277
1352
|
)
|
|
1278
1353
|
|
|
1354
|
+
# Optionally set champion alias (only if no champion exists or explicitly requested)
|
|
1355
|
+
if set_champion:
|
|
1356
|
+
try:
|
|
1357
|
+
mlflow.genai.set_prompt_alias(
|
|
1358
|
+
name=prompt_name,
|
|
1359
|
+
alias="champion",
|
|
1360
|
+
version=prompt_version.version,
|
|
1361
|
+
)
|
|
1362
|
+
logger.info(
|
|
1363
|
+
f"Set champion alias for '{prompt_name}' v{prompt_version.version}"
|
|
1364
|
+
)
|
|
1365
|
+
except Exception as alias_error:
|
|
1366
|
+
logger.warning(
|
|
1367
|
+
f"Could not set champion alias for '{prompt_name}': {alias_error}"
|
|
1368
|
+
)
|
|
1369
|
+
|
|
1279
1370
|
return prompt_version
|
|
1280
1371
|
|
|
1281
1372
|
except Exception as reg_error:
|
|
@@ -1289,349 +1380,3 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1289
1380
|
template=default_template,
|
|
1290
1381
|
tags={"dao_ai": dao_ai_version()},
|
|
1291
1382
|
)
|
|
1292
|
-
|
|
1293
|
-
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1294
|
-
"""
|
|
1295
|
-
Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
|
|
1296
|
-
|
|
1297
|
-
This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
|
|
1298
|
-
https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
|
|
1299
|
-
|
|
1300
|
-
Args:
|
|
1301
|
-
optimization: PromptOptimizationModel containing configuration
|
|
1302
|
-
|
|
1303
|
-
Returns:
|
|
1304
|
-
PromptModel: The optimized prompt with new URI
|
|
1305
|
-
"""
|
|
1306
|
-
from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
|
|
1307
|
-
from mlflow.genai.scorers import Correctness
|
|
1308
|
-
|
|
1309
|
-
from dao_ai.config import AgentModel, PromptModel
|
|
1310
|
-
|
|
1311
|
-
logger.info(f"Optimizing prompt: {optimization.name}")
|
|
1312
|
-
|
|
1313
|
-
# Get agent and prompt (prompt is guaranteed to be set by validator)
|
|
1314
|
-
agent_model: AgentModel = optimization.agent
|
|
1315
|
-
prompt: PromptModel = optimization.prompt # type: ignore[assignment]
|
|
1316
|
-
agent_model.prompt = prompt.uri
|
|
1317
|
-
|
|
1318
|
-
print(f"prompt={agent_model.prompt}")
|
|
1319
|
-
# Log the prompt URI scheme being used
|
|
1320
|
-
# Supports three schemes:
|
|
1321
|
-
# 1. Specific version: "prompts:/qa/1" (when version is specified)
|
|
1322
|
-
# 2. Alias: "prompts:/qa@champion" (when alias is specified)
|
|
1323
|
-
# 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
|
|
1324
|
-
prompt_uri: str = prompt.uri
|
|
1325
|
-
logger.info(f"Using prompt URI for optimization: {prompt_uri}")
|
|
1326
|
-
|
|
1327
|
-
# Load the specific prompt version by URI for comparison
|
|
1328
|
-
# Try to load the exact version specified, but if it doesn't exist,
|
|
1329
|
-
# use get_prompt to create it from default_template
|
|
1330
|
-
prompt_version: PromptVersion
|
|
1331
|
-
try:
|
|
1332
|
-
prompt_version = load_prompt(prompt_uri)
|
|
1333
|
-
logger.info(f"Successfully loaded prompt from registry: {prompt_uri}")
|
|
1334
|
-
except Exception as e:
|
|
1335
|
-
logger.warning(
|
|
1336
|
-
f"Could not load prompt '{prompt_uri}' directly: {e}. "
|
|
1337
|
-
"Attempting to create from default_template..."
|
|
1338
|
-
)
|
|
1339
|
-
# Use get_prompt which will create from default_template if needed
|
|
1340
|
-
prompt_version = self.get_prompt(prompt)
|
|
1341
|
-
logger.info(
|
|
1342
|
-
f"Created/loaded prompt '{prompt.full_name}' (will optimize against this version)"
|
|
1343
|
-
)
|
|
1344
|
-
|
|
1345
|
-
# Load the evaluation dataset by name
|
|
1346
|
-
logger.debug(f"Looking up dataset: {optimization.dataset}")
|
|
1347
|
-
dataset: EvaluationDataset
|
|
1348
|
-
if isinstance(optimization.dataset, str):
|
|
1349
|
-
dataset = get_dataset(name=optimization.dataset)
|
|
1350
|
-
else:
|
|
1351
|
-
dataset = optimization.dataset.as_dataset()
|
|
1352
|
-
|
|
1353
|
-
# Set up reflection model for the optimizer
|
|
1354
|
-
reflection_model_name: str
|
|
1355
|
-
if optimization.reflection_model:
|
|
1356
|
-
if isinstance(optimization.reflection_model, str):
|
|
1357
|
-
reflection_model_name = optimization.reflection_model
|
|
1358
|
-
else:
|
|
1359
|
-
reflection_model_name = optimization.reflection_model.uri
|
|
1360
|
-
else:
|
|
1361
|
-
reflection_model_name = agent_model.model.uri
|
|
1362
|
-
logger.debug(f"Using reflection model: {reflection_model_name}")
|
|
1363
|
-
|
|
1364
|
-
# Create the GepaPromptOptimizer
|
|
1365
|
-
optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
|
|
1366
|
-
reflection_model=reflection_model_name,
|
|
1367
|
-
max_metric_calls=optimization.num_candidates,
|
|
1368
|
-
display_progress_bar=True,
|
|
1369
|
-
)
|
|
1370
|
-
|
|
1371
|
-
# Set up scorer (judge model for evaluation)
|
|
1372
|
-
scorer_model: str
|
|
1373
|
-
if optimization.scorer_model:
|
|
1374
|
-
if isinstance(optimization.scorer_model, str):
|
|
1375
|
-
scorer_model = optimization.scorer_model
|
|
1376
|
-
else:
|
|
1377
|
-
scorer_model = optimization.scorer_model.uri
|
|
1378
|
-
else:
|
|
1379
|
-
scorer_model = agent_model.model.uri # Use Databricks default
|
|
1380
|
-
logger.debug(f"Using scorer with model: {scorer_model}")
|
|
1381
|
-
|
|
1382
|
-
scorers: list[Correctness] = [Correctness(model=scorer_model)]
|
|
1383
|
-
|
|
1384
|
-
# Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
|
|
1385
|
-
# DO NOT overwrite with prompt_version.uri as that uses fallback logic
|
|
1386
|
-
logger.debug(f"Optimizing prompt: {prompt_uri}")
|
|
1387
|
-
|
|
1388
|
-
agent: ResponsesAgent = agent_model.as_responses_agent()
|
|
1389
|
-
|
|
1390
|
-
# Create predict function that will be optimized
|
|
1391
|
-
def predict_fn(**inputs: dict[str, Any]) -> str:
|
|
1392
|
-
"""Prediction function that uses the ResponsesAgent with ChatPayload.
|
|
1393
|
-
|
|
1394
|
-
The agent already has the prompt referenced/applied, so we just need to
|
|
1395
|
-
convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
|
|
1396
|
-
|
|
1397
|
-
Args:
|
|
1398
|
-
**inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
|
|
1399
|
-
|
|
1400
|
-
Returns:
|
|
1401
|
-
str: The agent's response content
|
|
1402
|
-
"""
|
|
1403
|
-
from mlflow.types.responses import (
|
|
1404
|
-
ResponsesAgentRequest,
|
|
1405
|
-
ResponsesAgentResponse,
|
|
1406
|
-
)
|
|
1407
|
-
from mlflow.types.responses_helpers import Message
|
|
1408
|
-
|
|
1409
|
-
from dao_ai.config import ChatPayload
|
|
1410
|
-
|
|
1411
|
-
# Verify agent is accessible (should be captured from outer scope)
|
|
1412
|
-
if agent is None:
|
|
1413
|
-
raise RuntimeError(
|
|
1414
|
-
"Agent object is not available in predict_fn. "
|
|
1415
|
-
"This may indicate a serialization issue with the ResponsesAgent."
|
|
1416
|
-
)
|
|
1417
|
-
|
|
1418
|
-
# Convert inputs to ChatPayload
|
|
1419
|
-
chat_payload: ChatPayload = ChatPayload(**inputs)
|
|
1420
|
-
|
|
1421
|
-
# Convert ChatPayload messages to MLflow Message format
|
|
1422
|
-
mlflow_messages: list[Message] = [
|
|
1423
|
-
Message(role=msg.role, content=msg.content)
|
|
1424
|
-
for msg in chat_payload.messages
|
|
1425
|
-
]
|
|
1426
|
-
|
|
1427
|
-
# Create ResponsesAgentRequest
|
|
1428
|
-
request: ResponsesAgentRequest = ResponsesAgentRequest(
|
|
1429
|
-
input=mlflow_messages,
|
|
1430
|
-
custom_inputs=chat_payload.custom_inputs,
|
|
1431
|
-
)
|
|
1432
|
-
|
|
1433
|
-
# Call the ResponsesAgent's predict method
|
|
1434
|
-
response: ResponsesAgentResponse = agent.predict(request)
|
|
1435
|
-
|
|
1436
|
-
if response.output and len(response.output) > 0:
|
|
1437
|
-
content = response.output[0].content
|
|
1438
|
-
logger.debug(f"Response content type: {type(content)}")
|
|
1439
|
-
logger.debug(f"Response content: {content}")
|
|
1440
|
-
|
|
1441
|
-
# Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
|
|
1442
|
-
# Content can be:
|
|
1443
|
-
# - A string (return as is)
|
|
1444
|
-
# - A list of items with 'text' keys (extract and join)
|
|
1445
|
-
# - Other types (try to get 'text' attribute or convert to string)
|
|
1446
|
-
if isinstance(content, str):
|
|
1447
|
-
return content
|
|
1448
|
-
elif isinstance(content, list):
|
|
1449
|
-
text_parts = []
|
|
1450
|
-
for content_item in content:
|
|
1451
|
-
if isinstance(content_item, str):
|
|
1452
|
-
text_parts.append(content_item)
|
|
1453
|
-
elif isinstance(content_item, dict) and "text" in content_item:
|
|
1454
|
-
text_parts.append(content_item["text"])
|
|
1455
|
-
elif hasattr(content_item, "text"):
|
|
1456
|
-
text_parts.append(content_item.text)
|
|
1457
|
-
return "".join(text_parts) if text_parts else str(content)
|
|
1458
|
-
else:
|
|
1459
|
-
# Fallback for unknown types - try to extract text attribute
|
|
1460
|
-
return getattr(content, "text", str(content))
|
|
1461
|
-
else:
|
|
1462
|
-
return ""
|
|
1463
|
-
|
|
1464
|
-
# Set registry URI for Databricks Unity Catalog
|
|
1465
|
-
mlflow.set_registry_uri("databricks-uc")
|
|
1466
|
-
|
|
1467
|
-
# Run optimization with tracking disabled to prevent auto-registering all candidates
|
|
1468
|
-
logger.info("Running prompt optimization with GepaPromptOptimizer...")
|
|
1469
|
-
logger.info(
|
|
1470
|
-
f"Generating {optimization.num_candidates} candidate prompts for evaluation"
|
|
1471
|
-
)
|
|
1472
|
-
|
|
1473
|
-
from mlflow.genai.optimize.types import (
|
|
1474
|
-
PromptOptimizationResult,
|
|
1475
|
-
)
|
|
1476
|
-
|
|
1477
|
-
result: PromptOptimizationResult = optimize_prompts(
|
|
1478
|
-
predict_fn=predict_fn,
|
|
1479
|
-
train_data=dataset,
|
|
1480
|
-
prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
|
|
1481
|
-
optimizer=optimizer,
|
|
1482
|
-
scorers=scorers,
|
|
1483
|
-
enable_tracking=False, # Don't auto-register all candidates
|
|
1484
|
-
)
|
|
1485
|
-
|
|
1486
|
-
# Log the optimization results
|
|
1487
|
-
logger.info("Optimization complete!")
|
|
1488
|
-
logger.info(f"Optimizer used: {result.optimizer_name}")
|
|
1489
|
-
|
|
1490
|
-
if result.optimized_prompts:
|
|
1491
|
-
optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
|
|
1492
|
-
|
|
1493
|
-
# Check if the optimized prompt is actually different from the original
|
|
1494
|
-
original_template: str = prompt_version.to_single_brace_format().strip()
|
|
1495
|
-
optimized_template: str = (
|
|
1496
|
-
optimized_prompt_version.to_single_brace_format().strip()
|
|
1497
|
-
)
|
|
1498
|
-
|
|
1499
|
-
# Normalize whitespace for more robust comparison
|
|
1500
|
-
original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
|
|
1501
|
-
optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
|
|
1502
|
-
|
|
1503
|
-
logger.debug(f"Original template length: {len(original_template)} chars")
|
|
1504
|
-
logger.debug(f"Optimized template length: {len(optimized_template)} chars")
|
|
1505
|
-
logger.debug(
|
|
1506
|
-
f"Templates identical: {original_normalized == optimized_normalized}"
|
|
1507
|
-
)
|
|
1508
|
-
|
|
1509
|
-
if original_normalized == optimized_normalized:
|
|
1510
|
-
logger.info(
|
|
1511
|
-
f"Optimized prompt is identical to original for '{prompt.full_name}'. "
|
|
1512
|
-
"No new version will be registered."
|
|
1513
|
-
)
|
|
1514
|
-
return prompt
|
|
1515
|
-
|
|
1516
|
-
logger.info("Optimized prompt is DIFFERENT from original")
|
|
1517
|
-
logger.info(
|
|
1518
|
-
f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
|
|
1519
|
-
)
|
|
1520
|
-
logger.debug(
|
|
1521
|
-
f"Original template (first 300 chars): {original_template[:300]}..."
|
|
1522
|
-
)
|
|
1523
|
-
logger.debug(
|
|
1524
|
-
f"Optimized template (first 300 chars): {optimized_template[:300]}..."
|
|
1525
|
-
)
|
|
1526
|
-
|
|
1527
|
-
# Check evaluation scores to determine if we should register the optimized prompt
|
|
1528
|
-
should_register: bool = False
|
|
1529
|
-
has_improvement: bool = False
|
|
1530
|
-
|
|
1531
|
-
if (
|
|
1532
|
-
result.initial_eval_score is not None
|
|
1533
|
-
and result.final_eval_score is not None
|
|
1534
|
-
):
|
|
1535
|
-
logger.info("Evaluation scores:")
|
|
1536
|
-
logger.info(f" Initial score: {result.initial_eval_score}")
|
|
1537
|
-
logger.info(f" Final score: {result.final_eval_score}")
|
|
1538
|
-
|
|
1539
|
-
# Only register if there's improvement
|
|
1540
|
-
if result.final_eval_score > result.initial_eval_score:
|
|
1541
|
-
improvement: float = (
|
|
1542
|
-
(result.final_eval_score - result.initial_eval_score)
|
|
1543
|
-
/ result.initial_eval_score
|
|
1544
|
-
) * 100
|
|
1545
|
-
logger.info(
|
|
1546
|
-
f"Optimized prompt improved by {improvement:.2f}% "
|
|
1547
|
-
f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
|
|
1548
|
-
)
|
|
1549
|
-
should_register = True
|
|
1550
|
-
has_improvement = True
|
|
1551
|
-
else:
|
|
1552
|
-
logger.info(
|
|
1553
|
-
f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
|
|
1554
|
-
"No new version will be registered."
|
|
1555
|
-
)
|
|
1556
|
-
else:
|
|
1557
|
-
# No scores available - register anyway but warn
|
|
1558
|
-
logger.warning(
|
|
1559
|
-
"No evaluation scores available to compare performance. "
|
|
1560
|
-
"Registering optimized prompt without performance validation."
|
|
1561
|
-
)
|
|
1562
|
-
should_register = True
|
|
1563
|
-
|
|
1564
|
-
if not should_register:
|
|
1565
|
-
logger.info(
|
|
1566
|
-
f"Skipping registration for '{prompt.full_name}' (no improvement)"
|
|
1567
|
-
)
|
|
1568
|
-
return prompt
|
|
1569
|
-
|
|
1570
|
-
# Register the optimized prompt manually
|
|
1571
|
-
try:
|
|
1572
|
-
logger.info(f"Registering optimized prompt '{prompt.full_name}'")
|
|
1573
|
-
registered_version: PromptVersion = mlflow.genai.register_prompt(
|
|
1574
|
-
name=prompt.full_name,
|
|
1575
|
-
template=optimized_template,
|
|
1576
|
-
commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
|
|
1577
|
-
tags={
|
|
1578
|
-
"dao_ai": dao_ai_version(),
|
|
1579
|
-
"target_model": agent_model.model.uri,
|
|
1580
|
-
},
|
|
1581
|
-
)
|
|
1582
|
-
logger.info(
|
|
1583
|
-
f"Registered optimized prompt as version {registered_version.version}"
|
|
1584
|
-
)
|
|
1585
|
-
|
|
1586
|
-
# Always set "latest" alias (represents most recently registered prompt)
|
|
1587
|
-
logger.info(
|
|
1588
|
-
f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1589
|
-
)
|
|
1590
|
-
mlflow.genai.set_prompt_alias(
|
|
1591
|
-
name=prompt.full_name,
|
|
1592
|
-
alias="latest",
|
|
1593
|
-
version=registered_version.version,
|
|
1594
|
-
)
|
|
1595
|
-
logger.info(
|
|
1596
|
-
f"Successfully set 'latest' alias for '{prompt.full_name}' v{registered_version.version}"
|
|
1597
|
-
)
|
|
1598
|
-
|
|
1599
|
-
# If there's confirmed improvement, also set the "champion" alias
|
|
1600
|
-
# (represents the prompt that should be used by deployed agents)
|
|
1601
|
-
if has_improvement:
|
|
1602
|
-
logger.info(
|
|
1603
|
-
f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1604
|
-
)
|
|
1605
|
-
mlflow.genai.set_prompt_alias(
|
|
1606
|
-
name=prompt.full_name,
|
|
1607
|
-
alias="champion",
|
|
1608
|
-
version=registered_version.version,
|
|
1609
|
-
)
|
|
1610
|
-
logger.info(
|
|
1611
|
-
f"Successfully set 'champion' alias for '{prompt.full_name}' v{registered_version.version}"
|
|
1612
|
-
)
|
|
1613
|
-
|
|
1614
|
-
# Add target_model and dao_ai tags
|
|
1615
|
-
tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
|
|
1616
|
-
tags["target_model"] = agent_model.model.uri
|
|
1617
|
-
tags["dao_ai"] = dao_ai_version()
|
|
1618
|
-
|
|
1619
|
-
# Return the optimized prompt with the appropriate alias
|
|
1620
|
-
# Use "champion" if there was improvement, otherwise "latest"
|
|
1621
|
-
result_alias: str = "champion" if has_improvement else "latest"
|
|
1622
|
-
return PromptModel(
|
|
1623
|
-
name=prompt.name,
|
|
1624
|
-
schema=prompt.schema_model,
|
|
1625
|
-
description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
|
|
1626
|
-
alias=result_alias,
|
|
1627
|
-
tags=tags,
|
|
1628
|
-
)
|
|
1629
|
-
|
|
1630
|
-
except Exception as e:
|
|
1631
|
-
logger.error(
|
|
1632
|
-
f"Failed to register optimized prompt '{prompt.full_name}': {e}"
|
|
1633
|
-
)
|
|
1634
|
-
return prompt
|
|
1635
|
-
else:
|
|
1636
|
-
logger.warning("No optimized prompts returned from optimization")
|
|
1637
|
-
return prompt
|