dao-ai 0.0.25__py3-none-any.whl → 0.0.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  import base64
2
+ import re
2
3
  import uuid
3
- from importlib.metadata import version
4
4
  from pathlib import Path
5
5
  from typing import Any, Callable, Final, Sequence
6
6
 
@@ -32,11 +32,14 @@ from mlflow import MlflowClient
32
32
  from mlflow.entities import Experiment
33
33
  from mlflow.entities.model_registry import PromptVersion
34
34
  from mlflow.entities.model_registry.model_version import ModelVersion
35
+ from mlflow.genai.datasets import EvaluationDataset, get_dataset
36
+ from mlflow.genai.prompts import load_prompt
35
37
  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
36
38
  from mlflow.models.model import ModelInfo
37
39
  from mlflow.models.resources import (
38
40
  DatabricksResource,
39
41
  )
42
+ from mlflow.pyfunc import ResponsesAgent
40
43
  from pyspark.sql import SparkSession
41
44
  from unitycatalog.ai.core.base import FunctionExecutionResult
42
45
  from unitycatalog.ai.core.databricks import DatabricksFunctionClient
@@ -54,6 +57,7 @@ from dao_ai.config import (
54
57
  IsDatabricksResource,
55
58
  LLMModel,
56
59
  PromptModel,
60
+ PromptOptimizationModel,
57
61
  SchemaModel,
58
62
  TableModel,
59
63
  UnityCatalogFunctionSqlModel,
@@ -65,6 +69,7 @@ from dao_ai.config import (
65
69
  from dao_ai.models import get_latest_model_version
66
70
  from dao_ai.providers.base import ServiceProvider
67
71
  from dao_ai.utils import (
72
+ dao_ai_version,
68
73
  get_installed_packages,
69
74
  is_installed,
70
75
  is_lib_provided,
@@ -296,7 +301,7 @@ class DatabricksProvider(ServiceProvider):
296
301
  if is_installed():
297
302
  if not is_lib_provided("dao-ai", pip_requirements):
298
303
  pip_requirements += [
299
- f"dao-ai=={version('dao-ai')}",
304
+ f"dao-ai=={dao_ai_version()}",
300
305
  ]
301
306
  else:
302
307
  src_path: Path = model_root_path.parent
@@ -322,10 +327,11 @@ class DatabricksProvider(ServiceProvider):
322
327
 
323
328
  with mlflow.start_run(run_name=run_name):
324
329
  mlflow.set_tag("type", "agent")
330
+ mlflow.set_tag("dao_ai", dao_ai_version())
325
331
  logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
326
332
  python_model=model_path.as_posix(),
327
333
  code_paths=code_paths,
328
- model_config=config.model_dump(by_alias=True),
334
+ model_config=config.model_dump(mode="json", by_alias=True),
329
335
  name="agent",
330
336
  pip_requirements=pip_requirements,
331
337
  input_example=input_example,
@@ -344,9 +350,18 @@ class DatabricksProvider(ServiceProvider):
344
350
 
345
351
  client: MlflowClient = MlflowClient()
346
352
 
353
+ # Set tags on the model version
354
+ client.set_model_version_tag(
355
+ name=registered_model_name,
356
+ version=model_version.version,
357
+ key="dao_ai",
358
+ value=dao_ai_version(),
359
+ )
360
+ logger.debug(f"Set dao_ai tag on model version {model_version.version}")
361
+
347
362
  client.set_registered_model_alias(
348
363
  name=registered_model_name,
349
- alias="Current",
364
+ alias="Champion",
350
365
  version=model_version.version,
351
366
  )
352
367
 
@@ -372,7 +387,10 @@ class DatabricksProvider(ServiceProvider):
372
387
  scale_to_zero: bool = config.app.scale_to_zero
373
388
  environment_vars: dict[str, str] = config.app.environment_vars
374
389
  workload_size: str = config.app.workload_size
375
- tags: dict[str, str] = config.app.tags
390
+ tags: dict[str, str] = config.app.tags.copy() if config.app.tags else {}
391
+
392
+ # Add dao_ai framework tag
393
+ tags["dao_ai"] = dao_ai_version()
376
394
 
377
395
  latest_version: int = get_latest_model_version(registered_model_name)
378
396
 
@@ -1026,42 +1044,124 @@ class DatabricksProvider(ServiceProvider):
1026
1044
  )
1027
1045
  raise
1028
1046
 
1029
- def get_prompt(self, prompt_model: PromptModel) -> str:
1030
- """Load prompt from MLflow Prompt Registry or fall back to default_template."""
1031
- prompt_name: str = prompt_model.full_name
1047
+ def get_prompt(self, prompt_model: PromptModel) -> PromptVersion:
1048
+ """
1049
+ Load prompt from MLflow Prompt Registry with fallback logic.
1032
1050
 
1033
- # Build prompt URI based on alias, version, or default to latest
1034
- if prompt_model.alias:
1035
- prompt_uri = f"prompts:/{prompt_name}@{prompt_model.alias}"
1036
- elif prompt_model.version:
1037
- prompt_uri = f"prompts:/{prompt_name}/{prompt_model.version}"
1038
- else:
1039
- prompt_uri = f"prompts:/{prompt_name}@latest"
1051
+ If an explicit version or alias is specified in the prompt_model, uses that directly.
1052
+ Otherwise, tries to load prompts in this order:
1053
+ 1. champion alias (if it exists)
1054
+ 2. latest alias (if it exists)
1055
+ 3. default_template (if provided)
1040
1056
 
1041
- try:
1042
- from mlflow.genai.prompts import Prompt
1057
+ Args:
1058
+ prompt_model: The prompt model configuration
1043
1059
 
1044
- prompt_obj: Prompt = mlflow.genai.load_prompt(prompt_uri)
1045
- return prompt_obj.to_single_brace_format()
1060
+ Returns:
1061
+ PromptVersion: The loaded prompt version
1046
1062
 
1047
- except Exception as e:
1048
- logger.warning(f"Failed to load prompt '{prompt_name}' from registry: {e}")
1063
+ Raises:
1064
+ ValueError: If no prompt can be loaded from any source
1065
+ """
1066
+ prompt_name: str = prompt_model.full_name
1049
1067
 
1050
- if prompt_model.default_template:
1051
- logger.info(f"Using default_template for '{prompt_name}'")
1052
- self._sync_default_template_to_registry(
1053
- prompt_name, prompt_model.default_template, prompt_model.description
1068
+ # If explicit version or alias is specified, use it directly without fallback
1069
+ if prompt_model.version or prompt_model.alias:
1070
+ try:
1071
+ prompt_version: PromptVersion = prompt_model.as_prompt()
1072
+ logger.debug(
1073
+ f"Loaded prompt '{prompt_name}' with explicit "
1074
+ f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}"
1075
+ )
1076
+ return prompt_version
1077
+ except Exception as e:
1078
+ logger.warning(
1079
+ f"Failed to load prompt '{prompt_name}' with explicit "
1080
+ f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
1054
1081
  )
1055
- return prompt_model.default_template
1082
+ # Fall through to default_template if available
1083
+ else:
1084
+ # No explicit version/alias specified - check if default_template needs syncing first
1085
+ logger.debug(
1086
+ f"No explicit version/alias specified for '{prompt_name}', "
1087
+ "checking if default_template needs syncing"
1088
+ )
1056
1089
 
1057
- raise ValueError(
1058
- f"Prompt '{prompt_name}' not found in registry and no default_template provided"
1059
- ) from e
1090
+ # If we have a default_template, check if it differs from what's in the registry
1091
+ # This ensures we always sync config changes before returning any alias
1092
+ if prompt_model.default_template:
1093
+ try:
1094
+ default_uri: str = f"prompts:/{prompt_name}@default"
1095
+ default_version: PromptVersion = load_prompt(default_uri)
1096
+
1097
+ if (
1098
+ default_version.to_single_brace_format().strip()
1099
+ != prompt_model.default_template.strip()
1100
+ ):
1101
+ logger.info(
1102
+ f"Config default_template for '{prompt_name}' differs from registry, syncing..."
1103
+ )
1104
+ return self._sync_default_template_to_registry(
1105
+ prompt_name,
1106
+ prompt_model.default_template,
1107
+ prompt_model.description,
1108
+ )
1109
+ except Exception as e:
1110
+ logger.debug(f"Could not check default alias for sync: {e}")
1111
+
1112
+ # Now try aliases in order: champion → latest → default
1113
+ logger.debug(
1114
+ f"Trying fallback order for '{prompt_name}': champion → latest → default"
1115
+ )
1116
+
1117
+ # Try champion alias first
1118
+ try:
1119
+ champion_uri: str = f"prompts:/{prompt_name}@champion"
1120
+ prompt_version: PromptVersion = load_prompt(champion_uri)
1121
+ logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
1122
+ return prompt_version
1123
+ except Exception as e:
1124
+ logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1125
+
1126
+ # Try latest alias next
1127
+ try:
1128
+ latest_uri: str = f"prompts:/{prompt_name}@latest"
1129
+ prompt_version: PromptVersion = load_prompt(latest_uri)
1130
+ logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
1131
+ return prompt_version
1132
+ except Exception as e:
1133
+ logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
1134
+
1135
+ # Try default alias last
1136
+ try:
1137
+ default_uri: str = f"prompts:/{prompt_name}@default"
1138
+ prompt_version: PromptVersion = load_prompt(default_uri)
1139
+ logger.info(f"Loaded prompt '{prompt_name}' from default alias")
1140
+ return prompt_version
1141
+ except Exception as e:
1142
+ logger.debug(f"Default alias not found for '{prompt_name}': {e}")
1143
+
1144
+ # Fall back to registering default_template if provided
1145
+ if prompt_model.default_template:
1146
+ logger.info(
1147
+ f"Registering default_template for '{prompt_name}' "
1148
+ "(no aliases found in registry)"
1149
+ )
1150
+ return self._sync_default_template_to_registry(
1151
+ prompt_name, prompt_model.default_template, prompt_model.description
1152
+ )
1153
+
1154
+ raise ValueError(
1155
+ f"Prompt '{prompt_name}' not found in registry "
1156
+ "(tried champion, latest, default aliases) and no default_template provided"
1157
+ )
1060
1158
 
1061
1159
  def _sync_default_template_to_registry(
1062
1160
  self, prompt_name: str, default_template: str, description: str | None = None
1063
- ) -> None:
1161
+ ) -> PromptVersion:
1064
1162
  """Register default_template to prompt registry under 'default' alias if changed."""
1163
+ prompt_version: PromptVersion | None = None
1164
+
1065
1165
  try:
1066
1166
  # Check if default alias already has the same template
1067
1167
  try:
@@ -1074,7 +1174,45 @@ class DatabricksProvider(ServiceProvider):
1074
1174
  == default_template.strip()
1075
1175
  ):
1076
1176
  logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
1077
- return # Already up-to-date
1177
+
1178
+ # Ensure the "latest" and "champion" aliases also exist and point to the same version
1179
+ # This handles prompts created before the fix that added these aliases
1180
+ try:
1181
+ latest_version: PromptVersion = mlflow.genai.load_prompt(
1182
+ f"prompts:/{prompt_name}@latest"
1183
+ )
1184
+ logger.debug(
1185
+ f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
1186
+ )
1187
+ except Exception:
1188
+ logger.info(
1189
+ f"Setting 'latest' alias for existing prompt '{prompt_name}' v{existing.version}"
1190
+ )
1191
+ mlflow.genai.set_prompt_alias(
1192
+ name=prompt_name,
1193
+ alias="latest",
1194
+ version=existing.version,
1195
+ )
1196
+
1197
+ # Ensure champion alias exists for first-time deployments
1198
+ try:
1199
+ champion_version: PromptVersion = mlflow.genai.load_prompt(
1200
+ f"prompts:/{prompt_name}@champion"
1201
+ )
1202
+ logger.debug(
1203
+ f"Champion alias already exists for '{prompt_name}' pointing to version {champion_version.version}"
1204
+ )
1205
+ except Exception:
1206
+ logger.info(
1207
+ f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
1208
+ )
1209
+ mlflow.genai.set_prompt_alias(
1210
+ name=prompt_name,
1211
+ alias="champion",
1212
+ version=existing.version,
1213
+ )
1214
+
1215
+ return existing # Already up-to-date, return existing version
1078
1216
  except Exception:
1079
1217
  logger.debug(
1080
1218
  f"Default alias for prompt '{prompt_name}' doesn't exist yet"
@@ -1082,22 +1220,385 @@ class DatabricksProvider(ServiceProvider):
1082
1220
 
1083
1221
  # Register new version and set as default alias
1084
1222
  commit_message = description or "Auto-synced from default_template"
1085
- prompt_version: PromptVersion = mlflow.genai.register_prompt(
1223
+ prompt_version = mlflow.genai.register_prompt(
1086
1224
  name=prompt_name,
1087
1225
  template=default_template,
1088
1226
  commit_message=commit_message,
1227
+ tags={"dao_ai": dao_ai_version()},
1089
1228
  )
1090
1229
 
1091
- logger.debug(f"Setting default alias for prompt '{prompt_name}'")
1230
+ logger.debug(
1231
+ f"Setting default, latest, and champion aliases for prompt '{prompt_name}'"
1232
+ )
1092
1233
  mlflow.genai.set_prompt_alias(
1093
1234
  name=prompt_name,
1094
1235
  alias="default",
1095
1236
  version=prompt_version.version,
1096
1237
  )
1238
+ mlflow.genai.set_prompt_alias(
1239
+ name=prompt_name,
1240
+ alias="latest",
1241
+ version=prompt_version.version,
1242
+ )
1243
+ mlflow.genai.set_prompt_alias(
1244
+ name=prompt_name,
1245
+ alias="champion",
1246
+ version=prompt_version.version,
1247
+ )
1097
1248
 
1098
1249
  logger.info(
1099
- f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry"
1250
+ f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
1100
1251
  )
1252
+ return prompt_version
1101
1253
 
1102
1254
  except Exception as e:
1103
- logger.warning(f"Failed to sync '{prompt_name}' to registry: {e}")
1255
+ logger.error(f"Failed to sync '{prompt_name}' to registry: {e}")
1256
+ raise ValueError(
1257
+ f"Failed to sync prompt '{prompt_name}' to registry and unable to retrieve existing version"
1258
+ ) from e
1259
+
1260
+ def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
1261
+ """
1262
+ Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
1263
+
1264
+ This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
1265
+ https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
1266
+
1267
+ Args:
1268
+ optimization: PromptOptimizationModel containing configuration
1269
+
1270
+ Returns:
1271
+ PromptModel: The optimized prompt with new URI
1272
+ """
1273
+ from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
1274
+ from mlflow.genai.scorers import Correctness
1275
+
1276
+ from dao_ai.config import AgentModel, PromptModel
1277
+
1278
+ logger.info(f"Optimizing prompt: {optimization.name}")
1279
+
1280
+ # Get agent and prompt (prompt is guaranteed to be set by validator)
1281
+ agent_model: AgentModel = optimization.agent
1282
+ prompt: PromptModel = optimization.prompt # type: ignore[assignment]
1283
+ agent_model.prompt = prompt.uri
1284
+
1285
+ print(f"prompt={agent_model.prompt}")
1286
+ # Log the prompt URI scheme being used
1287
+ # Supports three schemes:
1288
+ # 1. Specific version: "prompts:/qa/1" (when version is specified)
1289
+ # 2. Alias: "prompts:/qa@champion" (when alias is specified)
1290
+ # 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
1291
+ prompt_uri: str = prompt.uri
1292
+ logger.info(f"Using prompt URI for optimization: {prompt_uri}")
1293
+
1294
+ # Load the specific prompt version by URI for comparison
1295
+ # Try to load the exact version specified, but if it doesn't exist,
1296
+ # use get_prompt to create it from default_template
1297
+ prompt_version: PromptVersion
1298
+ try:
1299
+ prompt_version = load_prompt(prompt_uri)
1300
+ logger.info(f"Successfully loaded prompt from registry: {prompt_uri}")
1301
+ except Exception as e:
1302
+ logger.warning(
1303
+ f"Could not load prompt '{prompt_uri}' directly: {e}. "
1304
+ "Attempting to create from default_template..."
1305
+ )
1306
+ # Use get_prompt which will create from default_template if needed
1307
+ prompt_version = self.get_prompt(prompt)
1308
+ logger.info(
1309
+ f"Created/loaded prompt '{prompt.full_name}' (will optimize against this version)"
1310
+ )
1311
+
1312
+ # Load the evaluation dataset by name
1313
+ logger.debug(f"Looking up dataset: {optimization.dataset}")
1314
+ dataset: EvaluationDataset
1315
+ if isinstance(optimization.dataset, str):
1316
+ dataset = get_dataset(name=optimization.dataset)
1317
+ else:
1318
+ dataset = optimization.dataset.as_dataset()
1319
+
1320
+ # Set up reflection model for the optimizer
1321
+ reflection_model_name: str
1322
+ if optimization.reflection_model:
1323
+ if isinstance(optimization.reflection_model, str):
1324
+ reflection_model_name = optimization.reflection_model
1325
+ else:
1326
+ reflection_model_name = optimization.reflection_model.uri
1327
+ else:
1328
+ reflection_model_name = agent_model.model.uri
1329
+ logger.debug(f"Using reflection model: {reflection_model_name}")
1330
+
1331
+ # Create the GepaPromptOptimizer
1332
+ optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
1333
+ reflection_model=reflection_model_name,
1334
+ max_metric_calls=optimization.num_candidates,
1335
+ display_progress_bar=True,
1336
+ )
1337
+
1338
+ # Set up scorer (judge model for evaluation)
1339
+ scorer_model: str
1340
+ if optimization.scorer_model:
1341
+ if isinstance(optimization.scorer_model, str):
1342
+ scorer_model = optimization.scorer_model
1343
+ else:
1344
+ scorer_model = optimization.scorer_model.uri
1345
+ else:
1346
+ scorer_model = agent_model.model.uri # Use Databricks default
1347
+ logger.debug(f"Using scorer with model: {scorer_model}")
1348
+
1349
+ scorers: list[Correctness] = [Correctness(model=scorer_model)]
1350
+
1351
+ # Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
1352
+ # DO NOT overwrite with prompt_version.uri as that uses fallback logic
1353
+ logger.debug(f"Optimizing prompt: {prompt_uri}")
1354
+
1355
+ agent: ResponsesAgent = agent_model.as_responses_agent()
1356
+
1357
+ # Create predict function that will be optimized
1358
+ def predict_fn(**inputs: dict[str, Any]) -> str:
1359
+ """Prediction function that uses the ResponsesAgent with ChatPayload.
1360
+
1361
+ The agent already has the prompt referenced/applied, so we just need to
1362
+ convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
1363
+
1364
+ Args:
1365
+ **inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
1366
+
1367
+ Returns:
1368
+ str: The agent's response content
1369
+ """
1370
+ from mlflow.types.responses import (
1371
+ ResponsesAgentRequest,
1372
+ ResponsesAgentResponse,
1373
+ )
1374
+ from mlflow.types.responses_helpers import Message
1375
+
1376
+ from dao_ai.config import ChatPayload
1377
+
1378
+ # Verify agent is accessible (should be captured from outer scope)
1379
+ if agent is None:
1380
+ raise RuntimeError(
1381
+ "Agent object is not available in predict_fn. "
1382
+ "This may indicate a serialization issue with the ResponsesAgent."
1383
+ )
1384
+
1385
+ # Convert inputs to ChatPayload
1386
+ chat_payload: ChatPayload = ChatPayload(**inputs)
1387
+
1388
+ # Convert ChatPayload messages to MLflow Message format
1389
+ mlflow_messages: list[Message] = [
1390
+ Message(role=msg.role, content=msg.content)
1391
+ for msg in chat_payload.messages
1392
+ ]
1393
+
1394
+ # Create ResponsesAgentRequest
1395
+ request: ResponsesAgentRequest = ResponsesAgentRequest(
1396
+ input=mlflow_messages,
1397
+ custom_inputs=chat_payload.custom_inputs,
1398
+ )
1399
+
1400
+ # Call the ResponsesAgent's predict method
1401
+ response: ResponsesAgentResponse = agent.predict(request)
1402
+
1403
+ if response.output and len(response.output) > 0:
1404
+ content = response.output[0].content
1405
+ logger.debug(f"Response content type: {type(content)}")
1406
+ logger.debug(f"Response content: {content}")
1407
+
1408
+ # Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
1409
+ # Content can be:
1410
+ # - A string (return as is)
1411
+ # - A list of items with 'text' keys (extract and join)
1412
+ # - Other types (try to get 'text' attribute or convert to string)
1413
+ if isinstance(content, str):
1414
+ return content
1415
+ elif isinstance(content, list):
1416
+ text_parts = []
1417
+ for content_item in content:
1418
+ if isinstance(content_item, str):
1419
+ text_parts.append(content_item)
1420
+ elif isinstance(content_item, dict) and "text" in content_item:
1421
+ text_parts.append(content_item["text"])
1422
+ elif hasattr(content_item, "text"):
1423
+ text_parts.append(content_item.text)
1424
+ return "".join(text_parts) if text_parts else str(content)
1425
+ else:
1426
+ # Fallback for unknown types - try to extract text attribute
1427
+ return getattr(content, "text", str(content))
1428
+ else:
1429
+ return ""
1430
+
1431
+ # Set registry URI for Databricks Unity Catalog
1432
+ mlflow.set_registry_uri("databricks-uc")
1433
+
1434
+ # Run optimization with tracking disabled to prevent auto-registering all candidates
1435
+ logger.info("Running prompt optimization with GepaPromptOptimizer...")
1436
+ logger.info(
1437
+ f"Generating {optimization.num_candidates} candidate prompts for evaluation"
1438
+ )
1439
+
1440
+ from mlflow.genai.optimize.types import (
1441
+ PromptOptimizationResult,
1442
+ )
1443
+
1444
+ result: PromptOptimizationResult = optimize_prompts(
1445
+ predict_fn=predict_fn,
1446
+ train_data=dataset,
1447
+ prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
1448
+ optimizer=optimizer,
1449
+ scorers=scorers,
1450
+ enable_tracking=False, # Don't auto-register all candidates
1451
+ )
1452
+
1453
+ # Log the optimization results
1454
+ logger.info("Optimization complete!")
1455
+ logger.info(f"Optimizer used: {result.optimizer_name}")
1456
+
1457
+ if result.optimized_prompts:
1458
+ optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
1459
+
1460
+ # Check if the optimized prompt is actually different from the original
1461
+ original_template: str = prompt_version.to_single_brace_format().strip()
1462
+ optimized_template: str = (
1463
+ optimized_prompt_version.to_single_brace_format().strip()
1464
+ )
1465
+
1466
+ # Normalize whitespace for more robust comparison
1467
+ original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
1468
+ optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
1469
+
1470
+ logger.debug(f"Original template length: {len(original_template)} chars")
1471
+ logger.debug(f"Optimized template length: {len(optimized_template)} chars")
1472
+ logger.debug(
1473
+ f"Templates identical: {original_normalized == optimized_normalized}"
1474
+ )
1475
+
1476
+ if original_normalized == optimized_normalized:
1477
+ logger.info(
1478
+ f"Optimized prompt is identical to original for '{prompt.full_name}'. "
1479
+ "No new version will be registered."
1480
+ )
1481
+ return prompt
1482
+
1483
+ logger.info("Optimized prompt is DIFFERENT from original")
1484
+ logger.info(
1485
+ f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
1486
+ )
1487
+ logger.debug(
1488
+ f"Original template (first 300 chars): {original_template[:300]}..."
1489
+ )
1490
+ logger.debug(
1491
+ f"Optimized template (first 300 chars): {optimized_template[:300]}..."
1492
+ )
1493
+
1494
+ # Check evaluation scores to determine if we should register the optimized prompt
1495
+ should_register: bool = False
1496
+ has_improvement: bool = False
1497
+
1498
+ if (
1499
+ result.initial_eval_score is not None
1500
+ and result.final_eval_score is not None
1501
+ ):
1502
+ logger.info("Evaluation scores:")
1503
+ logger.info(f" Initial score: {result.initial_eval_score}")
1504
+ logger.info(f" Final score: {result.final_eval_score}")
1505
+
1506
+ # Only register if there's improvement
1507
+ if result.final_eval_score > result.initial_eval_score:
1508
+ improvement: float = (
1509
+ (result.final_eval_score - result.initial_eval_score)
1510
+ / result.initial_eval_score
1511
+ ) * 100
1512
+ logger.info(
1513
+ f"Optimized prompt improved by {improvement:.2f}% "
1514
+ f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
1515
+ )
1516
+ should_register = True
1517
+ has_improvement = True
1518
+ else:
1519
+ logger.info(
1520
+ f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
1521
+ "No new version will be registered."
1522
+ )
1523
+ else:
1524
+ # No scores available - register anyway but warn
1525
+ logger.warning(
1526
+ "No evaluation scores available to compare performance. "
1527
+ "Registering optimized prompt without performance validation."
1528
+ )
1529
+ should_register = True
1530
+
1531
+ if not should_register:
1532
+ logger.info(
1533
+ f"Skipping registration for '{prompt.full_name}' (no improvement)"
1534
+ )
1535
+ return prompt
1536
+
1537
+ # Register the optimized prompt manually
1538
+ try:
1539
+ logger.info(f"Registering optimized prompt '{prompt.full_name}'")
1540
+ registered_version: PromptVersion = mlflow.genai.register_prompt(
1541
+ name=prompt.full_name,
1542
+ template=optimized_template,
1543
+ commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
1544
+ tags={
1545
+ "dao_ai": dao_ai_version(),
1546
+ "target_model": agent_model.model.uri,
1547
+ },
1548
+ )
1549
+ logger.info(
1550
+ f"Registered optimized prompt as version {registered_version.version}"
1551
+ )
1552
+
1553
+ # Always set "latest" alias (represents most recently registered prompt)
1554
+ logger.info(
1555
+ f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
1556
+ )
1557
+ mlflow.genai.set_prompt_alias(
1558
+ name=prompt.full_name,
1559
+ alias="latest",
1560
+ version=registered_version.version,
1561
+ )
1562
+ logger.info(
1563
+ f"Successfully set 'latest' alias for '{prompt.full_name}' v{registered_version.version}"
1564
+ )
1565
+
1566
+ # If there's confirmed improvement, also set the "champion" alias
1567
+ # (represents the prompt that should be used by deployed agents)
1568
+ if has_improvement:
1569
+ logger.info(
1570
+ f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
1571
+ )
1572
+ mlflow.genai.set_prompt_alias(
1573
+ name=prompt.full_name,
1574
+ alias="champion",
1575
+ version=registered_version.version,
1576
+ )
1577
+ logger.info(
1578
+ f"Successfully set 'champion' alias for '{prompt.full_name}' v{registered_version.version}"
1579
+ )
1580
+
1581
+ # Add target_model and dao_ai tags
1582
+ tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
1583
+ tags["target_model"] = agent_model.model.uri
1584
+ tags["dao_ai"] = dao_ai_version()
1585
+
1586
+ # Return the optimized prompt with the appropriate alias
1587
+ # Use "champion" if there was improvement, otherwise "latest"
1588
+ result_alias: str = "champion" if has_improvement else "latest"
1589
+ return PromptModel(
1590
+ name=prompt.name,
1591
+ schema=prompt.schema_model,
1592
+ description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
1593
+ alias=result_alias,
1594
+ tags=tags,
1595
+ )
1596
+
1597
+ except Exception as e:
1598
+ logger.error(
1599
+ f"Failed to register optimized prompt '{prompt.full_name}': {e}"
1600
+ )
1601
+ return prompt
1602
+ else:
1603
+ logger.warning("No optimized prompts returned from optimization")
1604
+ return prompt