dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,4 @@
1
1
  import base64
2
- import re
3
2
  import uuid
4
3
  from pathlib import Path
5
4
  from typing import Any, Callable, Final, Sequence
@@ -32,14 +31,12 @@ from mlflow import MlflowClient
32
31
  from mlflow.entities import Experiment
33
32
  from mlflow.entities.model_registry import PromptVersion
34
33
  from mlflow.entities.model_registry.model_version import ModelVersion
35
- from mlflow.genai.datasets import EvaluationDataset, get_dataset
36
34
  from mlflow.genai.prompts import load_prompt
37
35
  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
38
36
  from mlflow.models.model import ModelInfo
39
37
  from mlflow.models.resources import (
40
38
  DatabricksResource,
41
39
  )
42
- from mlflow.pyfunc import ResponsesAgent
43
40
  from pyspark.sql import SparkSession
44
41
  from unitycatalog.ai.core.base import FunctionExecutionResult
45
42
  from unitycatalog.ai.core.databricks import DatabricksFunctionClient
@@ -49,6 +46,7 @@ from dao_ai.config import (
49
46
  AppConfig,
50
47
  ConnectionModel,
51
48
  DatabaseModel,
49
+ DatabricksAppModel,
52
50
  DatasetModel,
53
51
  FunctionModel,
54
52
  GenieRoomModel,
@@ -57,7 +55,6 @@ from dao_ai.config import (
57
55
  IsDatabricksResource,
58
56
  LLMModel,
59
57
  PromptModel,
60
- PromptOptimizationModel,
61
58
  SchemaModel,
62
59
  TableModel,
63
60
  UnityCatalogFunctionSqlModel,
@@ -73,6 +70,7 @@ from dao_ai.utils import (
73
70
  get_installed_packages,
74
71
  is_installed,
75
72
  is_lib_provided,
73
+ normalize_host,
76
74
  normalize_name,
77
75
  )
78
76
  from dao_ai.vector_search import endpoint_exists, index_exists
@@ -94,15 +92,18 @@ def _workspace_client(
94
92
  Create a WorkspaceClient instance with the provided parameters.
95
93
  If no parameters are provided, it will use the default configuration.
96
94
  """
97
- if client_id and client_secret and workspace_host:
95
+ # Normalize the workspace host to ensure it has https:// scheme
96
+ normalized_host = normalize_host(workspace_host)
97
+
98
+ if client_id and client_secret and normalized_host:
98
99
  return WorkspaceClient(
99
- host=workspace_host,
100
+ host=normalized_host,
100
101
  client_id=client_id,
101
102
  client_secret=client_secret,
102
103
  auth_type="oauth-m2m",
103
104
  )
104
105
  elif pat:
105
- return WorkspaceClient(host=workspace_host, token=pat, auth_type="pat")
106
+ return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
106
107
  else:
107
108
  return WorkspaceClient()
108
109
 
@@ -117,15 +118,18 @@ def _vector_search_client(
117
118
  Create a VectorSearchClient instance with the provided parameters.
118
119
  If no parameters are provided, it will use the default configuration.
119
120
  """
120
- if client_id and client_secret and workspace_host:
121
+ # Normalize the workspace host to ensure it has https:// scheme
122
+ normalized_host = normalize_host(workspace_host)
123
+
124
+ if client_id and client_secret and normalized_host:
121
125
  return VectorSearchClient(
122
- workspace_url=workspace_host,
126
+ workspace_url=normalized_host,
123
127
  service_principal_client_id=client_id,
124
128
  service_principal_client_secret=client_secret,
125
129
  )
126
- elif pat and workspace_host:
130
+ elif pat and normalized_host:
127
131
  return VectorSearchClient(
128
- workspace_url=workspace_host,
132
+ workspace_url=normalized_host,
129
133
  personal_access_token=pat,
130
134
  )
131
135
  else:
@@ -177,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
177
181
  experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
178
182
  if experiment is None:
179
183
  experiment_id: str = mlflow.create_experiment(name=experiment_name)
180
- logger.info(
181
- f"Created new experiment: {experiment_name} (ID: {experiment_id})"
184
+ logger.success(
185
+ "Created new MLflow experiment",
186
+ experiment_name=experiment_name,
187
+ experiment_id=experiment_id,
182
188
  )
183
189
  experiment = mlflow.get_experiment(experiment_id)
184
190
  return experiment
185
191
 
186
192
  def create_token(self) -> str:
187
193
  current_user: User = self.w.current_user.me()
188
- logger.debug(f"Authenticated to Databricks as {current_user}")
194
+ logger.debug("Authenticated to Databricks", user=str(current_user))
189
195
  headers: dict[str, str] = self.w.config.authenticate()
190
196
  token: str = headers["Authorization"].replace("Bearer ", "")
191
197
  return token
@@ -197,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
197
203
  secret_response: GetSecretResponse = self.w.secrets.get_secret(
198
204
  secret_scope, secret_key
199
205
  )
200
- logger.debug(f"Retrieved secret {secret_key} from scope {secret_scope}")
206
+ logger.trace(
207
+ "Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
208
+ )
201
209
  encoded_secret: str = secret_response.value
202
210
  decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
203
211
  return decoded_secret
204
212
  except NotFound:
205
213
  logger.warning(
206
- f"Secret {secret_key} not found in scope {secret_scope}, using default value"
214
+ "Secret not found, using default value",
215
+ secret_key=secret_key,
216
+ secret_scope=secret_scope,
207
217
  )
208
218
  except Exception as e:
209
219
  logger.error(
210
- f"Error retrieving secret {secret_key} from scope {secret_scope}: {e}"
220
+ "Error retrieving secret",
221
+ secret_key=secret_key,
222
+ secret_scope=secret_scope,
223
+ error=str(e),
211
224
  )
212
225
 
213
226
  return default_value
@@ -216,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
216
229
  self,
217
230
  config: AppConfig,
218
231
  ) -> ModelInfo:
219
- logger.debug("Creating agent...")
232
+ logger.info("Creating agent")
220
233
  mlflow.set_registry_uri("databricks-uc")
221
234
 
235
+ # Set up experiment for proper tracking
236
+ experiment: Experiment = self.get_or_create_experiment(config)
237
+ mlflow.set_experiment(experiment_id=experiment.experiment_id)
238
+ logger.debug(
239
+ "Using MLflow experiment",
240
+ experiment_name=experiment.name,
241
+ experiment_id=experiment.experiment_id,
242
+ )
243
+
222
244
  llms: Sequence[LLMModel] = list(config.resources.llms.values())
223
245
  vector_indexes: Sequence[IndexModel] = list(
224
246
  config.resources.vector_stores.values()
@@ -236,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
236
258
  )
237
259
  databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
238
260
  volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
261
+ apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
239
262
 
240
263
  resources: Sequence[IsDatabricksResource] = (
241
264
  llms
@@ -247,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
247
270
  + connections
248
271
  + databases
249
272
  + volumes
273
+ + apps
250
274
  )
251
275
 
252
276
  # Flatten all resources from all models into a single list
@@ -260,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
260
284
  for resource in r.as_resources()
261
285
  if not r.on_behalf_of_user
262
286
  ]
263
- logger.debug(f"system_resources: {[r.name for r in system_resources]}")
287
+ logger.trace(
288
+ "System resources identified",
289
+ count=len(system_resources),
290
+ resources=[r.name for r in system_resources],
291
+ )
264
292
 
265
293
  system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
266
294
  resources=system_resources
267
295
  )
268
- logger.debug(f"system_auth_policy: {system_auth_policy}")
296
+ logger.trace("System auth policy created", policy=str(system_auth_policy))
269
297
 
270
298
  api_scopes: Sequence[str] = list(
271
299
  set(
@@ -277,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
277
305
  ]
278
306
  )
279
307
  )
280
- logger.debug(f"api_scopes: {api_scopes}")
308
+ logger.trace("API scopes identified", scopes=api_scopes)
281
309
 
282
310
  user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
283
- logger.debug(f"user_auth_policy: {user_auth_policy}")
311
+ logger.trace("User auth policy created", policy=str(user_auth_policy))
284
312
 
285
313
  auth_policy: AuthPolicy = AuthPolicy(
286
314
  system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
287
315
  )
288
- logger.debug(f"auth_policy: {auth_policy}")
316
+ logger.debug(
317
+ "Auth policy created",
318
+ has_system_auth=system_auth_policy is not None,
319
+ has_user_auth=user_auth_policy is not None,
320
+ )
289
321
 
290
322
  code_paths: list[str] = config.app.code_paths
291
323
  for path in code_paths:
@@ -312,18 +344,42 @@ class DatabricksProvider(ServiceProvider):
312
344
 
313
345
  pip_requirements += get_installed_packages()
314
346
 
315
- logger.debug(f"pip_requirements: {pip_requirements}")
316
- logger.debug(f"code_paths: {code_paths}")
347
+ logger.trace("Pip requirements prepared", count=len(pip_requirements))
348
+ logger.trace("Code paths prepared", count=len(code_paths))
317
349
 
318
350
  run_name: str = normalize_name(config.app.name)
319
- logger.debug(f"run_name: {run_name}")
320
- logger.debug(f"model_path: {model_path.as_posix()}")
351
+ logger.debug(
352
+ "Agent run configuration",
353
+ run_name=run_name,
354
+ model_path=model_path.as_posix(),
355
+ )
321
356
 
322
357
  input_example: dict[str, Any] = None
323
358
  if config.app.input_example:
324
359
  input_example = config.app.input_example.model_dump()
325
360
 
326
- logger.debug(f"input_example: {input_example}")
361
+ logger.trace("Input example configured", has_example=input_example is not None)
362
+
363
+ # Create conda environment with configured Python version
364
+ # This allows deploying from environments with different Python versions
365
+ # (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
366
+ target_python_version: str = config.app.python_version
367
+ logger.debug("Target Python version configured", version=target_python_version)
368
+
369
+ conda_env: dict[str, Any] = {
370
+ "name": "mlflow-env",
371
+ "channels": ["conda-forge"],
372
+ "dependencies": [
373
+ f"python={target_python_version}",
374
+ "pip",
375
+ {"pip": list(pip_requirements)},
376
+ ],
377
+ }
378
+ logger.trace(
379
+ "Conda environment configured",
380
+ python_version=target_python_version,
381
+ pip_packages_count=len(pip_requirements),
382
+ )
327
383
 
328
384
  with mlflow.start_run(run_name=run_name):
329
385
  mlflow.set_tag("type", "agent")
@@ -333,7 +389,7 @@ class DatabricksProvider(ServiceProvider):
333
389
  code_paths=code_paths,
334
390
  model_config=config.model_dump(mode="json", by_alias=True),
335
391
  name="agent",
336
- pip_requirements=pip_requirements,
392
+ conda_env=conda_env,
337
393
  input_example=input_example,
338
394
  # resources=all_resources,
339
395
  auth_policy=auth_policy,
@@ -344,8 +400,10 @@ class DatabricksProvider(ServiceProvider):
344
400
  model_version: ModelVersion = mlflow.register_model(
345
401
  name=registered_model_name, model_uri=logged_agent_info.model_uri
346
402
  )
347
- logger.debug(
348
- f"Registered model: {registered_model_name} with version: {model_version.version}"
403
+ logger.success(
404
+ "Model registered",
405
+ model_name=registered_model_name,
406
+ version=model_version.version,
349
407
  )
350
408
 
351
409
  client: MlflowClient = MlflowClient()
@@ -357,7 +415,7 @@ class DatabricksProvider(ServiceProvider):
357
415
  key="dao_ai",
358
416
  value=dao_ai_version(),
359
417
  )
360
- logger.debug(f"Set dao_ai tag on model version {model_version.version}")
418
+ logger.trace("Set dao_ai tag on model version", version=model_version.version)
361
419
 
362
420
  client.set_registered_model_alias(
363
421
  name=registered_model_name,
@@ -374,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
374
432
  aliased_model: ModelVersion = client.get_model_version_by_alias(
375
433
  registered_model_name, config.app.alias
376
434
  )
377
- logger.debug(
378
- f"Model {registered_model_name} aliased to {config.app.alias} with version: {aliased_model.version}"
435
+ logger.info(
436
+ "Model aliased",
437
+ model_name=registered_model_name,
438
+ alias=config.app.alias,
439
+ version=aliased_model.version,
379
440
  )
380
441
 
381
442
  def deploy_agent(self, config: AppConfig) -> None:
382
- logger.debug("Deploying agent...")
443
+ logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
383
444
  mlflow.set_registry_uri("databricks-uc")
384
445
 
385
446
  endpoint_name: str = config.app.endpoint_name
@@ -400,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
400
461
  agents.get_deployments(endpoint_name)
401
462
  endpoint_exists = True
402
463
  logger.debug(
403
- f"Endpoint {endpoint_name} already exists, updating without tags to avoid conflicts..."
464
+ "Endpoint already exists, updating", endpoint_name=endpoint_name
404
465
  )
405
466
  except Exception:
406
- logger.debug(
407
- f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
408
- )
467
+ logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
409
468
 
410
469
  # Deploy - skip tags for existing endpoints to avoid conflicts
411
470
  agents.deploy(
@@ -421,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
421
480
  registered_model_name: str = config.app.registered_model.full_name
422
481
  permissions: Sequence[dict[str, Any]] = config.app.permissions
423
482
 
424
- logger.debug(registered_model_name)
425
- logger.debug(permissions)
483
+ logger.debug(
484
+ "Configuring model permissions",
485
+ model_name=registered_model_name,
486
+ permissions_count=len(permissions),
487
+ )
426
488
 
427
489
  for permission in permissions:
428
490
  principals: Sequence[str] = permission.principals
@@ -442,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
442
504
  try:
443
505
  catalog_info = self.w.catalogs.get(name=schema.catalog_name)
444
506
  except NotFound:
445
- logger.debug(f"Creating catalog: {schema.catalog_name}")
507
+ logger.info("Creating catalog", catalog_name=schema.catalog_name)
446
508
  catalog_info = self.w.catalogs.create(name=schema.catalog_name)
447
509
  return catalog_info
448
510
 
@@ -452,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
452
514
  try:
453
515
  schema_info = self.w.schemas.get(full_name=schema.full_name)
454
516
  except NotFound:
455
- logger.debug(f"Creating schema: {schema.full_name}")
517
+ logger.info("Creating schema", schema_name=schema.full_name)
456
518
  schema_info = self.w.schemas.create(
457
519
  name=schema.schema_name, catalog_name=catalog_info.name
458
520
  )
@@ -464,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
464
526
  try:
465
527
  volume_info = self.w.volumes.read(name=volume.full_name)
466
528
  except NotFound:
467
- logger.debug(f"Creating volume: {volume.full_name}")
529
+ logger.info("Creating volume", volume_name=volume.full_name)
468
530
  volume_info = self.w.volumes.create(
469
531
  catalog_name=schema_info.catalog_name,
470
532
  schema_name=schema_info.name,
@@ -475,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
475
537
 
476
538
  def create_path(self, volume_path: VolumePathModel) -> Path:
477
539
  path: Path = volume_path.full_name
478
- logger.info(f"Creating volume path: {path}")
540
+ logger.info("Creating volume path", path=str(path))
479
541
  self.w.files.create_directory(path)
480
542
  return path
481
543
 
@@ -516,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
516
578
 
517
579
  if ddl:
518
580
  ddl_path: Path = Path(ddl)
519
- logger.debug(f"Executing DDL from: {ddl_path}")
581
+ logger.debug("Executing DDL", ddl_path=str(ddl_path))
520
582
  statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
521
583
  for statement in statements:
522
- logger.debug(statement)
523
- logger.debug(f"args: {args}")
584
+ logger.trace(
585
+ "Executing DDL statement", statement=str(statement)[:100], args=args
586
+ )
524
587
  spark.sql(
525
588
  str(statement),
526
589
  args=args,
@@ -529,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
529
592
  if data:
530
593
  data_path: Path = Path(data)
531
594
  if format == "sql":
532
- logger.debug(f"Executing SQL from: {data_path}")
595
+ logger.debug("Executing SQL from file", data_path=str(data_path))
533
596
  data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
534
597
  for statement in data_statements:
535
- logger.debug(statement)
536
- logger.debug(f"args: {args}")
598
+ logger.trace(
599
+ "Executing SQL statement",
600
+ statement=str(statement)[:100],
601
+ args=args,
602
+ )
537
603
  spark.sql(
538
604
  str(statement),
539
605
  args=args,
540
606
  )
541
607
  else:
542
- logger.debug(f"Writing to: {table}")
608
+ logger.debug("Writing dataset to table", table=table)
543
609
  if not data_path.is_absolute():
544
610
  data_path = current_dir / data_path
545
- logger.debug(f"Data path: {data_path.as_posix()}")
611
+ logger.trace("Data path resolved", path=data_path.as_posix())
546
612
  if format == "excel":
547
613
  pdf = pd.read_excel(data_path.as_posix())
548
614
  df = spark.createDataFrame(pdf, schema=dataset.table_schema)
@@ -559,6 +625,17 @@ class DatabricksProvider(ServiceProvider):
559
625
  df.write.mode("overwrite").saveAsTable(table)
560
626
 
561
627
  def create_vector_store(self, vector_store: VectorStoreModel) -> None:
628
+ """
629
+ Create a vector search index from a source table.
630
+
631
+ This method expects a VectorStoreModel in provisioning mode with all
632
+ required fields validated. Use VectorStoreModel.create() which handles
633
+ mode detection and validation.
634
+
635
+ Args:
636
+ vector_store: VectorStoreModel configured for provisioning
637
+ """
638
+ # Ensure endpoint exists
562
639
  if not endpoint_exists(self.vsc, vector_store.endpoint.name):
563
640
  self.vsc.create_endpoint_and_wait(
564
641
  name=vector_store.endpoint.name,
@@ -566,13 +643,17 @@ class DatabricksProvider(ServiceProvider):
566
643
  verbose=True,
567
644
  )
568
645
 
569
- logger.debug(f"Endpoint named {vector_store.endpoint.name} is ready.")
646
+ logger.success(
647
+ "Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
648
+ )
570
649
 
571
650
  if not index_exists(
572
651
  self.vsc, vector_store.endpoint.name, vector_store.index.full_name
573
652
  ):
574
- logger.debug(
575
- f"Creating index {vector_store.index.full_name} on endpoint {vector_store.endpoint.name}..."
653
+ logger.info(
654
+ "Creating vector search index",
655
+ index_name=vector_store.index.full_name,
656
+ endpoint_name=vector_store.endpoint.name,
576
657
  )
577
658
  self.vsc.create_delta_sync_index_and_wait(
578
659
  endpoint_name=vector_store.endpoint.name,
@@ -586,7 +667,8 @@ class DatabricksProvider(ServiceProvider):
586
667
  )
587
668
  else:
588
669
  logger.debug(
589
- f"Index {vector_store.index.full_name} already exists, checking status and syncing..."
670
+ "Vector search index already exists, checking status",
671
+ index_name=vector_store.index.full_name,
590
672
  )
591
673
  index = self.vsc.get_index(
592
674
  vector_store.endpoint.name, vector_store.index.full_name
@@ -609,54 +691,61 @@ class DatabricksProvider(ServiceProvider):
609
691
 
610
692
  if pipeline_status in [
611
693
  "COMPLETED",
694
+ "ONLINE",
612
695
  "FAILED",
613
696
  "CANCELED",
614
697
  "ONLINE_PIPELINE_FAILED",
615
698
  ]:
616
- logger.debug(
617
- f"Index is ready to sync (status: {pipeline_status})"
618
- )
699
+ logger.debug("Index ready to sync", status=pipeline_status)
619
700
  break
620
701
  elif pipeline_status in [
621
702
  "WAITING_FOR_RESOURCES",
622
703
  "PROVISIONING",
623
704
  "INITIALIZING",
624
705
  "INDEXING",
625
- "ONLINE",
626
706
  ]:
627
- logger.debug(
628
- f"Index not ready yet (status: {pipeline_status}), waiting {wait_interval} seconds..."
707
+ logger.trace(
708
+ "Index not ready, waiting",
709
+ status=pipeline_status,
710
+ wait_seconds=wait_interval,
629
711
  )
630
712
  time.sleep(wait_interval)
631
713
  elapsed += wait_interval
632
714
  else:
633
715
  logger.warning(
634
- f"Unknown pipeline status: {pipeline_status}, attempting sync anyway"
716
+ "Unknown pipeline status, attempting sync",
717
+ status=pipeline_status,
635
718
  )
636
719
  break
637
720
  except Exception as status_error:
638
721
  logger.warning(
639
- f"Could not check index status: {status_error}, attempting sync anyway"
722
+ "Could not check index status, attempting sync",
723
+ error=str(status_error),
640
724
  )
641
725
  break
642
726
 
643
727
  if elapsed >= max_wait_time:
644
728
  logger.warning(
645
- f"Timed out waiting for index to be ready after {max_wait_time} seconds"
729
+ "Timed out waiting for index to be ready",
730
+ max_wait_seconds=max_wait_time,
646
731
  )
647
732
 
648
733
  # Now attempt to sync
649
734
  try:
650
735
  index.sync()
651
- logger.debug("Index sync completed successfully")
736
+ logger.success("Index sync completed")
652
737
  except Exception as sync_error:
653
738
  if "not ready to sync yet" in str(sync_error).lower():
654
- logger.warning(f"Index still not ready to sync: {sync_error}")
739
+ logger.warning(
740
+ "Index still not ready to sync", error=str(sync_error)
741
+ )
655
742
  else:
656
743
  raise sync_error
657
744
 
658
- logger.debug(
659
- f"index {vector_store.index.full_name} on table {vector_store.source_table.full_name} is ready"
745
+ logger.success(
746
+ "Vector search index ready",
747
+ index_name=vector_store.index.full_name,
748
+ source_table=vector_store.source_table.full_name,
660
749
  )
661
750
 
662
751
  def get_vector_index(self, vector_store: VectorStoreModel) -> None:
@@ -692,12 +781,16 @@ class DatabricksProvider(ServiceProvider):
692
781
  # sql = sql.replace("{catalog_name}", schema.catalog_name)
693
782
  # sql = sql.replace("{schema_name}", schema.schema_name)
694
783
 
695
- logger.info(function.name)
696
- logger.info(sql)
784
+ logger.info("Creating SQL function", function_name=function.name)
785
+ logger.trace("SQL function body", sql=sql[:200])
697
786
  _: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
698
787
 
699
788
  if unity_catalog_function.test:
700
- logger.info(unity_catalog_function.test.parameters)
789
+ logger.debug(
790
+ "Testing function",
791
+ function_name=function.full_name,
792
+ parameters=unity_catalog_function.test.parameters,
793
+ )
701
794
 
702
795
  result: FunctionExecutionResult = self.dfs.execute_function(
703
796
  function_name=function.full_name,
@@ -705,37 +798,50 @@ class DatabricksProvider(ServiceProvider):
705
798
  )
706
799
 
707
800
  if result.error:
708
- logger.error(result.error)
801
+ logger.error(
802
+ "Function test failed",
803
+ function_name=function.full_name,
804
+ error=result.error,
805
+ )
709
806
  else:
710
- logger.info(f"Function {function.full_name} executed successfully.")
711
- logger.info(f"Result: {result}")
807
+ logger.success(
808
+ "Function test passed", function_name=function.full_name
809
+ )
810
+ logger.debug("Function test result", result=str(result))
712
811
 
713
812
  def find_columns(self, table_model: TableModel) -> Sequence[str]:
714
- logger.debug(f"Finding columns for table: {table_model.full_name}")
813
+ logger.trace("Finding columns for table", table=table_model.full_name)
715
814
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
716
815
  columns: Sequence[ColumnInfo] = table_info.columns
717
816
  column_names: Sequence[str] = [c.name for c in columns]
718
- logger.debug(f"Columns found: {column_names}")
817
+ logger.debug(
818
+ "Columns found",
819
+ table=table_model.full_name,
820
+ columns_count=len(column_names),
821
+ )
719
822
  return column_names
720
823
 
721
824
  def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
722
- logger.debug(f"Finding primary key for table: {table_model.full_name}")
825
+ logger.trace("Finding primary key for table", table=table_model.full_name)
723
826
  primary_keys: Sequence[str] | None = None
724
827
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
725
828
  constraints: Sequence[TableConstraint] = table_info.table_constraints
726
829
  primary_key_constraint: PrimaryKeyConstraint | None = next(
727
- c.primary_key_constraint for c in constraints if c.primary_key_constraint
830
+ (c.primary_key_constraint for c in constraints if c.primary_key_constraint),
831
+ None,
728
832
  )
729
833
  if primary_key_constraint:
730
834
  primary_keys = primary_key_constraint.child_columns
731
835
 
732
- logger.debug(f"Primary key for table {table_model.full_name}: {primary_keys}")
836
+ logger.debug(
837
+ "Primary key found", table=table_model.full_name, primary_keys=primary_keys
838
+ )
733
839
  return primary_keys
734
840
 
735
841
  def find_vector_search_endpoint(
736
842
  self, predicate: Callable[[dict[str, Any]], bool]
737
843
  ) -> str | None:
738
- logger.debug("Finding vector search endpoint...")
844
+ logger.trace("Finding vector search endpoint")
739
845
  endpoint_name: str | None = None
740
846
  vector_search_endpoints: Sequence[dict[str, Any]] = (
741
847
  self.vsc.list_endpoints().get("endpoints", [])
@@ -744,11 +850,13 @@ class DatabricksProvider(ServiceProvider):
744
850
  if predicate(endpoint):
745
851
  endpoint_name = endpoint["name"]
746
852
  break
747
- logger.debug(f"Vector search endpoint found: {endpoint_name}")
853
+ logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
748
854
  return endpoint_name
749
855
 
750
856
  def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
751
- logger.debug(f"Finding vector search index: {index_model.full_name}")
857
+ logger.trace(
858
+ "Finding endpoint for vector search index", index_name=index_model.full_name
859
+ )
752
860
  all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
753
861
  "endpoints", []
754
862
  )
@@ -758,14 +866,99 @@ class DatabricksProvider(ServiceProvider):
758
866
  endpoint_name: str = endpoint["name"]
759
867
  indexes = self.vsc.list_indexes(name=endpoint_name)
760
868
  vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
761
- logger.trace(f"Endpoint: {endpoint_name}, vector_indexes: {vector_indexes}")
869
+ logger.trace(
870
+ "Checking endpoint for indexes",
871
+ endpoint_name=endpoint_name,
872
+ indexes_count=len(vector_indexes),
873
+ )
762
874
  index_names = [vector_index["name"] for vector_index in vector_indexes]
763
875
  if index_name in index_names:
764
876
  found_endpoint_name = endpoint_name
765
877
  break
766
- logger.debug(f"Vector search index found: {found_endpoint_name}")
878
+ logger.debug(
879
+ "Vector search index endpoint found",
880
+ index_name=index_model.full_name,
881
+ endpoint_name=found_endpoint_name,
882
+ )
767
883
  return found_endpoint_name
768
884
 
885
+ def _wait_for_database_available(
886
+ self,
887
+ workspace_client: WorkspaceClient,
888
+ instance_name: str,
889
+ max_wait_time: int = 600,
890
+ wait_interval: int = 10,
891
+ ) -> None:
892
+ """
893
+ Wait for a database instance to become AVAILABLE.
894
+
895
+ Args:
896
+ workspace_client: The Databricks workspace client
897
+ instance_name: Name of the database instance to wait for
898
+ max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
899
+ wait_interval: Time between status checks in seconds (default: 10)
900
+
901
+ Raises:
902
+ TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
903
+ RuntimeError: If the database enters a failed or deleted state
904
+ """
905
+ import time
906
+ from typing import Any
907
+
908
+ logger.info(
909
+ "Waiting for database instance to become AVAILABLE",
910
+ instance_name=instance_name,
911
+ )
912
+ elapsed: int = 0
913
+
914
+ while elapsed < max_wait_time:
915
+ try:
916
+ current_instance: Any = workspace_client.database.get_database_instance(
917
+ name=instance_name
918
+ )
919
+ current_state: str = current_instance.state
920
+ logger.trace(
921
+ "Database instance state checked",
922
+ instance_name=instance_name,
923
+ state=current_state,
924
+ )
925
+
926
+ if current_state == "AVAILABLE":
927
+ logger.success(
928
+ "Database instance is now AVAILABLE",
929
+ instance_name=instance_name,
930
+ )
931
+ return
932
+ elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
933
+ logger.trace(
934
+ "Database instance not ready, waiting",
935
+ instance_name=instance_name,
936
+ state=current_state,
937
+ wait_seconds=wait_interval,
938
+ )
939
+ time.sleep(wait_interval)
940
+ elapsed += wait_interval
941
+ elif current_state in ["STOPPED", "DELETING", "FAILED"]:
942
+ raise RuntimeError(
943
+ f"Database instance {instance_name} entered unexpected state: {current_state}"
944
+ )
945
+ else:
946
+ logger.warning(
947
+ "Unknown database state, continuing to wait",
948
+ instance_name=instance_name,
949
+ state=current_state,
950
+ )
951
+ time.sleep(wait_interval)
952
+ elapsed += wait_interval
953
+ except NotFound:
954
+ raise RuntimeError(
955
+ f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
956
+ )
957
+
958
+ raise TimeoutError(
959
+ f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
960
+ )
961
+
769
962
  def create_lakebase(self, database: DatabaseModel) -> None:
770
963
  """
771
964
  Create a Lakebase database instance using the Databricks workspace client.
@@ -796,13 +989,17 @@ class DatabricksProvider(ServiceProvider):
796
989
 
797
990
  if existing_instance:
798
991
  logger.debug(
799
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
992
+ "Database instance already exists",
993
+ instance_name=database.instance_name,
994
+ state=existing_instance.state,
800
995
  )
801
996
 
802
997
  # Check if database is in an intermediate state
803
998
  if existing_instance.state in ["STARTING", "UPDATING"]:
804
999
  logger.info(
805
- f"Database instance {database.instance_name} is in {existing_instance.state} state, waiting for it to become AVAILABLE..."
1000
+ "Database instance in intermediate state, waiting",
1001
+ instance_name=database.instance_name,
1002
+ state=existing_instance.state,
806
1003
  )
807
1004
 
808
1005
  # Wait for database to reach a stable state
@@ -818,65 +1015,87 @@ class DatabricksProvider(ServiceProvider):
818
1015
  )
819
1016
  )
820
1017
  current_state: str = current_instance.state
821
- logger.debug(f"Database instance state: {current_state}")
1018
+ logger.trace(
1019
+ "Checking database instance state",
1020
+ instance_name=database.instance_name,
1021
+ state=current_state,
1022
+ )
822
1023
 
823
1024
  if current_state == "AVAILABLE":
824
- logger.info(
825
- f"Database instance {database.instance_name} is now AVAILABLE"
1025
+ logger.success(
1026
+ "Database instance is now AVAILABLE",
1027
+ instance_name=database.instance_name,
826
1028
  )
827
1029
  break
828
1030
  elif current_state in ["STARTING", "UPDATING"]:
829
- logger.debug(
830
- f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
1031
+ logger.trace(
1032
+ "Database instance not ready, waiting",
1033
+ instance_name=database.instance_name,
1034
+ state=current_state,
1035
+ wait_seconds=wait_interval,
831
1036
  )
832
1037
  time.sleep(wait_interval)
833
1038
  elapsed += wait_interval
834
1039
  elif current_state in ["STOPPED", "DELETING"]:
835
1040
  logger.warning(
836
- f"Database instance {database.instance_name} is in unexpected state: {current_state}"
1041
+ "Database instance in unexpected state",
1042
+ instance_name=database.instance_name,
1043
+ state=current_state,
837
1044
  )
838
1045
  break
839
1046
  else:
840
1047
  logger.warning(
841
- f"Unknown database state: {current_state}, proceeding anyway"
1048
+ "Unknown database state, proceeding",
1049
+ instance_name=database.instance_name,
1050
+ state=current_state,
842
1051
  )
843
1052
  break
844
1053
  except NotFound:
845
1054
  logger.warning(
846
- f"Database instance {database.instance_name} no longer exists, will attempt to recreate"
1055
+ "Database instance no longer exists, will recreate",
1056
+ instance_name=database.instance_name,
847
1057
  )
848
1058
  break
849
1059
  except Exception as state_error:
850
1060
  logger.warning(
851
- f"Could not check database state: {state_error}, proceeding anyway"
1061
+ "Could not check database state, proceeding",
1062
+ instance_name=database.instance_name,
1063
+ error=str(state_error),
852
1064
  )
853
1065
  break
854
1066
 
855
1067
  if elapsed >= max_wait_time:
856
1068
  logger.warning(
857
- f"Timed out waiting for database instance {database.instance_name} to become AVAILABLE after {max_wait_time} seconds"
1069
+ "Timed out waiting for database to become AVAILABLE",
1070
+ instance_name=database.instance_name,
1071
+ max_wait_seconds=max_wait_time,
858
1072
  )
859
1073
 
860
1074
  elif existing_instance.state == "AVAILABLE":
861
1075
  logger.info(
862
- f"Database instance {database.instance_name} already exists and is AVAILABLE"
1076
+ "Database instance already exists and is AVAILABLE",
1077
+ instance_name=database.instance_name,
863
1078
  )
864
1079
  return
865
1080
  elif existing_instance.state in ["STOPPED", "DELETING"]:
866
1081
  logger.warning(
867
- f"Database instance {database.instance_name} is in {existing_instance.state} state"
1082
+ "Database instance in terminal state",
1083
+ instance_name=database.instance_name,
1084
+ state=existing_instance.state,
868
1085
  )
869
1086
  return
870
1087
  else:
871
1088
  logger.info(
872
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
1089
+ "Database instance already exists",
1090
+ instance_name=database.instance_name,
1091
+ state=existing_instance.state,
873
1092
  )
874
1093
  return
875
1094
 
876
1095
  except NotFound:
877
1096
  # Database doesn't exist, proceed with creation
878
- logger.debug(
879
- f"Database instance {database.instance_name} not found, creating new instance..."
1097
+ logger.info(
1098
+ "Creating new database instance", instance_name=database.instance_name
880
1099
  )
881
1100
 
882
1101
  try:
@@ -896,10 +1115,17 @@ class DatabricksProvider(ServiceProvider):
896
1115
  workspace_client.database.create_database_instance(
897
1116
  database_instance=database_instance
898
1117
  )
899
- logger.info(
900
- f"Successfully created database instance: {database.instance_name}"
1118
+ logger.success(
1119
+ "Database instance created successfully",
1120
+ instance_name=database.instance_name,
901
1121
  )
902
1122
 
1123
+ # Wait for the newly created database to become AVAILABLE
1124
+ self._wait_for_database_available(
1125
+ workspace_client, database.instance_name
1126
+ )
1127
+ return
1128
+
903
1129
  except Exception as create_error:
904
1130
  error_msg: str = str(create_error)
905
1131
 
@@ -909,13 +1135,20 @@ class DatabricksProvider(ServiceProvider):
909
1135
  or "RESOURCE_ALREADY_EXISTS" in error_msg
910
1136
  ):
911
1137
  logger.info(
912
- f"Database instance {database.instance_name} was created concurrently by another process"
1138
+ "Database instance was created concurrently",
1139
+ instance_name=database.instance_name,
1140
+ )
1141
+ # Still need to wait for the database to become AVAILABLE
1142
+ self._wait_for_database_available(
1143
+ workspace_client, database.instance_name
913
1144
  )
914
1145
  return
915
1146
  else:
916
1147
  # Re-raise unexpected errors
917
1148
  logger.error(
918
- f"Error creating database instance {database.instance_name}: {create_error}"
1149
+ "Error creating database instance",
1150
+ instance_name=database.instance_name,
1151
+ error=str(create_error),
919
1152
  )
920
1153
  raise
921
1154
 
@@ -929,12 +1162,15 @@ class DatabricksProvider(ServiceProvider):
929
1162
  or "RESOURCE_ALREADY_EXISTS" in error_msg
930
1163
  ):
931
1164
  logger.info(
932
- f"Database instance {database.instance_name} already exists (detected via exception)"
1165
+ "Database instance already exists (detected via exception)",
1166
+ instance_name=database.instance_name,
933
1167
  )
934
1168
  return
935
1169
  else:
936
1170
  logger.error(
937
- f"Unexpected error while handling database {database.instance_name}: {e}"
1171
+ "Unexpected error while handling database",
1172
+ instance_name=database.instance_name,
1173
+ error=str(e),
938
1174
  )
939
1175
  raise
940
1176
 
@@ -942,7 +1178,9 @@ class DatabricksProvider(ServiceProvider):
942
1178
  """
943
1179
  Ask Databricks to mint a fresh DB credential for this instance.
944
1180
  """
945
- logger.debug(f"Generating password for lakebase instance: {instance_name}")
1181
+ logger.trace(
1182
+ "Generating password for lakebase instance", instance_name=instance_name
1183
+ )
946
1184
  w: WorkspaceClient = self.w
947
1185
  cred: DatabaseCredential = w.database.generate_database_credential(
948
1186
  request_id=str(uuid.uuid4()),
@@ -978,7 +1216,8 @@ class DatabricksProvider(ServiceProvider):
978
1216
  # Validate that client_id is provided
979
1217
  if not database.client_id:
980
1218
  logger.warning(
981
- f"client_id is required to create instance role for database {database.instance_name}"
1219
+ "client_id required to create instance role",
1220
+ instance_name=database.instance_name,
982
1221
  )
983
1222
  return
984
1223
 
@@ -988,7 +1227,10 @@ class DatabricksProvider(ServiceProvider):
988
1227
  instance_name: str = database.instance_name
989
1228
 
990
1229
  logger.debug(
991
- f"Creating instance role '{role_name}' for database {instance_name} with principal {client_id}"
1230
+ "Creating instance role",
1231
+ role_name=role_name,
1232
+ instance_name=instance_name,
1233
+ principal=client_id,
992
1234
  )
993
1235
 
994
1236
  try:
@@ -999,13 +1241,15 @@ class DatabricksProvider(ServiceProvider):
999
1241
  name=role_name,
1000
1242
  )
1001
1243
  logger.info(
1002
- f"Instance role '{role_name}' already exists for database {instance_name}"
1244
+ "Instance role already exists",
1245
+ role_name=role_name,
1246
+ instance_name=instance_name,
1003
1247
  )
1004
1248
  return
1005
1249
  except NotFound:
1006
1250
  # Role doesn't exist, proceed with creation
1007
1251
  logger.debug(
1008
- f"Instance role '{role_name}' not found, creating new role..."
1252
+ "Instance role not found, creating new role", role_name=role_name
1009
1253
  )
1010
1254
 
1011
1255
  # Create the database instance role
@@ -1021,8 +1265,10 @@ class DatabricksProvider(ServiceProvider):
1021
1265
  database_instance_role=role,
1022
1266
  )
1023
1267
 
1024
- logger.info(
1025
- f"Successfully created instance role '{role_name}' for database {instance_name}"
1268
+ logger.success(
1269
+ "Instance role created successfully",
1270
+ role_name=role_name,
1271
+ instance_name=instance_name,
1026
1272
  )
1027
1273
 
1028
1274
  except Exception as e:
@@ -1034,13 +1280,18 @@ class DatabricksProvider(ServiceProvider):
1034
1280
  or "RESOURCE_ALREADY_EXISTS" in error_msg
1035
1281
  ):
1036
1282
  logger.info(
1037
- f"Instance role '{role_name}' was created concurrently for database {instance_name}"
1283
+ "Instance role was created concurrently",
1284
+ role_name=role_name,
1285
+ instance_name=instance_name,
1038
1286
  )
1039
1287
  return
1040
1288
 
1041
1289
  # Re-raise unexpected errors
1042
1290
  logger.error(
1043
- f"Error creating instance role '{role_name}' for database {instance_name}: {e}"
1291
+ "Error creating instance role",
1292
+ role_name=role_name,
1293
+ instance_name=instance_name,
1294
+ error=str(e),
1044
1295
  )
1045
1296
  raise
1046
1297
 
@@ -1050,9 +1301,17 @@ class DatabricksProvider(ServiceProvider):
1050
1301
 
1051
1302
  If an explicit version or alias is specified in the prompt_model, uses that directly.
1052
1303
  Otherwise, tries to load prompts in this order:
1053
- 1. champion alias (if it exists)
1054
- 2. latest alias (if it exists)
1055
- 3. default_template (if provided)
1304
+ 1. champion alias
1305
+ 2. latest alias
1306
+ 3. default alias
1307
+ 4. Register default_template if provided (only if register_to_registry=True)
1308
+ 5. Use default_template directly (fallback)
1309
+
1310
+ The auto_register field controls whether the default_template is automatically
1311
+ synced to the prompt registry:
1312
+ - If True (default): Auto-registers/updates the default_template in the registry
1313
+ - If False: Never registers, but can still load existing prompts from registry
1314
+ or use default_template directly as a local-only prompt
1056
1315
 
1057
1316
  Args:
1058
1317
  prompt_model: The prompt model configuration
@@ -1063,542 +1322,266 @@ class DatabricksProvider(ServiceProvider):
1063
1322
  Raises:
1064
1323
  ValueError: If no prompt can be loaded from any source
1065
1324
  """
1325
+
1066
1326
  prompt_name: str = prompt_model.full_name
1067
1327
 
1068
- # If explicit version or alias is specified, use it directly without fallback
1328
+ # If explicit version or alias is specified, use it directly
1069
1329
  if prompt_model.version or prompt_model.alias:
1070
1330
  try:
1071
1331
  prompt_version: PromptVersion = prompt_model.as_prompt()
1332
+ version_or_alias = (
1333
+ f"version {prompt_model.version}"
1334
+ if prompt_model.version
1335
+ else f"alias {prompt_model.alias}"
1336
+ )
1072
1337
  logger.debug(
1073
- f"Loaded prompt '{prompt_name}' with explicit "
1074
- f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}"
1338
+ "Loaded prompt with explicit version/alias",
1339
+ prompt_name=prompt_name,
1340
+ version_or_alias=version_or_alias,
1075
1341
  )
1076
1342
  return prompt_version
1077
1343
  except Exception as e:
1344
+ version_or_alias = (
1345
+ f"version {prompt_model.version}"
1346
+ if prompt_model.version
1347
+ else f"alias {prompt_model.alias}"
1348
+ )
1078
1349
  logger.warning(
1079
- f"Failed to load prompt '{prompt_name}' with explicit "
1080
- f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
1350
+ "Failed to load prompt with explicit version/alias",
1351
+ prompt_name=prompt_name,
1352
+ version_or_alias=version_or_alias,
1353
+ error=str(e),
1081
1354
  )
1082
- # Fall through to default_template if available
1083
- else:
1084
- # No explicit version/alias specified - check if default_template needs syncing first
1085
- logger.debug(
1086
- f"No explicit version/alias specified for '{prompt_name}', "
1087
- "checking if default_template needs syncing"
1088
- )
1089
-
1090
- # If we have a default_template, check if it differs from what's in the registry
1091
- # This ensures we always sync config changes before returning any alias
1092
- if prompt_model.default_template:
1093
- try:
1094
- default_uri: str = f"prompts:/{prompt_name}@default"
1095
- default_version: PromptVersion = load_prompt(default_uri)
1355
+ # Fall through to try other methods
1096
1356
 
1097
- if (
1098
- default_version.to_single_brace_format().strip()
1099
- != prompt_model.default_template.strip()
1100
- ):
1101
- logger.info(
1102
- f"Config default_template for '{prompt_name}' differs from registry, syncing..."
1103
- )
1104
- return self._sync_default_template_to_registry(
1105
- prompt_name,
1106
- prompt_model.default_template,
1107
- prompt_model.description,
1108
- )
1109
- except Exception as e:
1110
- logger.debug(f"Could not check default alias for sync: {e}")
1111
-
1112
- # Now try aliases in order: champion → latest → default
1113
- logger.debug(
1114
- f"Trying fallback order for '{prompt_name}': champion → latest → default"
1115
- )
1116
-
1117
- # Try champion alias first
1118
- try:
1119
- champion_uri: str = f"prompts:/{prompt_name}@champion"
1120
- prompt_version: PromptVersion = load_prompt(champion_uri)
1121
- logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
1122
- return prompt_version
1123
- except Exception as e:
1124
- logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1125
-
1126
- # Try latest alias next
1127
- try:
1128
- latest_uri: str = f"prompts:/{prompt_name}@latest"
1129
- prompt_version: PromptVersion = load_prompt(latest_uri)
1130
- logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
1131
- return prompt_version
1132
- except Exception as e:
1133
- logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
1357
+ # Try to load in priority order: champion → default (with sync check)
1358
+ logger.trace(
1359
+ "Trying prompt fallback order",
1360
+ prompt_name=prompt_name,
1361
+ order="champion → default",
1362
+ )
1134
1363
 
1135
- # Try default alias last
1364
+ # First, sync default alias if template has changed (even if champion exists)
1365
+ # Only do this if auto_register is True
1366
+ if prompt_model.default_template and prompt_model.auto_register:
1136
1367
  try:
1137
- default_uri: str = f"prompts:/{prompt_name}@default"
1138
- prompt_version: PromptVersion = load_prompt(default_uri)
1139
- logger.info(f"Loaded prompt '{prompt_name}' from default alias")
1140
- return prompt_version
1141
- except Exception as e:
1142
- logger.debug(f"Default alias not found for '{prompt_name}': {e}")
1368
+ # Try to load existing default
1369
+ existing_default = load_prompt(f"prompts:/{prompt_name}@default")
1143
1370
 
1144
- # Fall back to registering default_template if provided
1145
- if prompt_model.default_template:
1146
- logger.info(
1147
- f"Registering default_template for '{prompt_name}' "
1148
- "(no aliases found in registry)"
1149
- )
1150
- return self._sync_default_template_to_registry(
1151
- prompt_name, prompt_model.default_template, prompt_model.description
1152
- )
1153
-
1154
- raise ValueError(
1155
- f"Prompt '{prompt_name}' not found in registry "
1156
- "(tried champion, latest, default aliases) and no default_template provided"
1157
- )
1158
-
1159
- def _sync_default_template_to_registry(
1160
- self, prompt_name: str, default_template: str, description: str | None = None
1161
- ) -> PromptVersion:
1162
- """Register default_template to prompt registry under 'default' alias if changed."""
1163
- prompt_version: PromptVersion | None = None
1371
+ # Check if champion exists and if it matches default
1372
+ champion_matches_default = False
1373
+ try:
1374
+ existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
1375
+ champion_matches_default = (
1376
+ existing_champion.version == existing_default.version
1377
+ )
1378
+ status = (
1379
+ "tracking" if champion_matches_default else "pinned separately"
1380
+ )
1381
+ logger.trace(
1382
+ "Champion vs default version",
1383
+ prompt_name=prompt_name,
1384
+ champion_version=existing_champion.version,
1385
+ default_version=existing_default.version,
1386
+ status=status,
1387
+ )
1388
+ except Exception:
1389
+ # No champion exists
1390
+ logger.trace("No champion alias found", prompt_name=prompt_name)
1164
1391
 
1165
- try:
1166
- # Check if default alias already has the same template
1167
- try:
1168
- logger.debug(f"Loading prompt '{prompt_name}' from registry...")
1169
- existing: PromptVersion = mlflow.genai.load_prompt(
1170
- f"prompts:/{prompt_name}@default"
1171
- )
1392
+ # Check if default_template differs from existing default
1172
1393
  if (
1173
- existing.to_single_brace_format().strip()
1174
- == default_template.strip()
1394
+ existing_default.template.strip()
1395
+ != prompt_model.default_template.strip()
1175
1396
  ):
1176
- logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
1397
+ logger.info(
1398
+ "Default template changed, registering new version",
1399
+ prompt_name=prompt_name,
1400
+ )
1177
1401
 
1178
- # Ensure the "latest" and "champion" aliases also exist and point to the same version
1179
- # This handles prompts created before the fix that added these aliases
1180
- try:
1181
- latest_version: PromptVersion = mlflow.genai.load_prompt(
1182
- f"prompts:/{prompt_name}@latest"
1183
- )
1184
- logger.debug(
1185
- f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
1186
- )
1187
- except Exception:
1402
+ # Only update champion if it was pointing to the old default
1403
+ if champion_matches_default:
1188
1404
  logger.info(
1189
- f"Setting 'latest' alias for existing prompt '{prompt_name}' v{existing.version}"
1405
+ "Champion was tracking default, will update to new version",
1406
+ prompt_name=prompt_name,
1407
+ old_version=existing_default.version,
1190
1408
  )
1191
- mlflow.genai.set_prompt_alias(
1192
- name=prompt_name,
1193
- alias="latest",
1194
- version=existing.version,
1195
- )
1196
-
1197
- # Ensure champion alias exists for first-time deployments
1198
- try:
1199
- champion_version: PromptVersion = mlflow.genai.load_prompt(
1200
- f"prompts:/{prompt_name}@champion"
1201
- )
1202
- logger.debug(
1203
- f"Champion alias already exists for '{prompt_name}' pointing to version {champion_version.version}"
1204
- )
1205
- except Exception:
1409
+ set_champion = True
1410
+ else:
1206
1411
  logger.info(
1207
- f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
1208
- )
1209
- mlflow.genai.set_prompt_alias(
1210
- name=prompt_name,
1211
- alias="champion",
1212
- version=existing.version,
1412
+ "Champion is pinned separately, preserving it",
1413
+ prompt_name=prompt_name,
1213
1414
  )
1415
+ set_champion = False
1214
1416
 
1215
- return existing # Already up-to-date, return existing version
1216
- except Exception:
1217
- logger.debug(
1218
- f"Default alias for prompt '{prompt_name}' doesn't exist yet"
1417
+ self._register_default_template(
1418
+ prompt_name,
1419
+ prompt_model.default_template,
1420
+ prompt_model.description,
1421
+ set_champion=set_champion,
1422
+ )
1423
+ except Exception as e:
1424
+ # No default exists yet, register it
1425
+ logger.trace(
1426
+ "No default alias found", prompt_name=prompt_name, error=str(e)
1219
1427
  )
1220
-
1221
- # Register new version and set as default alias
1222
- commit_message = description or "Auto-synced from default_template"
1223
- prompt_version = mlflow.genai.register_prompt(
1224
- name=prompt_name,
1225
- template=default_template,
1226
- commit_message=commit_message,
1227
- tags={"dao_ai": dao_ai_version()},
1228
- )
1229
-
1230
- logger.debug(
1231
- f"Setting default, latest, and champion aliases for prompt '{prompt_name}'"
1232
- )
1233
- mlflow.genai.set_prompt_alias(
1234
- name=prompt_name,
1235
- alias="default",
1236
- version=prompt_version.version,
1237
- )
1238
- mlflow.genai.set_prompt_alias(
1239
- name=prompt_name,
1240
- alias="latest",
1241
- version=prompt_version.version,
1242
- )
1243
- mlflow.genai.set_prompt_alias(
1244
- name=prompt_name,
1245
- alias="champion",
1246
- version=prompt_version.version,
1428
+ logger.info(
1429
+ "Registering default template as default alias",
1430
+ prompt_name=prompt_name,
1431
+ )
1432
+ # First registration - set both default and champion
1433
+ self._register_default_template(
1434
+ prompt_name,
1435
+ prompt_model.default_template,
1436
+ prompt_model.description,
1437
+ set_champion=True,
1438
+ )
1439
+ elif prompt_model.default_template and not prompt_model.auto_register:
1440
+ logger.trace(
1441
+ "Prompt has auto_register=False, skipping registration",
1442
+ prompt_name=prompt_name,
1247
1443
  )
1248
1444
 
1249
- logger.info(
1250
- f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
1251
- )
1445
+ # 1. Try champion alias (highest priority for execution)
1446
+ try:
1447
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
1448
+ logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
1252
1449
  return prompt_version
1253
-
1254
1450
  except Exception as e:
1255
- logger.error(f"Failed to sync '{prompt_name}' to registry: {e}")
1256
- raise ValueError(
1257
- f"Failed to sync prompt '{prompt_name}' to registry and unable to retrieve existing version"
1258
- ) from e
1259
-
1260
- def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
1261
- """
1262
- Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
1263
-
1264
- This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
1265
- https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
1451
+ logger.trace(
1452
+ "Champion alias not found", prompt_name=prompt_name, error=str(e)
1453
+ )
1266
1454
 
1267
- Args:
1268
- optimization: PromptOptimizationModel containing configuration
1455
+ # 2. Try default alias (already synced above)
1456
+ if prompt_model.default_template:
1457
+ try:
1458
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
1459
+ logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
1460
+ return prompt_version
1461
+ except Exception as e:
1462
+ # Should not happen since we just registered it above, but handle anyway
1463
+ logger.trace(
1464
+ "Default alias not found", prompt_name=prompt_name, error=str(e)
1465
+ )
1269
1466
 
1270
- Returns:
1271
- PromptModel: The optimized prompt with new URI
1272
- """
1273
- from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
1274
- from mlflow.genai.scorers import Correctness
1275
-
1276
- from dao_ai.config import AgentModel, PromptModel
1277
-
1278
- logger.info(f"Optimizing prompt: {optimization.name}")
1279
-
1280
- # Get agent and prompt (prompt is guaranteed to be set by validator)
1281
- agent_model: AgentModel = optimization.agent
1282
- prompt: PromptModel = optimization.prompt # type: ignore[assignment]
1283
- agent_model.prompt = prompt.uri
1284
-
1285
- print(f"prompt={agent_model.prompt}")
1286
- # Log the prompt URI scheme being used
1287
- # Supports three schemes:
1288
- # 1. Specific version: "prompts:/qa/1" (when version is specified)
1289
- # 2. Alias: "prompts:/qa@champion" (when alias is specified)
1290
- # 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
1291
- prompt_uri: str = prompt.uri
1292
- logger.info(f"Using prompt URI for optimization: {prompt_uri}")
1293
-
1294
- # Load the specific prompt version by URI for comparison
1295
- # Try to load the exact version specified, but if it doesn't exist,
1296
- # use get_prompt to create it from default_template
1297
- prompt_version: PromptVersion
1467
+ # 3. Try latest alias as final fallback
1298
1468
  try:
1299
- prompt_version = load_prompt(prompt_uri)
1300
- logger.info(f"Successfully loaded prompt from registry: {prompt_uri}")
1469
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
1470
+ logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
1471
+ return prompt_version
1301
1472
  except Exception as e:
1473
+ logger.trace(
1474
+ "Latest alias not found", prompt_name=prompt_name, error=str(e)
1475
+ )
1476
+
1477
+ # 4. Final fallback: use default_template directly if available
1478
+ if prompt_model.default_template:
1302
1479
  logger.warning(
1303
- f"Could not load prompt '{prompt_uri}' directly: {e}. "
1304
- "Attempting to create from default_template..."
1480
+ "Could not load prompt from registry, using default_template directly",
1481
+ prompt_name=prompt_name,
1305
1482
  )
1306
- # Use get_prompt which will create from default_template if needed
1307
- prompt_version = self.get_prompt(prompt)
1308
- logger.info(
1309
- f"Created/loaded prompt '{prompt.full_name}' (will optimize against this version)"
1483
+ return PromptVersion(
1484
+ name=prompt_name,
1485
+ version=1,
1486
+ template=prompt_model.default_template,
1487
+ tags={"dao_ai": dao_ai_version()},
1310
1488
  )
1311
1489
 
1312
- # Load the evaluation dataset by name
1313
- logger.debug(f"Looking up dataset: {optimization.dataset}")
1314
- dataset: EvaluationDataset
1315
- if isinstance(optimization.dataset, str):
1316
- dataset = get_dataset(name=optimization.dataset)
1317
- else:
1318
- dataset = optimization.dataset.as_dataset()
1319
-
1320
- # Set up reflection model for the optimizer
1321
- reflection_model_name: str
1322
- if optimization.reflection_model:
1323
- if isinstance(optimization.reflection_model, str):
1324
- reflection_model_name = optimization.reflection_model
1325
- else:
1326
- reflection_model_name = optimization.reflection_model.uri
1327
- else:
1328
- reflection_model_name = agent_model.model.uri
1329
- logger.debug(f"Using reflection model: {reflection_model_name}")
1330
-
1331
- # Create the GepaPromptOptimizer
1332
- optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
1333
- reflection_model=reflection_model_name,
1334
- max_metric_calls=optimization.num_candidates,
1335
- display_progress_bar=True,
1490
+ raise ValueError(
1491
+ f"Prompt '{prompt_name}' not found in registry "
1492
+ "(tried champion, default, latest aliases) "
1493
+ "and no default_template provided"
1336
1494
  )
1337
1495
 
1338
- # Set up scorer (judge model for evaluation)
1339
- scorer_model: str
1340
- if optimization.scorer_model:
1341
- if isinstance(optimization.scorer_model, str):
1342
- scorer_model = optimization.scorer_model
1343
- else:
1344
- scorer_model = optimization.scorer_model.uri
1345
- else:
1346
- scorer_model = agent_model.model.uri # Use Databricks default
1347
- logger.debug(f"Using scorer with model: {scorer_model}")
1348
-
1349
- scorers: list[Correctness] = [Correctness(model=scorer_model)]
1350
-
1351
- # Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
1352
- # DO NOT overwrite with prompt_version.uri as that uses fallback logic
1353
- logger.debug(f"Optimizing prompt: {prompt_uri}")
1354
-
1355
- agent: ResponsesAgent = agent_model.as_responses_agent()
1356
-
1357
- # Create predict function that will be optimized
1358
- def predict_fn(**inputs: dict[str, Any]) -> str:
1359
- """Prediction function that uses the ResponsesAgent with ChatPayload.
1360
-
1361
- The agent already has the prompt referenced/applied, so we just need to
1362
- convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
1363
-
1364
- Args:
1365
- **inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
1366
-
1367
- Returns:
1368
- str: The agent's response content
1369
- """
1370
- from mlflow.types.responses import (
1371
- ResponsesAgentRequest,
1372
- ResponsesAgentResponse,
1373
- )
1374
- from mlflow.types.responses_helpers import Message
1375
-
1376
- from dao_ai.config import ChatPayload
1377
-
1378
- # Verify agent is accessible (should be captured from outer scope)
1379
- if agent is None:
1380
- raise RuntimeError(
1381
- "Agent object is not available in predict_fn. "
1382
- "This may indicate a serialization issue with the ResponsesAgent."
1383
- )
1384
-
1385
- # Convert inputs to ChatPayload
1386
- chat_payload: ChatPayload = ChatPayload(**inputs)
1387
-
1388
- # Convert ChatPayload messages to MLflow Message format
1389
- mlflow_messages: list[Message] = [
1390
- Message(role=msg.role, content=msg.content)
1391
- for msg in chat_payload.messages
1392
- ]
1393
-
1394
- # Create ResponsesAgentRequest
1395
- request: ResponsesAgentRequest = ResponsesAgentRequest(
1396
- input=mlflow_messages,
1397
- custom_inputs=chat_payload.custom_inputs,
1398
- )
1496
+ def _register_default_template(
1497
+ self,
1498
+ prompt_name: str,
1499
+ default_template: str,
1500
+ description: str | None = None,
1501
+ set_champion: bool = True,
1502
+ ) -> PromptVersion:
1503
+ """Register default_template as a new prompt version.
1399
1504
 
1400
- # Call the ResponsesAgent's predict method
1401
- response: ResponsesAgentResponse = agent.predict(request)
1402
-
1403
- if response.output and len(response.output) > 0:
1404
- content = response.output[0].content
1405
- logger.debug(f"Response content type: {type(content)}")
1406
- logger.debug(f"Response content: {content}")
1407
-
1408
- # Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
1409
- # Content can be:
1410
- # - A string (return as is)
1411
- # - A list of items with 'text' keys (extract and join)
1412
- # - Other types (try to get 'text' attribute or convert to string)
1413
- if isinstance(content, str):
1414
- return content
1415
- elif isinstance(content, list):
1416
- text_parts = []
1417
- for content_item in content:
1418
- if isinstance(content_item, str):
1419
- text_parts.append(content_item)
1420
- elif isinstance(content_item, dict) and "text" in content_item:
1421
- text_parts.append(content_item["text"])
1422
- elif hasattr(content_item, "text"):
1423
- text_parts.append(content_item.text)
1424
- return "".join(text_parts) if text_parts else str(content)
1425
- else:
1426
- # Fallback for unknown types - try to extract text attribute
1427
- return getattr(content, "text", str(content))
1428
- else:
1429
- return ""
1505
+ Registers the template and sets the 'default' alias.
1506
+ Optionally sets 'champion' alias if no champion exists.
1430
1507
 
1431
- # Set registry URI for Databricks Unity Catalog
1432
- mlflow.set_registry_uri("databricks-uc")
1508
+ Args:
1509
+ prompt_name: Full name of the prompt
1510
+ default_template: The template content
1511
+ description: Optional description for commit message
1512
+ set_champion: Whether to also set champion alias (default: True)
1433
1513
 
1434
- # Run optimization with tracking disabled to prevent auto-registering all candidates
1435
- logger.info("Running prompt optimization with GepaPromptOptimizer...")
1514
+ If registration fails (e.g., in Model Serving with restricted permissions),
1515
+ logs the error and raises.
1516
+ """
1436
1517
  logger.info(
1437
- f"Generating {optimization.num_candidates} candidate prompts for evaluation"
1438
- )
1439
-
1440
- from mlflow.genai.optimize.types import (
1441
- PromptOptimizationResult,
1442
- )
1443
-
1444
- result: PromptOptimizationResult = optimize_prompts(
1445
- predict_fn=predict_fn,
1446
- train_data=dataset,
1447
- prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
1448
- optimizer=optimizer,
1449
- scorers=scorers,
1450
- enable_tracking=False, # Don't auto-register all candidates
1518
+ "Registering default template",
1519
+ prompt_name=prompt_name,
1520
+ set_champion=set_champion,
1451
1521
  )
1452
1522
 
1453
- # Log the optimization results
1454
- logger.info("Optimization complete!")
1455
- logger.info(f"Optimizer used: {result.optimizer_name}")
1456
-
1457
- if result.optimized_prompts:
1458
- optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
1459
-
1460
- # Check if the optimized prompt is actually different from the original
1461
- original_template: str = prompt_version.to_single_brace_format().strip()
1462
- optimized_template: str = (
1463
- optimized_prompt_version.to_single_brace_format().strip()
1464
- )
1465
-
1466
- # Normalize whitespace for more robust comparison
1467
- original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
1468
- optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
1469
-
1470
- logger.debug(f"Original template length: {len(original_template)} chars")
1471
- logger.debug(f"Optimized template length: {len(optimized_template)} chars")
1472
- logger.debug(
1473
- f"Templates identical: {original_normalized == optimized_normalized}"
1474
- )
1475
-
1476
- if original_normalized == optimized_normalized:
1477
- logger.info(
1478
- f"Optimized prompt is identical to original for '{prompt.full_name}'. "
1479
- "No new version will be registered."
1480
- )
1481
- return prompt
1482
-
1483
- logger.info("Optimized prompt is DIFFERENT from original")
1484
- logger.info(
1485
- f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
1486
- )
1487
- logger.debug(
1488
- f"Original template (first 300 chars): {original_template[:300]}..."
1489
- )
1490
- logger.debug(
1491
- f"Optimized template (first 300 chars): {optimized_template[:300]}..."
1523
+ try:
1524
+ commit_message = description or "Auto-synced from default_template"
1525
+ prompt_version = mlflow.genai.register_prompt(
1526
+ name=prompt_name,
1527
+ template=default_template,
1528
+ commit_message=commit_message,
1529
+ tags={"dao_ai": dao_ai_version()},
1492
1530
  )
1493
1531
 
1494
- # Check evaluation scores to determine if we should register the optimized prompt
1495
- should_register: bool = False
1496
- has_improvement: bool = False
1497
-
1498
- if (
1499
- result.initial_eval_score is not None
1500
- and result.final_eval_score is not None
1501
- ):
1502
- logger.info("Evaluation scores:")
1503
- logger.info(f" Initial score: {result.initial_eval_score}")
1504
- logger.info(f" Final score: {result.final_eval_score}")
1505
-
1506
- # Only register if there's improvement
1507
- if result.final_eval_score > result.initial_eval_score:
1508
- improvement: float = (
1509
- (result.final_eval_score - result.initial_eval_score)
1510
- / result.initial_eval_score
1511
- ) * 100
1512
- logger.info(
1513
- f"Optimized prompt improved by {improvement:.2f}% "
1514
- f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
1515
- )
1516
- should_register = True
1517
- has_improvement = True
1518
- else:
1519
- logger.info(
1520
- f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
1521
- "No new version will be registered."
1522
- )
1523
- else:
1524
- # No scores available - register anyway but warn
1525
- logger.warning(
1526
- "No evaluation scores available to compare performance. "
1527
- "Registering optimized prompt without performance validation."
1528
- )
1529
- should_register = True
1530
-
1531
- if not should_register:
1532
- logger.info(
1533
- f"Skipping registration for '{prompt.full_name}' (no improvement)"
1534
- )
1535
- return prompt
1536
-
1537
- # Register the optimized prompt manually
1532
+ # Always set default alias
1538
1533
  try:
1539
- logger.info(f"Registering optimized prompt '{prompt.full_name}'")
1540
- registered_version: PromptVersion = mlflow.genai.register_prompt(
1541
- name=prompt.full_name,
1542
- template=optimized_template,
1543
- commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
1544
- tags={
1545
- "dao_ai": dao_ai_version(),
1546
- "target_model": agent_model.model.uri,
1547
- },
1548
- )
1549
- logger.info(
1550
- f"Registered optimized prompt as version {registered_version.version}"
1551
- )
1552
-
1553
- # Always set "latest" alias (represents most recently registered prompt)
1554
- logger.info(
1555
- f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
1534
+ logger.debug(
1535
+ "Setting default alias",
1536
+ prompt_name=prompt_name,
1537
+ version=prompt_version.version,
1556
1538
  )
1557
1539
  mlflow.genai.set_prompt_alias(
1558
- name=prompt.full_name,
1559
- alias="latest",
1560
- version=registered_version.version,
1540
+ name=prompt_name, alias="default", version=prompt_version.version
1561
1541
  )
1562
- logger.info(
1563
- f"Successfully set 'latest' alias for '{prompt.full_name}' v{registered_version.version}"
1542
+ logger.success(
1543
+ "Set default alias for prompt",
1544
+ prompt_name=prompt_name,
1545
+ version=prompt_version.version,
1546
+ )
1547
+ except Exception as alias_error:
1548
+ logger.warning(
1549
+ "Could not set default alias",
1550
+ prompt_name=prompt_name,
1551
+ error=str(alias_error),
1564
1552
  )
1565
1553
 
1566
- # If there's confirmed improvement, also set the "champion" alias
1567
- # (represents the prompt that should be used by deployed agents)
1568
- if has_improvement:
1569
- logger.info(
1570
- f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
1571
- )
1554
+ # Optionally set champion alias (only if no champion exists or explicitly requested)
1555
+ if set_champion:
1556
+ try:
1572
1557
  mlflow.genai.set_prompt_alias(
1573
- name=prompt.full_name,
1558
+ name=prompt_name,
1574
1559
  alias="champion",
1575
- version=registered_version.version,
1560
+ version=prompt_version.version,
1576
1561
  )
1577
- logger.info(
1578
- f"Successfully set 'champion' alias for '{prompt.full_name}' v{registered_version.version}"
1562
+ logger.success(
1563
+ "Set champion alias for prompt",
1564
+ prompt_name=prompt_name,
1565
+ version=prompt_version.version,
1566
+ )
1567
+ except Exception as alias_error:
1568
+ logger.warning(
1569
+ "Could not set champion alias",
1570
+ prompt_name=prompt_name,
1571
+ error=str(alias_error),
1579
1572
  )
1580
1573
 
1581
- # Add target_model and dao_ai tags
1582
- tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
1583
- tags["target_model"] = agent_model.model.uri
1584
- tags["dao_ai"] = dao_ai_version()
1585
-
1586
- # Return the optimized prompt with the appropriate alias
1587
- # Use "champion" if there was improvement, otherwise "latest"
1588
- result_alias: str = "champion" if has_improvement else "latest"
1589
- return PromptModel(
1590
- name=prompt.name,
1591
- schema=prompt.schema_model,
1592
- description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
1593
- alias=result_alias,
1594
- tags=tags,
1595
- )
1574
+ return prompt_version
1596
1575
 
1597
- except Exception as e:
1598
- logger.error(
1599
- f"Failed to register optimized prompt '{prompt.full_name}': {e}"
1600
- )
1601
- return prompt
1602
- else:
1603
- logger.warning("No optimized prompts returned from optimization")
1604
- return prompt
1576
+ except Exception as reg_error:
1577
+ logger.error(
1578
+ "Failed to register prompt - please register from notebook with write permissions",
1579
+ prompt_name=prompt_name,
1580
+ error=str(reg_error),
1581
+ )
1582
+ return PromptVersion(
1583
+ name=prompt_name,
1584
+ version=1,
1585
+ template=default_template,
1586
+ tags={"dao_ai": dao_ai_version()},
1587
+ )