dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import base64
|
|
2
2
|
import uuid
|
|
3
|
-
from importlib.metadata import version
|
|
4
3
|
from pathlib import Path
|
|
5
4
|
from typing import Any, Callable, Final, Sequence
|
|
6
5
|
|
|
@@ -32,6 +31,7 @@ from mlflow import MlflowClient
|
|
|
32
31
|
from mlflow.entities import Experiment
|
|
33
32
|
from mlflow.entities.model_registry import PromptVersion
|
|
34
33
|
from mlflow.entities.model_registry.model_version import ModelVersion
|
|
34
|
+
from mlflow.genai.prompts import load_prompt
|
|
35
35
|
from mlflow.models.auth_policy import AuthPolicy, SystemAuthPolicy, UserAuthPolicy
|
|
36
36
|
from mlflow.models.model import ModelInfo
|
|
37
37
|
from mlflow.models.resources import (
|
|
@@ -46,6 +46,7 @@ from dao_ai.config import (
|
|
|
46
46
|
AppConfig,
|
|
47
47
|
ConnectionModel,
|
|
48
48
|
DatabaseModel,
|
|
49
|
+
DatabricksAppModel,
|
|
49
50
|
DatasetModel,
|
|
50
51
|
FunctionModel,
|
|
51
52
|
GenieRoomModel,
|
|
@@ -65,9 +66,11 @@ from dao_ai.config import (
|
|
|
65
66
|
from dao_ai.models import get_latest_model_version
|
|
66
67
|
from dao_ai.providers.base import ServiceProvider
|
|
67
68
|
from dao_ai.utils import (
|
|
69
|
+
dao_ai_version,
|
|
68
70
|
get_installed_packages,
|
|
69
71
|
is_installed,
|
|
70
72
|
is_lib_provided,
|
|
73
|
+
normalize_host,
|
|
71
74
|
normalize_name,
|
|
72
75
|
)
|
|
73
76
|
from dao_ai.vector_search import endpoint_exists, index_exists
|
|
@@ -89,15 +92,18 @@ def _workspace_client(
|
|
|
89
92
|
Create a WorkspaceClient instance with the provided parameters.
|
|
90
93
|
If no parameters are provided, it will use the default configuration.
|
|
91
94
|
"""
|
|
92
|
-
|
|
95
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
96
|
+
normalized_host = normalize_host(workspace_host)
|
|
97
|
+
|
|
98
|
+
if client_id and client_secret and normalized_host:
|
|
93
99
|
return WorkspaceClient(
|
|
94
|
-
host=
|
|
100
|
+
host=normalized_host,
|
|
95
101
|
client_id=client_id,
|
|
96
102
|
client_secret=client_secret,
|
|
97
103
|
auth_type="oauth-m2m",
|
|
98
104
|
)
|
|
99
105
|
elif pat:
|
|
100
|
-
return WorkspaceClient(host=
|
|
106
|
+
return WorkspaceClient(host=normalized_host, token=pat, auth_type="pat")
|
|
101
107
|
else:
|
|
102
108
|
return WorkspaceClient()
|
|
103
109
|
|
|
@@ -112,15 +118,18 @@ def _vector_search_client(
|
|
|
112
118
|
Create a VectorSearchClient instance with the provided parameters.
|
|
113
119
|
If no parameters are provided, it will use the default configuration.
|
|
114
120
|
"""
|
|
115
|
-
|
|
121
|
+
# Normalize the workspace host to ensure it has https:// scheme
|
|
122
|
+
normalized_host = normalize_host(workspace_host)
|
|
123
|
+
|
|
124
|
+
if client_id and client_secret and normalized_host:
|
|
116
125
|
return VectorSearchClient(
|
|
117
|
-
workspace_url=
|
|
126
|
+
workspace_url=normalized_host,
|
|
118
127
|
service_principal_client_id=client_id,
|
|
119
128
|
service_principal_client_secret=client_secret,
|
|
120
129
|
)
|
|
121
|
-
elif pat and
|
|
130
|
+
elif pat and normalized_host:
|
|
122
131
|
return VectorSearchClient(
|
|
123
|
-
workspace_url=
|
|
132
|
+
workspace_url=normalized_host,
|
|
124
133
|
personal_access_token=pat,
|
|
125
134
|
)
|
|
126
135
|
else:
|
|
@@ -172,15 +181,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
172
181
|
experiment: Experiment | None = mlflow.get_experiment_by_name(experiment_name)
|
|
173
182
|
if experiment is None:
|
|
174
183
|
experiment_id: str = mlflow.create_experiment(name=experiment_name)
|
|
175
|
-
logger.
|
|
176
|
-
|
|
184
|
+
logger.success(
|
|
185
|
+
"Created new MLflow experiment",
|
|
186
|
+
experiment_name=experiment_name,
|
|
187
|
+
experiment_id=experiment_id,
|
|
177
188
|
)
|
|
178
189
|
experiment = mlflow.get_experiment(experiment_id)
|
|
179
190
|
return experiment
|
|
180
191
|
|
|
181
192
|
def create_token(self) -> str:
|
|
182
193
|
current_user: User = self.w.current_user.me()
|
|
183
|
-
logger.debug(
|
|
194
|
+
logger.debug("Authenticated to Databricks", user=str(current_user))
|
|
184
195
|
headers: dict[str, str] = self.w.config.authenticate()
|
|
185
196
|
token: str = headers["Authorization"].replace("Bearer ", "")
|
|
186
197
|
return token
|
|
@@ -192,17 +203,24 @@ class DatabricksProvider(ServiceProvider):
|
|
|
192
203
|
secret_response: GetSecretResponse = self.w.secrets.get_secret(
|
|
193
204
|
secret_scope, secret_key
|
|
194
205
|
)
|
|
195
|
-
logger.
|
|
206
|
+
logger.trace(
|
|
207
|
+
"Retrieved secret", secret_key=secret_key, secret_scope=secret_scope
|
|
208
|
+
)
|
|
196
209
|
encoded_secret: str = secret_response.value
|
|
197
210
|
decoded_secret: str = base64.b64decode(encoded_secret).decode("utf-8")
|
|
198
211
|
return decoded_secret
|
|
199
212
|
except NotFound:
|
|
200
213
|
logger.warning(
|
|
201
|
-
|
|
214
|
+
"Secret not found, using default value",
|
|
215
|
+
secret_key=secret_key,
|
|
216
|
+
secret_scope=secret_scope,
|
|
202
217
|
)
|
|
203
218
|
except Exception as e:
|
|
204
219
|
logger.error(
|
|
205
|
-
|
|
220
|
+
"Error retrieving secret",
|
|
221
|
+
secret_key=secret_key,
|
|
222
|
+
secret_scope=secret_scope,
|
|
223
|
+
error=str(e),
|
|
206
224
|
)
|
|
207
225
|
|
|
208
226
|
return default_value
|
|
@@ -211,9 +229,18 @@ class DatabricksProvider(ServiceProvider):
|
|
|
211
229
|
self,
|
|
212
230
|
config: AppConfig,
|
|
213
231
|
) -> ModelInfo:
|
|
214
|
-
logger.
|
|
232
|
+
logger.info("Creating agent")
|
|
215
233
|
mlflow.set_registry_uri("databricks-uc")
|
|
216
234
|
|
|
235
|
+
# Set up experiment for proper tracking
|
|
236
|
+
experiment: Experiment = self.get_or_create_experiment(config)
|
|
237
|
+
mlflow.set_experiment(experiment_id=experiment.experiment_id)
|
|
238
|
+
logger.debug(
|
|
239
|
+
"Using MLflow experiment",
|
|
240
|
+
experiment_name=experiment.name,
|
|
241
|
+
experiment_id=experiment.experiment_id,
|
|
242
|
+
)
|
|
243
|
+
|
|
217
244
|
llms: Sequence[LLMModel] = list(config.resources.llms.values())
|
|
218
245
|
vector_indexes: Sequence[IndexModel] = list(
|
|
219
246
|
config.resources.vector_stores.values()
|
|
@@ -231,6 +258,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
231
258
|
)
|
|
232
259
|
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
233
260
|
volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
|
|
261
|
+
apps: Sequence[DatabricksAppModel] = list(config.resources.apps.values())
|
|
234
262
|
|
|
235
263
|
resources: Sequence[IsDatabricksResource] = (
|
|
236
264
|
llms
|
|
@@ -242,6 +270,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
242
270
|
+ connections
|
|
243
271
|
+ databases
|
|
244
272
|
+ volumes
|
|
273
|
+
+ apps
|
|
245
274
|
)
|
|
246
275
|
|
|
247
276
|
# Flatten all resources from all models into a single list
|
|
@@ -255,12 +284,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
255
284
|
for resource in r.as_resources()
|
|
256
285
|
if not r.on_behalf_of_user
|
|
257
286
|
]
|
|
258
|
-
logger.
|
|
287
|
+
logger.trace(
|
|
288
|
+
"System resources identified",
|
|
289
|
+
count=len(system_resources),
|
|
290
|
+
resources=[r.name for r in system_resources],
|
|
291
|
+
)
|
|
259
292
|
|
|
260
293
|
system_auth_policy: SystemAuthPolicy = SystemAuthPolicy(
|
|
261
294
|
resources=system_resources
|
|
262
295
|
)
|
|
263
|
-
logger.
|
|
296
|
+
logger.trace("System auth policy created", policy=str(system_auth_policy))
|
|
264
297
|
|
|
265
298
|
api_scopes: Sequence[str] = list(
|
|
266
299
|
set(
|
|
@@ -272,15 +305,19 @@ class DatabricksProvider(ServiceProvider):
|
|
|
272
305
|
]
|
|
273
306
|
)
|
|
274
307
|
)
|
|
275
|
-
logger.
|
|
308
|
+
logger.trace("API scopes identified", scopes=api_scopes)
|
|
276
309
|
|
|
277
310
|
user_auth_policy: UserAuthPolicy = UserAuthPolicy(api_scopes=api_scopes)
|
|
278
|
-
logger.
|
|
311
|
+
logger.trace("User auth policy created", policy=str(user_auth_policy))
|
|
279
312
|
|
|
280
313
|
auth_policy: AuthPolicy = AuthPolicy(
|
|
281
314
|
system_auth_policy=system_auth_policy, user_auth_policy=user_auth_policy
|
|
282
315
|
)
|
|
283
|
-
logger.debug(
|
|
316
|
+
logger.debug(
|
|
317
|
+
"Auth policy created",
|
|
318
|
+
has_system_auth=system_auth_policy is not None,
|
|
319
|
+
has_user_auth=user_auth_policy is not None,
|
|
320
|
+
)
|
|
284
321
|
|
|
285
322
|
code_paths: list[str] = config.app.code_paths
|
|
286
323
|
for path in code_paths:
|
|
@@ -296,7 +333,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
296
333
|
if is_installed():
|
|
297
334
|
if not is_lib_provided("dao-ai", pip_requirements):
|
|
298
335
|
pip_requirements += [
|
|
299
|
-
f"dao-ai=={
|
|
336
|
+
f"dao-ai=={dao_ai_version()}",
|
|
300
337
|
]
|
|
301
338
|
else:
|
|
302
339
|
src_path: Path = model_root_path.parent
|
|
@@ -307,27 +344,52 @@ class DatabricksProvider(ServiceProvider):
|
|
|
307
344
|
|
|
308
345
|
pip_requirements += get_installed_packages()
|
|
309
346
|
|
|
310
|
-
logger.
|
|
311
|
-
logger.
|
|
347
|
+
logger.trace("Pip requirements prepared", count=len(pip_requirements))
|
|
348
|
+
logger.trace("Code paths prepared", count=len(code_paths))
|
|
312
349
|
|
|
313
350
|
run_name: str = normalize_name(config.app.name)
|
|
314
|
-
logger.debug(
|
|
315
|
-
|
|
351
|
+
logger.debug(
|
|
352
|
+
"Agent run configuration",
|
|
353
|
+
run_name=run_name,
|
|
354
|
+
model_path=model_path.as_posix(),
|
|
355
|
+
)
|
|
316
356
|
|
|
317
357
|
input_example: dict[str, Any] = None
|
|
318
358
|
if config.app.input_example:
|
|
319
359
|
input_example = config.app.input_example.model_dump()
|
|
320
360
|
|
|
321
|
-
logger.
|
|
361
|
+
logger.trace("Input example configured", has_example=input_example is not None)
|
|
362
|
+
|
|
363
|
+
# Create conda environment with configured Python version
|
|
364
|
+
# This allows deploying from environments with different Python versions
|
|
365
|
+
# (e.g., Databricks Apps with Python 3.11 can deploy to Model Serving with 3.12)
|
|
366
|
+
target_python_version: str = config.app.python_version
|
|
367
|
+
logger.debug("Target Python version configured", version=target_python_version)
|
|
368
|
+
|
|
369
|
+
conda_env: dict[str, Any] = {
|
|
370
|
+
"name": "mlflow-env",
|
|
371
|
+
"channels": ["conda-forge"],
|
|
372
|
+
"dependencies": [
|
|
373
|
+
f"python={target_python_version}",
|
|
374
|
+
"pip",
|
|
375
|
+
{"pip": list(pip_requirements)},
|
|
376
|
+
],
|
|
377
|
+
}
|
|
378
|
+
logger.trace(
|
|
379
|
+
"Conda environment configured",
|
|
380
|
+
python_version=target_python_version,
|
|
381
|
+
pip_packages_count=len(pip_requirements),
|
|
382
|
+
)
|
|
322
383
|
|
|
323
384
|
with mlflow.start_run(run_name=run_name):
|
|
324
385
|
mlflow.set_tag("type", "agent")
|
|
386
|
+
mlflow.set_tag("dao_ai", dao_ai_version())
|
|
325
387
|
logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
|
|
326
388
|
python_model=model_path.as_posix(),
|
|
327
389
|
code_paths=code_paths,
|
|
328
|
-
model_config=config.model_dump(by_alias=True),
|
|
390
|
+
model_config=config.model_dump(mode="json", by_alias=True),
|
|
329
391
|
name="agent",
|
|
330
|
-
|
|
392
|
+
conda_env=conda_env,
|
|
331
393
|
input_example=input_example,
|
|
332
394
|
# resources=all_resources,
|
|
333
395
|
auth_policy=auth_policy,
|
|
@@ -338,15 +400,26 @@ class DatabricksProvider(ServiceProvider):
|
|
|
338
400
|
model_version: ModelVersion = mlflow.register_model(
|
|
339
401
|
name=registered_model_name, model_uri=logged_agent_info.model_uri
|
|
340
402
|
)
|
|
341
|
-
logger.
|
|
342
|
-
|
|
403
|
+
logger.success(
|
|
404
|
+
"Model registered",
|
|
405
|
+
model_name=registered_model_name,
|
|
406
|
+
version=model_version.version,
|
|
343
407
|
)
|
|
344
408
|
|
|
345
409
|
client: MlflowClient = MlflowClient()
|
|
346
410
|
|
|
411
|
+
# Set tags on the model version
|
|
412
|
+
client.set_model_version_tag(
|
|
413
|
+
name=registered_model_name,
|
|
414
|
+
version=model_version.version,
|
|
415
|
+
key="dao_ai",
|
|
416
|
+
value=dao_ai_version(),
|
|
417
|
+
)
|
|
418
|
+
logger.trace("Set dao_ai tag on model version", version=model_version.version)
|
|
419
|
+
|
|
347
420
|
client.set_registered_model_alias(
|
|
348
421
|
name=registered_model_name,
|
|
349
|
-
alias="
|
|
422
|
+
alias="Champion",
|
|
350
423
|
version=model_version.version,
|
|
351
424
|
)
|
|
352
425
|
|
|
@@ -359,12 +432,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
359
432
|
aliased_model: ModelVersion = client.get_model_version_by_alias(
|
|
360
433
|
registered_model_name, config.app.alias
|
|
361
434
|
)
|
|
362
|
-
logger.
|
|
363
|
-
|
|
435
|
+
logger.info(
|
|
436
|
+
"Model aliased",
|
|
437
|
+
model_name=registered_model_name,
|
|
438
|
+
alias=config.app.alias,
|
|
439
|
+
version=aliased_model.version,
|
|
364
440
|
)
|
|
365
441
|
|
|
366
442
|
def deploy_agent(self, config: AppConfig) -> None:
|
|
367
|
-
logger.
|
|
443
|
+
logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
|
|
368
444
|
mlflow.set_registry_uri("databricks-uc")
|
|
369
445
|
|
|
370
446
|
endpoint_name: str = config.app.endpoint_name
|
|
@@ -372,7 +448,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
372
448
|
scale_to_zero: bool = config.app.scale_to_zero
|
|
373
449
|
environment_vars: dict[str, str] = config.app.environment_vars
|
|
374
450
|
workload_size: str = config.app.workload_size
|
|
375
|
-
tags: dict[str, str] = config.app.tags
|
|
451
|
+
tags: dict[str, str] = config.app.tags.copy() if config.app.tags else {}
|
|
452
|
+
|
|
453
|
+
# Add dao_ai framework tag
|
|
454
|
+
tags["dao_ai"] = dao_ai_version()
|
|
376
455
|
|
|
377
456
|
latest_version: int = get_latest_model_version(registered_model_name)
|
|
378
457
|
|
|
@@ -382,12 +461,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
382
461
|
agents.get_deployments(endpoint_name)
|
|
383
462
|
endpoint_exists = True
|
|
384
463
|
logger.debug(
|
|
385
|
-
|
|
464
|
+
"Endpoint already exists, updating", endpoint_name=endpoint_name
|
|
386
465
|
)
|
|
387
466
|
except Exception:
|
|
388
|
-
logger.debug(
|
|
389
|
-
f"Endpoint {endpoint_name} doesn't exist, creating new with tags..."
|
|
390
|
-
)
|
|
467
|
+
logger.debug("Creating new endpoint", endpoint_name=endpoint_name)
|
|
391
468
|
|
|
392
469
|
# Deploy - skip tags for existing endpoints to avoid conflicts
|
|
393
470
|
agents.deploy(
|
|
@@ -403,8 +480,11 @@ class DatabricksProvider(ServiceProvider):
|
|
|
403
480
|
registered_model_name: str = config.app.registered_model.full_name
|
|
404
481
|
permissions: Sequence[dict[str, Any]] = config.app.permissions
|
|
405
482
|
|
|
406
|
-
logger.debug(
|
|
407
|
-
|
|
483
|
+
logger.debug(
|
|
484
|
+
"Configuring model permissions",
|
|
485
|
+
model_name=registered_model_name,
|
|
486
|
+
permissions_count=len(permissions),
|
|
487
|
+
)
|
|
408
488
|
|
|
409
489
|
for permission in permissions:
|
|
410
490
|
principals: Sequence[str] = permission.principals
|
|
@@ -424,7 +504,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
424
504
|
try:
|
|
425
505
|
catalog_info = self.w.catalogs.get(name=schema.catalog_name)
|
|
426
506
|
except NotFound:
|
|
427
|
-
logger.
|
|
507
|
+
logger.info("Creating catalog", catalog_name=schema.catalog_name)
|
|
428
508
|
catalog_info = self.w.catalogs.create(name=schema.catalog_name)
|
|
429
509
|
return catalog_info
|
|
430
510
|
|
|
@@ -434,7 +514,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
434
514
|
try:
|
|
435
515
|
schema_info = self.w.schemas.get(full_name=schema.full_name)
|
|
436
516
|
except NotFound:
|
|
437
|
-
logger.
|
|
517
|
+
logger.info("Creating schema", schema_name=schema.full_name)
|
|
438
518
|
schema_info = self.w.schemas.create(
|
|
439
519
|
name=schema.schema_name, catalog_name=catalog_info.name
|
|
440
520
|
)
|
|
@@ -446,7 +526,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
446
526
|
try:
|
|
447
527
|
volume_info = self.w.volumes.read(name=volume.full_name)
|
|
448
528
|
except NotFound:
|
|
449
|
-
logger.
|
|
529
|
+
logger.info("Creating volume", volume_name=volume.full_name)
|
|
450
530
|
volume_info = self.w.volumes.create(
|
|
451
531
|
catalog_name=schema_info.catalog_name,
|
|
452
532
|
schema_name=schema_info.name,
|
|
@@ -457,7 +537,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
457
537
|
|
|
458
538
|
def create_path(self, volume_path: VolumePathModel) -> Path:
|
|
459
539
|
path: Path = volume_path.full_name
|
|
460
|
-
logger.info(
|
|
540
|
+
logger.info("Creating volume path", path=str(path))
|
|
461
541
|
self.w.files.create_directory(path)
|
|
462
542
|
return path
|
|
463
543
|
|
|
@@ -498,11 +578,12 @@ class DatabricksProvider(ServiceProvider):
|
|
|
498
578
|
|
|
499
579
|
if ddl:
|
|
500
580
|
ddl_path: Path = Path(ddl)
|
|
501
|
-
logger.debug(
|
|
581
|
+
logger.debug("Executing DDL", ddl_path=str(ddl_path))
|
|
502
582
|
statements: Sequence[str] = sqlparse.parse(ddl_path.read_text())
|
|
503
583
|
for statement in statements:
|
|
504
|
-
logger.
|
|
505
|
-
|
|
584
|
+
logger.trace(
|
|
585
|
+
"Executing DDL statement", statement=str(statement)[:100], args=args
|
|
586
|
+
)
|
|
506
587
|
spark.sql(
|
|
507
588
|
str(statement),
|
|
508
589
|
args=args,
|
|
@@ -511,20 +592,23 @@ class DatabricksProvider(ServiceProvider):
|
|
|
511
592
|
if data:
|
|
512
593
|
data_path: Path = Path(data)
|
|
513
594
|
if format == "sql":
|
|
514
|
-
logger.debug(
|
|
595
|
+
logger.debug("Executing SQL from file", data_path=str(data_path))
|
|
515
596
|
data_statements: Sequence[str] = sqlparse.parse(data_path.read_text())
|
|
516
597
|
for statement in data_statements:
|
|
517
|
-
logger.
|
|
518
|
-
|
|
598
|
+
logger.trace(
|
|
599
|
+
"Executing SQL statement",
|
|
600
|
+
statement=str(statement)[:100],
|
|
601
|
+
args=args,
|
|
602
|
+
)
|
|
519
603
|
spark.sql(
|
|
520
604
|
str(statement),
|
|
521
605
|
args=args,
|
|
522
606
|
)
|
|
523
607
|
else:
|
|
524
|
-
logger.debug(
|
|
608
|
+
logger.debug("Writing dataset to table", table=table)
|
|
525
609
|
if not data_path.is_absolute():
|
|
526
610
|
data_path = current_dir / data_path
|
|
527
|
-
logger.
|
|
611
|
+
logger.trace("Data path resolved", path=data_path.as_posix())
|
|
528
612
|
if format == "excel":
|
|
529
613
|
pdf = pd.read_excel(data_path.as_posix())
|
|
530
614
|
df = spark.createDataFrame(pdf, schema=dataset.table_schema)
|
|
@@ -548,13 +632,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
548
632
|
verbose=True,
|
|
549
633
|
)
|
|
550
634
|
|
|
551
|
-
logger.
|
|
635
|
+
logger.success(
|
|
636
|
+
"Vector search endpoint ready", endpoint_name=vector_store.endpoint.name
|
|
637
|
+
)
|
|
552
638
|
|
|
553
639
|
if not index_exists(
|
|
554
640
|
self.vsc, vector_store.endpoint.name, vector_store.index.full_name
|
|
555
641
|
):
|
|
556
|
-
logger.
|
|
557
|
-
|
|
642
|
+
logger.info(
|
|
643
|
+
"Creating vector search index",
|
|
644
|
+
index_name=vector_store.index.full_name,
|
|
645
|
+
endpoint_name=vector_store.endpoint.name,
|
|
558
646
|
)
|
|
559
647
|
self.vsc.create_delta_sync_index_and_wait(
|
|
560
648
|
endpoint_name=vector_store.endpoint.name,
|
|
@@ -568,7 +656,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
568
656
|
)
|
|
569
657
|
else:
|
|
570
658
|
logger.debug(
|
|
571
|
-
|
|
659
|
+
"Vector search index already exists, checking status",
|
|
660
|
+
index_name=vector_store.index.full_name,
|
|
572
661
|
)
|
|
573
662
|
index = self.vsc.get_index(
|
|
574
663
|
vector_store.endpoint.name, vector_store.index.full_name
|
|
@@ -591,54 +680,61 @@ class DatabricksProvider(ServiceProvider):
|
|
|
591
680
|
|
|
592
681
|
if pipeline_status in [
|
|
593
682
|
"COMPLETED",
|
|
683
|
+
"ONLINE",
|
|
594
684
|
"FAILED",
|
|
595
685
|
"CANCELED",
|
|
596
686
|
"ONLINE_PIPELINE_FAILED",
|
|
597
687
|
]:
|
|
598
|
-
logger.debug(
|
|
599
|
-
f"Index is ready to sync (status: {pipeline_status})"
|
|
600
|
-
)
|
|
688
|
+
logger.debug("Index ready to sync", status=pipeline_status)
|
|
601
689
|
break
|
|
602
690
|
elif pipeline_status in [
|
|
603
691
|
"WAITING_FOR_RESOURCES",
|
|
604
692
|
"PROVISIONING",
|
|
605
693
|
"INITIALIZING",
|
|
606
694
|
"INDEXING",
|
|
607
|
-
"ONLINE",
|
|
608
695
|
]:
|
|
609
|
-
logger.
|
|
610
|
-
|
|
696
|
+
logger.trace(
|
|
697
|
+
"Index not ready, waiting",
|
|
698
|
+
status=pipeline_status,
|
|
699
|
+
wait_seconds=wait_interval,
|
|
611
700
|
)
|
|
612
701
|
time.sleep(wait_interval)
|
|
613
702
|
elapsed += wait_interval
|
|
614
703
|
else:
|
|
615
704
|
logger.warning(
|
|
616
|
-
|
|
705
|
+
"Unknown pipeline status, attempting sync",
|
|
706
|
+
status=pipeline_status,
|
|
617
707
|
)
|
|
618
708
|
break
|
|
619
709
|
except Exception as status_error:
|
|
620
710
|
logger.warning(
|
|
621
|
-
|
|
711
|
+
"Could not check index status, attempting sync",
|
|
712
|
+
error=str(status_error),
|
|
622
713
|
)
|
|
623
714
|
break
|
|
624
715
|
|
|
625
716
|
if elapsed >= max_wait_time:
|
|
626
717
|
logger.warning(
|
|
627
|
-
|
|
718
|
+
"Timed out waiting for index to be ready",
|
|
719
|
+
max_wait_seconds=max_wait_time,
|
|
628
720
|
)
|
|
629
721
|
|
|
630
722
|
# Now attempt to sync
|
|
631
723
|
try:
|
|
632
724
|
index.sync()
|
|
633
|
-
logger.
|
|
725
|
+
logger.success("Index sync completed")
|
|
634
726
|
except Exception as sync_error:
|
|
635
727
|
if "not ready to sync yet" in str(sync_error).lower():
|
|
636
|
-
logger.warning(
|
|
728
|
+
logger.warning(
|
|
729
|
+
"Index still not ready to sync", error=str(sync_error)
|
|
730
|
+
)
|
|
637
731
|
else:
|
|
638
732
|
raise sync_error
|
|
639
733
|
|
|
640
|
-
logger.
|
|
641
|
-
|
|
734
|
+
logger.success(
|
|
735
|
+
"Vector search index ready",
|
|
736
|
+
index_name=vector_store.index.full_name,
|
|
737
|
+
source_table=vector_store.source_table.full_name,
|
|
642
738
|
)
|
|
643
739
|
|
|
644
740
|
def get_vector_index(self, vector_store: VectorStoreModel) -> None:
|
|
@@ -674,12 +770,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
674
770
|
# sql = sql.replace("{catalog_name}", schema.catalog_name)
|
|
675
771
|
# sql = sql.replace("{schema_name}", schema.schema_name)
|
|
676
772
|
|
|
677
|
-
logger.info(function.name)
|
|
678
|
-
logger.
|
|
773
|
+
logger.info("Creating SQL function", function_name=function.name)
|
|
774
|
+
logger.trace("SQL function body", sql=sql[:200])
|
|
679
775
|
_: FunctionInfo = self.dfs.create_function(sql_function_body=sql)
|
|
680
776
|
|
|
681
777
|
if unity_catalog_function.test:
|
|
682
|
-
logger.
|
|
778
|
+
logger.debug(
|
|
779
|
+
"Testing function",
|
|
780
|
+
function_name=function.full_name,
|
|
781
|
+
parameters=unity_catalog_function.test.parameters,
|
|
782
|
+
)
|
|
683
783
|
|
|
684
784
|
result: FunctionExecutionResult = self.dfs.execute_function(
|
|
685
785
|
function_name=function.full_name,
|
|
@@ -687,37 +787,50 @@ class DatabricksProvider(ServiceProvider):
|
|
|
687
787
|
)
|
|
688
788
|
|
|
689
789
|
if result.error:
|
|
690
|
-
logger.error(
|
|
790
|
+
logger.error(
|
|
791
|
+
"Function test failed",
|
|
792
|
+
function_name=function.full_name,
|
|
793
|
+
error=result.error,
|
|
794
|
+
)
|
|
691
795
|
else:
|
|
692
|
-
logger.
|
|
693
|
-
|
|
796
|
+
logger.success(
|
|
797
|
+
"Function test passed", function_name=function.full_name
|
|
798
|
+
)
|
|
799
|
+
logger.debug("Function test result", result=str(result))
|
|
694
800
|
|
|
695
801
|
def find_columns(self, table_model: TableModel) -> Sequence[str]:
|
|
696
|
-
logger.
|
|
802
|
+
logger.trace("Finding columns for table", table=table_model.full_name)
|
|
697
803
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
698
804
|
columns: Sequence[ColumnInfo] = table_info.columns
|
|
699
805
|
column_names: Sequence[str] = [c.name for c in columns]
|
|
700
|
-
logger.debug(
|
|
806
|
+
logger.debug(
|
|
807
|
+
"Columns found",
|
|
808
|
+
table=table_model.full_name,
|
|
809
|
+
columns_count=len(column_names),
|
|
810
|
+
)
|
|
701
811
|
return column_names
|
|
702
812
|
|
|
703
813
|
def find_primary_key(self, table_model: TableModel) -> Sequence[str] | None:
|
|
704
|
-
logger.
|
|
814
|
+
logger.trace("Finding primary key for table", table=table_model.full_name)
|
|
705
815
|
primary_keys: Sequence[str] | None = None
|
|
706
816
|
table_info: TableInfo = self.w.tables.get(full_name=table_model.full_name)
|
|
707
817
|
constraints: Sequence[TableConstraint] = table_info.table_constraints
|
|
708
818
|
primary_key_constraint: PrimaryKeyConstraint | None = next(
|
|
709
|
-
c.primary_key_constraint for c in constraints if c.primary_key_constraint
|
|
819
|
+
(c.primary_key_constraint for c in constraints if c.primary_key_constraint),
|
|
820
|
+
None,
|
|
710
821
|
)
|
|
711
822
|
if primary_key_constraint:
|
|
712
823
|
primary_keys = primary_key_constraint.child_columns
|
|
713
824
|
|
|
714
|
-
logger.debug(
|
|
825
|
+
logger.debug(
|
|
826
|
+
"Primary key found", table=table_model.full_name, primary_keys=primary_keys
|
|
827
|
+
)
|
|
715
828
|
return primary_keys
|
|
716
829
|
|
|
717
830
|
def find_vector_search_endpoint(
|
|
718
831
|
self, predicate: Callable[[dict[str, Any]], bool]
|
|
719
832
|
) -> str | None:
|
|
720
|
-
logger.
|
|
833
|
+
logger.trace("Finding vector search endpoint")
|
|
721
834
|
endpoint_name: str | None = None
|
|
722
835
|
vector_search_endpoints: Sequence[dict[str, Any]] = (
|
|
723
836
|
self.vsc.list_endpoints().get("endpoints", [])
|
|
@@ -726,11 +839,13 @@ class DatabricksProvider(ServiceProvider):
|
|
|
726
839
|
if predicate(endpoint):
|
|
727
840
|
endpoint_name = endpoint["name"]
|
|
728
841
|
break
|
|
729
|
-
logger.debug(
|
|
842
|
+
logger.debug("Vector search endpoint found", endpoint_name=endpoint_name)
|
|
730
843
|
return endpoint_name
|
|
731
844
|
|
|
732
845
|
def find_endpoint_for_index(self, index_model: IndexModel) -> str | None:
|
|
733
|
-
logger.
|
|
846
|
+
logger.trace(
|
|
847
|
+
"Finding endpoint for vector search index", index_name=index_model.full_name
|
|
848
|
+
)
|
|
734
849
|
all_endpoints: Sequence[dict[str, Any]] = self.vsc.list_endpoints().get(
|
|
735
850
|
"endpoints", []
|
|
736
851
|
)
|
|
@@ -740,14 +855,99 @@ class DatabricksProvider(ServiceProvider):
|
|
|
740
855
|
endpoint_name: str = endpoint["name"]
|
|
741
856
|
indexes = self.vsc.list_indexes(name=endpoint_name)
|
|
742
857
|
vector_indexes: Sequence[dict[str, Any]] = indexes.get("vector_indexes", [])
|
|
743
|
-
logger.trace(
|
|
858
|
+
logger.trace(
|
|
859
|
+
"Checking endpoint for indexes",
|
|
860
|
+
endpoint_name=endpoint_name,
|
|
861
|
+
indexes_count=len(vector_indexes),
|
|
862
|
+
)
|
|
744
863
|
index_names = [vector_index["name"] for vector_index in vector_indexes]
|
|
745
864
|
if index_name in index_names:
|
|
746
865
|
found_endpoint_name = endpoint_name
|
|
747
866
|
break
|
|
748
|
-
logger.debug(
|
|
867
|
+
logger.debug(
|
|
868
|
+
"Vector search index endpoint found",
|
|
869
|
+
index_name=index_model.full_name,
|
|
870
|
+
endpoint_name=found_endpoint_name,
|
|
871
|
+
)
|
|
749
872
|
return found_endpoint_name
|
|
750
873
|
|
|
874
|
+
def _wait_for_database_available(
|
|
875
|
+
self,
|
|
876
|
+
workspace_client: WorkspaceClient,
|
|
877
|
+
instance_name: str,
|
|
878
|
+
max_wait_time: int = 600,
|
|
879
|
+
wait_interval: int = 10,
|
|
880
|
+
) -> None:
|
|
881
|
+
"""
|
|
882
|
+
Wait for a database instance to become AVAILABLE.
|
|
883
|
+
|
|
884
|
+
Args:
|
|
885
|
+
workspace_client: The Databricks workspace client
|
|
886
|
+
instance_name: Name of the database instance to wait for
|
|
887
|
+
max_wait_time: Maximum time to wait in seconds (default: 600 = 10 minutes)
|
|
888
|
+
wait_interval: Time between status checks in seconds (default: 10)
|
|
889
|
+
|
|
890
|
+
Raises:
|
|
891
|
+
TimeoutError: If the database doesn't become AVAILABLE within max_wait_time
|
|
892
|
+
RuntimeError: If the database enters a failed or deleted state
|
|
893
|
+
"""
|
|
894
|
+
import time
|
|
895
|
+
from typing import Any
|
|
896
|
+
|
|
897
|
+
logger.info(
|
|
898
|
+
"Waiting for database instance to become AVAILABLE",
|
|
899
|
+
instance_name=instance_name,
|
|
900
|
+
)
|
|
901
|
+
elapsed: int = 0
|
|
902
|
+
|
|
903
|
+
while elapsed < max_wait_time:
|
|
904
|
+
try:
|
|
905
|
+
current_instance: Any = workspace_client.database.get_database_instance(
|
|
906
|
+
name=instance_name
|
|
907
|
+
)
|
|
908
|
+
current_state: str = current_instance.state
|
|
909
|
+
logger.trace(
|
|
910
|
+
"Database instance state checked",
|
|
911
|
+
instance_name=instance_name,
|
|
912
|
+
state=current_state,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
if current_state == "AVAILABLE":
|
|
916
|
+
logger.success(
|
|
917
|
+
"Database instance is now AVAILABLE",
|
|
918
|
+
instance_name=instance_name,
|
|
919
|
+
)
|
|
920
|
+
return
|
|
921
|
+
elif current_state in ["STARTING", "UPDATING", "PROVISIONING"]:
|
|
922
|
+
logger.trace(
|
|
923
|
+
"Database instance not ready, waiting",
|
|
924
|
+
instance_name=instance_name,
|
|
925
|
+
state=current_state,
|
|
926
|
+
wait_seconds=wait_interval,
|
|
927
|
+
)
|
|
928
|
+
time.sleep(wait_interval)
|
|
929
|
+
elapsed += wait_interval
|
|
930
|
+
elif current_state in ["STOPPED", "DELETING", "FAILED"]:
|
|
931
|
+
raise RuntimeError(
|
|
932
|
+
f"Database instance {instance_name} entered unexpected state: {current_state}"
|
|
933
|
+
)
|
|
934
|
+
else:
|
|
935
|
+
logger.warning(
|
|
936
|
+
"Unknown database state, continuing to wait",
|
|
937
|
+
instance_name=instance_name,
|
|
938
|
+
state=current_state,
|
|
939
|
+
)
|
|
940
|
+
time.sleep(wait_interval)
|
|
941
|
+
elapsed += wait_interval
|
|
942
|
+
except NotFound:
|
|
943
|
+
raise RuntimeError(
|
|
944
|
+
f"Database instance {instance_name} was deleted while waiting for it to become AVAILABLE"
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
raise TimeoutError(
|
|
948
|
+
f"Timed out waiting for database instance {instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
949
|
+
)
|
|
950
|
+
|
|
751
951
|
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
752
952
|
"""
|
|
753
953
|
Create a Lakebase database instance using the Databricks workspace client.
|
|
@@ -778,13 +978,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
778
978
|
|
|
779
979
|
if existing_instance:
|
|
780
980
|
logger.debug(
|
|
781
|
-
|
|
981
|
+
"Database instance already exists",
|
|
982
|
+
instance_name=database.instance_name,
|
|
983
|
+
state=existing_instance.state,
|
|
782
984
|
)
|
|
783
985
|
|
|
784
986
|
# Check if database is in an intermediate state
|
|
785
987
|
if existing_instance.state in ["STARTING", "UPDATING"]:
|
|
786
988
|
logger.info(
|
|
787
|
-
|
|
989
|
+
"Database instance in intermediate state, waiting",
|
|
990
|
+
instance_name=database.instance_name,
|
|
991
|
+
state=existing_instance.state,
|
|
788
992
|
)
|
|
789
993
|
|
|
790
994
|
# Wait for database to reach a stable state
|
|
@@ -800,65 +1004,87 @@ class DatabricksProvider(ServiceProvider):
|
|
|
800
1004
|
)
|
|
801
1005
|
)
|
|
802
1006
|
current_state: str = current_instance.state
|
|
803
|
-
logger.
|
|
1007
|
+
logger.trace(
|
|
1008
|
+
"Checking database instance state",
|
|
1009
|
+
instance_name=database.instance_name,
|
|
1010
|
+
state=current_state,
|
|
1011
|
+
)
|
|
804
1012
|
|
|
805
1013
|
if current_state == "AVAILABLE":
|
|
806
|
-
logger.
|
|
807
|
-
|
|
1014
|
+
logger.success(
|
|
1015
|
+
"Database instance is now AVAILABLE",
|
|
1016
|
+
instance_name=database.instance_name,
|
|
808
1017
|
)
|
|
809
1018
|
break
|
|
810
1019
|
elif current_state in ["STARTING", "UPDATING"]:
|
|
811
|
-
logger.
|
|
812
|
-
|
|
1020
|
+
logger.trace(
|
|
1021
|
+
"Database instance not ready, waiting",
|
|
1022
|
+
instance_name=database.instance_name,
|
|
1023
|
+
state=current_state,
|
|
1024
|
+
wait_seconds=wait_interval,
|
|
813
1025
|
)
|
|
814
1026
|
time.sleep(wait_interval)
|
|
815
1027
|
elapsed += wait_interval
|
|
816
1028
|
elif current_state in ["STOPPED", "DELETING"]:
|
|
817
1029
|
logger.warning(
|
|
818
|
-
|
|
1030
|
+
"Database instance in unexpected state",
|
|
1031
|
+
instance_name=database.instance_name,
|
|
1032
|
+
state=current_state,
|
|
819
1033
|
)
|
|
820
1034
|
break
|
|
821
1035
|
else:
|
|
822
1036
|
logger.warning(
|
|
823
|
-
|
|
1037
|
+
"Unknown database state, proceeding",
|
|
1038
|
+
instance_name=database.instance_name,
|
|
1039
|
+
state=current_state,
|
|
824
1040
|
)
|
|
825
1041
|
break
|
|
826
1042
|
except NotFound:
|
|
827
1043
|
logger.warning(
|
|
828
|
-
|
|
1044
|
+
"Database instance no longer exists, will recreate",
|
|
1045
|
+
instance_name=database.instance_name,
|
|
829
1046
|
)
|
|
830
1047
|
break
|
|
831
1048
|
except Exception as state_error:
|
|
832
1049
|
logger.warning(
|
|
833
|
-
|
|
1050
|
+
"Could not check database state, proceeding",
|
|
1051
|
+
instance_name=database.instance_name,
|
|
1052
|
+
error=str(state_error),
|
|
834
1053
|
)
|
|
835
1054
|
break
|
|
836
1055
|
|
|
837
1056
|
if elapsed >= max_wait_time:
|
|
838
1057
|
logger.warning(
|
|
839
|
-
|
|
1058
|
+
"Timed out waiting for database to become AVAILABLE",
|
|
1059
|
+
instance_name=database.instance_name,
|
|
1060
|
+
max_wait_seconds=max_wait_time,
|
|
840
1061
|
)
|
|
841
1062
|
|
|
842
1063
|
elif existing_instance.state == "AVAILABLE":
|
|
843
1064
|
logger.info(
|
|
844
|
-
|
|
1065
|
+
"Database instance already exists and is AVAILABLE",
|
|
1066
|
+
instance_name=database.instance_name,
|
|
845
1067
|
)
|
|
846
1068
|
return
|
|
847
1069
|
elif existing_instance.state in ["STOPPED", "DELETING"]:
|
|
848
1070
|
logger.warning(
|
|
849
|
-
|
|
1071
|
+
"Database instance in terminal state",
|
|
1072
|
+
instance_name=database.instance_name,
|
|
1073
|
+
state=existing_instance.state,
|
|
850
1074
|
)
|
|
851
1075
|
return
|
|
852
1076
|
else:
|
|
853
1077
|
logger.info(
|
|
854
|
-
|
|
1078
|
+
"Database instance already exists",
|
|
1079
|
+
instance_name=database.instance_name,
|
|
1080
|
+
state=existing_instance.state,
|
|
855
1081
|
)
|
|
856
1082
|
return
|
|
857
1083
|
|
|
858
1084
|
except NotFound:
|
|
859
1085
|
# Database doesn't exist, proceed with creation
|
|
860
|
-
logger.
|
|
861
|
-
|
|
1086
|
+
logger.info(
|
|
1087
|
+
"Creating new database instance", instance_name=database.instance_name
|
|
862
1088
|
)
|
|
863
1089
|
|
|
864
1090
|
try:
|
|
@@ -878,10 +1104,17 @@ class DatabricksProvider(ServiceProvider):
|
|
|
878
1104
|
workspace_client.database.create_database_instance(
|
|
879
1105
|
database_instance=database_instance
|
|
880
1106
|
)
|
|
881
|
-
logger.
|
|
882
|
-
|
|
1107
|
+
logger.success(
|
|
1108
|
+
"Database instance created successfully",
|
|
1109
|
+
instance_name=database.instance_name,
|
|
883
1110
|
)
|
|
884
1111
|
|
|
1112
|
+
# Wait for the newly created database to become AVAILABLE
|
|
1113
|
+
self._wait_for_database_available(
|
|
1114
|
+
workspace_client, database.instance_name
|
|
1115
|
+
)
|
|
1116
|
+
return
|
|
1117
|
+
|
|
885
1118
|
except Exception as create_error:
|
|
886
1119
|
error_msg: str = str(create_error)
|
|
887
1120
|
|
|
@@ -891,13 +1124,20 @@ class DatabricksProvider(ServiceProvider):
|
|
|
891
1124
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
892
1125
|
):
|
|
893
1126
|
logger.info(
|
|
894
|
-
|
|
1127
|
+
"Database instance was created concurrently",
|
|
1128
|
+
instance_name=database.instance_name,
|
|
1129
|
+
)
|
|
1130
|
+
# Still need to wait for the database to become AVAILABLE
|
|
1131
|
+
self._wait_for_database_available(
|
|
1132
|
+
workspace_client, database.instance_name
|
|
895
1133
|
)
|
|
896
1134
|
return
|
|
897
1135
|
else:
|
|
898
1136
|
# Re-raise unexpected errors
|
|
899
1137
|
logger.error(
|
|
900
|
-
|
|
1138
|
+
"Error creating database instance",
|
|
1139
|
+
instance_name=database.instance_name,
|
|
1140
|
+
error=str(create_error),
|
|
901
1141
|
)
|
|
902
1142
|
raise
|
|
903
1143
|
|
|
@@ -911,12 +1151,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
911
1151
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
912
1152
|
):
|
|
913
1153
|
logger.info(
|
|
914
|
-
|
|
1154
|
+
"Database instance already exists (detected via exception)",
|
|
1155
|
+
instance_name=database.instance_name,
|
|
915
1156
|
)
|
|
916
1157
|
return
|
|
917
1158
|
else:
|
|
918
1159
|
logger.error(
|
|
919
|
-
|
|
1160
|
+
"Unexpected error while handling database",
|
|
1161
|
+
instance_name=database.instance_name,
|
|
1162
|
+
error=str(e),
|
|
920
1163
|
)
|
|
921
1164
|
raise
|
|
922
1165
|
|
|
@@ -924,7 +1167,9 @@ class DatabricksProvider(ServiceProvider):
|
|
|
924
1167
|
"""
|
|
925
1168
|
Ask Databricks to mint a fresh DB credential for this instance.
|
|
926
1169
|
"""
|
|
927
|
-
logger.
|
|
1170
|
+
logger.trace(
|
|
1171
|
+
"Generating password for lakebase instance", instance_name=instance_name
|
|
1172
|
+
)
|
|
928
1173
|
w: WorkspaceClient = self.w
|
|
929
1174
|
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
930
1175
|
request_id=str(uuid.uuid4()),
|
|
@@ -960,7 +1205,8 @@ class DatabricksProvider(ServiceProvider):
|
|
|
960
1205
|
# Validate that client_id is provided
|
|
961
1206
|
if not database.client_id:
|
|
962
1207
|
logger.warning(
|
|
963
|
-
|
|
1208
|
+
"client_id required to create instance role",
|
|
1209
|
+
instance_name=database.instance_name,
|
|
964
1210
|
)
|
|
965
1211
|
return
|
|
966
1212
|
|
|
@@ -970,7 +1216,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
970
1216
|
instance_name: str = database.instance_name
|
|
971
1217
|
|
|
972
1218
|
logger.debug(
|
|
973
|
-
|
|
1219
|
+
"Creating instance role",
|
|
1220
|
+
role_name=role_name,
|
|
1221
|
+
instance_name=instance_name,
|
|
1222
|
+
principal=client_id,
|
|
974
1223
|
)
|
|
975
1224
|
|
|
976
1225
|
try:
|
|
@@ -981,13 +1230,15 @@ class DatabricksProvider(ServiceProvider):
|
|
|
981
1230
|
name=role_name,
|
|
982
1231
|
)
|
|
983
1232
|
logger.info(
|
|
984
|
-
|
|
1233
|
+
"Instance role already exists",
|
|
1234
|
+
role_name=role_name,
|
|
1235
|
+
instance_name=instance_name,
|
|
985
1236
|
)
|
|
986
1237
|
return
|
|
987
1238
|
except NotFound:
|
|
988
1239
|
# Role doesn't exist, proceed with creation
|
|
989
1240
|
logger.debug(
|
|
990
|
-
|
|
1241
|
+
"Instance role not found, creating new role", role_name=role_name
|
|
991
1242
|
)
|
|
992
1243
|
|
|
993
1244
|
# Create the database instance role
|
|
@@ -1003,8 +1254,10 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1003
1254
|
database_instance_role=role,
|
|
1004
1255
|
)
|
|
1005
1256
|
|
|
1006
|
-
logger.
|
|
1007
|
-
|
|
1257
|
+
logger.success(
|
|
1258
|
+
"Instance role created successfully",
|
|
1259
|
+
role_name=role_name,
|
|
1260
|
+
instance_name=instance_name,
|
|
1008
1261
|
)
|
|
1009
1262
|
|
|
1010
1263
|
except Exception as e:
|
|
@@ -1016,88 +1269,308 @@ class DatabricksProvider(ServiceProvider):
|
|
|
1016
1269
|
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
1017
1270
|
):
|
|
1018
1271
|
logger.info(
|
|
1019
|
-
|
|
1272
|
+
"Instance role was created concurrently",
|
|
1273
|
+
role_name=role_name,
|
|
1274
|
+
instance_name=instance_name,
|
|
1020
1275
|
)
|
|
1021
1276
|
return
|
|
1022
1277
|
|
|
1023
1278
|
# Re-raise unexpected errors
|
|
1024
1279
|
logger.error(
|
|
1025
|
-
|
|
1280
|
+
"Error creating instance role",
|
|
1281
|
+
role_name=role_name,
|
|
1282
|
+
instance_name=instance_name,
|
|
1283
|
+
error=str(e),
|
|
1026
1284
|
)
|
|
1027
1285
|
raise
|
|
1028
1286
|
|
|
1029
|
-
def get_prompt(self, prompt_model: PromptModel) ->
|
|
1030
|
-
"""
|
|
1031
|
-
|
|
1287
|
+
def get_prompt(self, prompt_model: PromptModel) -> PromptVersion:
|
|
1288
|
+
"""
|
|
1289
|
+
Load prompt from MLflow Prompt Registry with fallback logic.
|
|
1290
|
+
|
|
1291
|
+
If an explicit version or alias is specified in the prompt_model, uses that directly.
|
|
1292
|
+
Otherwise, tries to load prompts in this order:
|
|
1293
|
+
1. champion alias
|
|
1294
|
+
2. latest alias
|
|
1295
|
+
3. default alias
|
|
1296
|
+
4. Register default_template if provided (only if register_to_registry=True)
|
|
1297
|
+
5. Use default_template directly (fallback)
|
|
1298
|
+
|
|
1299
|
+
The auto_register field controls whether the default_template is automatically
|
|
1300
|
+
synced to the prompt registry:
|
|
1301
|
+
- If True (default): Auto-registers/updates the default_template in the registry
|
|
1302
|
+
- If False: Never registers, but can still load existing prompts from registry
|
|
1303
|
+
or use default_template directly as a local-only prompt
|
|
1032
1304
|
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
prompt_uri = f"prompts:/{prompt_name}@{prompt_model.alias}"
|
|
1036
|
-
elif prompt_model.version:
|
|
1037
|
-
prompt_uri = f"prompts:/{prompt_name}/{prompt_model.version}"
|
|
1038
|
-
else:
|
|
1039
|
-
prompt_uri = f"prompts:/{prompt_name}@latest"
|
|
1305
|
+
Args:
|
|
1306
|
+
prompt_model: The prompt model configuration
|
|
1040
1307
|
|
|
1041
|
-
|
|
1042
|
-
|
|
1308
|
+
Returns:
|
|
1309
|
+
PromptVersion: The loaded prompt version
|
|
1043
1310
|
|
|
1044
|
-
|
|
1045
|
-
|
|
1311
|
+
Raises:
|
|
1312
|
+
ValueError: If no prompt can be loaded from any source
|
|
1313
|
+
"""
|
|
1046
1314
|
|
|
1047
|
-
|
|
1048
|
-
logger.warning(f"Failed to load prompt '{prompt_name}' from registry: {e}")
|
|
1315
|
+
prompt_name: str = prompt_model.full_name
|
|
1049
1316
|
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
1053
|
-
|
|
1317
|
+
# If explicit version or alias is specified, use it directly
|
|
1318
|
+
if prompt_model.version or prompt_model.alias:
|
|
1319
|
+
try:
|
|
1320
|
+
prompt_version: PromptVersion = prompt_model.as_prompt()
|
|
1321
|
+
version_or_alias = (
|
|
1322
|
+
f"version {prompt_model.version}"
|
|
1323
|
+
if prompt_model.version
|
|
1324
|
+
else f"alias {prompt_model.alias}"
|
|
1054
1325
|
)
|
|
1055
|
-
|
|
1326
|
+
logger.debug(
|
|
1327
|
+
"Loaded prompt with explicit version/alias",
|
|
1328
|
+
prompt_name=prompt_name,
|
|
1329
|
+
version_or_alias=version_or_alias,
|
|
1330
|
+
)
|
|
1331
|
+
return prompt_version
|
|
1332
|
+
except Exception as e:
|
|
1333
|
+
version_or_alias = (
|
|
1334
|
+
f"version {prompt_model.version}"
|
|
1335
|
+
if prompt_model.version
|
|
1336
|
+
else f"alias {prompt_model.alias}"
|
|
1337
|
+
)
|
|
1338
|
+
logger.warning(
|
|
1339
|
+
"Failed to load prompt with explicit version/alias",
|
|
1340
|
+
prompt_name=prompt_name,
|
|
1341
|
+
version_or_alias=version_or_alias,
|
|
1342
|
+
error=str(e),
|
|
1343
|
+
)
|
|
1344
|
+
# Fall through to try other methods
|
|
1056
1345
|
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1346
|
+
# Try to load in priority order: champion → default (with sync check)
|
|
1347
|
+
logger.trace(
|
|
1348
|
+
"Trying prompt fallback order",
|
|
1349
|
+
prompt_name=prompt_name,
|
|
1350
|
+
order="champion → default",
|
|
1351
|
+
)
|
|
1060
1352
|
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
"""Register default_template to prompt registry under 'default' alias if changed."""
|
|
1065
|
-
try:
|
|
1066
|
-
# Check if default alias already has the same template
|
|
1353
|
+
# First, sync default alias if template has changed (even if champion exists)
|
|
1354
|
+
# Only do this if auto_register is True
|
|
1355
|
+
if prompt_model.default_template and prompt_model.auto_register:
|
|
1067
1356
|
try:
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1357
|
+
# Try to load existing default
|
|
1358
|
+
existing_default = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1359
|
+
|
|
1360
|
+
# Check if champion exists and if it matches default
|
|
1361
|
+
champion_matches_default = False
|
|
1362
|
+
try:
|
|
1363
|
+
existing_champion = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1364
|
+
champion_matches_default = (
|
|
1365
|
+
existing_champion.version == existing_default.version
|
|
1366
|
+
)
|
|
1367
|
+
status = (
|
|
1368
|
+
"tracking" if champion_matches_default else "pinned separately"
|
|
1369
|
+
)
|
|
1370
|
+
logger.trace(
|
|
1371
|
+
"Champion vs default version",
|
|
1372
|
+
prompt_name=prompt_name,
|
|
1373
|
+
champion_version=existing_champion.version,
|
|
1374
|
+
default_version=existing_default.version,
|
|
1375
|
+
status=status,
|
|
1376
|
+
)
|
|
1377
|
+
except Exception:
|
|
1378
|
+
# No champion exists
|
|
1379
|
+
logger.trace("No champion alias found", prompt_name=prompt_name)
|
|
1380
|
+
|
|
1381
|
+
# Check if default_template differs from existing default
|
|
1072
1382
|
if (
|
|
1073
|
-
|
|
1074
|
-
|
|
1383
|
+
existing_default.template.strip()
|
|
1384
|
+
!= prompt_model.default_template.strip()
|
|
1075
1385
|
):
|
|
1076
|
-
logger.
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1386
|
+
logger.info(
|
|
1387
|
+
"Default template changed, registering new version",
|
|
1388
|
+
prompt_name=prompt_name,
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
# Only update champion if it was pointing to the old default
|
|
1392
|
+
if champion_matches_default:
|
|
1393
|
+
logger.info(
|
|
1394
|
+
"Champion was tracking default, will update to new version",
|
|
1395
|
+
prompt_name=prompt_name,
|
|
1396
|
+
old_version=existing_default.version,
|
|
1397
|
+
)
|
|
1398
|
+
set_champion = True
|
|
1399
|
+
else:
|
|
1400
|
+
logger.info(
|
|
1401
|
+
"Champion is pinned separately, preserving it",
|
|
1402
|
+
prompt_name=prompt_name,
|
|
1403
|
+
)
|
|
1404
|
+
set_champion = False
|
|
1405
|
+
|
|
1406
|
+
self._register_default_template(
|
|
1407
|
+
prompt_name,
|
|
1408
|
+
prompt_model.default_template,
|
|
1409
|
+
prompt_model.description,
|
|
1410
|
+
set_champion=set_champion,
|
|
1411
|
+
)
|
|
1412
|
+
except Exception as e:
|
|
1413
|
+
# No default exists yet, register it
|
|
1414
|
+
logger.trace(
|
|
1415
|
+
"No default alias found", prompt_name=prompt_name, error=str(e)
|
|
1416
|
+
)
|
|
1417
|
+
logger.info(
|
|
1418
|
+
"Registering default template as default alias",
|
|
1419
|
+
prompt_name=prompt_name,
|
|
1420
|
+
)
|
|
1421
|
+
# First registration - set both default and champion
|
|
1422
|
+
self._register_default_template(
|
|
1423
|
+
prompt_name,
|
|
1424
|
+
prompt_model.default_template,
|
|
1425
|
+
prompt_model.description,
|
|
1426
|
+
set_champion=True,
|
|
1427
|
+
)
|
|
1428
|
+
elif prompt_model.default_template and not prompt_model.auto_register:
|
|
1429
|
+
logger.trace(
|
|
1430
|
+
"Prompt has auto_register=False, skipping registration",
|
|
1431
|
+
prompt_name=prompt_name,
|
|
1432
|
+
)
|
|
1433
|
+
|
|
1434
|
+
# 1. Try champion alias (highest priority for execution)
|
|
1435
|
+
try:
|
|
1436
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@champion")
|
|
1437
|
+
logger.info("Loaded prompt from champion alias", prompt_name=prompt_name)
|
|
1438
|
+
return prompt_version
|
|
1439
|
+
except Exception as e:
|
|
1440
|
+
logger.trace(
|
|
1441
|
+
"Champion alias not found", prompt_name=prompt_name, error=str(e)
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
# 2. Try default alias (already synced above)
|
|
1445
|
+
if prompt_model.default_template:
|
|
1446
|
+
try:
|
|
1447
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@default")
|
|
1448
|
+
logger.info("Loaded prompt from default alias", prompt_name=prompt_name)
|
|
1449
|
+
return prompt_version
|
|
1450
|
+
except Exception as e:
|
|
1451
|
+
# Should not happen since we just registered it above, but handle anyway
|
|
1452
|
+
logger.trace(
|
|
1453
|
+
"Default alias not found", prompt_name=prompt_name, error=str(e)
|
|
1081
1454
|
)
|
|
1082
1455
|
|
|
1083
|
-
|
|
1456
|
+
# 3. Try latest alias as final fallback
|
|
1457
|
+
try:
|
|
1458
|
+
prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
|
|
1459
|
+
logger.info("Loaded prompt from latest alias", prompt_name=prompt_name)
|
|
1460
|
+
return prompt_version
|
|
1461
|
+
except Exception as e:
|
|
1462
|
+
logger.trace(
|
|
1463
|
+
"Latest alias not found", prompt_name=prompt_name, error=str(e)
|
|
1464
|
+
)
|
|
1465
|
+
|
|
1466
|
+
# 4. Final fallback: use default_template directly if available
|
|
1467
|
+
if prompt_model.default_template:
|
|
1468
|
+
logger.warning(
|
|
1469
|
+
"Could not load prompt from registry, using default_template directly",
|
|
1470
|
+
prompt_name=prompt_name,
|
|
1471
|
+
)
|
|
1472
|
+
return PromptVersion(
|
|
1473
|
+
name=prompt_name,
|
|
1474
|
+
version=1,
|
|
1475
|
+
template=prompt_model.default_template,
|
|
1476
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1477
|
+
)
|
|
1478
|
+
|
|
1479
|
+
raise ValueError(
|
|
1480
|
+
f"Prompt '{prompt_name}' not found in registry "
|
|
1481
|
+
"(tried champion, default, latest aliases) "
|
|
1482
|
+
"and no default_template provided"
|
|
1483
|
+
)
|
|
1484
|
+
|
|
1485
|
+
def _register_default_template(
|
|
1486
|
+
self,
|
|
1487
|
+
prompt_name: str,
|
|
1488
|
+
default_template: str,
|
|
1489
|
+
description: str | None = None,
|
|
1490
|
+
set_champion: bool = True,
|
|
1491
|
+
) -> PromptVersion:
|
|
1492
|
+
"""Register default_template as a new prompt version.
|
|
1493
|
+
|
|
1494
|
+
Registers the template and sets the 'default' alias.
|
|
1495
|
+
Optionally sets 'champion' alias if no champion exists.
|
|
1496
|
+
|
|
1497
|
+
Args:
|
|
1498
|
+
prompt_name: Full name of the prompt
|
|
1499
|
+
default_template: The template content
|
|
1500
|
+
description: Optional description for commit message
|
|
1501
|
+
set_champion: Whether to also set champion alias (default: True)
|
|
1502
|
+
|
|
1503
|
+
If registration fails (e.g., in Model Serving with restricted permissions),
|
|
1504
|
+
logs the error and raises.
|
|
1505
|
+
"""
|
|
1506
|
+
logger.info(
|
|
1507
|
+
"Registering default template",
|
|
1508
|
+
prompt_name=prompt_name,
|
|
1509
|
+
set_champion=set_champion,
|
|
1510
|
+
)
|
|
1511
|
+
|
|
1512
|
+
try:
|
|
1084
1513
|
commit_message = description or "Auto-synced from default_template"
|
|
1085
|
-
prompt_version
|
|
1514
|
+
prompt_version = mlflow.genai.register_prompt(
|
|
1086
1515
|
name=prompt_name,
|
|
1087
1516
|
template=default_template,
|
|
1088
1517
|
commit_message=commit_message,
|
|
1518
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1089
1519
|
)
|
|
1090
1520
|
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1521
|
+
# Always set default alias
|
|
1522
|
+
try:
|
|
1523
|
+
logger.debug(
|
|
1524
|
+
"Setting default alias",
|
|
1525
|
+
prompt_name=prompt_name,
|
|
1526
|
+
version=prompt_version.version,
|
|
1527
|
+
)
|
|
1528
|
+
mlflow.genai.set_prompt_alias(
|
|
1529
|
+
name=prompt_name, alias="default", version=prompt_version.version
|
|
1530
|
+
)
|
|
1531
|
+
logger.success(
|
|
1532
|
+
"Set default alias for prompt",
|
|
1533
|
+
prompt_name=prompt_name,
|
|
1534
|
+
version=prompt_version.version,
|
|
1535
|
+
)
|
|
1536
|
+
except Exception as alias_error:
|
|
1537
|
+
logger.warning(
|
|
1538
|
+
"Could not set default alias",
|
|
1539
|
+
prompt_name=prompt_name,
|
|
1540
|
+
error=str(alias_error),
|
|
1541
|
+
)
|
|
1097
1542
|
|
|
1098
|
-
|
|
1099
|
-
|
|
1100
|
-
|
|
1543
|
+
# Optionally set champion alias (only if no champion exists or explicitly requested)
|
|
1544
|
+
if set_champion:
|
|
1545
|
+
try:
|
|
1546
|
+
mlflow.genai.set_prompt_alias(
|
|
1547
|
+
name=prompt_name,
|
|
1548
|
+
alias="champion",
|
|
1549
|
+
version=prompt_version.version,
|
|
1550
|
+
)
|
|
1551
|
+
logger.success(
|
|
1552
|
+
"Set champion alias for prompt",
|
|
1553
|
+
prompt_name=prompt_name,
|
|
1554
|
+
version=prompt_version.version,
|
|
1555
|
+
)
|
|
1556
|
+
except Exception as alias_error:
|
|
1557
|
+
logger.warning(
|
|
1558
|
+
"Could not set champion alias",
|
|
1559
|
+
prompt_name=prompt_name,
|
|
1560
|
+
error=str(alias_error),
|
|
1561
|
+
)
|
|
1101
1562
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1563
|
+
return prompt_version
|
|
1564
|
+
|
|
1565
|
+
except Exception as reg_error:
|
|
1566
|
+
logger.error(
|
|
1567
|
+
"Failed to register prompt - please register from notebook with write permissions",
|
|
1568
|
+
prompt_name=prompt_name,
|
|
1569
|
+
error=str(reg_error),
|
|
1570
|
+
)
|
|
1571
|
+
return PromptVersion(
|
|
1572
|
+
name=prompt_name,
|
|
1573
|
+
version=1,
|
|
1574
|
+
template=default_template,
|
|
1575
|
+
tags={"dao_ai": dao_ai_version()},
|
|
1576
|
+
)
|