dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import re
|
|
3
2
|
import uuid
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
from typing import Any, Callable, Final, Sequence
|
|
@@ -32,14 +31,12 @@ from mlflow import MlflowClient
|
|
|
32
31
|
from mlflow.entities import Experiment
|
|
33
32
|
from mlflow.entities.model_registry import PromptVersion
|
|
34
33
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
35
|
-
from mlflow.genai.datasets import EvaluationDataset, get_dataset
|
|
36
34
|
from mlflow.genai.prompts import load_prompt
|
|
37
35
|
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
|
|
38
36
|
from mlflow.models.model import ModelInfo
|
|
39
37
|
from mlflow.models.resources import (
|
|
40
38
|
DatabricksResource,
|
|
41
39
|
)
|
|
42
|
-
from mlflow.pyfunc import ResponsesAgent
|
|
43
40
|
from pyspark.sql import SparkSession
|
|
44
41
|
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
45
42
|
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
|
|
@@ -49,6 +46,7 @@ from dao_ai.config import (
|
|
|
49
46
|
AppConfig,
|
|
50
47
|
ConnectionModel,
|
|
51
48
|
DatabaseModel,
|
|
49
|
+
DatabricksAppModel,
|
|
52
50
|
DatasetModel,
|
|
53
51
|
FunctionModel,
|
|
54
52
|
GenieRoomModel,
|
|
@@ -57,7 +55,6 @@ from dao_ai.config import (
|
|
|
57
55
|
IsDatabricksResource,
|
|
58
56
|
LLMModel,
|
|
59
57
|
PromptModel,
|
|
60
|
-
PromptOptimizationModel,
|
|
61
58
|
SchemaModel,
|
|
62
59
|
TableModel,
|
|
63
60
|
UnityCatalogFunctionSqlModel,
|
|
@@ -73,6 +70,7 @@ from dao_ai.utils import (
|
|
|
73
70
|
get_installed_packages,
|
|
74
71
|
is_installed,
|
|
75
72
|
is_lib_provided,
|
|
73
|
+
normalize_host,
|
|
76
74
|
normalize_name,
|
|
77
75
|
)
|
|
78
76
|
from dao_ai.vector_search import endpoint_exists, index_exists
|
|
@@ -94,15 +92,18 @@ def _workspace_client(
|
|
|
94
92
|
Create a WorkspaceClient instance with the provided parameters.
|
|
95
93
|
If no parameters are provided, it will use the default configuration.
|
|
96
94
|
"""
|
|
97
|
-
|
|
95
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
96
|
+
normalized_host = normalize_host(workspace_host)
|
|
97
|
+
|
|
98
|
+
if client_id and client_secret and normalized_host:
|
|
98
99
|
return WorkspaceClient(
|
|
99
|
-
host=
|
|
100
|
+
host=normalized_host,
|
|
100
101
|
client_id=client_id,
|
|
101
102
|
client_secret=client_secret,
|
|
102
103
|
auth_type="oauth-m2m",
|
|
103
104
|
)
|
|
104
105
|
elif pat:
|
|
105
|
-
return WorkspaceClient(host=
|
|
106
|
+
return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
|
|
106
107
|
else:
|
|
107
108
|
return WorkspaceClient()
|
|
108
109
|
|
|
@@ -117,15 +118,18 @@ def _vector_search_client(
|
|
|
117
118
|
Create a VectorSearchClient instance with the provided parameters.
|
|
118
119
|
If no parameters are provided, it will use the default configuration.
|
|
119
120
|
"""
|
|
120
|
-
|
|
121
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
122
|
+
normalized_host = normalize_host(workspace_host)
|
|
123
|
+
|
|
124
|
+
if client_id and client_secret and normalized_host:
|
|
121
125
|
return VectorSearchClient(
|
|
122
|
-
workspace_url=
|
|
126
|
+
workspace_url=normalized_host,
|
|
123
127
|
service_principal_client_id=client_id,
|
|
124
128
|
service_principal_client_secret=client_secret,
|
|
125
129
|
)
|
|
126
|
-
elif pat and
|
|
130
|
+
elif pat and normalized_host:
|
|
127
131
|
return VectorSearchClient(
|
|
128
|
-
workspace_url=
|
|
132
|
+
workspace_url=normalized_host,
|
|
129
133
|
personal_access_token=pat,
|
|
130
134
|
)
|
|
131
135
|
else:
|
|
@@ -177,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
177
181
|
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
178
182
|
if experiment is None:
|
|
179
183
|
experiment_id: str = mlflow.create_experiment(name=experiment_name)
|
|
180
|
-
logger.
|
|
181
|
-
|
|
184
|
+
logger.success(
|
|
185
|
+
"Created new MLflow experiment",
|
|
186
|
+
experiment_name=experiment_name,
|
|
187
|
+
experiment_id=experiment_id,
|
|
182
188
|
)
|
|
183
189
|
experiment = mlflow.get_experiment(experiment_id)
|
|
184
190
|
return experiment
|
|
185
191
|
|
|
186
192
|
def create_token(self) -> str:
|
|
187
193
|
current_user: User = self.w.current_user.me()
|
|
188
|
-
logger.debug(
|
|
194
|
+
logger.debug("Authenticated to Databricks", user=str(current_user))
|
|
189
195
|
headers: dict[str, str] = self.w.config.authenticate()
|
|
190
196
|
token: str = headers["Authorization"].replace("Bearer ", "")
|
|
191
197
|
return token
|
|
@@ -197,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
|
|
|
197
203
|
secret_response: GetSecretResponse = self.w.secrets.get_secret(
|
|
198
204
|
secret_scope, secret_key
|
|
199
205
|
)
|
|
200
|
-
logger.
|
|
206
|
+
logger.trace(
|
|
207
|
+
"Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
|
|
208
|
+
)
|
|
201
209
|
encoded_secret: str = secret_response.value
|
|
202
210
|
decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
|
|
203
211
|
return decoded_secret
|
|
204
212
|
except NotFound:
|
|
205
213
|
logger.warning(
|
|
206
|
-
|
|
214
|
+
"Secret not found, using default value",
|
|
215
|
+
secret_key=secret_key,
|
|
216
|
+
secret_scope=secret_scope,
|
|
207
217
|
)
|
|
208
218
|
except Exception as e:
|
|
209
219
|
logger.error(
|
|
210
|
-
|
|
220
|
+
"Error retrieving secret",
|
|
221
|
+
secret_key=secret_key,
|
|
222
|
+
secret_scope=secret_scope,
|
|
223
|
+
error=str(e),
|
|
211
224
|
)
|
|
212
225
|
|
|
213
226
|
return default_value
|
|
@@ -216,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
216
229
|
self,
|
|
217
230
|
config: AppConfig,
|
|
218
231
|
) -> ModelInfo:
|
|
219
|
-
logger.
|
|
232
|
+
logger.info("Creating agent")
|
|
220
233
|
mlflow.set_registry_uri("databricks-uc")
|
|
221
234
|
|
|
235
|
+
# Set up experiment for proper tracking
|
|
236
|
+
experiment: Experiment = self.get_or_create_experiment(config)
|
|
237
|
+
mlflow.set_experiment(experiment_id=experiment.experiment_id)
|
|
238
|
+
logger.debug(
|
|
239
|
+
"Using MLflow experiment",
|
|
240
|
+
experiment_name=experiment.name,
|
|
241
|
+
experiment_id=experiment.experiment_id,
|
|
242
|
+
)
|
|
243
|
+
|
|
222
244
|
llms: Sequence[LLMModel] = list(config.resources.llms.values())
|
|
223
245
|
vector_indexes: Sequence[IndexModel] = list(
|
|
224
246
|
config.resources.vector_stores.values()
|
|
@@ -236,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
236
258
|
)
|
|
237
259
|
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
238
260
|
volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
|
|
261
|
+
apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
|
|
239
262
|
|
|
240
263
|
resources: Sequence[IsDatabricksResource] = (
|
|
241
264
|
llms
|
|
@@ -247,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
247
270
|
+ connections
|
|
248
271
|
+ databases
|
|
249
272
|
+ volumes
|
|
273
|
+
+ apps
|
|
250
274
|
)
|
|
251
275
|
|
|
252
276
|
# Flatten all resources from all models into a single list
|
|
@@ -260,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
260
284
|
for resource in r.as_resources()
|
|
261
285
|
if not r.on_behalf_of_user
|
|
262
286
|
]
|
|
263
|
-
logger.
|
|
287
|
+
logger.trace(
|
|
288
|
+
"System resources identified",
|
|
289
|
+
count=len(system_resources),
|
|
290
|
+
resources=[r.name for r in system_resources],
|
|
291
|
+
)
|
|
264
292
|
|
|
265
293
|
system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
|
|
266
294
|
resources=system_resources
|
|
267
295
|
)
|
|
268
|
-
logger.
|
|
296
|
+
logger.trace("System auth policy created", policy=str(system_auth_policy))
|
|
269
297
|
|
|
270
298
|
api_scopes: Sequence[str] = list(
|
|
271
299
|
set(
|
|
@@ -277,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
|
|
|
277
305
|
]
|
|
278
306
|
)
|
|
279
307
|
)
|
|
280
|
-
logger.
|
|
308
|
+
logger.trace("API scopes identified", scopes=api_scopes)
|
|
281
309
|
|
|
282
310
|
user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
|
|
283
|
-
logger.
|
|
311
|
+
logger.trace("User auth policy created", policy=str(user_auth_policy))
|
|
284
312
|
|
|
285
313
|
auth_policy: AuthPolicy = AuthPolicy(
|
|
286
314
|
system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
|
|
287
315
|
)
|
|
288
|
-
logger.debug(
|
|
316
|
+
logger.debug(
|
|
317
|
+
"Auth policy created",
|
|
318
|
+
has_system_auth=system_auth_policy is not None,
|
|
319
|
+
has_user_auth=user_auth_policy is not None,
|
|
320
|
+
)
|
|
289
321
|
|
|
290
322
|
code_paths: list[str] = config.app.code_paths
|
|
291
323
|
for path in code_paths:
|
|
@@ -312,18 +344,42 @@ class DatabricksProvider(ServiceProvider):
|
|
|
312
344
|
|
|
313
345
|
pip_requirements += get_installed_packages()
|
|
314
346
|
|
|
315
|
-
logger.
|
|
316
|
-
logger.
|
|
347
|
+
logger.trace("Pip requirements prepared", count=len(pip_requirements))
|
|
348
|
+
logger.trace("Code paths prepared", count=len(code_paths))
|
|
317
349
|
|
|
318
350
|
run_name: str = normalize_name(config.app.name)
|
|
319
|
-
logger.debug(
|
|
320
|
-
|
|
351
|
+
logger.debug(
|
|
352
|
+
"Agent run configuration",
|
|
353
|
+
run_name=run_name,
|
|
354
|
+
model_path=model_path.as_posix(),
|
|
355
|
+
)
|
|
321
356
|
|
|
322
357
|
input_example: dict[str, Any] = None
|
|
323
358
|
if config.app.input_example:
|
|
324
359
|
input_example = config.app.input_example.model_dump()
|
|
325
360
|
|
|
326
|
-
logger.
|
|
361
|
+
logger.trace("Input example configured", has_example=input_example is not None)
|
|
362
|
+
|
|
363
|
+
# Create conda environment with configured Python version
|
|
364
|
+
# This allows deploying from environments with different Python versions
|
|
365
|
+
# (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
|
|
366
|
+
target_python_version: str = config.app.python_version
|
|
367
|
+
logger.debug("Target Python version configured", version=target_python_version)
|
|
368
|
+
|
|
369
|
+
conda_env: dict[str, Any] = {
|
|
370
|
+
"name": "mlflow-env",
|
|
371
|
+
"channels": ["conda-forge"],
|
|
372
|
+
"dependencies": [
|
|
373
|
+
f"python={target_python_version}",
|
|
374
|
+
"pip",
|
|
375
|
+
{"pip": list(pip_requirements)},
|
|
376
|
+
],
|
|
377
|
+
}
|
|
378
|
+
logger.trace(
|
|
379
|
+
"Conda environment configured",
|
|
380
|
+
python_version=target_python_version,
|
|
381
|
+
pip_packages_count=len(pip_requirements),
|
|
382
|
+
)
|
|
327
383
|
|
|
328
384
|
with mlflow.start_run(run_name=run_name):
|
|
329
385
|
mlflow.set_tag("type", "agent")
|
|
@@ -333,7 +389,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
333
389
|
code_paths=code_paths,
|
|
334
390
|
model_config=config.model_dump(mode="json", by_alias=True),
|
|
335
391
|
name="agent",
|
|
336
|
-
|
|
392
|
+
conda_env=conda_env,
|
|
337
393
|
input_example=input_example,
|
|
338
394
|
# resources=all_resources,
|
|
339
395
|
auth_policy=auth_policy,
|
|
@@ -344,8 +400,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
344
400
|
model_version: ModelVersion = mlflow.register_model(
|
|
345
401
|
name=registered_model_name, model_uri=logged_agent_info.model_uri
|
|
346
402
|
)
|
|
347
|
-
logger.
|
|
348
|
-
|
|
403
|
+
logger.success(
|
|
404
|
+
"Model registered",
|
|
405
|
+
model_name=registered_model_name,
|
|
406
|
+
version=model_version.version,
|
|
349
407
|
)
|
|
350
408
|
|
|
351
409
|
client: MlflowClient = MlflowClient()
|
|
@@ -357,7 +415,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
357
415
|
key="dao_ai",
|
|
358
416
|
value=dao_ai_version(),
|
|
359
417
|
)
|
|
360
|
-
logger.
|
|
418
|
+
logger.trace("Set dao_ai tag on model version", version=model_version.version)
|
|
361
419
|
|
|
362
420
|
client.set_registered_model_alias(
|
|
363
421
|
name=registered_model_name,
|
|
@@ -374,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
374
432
|
aliased_model: ModelVersion = client.get_model_version_by_alias(
|
|
375
433
|
registered_model_name, config.app.alias
|
|
376
434
|
)
|
|
377
|
-
logger.
|
|
378
|
-
|
|
435
|
+
logger.info(
|
|
436
|
+
"Model aliased",
|
|
437
|
+
model_name=registered_model_name,
|
|
438
|
+
alias=config.app.alias,
|
|
439
|
+
version=aliased_model.version,
|
|
379
440
|
)
|
|
380
441
|
|
|
381
442
|
def deploy_agent(self, config: AppConfig) -> None:
|
|
382
|
-
logger.
|
|
443
|
+
logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
|
|
383
444
|
mlflow.set_registry_uri("databricks-uc")
|
|
384
445
|
|
|
385
446
|
endpoint_name: str = config.app.endpoint_name
|
|
@@ -400,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
400
461
|
agents.get_deployments(endpoint_name)
|
|
401
462
|
endpoint_exists = True
|
|
402
463
|
logger.debug(
|
|
403
|
-
|
|
464
|
+
"Endpoint already exists, updating", endpoint_name=endpoint_name
|
|
404
465
|
)
|
|
405
466
|
except Exception:
|
|
406
|
-
logger.debug(
|
|
407
|
-
f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
|
|
408
|
-
)
|
|
467
|
+
logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
|
|
409
468
|
|
|
410
469
|
# Deploy - skip tags for existing endpoints to avoid conflicts
|
|
411
470
|
agents.deploy(
|
|
@@ -421,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
421
480
|
registered_model_name: str = config.app.registered_model.full_name
|
|
422
481
|
permissions: Sequence[dict[str, Any]] = config.app.permissions
|
|
423
482
|
|
|
424
|
-
logger.debug(
|
|
425
|
-
|
|
483
|
+
logger.debug(
|
|
484
|
+
"Configuring model permissions",
|
|
485
|
+
model_name=registered_model_name,
|
|
486
|
+
permissions_count=len(permissions),
|
|
487
|
+
)
|
|
426
488
|
|
|
427
489
|
for permission in permissions:
|
|
428
490
|
principals: Sequence[str] = permission.principals
|
|
@@ -442,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
442
504
|
try:
|
|
443
505
|
catalog_info = self.w.catalogs.get(name=schema.catalog_name)
|
|
444
506
|
except NotFound:
|
|
445
|
-
logger.
|
|
507
|
+
logger.info("Creating catalog", catalog_name=schema.catalog_name)
|
|
446
508
|
catalog_info = self.w.catalogs.create(name=schema.catalog_name)
|
|
447
509
|
return catalog_info
|
|
448
510
|
|
|
@@ -452,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
452
514
|
try:
|
|
453
515
|
schema_info = self.w.schemas.get(full_name=schema.full_name)
|
|
454
516
|
except NotFound:
|
|
455
|
-
logger.
|
|
517
|
+
logger.info("Creating schema", schema_name=schema.full_name)
|
|
456
518
|
schema_info = self.w.schemas.create(
|
|
457
519
|
name=schema.schema_name, catalog_name=catalog_info.name
|
|
458
520
|
)
|
|
@@ -464,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
464
526
|
try:
|
|
465
527
|
volume_info = self.w.volumes.read(name=volume.full_name)
|
|
466
528
|
except NotFound:
|
|
467
|
-
logger.
|
|
529
|
+
logger.info("Creating volume", volume_name=volume.full_name)
|
|
468
530
|
volume_info = self.w.volumes.create(
|
|
469
531
|
catalog_name=schema_info.catalog_name,
|
|
470
532
|
schema_name=schema_info.name,
|
|
@@ -475,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
475
537
|
|
|
476
538
|
def create_path(self, volume_path: VolumePathModel) -> Path:
|
|
477
539
|
path: Path = volume_path.full_name
|
|
478
|
-
logger.info(
|
|
540
|
+
logger.info("Creating volume path", path=str(path))
|
|
479
541
|
self.w.files.create_directory(path)
|
|
480
542
|
return path
|
|
481
543
|
|
|
@@ -516,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
|
|
|
516
578
|
|
|
517
579
|
if ddl:
|
|
518
580
|
ddl_path: Path = Path(ddl)
|
|
519
|
-
logger.debug(
|
|
581
|
+
logger.debug("Executing DDL", ddl_path=str(ddl_path))
|
|
520
582
|
statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
|
|
521
583
|
for statement in statements:
|
|
522
|
-
logger.
|
|
523
|
-
|
|
584
|
+
logger.trace(
|
|
585
|
+
"Executing DDL statement", statement=str(statement)[:100], args=args
|
|
586
|
+
)
|
|
524
587
|
spark.sql(
|
|
525
588
|
str(statement),
|
|
526
589
|
args=args,
|
|
@@ -529,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
|
|
|
529
592
|
if data:
|
|
530
593
|
data_path: Path = Path(data)
|
|
531
594
|
if format == "sql":
|
|
532
|
-
logger.debug(
|
|
595
|
+
logger.debug("Executing SQL from file", data_path=str(data_path))
|
|
533
596
|
data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
|
|
534
597
|
for statement in data_statements:
|
|
535
|
-
logger.
|
|
536
|
-
|
|
598
|
+
logger.trace(
|
|
599
|
+
"Executing SQL statement",
|
|
600
|
+
statement=str(statement)[:100],
|
|
601
|
+
args=args,
|
|
602
|
+
)
|
|
537
603
|
spark.sql(
|
|
538
604
|
str(statement),
|
|
539
605
|
args=args,
|
|
540
606
|
)
|
|
541
607
|
else:
|
|
542
|
-
logger.debug(
|
|
608
|
+
logger.debug("Writing dataset to table", table=table)
|
|
543
609
|
if not data_path.is_absolute():
|
|
544
610
|
data_path = current_dir / data_path
|
|
545
|
-
logger.
|
|
611
|
+
logger.trace("Data path resolved", path=data_path.as_posix())
|
|
546
612
|
if format == "excel":
|
|
547
613
|
pdf = pd.read_excel(data_path.as_posix())
|
|
548
614
|
df = spark.createDataFrame(pdf, schema=dataset.table_schema)
|
|
@@ -559,6 +625,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
559
625
|
df.write.mode("overwrite").saveAsTable(table)
|
|
560
626
|
|
|
561
627
|
def create_vector_store(self, vector_store: VectorStoreModel) -> None:
|
|
628
|
+
"""
|
|
629
|
+
Create a vector search index from a source table.
|
|
630
|
+
|
|
631
|
+
This method expects a VectorStoreModel in provisioning mode with all
|
|
632
|
+
required fields validated. Use VectorStoreModel.create() which handles
|
|
633
|
+
mode detection and validation.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
vector_store: VectorStoreModel configured for provisioning
|
|
637
|
+
"""
|
|
638
|
+
# Ensure endpoint exists
|
|
562
639
|
if not endpoint_exists(self.vsc, vector_store.endpoint.name):
|
|
563
640
|
self.vsc.create_endpoint_and_wait(
|
|
564
641
|
name=vector_store.endpoint.name,
|
|
@@ -566,13 +643,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
566
643
|
verbose=True,
|
|
567
644
|
)
|
|
568
645
|
|
|
569
|
-
logger.
|
|
646
|
+
logger.success(
|
|
647
|
+
"Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
|
|
648
|
+
)
|
|
570
649
|
|
|
571
650
|
if not index_exists(
|
|
572
651
|
self.vsc, vector_store.endpoint.name, vector_store.index.full_name
|
|
573
652
|
):
|
|
574
|
-
logger.
|
|
575
|
-
|
|
653
|
+
logger.info(
|
|
654
|
+
"Creating vector search index",
|
|
655
|
+
index_name=vector_store.index.full_name,
|
|
656
|
+
endpoint_name=vector_store.endpoint.name,
|
|
576
657
|
)
|
|
577
658
|
self.vsc.create_delta_sync_index_and_wait(
|
|
578
659
|
endpoint_name=vector_store.endpoint.name,
|
|
@@ -586,7 +667,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
586
667
|
)
|
|
587
668
|
else:
|
|
588
669
|
logger.debug(
|
|
589
|
-
|
|
670
|
+
"Vector search index already exists, checking status",
|
|
671
|
+
index_name=vector_store.index.full_name,
|
|
590
672
|
)
|
|
591
673
|
index = self.vsc.get_index(
|
|
592
674
|
vector_store.endpoint.name, vector_store.index.full_name
|
|
@@ -609,54 +691,61 @@ class DatabricksProvider(ServiceProvider):
|
|
|
609
691
|
|
|
610
692
|
if pipeline_status in [
|
|
611
693
|
"COMPLETED",
|
|
694
|
+
"ONLINE",
|
|
612
695
|
"FAILED",
|
|
613
696
|
"CANCELED",
|
|
614
697
|
"ONLINE_PIPELINE_FAILED",
|
|
615
698
|
]:
|
|
616
|
-
logger.debug(
|
|
617
|
-
f"Index is ready to sync (status: {pipeline_status})"
|
|
618
|
-
)
|
|
699
|
+
logger.debug("Index ready to sync", status=pipeline_status)
|
|
619
700
|
break
|
|
620
701
|
elif pipeline_status in [
|
|
621
702
|
"WAITING_FOR_RESOURCES",
|
|
622
703
|
"PROVISIONING",
|
|
623
704
|
"INITIALIZING",
|
|
624
705
|
"INDEXING",
|
|
625
|
-
"ONLINE",
|
|
626
706
|
]:
|
|
627
|
-
logger.
|
|
628
|
-
|
|
707
|
+
logger.trace(
|
|
708
|
+
"Index not ready, waiting",
|
|
709
|
+
status=pipeline_status,
|
|
710
|
+
wait_seconds=wait_interval,
|
|
629
711
|
)
|
|
630
712
|
time.sleep(wait_interval)
|
|
631
713
|
elapsed += wait_interval
|
|
632
714
|
else:
|
|
633
715
|
logger.warning(
|
|
634
|
-
|
|
716
|
+
"Unknown pipeline status, attempting sync",
|
|
717
|
+
status=pipeline_status,
|
|
635
718
|
)
|
|
636
719
|
break
|
|
637
720
|
except Exception as status_error:
|
|
638
721
|
logger.warning(
|
|
639
|
-
|
|
722
|
+
"Could not check index status, attempting sync",
|
|
723
|
+
error=str(status_error),
|
|
640
724
|
)
|
|
641
725
|
break
|
|
642
726
|
|
|
643
727
|
if elapsed >= max_wait_time:
|
|
644
728
|
logger.warning(
|
|
645
|
-
|
|
729
|
+
"Timed out waiting for index to be ready",
|
|
730
|
+
max_wait_seconds=max_wait_time,
|
|
646
731
|
)
|
|
647
732
|
|
|
648
733
|
# Now attempt to sync
|
|
649
734
|
try:
|
|
650
735
|
index.sync()
|
|
651
|
-
logger.
|
|
736
|
+
logger.success("Index sync completed")
|
|
652
737
|
except Exception as sync_error:
|
|
653
738
|
if "not ready to sync yet" in str(sync_error).lower():
|
|
654
|
-
logger.warning(
|
|
739
|
+
logger.warning(
|
|
740
|
+
"Index still not ready to sync", error=str(sync_error)
|
|
741
|
+
)
|
|
655
742
|
else:
|
|
656
743
|
raise sync_error
|
|
657
744
|
|
|
658
|
-
logger.
|
|
659
|
-
|
|
745
|
+
logger.success(
|
|
746
|
+
"Vector search index ready",
|
|
747
|
+
index_name=vector_store.index.full_name,
|
|
748
|
+
source_table=vector_store.source_table.full_name,
|
|
660
749
|
)
|
|
661
750
|
|
|
662
751
|
def get_vector_index(self, vector_store: VectorStoreModel) -> None:
|
|
@@ -692,12 +781,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
692
781
|
# sql = sql.replace("{catalog_name}", schema.catalog_name)
|
|
693
782
|
# sql = sql.replace("{schema_name}", schema.schema_name)
|
|
694
783
|
|
|
695
|
-
logger.info(function.name)
|
|
696
|
-
logger.
|
|
784
|
+
logger.info("Creating SQL function", function_name=function.name)
|
|
785
|
+
logger.trace("SQL function body", sql=sql[:200])
|
|
697
786
|
_: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
|
|
698
787
|
|
|
699
788
|
if unity_catalog_function.test:
|
|
700
|
-
logger.
|
|
789
|
+
logger.debug(
|
|
790
|
+
"Testing function",
|
|
791
|
+
function_name=function.full_name,
|
|
792
|
+
parameters=unity_catalog_function.test.parameters,
|
|
793
|
+
)
|
|
701
794
|
|
|
702
795
|
result: FunctionExecutionResult = self.dfs.execute_function(
|
|
703
796
|
function_name=function.full_name,
|
|
@@ -705,37 +798,50 @@ class DatabricksProvider(ServiceProvider):
|
|
|
705
798
|
)
|
|
706
799
|
|
|
707
800
|
if result.error:
|
|
708
|
-
logger.error(
|
|
801
|
+
logger.error(
|
|
802
|
+
"Function test failed",
|
|
803
|
+
function_name=function.full_name,
|
|
804
|
+
error=result.error,
|
|
805
|
+
)
|
|
709
806
|
else:
|
|
710
|
-
logger.
|
|
711
|
-
|
|
807
|
+
logger.success(
|
|
808
|
+
"Function test passed", function_name=function.full_name
|
|
809
|
+
)
|
|
810
|
+
logger.debug("Function test result", result=str(result))
|
|
712
811
|
|
|
713
812
|
def find_columns(self, table_model: TableModel) -> Sequence[str]:
|
|
714
|
-
logger.
|
|
813
|
+
logger.trace("Finding columns for table", table=table_model.full_name)
|
|
715
814
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
716
815
|
columns: Sequence[ColumnInfo] = table_info.columns
|
|
717
816
|
column_names: Sequence[str] = [c.name for c in columns]
|
|
718
|
-
logger.debug(
|
|
817
|
+
logger.debug(
|
|
818
|
+
"Columns found",
|
|
819
|
+
table=table_model.full_name,
|
|
820
|
+
columns_count=len(column_names),
|
|
821
|
+
)
|
|
719
822
|
return column_names
|
|
720
823
|
|
|
721
824
|
def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
|
|
722
|
-
logger.
|
|
825
|
+
logger.trace("Finding primary key for table", table=table_model.full_name)
|
|
723
826
|
primary_keys: Sequence[str] | None = None
|
|
724
827
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
725
828
|
constraints: Sequence[TableConstraint] = table_info.table_constraints
|
|
726
829
|
primary_key_constraint: PrimaryKeyConstraint | None = next(
|
|
727
|
-
c.primary_key_constraint for c in constraints if c.primary_key_constraint
|
|
830
|
+
(c.primary_key_constraint for c in constraints if c.primary_key_constraint),
|
|
831
|
+
None,
|
|
728
832
|
)
|
|
729
833
|
if primary_key_constraint:
|
|
730
834
|
primary_keys = primary_key_constraint.child_columns
|
|
731
835
|
|
|
732
|
-
logger.debug(
|
|
836
|
+
logger.debug(
|
|
837
|
+
"Primary key found", table=table_model.full_name, primary_keys=primary_keys
|
|
838
|
+
)
|
|
733
839
|
return primary_keys
|
|
734
840
|
|
|
735
841
|
def find_vector_search_endpoint(
|
|
736
842
|
self, predicate: Callable[[dict[str, Any]], bool]
|
|
737
843
|
) -> str | None:
|
|
738
|
-
logger.
|
|
844
|
+
logger.trace("Finding vector search endpoint")
|
|
739
845
|
endpoint_name: str | None = None
|
|
740
846
|
vector_search_endpoints: Sequence[dict[str, Any]] = (
|
|
741
847
|
self.vsc.list_endpoints().get("endpoints", [])
|
|
@@ -744,11 +850,13 @@ class DatabricksProvider(ServiceProvider):
|
|
|
744
850
|
if predicate(endpoint):
|
|
745
851
|
endpoint_name = endpoint["name"]
|
|
746
852
|
break
|
|
747
|
-
logger.debug(
|
|
853
|
+
logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
|
|
748
854
|
return endpoint_name
|
|
749
855
|
|
|
750
856
|
def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
|
|
751
|
-
logger.
|
|
857
|
+
logger.trace(
|
|
858
|
+
"Finding endpoint for vector search index", index_name=index_model.full_name
|
|
859
|
+
)
|
|
752
860
|
all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
|
|
753
861
|
"endpoints", []
|
|
754
862
|
)
|
|
@@ -758,14 +866,99 @@ class DatabricksProvider(ServiceProvider):
|
|
|
758
866
|
endpoint_name: str = endpoint["name"]
|
|
759
867
|
indexes = self.vsc.list_indexes(name=endpoint_name)
|
|
760
868
|
vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
|
|
761
|
-
logger.trace(
|
|
869
|
+
logger.trace(
|
|
870
|
+
"Checking endpoint for indexes",
|
|
871
|
+
endpoint_name=endpoint_name,
|
|
872
|
+
indexes_count=len(vector_indexes),
|
|
873
|
+
)
|
|
762
874
|
index_names = [vector_index["name"] for vector_index in vector_indexes]
|
|
763
875
|
if index_name in index_names:
|
|
764
876
|
found_endpoint_name = endpoint_name
|
|
765
877
|
break
|
|
766
|
-
logger.debug(
|
|
878
|
+
logger.debug(
|
|
879
|
+
"Vector search index endpoint found",
|
|
880
|
+
index_name=index_model.full_name,
|
|
881
|
+
endpoint_name=found_endpoint_name,
|
|
882
|
+
)
|
|
767
883
|
return found_endpoint_name
|
|
768
884
|
|
|
885
|
+
def _wait_for_database_available(
|
|
886
|
+
self,
|
|
887
|
+
workspace_client: WorkspaceClient,
|
|
888
|
+
instance_name: str,
|
|
889
|
+
max_wait_time: int = 600,
|
|
890
|
+
wait_interval: int = 10,
|
|
891
|
+
) -> None:
|
|
892
|
+
"""
|
|
893
|
+
Wait for a database instance to become AVAILABLE.
|
|
894
|
+
|
|
895
|
+
Args:
|
|
896
|
+
workspace_client: The Databricks workspace client
|
|
897
|
+
instance_name: Name of the database instance to wait for
|
|
898
|
+
max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
|
|
899
|
+
wait_interval: Time between status checks in seconds (default: 10)
|
|
900
|
+
|
|
901
|
+
Raises:
|
|
902
|
+
TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
|
|
903
|
+
RuntimeError: If the database enters a failed or deleted state
|
|
904
|
+
"""
|
|
905
|
+
import time
|
|
906
|
+
from typing import Any
|
|
907
|
+
|
|
908
|
+
logger.info(
|
|
909
|
+
"Waiting for database instance to become AVAILABLE",
|
|
910
|
+
instance_name=instance_name,
|
|
911
|
+
)
|
|
912
|
+
elapsed: int = 0
|
|
913
|
+
|
|
914
|
+
while elapsed < max_wait_time:
|
|
915
|
+
try:
|
|
916
|
+
current_instance: Any = workspace_client.database.get_database_instance(
|
|
917
|
+
name=instance_name
|
|
918
|
+
)
|
|
919
|
+
current_state: str = current_instance.state
|
|
920
|
+
logger.trace(
|
|
921
|
+
"Database instance state checked",
|
|
922
|
+
instance_name=instance_name,
|
|
923
|
+
state=current_state,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
if current_state == "AVAILABLE":
|
|
927
|
+
logger.success(
|
|
928
|
+
"Database instance is now AVAILABLE",
|
|
929
|
+
instance_name=instance_name,
|
|
930
|
+
)
|
|
931
|
+
return
|
|
932
|
+
elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
|
|
933
|
+
logger.trace(
|
|
934
|
+
"Database instance not ready, waiting",
|
|
935
|
+
instance_name=instance_name,
|
|
936
|
+
state=current_state,
|
|
937
|
+
wait_seconds=wait_interval,
|
|
938
|
+
)
|
|
939
|
+
time.sleep(wait_interval)
|
|
940
|
+
elapsed += wait_interval
|
|
941
|
+
elif current_state in ["STOPPED", "DELETING", "FAILED"]:
|
|
942
|
+
raise RuntimeError(
|
|
943
|
+
f"Database instance {instance_name} entered unexpected state: {current_state}"
|
|
944
|
+
)
|
|
945
|
+
else:
|
|
946
|
+
logger.warning(
|
|
947
|
+
"Unknown database state, continuing to wait",
|
|
948
|
+
instance_name=instance_name,
|
|
949
|
+
state=current_state,
|
|
950
|
+
)
|
|
951
|
+
time.sleep(wait_interval)
|
|
952
|
+
elapsed += wait_interval
|
|
953
|
+
except NotFound:
|
|
954
|
+
raise RuntimeError(
|
|
955
|
+
f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
raise TimeoutError(
|
|
959
|
+
f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
960
|
+
)
|
|
961
|
+
|
|
769
962
|
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
770
963
|
"""
|
|
771
964
|
Create a Lakebase database instance using the Databricks workspace client.
|
|
@@ -796,13 +989,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
796
989
|
|
|
797
990
|
if existing_instance:
|
|
798
991
|
logger.debug(
|
|
799
|
-
|
|
992
|
+
"Database instance already exists",
|
|
993
|
+
instance_name=database.instance_name,
|
|
994
|
+
state=existing_instance.state,
|
|
800
995
|
)
|
|
801
996
|
|
|
802
997
|
# Check if database is in an intermediate state
|
|
803
998
|
if existing_instance.state in ["STARTING", "UPDATING"]:
|
|
804
999
|
logger.info(
|
|
805
|
-
|
|
1000
|
+
"Database instance in intermediate state, waiting",
|
|
1001
|
+
instance_name=database.instance_name,
|
|
1002
|
+
state=existing_instance.state,
|
|
806
1003
|
)
|
|
807
1004
|
|
|
808
1005
|
# Wait for database to reach a stable state
|
|
@@ -818,65 +1015,87 @@ class DatabricksProvider(ServiceProvider):
|
|
|
818
1015
|
)
|
|
819
1016
|
)
|
|
820
1017
|
current_state: str = current_instance.state
|
|
821
|
-
logger.
|
|
1018
|
+
logger.trace(
|
|
1019
|
+
"Checking database instance state",
|
|
1020
|
+
instance_name=database.instance_name,
|
|
1021
|
+
state=current_state,
|
|
1022
|
+
)
|
|
822
1023
|
|
|
823
1024
|
if current_state == "AVAILABLE":
|
|
824
|
-
logger.
|
|
825
|
-
|
|
1025
|
+
logger.success(
|
|
1026
|
+
"Database instance is now AVAILABLE",
|
|
1027
|
+
instance_name=database.instance_name,
|
|
826
1028
|
)
|
|
827
1029
|
break
|
|
828
1030
|
elif current_state in ["STARTING", "UPDATING"]:
|
|
829
|
-
logger.
|
|
830
|
-
|
|
1031
|
+
logger.trace(
|
|
1032
|
+
"Database instance not ready, waiting",
|
|
1033
|
+
instance_name=database.instance_name,
|
|
1034
|
+
state=current_state,
|
|
1035
|
+
wait_seconds=wait_interval,
|
|
831
1036
|
)
|
|
832
1037
|
time.sleep(wait_interval)
|
|
833
1038
|
elapsed += wait_interval
|
|
834
1039
|
elif current_state in ["STOPPED", "DELETING"]:
|
|
835
1040
|
logger.warning(
|
|
836
|
-
|
|
1041
|
+
"Database instance in unexpected state",
|
|
1042
|
+
instance_name=database.instance_name,
|
|
1043
|
+
state=current_state,
|
|
837
1044
|
)
|
|
838
1045
|
break
|
|
839
1046
|
else:
|
|
840
1047
|
logger.warning(
|
|
841
|
-
|
|
1048
|
+
"Unknown database state, proceeding",
|
|
1049
|
+
instance_name=database.instance_name,
|
|
1050
|
+
state=current_state,
|
|
842
1051
|
)
|
|
843
1052
|
break
|
|
844
1053
|
except NotFound:
|
|
845
1054
|
logger.warning(
|
|
846
|
-
|
|
1055
|
+
"Database instance no longer exists, will recreate",
|
|
1056
|
+
instance_name=database.instance_name,
|
|
847
1057
|
)
|
|
848
1058
|
break
|
|
849
1059
|
except Exception as state_error:
|
|
850
1060
|
logger.warning(
|
|
851
|
-
|
|
1061
|
+
"Could not check database state, proceeding",
|
|
1062
|
+
instance_name=database.instance_name,
|
|
1063
|
+
error=str(state_error),
|
|
852
1064
|
)
|
|
853
1065
|
break
|
|
854
1066
|
|
|
855
1067
|
if elapsed >= max_wait_time:
|
|
856
1068
|
logger.warning(
|
|
857
|
-
|
|
1069
|
+
"Timed out waiting for database to become AVAILABLE",
|
|
1070
|
+
instance_name=database.instance_name,
|
|
1071
|
+
max_wait_seconds=max_wait_time,
|
|
858
1072
|
)
|
|
859
1073
|
|
|
860
1074
|
elif existing_instance.state == "AVAILABLE":
|
|
861
1075
|
logger.info(
|
|
862
|
-
|
|
1076
|
+
"Database instance already exists and is AVAILABLE",
|
|
1077
|
+
instance_name=database.instance_name,
|
|
863
1078
|
)
|
|
864
1079
|
return
|
|
865
1080
|
elif existing_instance.state in ["STOPPED", "DELETING"]:
|
|
866
1081
|
logger.warning(
|
|
867
|
-
|
|
1082
|
+
"Database instance in terminal state",
|
|
1083
|
+
instance_name=database.instance_name,
|
|
1084
|
+
state=existing_instance.state,
|
|
868
1085
|
)
|
|
869
1086
|
return
|
|
870
1087
|
else:
|
|
871
1088
|
logger.info(
|
|
872
|
-
|
|
1089
|
+
"Database instance already exists",
|
|
1090
|
+
instance_name=database.instance_name,
|
|
1091
|
+
state=existing_instance.state,
|
|
873
1092
|
)
|
|
874
1093
|
return
|
|
875
1094
|
|
|
876
1095
|
except NotFound:
|
|
877
1096
|
# Database doesn't exist, proceed with creation
|
|
878
|
-
logger.
|
|
879
|
-
|
|
1097
|
+
logger.info(
|
|
1098
|
+
"Creating new database instance", instance_name=database.instance_name
|
|
880
1099
|
)
|
|
881
1100
|
|
|
882
1101
|
try:
|
|
@@ -896,10 +1115,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
896
1115
|
workspace_client.database.create_database_instance(
|
|
897
1116
|
database_instance=database_instance
|
|
898
1117
|
)
|
|
899
|
-
logger.
|
|
900
|
-
|
|
1118
|
+
logger.success(
|
|
1119
|
+
"Database instance created successfully",
|
|
1120
|
+
instance_name=database.instance_name,
|
|
901
1121
|
)
|
|
902
1122
|
|
|
1123
|
+
# Wait for the newly created database to become AVAILABLE
|
|
1124
|
+
self._wait_for_database_available(
|
|
1125
|
+
workspace_client, database.instance_name
|
|
1126
|
+
)
|
|
1127
|
+
return
|
|
1128
|
+
|
|
903
1129
|
except Exception as create_error:
|
|
904
1130
|
error_msg: str = str(create_error)
|
|
905
1131
|
|
|
@@ -909,13 +1135,20 @@ class DatabricksProvider(ServiceProvider):
|
|
|
909
1135
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
910
1136
|
):
|
|
911
1137
|
logger.info(
|
|
912
|
-
|
|
1138
|
+
"Database instance was created concurrently",
|
|
1139
|
+
instance_name=database.instance_name,
|
|
1140
|
+
)
|
|
1141
|
+
# Still need to wait for the database to become AVAILABLE
|
|
1142
|
+
self._wait_for_database_available(
|
|
1143
|
+
workspace_client, database.instance_name
|
|
913
1144
|
)
|
|
914
1145
|
return
|
|
915
1146
|
else:
|
|
916
1147
|
# Re-raise unexpected errors
|
|
917
1148
|
logger.error(
|
|
918
|
-
|
|
1149
|
+
"Error creating database instance",
|
|
1150
|
+
instance_name=database.instance_name,
|
|
1151
|
+
error=str(create_error),
|
|
919
1152
|
)
|
|
920
1153
|
raise
|
|
921
1154
|
|
|
@@ -929,12 +1162,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
929
1162
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
930
1163
|
):
|
|
931
1164
|
logger.info(
|
|
932
|
-
|
|
1165
|
+
"Database instance already exists (detected via exception)",
|
|
1166
|
+
instance_name=database.instance_name,
|
|
933
1167
|
)
|
|
934
1168
|
return
|
|
935
1169
|
else:
|
|
936
1170
|
logger.error(
|
|
937
|
-
|
|
1171
|
+
"Unexpected error while handling database",
|
|
1172
|
+
instance_name=database.instance_name,
|
|
1173
|
+
error=str(e),
|
|
938
1174
|
)
|
|
939
1175
|
raise
|
|
940
1176
|
|
|
@@ -942,7 +1178,9 @@ class DatabricksProvider(ServiceProvider):
|
|
|
942
1178
|
"""
|
|
943
1179
|
Ask Databricks to mint a fresh DB credential for this instance.
|
|
944
1180
|
"""
|
|
945
|
-
logger.
|
|
1181
|
+
logger.trace(
|
|
1182
|
+
"Generating password for lakebase instance", instance_name=instance_name
|
|
1183
|
+
)
|
|
946
1184
|
w: WorkspaceClient = self.w
|
|
947
1185
|
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
948
1186
|
request_id=str(uuid.uuid4()),
|
|
@@ -978,7 +1216,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
978
1216
|
# Validate that client_id is provided
|
|
979
1217
|
if not database.client_id:
|
|
980
1218
|
logger.warning(
|
|
981
|
-
|
|
1219
|
+
"client_id required to create instance role",
|
|
1220
|
+
instance_name=database.instance_name,
|
|
982
1221
|
)
|
|
983
1222
|
return
|
|
984
1223
|
|
|
@@ -988,7 +1227,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
988
1227
|
instance_name: str = database.instance_name
|
|
989
1228
|
|
|
990
1229
|
logger.debug(
|
|
991
|
-
|
|
1230
|
+
"Creating instance role",
|
|
1231
|
+
role_name=role_name,
|
|
1232
|
+
instance_name=instance_name,
|
|
1233
|
+
principal=client_id,
|
|
992
1234
|
)
|
|
993
1235
|
|
|
994
1236
|
try:
|
|
@@ -999,13 +1241,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
999
1241
|
name=role_name,
|
|
1000
1242
|
)
|
|
1001
1243
|
logger.info(
|
|
1002
|
-
|
|
1244
|
+
"Instance role already exists",
|
|
1245
|
+
role_name=role_name,
|
|
1246
|
+
instance_name=instance_name,
|
|
1003
1247
|
)
|
|
1004
1248
|
return
|
|
1005
1249
|
except NotFound:
|
|
1006
1250
|
# Role doesn't exist, proceed with creation
|
|
1007
1251
|
logger.debug(
|
|
1008
|
-
|
|
1252
|
+
"Instance role not found, creating new role", role_name=role_name
|
|
1009
1253
|
)
|
|
1010
1254
|
|
|
1011
1255
|
# Create the database instance role
|
|
@@ -1021,8 +1265,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1021
1265
|
database_instance_role=role,
|
|
1022
1266
|
)
|
|
1023
1267
|
|
|
1024
|
-
logger.
|
|
1025
|
-
|
|
1268
|
+
logger.success(
|
|
1269
|
+
"Instance role created successfully",
|
|
1270
|
+
role_name=role_name,
|
|
1271
|
+
instance_name=instance_name,
|
|
1026
1272
|
)
|
|
1027
1273
|
|
|
1028
1274
|
except Exception as e:
|
|
@@ -1034,13 +1280,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1034
1280
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
1035
1281
|
):
|
|
1036
1282
|
logger.info(
|
|
1037
|
-
|
|
1283
|
+
"Instance role was created concurrently",
|
|
1284
|
+
role_name=role_name,
|
|
1285
|
+
instance_name=instance_name,
|
|
1038
1286
|
)
|
|
1039
1287
|
return
|
|
1040
1288
|
|
|
1041
1289
|
# Re-raise unexpected errors
|
|
1042
1290
|
logger.error(
|
|
1043
|
-
|
|
1291
|
+
"Error creating instance role",
|
|
1292
|
+
role_name=role_name,
|
|
1293
|
+
instance_name=instance_name,
|
|
1294
|
+
error=str(e),
|
|
1044
1295
|
)
|
|
1045
1296
|
raise
|
|
1046
1297
|
|
|
@@ -1050,9 +1301,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1050
1301
|
|
|
1051
1302
|
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1052
1303
|
Otherwise, tries to load prompts in this order:
|
|
1053
|
-
1. champion alias
|
|
1054
|
-
2. latest alias
|
|
1055
|
-
3.
|
|
1304
|
+
1. champion alias
|
|
1305
|
+
2. latest alias
|
|
1306
|
+
3. default alias
|
|
1307
|
+
4. Register default_template if provided (only if register_to_registry=True)
|
|
1308
|
+
5. Use default_template directly (fallback)
|
|
1309
|
+
|
|
1310
|
+
The auto_register field controls whether the default_template is automatically
|
|
1311
|
+
synced to the prompt registry:
|
|
1312
|
+
- If True (default): Auto-registers/updates the default_template in the registry
|
|
1313
|
+
- If False: Never registers, but can still load existing prompts from registry
|
|
1314
|
+
or use default_template directly as a local-only prompt
|
|
1056
1315
|
|
|
1057
1316
|
Args:
|
|
1058
1317
|
prompt_model: The prompt model configuration
|
|
@@ -1063,542 +1322,266 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1063
1322
|
Raises:
|
|
1064
1323
|
ValueError: If no prompt can be loaded from any source
|
|
1065
1324
|
"""
|
|
1325
|
+
|
|
1066
1326
|
prompt_name: str = prompt_model.full_name
|
|
1067
1327
|
|
|
1068
|
-
# If explicit version or alias is specified, use it directly
|
|
1328
|
+
# If explicit version or alias is specified, use it directly
|
|
1069
1329
|
if prompt_model.version or prompt_model.alias:
|
|
1070
1330
|
try:
|
|
1071
1331
|
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
1332
|
+
version_or_alias = (
|
|
1333
|
+
f"version {prompt_model.version}"
|
|
1334
|
+
if prompt_model.version
|
|
1335
|
+
else f"alias {prompt_model.alias}"
|
|
1336
|
+
)
|
|
1072
1337
|
logger.debug(
|
|
1073
|
-
|
|
1074
|
-
|
|
1338
|
+
"Loaded prompt with explicit version/alias",
|
|
1339
|
+
prompt_name=prompt_name,
|
|
1340
|
+
version_or_alias=version_or_alias,
|
|
1075
1341
|
)
|
|
1076
1342
|
return prompt_version
|
|
1077
1343
|
except Exception as e:
|
|
1344
|
+
version_or_alias = (
|
|
1345
|
+
f"version {prompt_model.version}"
|
|
1346
|
+
if prompt_model.version
|
|
1347
|
+
else f"alias {prompt_model.alias}"
|
|
1348
|
+
)
|
|
1078
1349
|
logger.warning(
|
|
1079
|
-
|
|
1080
|
-
|
|
1350
|
+
"Failed to load prompt with explicit version/alias",
|
|
1351
|
+
prompt_name=prompt_name,
|
|
1352
|
+
version_or_alias=version_or_alias,
|
|
1353
|
+
error=str(e),
|
|
1081
1354
|
)
|
|
1082
|
-
# Fall through to
|
|
1083
|
-
else:
|
|
1084
|
-
# No explicit version/alias specified - check if default_template needs syncing first
|
|
1085
|
-
logger.debug(
|
|
1086
|
-
f"No explicit version/alias specified for '{prompt_name}', "
|
|
1087
|
-
"checking if default_template needs syncing"
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
# If we have a default_template, check if it differs from what's in the registry
|
|
1091
|
-
# This ensures we always sync config changes before returning any alias
|
|
1092
|
-
if prompt_model.default_template:
|
|
1093
|
-
try:
|
|
1094
|
-
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1095
|
-
default_version: PromptVersion = load_prompt(default_uri)
|
|
1355
|
+
# Fall through to try other methods
|
|
1096
1356
|
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
)
|
|
1104
|
-
return self._sync_default_template_to_registry(
|
|
1105
|
-
prompt_name,
|
|
1106
|
-
prompt_model.default_template,
|
|
1107
|
-
prompt_model.description,
|
|
1108
|
-
)
|
|
1109
|
-
except Exception as e:
|
|
1110
|
-
logger.debug(f"Could not check default alias for sync: {e}")
|
|
1111
|
-
|
|
1112
|
-
# Now try aliases in order: champion → latest → default
|
|
1113
|
-
logger.debug(
|
|
1114
|
-
f"Trying fallback order for '{prompt_name}': champion → latest → default"
|
|
1115
|
-
)
|
|
1116
|
-
|
|
1117
|
-
# Try champion alias first
|
|
1118
|
-
try:
|
|
1119
|
-
champion_uri: str = f"prompts:/{prompt_name}@champion"
|
|
1120
|
-
prompt_version: PromptVersion = load_prompt(champion_uri)
|
|
1121
|
-
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1122
|
-
return prompt_version
|
|
1123
|
-
except Exception as e:
|
|
1124
|
-
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1125
|
-
|
|
1126
|
-
# Try latest alias next
|
|
1127
|
-
try:
|
|
1128
|
-
latest_uri: str = f"prompts:/{prompt_name}@latest"
|
|
1129
|
-
prompt_version: PromptVersion = load_prompt(latest_uri)
|
|
1130
|
-
logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
|
|
1131
|
-
return prompt_version
|
|
1132
|
-
except Exception as e:
|
|
1133
|
-
logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
|
|
1357
|
+
# Try to load in priority order: champion → default (with sync check)
|
|
1358
|
+
logger.trace(
|
|
1359
|
+
"Trying prompt fallback order",
|
|
1360
|
+
prompt_name=prompt_name,
|
|
1361
|
+
order="champion → default",
|
|
1362
|
+
)
|
|
1134
1363
|
|
|
1135
|
-
|
|
1364
|
+
# First, sync default alias if template has changed (even if champion exists)
|
|
1365
|
+
# Only do this if auto_register is True
|
|
1366
|
+
if prompt_model.default_template and prompt_model.auto_register:
|
|
1136
1367
|
try:
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1140
|
-
return prompt_version
|
|
1141
|
-
except Exception as e:
|
|
1142
|
-
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1368
|
+
# Try to load existing default
|
|
1369
|
+
existing_default = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1143
1370
|
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1371
|
+
# Check if champion exists and if it matches default
|
|
1372
|
+
champion_matches_default = False
|
|
1373
|
+
try:
|
|
1374
|
+
existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1375
|
+
champion_matches_default = (
|
|
1376
|
+
existing_champion.version == existing_default.version
|
|
1377
|
+
)
|
|
1378
|
+
status = (
|
|
1379
|
+
"tracking" if champion_matches_default else "pinned separately"
|
|
1380
|
+
)
|
|
1381
|
+
logger.trace(
|
|
1382
|
+
"Champion vs default version",
|
|
1383
|
+
prompt_name=prompt_name,
|
|
1384
|
+
champion_version=existing_champion.version,
|
|
1385
|
+
default_version=existing_default.version,
|
|
1386
|
+
status=status,
|
|
1387
|
+
)
|
|
1388
|
+
except Exception:
|
|
1389
|
+
# No champion exists
|
|
1390
|
+
logger.trace("No champion alias found", prompt_name=prompt_name)
|
|
1164
1391
|
|
|
1165
|
-
|
|
1166
|
-
# Check if default alias already has the same template
|
|
1167
|
-
try:
|
|
1168
|
-
logger.debug(f"Loading prompt '{prompt_name}' from registry...")
|
|
1169
|
-
existing: PromptVersion = mlflow.genai.load_prompt(
|
|
1170
|
-
f"prompts:/{prompt_name}@default"
|
|
1171
|
-
)
|
|
1392
|
+
# Check if default_template differs from existing default
|
|
1172
1393
|
if (
|
|
1173
|
-
|
|
1174
|
-
|
|
1394
|
+
existing_default.template.strip()
|
|
1395
|
+
!= prompt_model.default_template.strip()
|
|
1175
1396
|
):
|
|
1176
|
-
logger.
|
|
1397
|
+
logger.info(
|
|
1398
|
+
"Default template changed, registering new version",
|
|
1399
|
+
prompt_name=prompt_name,
|
|
1400
|
+
)
|
|
1177
1401
|
|
|
1178
|
-
#
|
|
1179
|
-
|
|
1180
|
-
try:
|
|
1181
|
-
latest_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1182
|
-
f"prompts:/{prompt_name}@latest"
|
|
1183
|
-
)
|
|
1184
|
-
logger.debug(
|
|
1185
|
-
f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
|
|
1186
|
-
)
|
|
1187
|
-
except Exception:
|
|
1402
|
+
# Only update champion if it was pointing to the old default
|
|
1403
|
+
if champion_matches_default:
|
|
1188
1404
|
logger.info(
|
|
1189
|
-
|
|
1405
|
+
"Champion was tracking default, will update to new version",
|
|
1406
|
+
prompt_name=prompt_name,
|
|
1407
|
+
old_version=existing_default.version,
|
|
1190
1408
|
)
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
alias="latest",
|
|
1194
|
-
version=existing.version,
|
|
1195
|
-
)
|
|
1196
|
-
|
|
1197
|
-
# Ensure champion alias exists for first-time deployments
|
|
1198
|
-
try:
|
|
1199
|
-
champion_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1200
|
-
f"prompts:/{prompt_name}@champion"
|
|
1201
|
-
)
|
|
1202
|
-
logger.debug(
|
|
1203
|
-
f"Champion alias already exists for '{prompt_name}' pointing to version {champion_version.version}"
|
|
1204
|
-
)
|
|
1205
|
-
except Exception:
|
|
1409
|
+
set_champion = True
|
|
1410
|
+
else:
|
|
1206
1411
|
logger.info(
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
mlflow.genai.set_prompt_alias(
|
|
1210
|
-
name=prompt_name,
|
|
1211
|
-
alias="champion",
|
|
1212
|
-
version=existing.version,
|
|
1412
|
+
"Champion is pinned separately, preserving it",
|
|
1413
|
+
prompt_name=prompt_name,
|
|
1213
1414
|
)
|
|
1415
|
+
set_champion = False
|
|
1214
1416
|
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1417
|
+
self._register_default_template(
|
|
1418
|
+
prompt_name,
|
|
1419
|
+
prompt_model.default_template,
|
|
1420
|
+
prompt_model.description,
|
|
1421
|
+
set_champion=set_champion,
|
|
1422
|
+
)
|
|
1423
|
+
except Exception as e:
|
|
1424
|
+
# No default exists yet, register it
|
|
1425
|
+
logger.trace(
|
|
1426
|
+
"No default alias found", prompt_name=prompt_name, error=str(e)
|
|
1219
1427
|
)
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
alias="default",
|
|
1236
|
-
version=prompt_version.version,
|
|
1237
|
-
)
|
|
1238
|
-
mlflow.genai.set_prompt_alias(
|
|
1239
|
-
name=prompt_name,
|
|
1240
|
-
alias="latest",
|
|
1241
|
-
version=prompt_version.version,
|
|
1242
|
-
)
|
|
1243
|
-
mlflow.genai.set_prompt_alias(
|
|
1244
|
-
name=prompt_name,
|
|
1245
|
-
alias="champion",
|
|
1246
|
-
version=prompt_version.version,
|
|
1428
|
+
logger.info(
|
|
1429
|
+
"Registering default template as default alias",
|
|
1430
|
+
prompt_name=prompt_name,
|
|
1431
|
+
)
|
|
1432
|
+
# First registration - set both default and champion
|
|
1433
|
+
self._register_default_template(
|
|
1434
|
+
prompt_name,
|
|
1435
|
+
prompt_model.default_template,
|
|
1436
|
+
prompt_model.description,
|
|
1437
|
+
set_champion=True,
|
|
1438
|
+
)
|
|
1439
|
+
elif prompt_model.default_template and not prompt_model.auto_register:
|
|
1440
|
+
logger.trace(
|
|
1441
|
+
"Prompt has auto_register=False, skipping registration",
|
|
1442
|
+
prompt_name=prompt_name,
|
|
1247
1443
|
)
|
|
1248
1444
|
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
)
|
|
1445
|
+
# 1. Try champion alias (highest priority for execution)
|
|
1446
|
+
try:
|
|
1447
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1448
|
+
logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
|
|
1252
1449
|
return prompt_version
|
|
1253
|
-
|
|
1254
1450
|
except Exception as e:
|
|
1255
|
-
logger.
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
) from e
|
|
1259
|
-
|
|
1260
|
-
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1261
|
-
"""
|
|
1262
|
-
Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
|
|
1263
|
-
|
|
1264
|
-
This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
|
|
1265
|
-
https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
|
|
1451
|
+
logger.trace(
|
|
1452
|
+
"Champion alias not found", prompt_name=prompt_name, error=str(e)
|
|
1453
|
+
)
|
|
1266
1454
|
|
|
1267
|
-
|
|
1268
|
-
|
|
1455
|
+
# 2. Try default alias (already synced above)
|
|
1456
|
+
if prompt_model.default_template:
|
|
1457
|
+
try:
|
|
1458
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1459
|
+
logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
|
|
1460
|
+
return prompt_version
|
|
1461
|
+
except Exception as e:
|
|
1462
|
+
# Should not happen since we just registered it above, but handle anyway
|
|
1463
|
+
logger.trace(
|
|
1464
|
+
"Default alias not found", prompt_name=prompt_name, error=str(e)
|
|
1465
|
+
)
|
|
1269
1466
|
|
|
1270
|
-
|
|
1271
|
-
PromptModel: The optimized prompt with new URI
|
|
1272
|
-
"""
|
|
1273
|
-
from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
|
|
1274
|
-
from mlflow.genai.scorers import Correctness
|
|
1275
|
-
|
|
1276
|
-
from dao_ai.config import AgentModel, PromptModel
|
|
1277
|
-
|
|
1278
|
-
logger.info(f"Optimizing prompt: {optimization.name}")
|
|
1279
|
-
|
|
1280
|
-
# Get agent and prompt (prompt is guaranteed to be set by validator)
|
|
1281
|
-
agent_model: AgentModel = optimization.agent
|
|
1282
|
-
prompt: PromptModel = optimization.prompt # type: ignore[assignment]
|
|
1283
|
-
agent_model.prompt = prompt.uri
|
|
1284
|
-
|
|
1285
|
-
print(f"prompt={agent_model.prompt}")
|
|
1286
|
-
# Log the prompt URI scheme being used
|
|
1287
|
-
# Supports three schemes:
|
|
1288
|
-
# 1. Specific version: "prompts:/qa/1" (when version is specified)
|
|
1289
|
-
# 2. Alias: "prompts:/qa@champion" (when alias is specified)
|
|
1290
|
-
# 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
|
|
1291
|
-
prompt_uri: str = prompt.uri
|
|
1292
|
-
logger.info(f"Using prompt URI for optimization: {prompt_uri}")
|
|
1293
|
-
|
|
1294
|
-
# Load the specific prompt version by URI for comparison
|
|
1295
|
-
# Try to load the exact version specified, but if it doesn't exist,
|
|
1296
|
-
# use get_prompt to create it from default_template
|
|
1297
|
-
prompt_version: PromptVersion
|
|
1467
|
+
# 3. Try latest alias as final fallback
|
|
1298
1468
|
try:
|
|
1299
|
-
prompt_version = load_prompt(
|
|
1300
|
-
logger.info(
|
|
1469
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
|
|
1470
|
+
logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
|
|
1471
|
+
return prompt_version
|
|
1301
1472
|
except Exception as e:
|
|
1473
|
+
logger.trace(
|
|
1474
|
+
"Latest alias not found", prompt_name=prompt_name, error=str(e)
|
|
1475
|
+
)
|
|
1476
|
+
|
|
1477
|
+
# 4. Final fallback: use default_template directly if available
|
|
1478
|
+
if prompt_model.default_template:
|
|
1302
1479
|
logger.warning(
|
|
1303
|
-
|
|
1304
|
-
|
|
1480
|
+
"Could not load prompt from registry, using default_template directly",
|
|
1481
|
+
prompt_name=prompt_name,
|
|
1305
1482
|
)
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1483
|
+
return PromptVersion(
|
|
1484
|
+
name=prompt_name,
|
|
1485
|
+
version=1,
|
|
1486
|
+
template=prompt_model.default_template,
|
|
1487
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1310
1488
|
)
|
|
1311
1489
|
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
dataset = get_dataset(name=optimization.dataset)
|
|
1317
|
-
else:
|
|
1318
|
-
dataset = optimization.dataset.as_dataset()
|
|
1319
|
-
|
|
1320
|
-
# Set up reflection model for the optimizer
|
|
1321
|
-
reflection_model_name: str
|
|
1322
|
-
if optimization.reflection_model:
|
|
1323
|
-
if isinstance(optimization.reflection_model, str):
|
|
1324
|
-
reflection_model_name = optimization.reflection_model
|
|
1325
|
-
else:
|
|
1326
|
-
reflection_model_name = optimization.reflection_model.uri
|
|
1327
|
-
else:
|
|
1328
|
-
reflection_model_name = agent_model.model.uri
|
|
1329
|
-
logger.debug(f"Using reflection model: {reflection_model_name}")
|
|
1330
|
-
|
|
1331
|
-
# Create the GepaPromptOptimizer
|
|
1332
|
-
optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
|
|
1333
|
-
reflection_model=reflection_model_name,
|
|
1334
|
-
max_metric_calls=optimization.num_candidates,
|
|
1335
|
-
display_progress_bar=True,
|
|
1490
|
+
raise ValueError(
|
|
1491
|
+
f"Prompt '{prompt_name}' not found in registry "
|
|
1492
|
+
"(tried champion, default, latest aliases) "
|
|
1493
|
+
"and no default_template provided"
|
|
1336
1494
|
)
|
|
1337
1495
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
scorer_model = agent_model.model.uri # Use Databricks default
|
|
1347
|
-
logger.debug(f"Using scorer with model: {scorer_model}")
|
|
1348
|
-
|
|
1349
|
-
scorers: list[Correctness] = [Correctness(model=scorer_model)]
|
|
1350
|
-
|
|
1351
|
-
# Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
|
|
1352
|
-
# DO NOT overwrite with prompt_version.uri as that uses fallback logic
|
|
1353
|
-
logger.debug(f"Optimizing prompt: {prompt_uri}")
|
|
1354
|
-
|
|
1355
|
-
agent: ResponsesAgent = agent_model.as_responses_agent()
|
|
1356
|
-
|
|
1357
|
-
# Create predict function that will be optimized
|
|
1358
|
-
def predict_fn(**inputs: dict[str, Any]) -> str:
|
|
1359
|
-
"""Prediction function that uses the ResponsesAgent with ChatPayload.
|
|
1360
|
-
|
|
1361
|
-
The agent already has the prompt referenced/applied, so we just need to
|
|
1362
|
-
convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
|
|
1363
|
-
|
|
1364
|
-
Args:
|
|
1365
|
-
**inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
|
|
1366
|
-
|
|
1367
|
-
Returns:
|
|
1368
|
-
str: The agent's response content
|
|
1369
|
-
"""
|
|
1370
|
-
from mlflow.types.responses import (
|
|
1371
|
-
ResponsesAgentRequest,
|
|
1372
|
-
ResponsesAgentResponse,
|
|
1373
|
-
)
|
|
1374
|
-
from mlflow.types.responses_helpers import Message
|
|
1375
|
-
|
|
1376
|
-
from dao_ai.config import ChatPayload
|
|
1377
|
-
|
|
1378
|
-
# Verify agent is accessible (should be captured from outer scope)
|
|
1379
|
-
if agent is None:
|
|
1380
|
-
raise RuntimeError(
|
|
1381
|
-
"Agent object is not available in predict_fn. "
|
|
1382
|
-
"This may indicate a serialization issue with the ResponsesAgent."
|
|
1383
|
-
)
|
|
1384
|
-
|
|
1385
|
-
# Convert inputs to ChatPayload
|
|
1386
|
-
chat_payload: ChatPayload = ChatPayload(**inputs)
|
|
1387
|
-
|
|
1388
|
-
# Convert ChatPayload messages to MLflow Message format
|
|
1389
|
-
mlflow_messages: list[Message] = [
|
|
1390
|
-
Message(role=msg.role, content=msg.content)
|
|
1391
|
-
for msg in chat_payload.messages
|
|
1392
|
-
]
|
|
1393
|
-
|
|
1394
|
-
# Create ResponsesAgentRequest
|
|
1395
|
-
request: ResponsesAgentRequest = ResponsesAgentRequest(
|
|
1396
|
-
input=mlflow_messages,
|
|
1397
|
-
custom_inputs=chat_payload.custom_inputs,
|
|
1398
|
-
)
|
|
1496
|
+
def _register_default_template(
|
|
1497
|
+
self,
|
|
1498
|
+
prompt_name: str,
|
|
1499
|
+
default_template: str,
|
|
1500
|
+
description: str | None = None,
|
|
1501
|
+
set_champion: bool = True,
|
|
1502
|
+
) -> PromptVersion:
|
|
1503
|
+
"""Register default_template as a new prompt version.
|
|
1399
1504
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
if response.output and len(response.output) > 0:
|
|
1404
|
-
content = response.output[0].content
|
|
1405
|
-
logger.debug(f"Response content type: {type(content)}")
|
|
1406
|
-
logger.debug(f"Response content: {content}")
|
|
1407
|
-
|
|
1408
|
-
# Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
|
|
1409
|
-
# Content can be:
|
|
1410
|
-
# - A string (return as is)
|
|
1411
|
-
# - A list of items with 'text' keys (extract and join)
|
|
1412
|
-
# - Other types (try to get 'text' attribute or convert to string)
|
|
1413
|
-
if isinstance(content, str):
|
|
1414
|
-
return content
|
|
1415
|
-
elif isinstance(content, list):
|
|
1416
|
-
text_parts = []
|
|
1417
|
-
for content_item in content:
|
|
1418
|
-
if isinstance(content_item, str):
|
|
1419
|
-
text_parts.append(content_item)
|
|
1420
|
-
elif isinstance(content_item, dict) and "text" in content_item:
|
|
1421
|
-
text_parts.append(content_item["text"])
|
|
1422
|
-
elif hasattr(content_item, "text"):
|
|
1423
|
-
text_parts.append(content_item.text)
|
|
1424
|
-
return "".join(text_parts) if text_parts else str(content)
|
|
1425
|
-
else:
|
|
1426
|
-
# Fallback for unknown types - try to extract text attribute
|
|
1427
|
-
return getattr(content, "text", str(content))
|
|
1428
|
-
else:
|
|
1429
|
-
return ""
|
|
1505
|
+
Registers the template and sets the 'default' alias.
|
|
1506
|
+
Optionally sets 'champion' alias if no champion exists.
|
|
1430
1507
|
|
|
1431
|
-
|
|
1432
|
-
|
|
1508
|
+
Args:
|
|
1509
|
+
prompt_name: Full name of the prompt
|
|
1510
|
+
default_template: The template content
|
|
1511
|
+
description: Optional description for commit message
|
|
1512
|
+
set_champion: Whether to also set champion alias (default: True)
|
|
1433
1513
|
|
|
1434
|
-
|
|
1435
|
-
|
|
1514
|
+
If registration fails (e.g., in Model Serving with restricted permissions),
|
|
1515
|
+
logs the error and raises.
|
|
1516
|
+
"""
|
|
1436
1517
|
logger.info(
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
from mlflow.genai.optimize.types import (
|
|
1441
|
-
PromptOptimizationResult,
|
|
1442
|
-
)
|
|
1443
|
-
|
|
1444
|
-
result: PromptOptimizationResult = optimize_prompts(
|
|
1445
|
-
predict_fn=predict_fn,
|
|
1446
|
-
train_data=dataset,
|
|
1447
|
-
prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
|
|
1448
|
-
optimizer=optimizer,
|
|
1449
|
-
scorers=scorers,
|
|
1450
|
-
enable_tracking=False, # Don't auto-register all candidates
|
|
1518
|
+
"Registering default template",
|
|
1519
|
+
prompt_name=prompt_name,
|
|
1520
|
+
set_champion=set_champion,
|
|
1451
1521
|
)
|
|
1452
1522
|
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
# Check if the optimized prompt is actually different from the original
|
|
1461
|
-
original_template: str = prompt_version.to_single_brace_format().strip()
|
|
1462
|
-
optimized_template: str = (
|
|
1463
|
-
optimized_prompt_version.to_single_brace_format().strip()
|
|
1464
|
-
)
|
|
1465
|
-
|
|
1466
|
-
# Normalize whitespace for more robust comparison
|
|
1467
|
-
original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
|
|
1468
|
-
optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
|
|
1469
|
-
|
|
1470
|
-
logger.debug(f"Original template length: {len(original_template)} chars")
|
|
1471
|
-
logger.debug(f"Optimized template length: {len(optimized_template)} chars")
|
|
1472
|
-
logger.debug(
|
|
1473
|
-
f"Templates identical: {original_normalized == optimized_normalized}"
|
|
1474
|
-
)
|
|
1475
|
-
|
|
1476
|
-
if original_normalized == optimized_normalized:
|
|
1477
|
-
logger.info(
|
|
1478
|
-
f"Optimized prompt is identical to original for '{prompt.full_name}'. "
|
|
1479
|
-
"No new version will be registered."
|
|
1480
|
-
)
|
|
1481
|
-
return prompt
|
|
1482
|
-
|
|
1483
|
-
logger.info("Optimized prompt is DIFFERENT from original")
|
|
1484
|
-
logger.info(
|
|
1485
|
-
f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
|
|
1486
|
-
)
|
|
1487
|
-
logger.debug(
|
|
1488
|
-
f"Original template (first 300 chars): {original_template[:300]}..."
|
|
1489
|
-
)
|
|
1490
|
-
logger.debug(
|
|
1491
|
-
f"Optimized template (first 300 chars): {optimized_template[:300]}..."
|
|
1523
|
+
try:
|
|
1524
|
+
commit_message = description or "Auto-synced from default_template"
|
|
1525
|
+
prompt_version = mlflow.genai.register_prompt(
|
|
1526
|
+
name=prompt_name,
|
|
1527
|
+
template=default_template,
|
|
1528
|
+
commit_message=commit_message,
|
|
1529
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1492
1530
|
)
|
|
1493
1531
|
|
|
1494
|
-
#
|
|
1495
|
-
should_register: bool = False
|
|
1496
|
-
has_improvement: bool = False
|
|
1497
|
-
|
|
1498
|
-
if (
|
|
1499
|
-
result.initial_eval_score is not None
|
|
1500
|
-
and result.final_eval_score is not None
|
|
1501
|
-
):
|
|
1502
|
-
logger.info("Evaluation scores:")
|
|
1503
|
-
logger.info(f" Initial score: {result.initial_eval_score}")
|
|
1504
|
-
logger.info(f" Final score: {result.final_eval_score}")
|
|
1505
|
-
|
|
1506
|
-
# Only register if there's improvement
|
|
1507
|
-
if result.final_eval_score > result.initial_eval_score:
|
|
1508
|
-
improvement: float = (
|
|
1509
|
-
(result.final_eval_score - result.initial_eval_score)
|
|
1510
|
-
/ result.initial_eval_score
|
|
1511
|
-
) * 100
|
|
1512
|
-
logger.info(
|
|
1513
|
-
f"Optimized prompt improved by {improvement:.2f}% "
|
|
1514
|
-
f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
|
|
1515
|
-
)
|
|
1516
|
-
should_register = True
|
|
1517
|
-
has_improvement = True
|
|
1518
|
-
else:
|
|
1519
|
-
logger.info(
|
|
1520
|
-
f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
|
|
1521
|
-
"No new version will be registered."
|
|
1522
|
-
)
|
|
1523
|
-
else:
|
|
1524
|
-
# No scores available - register anyway but warn
|
|
1525
|
-
logger.warning(
|
|
1526
|
-
"No evaluation scores available to compare performance. "
|
|
1527
|
-
"Registering optimized prompt without performance validation."
|
|
1528
|
-
)
|
|
1529
|
-
should_register = True
|
|
1530
|
-
|
|
1531
|
-
if not should_register:
|
|
1532
|
-
logger.info(
|
|
1533
|
-
f"Skipping registration for '{prompt.full_name}' (no improvement)"
|
|
1534
|
-
)
|
|
1535
|
-
return prompt
|
|
1536
|
-
|
|
1537
|
-
# Register the optimized prompt manually
|
|
1532
|
+
# Always set default alias
|
|
1538
1533
|
try:
|
|
1539
|
-
logger.
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
|
|
1544
|
-
tags={
|
|
1545
|
-
"dao_ai": dao_ai_version(),
|
|
1546
|
-
"target_model": agent_model.model.uri,
|
|
1547
|
-
},
|
|
1548
|
-
)
|
|
1549
|
-
logger.info(
|
|
1550
|
-
f"Registered optimized prompt as version {registered_version.version}"
|
|
1551
|
-
)
|
|
1552
|
-
|
|
1553
|
-
# Always set "latest" alias (represents most recently registered prompt)
|
|
1554
|
-
logger.info(
|
|
1555
|
-
f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1534
|
+
logger.debug(
|
|
1535
|
+
"Setting default alias",
|
|
1536
|
+
prompt_name=prompt_name,
|
|
1537
|
+
version=prompt_version.version,
|
|
1556
1538
|
)
|
|
1557
1539
|
mlflow.genai.set_prompt_alias(
|
|
1558
|
-
name=
|
|
1559
|
-
alias="latest",
|
|
1560
|
-
version=registered_version.version,
|
|
1540
|
+
name=prompt_name, alias="default", version=prompt_version.version
|
|
1561
1541
|
)
|
|
1562
|
-
logger.
|
|
1563
|
-
|
|
1542
|
+
logger.success(
|
|
1543
|
+
"Set default alias for prompt",
|
|
1544
|
+
prompt_name=prompt_name,
|
|
1545
|
+
version=prompt_version.version,
|
|
1546
|
+
)
|
|
1547
|
+
except Exception as alias_error:
|
|
1548
|
+
logger.warning(
|
|
1549
|
+
"Could not set default alias",
|
|
1550
|
+
prompt_name=prompt_name,
|
|
1551
|
+
error=str(alias_error),
|
|
1564
1552
|
)
|
|
1565
1553
|
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
logger.info(
|
|
1570
|
-
f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1571
|
-
)
|
|
1554
|
+
# Optionally set champion alias (only if no champion exists or explicitly requested)
|
|
1555
|
+
if set_champion:
|
|
1556
|
+
try:
|
|
1572
1557
|
mlflow.genai.set_prompt_alias(
|
|
1573
|
-
name=
|
|
1558
|
+
name=prompt_name,
|
|
1574
1559
|
alias="champion",
|
|
1575
|
-
version=
|
|
1560
|
+
version=prompt_version.version,
|
|
1576
1561
|
)
|
|
1577
|
-
logger.
|
|
1578
|
-
|
|
1562
|
+
logger.success(
|
|
1563
|
+
"Set champion alias for prompt",
|
|
1564
|
+
prompt_name=prompt_name,
|
|
1565
|
+
version=prompt_version.version,
|
|
1566
|
+
)
|
|
1567
|
+
except Exception as alias_error:
|
|
1568
|
+
logger.warning(
|
|
1569
|
+
"Could not set champion alias",
|
|
1570
|
+
prompt_name=prompt_name,
|
|
1571
|
+
error=str(alias_error),
|
|
1579
1572
|
)
|
|
1580
1573
|
|
|
1581
|
-
|
|
1582
|
-
tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
|
|
1583
|
-
tags["target_model"] = agent_model.model.uri
|
|
1584
|
-
tags["dao_ai"] = dao_ai_version()
|
|
1585
|
-
|
|
1586
|
-
# Return the optimized prompt with the appropriate alias
|
|
1587
|
-
# Use "champion" if there was improvement, otherwise "latest"
|
|
1588
|
-
result_alias: str = "champion" if has_improvement else "latest"
|
|
1589
|
-
return PromptModel(
|
|
1590
|
-
name=prompt.name,
|
|
1591
|
-
schema=prompt.schema_model,
|
|
1592
|
-
description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
|
|
1593
|
-
alias=result_alias,
|
|
1594
|
-
tags=tags,
|
|
1595
|
-
)
|
|
1574
|
+
return prompt_version
|
|
1596
1575
|
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1576
|
+
except Exception as reg_error:
|
|
1577
|
+
logger.error(
|
|
1578
|
+
"Failed to register prompt - please register from notebook with write permissions",
|
|
1579
|
+
prompt_name=prompt_name,
|
|
1580
|
+
error=str(reg_error),
|
|
1581
|
+
)
|
|
1582
|
+
return PromptVersion(
|
|
1583
|
+
name=prompt_name,
|
|
1584
|
+
version=1,
|
|
1585
|
+
template=default_template,
|
|
1586
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1587
|
+
)
|