dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import base64
|
|
2
|
-
import re
|
|
3
2
|
import uuid
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
from typing import Any, Callable, Final, Sequence
|
|
@@ -32,14 +31,12 @@ from mlflow import MlflowClient
|
|
|
32
31
|
from mlflow.entities import Experiment
|
|
33
32
|
from mlflow.entities.model_registry import PromptVersion
|
|
34
33
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
35
|
-
from mlflow.genai.datasets import EvaluationDataset, get_dataset
|
|
36
34
|
from mlflow.genai.prompts import load_prompt
|
|
37
35
|
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
|
|
38
36
|
from mlflow.models.model import ModelInfo
|
|
39
37
|
from mlflow.models.resources import (
|
|
40
38
|
DatabricksResource,
|
|
41
39
|
)
|
|
42
|
-
from mlflow.pyfunc import ResponsesAgent
|
|
43
40
|
from pyspark.sql import SparkSession
|
|
44
41
|
from unitycatalog.ai.core.base import FunctionExecutionResult
|
|
45
42
|
from unitycatalog.ai.core.databricks import DatabricksFunctionClient
|
|
@@ -49,6 +46,7 @@ from dao_ai.config import (
|
|
|
49
46
|
AppConfig,
|
|
50
47
|
ConnectionModel,
|
|
51
48
|
DatabaseModel,
|
|
49
|
+
DatabricksAppModel,
|
|
52
50
|
DatasetModel,
|
|
53
51
|
FunctionModel,
|
|
54
52
|
GenieRoomModel,
|
|
@@ -57,7 +55,6 @@ from dao_ai.config import (
|
|
|
57
55
|
IsDatabricksResource,
|
|
58
56
|
LLMModel,
|
|
59
57
|
PromptModel,
|
|
60
|
-
PromptOptimizationModel,
|
|
61
58
|
SchemaModel,
|
|
62
59
|
TableModel,
|
|
63
60
|
UnityCatalogFunctionSqlModel,
|
|
@@ -73,6 +70,7 @@ from dao_ai.utils import (
|
|
|
73
70
|
get_installed_packages,
|
|
74
71
|
is_installed,
|
|
75
72
|
is_lib_provided,
|
|
73
|
+
normalize_host,
|
|
76
74
|
normalize_name,
|
|
77
75
|
)
|
|
78
76
|
from dao_ai.vector_search import endpoint_exists, index_exists
|
|
@@ -94,15 +92,18 @@ def _workspace_client(
|
|
|
94
92
|
Create a WorkspaceClient instance with the provided parameters.
|
|
95
93
|
If no parameters are provided, it will use the default configuration.
|
|
96
94
|
"""
|
|
97
|
-
|
|
95
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
96
|
+
normalized_host = normalize_host(workspace_host)
|
|
97
|
+
|
|
98
|
+
if client_id and client_secret and normalized_host:
|
|
98
99
|
return WorkspaceClient(
|
|
99
|
-
host=
|
|
100
|
+
host=normalized_host,
|
|
100
101
|
client_id=client_id,
|
|
101
102
|
client_secret=client_secret,
|
|
102
103
|
auth_type="oauth-m2m",
|
|
103
104
|
)
|
|
104
105
|
elif pat:
|
|
105
|
-
return WorkspaceClient(host=
|
|
106
|
+
return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
|
|
106
107
|
else:
|
|
107
108
|
return WorkspaceClient()
|
|
108
109
|
|
|
@@ -117,15 +118,18 @@ def _vector_search_client(
|
|
|
117
118
|
Create a VectorSearchClient instance with the provided parameters.
|
|
118
119
|
If no parameters are provided, it will use the default configuration.
|
|
119
120
|
"""
|
|
120
|
-
|
|
121
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
122
|
+
normalized_host = normalize_host(workspace_host)
|
|
123
|
+
|
|
124
|
+
if client_id and client_secret and normalized_host:
|
|
121
125
|
return VectorSearchClient(
|
|
122
|
-
workspace_url=
|
|
126
|
+
workspace_url=normalized_host,
|
|
123
127
|
service_principal_client_id=client_id,
|
|
124
128
|
service_principal_client_secret=client_secret,
|
|
125
129
|
)
|
|
126
|
-
elif pat and
|
|
130
|
+
elif pat and normalized_host:
|
|
127
131
|
return VectorSearchClient(
|
|
128
|
-
workspace_url=
|
|
132
|
+
workspace_url=normalized_host,
|
|
129
133
|
personal_access_token=pat,
|
|
130
134
|
)
|
|
131
135
|
else:
|
|
@@ -177,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
177
181
|
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
178
182
|
if experiment is None:
|
|
179
183
|
experiment_id: str = mlflow.create_experiment(name=experiment_name)
|
|
180
|
-
logger.
|
|
181
|
-
|
|
184
|
+
logger.success(
|
|
185
|
+
"Created new MLflow experiment",
|
|
186
|
+
experiment_name=experiment_name,
|
|
187
|
+
experiment_id=experiment_id,
|
|
182
188
|
)
|
|
183
189
|
experiment = mlflow.get_experiment(experiment_id)
|
|
184
190
|
return experiment
|
|
185
191
|
|
|
186
192
|
def create_token(self) -> str:
|
|
187
193
|
current_user: User = self.w.current_user.me()
|
|
188
|
-
logger.debug(
|
|
194
|
+
logger.debug("Authenticated to Databricks", user=str(current_user))
|
|
189
195
|
headers: dict[str, str] = self.w.config.authenticate()
|
|
190
196
|
token: str = headers["Authorization"].replace("Bearer ", "")
|
|
191
197
|
return token
|
|
@@ -197,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
|
|
|
197
203
|
secret_response: GetSecretResponse = self.w.secrets.get_secret(
|
|
198
204
|
secret_scope, secret_key
|
|
199
205
|
)
|
|
200
|
-
logger.
|
|
206
|
+
logger.trace(
|
|
207
|
+
"Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
|
|
208
|
+
)
|
|
201
209
|
encoded_secret: str = secret_response.value
|
|
202
210
|
decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
|
|
203
211
|
return decoded_secret
|
|
204
212
|
except NotFound:
|
|
205
213
|
logger.warning(
|
|
206
|
-
|
|
214
|
+
"Secret not found, using default value",
|
|
215
|
+
secret_key=secret_key,
|
|
216
|
+
secret_scope=secret_scope,
|
|
207
217
|
)
|
|
208
218
|
except Exception as e:
|
|
209
219
|
logger.error(
|
|
210
|
-
|
|
220
|
+
"Error retrieving secret",
|
|
221
|
+
secret_key=secret_key,
|
|
222
|
+
secret_scope=secret_scope,
|
|
223
|
+
error=str(e),
|
|
211
224
|
)
|
|
212
225
|
|
|
213
226
|
return default_value
|
|
@@ -216,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
216
229
|
self,
|
|
217
230
|
config: AppConfig,
|
|
218
231
|
) -> ModelInfo:
|
|
219
|
-
logger.
|
|
232
|
+
logger.info("Creating agent")
|
|
220
233
|
mlflow.set_registry_uri("databricks-uc")
|
|
221
234
|
|
|
235
|
+
# Set up experiment for proper tracking
|
|
236
|
+
experiment: Experiment = self.get_or_create_experiment(config)
|
|
237
|
+
mlflow.set_experiment(experiment_id=experiment.experiment_id)
|
|
238
|
+
logger.debug(
|
|
239
|
+
"Using MLflow experiment",
|
|
240
|
+
experiment_name=experiment.name,
|
|
241
|
+
experiment_id=experiment.experiment_id,
|
|
242
|
+
)
|
|
243
|
+
|
|
222
244
|
llms: Sequence[LLMModel] = list(config.resources.llms.values())
|
|
223
245
|
vector_indexes: Sequence[IndexModel] = list(
|
|
224
246
|
config.resources.vector_stores.values()
|
|
@@ -236,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
236
258
|
)
|
|
237
259
|
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
238
260
|
volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
|
|
261
|
+
apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
|
|
239
262
|
|
|
240
263
|
resources: Sequence[IsDatabricksResource] = (
|
|
241
264
|
llms
|
|
@@ -247,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
247
270
|
+ connections
|
|
248
271
|
+ databases
|
|
249
272
|
+ volumes
|
|
273
|
+
+ apps
|
|
250
274
|
)
|
|
251
275
|
|
|
252
276
|
# Flatten all resources from all models into a single list
|
|
@@ -260,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
260
284
|
for resource in r.as_resources()
|
|
261
285
|
if not r.on_behalf_of_user
|
|
262
286
|
]
|
|
263
|
-
logger.
|
|
287
|
+
logger.trace(
|
|
288
|
+
"System resources identified",
|
|
289
|
+
count=len(system_resources),
|
|
290
|
+
resources=[r.name for r in system_resources],
|
|
291
|
+
)
|
|
264
292
|
|
|
265
293
|
system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
|
|
266
294
|
resources=system_resources
|
|
267
295
|
)
|
|
268
|
-
logger.
|
|
296
|
+
logger.trace("System auth policy created", policy=str(system_auth_policy))
|
|
269
297
|
|
|
270
298
|
api_scopes: Sequence[str] = list(
|
|
271
299
|
set(
|
|
@@ -277,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
|
|
|
277
305
|
]
|
|
278
306
|
)
|
|
279
307
|
)
|
|
280
|
-
logger.
|
|
308
|
+
logger.trace("API scopes identified", scopes=api_scopes)
|
|
281
309
|
|
|
282
310
|
user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
|
|
283
|
-
logger.
|
|
311
|
+
logger.trace("User auth policy created", policy=str(user_auth_policy))
|
|
284
312
|
|
|
285
313
|
auth_policy: AuthPolicy = AuthPolicy(
|
|
286
314
|
system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
|
|
287
315
|
)
|
|
288
|
-
logger.debug(
|
|
316
|
+
logger.debug(
|
|
317
|
+
"Auth policy created",
|
|
318
|
+
has_system_auth=system_auth_policy is not None,
|
|
319
|
+
has_user_auth=user_auth_policy is not None,
|
|
320
|
+
)
|
|
289
321
|
|
|
290
322
|
code_paths: list[str] = config.app.code_paths
|
|
291
323
|
for path in code_paths:
|
|
@@ -312,18 +344,42 @@ class DatabricksProvider(ServiceProvider):
|
|
|
312
344
|
|
|
313
345
|
pip_requirements += get_installed_packages()
|
|
314
346
|
|
|
315
|
-
logger.
|
|
316
|
-
logger.
|
|
347
|
+
logger.trace("Pip requirements prepared", count=len(pip_requirements))
|
|
348
|
+
logger.trace("Code paths prepared", count=len(code_paths))
|
|
317
349
|
|
|
318
350
|
run_name: str = normalize_name(config.app.name)
|
|
319
|
-
logger.debug(
|
|
320
|
-
|
|
351
|
+
logger.debug(
|
|
352
|
+
"Agent run configuration",
|
|
353
|
+
run_name=run_name,
|
|
354
|
+
model_path=model_path.as_posix(),
|
|
355
|
+
)
|
|
321
356
|
|
|
322
357
|
input_example: dict[str, Any] = None
|
|
323
358
|
if config.app.input_example:
|
|
324
359
|
input_example = config.app.input_example.model_dump()
|
|
325
360
|
|
|
326
|
-
logger.
|
|
361
|
+
logger.trace("Input example configured", has_example=input_example is not None)
|
|
362
|
+
|
|
363
|
+
# Create conda environment with configured Python version
|
|
364
|
+
# This allows deploying from environments with different Python versions
|
|
365
|
+
# (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
|
|
366
|
+
target_python_version: str = config.app.python_version
|
|
367
|
+
logger.debug("Target Python version configured", version=target_python_version)
|
|
368
|
+
|
|
369
|
+
conda_env: dict[str, Any] = {
|
|
370
|
+
"name": "mlflow-env",
|
|
371
|
+
"channels": ["conda-forge"],
|
|
372
|
+
"dependencies": [
|
|
373
|
+
f"python={target_python_version}",
|
|
374
|
+
"pip",
|
|
375
|
+
{"pip": list(pip_requirements)},
|
|
376
|
+
],
|
|
377
|
+
}
|
|
378
|
+
logger.trace(
|
|
379
|
+
"Conda environment configured",
|
|
380
|
+
python_version=target_python_version,
|
|
381
|
+
pip_packages_count=len(pip_requirements),
|
|
382
|
+
)
|
|
327
383
|
|
|
328
384
|
with mlflow.start_run(run_name=run_name):
|
|
329
385
|
mlflow.set_tag("type", "agent")
|
|
@@ -333,7 +389,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
333
389
|
code_paths=code_paths,
|
|
334
390
|
model_config=config.model_dump(mode="json", by_alias=True),
|
|
335
391
|
name="agent",
|
|
336
|
-
|
|
392
|
+
conda_env=conda_env,
|
|
337
393
|
input_example=input_example,
|
|
338
394
|
# resources=all_resources,
|
|
339
395
|
auth_policy=auth_policy,
|
|
@@ -344,8 +400,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
344
400
|
model_version: ModelVersion = mlflow.register_model(
|
|
345
401
|
name=registered_model_name, model_uri=logged_agent_info.model_uri
|
|
346
402
|
)
|
|
347
|
-
logger.
|
|
348
|
-
|
|
403
|
+
logger.success(
|
|
404
|
+
"Model registered",
|
|
405
|
+
model_name=registered_model_name,
|
|
406
|
+
version=model_version.version,
|
|
349
407
|
)
|
|
350
408
|
|
|
351
409
|
client: MlflowClient = MlflowClient()
|
|
@@ -357,7 +415,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
357
415
|
key="dao_ai",
|
|
358
416
|
value=dao_ai_version(),
|
|
359
417
|
)
|
|
360
|
-
logger.
|
|
418
|
+
logger.trace("Set dao_ai tag on model version", version=model_version.version)
|
|
361
419
|
|
|
362
420
|
client.set_registered_model_alias(
|
|
363
421
|
name=registered_model_name,
|
|
@@ -374,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
374
432
|
aliased_model: ModelVersion = client.get_model_version_by_alias(
|
|
375
433
|
registered_model_name, config.app.alias
|
|
376
434
|
)
|
|
377
|
-
logger.
|
|
378
|
-
|
|
435
|
+
logger.info(
|
|
436
|
+
"Model aliased",
|
|
437
|
+
model_name=registered_model_name,
|
|
438
|
+
alias=config.app.alias,
|
|
439
|
+
version=aliased_model.version,
|
|
379
440
|
)
|
|
380
441
|
|
|
381
442
|
def deploy_agent(self, config: AppConfig) -> None:
|
|
382
|
-
logger.
|
|
443
|
+
logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
|
|
383
444
|
mlflow.set_registry_uri("databricks-uc")
|
|
384
445
|
|
|
385
446
|
endpoint_name: str = config.app.endpoint_name
|
|
@@ -400,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
400
461
|
agents.get_deployments(endpoint_name)
|
|
401
462
|
endpoint_exists = True
|
|
402
463
|
logger.debug(
|
|
403
|
-
|
|
464
|
+
"Endpoint already exists, updating", endpoint_name=endpoint_name
|
|
404
465
|
)
|
|
405
466
|
except Exception:
|
|
406
|
-
logger.debug(
|
|
407
|
-
f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
|
|
408
|
-
)
|
|
467
|
+
logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
|
|
409
468
|
|
|
410
469
|
# Deploy - skip tags for existing endpoints to avoid conflicts
|
|
411
470
|
agents.deploy(
|
|
@@ -421,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
421
480
|
registered_model_name: str = config.app.registered_model.full_name
|
|
422
481
|
permissions: Sequence[dict[str, Any]] = config.app.permissions
|
|
423
482
|
|
|
424
|
-
logger.debug(
|
|
425
|
-
|
|
483
|
+
logger.debug(
|
|
484
|
+
"Configuring model permissions",
|
|
485
|
+
model_name=registered_model_name,
|
|
486
|
+
permissions_count=len(permissions),
|
|
487
|
+
)
|
|
426
488
|
|
|
427
489
|
for permission in permissions:
|
|
428
490
|
principals: Sequence[str] = permission.principals
|
|
@@ -442,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
442
504
|
try:
|
|
443
505
|
catalog_info = self.w.catalogs.get(name=schema.catalog_name)
|
|
444
506
|
except NotFound:
|
|
445
|
-
logger.
|
|
507
|
+
logger.info("Creating catalog", catalog_name=schema.catalog_name)
|
|
446
508
|
catalog_info = self.w.catalogs.create(name=schema.catalog_name)
|
|
447
509
|
return catalog_info
|
|
448
510
|
|
|
@@ -452,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
452
514
|
try:
|
|
453
515
|
schema_info = self.w.schemas.get(full_name=schema.full_name)
|
|
454
516
|
except NotFound:
|
|
455
|
-
logger.
|
|
517
|
+
logger.info("Creating schema", schema_name=schema.full_name)
|
|
456
518
|
schema_info = self.w.schemas.create(
|
|
457
519
|
name=schema.schema_name, catalog_name=catalog_info.name
|
|
458
520
|
)
|
|
@@ -464,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
464
526
|
try:
|
|
465
527
|
volume_info = self.w.volumes.read(name=volume.full_name)
|
|
466
528
|
except NotFound:
|
|
467
|
-
logger.
|
|
529
|
+
logger.info("Creating volume", volume_name=volume.full_name)
|
|
468
530
|
volume_info = self.w.volumes.create(
|
|
469
531
|
catalog_name=schema_info.catalog_name,
|
|
470
532
|
schema_name=schema_info.name,
|
|
@@ -475,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
475
537
|
|
|
476
538
|
def create_path(self, volume_path: VolumePathModel) -> Path:
|
|
477
539
|
path: Path = volume_path.full_name
|
|
478
|
-
logger.info(
|
|
540
|
+
logger.info("Creating volume path", path=str(path))
|
|
479
541
|
self.w.files.create_directory(path)
|
|
480
542
|
return path
|
|
481
543
|
|
|
@@ -516,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
|
|
|
516
578
|
|
|
517
579
|
if ddl:
|
|
518
580
|
ddl_path: Path = Path(ddl)
|
|
519
|
-
logger.debug(
|
|
581
|
+
logger.debug("Executing DDL", ddl_path=str(ddl_path))
|
|
520
582
|
statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
|
|
521
583
|
for statement in statements:
|
|
522
|
-
logger.
|
|
523
|
-
|
|
584
|
+
logger.trace(
|
|
585
|
+
"Executing DDL statement", statement=str(statement)[:100], args=args
|
|
586
|
+
)
|
|
524
587
|
spark.sql(
|
|
525
588
|
str(statement),
|
|
526
589
|
args=args,
|
|
@@ -529,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
|
|
|
529
592
|
if data:
|
|
530
593
|
data_path: Path = Path(data)
|
|
531
594
|
if format == "sql":
|
|
532
|
-
logger.debug(
|
|
595
|
+
logger.debug("Executing SQL from file", data_path=str(data_path))
|
|
533
596
|
data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
|
|
534
597
|
for statement in data_statements:
|
|
535
|
-
logger.
|
|
536
|
-
|
|
598
|
+
logger.trace(
|
|
599
|
+
"Executing SQL statement",
|
|
600
|
+
statement=str(statement)[:100],
|
|
601
|
+
args=args,
|
|
602
|
+
)
|
|
537
603
|
spark.sql(
|
|
538
604
|
str(statement),
|
|
539
605
|
args=args,
|
|
540
606
|
)
|
|
541
607
|
else:
|
|
542
|
-
logger.debug(
|
|
608
|
+
logger.debug("Writing dataset to table", table=table)
|
|
543
609
|
if not data_path.is_absolute():
|
|
544
610
|
data_path = current_dir / data_path
|
|
545
|
-
logger.
|
|
611
|
+
logger.trace("Data path resolved", path=data_path.as_posix())
|
|
546
612
|
if format == "excel":
|
|
547
613
|
pdf = pd.read_excel(data_path.as_posix())
|
|
548
614
|
df = spark.createDataFrame(pdf, schema=dataset.table_schema)
|
|
@@ -566,13 +632,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
566
632
|
verbose=True,
|
|
567
633
|
)
|
|
568
634
|
|
|
569
|
-
logger.
|
|
635
|
+
logger.success(
|
|
636
|
+
"Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
|
|
637
|
+
)
|
|
570
638
|
|
|
571
639
|
if not index_exists(
|
|
572
640
|
self.vsc, vector_store.endpoint.name, vector_store.index.full_name
|
|
573
641
|
):
|
|
574
|
-
logger.
|
|
575
|
-
|
|
642
|
+
logger.info(
|
|
643
|
+
"Creating vector search index",
|
|
644
|
+
index_name=vector_store.index.full_name,
|
|
645
|
+
endpoint_name=vector_store.endpoint.name,
|
|
576
646
|
)
|
|
577
647
|
self.vsc.create_delta_sync_index_and_wait(
|
|
578
648
|
endpoint_name=vector_store.endpoint.name,
|
|
@@ -586,7 +656,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
586
656
|
)
|
|
587
657
|
else:
|
|
588
658
|
logger.debug(
|
|
589
|
-
|
|
659
|
+
"Vector search index already exists, checking status",
|
|
660
|
+
index_name=vector_store.index.full_name,
|
|
590
661
|
)
|
|
591
662
|
index = self.vsc.get_index(
|
|
592
663
|
vector_store.endpoint.name, vector_store.index.full_name
|
|
@@ -609,54 +680,61 @@ class DatabricksProvider(ServiceProvider):
|
|
|
609
680
|
|
|
610
681
|
if pipeline_status in [
|
|
611
682
|
"COMPLETED",
|
|
683
|
+
"ONLINE",
|
|
612
684
|
"FAILED",
|
|
613
685
|
"CANCELED",
|
|
614
686
|
"ONLINE_PIPELINE_FAILED",
|
|
615
687
|
]:
|
|
616
|
-
logger.debug(
|
|
617
|
-
f"Index is ready to sync (status: {pipeline_status})"
|
|
618
|
-
)
|
|
688
|
+
logger.debug("Index ready to sync", status=pipeline_status)
|
|
619
689
|
break
|
|
620
690
|
elif pipeline_status in [
|
|
621
691
|
"WAITING_FOR_RESOURCES",
|
|
622
692
|
"PROVISIONING",
|
|
623
693
|
"INITIALIZING",
|
|
624
694
|
"INDEXING",
|
|
625
|
-
"ONLINE",
|
|
626
695
|
]:
|
|
627
|
-
logger.
|
|
628
|
-
|
|
696
|
+
logger.trace(
|
|
697
|
+
"Index not ready, waiting",
|
|
698
|
+
status=pipeline_status,
|
|
699
|
+
wait_seconds=wait_interval,
|
|
629
700
|
)
|
|
630
701
|
time.sleep(wait_interval)
|
|
631
702
|
elapsed += wait_interval
|
|
632
703
|
else:
|
|
633
704
|
logger.warning(
|
|
634
|
-
|
|
705
|
+
"Unknown pipeline status, attempting sync",
|
|
706
|
+
status=pipeline_status,
|
|
635
707
|
)
|
|
636
708
|
break
|
|
637
709
|
except Exception as status_error:
|
|
638
710
|
logger.warning(
|
|
639
|
-
|
|
711
|
+
"Could not check index status, attempting sync",
|
|
712
|
+
error=str(status_error),
|
|
640
713
|
)
|
|
641
714
|
break
|
|
642
715
|
|
|
643
716
|
if elapsed >= max_wait_time:
|
|
644
717
|
logger.warning(
|
|
645
|
-
|
|
718
|
+
"Timed out waiting for index to be ready",
|
|
719
|
+
max_wait_seconds=max_wait_time,
|
|
646
720
|
)
|
|
647
721
|
|
|
648
722
|
# Now attempt to sync
|
|
649
723
|
try:
|
|
650
724
|
index.sync()
|
|
651
|
-
logger.
|
|
725
|
+
logger.success("Index sync completed")
|
|
652
726
|
except Exception as sync_error:
|
|
653
727
|
if "not ready to sync yet" in str(sync_error).lower():
|
|
654
|
-
logger.warning(
|
|
728
|
+
logger.warning(
|
|
729
|
+
"Index still not ready to sync", error=str(sync_error)
|
|
730
|
+
)
|
|
655
731
|
else:
|
|
656
732
|
raise sync_error
|
|
657
733
|
|
|
658
|
-
logger.
|
|
659
|
-
|
|
734
|
+
logger.success(
|
|
735
|
+
"Vector search index ready",
|
|
736
|
+
index_name=vector_store.index.full_name,
|
|
737
|
+
source_table=vector_store.source_table.full_name,
|
|
660
738
|
)
|
|
661
739
|
|
|
662
740
|
def get_vector_index(self, vector_store: VectorStoreModel) -> None:
|
|
@@ -692,12 +770,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
692
770
|
# sql = sql.replace("{catalog_name}", schema.catalog_name)
|
|
693
771
|
# sql = sql.replace("{schema_name}", schema.schema_name)
|
|
694
772
|
|
|
695
|
-
logger.info(function.name)
|
|
696
|
-
logger.
|
|
773
|
+
logger.info("Creating SQL function", function_name=function.name)
|
|
774
|
+
logger.trace("SQL function body", sql=sql[:200])
|
|
697
775
|
_: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
|
|
698
776
|
|
|
699
777
|
if unity_catalog_function.test:
|
|
700
|
-
logger.
|
|
778
|
+
logger.debug(
|
|
779
|
+
"Testing function",
|
|
780
|
+
function_name=function.full_name,
|
|
781
|
+
parameters=unity_catalog_function.test.parameters,
|
|
782
|
+
)
|
|
701
783
|
|
|
702
784
|
result: FunctionExecutionResult = self.dfs.execute_function(
|
|
703
785
|
function_name=function.full_name,
|
|
@@ -705,37 +787,50 @@ class DatabricksProvider(ServiceProvider):
|
|
|
705
787
|
)
|
|
706
788
|
|
|
707
789
|
if result.error:
|
|
708
|
-
logger.error(
|
|
790
|
+
logger.error(
|
|
791
|
+
"Function test failed",
|
|
792
|
+
function_name=function.full_name,
|
|
793
|
+
error=result.error,
|
|
794
|
+
)
|
|
709
795
|
else:
|
|
710
|
-
logger.
|
|
711
|
-
|
|
796
|
+
logger.success(
|
|
797
|
+
"Function test passed", function_name=function.full_name
|
|
798
|
+
)
|
|
799
|
+
logger.debug("Function test result", result=str(result))
|
|
712
800
|
|
|
713
801
|
def find_columns(self, table_model: TableModel) -> Sequence[str]:
|
|
714
|
-
logger.
|
|
802
|
+
logger.trace("Finding columns for table", table=table_model.full_name)
|
|
715
803
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
716
804
|
columns: Sequence[ColumnInfo] = table_info.columns
|
|
717
805
|
column_names: Sequence[str] = [c.name for c in columns]
|
|
718
|
-
logger.debug(
|
|
806
|
+
logger.debug(
|
|
807
|
+
"Columns found",
|
|
808
|
+
table=table_model.full_name,
|
|
809
|
+
columns_count=len(column_names),
|
|
810
|
+
)
|
|
719
811
|
return column_names
|
|
720
812
|
|
|
721
813
|
def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
|
|
722
|
-
logger.
|
|
814
|
+
logger.trace("Finding primary key for table", table=table_model.full_name)
|
|
723
815
|
primary_keys: Sequence[str] | None = None
|
|
724
816
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
725
817
|
constraints: Sequence[TableConstraint] = table_info.table_constraints
|
|
726
818
|
primary_key_constraint: PrimaryKeyConstraint | None = next(
|
|
727
|
-
c.primary_key_constraint for c in constraints if c.primary_key_constraint
|
|
819
|
+
(c.primary_key_constraint for c in constraints if c.primary_key_constraint),
|
|
820
|
+
None,
|
|
728
821
|
)
|
|
729
822
|
if primary_key_constraint:
|
|
730
823
|
primary_keys = primary_key_constraint.child_columns
|
|
731
824
|
|
|
732
|
-
logger.debug(
|
|
825
|
+
logger.debug(
|
|
826
|
+
"Primary key found", table=table_model.full_name, primary_keys=primary_keys
|
|
827
|
+
)
|
|
733
828
|
return primary_keys
|
|
734
829
|
|
|
735
830
|
def find_vector_search_endpoint(
|
|
736
831
|
self, predicate: Callable[[dict[str, Any]], bool]
|
|
737
832
|
) -> str | None:
|
|
738
|
-
logger.
|
|
833
|
+
logger.trace("Finding vector search endpoint")
|
|
739
834
|
endpoint_name: str | None = None
|
|
740
835
|
vector_search_endpoints: Sequence[dict[str, Any]] = (
|
|
741
836
|
self.vsc.list_endpoints().get("endpoints", [])
|
|
@@ -744,11 +839,13 @@ class DatabricksProvider(ServiceProvider):
|
|
|
744
839
|
if predicate(endpoint):
|
|
745
840
|
endpoint_name = endpoint["name"]
|
|
746
841
|
break
|
|
747
|
-
logger.debug(
|
|
842
|
+
logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
|
|
748
843
|
return endpoint_name
|
|
749
844
|
|
|
750
845
|
def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
|
|
751
|
-
logger.
|
|
846
|
+
logger.trace(
|
|
847
|
+
"Finding endpoint for vector search index", index_name=index_model.full_name
|
|
848
|
+
)
|
|
752
849
|
all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
|
|
753
850
|
"endpoints", []
|
|
754
851
|
)
|
|
@@ -758,14 +855,99 @@ class DatabricksProvider(ServiceProvider):
|
|
|
758
855
|
endpoint_name: str = endpoint["name"]
|
|
759
856
|
indexes = self.vsc.list_indexes(name=endpoint_name)
|
|
760
857
|
vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
|
|
761
|
-
logger.trace(
|
|
858
|
+
logger.trace(
|
|
859
|
+
"Checking endpoint for indexes",
|
|
860
|
+
endpoint_name=endpoint_name,
|
|
861
|
+
indexes_count=len(vector_indexes),
|
|
862
|
+
)
|
|
762
863
|
index_names = [vector_index["name"] for vector_index in vector_indexes]
|
|
763
864
|
if index_name in index_names:
|
|
764
865
|
found_endpoint_name = endpoint_name
|
|
765
866
|
break
|
|
766
|
-
logger.debug(
|
|
867
|
+
logger.debug(
|
|
868
|
+
"Vector search index endpoint found",
|
|
869
|
+
index_name=index_model.full_name,
|
|
870
|
+
endpoint_name=found_endpoint_name,
|
|
871
|
+
)
|
|
767
872
|
return found_endpoint_name
|
|
768
873
|
|
|
874
|
+
def _wait_for_database_available(
|
|
875
|
+
self,
|
|
876
|
+
workspace_client: WorkspaceClient,
|
|
877
|
+
instance_name: str,
|
|
878
|
+
max_wait_time: int = 600,
|
|
879
|
+
wait_interval: int = 10,
|
|
880
|
+
) -> None:
|
|
881
|
+
"""
|
|
882
|
+
Wait for a database instance to become AVAILABLE.
|
|
883
|
+
|
|
884
|
+
Args:
|
|
885
|
+
workspace_client: The Databricks workspace client
|
|
886
|
+
instance_name: Name of the database instance to wait for
|
|
887
|
+
max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
|
|
888
|
+
wait_interval: Time between status checks in seconds (default: 10)
|
|
889
|
+
|
|
890
|
+
Raises:
|
|
891
|
+
TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
|
|
892
|
+
RuntimeError: If the database enters a failed or deleted state
|
|
893
|
+
"""
|
|
894
|
+
import time
|
|
895
|
+
from typing import Any
|
|
896
|
+
|
|
897
|
+
logger.info(
|
|
898
|
+
"Waiting for database instance to become AVAILABLE",
|
|
899
|
+
instance_name=instance_name,
|
|
900
|
+
)
|
|
901
|
+
elapsed: int = 0
|
|
902
|
+
|
|
903
|
+
while elapsed < max_wait_time:
|
|
904
|
+
try:
|
|
905
|
+
current_instance: Any = workspace_client.database.get_database_instance(
|
|
906
|
+
name=instance_name
|
|
907
|
+
)
|
|
908
|
+
current_state: str = current_instance.state
|
|
909
|
+
logger.trace(
|
|
910
|
+
"Database instance state checked",
|
|
911
|
+
instance_name=instance_name,
|
|
912
|
+
state=current_state,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
if current_state == "AVAILABLE":
|
|
916
|
+
logger.success(
|
|
917
|
+
"Database instance is now AVAILABLE",
|
|
918
|
+
instance_name=instance_name,
|
|
919
|
+
)
|
|
920
|
+
return
|
|
921
|
+
elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
|
|
922
|
+
logger.trace(
|
|
923
|
+
"Database instance not ready, waiting",
|
|
924
|
+
instance_name=instance_name,
|
|
925
|
+
state=current_state,
|
|
926
|
+
wait_seconds=wait_interval,
|
|
927
|
+
)
|
|
928
|
+
time.sleep(wait_interval)
|
|
929
|
+
elapsed += wait_interval
|
|
930
|
+
elif current_state in ["STOPPED", "DELETING", "FAILED"]:
|
|
931
|
+
raise RuntimeError(
|
|
932
|
+
f"Database instance {instance_name} entered unexpected state: {current_state}"
|
|
933
|
+
)
|
|
934
|
+
else:
|
|
935
|
+
logger.warning(
|
|
936
|
+
"Unknown database state, continuing to wait",
|
|
937
|
+
instance_name=instance_name,
|
|
938
|
+
state=current_state,
|
|
939
|
+
)
|
|
940
|
+
time.sleep(wait_interval)
|
|
941
|
+
elapsed += wait_interval
|
|
942
|
+
except NotFound:
|
|
943
|
+
raise RuntimeError(
|
|
944
|
+
f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
raise TimeoutError(
|
|
948
|
+
f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
949
|
+
)
|
|
950
|
+
|
|
769
951
|
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
770
952
|
"""
|
|
771
953
|
Create a Lakebase database instance using the Databricks workspace client.
|
|
@@ -796,13 +978,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
796
978
|
|
|
797
979
|
if existing_instance:
|
|
798
980
|
logger.debug(
|
|
799
|
-
|
|
981
|
+
"Database instance already exists",
|
|
982
|
+
instance_name=database.instance_name,
|
|
983
|
+
state=existing_instance.state,
|
|
800
984
|
)
|
|
801
985
|
|
|
802
986
|
# Check if database is in an intermediate state
|
|
803
987
|
if existing_instance.state in ["STARTING", "UPDATING"]:
|
|
804
988
|
logger.info(
|
|
805
|
-
|
|
989
|
+
"Database instance in intermediate state, waiting",
|
|
990
|
+
instance_name=database.instance_name,
|
|
991
|
+
state=existing_instance.state,
|
|
806
992
|
)
|
|
807
993
|
|
|
808
994
|
# Wait for database to reach a stable state
|
|
@@ -818,65 +1004,87 @@ class DatabricksProvider(ServiceProvider):
|
|
|
818
1004
|
)
|
|
819
1005
|
)
|
|
820
1006
|
current_state: str = current_instance.state
|
|
821
|
-
logger.
|
|
1007
|
+
logger.trace(
|
|
1008
|
+
"Checking database instance state",
|
|
1009
|
+
instance_name=database.instance_name,
|
|
1010
|
+
state=current_state,
|
|
1011
|
+
)
|
|
822
1012
|
|
|
823
1013
|
if current_state == "AVAILABLE":
|
|
824
|
-
logger.
|
|
825
|
-
|
|
1014
|
+
logger.success(
|
|
1015
|
+
"Database instance is now AVAILABLE",
|
|
1016
|
+
instance_name=database.instance_name,
|
|
826
1017
|
)
|
|
827
1018
|
break
|
|
828
1019
|
elif current_state in ["STARTING", "UPDATING"]:
|
|
829
|
-
logger.
|
|
830
|
-
|
|
1020
|
+
logger.trace(
|
|
1021
|
+
"Database instance not ready, waiting",
|
|
1022
|
+
instance_name=database.instance_name,
|
|
1023
|
+
state=current_state,
|
|
1024
|
+
wait_seconds=wait_interval,
|
|
831
1025
|
)
|
|
832
1026
|
time.sleep(wait_interval)
|
|
833
1027
|
elapsed += wait_interval
|
|
834
1028
|
elif current_state in ["STOPPED", "DELETING"]:
|
|
835
1029
|
logger.warning(
|
|
836
|
-
|
|
1030
|
+
"Database instance in unexpected state",
|
|
1031
|
+
instance_name=database.instance_name,
|
|
1032
|
+
state=current_state,
|
|
837
1033
|
)
|
|
838
1034
|
break
|
|
839
1035
|
else:
|
|
840
1036
|
logger.warning(
|
|
841
|
-
|
|
1037
|
+
"Unknown database state, proceeding",
|
|
1038
|
+
instance_name=database.instance_name,
|
|
1039
|
+
state=current_state,
|
|
842
1040
|
)
|
|
843
1041
|
break
|
|
844
1042
|
except NotFound:
|
|
845
1043
|
logger.warning(
|
|
846
|
-
|
|
1044
|
+
"Database instance no longer exists, will recreate",
|
|
1045
|
+
instance_name=database.instance_name,
|
|
847
1046
|
)
|
|
848
1047
|
break
|
|
849
1048
|
except Exception as state_error:
|
|
850
1049
|
logger.warning(
|
|
851
|
-
|
|
1050
|
+
"Could not check database state, proceeding",
|
|
1051
|
+
instance_name=database.instance_name,
|
|
1052
|
+
error=str(state_error),
|
|
852
1053
|
)
|
|
853
1054
|
break
|
|
854
1055
|
|
|
855
1056
|
if elapsed >= max_wait_time:
|
|
856
1057
|
logger.warning(
|
|
857
|
-
|
|
1058
|
+
"Timed out waiting for database to become AVAILABLE",
|
|
1059
|
+
instance_name=database.instance_name,
|
|
1060
|
+
max_wait_seconds=max_wait_time,
|
|
858
1061
|
)
|
|
859
1062
|
|
|
860
1063
|
elif existing_instance.state == "AVAILABLE":
|
|
861
1064
|
logger.info(
|
|
862
|
-
|
|
1065
|
+
"Database instance already exists and is AVAILABLE",
|
|
1066
|
+
instance_name=database.instance_name,
|
|
863
1067
|
)
|
|
864
1068
|
return
|
|
865
1069
|
elif existing_instance.state in ["STOPPED", "DELETING"]:
|
|
866
1070
|
logger.warning(
|
|
867
|
-
|
|
1071
|
+
"Database instance in terminal state",
|
|
1072
|
+
instance_name=database.instance_name,
|
|
1073
|
+
state=existing_instance.state,
|
|
868
1074
|
)
|
|
869
1075
|
return
|
|
870
1076
|
else:
|
|
871
1077
|
logger.info(
|
|
872
|
-
|
|
1078
|
+
"Database instance already exists",
|
|
1079
|
+
instance_name=database.instance_name,
|
|
1080
|
+
state=existing_instance.state,
|
|
873
1081
|
)
|
|
874
1082
|
return
|
|
875
1083
|
|
|
876
1084
|
except NotFound:
|
|
877
1085
|
# Database doesn't exist, proceed with creation
|
|
878
|
-
logger.
|
|
879
|
-
|
|
1086
|
+
logger.info(
|
|
1087
|
+
"Creating new database instance", instance_name=database.instance_name
|
|
880
1088
|
)
|
|
881
1089
|
|
|
882
1090
|
try:
|
|
@@ -896,10 +1104,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
896
1104
|
workspace_client.database.create_database_instance(
|
|
897
1105
|
database_instance=database_instance
|
|
898
1106
|
)
|
|
899
|
-
logger.
|
|
900
|
-
|
|
1107
|
+
logger.success(
|
|
1108
|
+
"Database instance created successfully",
|
|
1109
|
+
instance_name=database.instance_name,
|
|
901
1110
|
)
|
|
902
1111
|
|
|
1112
|
+
# Wait for the newly created database to become AVAILABLE
|
|
1113
|
+
self._wait_for_database_available(
|
|
1114
|
+
workspace_client, database.instance_name
|
|
1115
|
+
)
|
|
1116
|
+
return
|
|
1117
|
+
|
|
903
1118
|
except Exception as create_error:
|
|
904
1119
|
error_msg: str = str(create_error)
|
|
905
1120
|
|
|
@@ -909,13 +1124,20 @@ class DatabricksProvider(ServiceProvider):
|
|
|
909
1124
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
910
1125
|
):
|
|
911
1126
|
logger.info(
|
|
912
|
-
|
|
1127
|
+
"Database instance was created concurrently",
|
|
1128
|
+
instance_name=database.instance_name,
|
|
1129
|
+
)
|
|
1130
|
+
# Still need to wait for the database to become AVAILABLE
|
|
1131
|
+
self._wait_for_database_available(
|
|
1132
|
+
workspace_client, database.instance_name
|
|
913
1133
|
)
|
|
914
1134
|
return
|
|
915
1135
|
else:
|
|
916
1136
|
# Re-raise unexpected errors
|
|
917
1137
|
logger.error(
|
|
918
|
-
|
|
1138
|
+
"Error creating database instance",
|
|
1139
|
+
instance_name=database.instance_name,
|
|
1140
|
+
error=str(create_error),
|
|
919
1141
|
)
|
|
920
1142
|
raise
|
|
921
1143
|
|
|
@@ -929,12 +1151,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
929
1151
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
930
1152
|
):
|
|
931
1153
|
logger.info(
|
|
932
|
-
|
|
1154
|
+
"Database instance already exists (detected via exception)",
|
|
1155
|
+
instance_name=database.instance_name,
|
|
933
1156
|
)
|
|
934
1157
|
return
|
|
935
1158
|
else:
|
|
936
1159
|
logger.error(
|
|
937
|
-
|
|
1160
|
+
"Unexpected error while handling database",
|
|
1161
|
+
instance_name=database.instance_name,
|
|
1162
|
+
error=str(e),
|
|
938
1163
|
)
|
|
939
1164
|
raise
|
|
940
1165
|
|
|
@@ -942,7 +1167,9 @@ class DatabricksProvider(ServiceProvider):
|
|
|
942
1167
|
"""
|
|
943
1168
|
Ask Databricks to mint a fresh DB credential for this instance.
|
|
944
1169
|
"""
|
|
945
|
-
logger.
|
|
1170
|
+
logger.trace(
|
|
1171
|
+
"Generating password for lakebase instance", instance_name=instance_name
|
|
1172
|
+
)
|
|
946
1173
|
w: WorkspaceClient = self.w
|
|
947
1174
|
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
948
1175
|
request_id=str(uuid.uuid4()),
|
|
@@ -978,7 +1205,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
978
1205
|
# Validate that client_id is provided
|
|
979
1206
|
if not database.client_id:
|
|
980
1207
|
logger.warning(
|
|
981
|
-
|
|
1208
|
+
"client_id required to create instance role",
|
|
1209
|
+
instance_name=database.instance_name,
|
|
982
1210
|
)
|
|
983
1211
|
return
|
|
984
1212
|
|
|
@@ -988,7 +1216,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
988
1216
|
instance_name: str = database.instance_name
|
|
989
1217
|
|
|
990
1218
|
logger.debug(
|
|
991
|
-
|
|
1219
|
+
"Creating instance role",
|
|
1220
|
+
role_name=role_name,
|
|
1221
|
+
instance_name=instance_name,
|
|
1222
|
+
principal=client_id,
|
|
992
1223
|
)
|
|
993
1224
|
|
|
994
1225
|
try:
|
|
@@ -999,13 +1230,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
999
1230
|
name=role_name,
|
|
1000
1231
|
)
|
|
1001
1232
|
logger.info(
|
|
1002
|
-
|
|
1233
|
+
"Instance role already exists",
|
|
1234
|
+
role_name=role_name,
|
|
1235
|
+
instance_name=instance_name,
|
|
1003
1236
|
)
|
|
1004
1237
|
return
|
|
1005
1238
|
except NotFound:
|
|
1006
1239
|
# Role doesn't exist, proceed with creation
|
|
1007
1240
|
logger.debug(
|
|
1008
|
-
|
|
1241
|
+
"Instance role not found, creating new role", role_name=role_name
|
|
1009
1242
|
)
|
|
1010
1243
|
|
|
1011
1244
|
# Create the database instance role
|
|
@@ -1021,8 +1254,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1021
1254
|
database_instance_role=role,
|
|
1022
1255
|
)
|
|
1023
1256
|
|
|
1024
|
-
logger.
|
|
1025
|
-
|
|
1257
|
+
logger.success(
|
|
1258
|
+
"Instance role created successfully",
|
|
1259
|
+
role_name=role_name,
|
|
1260
|
+
instance_name=instance_name,
|
|
1026
1261
|
)
|
|
1027
1262
|
|
|
1028
1263
|
except Exception as e:
|
|
@@ -1034,13 +1269,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1034
1269
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
1035
1270
|
):
|
|
1036
1271
|
logger.info(
|
|
1037
|
-
|
|
1272
|
+
"Instance role was created concurrently",
|
|
1273
|
+
role_name=role_name,
|
|
1274
|
+
instance_name=instance_name,
|
|
1038
1275
|
)
|
|
1039
1276
|
return
|
|
1040
1277
|
|
|
1041
1278
|
# Re-raise unexpected errors
|
|
1042
1279
|
logger.error(
|
|
1043
|
-
|
|
1280
|
+
"Error creating instance role",
|
|
1281
|
+
role_name=role_name,
|
|
1282
|
+
instance_name=instance_name,
|
|
1283
|
+
error=str(e),
|
|
1044
1284
|
)
|
|
1045
1285
|
raise
|
|
1046
1286
|
|
|
@@ -1050,9 +1290,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1050
1290
|
|
|
1051
1291
|
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1052
1292
|
Otherwise, tries to load prompts in this order:
|
|
1053
|
-
1. champion alias
|
|
1054
|
-
2. latest alias
|
|
1055
|
-
3.
|
|
1293
|
+
1. champion alias
|
|
1294
|
+
2. latest alias
|
|
1295
|
+
3. default alias
|
|
1296
|
+
4. Register default_template if provided (only if register_to_registry=True)
|
|
1297
|
+
5. Use default_template directly (fallback)
|
|
1298
|
+
|
|
1299
|
+
The auto_register field controls whether the default_template is automatically
|
|
1300
|
+
synced to the prompt registry:
|
|
1301
|
+
- If True (default): Auto-registers/updates the default_template in the registry
|
|
1302
|
+
- If False: Never registers, but can still load existing prompts from registry
|
|
1303
|
+
or use default_template directly as a local-only prompt
|
|
1056
1304
|
|
|
1057
1305
|
Args:
|
|
1058
1306
|
prompt_model: The prompt model configuration
|
|
@@ -1063,542 +1311,266 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1063
1311
|
Raises:
|
|
1064
1312
|
ValueError: If no prompt can be loaded from any source
|
|
1065
1313
|
"""
|
|
1314
|
+
|
|
1066
1315
|
prompt_name: str = prompt_model.full_name
|
|
1067
1316
|
|
|
1068
|
-
# If explicit version or alias is specified, use it directly
|
|
1317
|
+
# If explicit version or alias is specified, use it directly
|
|
1069
1318
|
if prompt_model.version or prompt_model.alias:
|
|
1070
1319
|
try:
|
|
1071
1320
|
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
1321
|
+
version_or_alias = (
|
|
1322
|
+
f"version {prompt_model.version}"
|
|
1323
|
+
if prompt_model.version
|
|
1324
|
+
else f"alias {prompt_model.alias}"
|
|
1325
|
+
)
|
|
1072
1326
|
logger.debug(
|
|
1073
|
-
|
|
1074
|
-
|
|
1327
|
+
"Loaded prompt with explicit version/alias",
|
|
1328
|
+
prompt_name=prompt_name,
|
|
1329
|
+
version_or_alias=version_or_alias,
|
|
1075
1330
|
)
|
|
1076
1331
|
return prompt_version
|
|
1077
1332
|
except Exception as e:
|
|
1333
|
+
version_or_alias = (
|
|
1334
|
+
f"version {prompt_model.version}"
|
|
1335
|
+
if prompt_model.version
|
|
1336
|
+
else f"alias {prompt_model.alias}"
|
|
1337
|
+
)
|
|
1078
1338
|
logger.warning(
|
|
1079
|
-
|
|
1080
|
-
|
|
1339
|
+
"Failed to load prompt with explicit version/alias",
|
|
1340
|
+
prompt_name=prompt_name,
|
|
1341
|
+
version_or_alias=version_or_alias,
|
|
1342
|
+
error=str(e),
|
|
1081
1343
|
)
|
|
1082
|
-
# Fall through to
|
|
1083
|
-
else:
|
|
1084
|
-
# No explicit version/alias specified - check if default_template needs syncing first
|
|
1085
|
-
logger.debug(
|
|
1086
|
-
f"No explicit version/alias specified for '{prompt_name}', "
|
|
1087
|
-
"checking if default_template needs syncing"
|
|
1088
|
-
)
|
|
1089
|
-
|
|
1090
|
-
# If we have a default_template, check if it differs from what's in the registry
|
|
1091
|
-
# This ensures we always sync config changes before returning any alias
|
|
1092
|
-
if prompt_model.default_template:
|
|
1093
|
-
try:
|
|
1094
|
-
default_uri: str = f"prompts:/{prompt_name}@default"
|
|
1095
|
-
default_version: PromptVersion = load_prompt(default_uri)
|
|
1096
|
-
|
|
1097
|
-
if (
|
|
1098
|
-
default_version.to_single_brace_format().strip()
|
|
1099
|
-
!= prompt_model.default_template.strip()
|
|
1100
|
-
):
|
|
1101
|
-
logger.info(
|
|
1102
|
-
f"Config default_template for '{prompt_name}' differs from registry, syncing..."
|
|
1103
|
-
)
|
|
1104
|
-
return self._sync_default_template_to_registry(
|
|
1105
|
-
prompt_name,
|
|
1106
|
-
prompt_model.default_template,
|
|
1107
|
-
prompt_model.description,
|
|
1108
|
-
)
|
|
1109
|
-
except Exception as e:
|
|
1110
|
-
logger.debug(f"Could not check default alias for sync: {e}")
|
|
1344
|
+
# Fall through to try other methods
|
|
1111
1345
|
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
try:
|
|
1119
|
-
champion_uri: str = f"prompts:/{prompt_name}@champion"
|
|
1120
|
-
prompt_version: PromptVersion = load_prompt(champion_uri)
|
|
1121
|
-
logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
|
|
1122
|
-
return prompt_version
|
|
1123
|
-
except Exception as e:
|
|
1124
|
-
logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
|
|
1125
|
-
|
|
1126
|
-
# Try latest alias next
|
|
1127
|
-
try:
|
|
1128
|
-
latest_uri: str = f"prompts:/{prompt_name}@latest"
|
|
1129
|
-
prompt_version: PromptVersion = load_prompt(latest_uri)
|
|
1130
|
-
logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
|
|
1131
|
-
return prompt_version
|
|
1132
|
-
except Exception as e:
|
|
1133
|
-
logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
|
|
1346
|
+
# Try to load in priority order: champion → default (with sync check)
|
|
1347
|
+
logger.trace(
|
|
1348
|
+
"Trying prompt fallback order",
|
|
1349
|
+
prompt_name=prompt_name,
|
|
1350
|
+
order="champion → default",
|
|
1351
|
+
)
|
|
1134
1352
|
|
|
1135
|
-
|
|
1353
|
+
# First, sync default alias if template has changed (even if champion exists)
|
|
1354
|
+
# Only do this if auto_register is True
|
|
1355
|
+
if prompt_model.default_template and prompt_model.auto_register:
|
|
1136
1356
|
try:
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
logger.info(f"Loaded prompt '{prompt_name}' from default alias")
|
|
1140
|
-
return prompt_version
|
|
1141
|
-
except Exception as e:
|
|
1142
|
-
logger.debug(f"Default alias not found for '{prompt_name}': {e}")
|
|
1357
|
+
# Try to load existing default
|
|
1358
|
+
existing_default = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1143
1359
|
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1360
|
+
# Check if champion exists and if it matches default
|
|
1361
|
+
champion_matches_default = False
|
|
1362
|
+
try:
|
|
1363
|
+
existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1364
|
+
champion_matches_default = (
|
|
1365
|
+
existing_champion.version == existing_default.version
|
|
1366
|
+
)
|
|
1367
|
+
status = (
|
|
1368
|
+
"tracking" if champion_matches_default else "pinned separately"
|
|
1369
|
+
)
|
|
1370
|
+
logger.trace(
|
|
1371
|
+
"Champion vs default version",
|
|
1372
|
+
prompt_name=prompt_name,
|
|
1373
|
+
champion_version=existing_champion.version,
|
|
1374
|
+
default_version=existing_default.version,
|
|
1375
|
+
status=status,
|
|
1376
|
+
)
|
|
1377
|
+
except Exception:
|
|
1378
|
+
# No champion exists
|
|
1379
|
+
logger.trace("No champion alias found", prompt_name=prompt_name)
|
|
1164
1380
|
|
|
1165
|
-
|
|
1166
|
-
# Check if default alias already has the same template
|
|
1167
|
-
try:
|
|
1168
|
-
logger.debug(f"Loading prompt '{prompt_name}' from registry...")
|
|
1169
|
-
existing: PromptVersion = mlflow.genai.load_prompt(
|
|
1170
|
-
f"prompts:/{prompt_name}@default"
|
|
1171
|
-
)
|
|
1381
|
+
# Check if default_template differs from existing default
|
|
1172
1382
|
if (
|
|
1173
|
-
|
|
1174
|
-
|
|
1383
|
+
existing_default.template.strip()
|
|
1384
|
+
!= prompt_model.default_template.strip()
|
|
1175
1385
|
):
|
|
1176
|
-
logger.
|
|
1386
|
+
logger.info(
|
|
1387
|
+
"Default template changed, registering new version",
|
|
1388
|
+
prompt_name=prompt_name,
|
|
1389
|
+
)
|
|
1177
1390
|
|
|
1178
|
-
#
|
|
1179
|
-
|
|
1180
|
-
try:
|
|
1181
|
-
latest_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1182
|
-
f"prompts:/{prompt_name}@latest"
|
|
1183
|
-
)
|
|
1184
|
-
logger.debug(
|
|
1185
|
-
f"Latest alias already exists for '{prompt_name}' pointing to version {latest_version.version}"
|
|
1186
|
-
)
|
|
1187
|
-
except Exception:
|
|
1391
|
+
# Only update champion if it was pointing to the old default
|
|
1392
|
+
if champion_matches_default:
|
|
1188
1393
|
logger.info(
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
name=prompt_name,
|
|
1193
|
-
alias="latest",
|
|
1194
|
-
version=existing.version,
|
|
1195
|
-
)
|
|
1196
|
-
|
|
1197
|
-
# Ensure champion alias exists for first-time deployments
|
|
1198
|
-
try:
|
|
1199
|
-
champion_version: PromptVersion = mlflow.genai.load_prompt(
|
|
1200
|
-
f"prompts:/{prompt_name}@champion"
|
|
1394
|
+
"Champion was tracking default, will update to new version",
|
|
1395
|
+
prompt_name=prompt_name,
|
|
1396
|
+
old_version=existing_default.version,
|
|
1201
1397
|
)
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
)
|
|
1205
|
-
except Exception:
|
|
1398
|
+
set_champion = True
|
|
1399
|
+
else:
|
|
1206
1400
|
logger.info(
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
mlflow.genai.set_prompt_alias(
|
|
1210
|
-
name=prompt_name,
|
|
1211
|
-
alias="champion",
|
|
1212
|
-
version=existing.version,
|
|
1401
|
+
"Champion is pinned separately, preserving it",
|
|
1402
|
+
prompt_name=prompt_name,
|
|
1213
1403
|
)
|
|
1404
|
+
set_champion = False
|
|
1214
1405
|
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1406
|
+
self._register_default_template(
|
|
1407
|
+
prompt_name,
|
|
1408
|
+
prompt_model.default_template,
|
|
1409
|
+
prompt_model.description,
|
|
1410
|
+
set_champion=set_champion,
|
|
1411
|
+
)
|
|
1412
|
+
except Exception as e:
|
|
1413
|
+
# No default exists yet, register it
|
|
1414
|
+
logger.trace(
|
|
1415
|
+
"No default alias found", prompt_name=prompt_name, error=str(e)
|
|
1219
1416
|
)
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
alias="default",
|
|
1236
|
-
version=prompt_version.version,
|
|
1237
|
-
)
|
|
1238
|
-
mlflow.genai.set_prompt_alias(
|
|
1239
|
-
name=prompt_name,
|
|
1240
|
-
alias="latest",
|
|
1241
|
-
version=prompt_version.version,
|
|
1242
|
-
)
|
|
1243
|
-
mlflow.genai.set_prompt_alias(
|
|
1244
|
-
name=prompt_name,
|
|
1245
|
-
alias="champion",
|
|
1246
|
-
version=prompt_version.version,
|
|
1417
|
+
logger.info(
|
|
1418
|
+
"Registering default template as default alias",
|
|
1419
|
+
prompt_name=prompt_name,
|
|
1420
|
+
)
|
|
1421
|
+
# First registration - set both default and champion
|
|
1422
|
+
self._register_default_template(
|
|
1423
|
+
prompt_name,
|
|
1424
|
+
prompt_model.default_template,
|
|
1425
|
+
prompt_model.description,
|
|
1426
|
+
set_champion=True,
|
|
1427
|
+
)
|
|
1428
|
+
elif prompt_model.default_template and not prompt_model.auto_register:
|
|
1429
|
+
logger.trace(
|
|
1430
|
+
"Prompt has auto_register=False, skipping registration",
|
|
1431
|
+
prompt_name=prompt_name,
|
|
1247
1432
|
)
|
|
1248
1433
|
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
)
|
|
1434
|
+
# 1. Try champion alias (highest priority for execution)
|
|
1435
|
+
try:
|
|
1436
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1437
|
+
logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
|
|
1252
1438
|
return prompt_version
|
|
1253
|
-
|
|
1254
1439
|
except Exception as e:
|
|
1255
|
-
logger.
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
) from e
|
|
1259
|
-
|
|
1260
|
-
def optimize_prompt(self, optimization: PromptOptimizationModel) -> PromptModel:
|
|
1261
|
-
"""
|
|
1262
|
-
Optimize a prompt using MLflow's prompt optimization (MLflow 3.5+).
|
|
1263
|
-
|
|
1264
|
-
This uses the MLflow GenAI optimize_prompts API with GepaPromptOptimizer as documented at:
|
|
1265
|
-
https://mlflow.org/docs/latest/genai/prompt-registry/optimize-prompts/
|
|
1440
|
+
logger.trace(
|
|
1441
|
+
"Champion alias not found", prompt_name=prompt_name, error=str(e)
|
|
1442
|
+
)
|
|
1266
1443
|
|
|
1267
|
-
|
|
1268
|
-
|
|
1444
|
+
# 2. Try default alias (already synced above)
|
|
1445
|
+
if prompt_model.default_template:
|
|
1446
|
+
try:
|
|
1447
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1448
|
+
logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
|
|
1449
|
+
return prompt_version
|
|
1450
|
+
except Exception as e:
|
|
1451
|
+
# Should not happen since we just registered it above, but handle anyway
|
|
1452
|
+
logger.trace(
|
|
1453
|
+
"Default alias not found", prompt_name=prompt_name, error=str(e)
|
|
1454
|
+
)
|
|
1269
1455
|
|
|
1270
|
-
|
|
1271
|
-
PromptModel: The optimized prompt with new URI
|
|
1272
|
-
"""
|
|
1273
|
-
from mlflow.genai.optimize import GepaPromptOptimizer, optimize_prompts
|
|
1274
|
-
from mlflow.genai.scorers import Correctness
|
|
1275
|
-
|
|
1276
|
-
from dao_ai.config import AgentModel, PromptModel
|
|
1277
|
-
|
|
1278
|
-
logger.info(f"Optimizing prompt: {optimization.name}")
|
|
1279
|
-
|
|
1280
|
-
# Get agent and prompt (prompt is guaranteed to be set by validator)
|
|
1281
|
-
agent_model: AgentModel = optimization.agent
|
|
1282
|
-
prompt: PromptModel = optimization.prompt # type: ignore[assignment]
|
|
1283
|
-
agent_model.prompt = prompt.uri
|
|
1284
|
-
|
|
1285
|
-
print(f"prompt={agent_model.prompt}")
|
|
1286
|
-
# Log the prompt URI scheme being used
|
|
1287
|
-
# Supports three schemes:
|
|
1288
|
-
# 1. Specific version: "prompts:/qa/1" (when version is specified)
|
|
1289
|
-
# 2. Alias: "prompts:/qa@champion" (when alias is specified)
|
|
1290
|
-
# 3. Latest: "prompts:/qa@latest" (default when neither version nor alias specified)
|
|
1291
|
-
prompt_uri: str = prompt.uri
|
|
1292
|
-
logger.info(f"Using prompt URI for optimization: {prompt_uri}")
|
|
1293
|
-
|
|
1294
|
-
# Load the specific prompt version by URI for comparison
|
|
1295
|
-
# Try to load the exact version specified, but if it doesn't exist,
|
|
1296
|
-
# use get_prompt to create it from default_template
|
|
1297
|
-
prompt_version: PromptVersion
|
|
1456
|
+
# 3. Try latest alias as final fallback
|
|
1298
1457
|
try:
|
|
1299
|
-
prompt_version = load_prompt(
|
|
1300
|
-
logger.info(
|
|
1458
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
|
|
1459
|
+
logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
|
|
1460
|
+
return prompt_version
|
|
1301
1461
|
except Exception as e:
|
|
1462
|
+
logger.trace(
|
|
1463
|
+
"Latest alias not found", prompt_name=prompt_name, error=str(e)
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# 4. Final fallback: use default_template directly if available
|
|
1467
|
+
if prompt_model.default_template:
|
|
1302
1468
|
logger.warning(
|
|
1303
|
-
|
|
1304
|
-
|
|
1469
|
+
"Could not load prompt from registry, using default_template directly",
|
|
1470
|
+
prompt_name=prompt_name,
|
|
1305
1471
|
)
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1472
|
+
return PromptVersion(
|
|
1473
|
+
name=prompt_name,
|
|
1474
|
+
version=1,
|
|
1475
|
+
template=prompt_model.default_template,
|
|
1476
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1310
1477
|
)
|
|
1311
1478
|
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
dataset = get_dataset(name=optimization.dataset)
|
|
1317
|
-
else:
|
|
1318
|
-
dataset = optimization.dataset.as_dataset()
|
|
1319
|
-
|
|
1320
|
-
# Set up reflection model for the optimizer
|
|
1321
|
-
reflection_model_name: str
|
|
1322
|
-
if optimization.reflection_model:
|
|
1323
|
-
if isinstance(optimization.reflection_model, str):
|
|
1324
|
-
reflection_model_name = optimization.reflection_model
|
|
1325
|
-
else:
|
|
1326
|
-
reflection_model_name = optimization.reflection_model.uri
|
|
1327
|
-
else:
|
|
1328
|
-
reflection_model_name = agent_model.model.uri
|
|
1329
|
-
logger.debug(f"Using reflection model: {reflection_model_name}")
|
|
1330
|
-
|
|
1331
|
-
# Create the GepaPromptOptimizer
|
|
1332
|
-
optimizer: GepaPromptOptimizer = GepaPromptOptimizer(
|
|
1333
|
-
reflection_model=reflection_model_name,
|
|
1334
|
-
max_metric_calls=optimization.num_candidates,
|
|
1335
|
-
display_progress_bar=True,
|
|
1479
|
+
raise ValueError(
|
|
1480
|
+
f"Prompt '{prompt_name}' not found in registry "
|
|
1481
|
+
"(tried champion, default, latest aliases) "
|
|
1482
|
+
"and no default_template provided"
|
|
1336
1483
|
)
|
|
1337
1484
|
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
scorer_model = agent_model.model.uri # Use Databricks default
|
|
1347
|
-
logger.debug(f"Using scorer with model: {scorer_model}")
|
|
1348
|
-
|
|
1349
|
-
scorers: list[Correctness] = [Correctness(model=scorer_model)]
|
|
1350
|
-
|
|
1351
|
-
# Use prompt_uri from line 1188 - already set to prompt.uri (configured URI)
|
|
1352
|
-
# DO NOT overwrite with prompt_version.uri as that uses fallback logic
|
|
1353
|
-
logger.debug(f"Optimizing prompt: {prompt_uri}")
|
|
1354
|
-
|
|
1355
|
-
agent: ResponsesAgent = agent_model.as_responses_agent()
|
|
1356
|
-
|
|
1357
|
-
# Create predict function that will be optimized
|
|
1358
|
-
def predict_fn(**inputs: dict[str, Any]) -> str:
|
|
1359
|
-
"""Prediction function that uses the ResponsesAgent with ChatPayload.
|
|
1360
|
-
|
|
1361
|
-
The agent already has the prompt referenced/applied, so we just need to
|
|
1362
|
-
convert the ChatPayload inputs to ResponsesAgentRequest format and call predict.
|
|
1363
|
-
|
|
1364
|
-
Args:
|
|
1365
|
-
**inputs: Dictionary containing ChatPayload fields (messages/input, custom_inputs)
|
|
1366
|
-
|
|
1367
|
-
Returns:
|
|
1368
|
-
str: The agent's response content
|
|
1369
|
-
"""
|
|
1370
|
-
from mlflow.types.responses import (
|
|
1371
|
-
ResponsesAgentRequest,
|
|
1372
|
-
ResponsesAgentResponse,
|
|
1373
|
-
)
|
|
1374
|
-
from mlflow.types.responses_helpers import Message
|
|
1375
|
-
|
|
1376
|
-
from dao_ai.config import ChatPayload
|
|
1377
|
-
|
|
1378
|
-
# Verify agent is accessible (should be captured from outer scope)
|
|
1379
|
-
if agent is None:
|
|
1380
|
-
raise RuntimeError(
|
|
1381
|
-
"Agent object is not available in predict_fn. "
|
|
1382
|
-
"This may indicate a serialization issue with the ResponsesAgent."
|
|
1383
|
-
)
|
|
1384
|
-
|
|
1385
|
-
# Convert inputs to ChatPayload
|
|
1386
|
-
chat_payload: ChatPayload = ChatPayload(**inputs)
|
|
1387
|
-
|
|
1388
|
-
# Convert ChatPayload messages to MLflow Message format
|
|
1389
|
-
mlflow_messages: list[Message] = [
|
|
1390
|
-
Message(role=msg.role, content=msg.content)
|
|
1391
|
-
for msg in chat_payload.messages
|
|
1392
|
-
]
|
|
1393
|
-
|
|
1394
|
-
# Create ResponsesAgentRequest
|
|
1395
|
-
request: ResponsesAgentRequest = ResponsesAgentRequest(
|
|
1396
|
-
input=mlflow_messages,
|
|
1397
|
-
custom_inputs=chat_payload.custom_inputs,
|
|
1398
|
-
)
|
|
1485
|
+
def _register_default_template(
|
|
1486
|
+
self,
|
|
1487
|
+
prompt_name: str,
|
|
1488
|
+
default_template: str,
|
|
1489
|
+
description: str | None = None,
|
|
1490
|
+
set_champion: bool = True,
|
|
1491
|
+
) -> PromptVersion:
|
|
1492
|
+
"""Register default_template as a new prompt version.
|
|
1399
1493
|
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
if response.output and len(response.output) > 0:
|
|
1404
|
-
content = response.output[0].content
|
|
1405
|
-
logger.debug(f"Response content type: {type(content)}")
|
|
1406
|
-
logger.debug(f"Response content: {content}")
|
|
1407
|
-
|
|
1408
|
-
# Extract text from content using same logic as LanggraphResponsesAgent._extract_text_from_content
|
|
1409
|
-
# Content can be:
|
|
1410
|
-
# - A string (return as is)
|
|
1411
|
-
# - A list of items with 'text' keys (extract and join)
|
|
1412
|
-
# - Other types (try to get 'text' attribute or convert to string)
|
|
1413
|
-
if isinstance(content, str):
|
|
1414
|
-
return content
|
|
1415
|
-
elif isinstance(content, list):
|
|
1416
|
-
text_parts = []
|
|
1417
|
-
for content_item in content:
|
|
1418
|
-
if isinstance(content_item, str):
|
|
1419
|
-
text_parts.append(content_item)
|
|
1420
|
-
elif isinstance(content_item, dict) and "text" in content_item:
|
|
1421
|
-
text_parts.append(content_item["text"])
|
|
1422
|
-
elif hasattr(content_item, "text"):
|
|
1423
|
-
text_parts.append(content_item.text)
|
|
1424
|
-
return "".join(text_parts) if text_parts else str(content)
|
|
1425
|
-
else:
|
|
1426
|
-
# Fallback for unknown types - try to extract text attribute
|
|
1427
|
-
return getattr(content, "text", str(content))
|
|
1428
|
-
else:
|
|
1429
|
-
return ""
|
|
1494
|
+
Registers the template and sets the 'default' alias.
|
|
1495
|
+
Optionally sets 'champion' alias if no champion exists.
|
|
1430
1496
|
|
|
1431
|
-
|
|
1432
|
-
|
|
1497
|
+
Args:
|
|
1498
|
+
prompt_name: Full name of the prompt
|
|
1499
|
+
default_template: The template content
|
|
1500
|
+
description: Optional description for commit message
|
|
1501
|
+
set_champion: Whether to also set champion alias (default: True)
|
|
1433
1502
|
|
|
1434
|
-
|
|
1435
|
-
|
|
1503
|
+
If registration fails (e.g., in Model Serving with restricted permissions),
|
|
1504
|
+
logs the error and raises.
|
|
1505
|
+
"""
|
|
1436
1506
|
logger.info(
|
|
1437
|
-
|
|
1507
|
+
"Registering default template",
|
|
1508
|
+
prompt_name=prompt_name,
|
|
1509
|
+
set_champion=set_champion,
|
|
1438
1510
|
)
|
|
1439
1511
|
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
prompt_uris=[prompt_uri], # Use the configured URI (version/alias/latest)
|
|
1448
|
-
optimizer=optimizer,
|
|
1449
|
-
scorers=scorers,
|
|
1450
|
-
enable_tracking=False, # Don't auto-register all candidates
|
|
1451
|
-
)
|
|
1452
|
-
|
|
1453
|
-
# Log the optimization results
|
|
1454
|
-
logger.info("Optimization complete!")
|
|
1455
|
-
logger.info(f"Optimizer used: {result.optimizer_name}")
|
|
1456
|
-
|
|
1457
|
-
if result.optimized_prompts:
|
|
1458
|
-
optimized_prompt_version: PromptVersion = result.optimized_prompts[0]
|
|
1459
|
-
|
|
1460
|
-
# Check if the optimized prompt is actually different from the original
|
|
1461
|
-
original_template: str = prompt_version.to_single_brace_format().strip()
|
|
1462
|
-
optimized_template: str = (
|
|
1463
|
-
optimized_prompt_version.to_single_brace_format().strip()
|
|
1464
|
-
)
|
|
1465
|
-
|
|
1466
|
-
# Normalize whitespace for more robust comparison
|
|
1467
|
-
original_normalized: str = re.sub(r"\s+", " ", original_template).strip()
|
|
1468
|
-
optimized_normalized: str = re.sub(r"\s+", " ", optimized_template).strip()
|
|
1469
|
-
|
|
1470
|
-
logger.debug(f"Original template length: {len(original_template)} chars")
|
|
1471
|
-
logger.debug(f"Optimized template length: {len(optimized_template)} chars")
|
|
1472
|
-
logger.debug(
|
|
1473
|
-
f"Templates identical: {original_normalized == optimized_normalized}"
|
|
1474
|
-
)
|
|
1475
|
-
|
|
1476
|
-
if original_normalized == optimized_normalized:
|
|
1477
|
-
logger.info(
|
|
1478
|
-
f"Optimized prompt is identical to original for '{prompt.full_name}'. "
|
|
1479
|
-
"No new version will be registered."
|
|
1480
|
-
)
|
|
1481
|
-
return prompt
|
|
1482
|
-
|
|
1483
|
-
logger.info("Optimized prompt is DIFFERENT from original")
|
|
1484
|
-
logger.info(
|
|
1485
|
-
f"Original length: {len(original_template)}, Optimized length: {len(optimized_template)}"
|
|
1486
|
-
)
|
|
1487
|
-
logger.debug(
|
|
1488
|
-
f"Original template (first 300 chars): {original_template[:300]}..."
|
|
1489
|
-
)
|
|
1490
|
-
logger.debug(
|
|
1491
|
-
f"Optimized template (first 300 chars): {optimized_template[:300]}..."
|
|
1512
|
+
try:
|
|
1513
|
+
commit_message = description or "Auto-synced from default_template"
|
|
1514
|
+
prompt_version = mlflow.genai.register_prompt(
|
|
1515
|
+
name=prompt_name,
|
|
1516
|
+
template=default_template,
|
|
1517
|
+
commit_message=commit_message,
|
|
1518
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1492
1519
|
)
|
|
1493
1520
|
|
|
1494
|
-
#
|
|
1495
|
-
should_register: bool = False
|
|
1496
|
-
has_improvement: bool = False
|
|
1497
|
-
|
|
1498
|
-
if (
|
|
1499
|
-
result.initial_eval_score is not None
|
|
1500
|
-
and result.final_eval_score is not None
|
|
1501
|
-
):
|
|
1502
|
-
logger.info("Evaluation scores:")
|
|
1503
|
-
logger.info(f" Initial score: {result.initial_eval_score}")
|
|
1504
|
-
logger.info(f" Final score: {result.final_eval_score}")
|
|
1505
|
-
|
|
1506
|
-
# Only register if there's improvement
|
|
1507
|
-
if result.final_eval_score > result.initial_eval_score:
|
|
1508
|
-
improvement: float = (
|
|
1509
|
-
(result.final_eval_score - result.initial_eval_score)
|
|
1510
|
-
/ result.initial_eval_score
|
|
1511
|
-
) * 100
|
|
1512
|
-
logger.info(
|
|
1513
|
-
f"Optimized prompt improved by {improvement:.2f}% "
|
|
1514
|
-
f"(initial: {result.initial_eval_score}, final: {result.final_eval_score})"
|
|
1515
|
-
)
|
|
1516
|
-
should_register = True
|
|
1517
|
-
has_improvement = True
|
|
1518
|
-
else:
|
|
1519
|
-
logger.info(
|
|
1520
|
-
f"Optimized prompt (score: {result.final_eval_score}) did NOT improve over baseline (score: {result.initial_eval_score}). "
|
|
1521
|
-
"No new version will be registered."
|
|
1522
|
-
)
|
|
1523
|
-
else:
|
|
1524
|
-
# No scores available - register anyway but warn
|
|
1525
|
-
logger.warning(
|
|
1526
|
-
"No evaluation scores available to compare performance. "
|
|
1527
|
-
"Registering optimized prompt without performance validation."
|
|
1528
|
-
)
|
|
1529
|
-
should_register = True
|
|
1530
|
-
|
|
1531
|
-
if not should_register:
|
|
1532
|
-
logger.info(
|
|
1533
|
-
f"Skipping registration for '{prompt.full_name}' (no improvement)"
|
|
1534
|
-
)
|
|
1535
|
-
return prompt
|
|
1536
|
-
|
|
1537
|
-
# Register the optimized prompt manually
|
|
1521
|
+
# Always set default alias
|
|
1538
1522
|
try:
|
|
1539
|
-
logger.
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
commit_message=f"Optimized for {agent_model.model.uri} using GepaPromptOptimizer",
|
|
1544
|
-
tags={
|
|
1545
|
-
"dao_ai": dao_ai_version(),
|
|
1546
|
-
"target_model": agent_model.model.uri,
|
|
1547
|
-
},
|
|
1548
|
-
)
|
|
1549
|
-
logger.info(
|
|
1550
|
-
f"Registered optimized prompt as version {registered_version.version}"
|
|
1551
|
-
)
|
|
1552
|
-
|
|
1553
|
-
# Always set "latest" alias (represents most recently registered prompt)
|
|
1554
|
-
logger.info(
|
|
1555
|
-
f"Setting 'latest' alias for optimized prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1523
|
+
logger.debug(
|
|
1524
|
+
"Setting default alias",
|
|
1525
|
+
prompt_name=prompt_name,
|
|
1526
|
+
version=prompt_version.version,
|
|
1556
1527
|
)
|
|
1557
1528
|
mlflow.genai.set_prompt_alias(
|
|
1558
|
-
name=
|
|
1559
|
-
alias="latest",
|
|
1560
|
-
version=registered_version.version,
|
|
1529
|
+
name=prompt_name, alias="default", version=prompt_version.version
|
|
1561
1530
|
)
|
|
1562
|
-
logger.
|
|
1563
|
-
|
|
1531
|
+
logger.success(
|
|
1532
|
+
"Set default alias for prompt",
|
|
1533
|
+
prompt_name=prompt_name,
|
|
1534
|
+
version=prompt_version.version,
|
|
1535
|
+
)
|
|
1536
|
+
except Exception as alias_error:
|
|
1537
|
+
logger.warning(
|
|
1538
|
+
"Could not set default alias",
|
|
1539
|
+
prompt_name=prompt_name,
|
|
1540
|
+
error=str(alias_error),
|
|
1564
1541
|
)
|
|
1565
1542
|
|
|
1566
|
-
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
logger.info(
|
|
1570
|
-
f"Setting 'champion' alias for improved prompt '{prompt.full_name}' version {registered_version.version}"
|
|
1571
|
-
)
|
|
1543
|
+
# Optionally set champion alias (only if no champion exists or explicitly requested)
|
|
1544
|
+
if set_champion:
|
|
1545
|
+
try:
|
|
1572
1546
|
mlflow.genai.set_prompt_alias(
|
|
1573
|
-
name=
|
|
1547
|
+
name=prompt_name,
|
|
1574
1548
|
alias="champion",
|
|
1575
|
-
version=
|
|
1549
|
+
version=prompt_version.version,
|
|
1576
1550
|
)
|
|
1577
|
-
logger.
|
|
1578
|
-
|
|
1551
|
+
logger.success(
|
|
1552
|
+
"Set champion alias for prompt",
|
|
1553
|
+
prompt_name=prompt_name,
|
|
1554
|
+
version=prompt_version.version,
|
|
1555
|
+
)
|
|
1556
|
+
except Exception as alias_error:
|
|
1557
|
+
logger.warning(
|
|
1558
|
+
"Could not set champion alias",
|
|
1559
|
+
prompt_name=prompt_name,
|
|
1560
|
+
error=str(alias_error),
|
|
1579
1561
|
)
|
|
1580
1562
|
|
|
1581
|
-
|
|
1582
|
-
tags: dict[str, Any] = prompt.tags.copy() if prompt.tags else {}
|
|
1583
|
-
tags["target_model"] = agent_model.model.uri
|
|
1584
|
-
tags["dao_ai"] = dao_ai_version()
|
|
1585
|
-
|
|
1586
|
-
# Return the optimized prompt with the appropriate alias
|
|
1587
|
-
# Use "champion" if there was improvement, otherwise "latest"
|
|
1588
|
-
result_alias: str = "champion" if has_improvement else "latest"
|
|
1589
|
-
return PromptModel(
|
|
1590
|
-
name=prompt.name,
|
|
1591
|
-
schema=prompt.schema_model,
|
|
1592
|
-
description=f"Optimized version of {prompt.name} for {agent_model.model.uri}",
|
|
1593
|
-
alias=result_alias,
|
|
1594
|
-
tags=tags,
|
|
1595
|
-
)
|
|
1563
|
+
return prompt_version
|
|
1596
1564
|
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
|
|
1600
|
-
|
|
1601
|
-
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1565
|
+
except Exception as reg_error:
|
|
1566
|
+
logger.error(
|
|
1567
|
+
"Failed to register prompt - please register from notebook with write permissions",
|
|
1568
|
+
prompt_name=prompt_name,
|
|
1569
|
+
error=str(reg_error),
|
|
1570
|
+
)
|
|
1571
|
+
return PromptVersion(
|
|
1572
|
+
name=prompt_name,
|
|
1573
|
+
version=1,
|
|
1574
|
+
template=default_template,
|
|
1575
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1576
|
+
)
|