dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,4 @@
1
1
  import base64
2
- import re
3
2
  import uuid
4
3
  from pathlib import Path
5
4
  from typing import Any, Callable, Final, Sequence
@@ -32,14 +31,12 @@ from mlflow import MlflowClient
32
31
  from mlflow.entities import Experiment
33
32
  from mlflow.entities.model_registry import PromptVersion
34
33
  from mlflow.entities.model_registry.model_version import ModelVersion
35
- from mlflow.genai.datasets import EvaluationDataset, get_dataset
36
34
  from mlflow.genai.prompts import load_prompt
37
35
  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
38
36
  from mlflow.models.model import ModelInfo
39
37
  from mlflow.models.resources import (
40
38
  DatabricksResource,
41
39
  )
42
- from mlflow.pyfunc import ResponsesAgent
43
40
  from pyspark.sql import SparkSession
44
41
  from unitycatalog.ai.core.base import FunctionExecutionResult
45
42
  from unitycatalog.ai.core.databricks import DatabricksFunctionClient
@@ -49,6 +46,7 @@ from dao_ai.config import (
49
46
  AppConfig,
50
47
  ConnectionModel,
51
48
  DatabaseModel,
49
+ DatabricksAppModel,
52
50
  DatasetModel,
53
51
  FunctionModel,
54
52
  GenieRoomModel,
@@ -57,7 +55,6 @@ from dao_ai.config import (
57
55
  IsDatabricksResource,
58
56
  LLMModel,
59
57
  PromptModel,
60
- PromptOptimizationModel,
61
58
  SchemaModel,
62
59
  TableModel,
63
60
  UnityCatalogFunctionSqlModel,
@@ -73,6 +70,7 @@ from dao_ai.utils import (
73
70
  get_installed_packages,
74
71
  is_installed,
75
72
  is_lib_provided,
73
+ normalize_host,
76
74
  normalize_name,
77
75
  )
78
76
  from dao_ai.vector_search import endpoint_exists, index_exists
@@ -94,15 +92,18 @@ def _workspace_client(
94
92
  Create a WorkspaceClient instance with the provided parameters.
95
93
  If no parameters are provided, it will use the default configuration.
96
94
  """
97
- if client_id and client_secret and workspace_host:
95
+ # Normalize the workspace host to ensure it has https:// scheme
96
+ normalized_host = normalize_host(workspace_host)
97
+
98
+ if client_id and client_secret and normalized_host:
98
99
  return WorkspaceClient(
99
- host=workspace_host,
100
+ host=normalized_host,
100
101
  client_id=client_id,
101
102
  client_secret=client_secret,
102
103
  auth_type="oauth-m2m",
103
104
  )
104
105
  elif pat:
105
- return WorkspaceClient(host=workspace_host, token=pat, auth_type="pat")
106
+ return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
106
107
  else:
107
108
  return WorkspaceClient()
108
109
 
@@ -117,15 +118,18 @@ def _vector_search_client(
117
118
  Create a VectorSearchClient instance with the provided parameters.
118
119
  If no parameters are provided, it will use the default configuration.
119
120
  """
120
- if client_id and client_secret and workspace_host:
121
+ # Normalize the workspace host to ensure it has https:// scheme
122
+ normalized_host = normalize_host(workspace_host)
123
+
124
+ if client_id and client_secret and normalized_host:
121
125
  return VectorSearchClient(
122
- workspace_url=workspace_host,
126
+ workspace_url=normalized_host,
123
127
  service_principal_client_id=client_id,
124
128
  service_principal_client_secret=client_secret,
125
129
  )
126
- elif pat and workspace_host:
130
+ elif pat and normalized_host:
127
131
  return VectorSearchClient(
128
- workspace_url=workspace_host,
132
+ workspace_url=normalized_host,
129
133
  personal_access_token=pat,
130
134
  )
131
135
  else:
@@ -177,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
177
181
  experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
178
182
  if experiment is None:
179
183
  experiment_id: str = mlflow.create_experiment(name=experiment_name)
180
- logger.info(
181
- f"Created new experiment: {experiment_name} (ID: {experiment_id})"
184
+ logger.success(
185
+ "Created new MLflow experiment",
186
+ experiment_name=experiment_name,
187
+ experiment_id=experiment_id,
182
188
  )
183
189
  experiment = mlflow.get_experiment(experiment_id)
184
190
  return experiment
185
191
 
186
192
  def create_token(self) -> str:
187
193
  current_user: User = self.w.current_user.me()
188
- logger.debug(f"Authenticated to Databricks as {current_user}")
194
+ logger.debug("Authenticated to Databricks", user=str(current_user))
189
195
  headers: dict[str, str] = self.w.config.authenticate()
190
196
  token: str = headers["Authorization"].replace("Bearer ", "")
191
197
  return token
@@ -197,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
197
203
  secret_response: GetSecretResponse = self.w.secrets.get_secret(
198
204
  secret_scope, secret_key
199
205
  )
200
- logger.debug(f"Retrieved secret {secret_key} from scope {secret_scope}")
206
+ logger.trace(
207
+ "Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
208
+ )
201
209
  encoded_secret: str = secret_response.value
202
210
  decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
203
211
  return decoded_secret
204
212
  except NotFound:
205
213
  logger.warning(
206
- f"Secret {secret_key} not found in scope {secret_scope}, using default value"
214
+ "Secret not found, using default value",
215
+ secret_key=secret_key,
216
+ secret_scope=secret_scope,
207
217
  )
208
218
  except Exception as e:
209
219
  logger.error(
210
- f"Error retrieving secret {secret_key} from scope {secret_scope}: {e}"
220
+ "Error retrieving secret",
221
+ secret_key=secret_key,
222
+ secret_scope=secret_scope,
223
+ error=str(e),
211
224
  )
212
225
 
213
226
  return default_value
@@ -216,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
216
229
  self,
217
230
  config: AppConfig,
218
231
  ) -> ModelInfo:
219
- logger.debug("Creating agent...")
232
+ logger.info("Creating agent")
220
233
  mlflow.set_registry_uri("databricks-uc")
221
234
 
235
+ # Set up experiment for proper tracking
236
+ experiment: Experiment = self.get_or_create_experiment(config)
237
+ mlflow.set_experiment(experiment_id=experiment.experiment_id)
238
+ logger.debug(
239
+ "Using MLflow experiment",
240
+ experiment_name=experiment.name,
241
+ experiment_id=experiment.experiment_id,
242
+ )
243
+
222
244
  llms: Sequence[LLMModel] = list(config.resources.llms.values())
223
245
  vector_indexes: Sequence[IndexModel] = list(
224
246
  config.resources.vector_stores.values()
@@ -236,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
236
258
  )
237
259
  databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
238
260
  volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
261
+ apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
239
262
 
240
263
  resources: Sequence[IsDatabricksResource] = (
241
264
  llms
@@ -247,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
247
270
  + connections
248
271
  + databases
249
272
  + volumes
273
+ + apps
250
274
  )
251
275
 
252
276
  # Flatten all resources from all models into a single list
@@ -260,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
260
284
  for resource in r.as_resources()
261
285
  if not r.on_behalf_of_user
262
286
  ]
263
- logger.debug(f"system_resources: {[r.name for r in system_resources]}")
287
+ logger.trace(
288
+ "System resources identified",
289
+ count=len(system_resources),
290
+ resources=[r.name for r in system_resources],
291
+ )
264
292
 
265
293
  system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
266
294
  resources=system_resources
267
295
  )
268
- logger.debug(f"system_auth_policy: {system_auth_policy}")
296
+ logger.trace("System auth policy created", policy=str(system_auth_policy))
269
297
 
270
298
  api_scopes: Sequence[str] = list(
271
299
  set(
@@ -277,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
277
305
  ]
278
306
  )
279
307
  )
280
- logger.debug(f"api_scopes: {api_scopes}")
308
+ logger.trace("API scopes identified", scopes=api_scopes)
281
309
 
282
310
  user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
283
- logger.debug(f"user_auth_policy: {user_auth_policy}")
311
+ logger.trace("User auth policy created", policy=str(user_auth_policy))
284
312
 
285
313
  auth_policy: AuthPolicy = AuthPolicy(
286
314
  system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
287
315
  )
288
- logger.debug(f"auth_policy: {auth_policy}")
316
+ logger.debug(
317
+ "Auth policy created",
318
+ has_system_auth=system_auth_policy is not None,
319
+ has_user_auth=user_auth_policy is not None,
320
+ )
289
321
 
290
322
  code_paths: list[str] = config.app.code_paths
291
323
  for path in code_paths:
@@ -312,18 +344,42 @@ class DatabricksProvider(ServiceProvider):
312
344
 
313
345
  pip_requirements += get_installed_packages()
314
346
 
315
- logger.debug(f"pip_requirements: {pip_requirements}")
316
- logger.debug(f"code_paths: {code_paths}")
347
+ logger.trace("Pip requirements prepared", count=len(pip_requirements))
348
+ logger.trace("Code paths prepared", count=len(code_paths))
317
349
 
318
350
  run_name: str = normalize_name(config.app.name)
319
- logger.debug(f"run_name: {run_name}")
320
- logger.debug(f"model_path: {model_path.as_posix()}")
351
+ logger.debug(
352
+ "Agent run configuration",
353
+ run_name=run_name,
354
+ model_path=model_path.as_posix(),
355
+ )
321
356
 
322
357
  input_example: dict[str, Any] = None
323
358
  if config.app.input_example:
324
359
  input_example = config.app.input_example.model_dump()
325
360
 
326
- logger.debug(f"input_example: {input_example}")
361
+ logger.trace("Input example configured", has_example=input_example is not None)
362
+
363
+ # Create conda environment with configured Python version
364
+ # This allows deploying from environments with different Python versions
365
+ # (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
366
+ target_python_version: str = config.app.python_version
367
+ logger.debug("Target Python version configured", version=target_python_version)
368
+
369
+ conda_env: dict[str, Any] = {
370
+ "name": "mlflow-env",
371
+ "channels": ["conda-forge"],
372
+ "dependencies": [
373
+ f"python={target_python_version}",
374
+ "pip",
375
+ {"pip": list(pip_requirements)},
376
+ ],
377
+ }
378
+ logger.trace(
379
+ "Conda environment configured",
380
+ python_version=target_python_version,
381
+ pip_packages_count=len(pip_requirements),
382
+ )
327
383
 
328
384
  with mlflow.start_run(run_name=run_name):
329
385
  mlflow.set_tag("type", "agent")
@@ -333,7 +389,7 @@ class DatabricksProvider(ServiceProvider):
333
389
  code_paths=code_paths,
334
390
  model_config=config.model_dump(mode="json", by_alias=True),
335
391
  name="agent",
336
- pip_requirements=pip_requirements,
392
+ conda_env=conda_env,
337
393
  input_example=input_example,
338
394
  # resources=all_resources,
339
395
  auth_policy=auth_policy,
@@ -344,8 +400,10 @@ class DatabricksProvider(ServiceProvider):
344
400
  model_version: ModelVersion = mlflow.register_model(
345
401
  name=registered_model_name, model_uri=logged_agent_info.model_uri
346
402
  )
347
- logger.debug(
348
- f"Registered model: {registered_model_name} with version: {model_version.version}"
403
+ logger.success(
404
+ "Model registered",
405
+ model_name=registered_model_name,
406
+ version=model_version.version,
349
407
  )
350
408
 
351
409
  client: MlflowClient = MlflowClient()
@@ -357,7 +415,7 @@ class DatabricksProvider(ServiceProvider):
357
415
  key="dao_ai",
358
416
  value=dao_ai_version(),
359
417
  )
360
- logger.debug(f"Set dao_ai tag on model version {model_version.version}")
418
+ logger.trace("Set dao_ai tag on model version", version=model_version.version)
361
419
 
362
420
  client.set_registered_model_alias(
363
421
  name=registered_model_name,
@@ -374,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
374
432
  aliased_model: ModelVersion = client.get_model_version_by_alias(
375
433
  registered_model_name, config.app.alias
376
434
  )
377
- logger.debug(
378
- f"Model {registered_model_name} aliased to {config.app.alias} with version: {aliased_model.version}"
435
+ logger.info(
436
+ "Model aliased",
437
+ model_name=registered_model_name,
438
+ alias=config.app.alias,
439
+ version=aliased_model.version,
379
440
  )
380
441
 
381
442
  def deploy_agent(self, config: AppConfig) -> None:
382
- logger.debug("Deploying agent...")
443
+ logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
383
444
  mlflow.set_registry_uri("databricks-uc")
384
445
 
385
446
  endpoint_name: str = config.app.endpoint_name
@@ -400,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
400
461
  agents.get_deployments(endpoint_name)
401
462
  endpoint_exists = True
402
463
  logger.debug(
403
- f"Endpoint {endpoint_name} already exists, updating without tags to avoid conflicts..."
464
+ "Endpoint already exists, updating", endpoint_name=endpoint_name
404
465
  )
405
466
  except Exception:
406
- logger.debug(
407
- f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
408
- )
467
+ logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
409
468
 
410
469
  # Deploy - skip tags for existing endpoints to avoid conflicts
411
470
  agents.deploy(
@@ -421,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
421
480
  registered_model_name: str = config.app.registered_model.full_name
422
481
  permissions: Sequence[dict[str, Any]] = config.app.permissions
423
482
 
424
- logger.debug(registered_model_name)
425
- logger.debug(permissions)
483
+ logger.debug(
484
+ "Configuring model permissions",
485
+ model_name=registered_model_name,
486
+ permissions_count=len(permissions),
487
+ )
426
488
 
427
489
  for permission in permissions:
428
490
  principals: Sequence[str] = permission.principals
@@ -442,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
442
504
  try:
443
505
  catalog_info = self.w.catalogs.get(name=schema.catalog_name)
444
506
  except NotFound:
445
- logger.debug(f"Creating catalog: {schema.catalog_name}")
507
+ logger.info("Creating catalog", catalog_name=schema.catalog_name)
446
508
  catalog_info = self.w.catalogs.create(name=schema.catalog_name)
447
509
  return catalog_info
448
510
 
@@ -452,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
452
514
  try:
453
515
  schema_info = self.w.schemas.get(full_name=schema.full_name)
454
516
  except NotFound:
455
- logger.debug(f"Creating schema: {schema.full_name}")
517
+ logger.info("Creating schema", schema_name=schema.full_name)
456
518
  schema_info = self.w.schemas.create(
457
519
  name=schema.schema_name, catalog_name=catalog_info.name
458
520
  )
@@ -464,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
464
526
  try:
465
527
  volume_info = self.w.volumes.read(name=volume.full_name)
466
528
  except NotFound:
467
- logger.debug(f"Creating volume: {volume.full_name}")
529
+ logger.info("Creating volume", volume_name=volume.full_name)
468
530
  volume_info = self.w.volumes.create(
469
531
  catalog_name=schema_info.catalog_name,
470
532
  schema_name=schema_info.name,
@@ -475,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
475
537
 
476
538
  def create_path(self, volume_path: VolumePathModel) -> Path:
477
539
  path: Path = volume_path.full_name
478
- logger.info(f"Creating volume path: {path}")
540
+ logger.info("Creating volume path", path=str(path))
479
541
  self.w.files.create_directory(path)
480
542
  return path
481
543
 
@@ -516,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
516
578
 
517
579
  if ddl:
518
580
  ddl_path: Path = Path(ddl)
519
- logger.debug(f"Executing DDL from: {ddl_path}")
581
+ logger.debug("Executing DDL", ddl_path=str(ddl_path))
520
582
  statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
521
583
  for statement in statements:
522
- logger.debug(statement)
523
- logger.debug(f"args: {args}")
584
+ logger.trace(
585
+ "Executing DDL statement", statement=str(statement)[:100], args=args
586
+ )
524
587
  spark.sql(
525
588
  str(statement),
526
589
  args=args,
@@ -529,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
529
592
  if data:
530
593
  data_path: Path = Path(data)
531
594
  if format == "sql":
532
- logger.debug(f"Executing SQL from: {data_path}")
595
+ logger.debug("Executing SQL from file", data_path=str(data_path))
533
596
  data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
534
597
  for statement in data_statements:
535
- logger.debug(statement)
536
- logger.debug(f"args: {args}")
598
+ logger.trace(
599
+ "Executing SQL statement",
600
+ statement=str(statement)[:100],
601
+ args=args,
602
+ )
537
603
  spark.sql(
538
604
  str(statement),
539
605
  args=args,
540
606
  )
541
607
  else:
542
- logger.debug(f"Writing to: {table}")
608
+ logger.debug("Writing dataset to table", table=table)
543
609
  if not data_path.is_absolute():
544
610
  data_path = current_dir / data_path
545
- logger.debug(f"Data path: {data_path.as_posix()}")
611
+ logger.trace("Data path resolved", path=data_path.as_posix())
546
612
  if format == "excel":
547
613
  pdf = pd.read_excel(data_path.as_posix())
548
614
  df = spark.createDataFrame(pdf, schema=dataset.table_schema)
@@ -566,13 +632,17 @@ class DatabricksProvider(ServiceProvider):
566
632
  verbose=True,
567
633
  )
568
634
 
569
- logger.debug(f"Endpoint named {vector_store.endpoint.name} is ready.")
635
+ logger.success(
636
+ "Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
637
+ )
570
638
 
571
639
  if not index_exists(
572
640
  self.vsc, vector_store.endpoint.name, vector_store.index.full_name
573
641
  ):
574
- logger.debug(
575
- f"Creating index {vector_store.index.full_name} on endpoint {vector_store.endpoint.name}..."
642
+ logger.info(
643
+ "Creating vector search index",
644
+ index_name=vector_store.index.full_name,
645
+ endpoint_name=vector_store.endpoint.name,
576
646
  )
577
647
  self.vsc.create_delta_sync_index_and_wait(
578
648
  endpoint_name=vector_store.endpoint.name,
@@ -586,7 +656,8 @@ class DatabricksProvider(ServiceProvider):
586
656
  )
587
657
  else:
588
658
  logger.debug(
589
- f"Index {vector_store.index.full_name} already exists, checking status and syncing..."
659
+ "Vector search index already exists, checking status",
660
+ index_name=vector_store.index.full_name,
590
661
  )
591
662
  index = self.vsc.get_index(
592
663
  vector_store.endpoint.name, vector_store.index.full_name
@@ -609,54 +680,61 @@ class DatabricksProvider(ServiceProvider):
609
680
 
610
681
  if pipeline_status in [
611
682
  "COMPLETED",
683
+ "ONLINE",
612
684
  "FAILED",
613
685
  "CANCELED",
614
686
  "ONLINE_PIPELINE_FAILED",
615
687
  ]:
616
- logger.debug(
617
- f"Index is ready to sync (status: {pipeline_status})"
618
- )
688
+ logger.debug("Index ready to sync", status=pipeline_status)
619
689
  break
620
690
  elif pipeline_status in [
621
691
  "WAITING_FOR_RESOURCES",
622
692
  "PROVISIONING",
623
693
  "INITIALIZING",
624
694
  "INDEXING",
625
- "ONLINE",
626
695
  ]:
627
- logger.debug(
628
- f"Index not ready yet (status: {pipeline_status}), waiting {wait_interval} seconds..."
696
+ logger.trace(
697
+ "Index not ready, waiting",
698
+ status=pipeline_status,
699
+ wait_seconds=wait_interval,
629
700
  )
630
701
  time.sleep(wait_interval)
631
702
  elapsed += wait_interval
632
703
  else:
633
704
  logger.warning(
634
- f"Unknown pipeline status: {pipeline_status}, attempting sync anyway"
705
+ "Unknown pipeline status, attempting sync",
706
+ status=pipeline_status,
635
707
  )
636
708
  break
637
709
  except Exception as status_error:
638
710
  logger.warning(
639
- f"Could not check index status: {status_error}, attempting sync anyway"
711
+ "Could not check index status, attempting sync",
712
+ error=str(status_error),
640
713
  )
641
714
  break
642
715
 
643
716
  if elapsed >= max_wait_time:
644
717
  logger.warning(
645
- f"Timed out waiting for index to be ready after {max_wait_time} seconds"
718
+ "Timed out waiting for index to be ready",
719
+ max_wait_seconds=max_wait_time,
646
720
  )
647
721
 
648
722
  # Now attempt to sync
649
723
  try:
650
724
  index.sync()
651
- logger.debug("Index sync completed successfully")
725
+ logger.success("Index sync completed")
652
726
  except Exception as sync_error:
653
727
  if "not ready to sync yet" in str(sync_error).lower():
654
- logger.warning(f"Index still not ready to sync: {sync_error}")
728
+ logger.warning(
729
+ "Index still not ready to sync", error=str(sync_error)
730
+ )
655
731
  else:
656
732
  raise sync_error
657
733
 
658
- logger.debug(
659
- f"index {vector_store.index.full_name} on table {vector_store.source_table.full_name} is ready"
734
+ logger.success(
735
+ "Vector search index ready",
736
+ index_name=vector_store.index.full_name,
737
+ source_table=vector_store.source_table.full_name,
660
738
  )
661
739
 
662
740
  def get_vector_index(self, vector_store: VectorStoreModel) -> None:
@@ -692,12 +770,16 @@ class DatabricksProvider(ServiceProvider):
692
770
  # sql = sql.replace("{catalog_name}", schema.catalog_name)
693
771
  # sql = sql.replace("{schema_name}", schema.schema_name)
694
772
 
695
- logger.info(function.name)
696
- logger.info(sql)
773
+ logger.info("Creating SQL function", function_name=function.name)
774
+ logger.trace("SQL function body", sql=sql[:200])
697
775
  _: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
698
776
 
699
777
  if unity_catalog_function.test:
700
- logger.info(unity_catalog_function.test.parameters)
778
+ logger.debug(
779
+ "Testing function",
780
+ function_name=function.full_name,
781
+ parameters=unity_catalog_function.test.parameters,
782
+ )
701
783
 
702
784
  result: FunctionExecutionResult = self.dfs.execute_function(
703
785
  function_name=function.full_name,
@@ -705,37 +787,50 @@ class DatabricksProvider(ServiceProvider):
705
787
  )
706
788
 
707
789
  if result.error:
708
- logger.error(result.error)
790
+ logger.error(
791
+ "Function test failed",
792
+ function_name=function.full_name,
793
+ error=result.error,
794
+ )
709
795
  else:
710
- logger.info(f"Function {function.full_name} executed successfully.")
711
- logger.info(f"Result: {result}")
796
+ logger.success(
797
+ "Function test passed", function_name=function.full_name
798
+ )
799
+ logger.debug("Function test result", result=str(result))
712
800
 
713
801
  def find_columns(self, table_model: TableModel) -> Sequence[str]:
714
- logger.debug(f"Finding columns for table: {table_model.full_name}")
802
+ logger.trace("Finding columns for table", table=table_model.full_name)
715
803
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
716
804
  columns: Sequence[ColumnInfo] = table_info.columns
717
805
  column_names: Sequence[str] = [c.name for c in columns]
718
- logger.debug(f"Columns found: {column_names}")
806
+ logger.debug(
807
+ "Columns found",
808
+ table=table_model.full_name,
809
+ columns_count=len(column_names),
810
+ )
719
811
  return column_names
720
812
 
721
813
  def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
722
- logger.debug(f"Finding primary key for table: {table_model.full_name}")
814
+ logger.trace("Finding primary key for table", table=table_model.full_name)
723
815
  primary_keys: Sequence[str] | None = None
724
816
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
725
817
  constraints: Sequence[TableConstraint] = table_info.table_constraints
726
818
  primary_key_constraint: PrimaryKeyConstraint | None = next(
727
- c.primary_key_constraint for c in constraints if c.primary_key_constraint
819
+ (c.primary_key_constraint for c in constraints if c.primary_key_constraint),
820
+ None,
728
821
  )
729
822
  if primary_key_constraint:
730
823
  primary_keys = primary_key_constraint.child_columns
731
824
 
732
- logger.debug(f"Primary key for table {table_model.full_name}: {primary_keys}")
825
+ logger.debug(
826
+ "Primary key found", table=table_model.full_name, primary_keys=primary_keys
827
+ )
733
828
  return primary_keys
734
829
 
735
830
  def find_vector_search_endpoint(
736
831
  self, predicate: Callable[[dict[str, Any]], bool]
737
832
  ) -> str | None:
738
- logger.debug("Finding vector search endpoint...")
833
+ logger.trace("Finding vector search endpoint")
739
834
  endpoint_name: str | None = None
740
835
  vector_search_endpoints: Sequence[dict[str, Any]] = (
741
836
  self.vsc.list_endpoints().get("endpoints", [])
@@ -744,11 +839,13 @@ class DatabricksProvider(ServiceProvider):
744
839
  if predicate(endpoint):
745
840
  endpoint_name = endpoint["name"]
746
841
  break
747
- logger.debug(f"Vector search endpoint found: {endpoint_name}")
842
+ logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
748
843
  return endpoint_name
749
844
 
750
845
  def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
751
- logger.debug(f"Finding vector search index: {index_model.full_name}")
846
+ logger.trace(
847
+ "Finding endpoint for vector search index", index_name=index_model.full_name
848
+ )
752
849
  all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
753
850
  "endpoints", []
754
851
  )
@@ -758,14 +855,99 @@ class DatabricksProvider(ServiceProvider):
758
855
  endpoint_name: str = endpoint["name"]
759
856
  indexes = self.vsc.list_indexes(name=endpoint_name)
760
857
  vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
761
- logger.trace(f"Endpoint: {endpoint_name}, vector_indexes: {vector_indexes}")
858
+ logger.trace(
859
+ "Checking endpoint for indexes",
860
+ endpoint_name=endpoint_name,
861
+ indexes_count=len(vector_indexes),
862
+ )
762
863
  index_names = [vector_index["name"] for vector_index in vector_indexes]
763
864
  if index_name in index_names:
764
865
  found_endpoint_name = endpoint_name
765
866
  break
766
- logger.debug(f"Vector search index found: {found_endpoint_name}")
867
+ logger.debug(
868
+ "Vector search index endpoint found",
869
+ index_name=index_model.full_name,
870
+ endpoint_name=found_endpoint_name,
871
+ )
767
872
  return found_endpoint_name
768
873
 
874
+ def _wait_for_database_available(
875
+ self,
876
+ workspace_client: WorkspaceClient,
877
+ instance_name: str,
878
+ max_wait_time: int = 600,
879
+ wait_interval: int = 10,
880
+ ) -> None:
881
+ """
882
+ Wait for a database instance to become AVAILABLE.
883
+
884
+ Args:
885
+ workspace_client: The Databricks workspace client
886
+ instance_name: Name of the database instance to wait for
887
+ max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
888
+ wait_interval: Time between status checks in seconds (default: 10)
889
+
890
+ Raises:
891
+ TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
892
+ RuntimeError: If the database enters a failed or deleted state
893
+ """
894
+ import time
895
+ from typing import Any
896
+
897
+ logger.info(
898
+ "Waiting for database instance to become AVAILABLE",
899
+ instance_name=instance_name,
900
+ )
901
+ elapsed: int = 0
902
+
903
+ while elapsed < max_wait_time:
904
+ try:
905
+ current_instance: Any = workspace_client.database.get_database_instance(
906
+ name=instance_name
907
+ )
908
+ current_state: str = current_instance.state
909
+ logger.trace(
910
+ "Database instance state checked",
911
+ instance_name=instance_name,
912
+ state=current_state,
913
+ )
914
+
915
+ if current_state == "AVAILABLE":
916
+ logger.success(
917
+ "Database instance is now AVAILABLE",
918
+ instance_name=instance_name,
919
+ )
920
+ return
921
+ elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
922
+ logger.trace(
923
+ "Database instance not ready, waiting",
924
+ instance_name=instance_name,
925
+ state=current_state,
926
+ wait_seconds=wait_interval,
927
+ )
928
+ time.sleep(wait_interval)
929
+ elapsed += wait_interval
930
+ elif current_state in ["STOPPED", "DELETING", "FAILED"]:
931
+ raise RuntimeError(
932
+ f"Database instance {instance_name} entered unexpected state: {current_state}"
933
+ )
934
+ else:
935
+ logger.warning(
936
+ "Unknown database state, continuing to wait",
937
+ instance_name=instance_name,
938
+ state=current_state,
939
+ )
940
+ time.sleep(wait_interval)
941
+ elapsed += wait_interval
942
+ except NotFound:
943
+ raise RuntimeError(
944
+ f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
945
+ )
946
+
947
+ raise TimeoutError(
948
+ f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
949
+ )
950
+
769
951
  def create_lakebase(self, database: DatabaseModel) -> None:
770
952
  """
771
953
  Create a Lakebase database instance using the Databricks workspace client.
@@ -796,13 +978,17 @@ class DatabricksProvider(ServiceProvider):
796
978
 
797
979
  if existing_instance:
798
980
  logger.debug(
799
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
981
+ "Database instance already exists",
982
+ instance_name=database.instance_name,
983
+ state=existing_instance.state,
800
984
  )
801
985
 
802
986
  # Check if database is in an intermediate state
803
987
  if existing_instance.state in ["STARTING", "UPDATING"]:
804
988
  logger.info(
805
- f"Database instance {database.instance_name} is in {existing_instance.state} state, waiting for it to become AVAILABLE..."
989
+ "Database instance in intermediate state, waiting",
990
+ instance_name=database.instance_name,
991
+ state=existing_instance.state,
806
992
  )
807
993
 
808
994
  # Wait for database to reach a stable state
@@ -818,65 +1004,87 @@ class DatabricksProvider(ServiceProvider):
818
1004
  )
819
1005
  )
820
1006
  current_state: str = current_instance.state
821
- logger.debug(f"Database instance state: {current_state}")
1007
+ logger.trace(
1008
+ "Checking database instance state",
1009
+ instance_name=database.instance_name,
1010
+ state=current_state,
1011
+ )
822
1012
 
823
1013
  if current_state == "AVAILABLE":
824
- logger.info(
825
- f"Database instance {database.instance_name} is now AVAILABLE"
1014
+ logger.success(
1015
+ "Database instance is now AVAILABLE",
1016
+ instance_name=database.instance_name,
826
1017
  )
827
1018
  break
828
1019
  elif current_state in ["STARTING", "UPDATING"]:
829
- logger.debug(
830
- f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
1020
+ logger.trace(
1021
+ "Database instance not ready, waiting",
1022
+ instance_name=database.instance_name,
1023
+ state=current_state,
1024
+ wait_seconds=wait_interval,
831
1025
  )
832
1026
  time.sleep(wait_interval)
833
1027
  elapsed += wait_interval
834
1028
  elif current_state in ["STOPPED", "DELETING"]:
835
1029
  logger.warning(
836
- f"Database instance {database.instance_name} is in unexpected state: {current_state}"
1030
+ "Database instance in unexpected state",
1031
+ instance_name=database.instance_name,
1032
+ state=current_state,
837
1033
  )
838
1034
  break
839
1035
  else:
840
1036
  logger.warning(
841
- f"Unknown database state: {current_state}, proceeding anyway"
1037
+ "Unknown database state, proceeding",
1038
+ instance_name=database.instance_name,
1039
+ state=current_state,
842
1040
  )
843
1041
  break
844
1042
  except NotFound:
845
1043
  logger.warning(
846
- f"Database instance {database.instance_name} no longer exists, will attempt to recreate"
1044
+ "Database instance no longer exists, will recreate",
1045
+ instance_name=database.instance_name,
847
1046
  )
848
1047
  break
849
1048
  except Exception as state_error:
850
1049
  logger.warning(
851
- f"Could not check database state: {state_error}, proceeding anyway"
1050
+ "Could not check database state, proceeding",
1051
+ instance_name=database.instance_name,
1052
+ error=str(state_error),
852
1053
  )
853
1054
  break
854
1055
 
855
1056
  if elapsed >= max_wait_time:
856
1057
  logger.warning(
857
- f"Timed out waiting for database instance {database.instance_name} to become AVAILABLE after {max_wait_time} seconds"
1058
+ "Timed out waiting for database to become AVAILABLE",
1059
+ instance_name=database.instance_name,
1060
+ max_wait_seconds=max_wait_time,
858
1061
  )
859
1062
 
860
1063
  elif existing_instance.state == "AVAILABLE":
861
1064
  logger.info(
862
- f"Database instance {database.instance_name} already exists and is AVAILABLE"
1065
+ "Database instance already exists and is AVAILABLE",
1066
+ instance_name=database.instance_name,
863
1067
  )
864
1068
  return
865
1069
  elif existing_instance.state in ["STOPPED", "DELETING"]:
866
1070
  logger.warning(
867
- f"Database instance {database.instance_name} is in {existing_instance.state} state"
1071
+ "Database instance in terminal state",
1072
+ instance_name=database.instance_name,
1073
+ state=existing_instance.state,
868
1074
  )
869
1075
  return
870
1076
  else:
871
1077
  logger.info(
872
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
1078
+ "Database instance already exists",
1079
+ instance_name=database.instance_name,
1080
+ state=existing_instance.state,
873
1081
  )
874
1082
  return
875
1083
 
876
1084
  except NotFound:
877
1085
  # Database doesn't exist, proceed with creation
878
- logger.debug(
879
- f"Database instance {database.instance_name} not found, creating new instance..."
1086
+ logger.info(
1087
+ "Creating new database instance", instance_name=database.instance_name
880
1088
  )
881
1089
 
882
1090
  try:
@@ -896,10 +1104,17 @@ class DatabricksProvider(ServiceProvider):
896
1104
  workspace_client.database.create_database_instance(
897
1105
  database_instance=database_instance
898
1106
  )
899
- logger.info(
900
- f"Successfully created database instance: {database.instance_name}"
1107
+ logger.success(
1108
+ "Database instance created successfully",
1109
+ instance_name=database.instance_name,
901
1110
  )
902
1111
 
1112
+ # Wait for the newly created database to become AVAILABLE
1113
+ self._wait_for_database_available(
1114
+ workspace_client, database.instance_name
1115
+ )
1116
+ return
1117
+
903
1118
  except Exception as create_error:
904
1119
  error_msg: str = str(create_error)
905
1120
 
@@ -909,13 +1124,20 @@ class DatabricksProvider(ServiceProvider):
909
1124
  or "RESOURCE_ALREADY_EXISTS" in error_msg
910
1125
  ):
911
1126
  logger.info(
912
- f"Database instance {database.instance_name} was created concurrently by another process"
1127
+ "Database instance was created concurrently",
1128
+ instance_name=database.instance_name,
1129
+ )
1130
+ # Still need to wait for the database to become AVAILABLE
1131
+ self._wait_for_database_available(
1132
+ workspace_client, database.instance_name
913
1133
  )
914
1134
  return
915
1135
  else:
916
1136
  # Re-raise unexpected errors
917
1137
  logger.error(
918
- f"Error creating database instance {database.instance_name}: {create_error}"
1138
+ "Error creating database instance",
1139
+ instance_name=database.instance_name,
1140
+ error=str(create_error),
919
1141
  )
920
1142
  raise
921
1143
 
@@ -929,12 +1151,15 @@ class DatabricksProvider(ServiceProvider):
929
1151
  or "RESOURCE_ALREADY_EXISTS" in error_msg
930
1152
  ):
931
1153
  logger.info(
932
- f"Database instance {database.instance_name} already exists (detected via exception)"
1154
+ "Database instance already exists (detected via exception)",
1155
+ instance_name=database.instance_name,
933
1156
  )
934
1157
  return
935
1158
  else:
936
1159
  logger.error(
937
- f"Unexpected error while handling database {database.instance_name}: {e}"
1160
+ "Unexpected error while handling database",
1161
+ instance_name=database.instance_name,
1162
+ error=str(e),
938
1163
  )
939
1164
  raise
940
1165
 
@@ -942,7 +1167,9 @@ class DatabricksProvider(ServiceProvider):
942
1167
  """
943
1168
  Ask Databricks to mint a fresh DB credential for this instance.
944
1169
  """
945
- logger.debug(f"Generating password for lakebase instance: {instance_name}")
1170
+ logger.trace(
1171
+ "Generating password for lakebase instance", instance_name=instance_name
1172
+ )
946
1173
  w: WorkspaceClient = self.w
947
1174
  cred: DatabaseCredential = w.database.generate_database_credential(
948
1175
  request_id=str(uuid.uuid4()),
@@ -978,7 +1205,8 @@ class DatabricksProvider(ServiceProvider):
978
1205
  # Validate that client_id is provided
979
1206
  if not database.client_id:
980
1207
  logger.warning(
981
- f"client_id is required to create instance role for database {database.instance_name}"
1208
+ "client_id required to create instance role",
1209
+ instance_name=database.instance_name,
982
1210
  )
983
1211
  return
984
1212
 
@@ -988,7 +1216,10 @@ class DatabricksProvider(ServiceProvider):
988
1216
  instance_name: str = database.instance_name
989
1217
 
990
1218
  logger.debug(
991
- f"Creating instance role '{role_name}' for database {instance_name} with principal {client_id}"
1219
+ "Creating instance role",
1220
+ role_name=role_name,
1221
+ instance_name=instance_name,
1222
+ principal=client_id,
992
1223
  )
993
1224
 
994
1225
  try:
@@ -999,13 +1230,15 @@ class DatabricksProvider(ServiceProvider):
999
1230
  name=role_name,
1000
1231
  )
1001
1232
  logger.info(
1002
- f"Instance role '{role_name}' already exists for database {instance_name}"
1233
+ "Instance role already exists",
1234
+ role_name=role_name,
1235
+ instance_name=instance_name,
1003
1236
  )
1004
1237
  return
1005
1238
  except NotFound:
1006
1239
  # Role doesn't exist, proceed with creation
1007
1240
  logger.debug(
1008
- f"Instance role '{role_name}' not found, creating new role..."
1241
+ "Instance role not found, creating new role", role_name=role_name
1009
1242
  )
1010
1243
 
1011
1244
  # Create the database instance role
@@ -1021,8 +1254,10 @@ class DatabricksProvider(ServiceProvider):
1021
1254
  database_instance_role=role,
1022
1255
  )
1023
1256
 
1024
- logger.info(
1025
- f"Successfully created instance role '{role_name}' for database {instance_name}"
1257
+ logger.success(
1258
+ "Instance role created successfully",
1259
+ role_name=role_name,
1260
+ instance_name=instance_name,
1026
1261
  )
1027
1262
 
1028
1263
  except Exception as e:
@@ -1034,13 +1269,18 @@ class DatabricksProvider(ServiceProvider):
1034
1269
  or "RESOURCE_ALREADY_EXISTS" in error_msg
1035
1270
  ):
1036
1271
  logger.info(
1037
- f"Instance role '{role_name}' was created concurrently for database {instance_name}"
1272
+ "Instance role was created concurrently",
1273
+ role_name=role_name,
1274
+ instance_name=instance_name,
1038
1275
  )
1039
1276
  return
1040
1277
 
1041
1278
  # Re-raise unexpected errors
1042
1279
  logger.error(
1043
- f"Error creating instance role '{role_name}' for database {instance_name}: {e}"
1280
+ "Error creating instance role",
1281
+ role_name=role_name,
1282
+ instance_name=instance_name,
1283
+ error=str(e),
1044
1284
  )
1045
1285
  raise
1046
1286
 
@@ -1050,9 +1290,17 @@ class DatabricksProvider(ServiceProvider):
1050
1290
 
1051
1291
  If an explicit version or alias is specified in the prompt_model, uses that directly.
1052
1292
  Otherwise, tries to load prompts in this order:
1053
- 1. champion alias (if it exists)
1054
- 2. latest alias (if it exists)
1055
- 3. default_template (if provided)
1293
+ 1. champion alias
1294
+ 2. latest alias
1295
+ 3. default alias
1296
+ 4. Register default_template if provided (only if register_to_registry=True)
1297
+ 5. Use default_template directly (fallback)
1298
+
1299
+ The auto_register field controls whether the default_template is automatically
1300
+ synced to the prompt registry:
1301
+ - If True (default): Auto-registers/updates the default_template in the registry
1302
+ - If False: Never registers, but can still load existing prompts from registry
1303
+ or use default_template directly as a local-only prompt
1056
1304
 
1057
1305
  Args:
1058
1306
  prompt_model: The prompt model configuration
@@ -1063,542 +1311,266 @@ class DatabricksProvider(ServiceProvider):
1063
1311
  Raises:
1064
1312
  ValueError: If no prompt can be loaded from any source
1065
1313
  """
1314
+
1066
1315
  prompt_name: str = prompt_model.full_name
1067
1316
 
1068
- # If explicit version or alias is specified, use it directly without fallback
1317
+ # If explicit version or alias is specified, use it directly
1069
1318
  if prompt_model.version or prompt_model.alias:
1070
1319
  try:
1071
1320
  prompt_version: PromptVersion = prompt_model.as_prompt()
1321
+ version_or_alias = (
1322
+ f"version {prompt_model.version}"
1323
+ if prompt_model.version
1324
+ else f"alias {prompt_model.alias}"
1325
+ )
1072
1326
  logger.debug(
1073
- f"Loaded prompt '{prompt_name}' with explicit "
1074
- f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}"
1327
+ "Loaded prompt with explicit version/alias",
1328
+ prompt_name=prompt_name,
1329
+ version_or_alias=version_or_alias,
1075
1330
  )
1076
1331
  return prompt_version
1077
1332
  except Exception as e:
1333
+ version_or_alias = (
1334
+ f"version {prompt_model.version}"
1335
+ if prompt_model.version
1336
+ else f"alias {prompt_model.alias}"
1337
+ )
1078
1338
  logger.warning(
1079
- f"Failed to load prompt '{prompt_name}' with explicit "
1080
- f"{'version ' + str(prompt_model.version) if prompt_model.version else 'alias ' + prompt_model.alias}: {e}"
1339
+ "Failed to load prompt with explicit version/alias",
1340
+ prompt_name=prompt_name,
1341
+ version_or_alias=version_or_alias,
1342
+ error=str(e),
1081
1343
  )
1082
- # Fall through to default_template if available
1083
- else:
1084
- # No explicit version/alias specified - check if default_template needs syncing first
1085
- logger.debug(
1086
- f"No explicit version/alias specified for '{prompt_name}', "
1087
- "checking if default_template needs syncing"
1088
- )
1089
-
1090
- # If we have a default_template, check if it differs from what's in the registry
1091
- # This ensures we always sync config changes before returning any alias
1092
- if prompt_model.default_template:
1093
- try:
1094
- default_uri: str = f"prompts:/{prompt_name}@default"
1095
- default_version: PromptVersion = load_prompt(default_uri)
1096
-
1097
- if (
1098
- default_version.to_single_brace_format().strip()
1099
- != prompt_model.default_template.strip()
1100
- ):
1101
- logger.info(
1102
- f"Config default_template for '{prompt_name}' differs from registry, syncing..."
1103
- )
1104
- return self._sync_default_template_to_registry(
1105
- prompt_name,
1106
- prompt_model.default_template,
1107
- prompt_model.description,
1108
- )
1109
- except Exception as e:
1110
- logger.debug(f"Could not check default alias for sync: {e}")
1344
+ # Fall through to try other methods
1111
1345
 
1112
- # Now try aliases in order: champion → latest default
1113
- logger.debug(
1114
- f"Trying fallback order for '{prompt_name}': champion → latest → default"
1115
- )
1116
-
1117
- # Try champion alias first
1118
- try:
1119
- champion_uri: str = f"prompts:/{prompt_name}@champion"
1120
- prompt_version: PromptVersion = load_prompt(champion_uri)
1121
- logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
1122
- return prompt_version
1123
- except Exception as e:
1124
- logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1125
-
1126
- # Try latest alias next
1127
- try:
1128
- latest_uri: str = f"prompts:/{prompt_name}@latest"
1129
- prompt_version: PromptVersion = load_prompt(latest_uri)
1130
- logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
1131
- return prompt_version
1132
- except Exception as e:
1133
- logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
1346
+ # Try to load in priority order: champion → default (with sync check)
1347
+ logger.trace(
1348
+ "Trying prompt fallback order",
1349
+ prompt_name=prompt_name,
1350
+ order="champion → default",
1351
+ )
1134
1352
 
1135
- # Try default alias last
1353
+ # First, sync default alias if template has changed (even if champion exists)
1354
+ # Only do this if auto_register is True
1355
+ if prompt_model.default_template and prompt_model.auto_register:
1136
1356
  try:
1137
- default_uri: str = f"prompts:/{prompt_name}@default"
1138
- prompt_version: PromptVersion = load_prompt(default_uri)
1139
- logger.info(f"Loaded prompt '{prompt_name}' from default alias")
1140
- return prompt_version
1141
- except Exception as e:
1142
- logger.debug(f"Default alias not found for '{prompt_name}': {e}")
1357
+ # Try to load existing default
1358
+ existing_default = load_prompt(f"prompts:/{prompt_name}@default")
1143
1359
 
1144
- # Fall back to registering default_template if provided
1145
- if prompt_model.default_template:
1146
- logger.info(
1147
- f"Registering default_template for '{prompt_name}' "
1148
- "(no aliases found in registry)"
1149
- )
1150
- return self._sync_default_template_to_registry(
1151
- prompt_name, prompt_model.default_template, prompt_model.description
1152
- )
1153
-
1154
- raise ValueError(
1155
- f"Prompt '{prompt_name}' not found in registry "
1156
- "(tried champion, latest, default aliases) and no default_template provided"
1157
- )
1158
-
1159
- def _sync_default_template_to_registry(
1160
- self, prompt_name: str, default_template: str, description: str | None = None
1161
- ) -> PromptVersion:
1162
- """Register default_template to prompt registry under 'default' alias if changed."""
1163
- prompt_version: PromptVersion | None = None
1360
+ # Check if champion exists and if it matches default
1361
+ champion_matches_default = False
1362
+ try:
1363
+ existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
1364
+ champion_matches_default = (
1365
+ existing_champion.version == existing_default.version
1366
+ )
1367
+ status = (
1368
+ "tracking" if champion_matches_default else "pinned separately"
1369
+ )
1370
+ logger.trace(
1371
+ "Champion vs default version",
1372
+ prompt_name=prompt_name,
1373
+ champion_version=existing_champion.version,
1374
+ default_version=existing_default.version,
1375
+ status=status,
1376
+ )
1377
+ except Exception:
1378
+ # No champion exists
1379
+ logger.trace("No champion alias found", prompt_name=prompt_name)
1164
1380
 
1165
- try:
1166
- # Check if default alias already has the same template
1167
- try:
1168
- logger.debug(f"Loading prompt '{prompt_name}' from registry...")
1169
- existing: PromptVersion = mlflow.genai.load_prompt(
1170
- f"prompts:/{prompt_name}@default"
1171
- )
1381
+ # Check if default_template differs from existing default
1172
1382
  if (
1173
- existing.to_single_brace_format().strip()
1174
- == default_template.strip()
1383
+ existing_default.template.strip()
1384
+ != prompt_model.default_template.strip()
1175
1385
  ):
1176
- logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
1386
+ logger.info(
1387
+ "Default template changed, registering new version",
1388
+ prompt_name=prompt_name,
1389
+ )
1177
1390
 
1178
- # Ensure the "latest" and "champion" aliases also exist and point to the same version
1179
- # This handles prompts created before the fix that added these aliases
1180
- try:
1181
- latest_version: PromptVersion = mlflow.genai.load_prompt(
1182
- f"prompts:/{prompt_name}@latest"
1183
- )
1184
- logger.debug(
1185
- f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
1186
- )
1187
- except Exception:
1391
+ # Only update champion if it was pointing to the old default
1392
+ if champion_matches_default:
1188
1393
  logger.info(
1189
- f"Setting 'latest' alias for existing prompt '{prompt_name}' v{existing.version}"
1190
- )
1191
- mlflow.genai.set_prompt_alias(
1192
- name=prompt_name,
1193
- alias="latest",
1194
- version=existing.version,
1195
- )
1196
-
1197
- # Ensure champion alias exists for first-time deployments
1198
- try:
1199
- champion_version: PromptVersion = mlflow.genai.load_prompt(
1200
- f"prompts:/{prompt_name}@champion"
1394
+ "Champion was tracking default, will update to new version",
1395
+ prompt_name=prompt_name,
1396
+ old_version=existing_default.version,
1201
1397
  )
1202
- logger.debug(
1203
- f"Champion alias already exists for '{prompt_name}' pointing to version {champion_version.version}"
1204
- )
1205
- except Exception:
1398
+ set_champion = True
1399
+ else:
1206
1400
  logger.info(
1207
- f"Setting 'champion' alias for existing prompt '{prompt_name}' v{existing.version}"
1208
- )
1209
- mlflow.genai.set_prompt_alias(
1210
- name=prompt_name,
1211
- alias="champion",
1212
- version=existing.version,
1401
+ "Champion is pinned separately, preserving it",
1402
+ prompt_name=prompt_name,
1213
1403
  )
1404
+ set_champion = False
1214
1405
 
1215
- return existing # Already up-to-date, return existing version
1216
- except Exception:
1217
- logger.debug(
1218
- f"Default alias for prompt '{prompt_name}' doesn't exist yet"
1406
+ self._register_default_template(
1407
+ prompt_name,
1408
+ prompt_model.default_template,
1409
+ prompt_model.description,
1410
+ set_champion=set_champion,
1411
+ )
1412
+ except Exception as e:
1413
+ # No default exists yet, register it
1414
+ logger.trace(
1415
+ "No default alias found", prompt_name=prompt_name, error=str(e)
1219
1416
  )
1220
-
1221
- # Register new version and set as default alias
1222
- commit_message = description or "Auto-synced from default_template"
1223
- prompt_version = mlflow.genai.register_prompt(
1224
- name=prompt_name,
1225
- template=default_template,
1226
- commit_message=commit_message,
1227
- tags={"dao_ai": dao_ai_version()},
1228
- )
1229
-
1230
- logger.debug(
1231
- f"Setting default, latest, and champion aliases for prompt '{prompt_name}'"
1232
- )
1233
- mlflow.genai.set_prompt_alias(
1234
- name=prompt_name,
1235
- alias="default",
1236
- version=prompt_version.version,
1237
- )
1238
- mlflow.genai.set_prompt_alias(
1239
- name=prompt_name,
1240
- alias="latest",
1241
- version=prompt_version.version,
1242
- )
1243
- mlflow.genai.set_prompt_alias(
1244
- name=prompt_name,
1245
- alias="champion",
1246
- version=prompt_version.version,
1417
+ logger.info(
1418
+ "Registering default template as default alias",
1419
+ prompt_name=prompt_name,
1420
+ )
1421
+ # First registration - set both default and champion
1422
+ self._register_default_template(
1423
+ prompt_name,
1424
+ prompt_model.default_template,
1425
+ prompt_model.description,
1426
+ set_champion=True,
1427
+ )
1428
+ elif prompt_model.default_template and not prompt_model.auto_register:
1429
+ logger.trace(
1430
+ "Prompt has auto_register=False, skipping registration",
1431
+ prompt_name=prompt_name,
1247
1432
  )
1248
1433
 
1249
- logger.info(
1250
- f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry with 'default', 'latest', and 'champion' aliases"
1251
- )
1434
+ # 1. Try champion alias (highest priority for execution)
1435
+ try:
1436
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
1437
+ logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
1252
1438
  return prompt_version
1253
-
1254
1439
  except Exception as e:
1255
- logger.error(f"Failed to sync '{prompt_name}' to registry: {e}")
1256
- raise ValueError(
1257
- f"Failed to sync prompt '{prompt_name}' to registry and unable to retrieve existing version"
1258
- ) from e
1259
-
1260
- def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
1261
- """
1262
- Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
1263
-
1264
- This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
1265
- https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
1440
+ logger.trace(
1441
+ "Champion alias not found", prompt_name=prompt_name, error=str(e)
1442
+ )
1266
1443
 
1267
- Args:
1268
- optimization: PromptOptimizationModel containing configuration
1444
+ # 2. Try default alias (already synced above)
1445
+ if prompt_model.default_template:
1446
+ try:
1447
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
1448
+ logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
1449
+ return prompt_version
1450
+ except Exception as e:
1451
+ # Should not happen since we just registered it above, but handle anyway
1452
+ logger.trace(
1453
+ "Default alias not found", prompt_name=prompt_name, error=str(e)
1454
+ )
1269
1455
 
1270
- Returns:
1271
- PromptModel: The optimized prompt with new URI
1272
- """
1273
- from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
1274
- from mlflow.genai.scorers import Correctness
1275
-
1276
- from dao_ai.config import AgentModel, PromptModel
1277
-
1278
- logger.info(f"Optimizing prompt: {optimization.name}")
1279
-
1280
- # Get agent and prompt (prompt is guaranteed to be set by validator)
1281
- agent_model: AgentModel = optimization.agent
1282
- prompt: PromptModel = optimization.prompt # type: ignore[assignment]
1283
- agent_model.prompt = prompt.uri
1284
-
1285
- print(f"prompt={agent_model.prompt}")
1286
- # Log the prompt URI scheme being used
1287
- # Supports three schemes:
1288
- # 1. Specific version: "prompts:/qa/1" (when version is specified)
1289
- # 2. Alias: "prompts:/qa@champion" (when alias is specified)
1290
- # 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
1291
- prompt_uri: str = prompt.uri
1292
- logger.info(f"Using prompt URI for optimization: {prompt_uri}")
1293
-
1294
- # Load the specific prompt version by URI for comparison
1295
- # Try to load the exact version specified, but if it doesn't exist,
1296
- # use get_prompt to create it from default_template
1297
- prompt_version: PromptVersion
1456
+ # 3. Try latest alias as final fallback
1298
1457
  try:
1299
- prompt_version = load_prompt(prompt_uri)
1300
- logger.info(f"Successfully loaded prompt from registry: {prompt_uri}")
1458
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
1459
+ logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
1460
+ return prompt_version
1301
1461
  except Exception as e:
1462
+ logger.trace(
1463
+ "Latest alias not found", prompt_name=prompt_name, error=str(e)
1464
+ )
1465
+
1466
+ # 4. Final fallback: use default_template directly if available
1467
+ if prompt_model.default_template:
1302
1468
  logger.warning(
1303
- f"Could not load prompt '{prompt_uri}' directly: {e}. "
1304
- "Attempting to create from default_template..."
1469
+ "Could not load prompt from registry, using default_template directly",
1470
+ prompt_name=prompt_name,
1305
1471
  )
1306
- # Use get_prompt which will create from default_template if needed
1307
- prompt_version = self.get_prompt(prompt)
1308
- logger.info(
1309
- f"Created/loaded prompt '{prompt.full_name}' (will optimize against this version)"
1472
+ return PromptVersion(
1473
+ name=prompt_name,
1474
+ version=1,
1475
+ template=prompt_model.default_template,
1476
+ tags={"dao_ai": dao_ai_version()},
1310
1477
  )
1311
1478
 
1312
- # Load the evaluation dataset by name
1313
- logger.debug(f"Looking up dataset: {optimization.dataset}")
1314
- dataset: EvaluationDataset
1315
- if isinstance(optimization.dataset, str):
1316
- dataset = get_dataset(name=optimization.dataset)
1317
- else:
1318
- dataset = optimization.dataset.as_dataset()
1319
-
1320
- # Set up reflection model for the optimizer
1321
- reflection_model_name: str
1322
- if optimization.reflection_model:
1323
- if isinstance(optimization.reflection_model, str):
1324
- reflection_model_name = optimization.reflection_model
1325
- else:
1326
- reflection_model_name = optimization.reflection_model.uri
1327
- else:
1328
- reflection_model_name = agent_model.model.uri
1329
- logger.debug(f"Using reflection model: {reflection_model_name}")
1330
-
1331
- # Create the GepaPromptOptimizer
1332
- optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
1333
- reflection_model=reflection_model_name,
1334
- max_metric_calls=optimization.num_candidates,
1335
- display_progress_bar=True,
1479
+ raise ValueError(
1480
+ f"Prompt '{prompt_name}' not found in registry "
1481
+ "(tried champion, default, latest aliases) "
1482
+ "and no default_template provided"
1336
1483
  )
1337
1484
 
1338
- # Set up scorer (judge model for evaluation)
1339
- scorer_model: str
1340
- if optimization.scorer_model:
1341
- if isinstance(optimization.scorer_model, str):
1342
- scorer_model = optimization.scorer_model
1343
- else:
1344
- scorer_model = optimization.scorer_model.uri
1345
- else:
1346
- scorer_model = agent_model.model.uri # Use Databricks default
1347
- logger.debug(f"Using scorer with model: {scorer_model}")
1348
-
1349
- scorers: list[Correctness] = [Correctness(model=scorer_model)]
1350
-
1351
- # Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
1352
- # DO NOT overwrite with prompt_version.uri as that uses fallback logic
1353
- logger.debug(f"Optimizing prompt: {prompt_uri}")
1354
-
1355
- agent: ResponsesAgent = agent_model.as_responses_agent()
1356
-
1357
- # Create predict function that will be optimized
1358
- def predict_fn(**inputs: dict[str, Any]) -> str:
1359
- """Prediction function that uses the ResponsesAgent with ChatPayload.
1360
-
1361
- The agent already has the prompt referenced/applied, so we just need to
1362
- convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
1363
-
1364
- Args:
1365
- **inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
1366
-
1367
- Returns:
1368
- str: The agent's response content
1369
- """
1370
- from mlflow.types.responses import (
1371
- ResponsesAgentRequest,
1372
- ResponsesAgentResponse,
1373
- )
1374
- from mlflow.types.responses_helpers import Message
1375
-
1376
- from dao_ai.config import ChatPayload
1377
-
1378
- # Verify agent is accessible (should be captured from outer scope)
1379
- if agent is None:
1380
- raise RuntimeError(
1381
- "Agent object is not available in predict_fn. "
1382
- "This may indicate a serialization issue with the ResponsesAgent."
1383
- )
1384
-
1385
- # Convert inputs to ChatPayload
1386
- chat_payload: ChatPayload = ChatPayload(**inputs)
1387
-
1388
- # Convert ChatPayload messages to MLflow Message format
1389
- mlflow_messages: list[Message] = [
1390
- Message(role=msg.role, content=msg.content)
1391
- for msg in chat_payload.messages
1392
- ]
1393
-
1394
- # Create ResponsesAgentRequest
1395
- request: ResponsesAgentRequest = ResponsesAgentRequest(
1396
- input=mlflow_messages,
1397
- custom_inputs=chat_payload.custom_inputs,
1398
- )
1485
+ def _register_default_template(
1486
+ self,
1487
+ prompt_name: str,
1488
+ default_template: str,
1489
+ description: str | None = None,
1490
+ set_champion: bool = True,
1491
+ ) -> PromptVersion:
1492
+ """Register default_template as a new prompt version.
1399
1493
 
1400
- # Call the ResponsesAgent's predict method
1401
- response: ResponsesAgentResponse = agent.predict(request)
1402
-
1403
- if response.output and len(response.output) > 0:
1404
- content = response.output[0].content
1405
- logger.debug(f"Response content type: {type(content)}")
1406
- logger.debug(f"Response content: {content}")
1407
-
1408
- # Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
1409
- # Content can be:
1410
- # - A string (return as is)
1411
- # - A list of items with 'text' keys (extract and join)
1412
- # - Other types (try to get 'text' attribute or convert to string)
1413
- if isinstance(content, str):
1414
- return content
1415
- elif isinstance(content, list):
1416
- text_parts = []
1417
- for content_item in content:
1418
- if isinstance(content_item, str):
1419
- text_parts.append(content_item)
1420
- elif isinstance(content_item, dict) and "text" in content_item:
1421
- text_parts.append(content_item["text"])
1422
- elif hasattr(content_item, "text"):
1423
- text_parts.append(content_item.text)
1424
- return "".join(text_parts) if text_parts else str(content)
1425
- else:
1426
- # Fallback for unknown types - try to extract text attribute
1427
- return getattr(content, "text", str(content))
1428
- else:
1429
- return ""
1494
+ Registers the template and sets the 'default' alias.
1495
+ Optionally sets 'champion' alias if no champion exists.
1430
1496
 
1431
- # Set registry URI for Databricks Unity Catalog
1432
- mlflow.set_registry_uri("databricks-uc")
1497
+ Args:
1498
+ prompt_name: Full name of the prompt
1499
+ default_template: The template content
1500
+ description: Optional description for commit message
1501
+ set_champion: Whether to also set champion alias (default: True)
1433
1502
 
1434
- # Run optimization with tracking disabled to prevent auto-registering all candidates
1435
- logger.info("Running prompt optimization with GepaPromptOptimizer...")
1503
+ If registration fails (e.g., in Model Serving with restricted permissions),
1504
+ logs the error and raises.
1505
+ """
1436
1506
  logger.info(
1437
- f"Generating {optimization.num_candidates} candidate prompts for evaluation"
1507
+ "Registering default template",
1508
+ prompt_name=prompt_name,
1509
+ set_champion=set_champion,
1438
1510
  )
1439
1511
 
1440
- from mlflow.genai.optimize.types import (
1441
- PromptOptimizationResult,
1442
- )
1443
-
1444
- result: PromptOptimizationResult = optimize_prompts(
1445
- predict_fn=predict_fn,
1446
- train_data=dataset,
1447
- prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
1448
- optimizer=optimizer,
1449
- scorers=scorers,
1450
- enable_tracking=False, # Don't auto-register all candidates
1451
- )
1452
-
1453
- # Log the optimization results
1454
- logger.info("Optimization complete!")
1455
- logger.info(f"Optimizer used: {result.optimizer_name}")
1456
-
1457
- if result.optimized_prompts:
1458
- optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
1459
-
1460
- # Check if the optimized prompt is actually different from the original
1461
- original_template: str = prompt_version.to_single_brace_format().strip()
1462
- optimized_template: str = (
1463
- optimized_prompt_version.to_single_brace_format().strip()
1464
- )
1465
-
1466
- # Normalize whitespace for more robust comparison
1467
- original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
1468
- optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
1469
-
1470
- logger.debug(f"Original template length: {len(original_template)} chars")
1471
- logger.debug(f"Optimized template length: {len(optimized_template)} chars")
1472
- logger.debug(
1473
- f"Templates identical: {original_normalized == optimized_normalized}"
1474
- )
1475
-
1476
- if original_normalized == optimized_normalized:
1477
- logger.info(
1478
- f"Optimized prompt is identical to original for '{prompt.full_name}'. "
1479
- "No new version will be registered."
1480
- )
1481
- return prompt
1482
-
1483
- logger.info("Optimized prompt is DIFFERENT from original")
1484
- logger.info(
1485
- f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
1486
- )
1487
- logger.debug(
1488
- f"Original template (first 300 chars): {original_template[:300]}..."
1489
- )
1490
- logger.debug(
1491
- f"Optimized template (first 300 chars): {optimized_template[:300]}..."
1512
+ try:
1513
+ commit_message = description or "Auto-synced from default_template"
1514
+ prompt_version = mlflow.genai.register_prompt(
1515
+ name=prompt_name,
1516
+ template=default_template,
1517
+ commit_message=commit_message,
1518
+ tags={"dao_ai": dao_ai_version()},
1492
1519
  )
1493
1520
 
1494
- # Check evaluation scores to determine if we should register the optimized prompt
1495
- should_register: bool = False
1496
- has_improvement: bool = False
1497
-
1498
- if (
1499
- result.initial_eval_score is not None
1500
- and result.final_eval_score is not None
1501
- ):
1502
- logger.info("Evaluation scores:")
1503
- logger.info(f" Initial score: {result.initial_eval_score}")
1504
- logger.info(f" Final score: {result.final_eval_score}")
1505
-
1506
- # Only register if there's improvement
1507
- if result.final_eval_score > result.initial_eval_score:
1508
- improvement: float = (
1509
- (result.final_eval_score - result.initial_eval_score)
1510
- / result.initial_eval_score
1511
- ) * 100
1512
- logger.info(
1513
- f"Optimized prompt improved by {improvement:.2f}% "
1514
- f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
1515
- )
1516
- should_register = True
1517
- has_improvement = True
1518
- else:
1519
- logger.info(
1520
- f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
1521
- "No new version will be registered."
1522
- )
1523
- else:
1524
- # No scores available - register anyway but warn
1525
- logger.warning(
1526
- "No evaluation scores available to compare performance. "
1527
- "Registering optimized prompt without performance validation."
1528
- )
1529
- should_register = True
1530
-
1531
- if not should_register:
1532
- logger.info(
1533
- f"Skipping registration for '{prompt.full_name}' (no improvement)"
1534
- )
1535
- return prompt
1536
-
1537
- # Register the optimized prompt manually
1521
+ # Always set default alias
1538
1522
  try:
1539
- logger.info(f"Registering optimized prompt '{prompt.full_name}'")
1540
- registered_version: PromptVersion = mlflow.genai.register_prompt(
1541
- name=prompt.full_name,
1542
- template=optimized_template,
1543
- commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
1544
- tags={
1545
- "dao_ai": dao_ai_version(),
1546
- "target_model": agent_model.model.uri,
1547
- },
1548
- )
1549
- logger.info(
1550
- f"Registered optimized prompt as version {registered_version.version}"
1551
- )
1552
-
1553
- # Always set "latest" alias (represents most recently registered prompt)
1554
- logger.info(
1555
- f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
1523
+ logger.debug(
1524
+ "Setting default alias",
1525
+ prompt_name=prompt_name,
1526
+ version=prompt_version.version,
1556
1527
  )
1557
1528
  mlflow.genai.set_prompt_alias(
1558
- name=prompt.full_name,
1559
- alias="latest",
1560
- version=registered_version.version,
1529
+ name=prompt_name, alias="default", version=prompt_version.version
1561
1530
  )
1562
- logger.info(
1563
- f"Successfully set 'latest' alias for '{prompt.full_name}' v{registered_version.version}"
1531
+ logger.success(
1532
+ "Set default alias for prompt",
1533
+ prompt_name=prompt_name,
1534
+ version=prompt_version.version,
1535
+ )
1536
+ except Exception as alias_error:
1537
+ logger.warning(
1538
+ "Could not set default alias",
1539
+ prompt_name=prompt_name,
1540
+ error=str(alias_error),
1564
1541
  )
1565
1542
 
1566
- # If there's confirmed improvement, also set the "champion" alias
1567
- # (represents the prompt that should be used by deployed agents)
1568
- if has_improvement:
1569
- logger.info(
1570
- f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
1571
- )
1543
+ # Optionally set champion alias (only if no champion exists or explicitly requested)
1544
+ if set_champion:
1545
+ try:
1572
1546
  mlflow.genai.set_prompt_alias(
1573
- name=prompt.full_name,
1547
+ name=prompt_name,
1574
1548
  alias="champion",
1575
- version=registered_version.version,
1549
+ version=prompt_version.version,
1576
1550
  )
1577
- logger.info(
1578
- f"Successfully set 'champion' alias for '{prompt.full_name}' v{registered_version.version}"
1551
+ logger.success(
1552
+ "Set champion alias for prompt",
1553
+ prompt_name=prompt_name,
1554
+ version=prompt_version.version,
1555
+ )
1556
+ except Exception as alias_error:
1557
+ logger.warning(
1558
+ "Could not set champion alias",
1559
+ prompt_name=prompt_name,
1560
+ error=str(alias_error),
1579
1561
  )
1580
1562
 
1581
- # Add target_model and dao_ai tags
1582
- tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
1583
- tags["target_model"] = agent_model.model.uri
1584
- tags["dao_ai"] = dao_ai_version()
1585
-
1586
- # Return the optimized prompt with the appropriate alias
1587
- # Use "champion" if there was improvement, otherwise "latest"
1588
- result_alias: str = "champion" if has_improvement else "latest"
1589
- return PromptModel(
1590
- name=prompt.name,
1591
- schema=prompt.schema_model,
1592
- description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
1593
- alias=result_alias,
1594
- tags=tags,
1595
- )
1563
+ return prompt_version
1596
1564
 
1597
- except Exception as e:
1598
- logger.error(
1599
- f"Failed to register optimized prompt '{prompt.full_name}': {e}"
1600
- )
1601
- return prompt
1602
- else:
1603
- logger.warning("No optimized prompts returned from optimization")
1604
- return prompt
1565
+ except Exception as reg_error:
1566
+ logger.error(
1567
+ "Failed to register prompt - please register from notebook with write permissions",
1568
+ prompt_name=prompt_name,
1569
+ error=str(reg_error),
1570
+ )
1571
+ return PromptVersion(
1572
+ name=prompt_name,
1573
+ version=1,
1574
+ template=default_template,
1575
+ tags={"dao_ai": dao_ai_version()},
1576
+ )