dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,6 +1,5 @@
1
1
  import base64
2
2
  import uuid
3
- from importlib.metadata import version
4
3
  from pathlib import Path
5
4
  from typing import Any, Callable, Final, Sequence
6
5
 
@@ -32,6 +31,7 @@ from mlflow import MlflowClient
32
31
  from mlflow.entities import Experiment
33
32
  from mlflow.entities.model_registry import PromptVersion
34
33
  from mlflow.entities.model_registry.model_version import ModelVersion
34
+ from mlflow.genai.prompts import load_prompt
35
35
  from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
36
36
  from mlflow.models.model import ModelInfo
37
37
  from mlflow.models.resources import (
@@ -46,6 +46,7 @@ from dao_ai.config import (
46
46
  AppConfig,
47
47
  ConnectionModel,
48
48
  DatabaseModel,
49
+ DatabricksAppModel,
49
50
  DatasetModel,
50
51
  FunctionModel,
51
52
  GenieRoomModel,
@@ -65,9 +66,11 @@ from dao_ai.config import (
65
66
  from dao_ai.models import get_latest_model_version
66
67
  from dao_ai.providers.base import ServiceProvider
67
68
  from dao_ai.utils import (
69
+ dao_ai_version,
68
70
  get_installed_packages,
69
71
  is_installed,
70
72
  is_lib_provided,
73
+ normalize_host,
71
74
  normalize_name,
72
75
  )
73
76
  from dao_ai.vector_search import endpoint_exists, index_exists
@@ -89,15 +92,18 @@ def _workspace_client(
89
92
  Create a WorkspaceClient instance with the provided parameters.
90
93
  If no parameters are provided, it will use the default configuration.
91
94
  """
92
- if client_id and client_secret and workspace_host:
95
+ # Normalize the workspace host to ensure it has https:// scheme
96
+ normalized_host = normalize_host(workspace_host)
97
+
98
+ if client_id and client_secret and normalized_host:
93
99
  return WorkspaceClient(
94
- host=workspace_host,
100
+ host=normalized_host,
95
101
  client_id=client_id,
96
102
  client_secret=client_secret,
97
103
  auth_type="oauth-m2m",
98
104
  )
99
105
  elif pat:
100
- return WorkspaceClient(host=workspace_host, token=pat, auth_type="pat")
106
+ return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
101
107
  else:
102
108
  return WorkspaceClient()
103
109
 
@@ -112,15 +118,18 @@ def _vector_search_client(
112
118
  Create a VectorSearchClient instance with the provided parameters.
113
119
  If no parameters are provided, it will use the default configuration.
114
120
  """
115
- if client_id and client_secret and workspace_host:
121
+ # Normalize the workspace host to ensure it has https:// scheme
122
+ normalized_host = normalize_host(workspace_host)
123
+
124
+ if client_id and client_secret and normalized_host:
116
125
  return VectorSearchClient(
117
- workspace_url=workspace_host,
126
+ workspace_url=normalized_host,
118
127
  service_principal_client_id=client_id,
119
128
  service_principal_client_secret=client_secret,
120
129
  )
121
- elif pat and workspace_host:
130
+ elif pat and normalized_host:
122
131
  return VectorSearchClient(
123
- workspace_url=workspace_host,
132
+ workspace_url=normalized_host,
124
133
  personal_access_token=pat,
125
134
  )
126
135
  else:
@@ -172,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
172
181
  experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
173
182
  if experiment is None:
174
183
  experiment_id: str = mlflow.create_experiment(name=experiment_name)
175
- logger.info(
176
- f"Created new experiment: {experiment_name} (ID: {experiment_id})"
184
+ logger.success(
185
+ "Created new MLflow experiment",
186
+ experiment_name=experiment_name,
187
+ experiment_id=experiment_id,
177
188
  )
178
189
  experiment = mlflow.get_experiment(experiment_id)
179
190
  return experiment
180
191
 
181
192
  def create_token(self) -> str:
182
193
  current_user: User = self.w.current_user.me()
183
- logger.debug(f"Authenticated to Databricks as {current_user}")
194
+ logger.debug("Authenticated to Databricks", user=str(current_user))
184
195
  headers: dict[str, str] = self.w.config.authenticate()
185
196
  token: str = headers["Authorization"].replace("Bearer ", "")
186
197
  return token
@@ -192,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
192
203
  secret_response: GetSecretResponse = self.w.secrets.get_secret(
193
204
  secret_scope, secret_key
194
205
  )
195
- logger.debug(f"Retrieved secret {secret_key} from scope {secret_scope}")
206
+ logger.trace(
207
+ "Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
208
+ )
196
209
  encoded_secret: str = secret_response.value
197
210
  decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
198
211
  return decoded_secret
199
212
  except NotFound:
200
213
  logger.warning(
201
- f"Secret {secret_key} not found in scope {secret_scope}, using default value"
214
+ "Secret not found, using default value",
215
+ secret_key=secret_key,
216
+ secret_scope=secret_scope,
202
217
  )
203
218
  except Exception as e:
204
219
  logger.error(
205
- f"Error retrieving secret {secret_key} from scope {secret_scope}: {e}"
220
+ "Error retrieving secret",
221
+ secret_key=secret_key,
222
+ secret_scope=secret_scope,
223
+ error=str(e),
206
224
  )
207
225
 
208
226
  return default_value
@@ -211,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
211
229
  self,
212
230
  config: AppConfig,
213
231
  ) -> ModelInfo:
214
- logger.debug("Creating agent...")
232
+ logger.info("Creating agent")
215
233
  mlflow.set_registry_uri("databricks-uc")
216
234
 
235
+ # Set up experiment for proper tracking
236
+ experiment: Experiment = self.get_or_create_experiment(config)
237
+ mlflow.set_experiment(experiment_id=experiment.experiment_id)
238
+ logger.debug(
239
+ "Using MLflow experiment",
240
+ experiment_name=experiment.name,
241
+ experiment_id=experiment.experiment_id,
242
+ )
243
+
217
244
  llms: Sequence[LLMModel] = list(config.resources.llms.values())
218
245
  vector_indexes: Sequence[IndexModel] = list(
219
246
  config.resources.vector_stores.values()
@@ -231,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
231
258
  )
232
259
  databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
233
260
  volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
261
+ apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
234
262
 
235
263
  resources: Sequence[IsDatabricksResource] = (
236
264
  llms
@@ -242,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
242
270
  + connections
243
271
  + databases
244
272
  + volumes
273
+ + apps
245
274
  )
246
275
 
247
276
  # Flatten all resources from all models into a single list
@@ -255,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
255
284
  for resource in r.as_resources()
256
285
  if not r.on_behalf_of_user
257
286
  ]
258
- logger.debug(f"system_resources: {[r.name for r in system_resources]}")
287
+ logger.trace(
288
+ "System resources identified",
289
+ count=len(system_resources),
290
+ resources=[r.name for r in system_resources],
291
+ )
259
292
 
260
293
  system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
261
294
  resources=system_resources
262
295
  )
263
- logger.debug(f"system_auth_policy: {system_auth_policy}")
296
+ logger.trace("System auth policy created", policy=str(system_auth_policy))
264
297
 
265
298
  api_scopes: Sequence[str] = list(
266
299
  set(
@@ -272,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
272
305
  ]
273
306
  )
274
307
  )
275
- logger.debug(f"api_scopes: {api_scopes}")
308
+ logger.trace("API scopes identified", scopes=api_scopes)
276
309
 
277
310
  user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
278
- logger.debug(f"user_auth_policy: {user_auth_policy}")
311
+ logger.trace("User auth policy created", policy=str(user_auth_policy))
279
312
 
280
313
  auth_policy: AuthPolicy = AuthPolicy(
281
314
  system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
282
315
  )
283
- logger.debug(f"auth_policy: {auth_policy}")
316
+ logger.debug(
317
+ "Auth policy created",
318
+ has_system_auth=system_auth_policy is not None,
319
+ has_user_auth=user_auth_policy is not None,
320
+ )
284
321
 
285
322
  code_paths: list[str] = config.app.code_paths
286
323
  for path in code_paths:
@@ -296,7 +333,7 @@ class DatabricksProvider(ServiceProvider):
296
333
  if is_installed():
297
334
  if not is_lib_provided("dao-ai", pip_requirements):
298
335
  pip_requirements += [
299
- f"dao-ai=={version('dao-ai')}",
336
+ f"dao-ai=={dao_ai_version()}",
300
337
  ]
301
338
  else:
302
339
  src_path: Path = model_root_path.parent
@@ -307,27 +344,52 @@ class DatabricksProvider(ServiceProvider):
307
344
 
308
345
  pip_requirements += get_installed_packages()
309
346
 
310
- logger.debug(f"pip_requirements: {pip_requirements}")
311
- logger.debug(f"code_paths: {code_paths}")
347
+ logger.trace("Pip requirements prepared", count=len(pip_requirements))
348
+ logger.trace("Code paths prepared", count=len(code_paths))
312
349
 
313
350
  run_name: str = normalize_name(config.app.name)
314
- logger.debug(f"run_name: {run_name}")
315
- logger.debug(f"model_path: {model_path.as_posix()}")
351
+ logger.debug(
352
+ "Agent run configuration",
353
+ run_name=run_name,
354
+ model_path=model_path.as_posix(),
355
+ )
316
356
 
317
357
  input_example: dict[str, Any] = None
318
358
  if config.app.input_example:
319
359
  input_example = config.app.input_example.model_dump()
320
360
 
321
- logger.debug(f"input_example: {input_example}")
361
+ logger.trace("Input example configured", has_example=input_example is not None)
362
+
363
+ # Create conda environment with configured Python version
364
+ # This allows deploying from environments with different Python versions
365
+ # (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
366
+ target_python_version: str = config.app.python_version
367
+ logger.debug("Target Python version configured", version=target_python_version)
368
+
369
+ conda_env: dict[str, Any] = {
370
+ "name": "mlflow-env",
371
+ "channels": ["conda-forge"],
372
+ "dependencies": [
373
+ f"python={target_python_version}",
374
+ "pip",
375
+ {"pip": list(pip_requirements)},
376
+ ],
377
+ }
378
+ logger.trace(
379
+ "Conda environment configured",
380
+ python_version=target_python_version,
381
+ pip_packages_count=len(pip_requirements),
382
+ )
322
383
 
323
384
  with mlflow.start_run(run_name=run_name):
324
385
  mlflow.set_tag("type", "agent")
386
+ mlflow.set_tag("dao_ai", dao_ai_version())
325
387
  logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
326
388
  python_model=model_path.as_posix(),
327
389
  code_paths=code_paths,
328
- model_config=config.model_dump(by_alias=True),
390
+ model_config=config.model_dump(mode="json", by_alias=True),
329
391
  name="agent",
330
- pip_requirements=pip_requirements,
392
+ conda_env=conda_env,
331
393
  input_example=input_example,
332
394
  # resources=all_resources,
333
395
  auth_policy=auth_policy,
@@ -338,15 +400,26 @@ class DatabricksProvider(ServiceProvider):
338
400
  model_version: ModelVersion = mlflow.register_model(
339
401
  name=registered_model_name, model_uri=logged_agent_info.model_uri
340
402
  )
341
- logger.debug(
342
- f"Registered model: {registered_model_name} with version: {model_version.version}"
403
+ logger.success(
404
+ "Model registered",
405
+ model_name=registered_model_name,
406
+ version=model_version.version,
343
407
  )
344
408
 
345
409
  client: MlflowClient = MlflowClient()
346
410
 
411
+ # Set tags on the model version
412
+ client.set_model_version_tag(
413
+ name=registered_model_name,
414
+ version=model_version.version,
415
+ key="dao_ai",
416
+ value=dao_ai_version(),
417
+ )
418
+ logger.trace("Set dao_ai tag on model version", version=model_version.version)
419
+
347
420
  client.set_registered_model_alias(
348
421
  name=registered_model_name,
349
- alias="Current",
422
+ alias="Champion",
350
423
  version=model_version.version,
351
424
  )
352
425
 
@@ -359,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
359
432
  aliased_model: ModelVersion = client.get_model_version_by_alias(
360
433
  registered_model_name, config.app.alias
361
434
  )
362
- logger.debug(
363
- f"Model {registered_model_name} aliased to {config.app.alias} with version: {aliased_model.version}"
435
+ logger.info(
436
+ "Model aliased",
437
+ model_name=registered_model_name,
438
+ alias=config.app.alias,
439
+ version=aliased_model.version,
364
440
  )
365
441
 
366
442
  def deploy_agent(self, config: AppConfig) -> None:
367
- logger.debug("Deploying agent...")
443
+ logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
368
444
  mlflow.set_registry_uri("databricks-uc")
369
445
 
370
446
  endpoint_name: str = config.app.endpoint_name
@@ -372,7 +448,10 @@ class DatabricksProvider(ServiceProvider):
372
448
  scale_to_zero: bool = config.app.scale_to_zero
373
449
  environment_vars: dict[str, str] = config.app.environment_vars
374
450
  workload_size: str = config.app.workload_size
375
- tags: dict[str, str] = config.app.tags
451
+ tags: dict[str, str] = config.app.tags.copy() if config.app.tags else {}
452
+
453
+ # Add dao_ai framework tag
454
+ tags["dao_ai"] = dao_ai_version()
376
455
 
377
456
  latest_version: int = get_latest_model_version(registered_model_name)
378
457
 
@@ -382,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
382
461
  agents.get_deployments(endpoint_name)
383
462
  endpoint_exists = True
384
463
  logger.debug(
385
- f"Endpoint {endpoint_name} already exists, updating without tags to avoid conflicts..."
464
+ "Endpoint already exists, updating", endpoint_name=endpoint_name
386
465
  )
387
466
  except Exception:
388
- logger.debug(
389
- f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
390
- )
467
+ logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
391
468
 
392
469
  # Deploy - skip tags for existing endpoints to avoid conflicts
393
470
  agents.deploy(
@@ -403,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
403
480
  registered_model_name: str = config.app.registered_model.full_name
404
481
  permissions: Sequence[dict[str, Any]] = config.app.permissions
405
482
 
406
- logger.debug(registered_model_name)
407
- logger.debug(permissions)
483
+ logger.debug(
484
+ "Configuring model permissions",
485
+ model_name=registered_model_name,
486
+ permissions_count=len(permissions),
487
+ )
408
488
 
409
489
  for permission in permissions:
410
490
  principals: Sequence[str] = permission.principals
@@ -424,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
424
504
  try:
425
505
  catalog_info = self.w.catalogs.get(name=schema.catalog_name)
426
506
  except NotFound:
427
- logger.debug(f"Creating catalog: {schema.catalog_name}")
507
+ logger.info("Creating catalog", catalog_name=schema.catalog_name)
428
508
  catalog_info = self.w.catalogs.create(name=schema.catalog_name)
429
509
  return catalog_info
430
510
 
@@ -434,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
434
514
  try:
435
515
  schema_info = self.w.schemas.get(full_name=schema.full_name)
436
516
  except NotFound:
437
- logger.debug(f"Creating schema: {schema.full_name}")
517
+ logger.info("Creating schema", schema_name=schema.full_name)
438
518
  schema_info = self.w.schemas.create(
439
519
  name=schema.schema_name, catalog_name=catalog_info.name
440
520
  )
@@ -446,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
446
526
  try:
447
527
  volume_info = self.w.volumes.read(name=volume.full_name)
448
528
  except NotFound:
449
- logger.debug(f"Creating volume: {volume.full_name}")
529
+ logger.info("Creating volume", volume_name=volume.full_name)
450
530
  volume_info = self.w.volumes.create(
451
531
  catalog_name=schema_info.catalog_name,
452
532
  schema_name=schema_info.name,
@@ -457,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
457
537
 
458
538
  def create_path(self, volume_path: VolumePathModel) -> Path:
459
539
  path: Path = volume_path.full_name
460
- logger.info(f"Creating volume path: {path}")
540
+ logger.info("Creating volume path", path=str(path))
461
541
  self.w.files.create_directory(path)
462
542
  return path
463
543
 
@@ -498,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
498
578
 
499
579
  if ddl:
500
580
  ddl_path: Path = Path(ddl)
501
- logger.debug(f"Executing DDL from: {ddl_path}")
581
+ logger.debug("Executing DDL", ddl_path=str(ddl_path))
502
582
  statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
503
583
  for statement in statements:
504
- logger.debug(statement)
505
- logger.debug(f"args: {args}")
584
+ logger.trace(
585
+ "Executing DDL statement", statement=str(statement)[:100], args=args
586
+ )
506
587
  spark.sql(
507
588
  str(statement),
508
589
  args=args,
@@ -511,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
511
592
  if data:
512
593
  data_path: Path = Path(data)
513
594
  if format == "sql":
514
- logger.debug(f"Executing SQL from: {data_path}")
595
+ logger.debug("Executing SQL from file", data_path=str(data_path))
515
596
  data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
516
597
  for statement in data_statements:
517
- logger.debug(statement)
518
- logger.debug(f"args: {args}")
598
+ logger.trace(
599
+ "Executing SQL statement",
600
+ statement=str(statement)[:100],
601
+ args=args,
602
+ )
519
603
  spark.sql(
520
604
  str(statement),
521
605
  args=args,
522
606
  )
523
607
  else:
524
- logger.debug(f"Writing to: {table}")
608
+ logger.debug("Writing dataset to table", table=table)
525
609
  if not data_path.is_absolute():
526
610
  data_path = current_dir / data_path
527
- logger.debug(f"Data path: {data_path.as_posix()}")
611
+ logger.trace("Data path resolved", path=data_path.as_posix())
528
612
  if format == "excel":
529
613
  pdf = pd.read_excel(data_path.as_posix())
530
614
  df = spark.createDataFrame(pdf, schema=dataset.table_schema)
@@ -548,13 +632,17 @@ class DatabricksProvider(ServiceProvider):
548
632
  verbose=True,
549
633
  )
550
634
 
551
- logger.debug(f"Endpoint named {vector_store.endpoint.name} is ready.")
635
+ logger.success(
636
+ "Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
637
+ )
552
638
 
553
639
  if not index_exists(
554
640
  self.vsc, vector_store.endpoint.name, vector_store.index.full_name
555
641
  ):
556
- logger.debug(
557
- f"Creating index {vector_store.index.full_name} on endpoint {vector_store.endpoint.name}..."
642
+ logger.info(
643
+ "Creating vector search index",
644
+ index_name=vector_store.index.full_name,
645
+ endpoint_name=vector_store.endpoint.name,
558
646
  )
559
647
  self.vsc.create_delta_sync_index_and_wait(
560
648
  endpoint_name=vector_store.endpoint.name,
@@ -568,7 +656,8 @@ class DatabricksProvider(ServiceProvider):
568
656
  )
569
657
  else:
570
658
  logger.debug(
571
- f"Index {vector_store.index.full_name} already exists, checking status and syncing..."
659
+ "Vector search index already exists, checking status",
660
+ index_name=vector_store.index.full_name,
572
661
  )
573
662
  index = self.vsc.get_index(
574
663
  vector_store.endpoint.name, vector_store.index.full_name
@@ -591,54 +680,61 @@ class DatabricksProvider(ServiceProvider):
591
680
 
592
681
  if pipeline_status in [
593
682
  "COMPLETED",
683
+ "ONLINE",
594
684
  "FAILED",
595
685
  "CANCELED",
596
686
  "ONLINE_PIPELINE_FAILED",
597
687
  ]:
598
- logger.debug(
599
- f"Index is ready to sync (status: {pipeline_status})"
600
- )
688
+ logger.debug("Index ready to sync", status=pipeline_status)
601
689
  break
602
690
  elif pipeline_status in [
603
691
  "WAITING_FOR_RESOURCES",
604
692
  "PROVISIONING",
605
693
  "INITIALIZING",
606
694
  "INDEXING",
607
- "ONLINE",
608
695
  ]:
609
- logger.debug(
610
- f"Index not ready yet (status: {pipeline_status}), waiting {wait_interval} seconds..."
696
+ logger.trace(
697
+ "Index not ready, waiting",
698
+ status=pipeline_status,
699
+ wait_seconds=wait_interval,
611
700
  )
612
701
  time.sleep(wait_interval)
613
702
  elapsed += wait_interval
614
703
  else:
615
704
  logger.warning(
616
- f"Unknown pipeline status: {pipeline_status}, attempting sync anyway"
705
+ "Unknown pipeline status, attempting sync",
706
+ status=pipeline_status,
617
707
  )
618
708
  break
619
709
  except Exception as status_error:
620
710
  logger.warning(
621
- f"Could not check index status: {status_error}, attempting sync anyway"
711
+ "Could not check index status, attempting sync",
712
+ error=str(status_error),
622
713
  )
623
714
  break
624
715
 
625
716
  if elapsed >= max_wait_time:
626
717
  logger.warning(
627
- f"Timed out waiting for index to be ready after {max_wait_time} seconds"
718
+ "Timed out waiting for index to be ready",
719
+ max_wait_seconds=max_wait_time,
628
720
  )
629
721
 
630
722
  # Now attempt to sync
631
723
  try:
632
724
  index.sync()
633
- logger.debug("Index sync completed successfully")
725
+ logger.success("Index sync completed")
634
726
  except Exception as sync_error:
635
727
  if "not ready to sync yet" in str(sync_error).lower():
636
- logger.warning(f"Index still not ready to sync: {sync_error}")
728
+ logger.warning(
729
+ "Index still not ready to sync", error=str(sync_error)
730
+ )
637
731
  else:
638
732
  raise sync_error
639
733
 
640
- logger.debug(
641
- f"index {vector_store.index.full_name} on table {vector_store.source_table.full_name} is ready"
734
+ logger.success(
735
+ "Vector search index ready",
736
+ index_name=vector_store.index.full_name,
737
+ source_table=vector_store.source_table.full_name,
642
738
  )
643
739
 
644
740
  def get_vector_index(self, vector_store: VectorStoreModel) -> None:
@@ -674,12 +770,16 @@ class DatabricksProvider(ServiceProvider):
674
770
  # sql = sql.replace("{catalog_name}", schema.catalog_name)
675
771
  # sql = sql.replace("{schema_name}", schema.schema_name)
676
772
 
677
- logger.info(function.name)
678
- logger.info(sql)
773
+ logger.info("Creating SQL function", function_name=function.name)
774
+ logger.trace("SQL function body", sql=sql[:200])
679
775
  _: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
680
776
 
681
777
  if unity_catalog_function.test:
682
- logger.info(unity_catalog_function.test.parameters)
778
+ logger.debug(
779
+ "Testing function",
780
+ function_name=function.full_name,
781
+ parameters=unity_catalog_function.test.parameters,
782
+ )
683
783
 
684
784
  result: FunctionExecutionResult = self.dfs.execute_function(
685
785
  function_name=function.full_name,
@@ -687,37 +787,50 @@ class DatabricksProvider(ServiceProvider):
687
787
  )
688
788
 
689
789
  if result.error:
690
- logger.error(result.error)
790
+ logger.error(
791
+ "Function test failed",
792
+ function_name=function.full_name,
793
+ error=result.error,
794
+ )
691
795
  else:
692
- logger.info(f"Function {function.full_name} executed successfully.")
693
- logger.info(f"Result: {result}")
796
+ logger.success(
797
+ "Function test passed", function_name=function.full_name
798
+ )
799
+ logger.debug("Function test result", result=str(result))
694
800
 
695
801
  def find_columns(self, table_model: TableModel) -> Sequence[str]:
696
- logger.debug(f"Finding columns for table: {table_model.full_name}")
802
+ logger.trace("Finding columns for table", table=table_model.full_name)
697
803
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
698
804
  columns: Sequence[ColumnInfo] = table_info.columns
699
805
  column_names: Sequence[str] = [c.name for c in columns]
700
- logger.debug(f"Columns found: {column_names}")
806
+ logger.debug(
807
+ "Columns found",
808
+ table=table_model.full_name,
809
+ columns_count=len(column_names),
810
+ )
701
811
  return column_names
702
812
 
703
813
  def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
704
- logger.debug(f"Finding primary key for table: {table_model.full_name}")
814
+ logger.trace("Finding primary key for table", table=table_model.full_name)
705
815
  primary_keys: Sequence[str] | None = None
706
816
  table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
707
817
  constraints: Sequence[TableConstraint] = table_info.table_constraints
708
818
  primary_key_constraint: PrimaryKeyConstraint | None = next(
709
- c.primary_key_constraint for c in constraints if c.primary_key_constraint
819
+ (c.primary_key_constraint for c in constraints if c.primary_key_constraint),
820
+ None,
710
821
  )
711
822
  if primary_key_constraint:
712
823
  primary_keys = primary_key_constraint.child_columns
713
824
 
714
- logger.debug(f"Primary key for table {table_model.full_name}: {primary_keys}")
825
+ logger.debug(
826
+ "Primary key found", table=table_model.full_name, primary_keys=primary_keys
827
+ )
715
828
  return primary_keys
716
829
 
717
830
  def find_vector_search_endpoint(
718
831
  self, predicate: Callable[[dict[str, Any]], bool]
719
832
  ) -> str | None:
720
- logger.debug("Finding vector search endpoint...")
833
+ logger.trace("Finding vector search endpoint")
721
834
  endpoint_name: str | None = None
722
835
  vector_search_endpoints: Sequence[dict[str, Any]] = (
723
836
  self.vsc.list_endpoints().get("endpoints", [])
@@ -726,11 +839,13 @@ class DatabricksProvider(ServiceProvider):
726
839
  if predicate(endpoint):
727
840
  endpoint_name = endpoint["name"]
728
841
  break
729
- logger.debug(f"Vector search endpoint found: {endpoint_name}")
842
+ logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
730
843
  return endpoint_name
731
844
 
732
845
  def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
733
- logger.debug(f"Finding vector search index: {index_model.full_name}")
846
+ logger.trace(
847
+ "Finding endpoint for vector search index", index_name=index_model.full_name
848
+ )
734
849
  all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
735
850
  "endpoints", []
736
851
  )
@@ -740,14 +855,99 @@ class DatabricksProvider(ServiceProvider):
740
855
  endpoint_name: str = endpoint["name"]
741
856
  indexes = self.vsc.list_indexes(name=endpoint_name)
742
857
  vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
743
- logger.trace(f"Endpoint: {endpoint_name}, vector_indexes: {vector_indexes}")
858
+ logger.trace(
859
+ "Checking endpoint for indexes",
860
+ endpoint_name=endpoint_name,
861
+ indexes_count=len(vector_indexes),
862
+ )
744
863
  index_names = [vector_index["name"] for vector_index in vector_indexes]
745
864
  if index_name in index_names:
746
865
  found_endpoint_name = endpoint_name
747
866
  break
748
- logger.debug(f"Vector search index found: {found_endpoint_name}")
867
+ logger.debug(
868
+ "Vector search index endpoint found",
869
+ index_name=index_model.full_name,
870
+ endpoint_name=found_endpoint_name,
871
+ )
749
872
  return found_endpoint_name
750
873
 
874
+ def _wait_for_database_available(
875
+ self,
876
+ workspace_client: WorkspaceClient,
877
+ instance_name: str,
878
+ max_wait_time: int = 600,
879
+ wait_interval: int = 10,
880
+ ) -> None:
881
+ """
882
+ Wait for a database instance to become AVAILABLE.
883
+
884
+ Args:
885
+ workspace_client: The Databricks workspace client
886
+ instance_name: Name of the database instance to wait for
887
+ max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
888
+ wait_interval: Time between status checks in seconds (default: 10)
889
+
890
+ Raises:
891
+ TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
892
+ RuntimeError: If the database enters a failed or deleted state
893
+ """
894
+ import time
895
+ from typing import Any
896
+
897
+ logger.info(
898
+ "Waiting for database instance to become AVAILABLE",
899
+ instance_name=instance_name,
900
+ )
901
+ elapsed: int = 0
902
+
903
+ while elapsed < max_wait_time:
904
+ try:
905
+ current_instance: Any = workspace_client.database.get_database_instance(
906
+ name=instance_name
907
+ )
908
+ current_state: str = current_instance.state
909
+ logger.trace(
910
+ "Database instance state checked",
911
+ instance_name=instance_name,
912
+ state=current_state,
913
+ )
914
+
915
+ if current_state == "AVAILABLE":
916
+ logger.success(
917
+ "Database instance is now AVAILABLE",
918
+ instance_name=instance_name,
919
+ )
920
+ return
921
+ elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
922
+ logger.trace(
923
+ "Database instance not ready, waiting",
924
+ instance_name=instance_name,
925
+ state=current_state,
926
+ wait_seconds=wait_interval,
927
+ )
928
+ time.sleep(wait_interval)
929
+ elapsed += wait_interval
930
+ elif current_state in ["STOPPED", "DELETING", "FAILED"]:
931
+ raise RuntimeError(
932
+ f"Database instance {instance_name} entered unexpected state: {current_state}"
933
+ )
934
+ else:
935
+ logger.warning(
936
+ "Unknown database state, continuing to wait",
937
+ instance_name=instance_name,
938
+ state=current_state,
939
+ )
940
+ time.sleep(wait_interval)
941
+ elapsed += wait_interval
942
+ except NotFound:
943
+ raise RuntimeError(
944
+ f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
945
+ )
946
+
947
+ raise TimeoutError(
948
+ f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
949
+ )
950
+
751
951
  def create_lakebase(self, database: DatabaseModel) -> None:
752
952
  """
753
953
  Create a Lakebase database instance using the Databricks workspace client.
@@ -778,13 +978,17 @@ class DatabricksProvider(ServiceProvider):
778
978
 
779
979
  if existing_instance:
780
980
  logger.debug(
781
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
981
+ "Database instance already exists",
982
+ instance_name=database.instance_name,
983
+ state=existing_instance.state,
782
984
  )
783
985
 
784
986
  # Check if database is in an intermediate state
785
987
  if existing_instance.state in ["STARTING", "UPDATING"]:
786
988
  logger.info(
787
- f"Database instance {database.instance_name} is in {existing_instance.state} state, waiting for it to become AVAILABLE..."
989
+ "Database instance in intermediate state, waiting",
990
+ instance_name=database.instance_name,
991
+ state=existing_instance.state,
788
992
  )
789
993
 
790
994
  # Wait for database to reach a stable state
@@ -800,65 +1004,87 @@ class DatabricksProvider(ServiceProvider):
800
1004
  )
801
1005
  )
802
1006
  current_state: str = current_instance.state
803
- logger.debug(f"Database instance state: {current_state}")
1007
+ logger.trace(
1008
+ "Checking database instance state",
1009
+ instance_name=database.instance_name,
1010
+ state=current_state,
1011
+ )
804
1012
 
805
1013
  if current_state == "AVAILABLE":
806
- logger.info(
807
- f"Database instance {database.instance_name} is now AVAILABLE"
1014
+ logger.success(
1015
+ "Database instance is now AVAILABLE",
1016
+ instance_name=database.instance_name,
808
1017
  )
809
1018
  break
810
1019
  elif current_state in ["STARTING", "UPDATING"]:
811
- logger.debug(
812
- f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
1020
+ logger.trace(
1021
+ "Database instance not ready, waiting",
1022
+ instance_name=database.instance_name,
1023
+ state=current_state,
1024
+ wait_seconds=wait_interval,
813
1025
  )
814
1026
  time.sleep(wait_interval)
815
1027
  elapsed += wait_interval
816
1028
  elif current_state in ["STOPPED", "DELETING"]:
817
1029
  logger.warning(
818
- f"Database instance {database.instance_name} is in unexpected state: {current_state}"
1030
+ "Database instance in unexpected state",
1031
+ instance_name=database.instance_name,
1032
+ state=current_state,
819
1033
  )
820
1034
  break
821
1035
  else:
822
1036
  logger.warning(
823
- f"Unknown database state: {current_state}, proceeding anyway"
1037
+ "Unknown database state, proceeding",
1038
+ instance_name=database.instance_name,
1039
+ state=current_state,
824
1040
  )
825
1041
  break
826
1042
  except NotFound:
827
1043
  logger.warning(
828
- f"Database instance {database.instance_name} no longer exists, will attempt to recreate"
1044
+ "Database instance no longer exists, will recreate",
1045
+ instance_name=database.instance_name,
829
1046
  )
830
1047
  break
831
1048
  except Exception as state_error:
832
1049
  logger.warning(
833
- f"Could not check database state: {state_error}, proceeding anyway"
1050
+ "Could not check database state, proceeding",
1051
+ instance_name=database.instance_name,
1052
+ error=str(state_error),
834
1053
  )
835
1054
  break
836
1055
 
837
1056
  if elapsed >= max_wait_time:
838
1057
  logger.warning(
839
- f"Timed out waiting for database instance {database.instance_name} to become AVAILABLE after {max_wait_time} seconds"
1058
+ "Timed out waiting for database to become AVAILABLE",
1059
+ instance_name=database.instance_name,
1060
+ max_wait_seconds=max_wait_time,
840
1061
  )
841
1062
 
842
1063
  elif existing_instance.state == "AVAILABLE":
843
1064
  logger.info(
844
- f"Database instance {database.instance_name} already exists and is AVAILABLE"
1065
+ "Database instance already exists and is AVAILABLE",
1066
+ instance_name=database.instance_name,
845
1067
  )
846
1068
  return
847
1069
  elif existing_instance.state in ["STOPPED", "DELETING"]:
848
1070
  logger.warning(
849
- f"Database instance {database.instance_name} is in {existing_instance.state} state"
1071
+ "Database instance in terminal state",
1072
+ instance_name=database.instance_name,
1073
+ state=existing_instance.state,
850
1074
  )
851
1075
  return
852
1076
  else:
853
1077
  logger.info(
854
- f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
1078
+ "Database instance already exists",
1079
+ instance_name=database.instance_name,
1080
+ state=existing_instance.state,
855
1081
  )
856
1082
  return
857
1083
 
858
1084
  except NotFound:
859
1085
  # Database doesn't exist, proceed with creation
860
- logger.debug(
861
- f"Database instance {database.instance_name} not found, creating new instance..."
1086
+ logger.info(
1087
+ "Creating new database instance", instance_name=database.instance_name
862
1088
  )
863
1089
 
864
1090
  try:
@@ -878,10 +1104,17 @@ class DatabricksProvider(ServiceProvider):
878
1104
  workspace_client.database.create_database_instance(
879
1105
  database_instance=database_instance
880
1106
  )
881
- logger.info(
882
- f"Successfully created database instance: {database.instance_name}"
1107
+ logger.success(
1108
+ "Database instance created successfully",
1109
+ instance_name=database.instance_name,
883
1110
  )
884
1111
 
1112
+ # Wait for the newly created database to become AVAILABLE
1113
+ self._wait_for_database_available(
1114
+ workspace_client, database.instance_name
1115
+ )
1116
+ return
1117
+
885
1118
  except Exception as create_error:
886
1119
  error_msg: str = str(create_error)
887
1120
 
@@ -891,13 +1124,20 @@ class DatabricksProvider(ServiceProvider):
891
1124
  or "RESOURCE_ALREADY_EXISTS" in error_msg
892
1125
  ):
893
1126
  logger.info(
894
- f"Database instance {database.instance_name} was created concurrently by another process"
1127
+ "Database instance was created concurrently",
1128
+ instance_name=database.instance_name,
1129
+ )
1130
+ # Still need to wait for the database to become AVAILABLE
1131
+ self._wait_for_database_available(
1132
+ workspace_client, database.instance_name
895
1133
  )
896
1134
  return
897
1135
  else:
898
1136
  # Re-raise unexpected errors
899
1137
  logger.error(
900
- f"Error creating database instance {database.instance_name}: {create_error}"
1138
+ "Error creating database instance",
1139
+ instance_name=database.instance_name,
1140
+ error=str(create_error),
901
1141
  )
902
1142
  raise
903
1143
 
@@ -911,12 +1151,15 @@ class DatabricksProvider(ServiceProvider):
911
1151
  or "RESOURCE_ALREADY_EXISTS" in error_msg
912
1152
  ):
913
1153
  logger.info(
914
- f"Database instance {database.instance_name} already exists (detected via exception)"
1154
+ "Database instance already exists (detected via exception)",
1155
+ instance_name=database.instance_name,
915
1156
  )
916
1157
  return
917
1158
  else:
918
1159
  logger.error(
919
- f"Unexpected error while handling database {database.instance_name}: {e}"
1160
+ "Unexpected error while handling database",
1161
+ instance_name=database.instance_name,
1162
+ error=str(e),
920
1163
  )
921
1164
  raise
922
1165
 
@@ -924,7 +1167,9 @@ class DatabricksProvider(ServiceProvider):
924
1167
  """
925
1168
  Ask Databricks to mint a fresh DB credential for this instance.
926
1169
  """
927
- logger.debug(f"Generating password for lakebase instance: {instance_name}")
1170
+ logger.trace(
1171
+ "Generating password for lakebase instance", instance_name=instance_name
1172
+ )
928
1173
  w: WorkspaceClient = self.w
929
1174
  cred: DatabaseCredential = w.database.generate_database_credential(
930
1175
  request_id=str(uuid.uuid4()),
@@ -960,7 +1205,8 @@ class DatabricksProvider(ServiceProvider):
960
1205
  # Validate that client_id is provided
961
1206
  if not database.client_id:
962
1207
  logger.warning(
963
- f"client_id is required to create instance role for database {database.instance_name}"
1208
+ "client_id required to create instance role",
1209
+ instance_name=database.instance_name,
964
1210
  )
965
1211
  return
966
1212
 
@@ -970,7 +1216,10 @@ class DatabricksProvider(ServiceProvider):
970
1216
  instance_name: str = database.instance_name
971
1217
 
972
1218
  logger.debug(
973
- f"Creating instance role '{role_name}' for database {instance_name} with principal {client_id}"
1219
+ "Creating instance role",
1220
+ role_name=role_name,
1221
+ instance_name=instance_name,
1222
+ principal=client_id,
974
1223
  )
975
1224
 
976
1225
  try:
@@ -981,13 +1230,15 @@ class DatabricksProvider(ServiceProvider):
981
1230
  name=role_name,
982
1231
  )
983
1232
  logger.info(
984
- f"Instance role '{role_name}' already exists for database {instance_name}"
1233
+ "Instance role already exists",
1234
+ role_name=role_name,
1235
+ instance_name=instance_name,
985
1236
  )
986
1237
  return
987
1238
  except NotFound:
988
1239
  # Role doesn't exist, proceed with creation
989
1240
  logger.debug(
990
- f"Instance role '{role_name}' not found, creating new role..."
1241
+ "Instance role not found, creating new role", role_name=role_name
991
1242
  )
992
1243
 
993
1244
  # Create the database instance role
@@ -1003,8 +1254,10 @@ class DatabricksProvider(ServiceProvider):
1003
1254
  database_instance_role=role,
1004
1255
  )
1005
1256
 
1006
- logger.info(
1007
- f"Successfully created instance role '{role_name}' for database {instance_name}"
1257
+ logger.success(
1258
+ "Instance role created successfully",
1259
+ role_name=role_name,
1260
+ instance_name=instance_name,
1008
1261
  )
1009
1262
 
1010
1263
  except Exception as e:
@@ -1016,88 +1269,308 @@ class DatabricksProvider(ServiceProvider):
1016
1269
  or "RESOURCE_ALREADY_EXISTS" in error_msg
1017
1270
  ):
1018
1271
  logger.info(
1019
- f"Instance role '{role_name}' was created concurrently for database {instance_name}"
1272
+ "Instance role was created concurrently",
1273
+ role_name=role_name,
1274
+ instance_name=instance_name,
1020
1275
  )
1021
1276
  return
1022
1277
 
1023
1278
  # Re-raise unexpected errors
1024
1279
  logger.error(
1025
- f"Error creating instance role '{role_name}' for database {instance_name}: {e}"
1280
+ "Error creating instance role",
1281
+ role_name=role_name,
1282
+ instance_name=instance_name,
1283
+ error=str(e),
1026
1284
  )
1027
1285
  raise
1028
1286
 
1029
- def get_prompt(self, prompt_model: PromptModel) -> str:
1030
- """Load prompt from MLflow Prompt Registry or fall back to default_template."""
1031
- prompt_name: str = prompt_model.full_name
1287
+ def get_prompt(self, prompt_model: PromptModel) -> PromptVersion:
1288
+ """
1289
+ Load prompt from MLflow Prompt Registry with fallback logic.
1290
+
1291
+ If an explicit version or alias is specified in the prompt_model, uses that directly.
1292
+ Otherwise, tries to load prompts in this order:
1293
+ 1. champion alias
1294
+ 2. latest alias
1295
+ 3. default alias
1296
+ 4. Register default_template if provided (only if register_to_registry=True)
1297
+ 5. Use default_template directly (fallback)
1298
+
1299
+ The auto_register field controls whether the default_template is automatically
1300
+ synced to the prompt registry:
1301
+ - If True (default): Auto-registers/updates the default_template in the registry
1302
+ - If False: Never registers, but can still load existing prompts from registry
1303
+ or use default_template directly as a local-only prompt
1032
1304
 
1033
- # Build prompt URI based on alias, version, or default to latest
1034
- if prompt_model.alias:
1035
- prompt_uri = f"prompts:/{prompt_name}@{prompt_model.alias}"
1036
- elif prompt_model.version:
1037
- prompt_uri = f"prompts:/{prompt_name}/{prompt_model.version}"
1038
- else:
1039
- prompt_uri = f"prompts:/{prompt_name}@latest"
1305
+ Args:
1306
+ prompt_model: The prompt model configuration
1040
1307
 
1041
- try:
1042
- from mlflow.genai.prompts import Prompt
1308
+ Returns:
1309
+ PromptVersion: The loaded prompt version
1043
1310
 
1044
- prompt_obj: Prompt = mlflow.genai.load_prompt(prompt_uri)
1045
- return prompt_obj.to_single_brace_format()
1311
+ Raises:
1312
+ ValueError: If no prompt can be loaded from any source
1313
+ """
1046
1314
 
1047
- except Exception as e:
1048
- logger.warning(f"Failed to load prompt '{prompt_name}' from registry: {e}")
1315
+ prompt_name: str = prompt_model.full_name
1049
1316
 
1050
- if prompt_model.default_template:
1051
- logger.info(f"Using default_template for '{prompt_name}'")
1052
- self._sync_default_template_to_registry(
1053
- prompt_name, prompt_model.default_template, prompt_model.description
1317
+ # If explicit version or alias is specified, use it directly
1318
+ if prompt_model.version or prompt_model.alias:
1319
+ try:
1320
+ prompt_version: PromptVersion = prompt_model.as_prompt()
1321
+ version_or_alias = (
1322
+ f"version {prompt_model.version}"
1323
+ if prompt_model.version
1324
+ else f"alias {prompt_model.alias}"
1054
1325
  )
1055
- return prompt_model.default_template
1326
+ logger.debug(
1327
+ "Loaded prompt with explicit version/alias",
1328
+ prompt_name=prompt_name,
1329
+ version_or_alias=version_or_alias,
1330
+ )
1331
+ return prompt_version
1332
+ except Exception as e:
1333
+ version_or_alias = (
1334
+ f"version {prompt_model.version}"
1335
+ if prompt_model.version
1336
+ else f"alias {prompt_model.alias}"
1337
+ )
1338
+ logger.warning(
1339
+ "Failed to load prompt with explicit version/alias",
1340
+ prompt_name=prompt_name,
1341
+ version_or_alias=version_or_alias,
1342
+ error=str(e),
1343
+ )
1344
+ # Fall through to try other methods
1056
1345
 
1057
- raise ValueError(
1058
- f"Prompt '{prompt_name}' not found in registry and no default_template provided"
1059
- ) from e
1346
+ # Try to load in priority order: champion → default (with sync check)
1347
+ logger.trace(
1348
+ "Trying prompt fallback order",
1349
+ prompt_name=prompt_name,
1350
+ order="champion → default",
1351
+ )
1060
1352
 
1061
- def _sync_default_template_to_registry(
1062
- self, prompt_name: str, default_template: str, description: str | None = None
1063
- ) -> None:
1064
- """Register default_template to prompt registry under 'default' alias if changed."""
1065
- try:
1066
- # Check if default alias already has the same template
1353
+ # First, sync default alias if template has changed (even if champion exists)
1354
+ # Only do this if auto_register is True
1355
+ if prompt_model.default_template and prompt_model.auto_register:
1067
1356
  try:
1068
- logger.debug(f"Loading prompt '{prompt_name}' from registry...")
1069
- existing: PromptVersion = mlflow.genai.load_prompt(
1070
- f"prompts:/{prompt_name}@default"
1071
- )
1357
+ # Try to load existing default
1358
+ existing_default = load_prompt(f"prompts:/{prompt_name}@default")
1359
+
1360
+ # Check if champion exists and if it matches default
1361
+ champion_matches_default = False
1362
+ try:
1363
+ existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
1364
+ champion_matches_default = (
1365
+ existing_champion.version == existing_default.version
1366
+ )
1367
+ status = (
1368
+ "tracking" if champion_matches_default else "pinned separately"
1369
+ )
1370
+ logger.trace(
1371
+ "Champion vs default version",
1372
+ prompt_name=prompt_name,
1373
+ champion_version=existing_champion.version,
1374
+ default_version=existing_default.version,
1375
+ status=status,
1376
+ )
1377
+ except Exception:
1378
+ # No champion exists
1379
+ logger.trace("No champion alias found", prompt_name=prompt_name)
1380
+
1381
+ # Check if default_template differs from existing default
1072
1382
  if (
1073
- existing.to_single_brace_format().strip()
1074
- == default_template.strip()
1383
+ existing_default.template.strip()
1384
+ != prompt_model.default_template.strip()
1075
1385
  ):
1076
- logger.debug(f"Prompt '{prompt_name}' is already up-to-date")
1077
- return # Already up-to-date
1078
- except Exception:
1079
- logger.debug(
1080
- f"Default alias for prompt '{prompt_name}' doesn't exist yet"
1386
+ logger.info(
1387
+ "Default template changed, registering new version",
1388
+ prompt_name=prompt_name,
1389
+ )
1390
+
1391
+ # Only update champion if it was pointing to the old default
1392
+ if champion_matches_default:
1393
+ logger.info(
1394
+ "Champion was tracking default, will update to new version",
1395
+ prompt_name=prompt_name,
1396
+ old_version=existing_default.version,
1397
+ )
1398
+ set_champion = True
1399
+ else:
1400
+ logger.info(
1401
+ "Champion is pinned separately, preserving it",
1402
+ prompt_name=prompt_name,
1403
+ )
1404
+ set_champion = False
1405
+
1406
+ self._register_default_template(
1407
+ prompt_name,
1408
+ prompt_model.default_template,
1409
+ prompt_model.description,
1410
+ set_champion=set_champion,
1411
+ )
1412
+ except Exception as e:
1413
+ # No default exists yet, register it
1414
+ logger.trace(
1415
+ "No default alias found", prompt_name=prompt_name, error=str(e)
1416
+ )
1417
+ logger.info(
1418
+ "Registering default template as default alias",
1419
+ prompt_name=prompt_name,
1420
+ )
1421
+ # First registration - set both default and champion
1422
+ self._register_default_template(
1423
+ prompt_name,
1424
+ prompt_model.default_template,
1425
+ prompt_model.description,
1426
+ set_champion=True,
1427
+ )
1428
+ elif prompt_model.default_template and not prompt_model.auto_register:
1429
+ logger.trace(
1430
+ "Prompt has auto_register=False, skipping registration",
1431
+ prompt_name=prompt_name,
1432
+ )
1433
+
1434
+ # 1. Try champion alias (highest priority for execution)
1435
+ try:
1436
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
1437
+ logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
1438
+ return prompt_version
1439
+ except Exception as e:
1440
+ logger.trace(
1441
+ "Champion alias not found", prompt_name=prompt_name, error=str(e)
1442
+ )
1443
+
1444
+ # 2. Try default alias (already synced above)
1445
+ if prompt_model.default_template:
1446
+ try:
1447
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
1448
+ logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
1449
+ return prompt_version
1450
+ except Exception as e:
1451
+ # Should not happen since we just registered it above, but handle anyway
1452
+ logger.trace(
1453
+ "Default alias not found", prompt_name=prompt_name, error=str(e)
1081
1454
  )
1082
1455
 
1083
- # Register new version and set as default alias
1456
+ # 3. Try latest alias as final fallback
1457
+ try:
1458
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
1459
+ logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
1460
+ return prompt_version
1461
+ except Exception as e:
1462
+ logger.trace(
1463
+ "Latest alias not found", prompt_name=prompt_name, error=str(e)
1464
+ )
1465
+
1466
+ # 4. Final fallback: use default_template directly if available
1467
+ if prompt_model.default_template:
1468
+ logger.warning(
1469
+ "Could not load prompt from registry, using default_template directly",
1470
+ prompt_name=prompt_name,
1471
+ )
1472
+ return PromptVersion(
1473
+ name=prompt_name,
1474
+ version=1,
1475
+ template=prompt_model.default_template,
1476
+ tags={"dao_ai": dao_ai_version()},
1477
+ )
1478
+
1479
+ raise ValueError(
1480
+ f"Prompt '{prompt_name}' not found in registry "
1481
+ "(tried champion, default, latest aliases) "
1482
+ "and no default_template provided"
1483
+ )
1484
+
1485
+ def _register_default_template(
1486
+ self,
1487
+ prompt_name: str,
1488
+ default_template: str,
1489
+ description: str | None = None,
1490
+ set_champion: bool = True,
1491
+ ) -> PromptVersion:
1492
+ """Register default_template as a new prompt version.
1493
+
1494
+ Registers the template and sets the 'default' alias.
1495
+ Optionally sets 'champion' alias if no champion exists.
1496
+
1497
+ Args:
1498
+ prompt_name: Full name of the prompt
1499
+ default_template: The template content
1500
+ description: Optional description for commit message
1501
+ set_champion: Whether to also set champion alias (default: True)
1502
+
1503
+ If registration fails (e.g., in Model Serving with restricted permissions),
1504
+ logs the error and raises.
1505
+ """
1506
+ logger.info(
1507
+ "Registering default template",
1508
+ prompt_name=prompt_name,
1509
+ set_champion=set_champion,
1510
+ )
1511
+
1512
+ try:
1084
1513
  commit_message = description or "Auto-synced from default_template"
1085
- prompt_version: PromptVersion = mlflow.genai.register_prompt(
1514
+ prompt_version = mlflow.genai.register_prompt(
1086
1515
  name=prompt_name,
1087
1516
  template=default_template,
1088
1517
  commit_message=commit_message,
1518
+ tags={"dao_ai": dao_ai_version()},
1089
1519
  )
1090
1520
 
1091
- logger.debug(f"Setting default alias for prompt '{prompt_name}'")
1092
- mlflow.genai.set_prompt_alias(
1093
- name=prompt_name,
1094
- alias="default",
1095
- version=prompt_version.version,
1096
- )
1521
+ # Always set default alias
1522
+ try:
1523
+ logger.debug(
1524
+ "Setting default alias",
1525
+ prompt_name=prompt_name,
1526
+ version=prompt_version.version,
1527
+ )
1528
+ mlflow.genai.set_prompt_alias(
1529
+ name=prompt_name, alias="default", version=prompt_version.version
1530
+ )
1531
+ logger.success(
1532
+ "Set default alias for prompt",
1533
+ prompt_name=prompt_name,
1534
+ version=prompt_version.version,
1535
+ )
1536
+ except Exception as alias_error:
1537
+ logger.warning(
1538
+ "Could not set default alias",
1539
+ prompt_name=prompt_name,
1540
+ error=str(alias_error),
1541
+ )
1097
1542
 
1098
- logger.info(
1099
- f"Synced prompt '{prompt_name}' v{prompt_version.version} to registry"
1100
- )
1543
+ # Optionally set champion alias (only if no champion exists or explicitly requested)
1544
+ if set_champion:
1545
+ try:
1546
+ mlflow.genai.set_prompt_alias(
1547
+ name=prompt_name,
1548
+ alias="champion",
1549
+ version=prompt_version.version,
1550
+ )
1551
+ logger.success(
1552
+ "Set champion alias for prompt",
1553
+ prompt_name=prompt_name,
1554
+ version=prompt_version.version,
1555
+ )
1556
+ except Exception as alias_error:
1557
+ logger.warning(
1558
+ "Could not set champion alias",
1559
+ prompt_name=prompt_name,
1560
+ error=str(alias_error),
1561
+ )
1101
1562
 
1102
- except Exception as e:
1103
- logger.warning(f"Failed to sync '{prompt_name}' to registry: {e}")
1563
+ return prompt_version
1564
+
1565
+ except Exception as reg_error:
1566
+ logger.error(
1567
+ "Failed to register prompt - please register from notebook with write permissions",
1568
+ prompt_name=prompt_name,
1569
+ error=str(reg_error),
1570
+ )
1571
+ return PromptVersion(
1572
+ name=prompt_name,
1573
+ version=1,
1574
+ template=default_template,
1575
+ tags={"dao_ai": dao_ai_version()},
1576
+ )