dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Iterator,
|
|
13
13
|
Literal,
|
|
14
14
|
Optional,
|
|
15
|
+
Self,
|
|
15
16
|
Sequence,
|
|
16
17
|
TypeAlias,
|
|
17
18
|
Union,
|
|
@@ -22,13 +23,20 @@ from databricks.sdk.credentials_provider import (
|
|
|
22
23
|
CredentialsStrategy,
|
|
23
24
|
ModelServingUserCredentials,
|
|
24
25
|
)
|
|
26
|
+
from databricks.sdk.errors.platform import NotFound
|
|
25
27
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
|
+
from databricks.sdk.service.dashboards import GenieSpace
|
|
26
29
|
from databricks.sdk.service.database import DatabaseInstance
|
|
30
|
+
from databricks.sdk.service.sql import GetWarehouseResponse
|
|
27
31
|
from databricks.vector_search.client import VectorSearchClient
|
|
28
32
|
from databricks.vector_search.index import VectorSearchIndex
|
|
29
33
|
from databricks_langchain import (
|
|
34
|
+
ChatDatabricks,
|
|
35
|
+
DatabricksEmbeddings,
|
|
30
36
|
DatabricksFunctionClient,
|
|
31
37
|
)
|
|
38
|
+
from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
|
|
39
|
+
from langchain_core.embeddings import Embeddings
|
|
32
40
|
from langchain_core.language_models import LanguageModelLike
|
|
33
41
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
34
42
|
from langchain_core.runnables.base import RunnableLike
|
|
@@ -41,6 +49,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
|
41
49
|
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
42
50
|
from mlflow.models import ModelConfig
|
|
43
51
|
from mlflow.models.resources import (
|
|
52
|
+
DatabricksApp,
|
|
44
53
|
DatabricksFunction,
|
|
45
54
|
DatabricksGenieSpace,
|
|
46
55
|
DatabricksLakebase,
|
|
@@ -59,10 +68,13 @@ from pydantic import (
|
|
|
59
68
|
BaseModel,
|
|
60
69
|
ConfigDict,
|
|
61
70
|
Field,
|
|
71
|
+
PrivateAttr,
|
|
62
72
|
field_serializer,
|
|
63
73
|
model_validator,
|
|
64
74
|
)
|
|
65
75
|
|
|
76
|
+
from dao_ai.utils import normalize_name
|
|
77
|
+
|
|
66
78
|
|
|
67
79
|
class HasValue(ABC):
|
|
68
80
|
@abstractmethod
|
|
@@ -81,27 +93,6 @@ class HasFullName(ABC):
|
|
|
81
93
|
def full_name(self) -> str: ...
|
|
82
94
|
|
|
83
95
|
|
|
84
|
-
class IsDatabricksResource(ABC):
|
|
85
|
-
on_behalf_of_user: Optional[bool] = False
|
|
86
|
-
|
|
87
|
-
@abstractmethod
|
|
88
|
-
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def api_scopes(self) -> Sequence[str]: ...
|
|
93
|
-
|
|
94
|
-
@property
|
|
95
|
-
def workspace_client(self) -> WorkspaceClient:
|
|
96
|
-
credentials_strategy: CredentialsStrategy = None
|
|
97
|
-
if self.on_behalf_of_user:
|
|
98
|
-
credentials_strategy = ModelServingUserCredentials()
|
|
99
|
-
logger.debug(
|
|
100
|
-
f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
|
|
101
|
-
)
|
|
102
|
-
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
103
|
-
|
|
104
|
-
|
|
105
96
|
class EnvironmentVariableModel(BaseModel, HasValue):
|
|
106
97
|
model_config = ConfigDict(
|
|
107
98
|
frozen=True,
|
|
@@ -200,6 +191,162 @@ AnyVariable: TypeAlias = (
|
|
|
200
191
|
)
|
|
201
192
|
|
|
202
193
|
|
|
194
|
+
class ServicePrincipalModel(BaseModel):
|
|
195
|
+
model_config = ConfigDict(
|
|
196
|
+
frozen=True,
|
|
197
|
+
use_enum_values=True,
|
|
198
|
+
)
|
|
199
|
+
client_id: AnyVariable
|
|
200
|
+
client_secret: AnyVariable
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class IsDatabricksResource(ABC, BaseModel):
|
|
204
|
+
"""
|
|
205
|
+
Base class for Databricks resources with authentication support.
|
|
206
|
+
|
|
207
|
+
Authentication Options:
|
|
208
|
+
----------------------
|
|
209
|
+
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
210
|
+
calling user's identity via ModelServingUserCredentials.
|
|
211
|
+
|
|
212
|
+
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
213
|
+
(client_id + client_secret + workspace_host) for service principal auth.
|
|
214
|
+
|
|
215
|
+
3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
|
|
216
|
+
to authenticate with a personal access token.
|
|
217
|
+
|
|
218
|
+
4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
|
|
219
|
+
(environment variables, notebook context, etc.)
|
|
220
|
+
|
|
221
|
+
Authentication Priority:
|
|
222
|
+
1. OBO (on_behalf_of_user=True)
|
|
223
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
224
|
+
3. PAT (pat + workspace_host)
|
|
225
|
+
4. Ambient/default authentication
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
229
|
+
|
|
230
|
+
on_behalf_of_user: Optional[bool] = False
|
|
231
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
232
|
+
client_id: Optional[AnyVariable] = None
|
|
233
|
+
client_secret: Optional[AnyVariable] = None
|
|
234
|
+
workspace_host: Optional[AnyVariable] = None
|
|
235
|
+
pat: Optional[AnyVariable] = None
|
|
236
|
+
|
|
237
|
+
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
+
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
+
|
|
240
|
+
@abstractmethod
|
|
241
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
@abstractmethod
|
|
245
|
+
def api_scopes(self) -> Sequence[str]: ...
|
|
246
|
+
|
|
247
|
+
@model_validator(mode="after")
|
|
248
|
+
def _expand_service_principal(self) -> Self:
|
|
249
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
250
|
+
if self.service_principal is not None:
|
|
251
|
+
if self.client_id is None:
|
|
252
|
+
self.client_id = self.service_principal.client_id
|
|
253
|
+
if self.client_secret is None:
|
|
254
|
+
self.client_secret = self.service_principal.client_secret
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
@model_validator(mode="after")
|
|
258
|
+
def _validate_auth_not_mixed(self) -> Self:
|
|
259
|
+
"""Validate that OAuth and PAT authentication are not both provided."""
|
|
260
|
+
has_oauth: bool = self.client_id is not None and self.client_secret is not None
|
|
261
|
+
has_pat: bool = self.pat is not None
|
|
262
|
+
|
|
263
|
+
if has_oauth and has_pat:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"Cannot use both OAuth and user authentication methods. "
|
|
266
|
+
"Please provide either OAuth credentials or user credentials."
|
|
267
|
+
)
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
272
|
+
"""
|
|
273
|
+
Get a WorkspaceClient configured with the appropriate authentication.
|
|
274
|
+
|
|
275
|
+
The client is lazily instantiated on first access and cached for subsequent calls.
|
|
276
|
+
|
|
277
|
+
Authentication priority:
|
|
278
|
+
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
279
|
+
2. If service principal credentials are configured (client_id, client_secret,
|
|
280
|
+
workspace_host), uses OAuth M2M
|
|
281
|
+
3. If PAT is configured, uses token authentication
|
|
282
|
+
4. Otherwise, uses default/ambient authentication
|
|
283
|
+
"""
|
|
284
|
+
# Return cached client if already instantiated
|
|
285
|
+
if self._workspace_client is not None:
|
|
286
|
+
return self._workspace_client
|
|
287
|
+
|
|
288
|
+
from dao_ai.utils import normalize_host
|
|
289
|
+
|
|
290
|
+
# Check for OBO first (highest priority)
|
|
291
|
+
if self.on_behalf_of_user:
|
|
292
|
+
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
293
|
+
logger.debug(
|
|
294
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
295
|
+
f"with OBO credentials strategy"
|
|
296
|
+
)
|
|
297
|
+
self._workspace_client = WorkspaceClient(
|
|
298
|
+
credentials_strategy=credentials_strategy
|
|
299
|
+
)
|
|
300
|
+
return self._workspace_client
|
|
301
|
+
|
|
302
|
+
# Check for service principal credentials
|
|
303
|
+
client_id_value: str | None = (
|
|
304
|
+
value_of(self.client_id) if self.client_id else None
|
|
305
|
+
)
|
|
306
|
+
client_secret_value: str | None = (
|
|
307
|
+
value_of(self.client_secret) if self.client_secret else None
|
|
308
|
+
)
|
|
309
|
+
workspace_host_value: str | None = (
|
|
310
|
+
normalize_host(value_of(self.workspace_host))
|
|
311
|
+
if self.workspace_host
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if client_id_value and client_secret_value and workspace_host_value:
|
|
316
|
+
logger.debug(
|
|
317
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
318
|
+
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
319
|
+
)
|
|
320
|
+
self._workspace_client = WorkspaceClient(
|
|
321
|
+
host=workspace_host_value,
|
|
322
|
+
client_id=client_id_value,
|
|
323
|
+
client_secret=client_secret_value,
|
|
324
|
+
auth_type="oauth-m2m",
|
|
325
|
+
)
|
|
326
|
+
return self._workspace_client
|
|
327
|
+
|
|
328
|
+
# Check for PAT authentication
|
|
329
|
+
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
330
|
+
if pat_value:
|
|
331
|
+
logger.debug(
|
|
332
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
333
|
+
)
|
|
334
|
+
self._workspace_client = WorkspaceClient(
|
|
335
|
+
host=workspace_host_value,
|
|
336
|
+
token=pat_value,
|
|
337
|
+
auth_type="pat",
|
|
338
|
+
)
|
|
339
|
+
return self._workspace_client
|
|
340
|
+
|
|
341
|
+
# Default: use ambient authentication
|
|
342
|
+
logger.debug(
|
|
343
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
344
|
+
"with default/ambient authentication"
|
|
345
|
+
)
|
|
346
|
+
self._workspace_client = WorkspaceClient()
|
|
347
|
+
return self._workspace_client
|
|
348
|
+
|
|
349
|
+
|
|
203
350
|
class Privilege(str, Enum):
|
|
204
351
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
205
352
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -226,9 +373,21 @@ class Privilege(str, Enum):
|
|
|
226
373
|
|
|
227
374
|
class PermissionModel(BaseModel):
|
|
228
375
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
229
|
-
principals: list[str] = Field(default_factory=list)
|
|
376
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
230
377
|
privileges: list[Privilege]
|
|
231
378
|
|
|
379
|
+
@model_validator(mode="after")
|
|
380
|
+
def resolve_principals(self) -> Self:
|
|
381
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
382
|
+
resolved: list[str] = []
|
|
383
|
+
for principal in self.principals:
|
|
384
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
385
|
+
resolved.append(value_of(principal.client_id))
|
|
386
|
+
else:
|
|
387
|
+
resolved.append(principal)
|
|
388
|
+
self.principals = resolved
|
|
389
|
+
return self
|
|
390
|
+
|
|
232
391
|
|
|
233
392
|
class SchemaModel(BaseModel, HasFullName):
|
|
234
393
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -248,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
248
407
|
provider.create_schema(self)
|
|
249
408
|
|
|
250
409
|
|
|
251
|
-
class
|
|
410
|
+
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
411
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
412
|
+
name: str
|
|
413
|
+
url: str
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def full_name(self) -> str:
|
|
417
|
+
return self.name
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def api_scopes(self) -> Sequence[str]:
|
|
421
|
+
return ["apps.apps"]
|
|
422
|
+
|
|
423
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
424
|
+
return [
|
|
425
|
+
DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class TableModel(IsDatabricksResource, HasFullName):
|
|
252
430
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
253
431
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
254
432
|
name: Optional[str] = None
|
|
@@ -274,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
274
452
|
def api_scopes(self) -> Sequence[str]:
|
|
275
453
|
return []
|
|
276
454
|
|
|
455
|
+
def exists(self) -> bool:
|
|
456
|
+
"""Check if the table exists in Unity Catalog.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
True if the table exists, False otherwise.
|
|
460
|
+
"""
|
|
461
|
+
try:
|
|
462
|
+
self.workspace_client.tables.get(full_name=self.full_name)
|
|
463
|
+
return True
|
|
464
|
+
except NotFound:
|
|
465
|
+
logger.debug(f"Table not found: {self.full_name}")
|
|
466
|
+
return False
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.warning(f"Error checking table existence for {self.full_name}: {e}")
|
|
469
|
+
return False
|
|
470
|
+
|
|
277
471
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
278
472
|
resources: list[DatabricksResource] = []
|
|
279
473
|
|
|
@@ -317,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
317
511
|
return resources
|
|
318
512
|
|
|
319
513
|
|
|
320
|
-
class LLMModel(
|
|
514
|
+
class LLMModel(IsDatabricksResource):
|
|
321
515
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
322
516
|
name: str
|
|
517
|
+
description: Optional[str] = None
|
|
323
518
|
temperature: Optional[float] = 0.1
|
|
324
519
|
max_tokens: Optional[int] = 8192
|
|
325
520
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
521
|
+
use_responses_api: Optional[bool] = Field(
|
|
522
|
+
default=False,
|
|
523
|
+
description="Use Responses API for ResponsesAgent endpoints",
|
|
524
|
+
)
|
|
326
525
|
|
|
327
526
|
@property
|
|
328
527
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -342,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
342
541
|
]
|
|
343
542
|
|
|
344
543
|
def as_chat_model(self) -> LanguageModelLike:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
351
|
-
|
|
352
|
-
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
353
|
-
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
544
|
+
chat_client: LanguageModelLike = ChatDatabricks(
|
|
545
|
+
model=self.name,
|
|
546
|
+
temperature=self.temperature,
|
|
547
|
+
max_tokens=self.max_tokens,
|
|
548
|
+
use_responses_api=self.use_responses_api,
|
|
354
549
|
)
|
|
355
|
-
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
356
|
-
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
357
|
-
# )
|
|
358
550
|
|
|
359
551
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
360
552
|
for fallback in self.fallbacks:
|
|
@@ -386,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
386
578
|
|
|
387
579
|
return chat_client
|
|
388
580
|
|
|
581
|
+
def as_embeddings_model(self) -> Embeddings:
|
|
582
|
+
return DatabricksEmbeddings(endpoint=self.name)
|
|
583
|
+
|
|
389
584
|
|
|
390
585
|
class VectorSearchEndpointType(str, Enum):
|
|
391
586
|
STANDARD = "STANDARD"
|
|
@@ -405,7 +600,9 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
405
600
|
return str(value)
|
|
406
601
|
|
|
407
602
|
|
|
408
|
-
class IndexModel(
|
|
603
|
+
class IndexModel(IsDatabricksResource, HasFullName):
|
|
604
|
+
"""Model representing a Databricks Vector Search index."""
|
|
605
|
+
|
|
409
606
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
410
607
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
411
608
|
name: str
|
|
@@ -429,13 +626,314 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
429
626
|
)
|
|
430
627
|
]
|
|
431
628
|
|
|
629
|
+
def exists(self) -> bool:
|
|
630
|
+
"""Check if this vector search index exists.
|
|
631
|
+
|
|
632
|
+
Returns:
|
|
633
|
+
True if the index exists, False otherwise.
|
|
634
|
+
"""
|
|
635
|
+
try:
|
|
636
|
+
self.workspace_client.vector_search_indexes.get_index(self.full_name)
|
|
637
|
+
return True
|
|
638
|
+
except NotFound:
|
|
639
|
+
logger.debug(f"Index not found: {self.full_name}")
|
|
640
|
+
return False
|
|
641
|
+
except Exception as e:
|
|
642
|
+
logger.warning(f"Error checking index existence for {self.full_name}: {e}")
|
|
643
|
+
return False
|
|
644
|
+
|
|
432
645
|
|
|
433
|
-
class
|
|
646
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
647
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
648
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
649
|
+
name: Optional[str] = None
|
|
650
|
+
|
|
651
|
+
@model_validator(mode="after")
|
|
652
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
653
|
+
if not self.name and not self.schema_model:
|
|
654
|
+
raise ValueError(
|
|
655
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
656
|
+
)
|
|
657
|
+
return self
|
|
658
|
+
|
|
659
|
+
@property
|
|
660
|
+
def full_name(self) -> str:
|
|
661
|
+
if self.schema_model:
|
|
662
|
+
name: str = ""
|
|
663
|
+
if self.name:
|
|
664
|
+
name = f".{self.name}"
|
|
665
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
666
|
+
return self.name
|
|
667
|
+
|
|
668
|
+
def exists(self) -> bool:
|
|
669
|
+
"""Check if the function exists in Unity Catalog.
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
True if the function exists, False otherwise.
|
|
673
|
+
"""
|
|
674
|
+
try:
|
|
675
|
+
self.workspace_client.functions.get(name=self.full_name)
|
|
676
|
+
return True
|
|
677
|
+
except NotFound:
|
|
678
|
+
logger.debug(f"Function not found: {self.full_name}")
|
|
679
|
+
return False
|
|
680
|
+
except Exception as e:
|
|
681
|
+
logger.warning(
|
|
682
|
+
f"Error checking function existence for {self.full_name}: {e}"
|
|
683
|
+
)
|
|
684
|
+
return False
|
|
685
|
+
|
|
686
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
687
|
+
resources: list[DatabricksResource] = []
|
|
688
|
+
if self.name:
|
|
689
|
+
resources.append(
|
|
690
|
+
DatabricksFunction(
|
|
691
|
+
function_name=self.full_name,
|
|
692
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
693
|
+
)
|
|
694
|
+
)
|
|
695
|
+
else:
|
|
696
|
+
w: WorkspaceClient = self.workspace_client
|
|
697
|
+
schema_full_name: str = self.schema_model.full_name
|
|
698
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
699
|
+
catalog_name=self.schema_model.catalog_name,
|
|
700
|
+
schema_name=self.schema_model.schema_name,
|
|
701
|
+
)
|
|
702
|
+
resources.extend(
|
|
703
|
+
[
|
|
704
|
+
DatabricksFunction(
|
|
705
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
706
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
707
|
+
)
|
|
708
|
+
for function in functions
|
|
709
|
+
]
|
|
710
|
+
)
|
|
711
|
+
|
|
712
|
+
return resources
|
|
713
|
+
|
|
714
|
+
@property
|
|
715
|
+
def api_scopes(self) -> Sequence[str]:
|
|
716
|
+
return ["sql.statement-execution"]
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
class WarehouseModel(IsDatabricksResource):
|
|
720
|
+
model_config = ConfigDict()
|
|
721
|
+
name: str
|
|
722
|
+
description: Optional[str] = None
|
|
723
|
+
warehouse_id: AnyVariable
|
|
724
|
+
|
|
725
|
+
@property
|
|
726
|
+
def api_scopes(self) -> Sequence[str]:
|
|
727
|
+
return [
|
|
728
|
+
"sql.warehouses",
|
|
729
|
+
"sql.statement-execution",
|
|
730
|
+
]
|
|
731
|
+
|
|
732
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
733
|
+
return [
|
|
734
|
+
DatabricksSQLWarehouse(
|
|
735
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
736
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
737
|
+
)
|
|
738
|
+
]
|
|
739
|
+
|
|
740
|
+
@model_validator(mode="after")
|
|
741
|
+
def update_warehouse_id(self) -> Self:
|
|
742
|
+
self.warehouse_id = value_of(self.warehouse_id)
|
|
743
|
+
return self
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
class GenieRoomModel(IsDatabricksResource):
|
|
434
747
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
435
748
|
name: str
|
|
436
749
|
description: Optional[str] = None
|
|
437
750
|
space_id: AnyVariable
|
|
438
751
|
|
|
752
|
+
_space_details: Optional[GenieSpace] = PrivateAttr(default=None)
|
|
753
|
+
|
|
754
|
+
def _get_space_details(self) -> GenieSpace:
|
|
755
|
+
if self._space_details is None:
|
|
756
|
+
self._space_details = self.workspace_client.genie.get_space(
|
|
757
|
+
space_id=self.space_id, include_serialized_space=True
|
|
758
|
+
)
|
|
759
|
+
return self._space_details
|
|
760
|
+
|
|
761
|
+
def _parse_serialized_space(self) -> dict[str, Any]:
|
|
762
|
+
"""Parse the serialized_space JSON string and return the parsed data."""
|
|
763
|
+
import json
|
|
764
|
+
|
|
765
|
+
space_details = self._get_space_details()
|
|
766
|
+
if not space_details.serialized_space:
|
|
767
|
+
return {}
|
|
768
|
+
|
|
769
|
+
try:
|
|
770
|
+
return json.loads(space_details.serialized_space)
|
|
771
|
+
except json.JSONDecodeError as e:
|
|
772
|
+
logger.warning(f"Failed to parse serialized_space: {e}")
|
|
773
|
+
return {}
|
|
774
|
+
|
|
775
|
+
@property
|
|
776
|
+
def warehouse(self) -> Optional[WarehouseModel]:
|
|
777
|
+
"""Extract warehouse information from the Genie space.
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
WarehouseModel instance if warehouse_id is available, None otherwise.
|
|
781
|
+
"""
|
|
782
|
+
space_details: GenieSpace = self._get_space_details()
|
|
783
|
+
|
|
784
|
+
if not space_details.warehouse_id:
|
|
785
|
+
return None
|
|
786
|
+
|
|
787
|
+
try:
|
|
788
|
+
response: GetWarehouseResponse = self.workspace_client.warehouses.get(
|
|
789
|
+
space_details.warehouse_id
|
|
790
|
+
)
|
|
791
|
+
warehouse_name: str = response.name or space_details.warehouse_id
|
|
792
|
+
|
|
793
|
+
warehouse_model = WarehouseModel(
|
|
794
|
+
name=warehouse_name,
|
|
795
|
+
warehouse_id=space_details.warehouse_id,
|
|
796
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
797
|
+
service_principal=self.service_principal,
|
|
798
|
+
client_id=self.client_id,
|
|
799
|
+
client_secret=self.client_secret,
|
|
800
|
+
workspace_host=self.workspace_host,
|
|
801
|
+
pat=self.pat,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
# Share the cached workspace client if available
|
|
805
|
+
if self._workspace_client is not None:
|
|
806
|
+
warehouse_model._workspace_client = self._workspace_client
|
|
807
|
+
|
|
808
|
+
return warehouse_model
|
|
809
|
+
except Exception as e:
|
|
810
|
+
logger.warning(
|
|
811
|
+
f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
|
|
812
|
+
)
|
|
813
|
+
return None
|
|
814
|
+
|
|
815
|
+
@property
|
|
816
|
+
def tables(self) -> list[TableModel]:
|
|
817
|
+
"""Extract tables from the serialized Genie space.
|
|
818
|
+
|
|
819
|
+
Databricks Genie stores tables in: data_sources.tables[].identifier
|
|
820
|
+
Only includes tables that actually exist in Unity Catalog.
|
|
821
|
+
"""
|
|
822
|
+
parsed_space = self._parse_serialized_space()
|
|
823
|
+
tables_list: list[TableModel] = []
|
|
824
|
+
|
|
825
|
+
# Primary structure: data_sources.tables with 'identifier' field
|
|
826
|
+
if "data_sources" in parsed_space:
|
|
827
|
+
data_sources = parsed_space["data_sources"]
|
|
828
|
+
if isinstance(data_sources, dict) and "tables" in data_sources:
|
|
829
|
+
tables_data = data_sources["tables"]
|
|
830
|
+
if isinstance(tables_data, list):
|
|
831
|
+
for table_item in tables_data:
|
|
832
|
+
table_name: str | None = None
|
|
833
|
+
if isinstance(table_item, dict):
|
|
834
|
+
# Standard Databricks structure uses 'identifier'
|
|
835
|
+
table_name = table_item.get("identifier") or table_item.get(
|
|
836
|
+
"name"
|
|
837
|
+
)
|
|
838
|
+
elif isinstance(table_item, str):
|
|
839
|
+
table_name = table_item
|
|
840
|
+
|
|
841
|
+
if table_name:
|
|
842
|
+
table_model = TableModel(
|
|
843
|
+
name=table_name,
|
|
844
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
845
|
+
service_principal=self.service_principal,
|
|
846
|
+
client_id=self.client_id,
|
|
847
|
+
client_secret=self.client_secret,
|
|
848
|
+
workspace_host=self.workspace_host,
|
|
849
|
+
pat=self.pat,
|
|
850
|
+
)
|
|
851
|
+
# Share the cached workspace client if available
|
|
852
|
+
if self._workspace_client is not None:
|
|
853
|
+
table_model._workspace_client = self._workspace_client
|
|
854
|
+
|
|
855
|
+
# Verify the table exists before adding
|
|
856
|
+
if not table_model.exists():
|
|
857
|
+
continue
|
|
858
|
+
|
|
859
|
+
tables_list.append(table_model)
|
|
860
|
+
|
|
861
|
+
return tables_list
|
|
862
|
+
|
|
863
|
+
@property
|
|
864
|
+
def functions(self) -> list[FunctionModel]:
|
|
865
|
+
"""Extract functions from the serialized Genie space.
|
|
866
|
+
|
|
867
|
+
Databricks Genie stores functions in multiple locations:
|
|
868
|
+
- instructions.sql_functions[].identifier (SQL functions)
|
|
869
|
+
- data_sources.functions[].identifier (other functions)
|
|
870
|
+
Only includes functions that actually exist in Unity Catalog.
|
|
871
|
+
"""
|
|
872
|
+
parsed_space = self._parse_serialized_space()
|
|
873
|
+
functions_list: list[FunctionModel] = []
|
|
874
|
+
seen_functions: set[str] = set()
|
|
875
|
+
|
|
876
|
+
def add_function_if_exists(function_name: str) -> None:
|
|
877
|
+
"""Helper to add a function if it exists and hasn't been added."""
|
|
878
|
+
if function_name in seen_functions:
|
|
879
|
+
return
|
|
880
|
+
|
|
881
|
+
seen_functions.add(function_name)
|
|
882
|
+
function_model = FunctionModel(
|
|
883
|
+
name=function_name,
|
|
884
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
885
|
+
service_principal=self.service_principal,
|
|
886
|
+
client_id=self.client_id,
|
|
887
|
+
client_secret=self.client_secret,
|
|
888
|
+
workspace_host=self.workspace_host,
|
|
889
|
+
pat=self.pat,
|
|
890
|
+
)
|
|
891
|
+
# Share the cached workspace client if available
|
|
892
|
+
if self._workspace_client is not None:
|
|
893
|
+
function_model._workspace_client = self._workspace_client
|
|
894
|
+
|
|
895
|
+
# Verify the function exists before adding
|
|
896
|
+
if not function_model.exists():
|
|
897
|
+
return
|
|
898
|
+
|
|
899
|
+
functions_list.append(function_model)
|
|
900
|
+
|
|
901
|
+
# Primary structure: instructions.sql_functions with 'identifier' field
|
|
902
|
+
if "instructions" in parsed_space:
|
|
903
|
+
instructions = parsed_space["instructions"]
|
|
904
|
+
if isinstance(instructions, dict) and "sql_functions" in instructions:
|
|
905
|
+
sql_functions_data = instructions["sql_functions"]
|
|
906
|
+
if isinstance(sql_functions_data, list):
|
|
907
|
+
for function_item in sql_functions_data:
|
|
908
|
+
if isinstance(function_item, dict):
|
|
909
|
+
# SQL functions use 'identifier' field
|
|
910
|
+
function_name = function_item.get(
|
|
911
|
+
"identifier"
|
|
912
|
+
) or function_item.get("name")
|
|
913
|
+
if function_name:
|
|
914
|
+
add_function_if_exists(function_name)
|
|
915
|
+
|
|
916
|
+
# Secondary structure: data_sources.functions with 'identifier' field
|
|
917
|
+
if "data_sources" in parsed_space:
|
|
918
|
+
data_sources = parsed_space["data_sources"]
|
|
919
|
+
if isinstance(data_sources, dict) and "functions" in data_sources:
|
|
920
|
+
functions_data = data_sources["functions"]
|
|
921
|
+
if isinstance(functions_data, list):
|
|
922
|
+
for function_item in functions_data:
|
|
923
|
+
function_name: str | None = None
|
|
924
|
+
if isinstance(function_item, dict):
|
|
925
|
+
# Standard Databricks structure uses 'identifier'
|
|
926
|
+
function_name = function_item.get(
|
|
927
|
+
"identifier"
|
|
928
|
+
) or function_item.get("name")
|
|
929
|
+
elif isinstance(function_item, str):
|
|
930
|
+
function_name = function_item
|
|
931
|
+
|
|
932
|
+
if function_name:
|
|
933
|
+
add_function_if_exists(function_name)
|
|
934
|
+
|
|
935
|
+
return functions_list
|
|
936
|
+
|
|
439
937
|
@property
|
|
440
938
|
def api_scopes(self) -> Sequence[str]:
|
|
441
939
|
return [
|
|
@@ -451,12 +949,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
451
949
|
]
|
|
452
950
|
|
|
453
951
|
@model_validator(mode="after")
|
|
454
|
-
def update_space_id(self):
|
|
952
|
+
def update_space_id(self) -> Self:
|
|
455
953
|
self.space_id = value_of(self.space_id)
|
|
456
954
|
return self
|
|
457
955
|
|
|
956
|
+
@model_validator(mode="after")
|
|
957
|
+
def update_description_from_space(self) -> Self:
|
|
958
|
+
"""Populate description from GenieSpace if not provided."""
|
|
959
|
+
if not self.description:
|
|
960
|
+
try:
|
|
961
|
+
space_details = self._get_space_details()
|
|
962
|
+
if space_details.description:
|
|
963
|
+
self.description = space_details.description
|
|
964
|
+
except Exception as e:
|
|
965
|
+
logger.debug(f"Could not fetch description from Genie space: {e}")
|
|
966
|
+
return self
|
|
458
967
|
|
|
459
|
-
|
|
968
|
+
|
|
969
|
+
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
460
970
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
461
971
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
462
972
|
name: str
|
|
@@ -516,28 +1026,93 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
516
1026
|
provider.create_path(self)
|
|
517
1027
|
|
|
518
1028
|
|
|
519
|
-
class VectorStoreModel(
|
|
1029
|
+
class VectorStoreModel(IsDatabricksResource):
|
|
1030
|
+
"""
|
|
1031
|
+
Configuration model for a Databricks Vector Search store.
|
|
1032
|
+
|
|
1033
|
+
Supports two modes:
|
|
1034
|
+
1. **Use Existing Index**: Provide only `index` (fully qualified name).
|
|
1035
|
+
Used for querying an existing vector search index at runtime.
|
|
1036
|
+
2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
|
|
1037
|
+
Used for creating a new vector search index.
|
|
1038
|
+
|
|
1039
|
+
Examples:
|
|
1040
|
+
Minimal configuration (use existing index):
|
|
1041
|
+
```yaml
|
|
1042
|
+
vector_stores:
|
|
1043
|
+
products_search:
|
|
1044
|
+
index:
|
|
1045
|
+
name: catalog.schema.my_index
|
|
1046
|
+
```
|
|
1047
|
+
|
|
1048
|
+
Full provisioning configuration:
|
|
1049
|
+
```yaml
|
|
1050
|
+
vector_stores:
|
|
1051
|
+
products_search:
|
|
1052
|
+
source_table:
|
|
1053
|
+
schema: *my_schema
|
|
1054
|
+
name: products
|
|
1055
|
+
embedding_source_column: description
|
|
1056
|
+
endpoint:
|
|
1057
|
+
name: my_endpoint
|
|
1058
|
+
```
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
520
1061
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
521
|
-
|
|
1062
|
+
|
|
1063
|
+
# RUNTIME: Only index is truly required for querying existing indexes
|
|
522
1064
|
index: Optional[IndexModel] = None
|
|
1065
|
+
|
|
1066
|
+
# PROVISIONING ONLY: Required when creating a new index
|
|
1067
|
+
source_table: Optional[TableModel] = None
|
|
1068
|
+
embedding_source_column: Optional[str] = None
|
|
1069
|
+
embedding_model: Optional[LLMModel] = None
|
|
523
1070
|
endpoint: Optional[VectorSearchEndpoint] = None
|
|
524
|
-
|
|
1071
|
+
|
|
1072
|
+
# OPTIONAL: For both modes
|
|
525
1073
|
source_path: Optional[VolumePathModel] = None
|
|
526
1074
|
checkpoint_path: Optional[VolumePathModel] = None
|
|
527
1075
|
primary_key: Optional[str] = None
|
|
528
1076
|
columns: Optional[list[str]] = Field(default_factory=list)
|
|
529
1077
|
doc_uri: Optional[str] = None
|
|
530
|
-
embedding_source_column: str
|
|
531
1078
|
|
|
532
1079
|
@model_validator(mode="after")
|
|
533
|
-
def
|
|
534
|
-
|
|
1080
|
+
def validate_configuration_mode(self) -> Self:
|
|
1081
|
+
"""
|
|
1082
|
+
Validate that configuration is valid for either:
|
|
1083
|
+
- Use existing mode: index is provided
|
|
1084
|
+
- Provisioning mode: source_table + embedding_source_column provided
|
|
1085
|
+
"""
|
|
1086
|
+
has_index = self.index is not None
|
|
1087
|
+
has_source_table = self.source_table is not None
|
|
1088
|
+
has_embedding_col = self.embedding_source_column is not None
|
|
1089
|
+
|
|
1090
|
+
# Must have at least index OR source_table
|
|
1091
|
+
if not has_index and not has_source_table:
|
|
1092
|
+
raise ValueError(
|
|
1093
|
+
"Either 'index' (for existing indexes) or 'source_table' "
|
|
1094
|
+
"(for provisioning) must be provided"
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
# If provisioning mode, need embedding_source_column
|
|
1098
|
+
if has_source_table and not has_embedding_col:
|
|
1099
|
+
raise ValueError(
|
|
1100
|
+
"embedding_source_column is required when source_table is provided (provisioning mode)"
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
return self
|
|
1104
|
+
|
|
1105
|
+
@model_validator(mode="after")
|
|
1106
|
+
def set_default_embedding_model(self) -> Self:
|
|
1107
|
+
# Only set default embedding model in provisioning mode
|
|
1108
|
+
if self.source_table is not None and not self.embedding_model:
|
|
535
1109
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
536
1110
|
return self
|
|
537
1111
|
|
|
538
1112
|
@model_validator(mode="after")
|
|
539
|
-
def set_default_primary_key(self):
|
|
540
|
-
|
|
1113
|
+
def set_default_primary_key(self) -> Self:
|
|
1114
|
+
# Only auto-discover primary key in provisioning mode
|
|
1115
|
+
if self.primary_key is None and self.source_table is not None:
|
|
541
1116
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
542
1117
|
|
|
543
1118
|
provider: DatabricksProvider = DatabricksProvider()
|
|
@@ -557,15 +1132,17 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
557
1132
|
return self
|
|
558
1133
|
|
|
559
1134
|
@model_validator(mode="after")
|
|
560
|
-
def set_default_index(self):
|
|
561
|
-
|
|
1135
|
+
def set_default_index(self) -> Self:
|
|
1136
|
+
# Only generate index from source_table in provisioning mode
|
|
1137
|
+
if self.index is None and self.source_table is not None:
|
|
562
1138
|
name: str = f"{self.source_table.name}_index"
|
|
563
1139
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
564
1140
|
return self
|
|
565
1141
|
|
|
566
1142
|
@model_validator(mode="after")
|
|
567
|
-
def set_default_endpoint(self):
|
|
568
|
-
|
|
1143
|
+
def set_default_endpoint(self) -> Self:
|
|
1144
|
+
# Only find/create endpoint in provisioning mode
|
|
1145
|
+
if self.endpoint is None and self.source_table is not None:
|
|
569
1146
|
from dao_ai.providers.databricks import (
|
|
570
1147
|
DatabricksProvider,
|
|
571
1148
|
with_available_indexes,
|
|
@@ -600,77 +1177,64 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
600
1177
|
return self.index.as_resources()
|
|
601
1178
|
|
|
602
1179
|
def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
|
|
603
|
-
from dao_ai.providers.base import ServiceProvider
|
|
604
1180
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
605
1181
|
|
|
606
|
-
provider:
|
|
1182
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
607
1183
|
index: VectorSearchIndex = provider.get_vector_index(self)
|
|
608
1184
|
return index
|
|
609
1185
|
|
|
610
|
-
def create(self, vsc: VectorSearchClient | None = None) -> None:
|
|
611
|
-
|
|
612
|
-
|
|
1186
|
+
def create(self, vsc: VectorSearchClient | None = None) -> None:
|
|
1187
|
+
"""
|
|
1188
|
+
Create or validate the vector search index.
|
|
1189
|
+
|
|
1190
|
+
Behavior depends on configuration mode:
|
|
1191
|
+
- **Provisioning Mode** (source_table provided): Creates the index
|
|
1192
|
+
- **Use Existing Mode** (only index provided): Validates the index exists
|
|
613
1193
|
|
|
614
|
-
|
|
615
|
-
|
|
1194
|
+
Args:
|
|
1195
|
+
vsc: Optional VectorSearchClient instance
|
|
616
1196
|
|
|
1197
|
+
Raises:
|
|
1198
|
+
ValueError: If configuration is invalid or index doesn't exist
|
|
1199
|
+
"""
|
|
1200
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
617
1201
|
|
|
618
|
-
|
|
619
|
-
model_config = ConfigDict()
|
|
620
|
-
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
621
|
-
name: Optional[str] = None
|
|
1202
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
622
1203
|
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
628
|
-
)
|
|
629
|
-
return self
|
|
1204
|
+
if self.source_table is not None:
|
|
1205
|
+
self._create_new_index(provider)
|
|
1206
|
+
else:
|
|
1207
|
+
self._validate_existing_index(provider)
|
|
630
1208
|
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
if self.
|
|
634
|
-
|
|
635
|
-
if self.name:
|
|
636
|
-
name = f".{self.name}"
|
|
637
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
638
|
-
return self.name
|
|
1209
|
+
def _validate_existing_index(self, provider: Any) -> None:
|
|
1210
|
+
"""Validate that an existing index is accessible."""
|
|
1211
|
+
if self.index is None:
|
|
1212
|
+
raise ValueError("index is required for 'use existing' mode")
|
|
639
1213
|
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
DatabricksFunction(
|
|
645
|
-
function_name=self.full_name,
|
|
646
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
647
|
-
)
|
|
1214
|
+
if self.index.exists():
|
|
1215
|
+
logger.info(
|
|
1216
|
+
"Vector search index exists and ready",
|
|
1217
|
+
index_name=self.index.full_name,
|
|
648
1218
|
)
|
|
649
1219
|
else:
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
catalog_name=self.schema_model.catalog_name,
|
|
654
|
-
schema_name=self.schema_model.schema_name,
|
|
655
|
-
)
|
|
656
|
-
resources.extend(
|
|
657
|
-
[
|
|
658
|
-
DatabricksFunction(
|
|
659
|
-
function_name=f"{schema_full_name}.{function.name}",
|
|
660
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
661
|
-
)
|
|
662
|
-
for function in functions
|
|
663
|
-
]
|
|
1220
|
+
raise ValueError(
|
|
1221
|
+
f"Index '{self.index.full_name}' does not exist. "
|
|
1222
|
+
"Provide 'source_table' to provision it."
|
|
664
1223
|
)
|
|
665
1224
|
|
|
666
|
-
|
|
1225
|
+
def _create_new_index(self, provider: Any) -> None:
|
|
1226
|
+
"""Create a new vector search index from source table."""
|
|
1227
|
+
if self.embedding_source_column is None:
|
|
1228
|
+
raise ValueError("embedding_source_column is required for provisioning")
|
|
1229
|
+
if self.endpoint is None:
|
|
1230
|
+
raise ValueError("endpoint is required for provisioning")
|
|
1231
|
+
if self.index is None:
|
|
1232
|
+
raise ValueError("index is required for provisioning")
|
|
667
1233
|
|
|
668
|
-
|
|
669
|
-
def api_scopes(self) -> Sequence[str]:
|
|
670
|
-
return ["sql.statement-execution"]
|
|
1234
|
+
provider.create_vector_store(self)
|
|
671
1235
|
|
|
672
1236
|
|
|
673
|
-
class ConnectionModel(
|
|
1237
|
+
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
674
1238
|
model_config = ConfigDict()
|
|
675
1239
|
name: str
|
|
676
1240
|
|
|
@@ -697,34 +1261,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
697
1261
|
]
|
|
698
1262
|
|
|
699
1263
|
|
|
700
|
-
class
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
1264
|
+
class DatabaseModel(IsDatabricksResource):
|
|
1265
|
+
"""
|
|
1266
|
+
Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
|
|
1267
|
+
|
|
1268
|
+
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
1269
|
+
- user/password: For user-based database authentication
|
|
1270
|
+
|
|
1271
|
+
Connection Types (determined by fields provided):
|
|
1272
|
+
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1273
|
+
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1274
|
+
|
|
1275
|
+
Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
|
|
1276
|
+
|
|
1277
|
+
Example Databricks Lakebase with Service Principal:
|
|
1278
|
+
```yaml
|
|
1279
|
+
databases:
|
|
1280
|
+
my_lakebase:
|
|
1281
|
+
name: my-database
|
|
1282
|
+
instance_name: my-lakebase-instance
|
|
1283
|
+
service_principal:
|
|
1284
|
+
client_id:
|
|
1285
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
1286
|
+
client_secret:
|
|
1287
|
+
scope: my-scope
|
|
1288
|
+
secret: sp-client-secret
|
|
1289
|
+
workspace_host:
|
|
1290
|
+
env: DATABRICKS_HOST
|
|
1291
|
+
```
|
|
1292
|
+
|
|
1293
|
+
Example Databricks Lakebase with Ambient Authentication:
|
|
1294
|
+
```yaml
|
|
1295
|
+
databases:
|
|
1296
|
+
my_lakebase:
|
|
1297
|
+
name: my-database
|
|
1298
|
+
instance_name: my-lakebase-instance
|
|
1299
|
+
on_behalf_of_user: true
|
|
1300
|
+
```
|
|
1301
|
+
|
|
1302
|
+
Example Standard PostgreSQL:
|
|
1303
|
+
```yaml
|
|
1304
|
+
databases:
|
|
1305
|
+
my_postgres:
|
|
1306
|
+
name: my-database
|
|
1307
|
+
host: my-postgres-host.example.com
|
|
1308
|
+
port: 5432
|
|
1309
|
+
database: my_db
|
|
1310
|
+
user: my_user
|
|
1311
|
+
password:
|
|
1312
|
+
env: PGPASSWORD
|
|
1313
|
+
```
|
|
1314
|
+
"""
|
|
726
1315
|
|
|
727
|
-
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
728
1316
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
729
1317
|
name: str
|
|
730
1318
|
instance_name: Optional[str] = None
|
|
@@ -737,80 +1325,117 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
737
1325
|
timeout_seconds: Optional[int] = 10
|
|
738
1326
|
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
739
1327
|
node_count: Optional[int] = None
|
|
1328
|
+
# Database-specific auth (user identity for DB connection)
|
|
740
1329
|
user: Optional[AnyVariable] = None
|
|
741
1330
|
password: Optional[AnyVariable] = None
|
|
742
|
-
client_id: Optional[AnyVariable] = None
|
|
743
|
-
client_secret: Optional[AnyVariable] = None
|
|
744
|
-
workspace_host: Optional[AnyVariable] = None
|
|
745
1331
|
|
|
746
1332
|
@property
|
|
747
1333
|
def api_scopes(self) -> Sequence[str]:
|
|
748
|
-
return []
|
|
1334
|
+
return ["database.database-instances"]
|
|
1335
|
+
|
|
1336
|
+
@property
|
|
1337
|
+
def is_lakebase(self) -> bool:
|
|
1338
|
+
"""Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
|
|
1339
|
+
return self.instance_name is not None
|
|
749
1340
|
|
|
750
1341
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
1342
|
+
if self.is_lakebase:
|
|
1343
|
+
return [
|
|
1344
|
+
DatabricksLakebase(
|
|
1345
|
+
database_instance_name=self.instance_name,
|
|
1346
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
1347
|
+
)
|
|
1348
|
+
]
|
|
1349
|
+
return []
|
|
757
1350
|
|
|
758
1351
|
@model_validator(mode="after")
|
|
759
|
-
def
|
|
760
|
-
|
|
761
|
-
self.instance_name = self.name
|
|
1352
|
+
def validate_connection_type(self) -> Self:
|
|
1353
|
+
"""Validate connection configuration based on type.
|
|
762
1354
|
|
|
1355
|
+
- If instance_name is provided: Databricks Lakebase connection
|
|
1356
|
+
(host is optional - will be fetched from API if not provided)
|
|
1357
|
+
- If only host is provided: Standard PostgreSQL connection
|
|
1358
|
+
(must not have instance_name)
|
|
1359
|
+
"""
|
|
1360
|
+
if not self.instance_name and not self.host:
|
|
1361
|
+
raise ValueError(
|
|
1362
|
+
"Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
|
|
1363
|
+
)
|
|
763
1364
|
return self
|
|
764
1365
|
|
|
765
1366
|
@model_validator(mode="after")
|
|
766
|
-
def update_user(self):
|
|
767
|
-
if
|
|
1367
|
+
def update_user(self) -> Self:
|
|
1368
|
+
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
1369
|
+
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
768
1370
|
return self
|
|
769
1371
|
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
1372
|
+
# For standard PostgreSQL, we need explicit user credentials
|
|
1373
|
+
# For Lakebase with no auth, ambient auth is allowed
|
|
1374
|
+
if not self.is_lakebase:
|
|
1375
|
+
# Standard PostgreSQL - try to determine current user for local development
|
|
1376
|
+
try:
|
|
1377
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1378
|
+
except Exception as e:
|
|
1379
|
+
logger.warning(
|
|
1380
|
+
f"Could not determine current user for PostgreSQL database: {e}. "
|
|
1381
|
+
f"Please provide explicit user credentials."
|
|
1382
|
+
)
|
|
1383
|
+
else:
|
|
1384
|
+
# For Lakebase, try to determine current user but don't fail if we can't
|
|
1385
|
+
try:
|
|
1386
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1387
|
+
except Exception:
|
|
1388
|
+
# If we can't determine user and no explicit auth, that's okay
|
|
1389
|
+
# for Lakebase with ambient auth - credentials will be injected at runtime
|
|
1390
|
+
pass
|
|
775
1391
|
|
|
776
1392
|
return self
|
|
777
1393
|
|
|
778
1394
|
@model_validator(mode="after")
|
|
779
|
-
def update_host(self):
|
|
780
|
-
|
|
1395
|
+
def update_host(self) -> Self:
|
|
1396
|
+
# Lakebase uses instance_name directly via databricks_langchain - host not needed
|
|
1397
|
+
if self.is_lakebase:
|
|
781
1398
|
return self
|
|
782
1399
|
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
name=self.instance_name
|
|
786
|
-
)
|
|
787
|
-
)
|
|
788
|
-
self.host = existing_instance.read_write_dns
|
|
1400
|
+
# For standard PostgreSQL, host must be provided by the user
|
|
1401
|
+
# (enforced by validate_connection_type)
|
|
789
1402
|
return self
|
|
790
1403
|
|
|
791
1404
|
@model_validator(mode="after")
|
|
792
|
-
def validate_auth_methods(self):
|
|
1405
|
+
def validate_auth_methods(self) -> Self:
|
|
793
1406
|
oauth_fields: Sequence[Any] = [
|
|
794
1407
|
self.workspace_host,
|
|
795
1408
|
self.client_id,
|
|
796
1409
|
self.client_secret,
|
|
797
1410
|
]
|
|
798
1411
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1412
|
+
has_user_auth: bool = self.user is not None
|
|
1413
|
+
has_obo: bool = self.on_behalf_of_user is True
|
|
1414
|
+
has_pat: bool = self.pat is not None
|
|
799
1415
|
|
|
800
|
-
|
|
801
|
-
|
|
1416
|
+
# Count how many auth methods are configured
|
|
1417
|
+
auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
|
|
802
1418
|
|
|
803
|
-
if
|
|
1419
|
+
if auth_methods_count > 1:
|
|
804
1420
|
raise ValueError(
|
|
805
|
-
"Cannot
|
|
806
|
-
"Please provide
|
|
1421
|
+
"Cannot mix authentication methods. "
|
|
1422
|
+
"Please provide exactly one of: "
|
|
1423
|
+
"on_behalf_of_user=true (for passive auth in model serving), "
|
|
1424
|
+
"OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
|
|
1425
|
+
"PAT (personal access token), "
|
|
1426
|
+
"or user credentials (user)."
|
|
807
1427
|
)
|
|
808
1428
|
|
|
809
|
-
|
|
1429
|
+
# For standard PostgreSQL (host-based), at least one auth method must be configured
|
|
1430
|
+
# For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
|
|
1431
|
+
if not self.is_lakebase and auth_methods_count == 0:
|
|
810
1432
|
raise ValueError(
|
|
811
|
-
"
|
|
812
|
-
"
|
|
813
|
-
"
|
|
1433
|
+
"PostgreSQL databases require explicit authentication. "
|
|
1434
|
+
"Please provide one of: "
|
|
1435
|
+
"OAuth credentials (workspace_host, client_id, client_secret), "
|
|
1436
|
+
"service_principal with workspace_host, "
|
|
1437
|
+
"PAT (personal access token), "
|
|
1438
|
+
"or user credentials (user)."
|
|
814
1439
|
)
|
|
815
1440
|
|
|
816
1441
|
return self
|
|
@@ -821,38 +1446,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
821
1446
|
Get database connection parameters as a dictionary.
|
|
822
1447
|
|
|
823
1448
|
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
824
|
-
|
|
825
|
-
|
|
1449
|
+
|
|
1450
|
+
For Lakebase: Uses Databricks-generated credentials (token-based auth).
|
|
1451
|
+
For standard PostgreSQL: Uses provided user/password credentials.
|
|
826
1452
|
"""
|
|
827
|
-
|
|
828
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
1453
|
+
import uuid as _uuid
|
|
829
1454
|
|
|
1455
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
1456
|
+
|
|
1457
|
+
host: str
|
|
1458
|
+
port: int
|
|
1459
|
+
database: str
|
|
830
1460
|
username: str | None = None
|
|
1461
|
+
password_value: str | None = None
|
|
1462
|
+
|
|
1463
|
+
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1464
|
+
host_value: Any = self.host
|
|
1465
|
+
if host_value is None and self.is_lakebase and self.on_behalf_of_user:
|
|
1466
|
+
# Fetch host at runtime for OBO mode
|
|
1467
|
+
existing_instance: DatabaseInstance = (
|
|
1468
|
+
self.workspace_client.database.get_database_instance(
|
|
1469
|
+
name=self.instance_name
|
|
1470
|
+
)
|
|
1471
|
+
)
|
|
1472
|
+
host_value = existing_instance.read_write_dns
|
|
1473
|
+
|
|
1474
|
+
if host_value is None:
|
|
1475
|
+
instance_or_name = self.instance_name if self.is_lakebase else self.name
|
|
1476
|
+
raise ValueError(
|
|
1477
|
+
f"Database host not configured for {instance_or_name}. "
|
|
1478
|
+
"Please provide 'host' explicitly."
|
|
1479
|
+
)
|
|
831
1480
|
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
username = value_of(self.user)
|
|
1481
|
+
host = value_of(host_value)
|
|
1482
|
+
port = value_of(self.port)
|
|
1483
|
+
database = value_of(self.database)
|
|
836
1484
|
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
1485
|
+
if self.is_lakebase:
|
|
1486
|
+
# Lakebase: Use Databricks-generated credentials
|
|
1487
|
+
if self.client_id and self.client_secret and self.workspace_host:
|
|
1488
|
+
username = value_of(self.client_id)
|
|
1489
|
+
elif self.user:
|
|
1490
|
+
username = value_of(self.user)
|
|
1491
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
840
1492
|
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
1493
|
+
# Generate Databricks database credential (token)
|
|
1494
|
+
w: WorkspaceClient = self.workspace_client
|
|
1495
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1496
|
+
request_id=str(_uuid.uuid4()),
|
|
1497
|
+
instance_names=[self.instance_name],
|
|
1498
|
+
)
|
|
1499
|
+
password_value = cred.token
|
|
1500
|
+
else:
|
|
1501
|
+
# Standard PostgreSQL: Use provided credentials
|
|
1502
|
+
if self.user:
|
|
1503
|
+
username = value_of(self.user)
|
|
1504
|
+
if self.password:
|
|
1505
|
+
password_value = value_of(self.password)
|
|
847
1506
|
|
|
848
|
-
|
|
1507
|
+
if not username or not password_value:
|
|
1508
|
+
raise ValueError(
|
|
1509
|
+
f"Standard PostgreSQL databases require both 'user' and 'password'. "
|
|
1510
|
+
f"Database: {self.name}"
|
|
1511
|
+
)
|
|
849
1512
|
|
|
850
1513
|
# Build connection parameters dictionary
|
|
851
1514
|
params: dict[str, Any] = {
|
|
852
1515
|
"dbname": database,
|
|
853
1516
|
"host": host,
|
|
854
1517
|
"port": port,
|
|
855
|
-
"password":
|
|
1518
|
+
"password": password_value,
|
|
856
1519
|
"sslmode": "require",
|
|
857
1520
|
}
|
|
858
1521
|
|
|
@@ -883,11 +1546,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
883
1546
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
884
1547
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
885
1548
|
|
|
886
|
-
|
|
1549
|
+
# Use provided workspace client or fall back to resource's own workspace_client
|
|
1550
|
+
if w is None:
|
|
1551
|
+
w = self.workspace_client
|
|
1552
|
+
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
887
1553
|
provider.create_lakebase(self)
|
|
888
1554
|
provider.create_lakebase_instance_role(self)
|
|
889
1555
|
|
|
890
1556
|
|
|
1557
|
+
class GenieLRUCacheParametersModel(BaseModel):
|
|
1558
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1559
|
+
capacity: int = 1000
|
|
1560
|
+
time_to_live_seconds: int | None = (
|
|
1561
|
+
60 * 60 * 24
|
|
1562
|
+
) # 1 day default, None or negative = never expires
|
|
1563
|
+
warehouse: WarehouseModel
|
|
1564
|
+
|
|
1565
|
+
|
|
1566
|
+
class GenieSemanticCacheParametersModel(BaseModel):
|
|
1567
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1568
|
+
time_to_live_seconds: int | None = (
|
|
1569
|
+
60 * 60 * 24
|
|
1570
|
+
) # 1 day default, None or negative = never expires
|
|
1571
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1572
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1573
|
+
question_weight: Optional[float] = (
|
|
1574
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1575
|
+
)
|
|
1576
|
+
context_weight: Optional[float] = (
|
|
1577
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1578
|
+
)
|
|
1579
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1580
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1581
|
+
database: DatabaseModel
|
|
1582
|
+
warehouse: WarehouseModel
|
|
1583
|
+
table_name: str = "genie_semantic_cache"
|
|
1584
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1585
|
+
max_context_tokens: int = (
|
|
1586
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1587
|
+
)
|
|
1588
|
+
|
|
1589
|
+
@model_validator(mode="after")
|
|
1590
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1591
|
+
"""
|
|
1592
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1593
|
+
|
|
1594
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1595
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1596
|
+
If both are provided, they must sum to 1.0.
|
|
1597
|
+
"""
|
|
1598
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1599
|
+
# Both missing - use defaults
|
|
1600
|
+
self.question_weight = 0.6
|
|
1601
|
+
self.context_weight = 0.4
|
|
1602
|
+
elif self.question_weight is None:
|
|
1603
|
+
# Compute question_weight from context_weight
|
|
1604
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1605
|
+
raise ValueError(
|
|
1606
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1607
|
+
)
|
|
1608
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1609
|
+
elif self.context_weight is None:
|
|
1610
|
+
# Compute context_weight from question_weight
|
|
1611
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1612
|
+
raise ValueError(
|
|
1613
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1614
|
+
)
|
|
1615
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1616
|
+
else:
|
|
1617
|
+
# Both provided - validate they sum to 1.0
|
|
1618
|
+
total_weight = self.question_weight + self.context_weight
|
|
1619
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1620
|
+
raise ValueError(
|
|
1621
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1622
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1623
|
+
f"of question vs context similarity in the combined score."
|
|
1624
|
+
)
|
|
1625
|
+
|
|
1626
|
+
return self
|
|
1627
|
+
|
|
1628
|
+
|
|
891
1629
|
class SearchParametersModel(BaseModel):
|
|
892
1630
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
893
1631
|
num_results: Optional[int] = 10
|
|
@@ -918,11 +1656,13 @@ class RerankParametersModel(BaseModel):
|
|
|
918
1656
|
top_n: 5 # Return top 5 after reranking
|
|
919
1657
|
```
|
|
920
1658
|
|
|
921
|
-
Available models (
|
|
922
|
-
- "ms-marco-TinyBERT-L-2-v2" (
|
|
923
|
-
- "ms-marco-MiniLM-L-
|
|
924
|
-
- "
|
|
925
|
-
- "
|
|
1659
|
+
Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1660
|
+
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1661
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
|
|
1662
|
+
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1663
|
+
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1664
|
+
- "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
|
|
1665
|
+
- "miniReranker_arabic_v1" (Arabic language)
|
|
926
1666
|
"""
|
|
927
1667
|
|
|
928
1668
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -936,8 +1676,8 @@ class RerankParametersModel(BaseModel):
|
|
|
936
1676
|
description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
|
|
937
1677
|
)
|
|
938
1678
|
cache_dir: Optional[str] = Field(
|
|
939
|
-
default="/
|
|
940
|
-
description="Directory to cache downloaded model weights.",
|
|
1679
|
+
default="~/.dao_ai/cache/flashrank",
|
|
1680
|
+
description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
|
|
941
1681
|
)
|
|
942
1682
|
columns: Optional[list[str]] = Field(
|
|
943
1683
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
@@ -957,14 +1697,14 @@ class RetrieverModel(BaseModel):
|
|
|
957
1697
|
)
|
|
958
1698
|
|
|
959
1699
|
@model_validator(mode="after")
|
|
960
|
-
def set_default_columns(self):
|
|
1700
|
+
def set_default_columns(self) -> Self:
|
|
961
1701
|
if not self.columns:
|
|
962
1702
|
columns: Sequence[str] = self.vector_store.columns
|
|
963
1703
|
self.columns = columns
|
|
964
1704
|
return self
|
|
965
1705
|
|
|
966
1706
|
@model_validator(mode="after")
|
|
967
|
-
def set_default_reranker(self):
|
|
1707
|
+
def set_default_reranker(self) -> Self:
|
|
968
1708
|
"""Convert bool to ReRankParametersModel with defaults."""
|
|
969
1709
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
970
1710
|
self.rerank = RerankParametersModel()
|
|
@@ -978,28 +1718,47 @@ class FunctionType(str, Enum):
|
|
|
978
1718
|
MCP = "mcp"
|
|
979
1719
|
|
|
980
1720
|
|
|
981
|
-
class
|
|
982
|
-
"""
|
|
1721
|
+
class HumanInTheLoopModel(BaseModel):
|
|
1722
|
+
"""
|
|
1723
|
+
Configuration for Human-in-the-Loop tool approval.
|
|
983
1724
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
RESPONSE = "response"
|
|
987
|
-
DECLINE = "decline"
|
|
1725
|
+
This model configures when and how tools require human approval before execution.
|
|
1726
|
+
It maps to LangChain's HumanInTheLoopMiddleware.
|
|
988
1727
|
|
|
1728
|
+
LangChain supports three decision types:
|
|
1729
|
+
- "approve": Execute tool with original arguments
|
|
1730
|
+
- "edit": Modify arguments before execution
|
|
1731
|
+
- "reject": Skip execution with optional feedback message
|
|
1732
|
+
"""
|
|
989
1733
|
|
|
990
|
-
class HumanInTheLoopModel(BaseModel):
|
|
991
1734
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1735
|
+
|
|
1736
|
+
review_prompt: Optional[str] = Field(
|
|
1737
|
+
default=None,
|
|
1738
|
+
description="Message shown to the reviewer when approval is requested",
|
|
1739
|
+
)
|
|
1740
|
+
|
|
1741
|
+
allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
|
|
1742
|
+
default_factory=lambda: ["approve", "edit", "reject"],
|
|
1743
|
+
description="List of allowed decision types for this tool",
|
|
1000
1744
|
)
|
|
1001
|
-
|
|
1002
|
-
|
|
1745
|
+
|
|
1746
|
+
@model_validator(mode="after")
|
|
1747
|
+
def validate_and_normalize_decisions(self) -> Self:
|
|
1748
|
+
"""Validate and normalize allowed decisions."""
|
|
1749
|
+
if not self.allowed_decisions:
|
|
1750
|
+
raise ValueError("At least one decision type must be allowed")
|
|
1751
|
+
|
|
1752
|
+
# Remove duplicates while preserving order
|
|
1753
|
+
seen = set()
|
|
1754
|
+
unique_decisions = []
|
|
1755
|
+
for decision in self.allowed_decisions:
|
|
1756
|
+
if decision not in seen:
|
|
1757
|
+
seen.add(decision)
|
|
1758
|
+
unique_decisions.append(decision)
|
|
1759
|
+
self.allowed_decisions = unique_decisions
|
|
1760
|
+
|
|
1761
|
+
return self
|
|
1003
1762
|
|
|
1004
1763
|
|
|
1005
1764
|
class BaseFunctionModel(ABC, BaseModel):
|
|
@@ -1008,7 +1767,6 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1008
1767
|
discriminator="type",
|
|
1009
1768
|
)
|
|
1010
1769
|
type: FunctionType
|
|
1011
|
-
name: str
|
|
1012
1770
|
human_in_the_loop: Optional[HumanInTheLoopModel] = None
|
|
1013
1771
|
|
|
1014
1772
|
@abstractmethod
|
|
@@ -1025,6 +1783,7 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1025
1783
|
class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
1026
1784
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1027
1785
|
type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
|
|
1786
|
+
name: str
|
|
1028
1787
|
|
|
1029
1788
|
@property
|
|
1030
1789
|
def full_name(self) -> str:
|
|
@@ -1038,8 +1797,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1038
1797
|
|
|
1039
1798
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
1040
1799
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1041
|
-
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1042
1800
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
1801
|
+
name: str
|
|
1802
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1043
1803
|
|
|
1044
1804
|
@property
|
|
1045
1805
|
def full_name(self) -> str:
|
|
@@ -1051,7 +1811,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1051
1811
|
return [create_factory_tool(self, **kwargs)]
|
|
1052
1812
|
|
|
1053
1813
|
@model_validator(mode="after")
|
|
1054
|
-
def update_args(self):
|
|
1814
|
+
def update_args(self) -> Self:
|
|
1055
1815
|
for key, value in self.args.items():
|
|
1056
1816
|
self.args[key] = value_of(value)
|
|
1057
1817
|
return self
|
|
@@ -1062,7 +1822,16 @@ class TransportType(str, Enum):
|
|
|
1062
1822
|
STDIO = "stdio"
|
|
1063
1823
|
|
|
1064
1824
|
|
|
1065
|
-
class McpFunctionModel(BaseFunctionModel,
|
|
1825
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
1826
|
+
"""
|
|
1827
|
+
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1828
|
+
|
|
1829
|
+
Authentication for MCP connections uses the same options as other resources:
|
|
1830
|
+
- Service Principal (client_id + client_secret + workspace_host)
|
|
1831
|
+
- PAT (pat + workspace_host)
|
|
1832
|
+
- OBO (on_behalf_of_user)
|
|
1833
|
+
"""
|
|
1834
|
+
|
|
1066
1835
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1067
1836
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
1068
1837
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
@@ -1070,10 +1839,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1070
1839
|
url: Optional[AnyVariable] = None
|
|
1071
1840
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1072
1841
|
args: list[str] = Field(default_factory=list)
|
|
1073
|
-
|
|
1074
|
-
client_id: Optional[AnyVariable] = None
|
|
1075
|
-
client_secret: Optional[AnyVariable] = None
|
|
1076
|
-
workspace_host: Optional[AnyVariable] = None
|
|
1842
|
+
# MCP-specific fields
|
|
1077
1843
|
connection: Optional[ConnectionModel] = None
|
|
1078
1844
|
functions: Optional[SchemaModel] = None
|
|
1079
1845
|
genie_room: Optional[GenieRoomModel] = None
|
|
@@ -1081,35 +1847,55 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1081
1847
|
vector_search: Optional[VectorStoreModel] = None
|
|
1082
1848
|
|
|
1083
1849
|
@property
|
|
1084
|
-
def
|
|
1085
|
-
|
|
1850
|
+
def api_scopes(self) -> Sequence[str]:
|
|
1851
|
+
"""API scopes for MCP connections."""
|
|
1852
|
+
return [
|
|
1853
|
+
"serving.serving-endpoints",
|
|
1854
|
+
"mcp.genie",
|
|
1855
|
+
"mcp.functions",
|
|
1856
|
+
"mcp.vectorsearch",
|
|
1857
|
+
"mcp.external",
|
|
1858
|
+
]
|
|
1859
|
+
|
|
1860
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
1861
|
+
"""MCP functions don't declare static resources."""
|
|
1862
|
+
return []
|
|
1086
1863
|
|
|
1087
1864
|
def _get_workspace_host(self) -> str:
|
|
1088
1865
|
"""
|
|
1089
1866
|
Get the workspace host, either from config or from workspace client.
|
|
1090
1867
|
|
|
1091
1868
|
If connection is provided, uses its workspace client.
|
|
1092
|
-
Otherwise, falls back to
|
|
1869
|
+
Otherwise, falls back to the default Databricks host.
|
|
1093
1870
|
|
|
1094
1871
|
Returns:
|
|
1095
|
-
str: The workspace host URL without trailing slash
|
|
1872
|
+
str: The workspace host URL with https:// scheme and without trailing slash
|
|
1096
1873
|
"""
|
|
1097
|
-
from
|
|
1874
|
+
from dao_ai.utils import get_default_databricks_host, normalize_host
|
|
1098
1875
|
|
|
1099
1876
|
# Try to get workspace_host from config
|
|
1100
1877
|
workspace_host: str | None = (
|
|
1101
|
-
value_of(self.workspace_host)
|
|
1878
|
+
normalize_host(value_of(self.workspace_host))
|
|
1879
|
+
if self.workspace_host
|
|
1880
|
+
else None
|
|
1102
1881
|
)
|
|
1103
1882
|
|
|
1104
1883
|
# If no workspace_host in config, get it from workspace client
|
|
1105
1884
|
if not workspace_host:
|
|
1106
1885
|
# Use connection's workspace client if available
|
|
1107
1886
|
if self.connection:
|
|
1108
|
-
workspace_host =
|
|
1887
|
+
workspace_host = normalize_host(
|
|
1888
|
+
self.connection.workspace_client.config.host
|
|
1889
|
+
)
|
|
1109
1890
|
else:
|
|
1110
|
-
#
|
|
1111
|
-
|
|
1112
|
-
|
|
1891
|
+
# get_default_databricks_host already normalizes the host
|
|
1892
|
+
workspace_host = get_default_databricks_host()
|
|
1893
|
+
|
|
1894
|
+
if not workspace_host:
|
|
1895
|
+
raise ValueError(
|
|
1896
|
+
"Could not determine workspace host. "
|
|
1897
|
+
"Please set workspace_host in config or DATABRICKS_HOST environment variable."
|
|
1898
|
+
)
|
|
1113
1899
|
|
|
1114
1900
|
# Remove trailing slash
|
|
1115
1901
|
return workspace_host.rstrip("/")
|
|
@@ -1234,74 +2020,132 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1234
2020
|
self.headers[key] = value_of(value)
|
|
1235
2021
|
return self
|
|
1236
2022
|
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
oauth_fields: Sequence[Any] = [
|
|
1240
|
-
self.client_id,
|
|
1241
|
-
self.client_secret,
|
|
1242
|
-
]
|
|
1243
|
-
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1244
|
-
|
|
1245
|
-
pat_fields: Sequence[Any] = [self.pat]
|
|
1246
|
-
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
2023
|
+
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
2024
|
+
from dao_ai.tools import create_mcp_tools
|
|
1247
2025
|
|
|
1248
|
-
|
|
1249
|
-
raise ValueError(
|
|
1250
|
-
"Cannot use both OAuth and user authentication methods. "
|
|
1251
|
-
"Please provide either OAuth credentials or user credentials."
|
|
1252
|
-
)
|
|
2026
|
+
return create_mcp_tools(self)
|
|
1253
2027
|
|
|
1254
|
-
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1255
2028
|
|
|
1256
|
-
|
|
2029
|
+
class UnityCatalogFunctionModel(BaseFunctionModel):
|
|
2030
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2031
|
+
type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
|
|
2032
|
+
resource: FunctionModel
|
|
2033
|
+
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1257
2034
|
|
|
1258
2035
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1259
|
-
from dao_ai.tools import
|
|
2036
|
+
from dao_ai.tools import create_uc_tools
|
|
1260
2037
|
|
|
1261
|
-
return
|
|
2038
|
+
return create_uc_tools(self)
|
|
2039
|
+
|
|
2040
|
+
|
|
2041
|
+
AnyTool: TypeAlias = (
|
|
2042
|
+
Union[
|
|
2043
|
+
PythonFunctionModel,
|
|
2044
|
+
FactoryFunctionModel,
|
|
2045
|
+
UnityCatalogFunctionModel,
|
|
2046
|
+
McpFunctionModel,
|
|
2047
|
+
]
|
|
2048
|
+
| str
|
|
2049
|
+
)
|
|
2050
|
+
|
|
2051
|
+
|
|
2052
|
+
class ToolModel(BaseModel):
|
|
2053
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2054
|
+
name: str
|
|
2055
|
+
function: AnyTool
|
|
1262
2056
|
|
|
1263
2057
|
|
|
1264
|
-
class
|
|
2058
|
+
class PromptModel(BaseModel, HasFullName):
|
|
1265
2059
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1266
2060
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1267
|
-
|
|
1268
|
-
|
|
2061
|
+
name: str
|
|
2062
|
+
description: Optional[str] = None
|
|
2063
|
+
default_template: Optional[str] = None
|
|
2064
|
+
alias: Optional[str] = None
|
|
2065
|
+
version: Optional[int] = None
|
|
2066
|
+
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
2067
|
+
auto_register: bool = Field(
|
|
2068
|
+
default=False,
|
|
2069
|
+
description="Whether to automatically register the default_template to the prompt registry. "
|
|
2070
|
+
"If False, the prompt will only be loaded from the registry (never created/updated). "
|
|
2071
|
+
"Defaults to True for backward compatibility.",
|
|
2072
|
+
)
|
|
2073
|
+
|
|
2074
|
+
@property
|
|
2075
|
+
def template(self) -> str:
|
|
2076
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
2077
|
+
|
|
2078
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
2079
|
+
prompt_version = provider.get_prompt(self)
|
|
2080
|
+
return prompt_version.to_single_brace_format()
|
|
2081
|
+
|
|
2082
|
+
@property
|
|
2083
|
+
def full_name(self) -> str:
|
|
2084
|
+
prompt_name: str = self.name
|
|
2085
|
+
if self.schema_model:
|
|
2086
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
2087
|
+
return prompt_name
|
|
1269
2088
|
|
|
1270
2089
|
@property
|
|
1271
|
-
def
|
|
1272
|
-
|
|
1273
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1274
|
-
return self.name
|
|
2090
|
+
def uri(self) -> str:
|
|
2091
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1275
2092
|
|
|
1276
|
-
|
|
1277
|
-
|
|
2093
|
+
if self.alias:
|
|
2094
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
2095
|
+
elif self.version:
|
|
2096
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
2097
|
+
else:
|
|
2098
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1278
2099
|
|
|
1279
|
-
return
|
|
2100
|
+
return prompt_uri
|
|
1280
2101
|
|
|
2102
|
+
def as_prompt(self) -> PromptVersion:
|
|
2103
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
2104
|
+
return prompt_version
|
|
1281
2105
|
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
McpFunctionModel,
|
|
1288
|
-
]
|
|
1289
|
-
| str
|
|
1290
|
-
)
|
|
2106
|
+
@model_validator(mode="after")
|
|
2107
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
2108
|
+
if self.alias and self.version:
|
|
2109
|
+
raise ValueError("Cannot specify both alias and version")
|
|
2110
|
+
return self
|
|
1291
2111
|
|
|
1292
2112
|
|
|
1293
|
-
class
|
|
2113
|
+
class GuardrailModel(BaseModel):
|
|
1294
2114
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1295
2115
|
name: str
|
|
1296
|
-
|
|
2116
|
+
model: str | LLMModel
|
|
2117
|
+
prompt: str | PromptModel
|
|
2118
|
+
num_retries: Optional[int] = 3
|
|
2119
|
+
|
|
2120
|
+
@model_validator(mode="after")
|
|
2121
|
+
def validate_llm_model(self) -> Self:
|
|
2122
|
+
if isinstance(self.model, str):
|
|
2123
|
+
self.model = LLMModel(name=self.model)
|
|
2124
|
+
return self
|
|
1297
2125
|
|
|
1298
2126
|
|
|
1299
|
-
class
|
|
2127
|
+
class MiddlewareModel(BaseModel):
|
|
2128
|
+
"""Configuration for middleware that can be applied to agents.
|
|
2129
|
+
|
|
2130
|
+
Middleware is defined at the AppConfig level and can be referenced by name
|
|
2131
|
+
in agent configurations using YAML anchors for reusability.
|
|
2132
|
+
"""
|
|
2133
|
+
|
|
1300
2134
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1301
|
-
name: str
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
2135
|
+
name: str = Field(
|
|
2136
|
+
description="Fully qualified name of the middleware factory function"
|
|
2137
|
+
)
|
|
2138
|
+
args: dict[str, Any] = Field(
|
|
2139
|
+
default_factory=dict,
|
|
2140
|
+
description="Arguments to pass to the middleware factory function",
|
|
2141
|
+
)
|
|
2142
|
+
|
|
2143
|
+
@model_validator(mode="after")
|
|
2144
|
+
def resolve_args(self) -> Self:
|
|
2145
|
+
"""Resolve any variable references in args."""
|
|
2146
|
+
for key, value in self.args.items():
|
|
2147
|
+
self.args[key] = value_of(value)
|
|
2148
|
+
return self
|
|
1305
2149
|
|
|
1306
2150
|
|
|
1307
2151
|
class StorageType(str, Enum):
|
|
@@ -1312,14 +2156,12 @@ class StorageType(str, Enum):
|
|
|
1312
2156
|
class CheckpointerModel(BaseModel):
|
|
1313
2157
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1314
2158
|
name: str
|
|
1315
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1316
2159
|
database: Optional[DatabaseModel] = None
|
|
1317
2160
|
|
|
1318
|
-
@
|
|
1319
|
-
def
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
return self
|
|
2161
|
+
@property
|
|
2162
|
+
def storage_type(self) -> StorageType:
|
|
2163
|
+
"""Infer storage type from database presence."""
|
|
2164
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1323
2165
|
|
|
1324
2166
|
def as_checkpointer(self) -> BaseCheckpointSaver:
|
|
1325
2167
|
from dao_ai.memory import CheckpointManager
|
|
@@ -1335,16 +2177,14 @@ class StoreModel(BaseModel):
|
|
|
1335
2177
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1336
2178
|
name: str
|
|
1337
2179
|
embedding_model: Optional[LLMModel] = None
|
|
1338
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1339
2180
|
dims: Optional[int] = 1536
|
|
1340
2181
|
database: Optional[DatabaseModel] = None
|
|
1341
2182
|
namespace: Optional[str] = None
|
|
1342
2183
|
|
|
1343
|
-
@
|
|
1344
|
-
def
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
return self
|
|
2184
|
+
@property
|
|
2185
|
+
def storage_type(self) -> StorageType:
|
|
2186
|
+
"""Infer storage type from database presence."""
|
|
2187
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1348
2188
|
|
|
1349
2189
|
def as_store(self) -> BaseStore:
|
|
1350
2190
|
from dao_ai.memory import StoreManager
|
|
@@ -1362,56 +2202,158 @@ class MemoryModel(BaseModel):
|
|
|
1362
2202
|
FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
|
|
1363
2203
|
|
|
1364
2204
|
|
|
1365
|
-
class
|
|
2205
|
+
class ResponseFormatModel(BaseModel):
|
|
2206
|
+
"""
|
|
2207
|
+
Configuration for structured response formats.
|
|
2208
|
+
|
|
2209
|
+
The response_schema field accepts either a type or a string:
|
|
2210
|
+
- Type (Pydantic model, dataclass, etc.): Used directly for structured output
|
|
2211
|
+
- String: First attempts to load as a fully qualified type name, falls back to JSON schema string
|
|
2212
|
+
|
|
2213
|
+
This unified approach simplifies the API while maintaining flexibility.
|
|
2214
|
+
"""
|
|
2215
|
+
|
|
1366
2216
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
2217
|
+
use_tool: Optional[bool] = Field(
|
|
2218
|
+
default=None,
|
|
2219
|
+
description=(
|
|
2220
|
+
"Strategy for structured output: "
|
|
2221
|
+
"None (default) = auto-detect from model capabilities, "
|
|
2222
|
+
"False = force ProviderStrategy (native), "
|
|
2223
|
+
"True = force ToolStrategy (function calling)"
|
|
2224
|
+
),
|
|
2225
|
+
)
|
|
2226
|
+
response_schema: Optional[str | type] = Field(
|
|
2227
|
+
default=None,
|
|
2228
|
+
description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
|
|
2229
|
+
)
|
|
1374
2230
|
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
2231
|
+
def as_strategy(self) -> ProviderStrategy | ToolStrategy:
|
|
2232
|
+
"""
|
|
2233
|
+
Convert response_schema to appropriate LangChain strategy.
|
|
1378
2234
|
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
2235
|
+
Returns:
|
|
2236
|
+
- None if no response_schema configured
|
|
2237
|
+
- Raw schema/type for auto-detection (when use_tool=None)
|
|
2238
|
+
- ToolStrategy wrapping the schema (when use_tool=True)
|
|
2239
|
+
- ProviderStrategy wrapping the schema (when use_tool=False)
|
|
1382
2240
|
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
if self.schema_model:
|
|
1387
|
-
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1388
|
-
return prompt_name
|
|
2241
|
+
Raises:
|
|
2242
|
+
ValueError: If response_schema is a JSON schema string that cannot be parsed
|
|
2243
|
+
"""
|
|
1389
2244
|
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
2245
|
+
if self.response_schema is None:
|
|
2246
|
+
return None
|
|
1393
2247
|
|
|
1394
|
-
|
|
1395
|
-
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1396
|
-
elif self.version:
|
|
1397
|
-
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1398
|
-
else:
|
|
1399
|
-
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
2248
|
+
schema = self.response_schema
|
|
1400
2249
|
|
|
1401
|
-
|
|
2250
|
+
# Handle type schemas (Pydantic, dataclass, etc.)
|
|
2251
|
+
if self.is_type_schema:
|
|
2252
|
+
if self.use_tool is None:
|
|
2253
|
+
# Auto-detect: Pass schema directly, let LangChain decide
|
|
2254
|
+
return schema
|
|
2255
|
+
elif self.use_tool is True:
|
|
2256
|
+
# Force ToolStrategy (function calling)
|
|
2257
|
+
return ToolStrategy(schema)
|
|
2258
|
+
else: # use_tool is False
|
|
2259
|
+
# Force ProviderStrategy (native structured output)
|
|
2260
|
+
return ProviderStrategy(schema)
|
|
1402
2261
|
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
2262
|
+
# Handle JSON schema strings
|
|
2263
|
+
elif self.is_json_schema:
|
|
2264
|
+
import json
|
|
2265
|
+
|
|
2266
|
+
try:
|
|
2267
|
+
schema_dict = json.loads(schema)
|
|
2268
|
+
except json.JSONDecodeError as e:
|
|
2269
|
+
raise ValueError(f"Invalid JSON schema string: {e}") from e
|
|
2270
|
+
|
|
2271
|
+
# Apply same use_tool logic as type schemas
|
|
2272
|
+
if self.use_tool is None:
|
|
2273
|
+
# Auto-detect
|
|
2274
|
+
return schema_dict
|
|
2275
|
+
elif self.use_tool is True:
|
|
2276
|
+
# Force ToolStrategy
|
|
2277
|
+
return ToolStrategy(schema_dict)
|
|
2278
|
+
else: # use_tool is False
|
|
2279
|
+
# Force ProviderStrategy
|
|
2280
|
+
return ProviderStrategy(schema_dict)
|
|
2281
|
+
|
|
2282
|
+
return None
|
|
1406
2283
|
|
|
1407
2284
|
@model_validator(mode="after")
|
|
1408
|
-
def
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
2285
|
+
def validate_response_schema(self) -> Self:
|
|
2286
|
+
"""
|
|
2287
|
+
Validate and convert response_schema.
|
|
2288
|
+
|
|
2289
|
+
Processing logic:
|
|
2290
|
+
1. If None: no response format specified
|
|
2291
|
+
2. If type: use directly as structured output type
|
|
2292
|
+
3. If str: try to load as FQN using type_from_fqn
|
|
2293
|
+
- Success: response_schema becomes the loaded type
|
|
2294
|
+
- Failure: keep as string (treated as JSON schema)
|
|
2295
|
+
|
|
2296
|
+
After validation, response_schema is one of:
|
|
2297
|
+
- None (no schema)
|
|
2298
|
+
- type (use for structured output)
|
|
2299
|
+
- str (JSON schema)
|
|
2300
|
+
|
|
2301
|
+
Returns:
|
|
2302
|
+
Self with validated response_schema
|
|
2303
|
+
"""
|
|
2304
|
+
if self.response_schema is None:
|
|
2305
|
+
return self
|
|
2306
|
+
|
|
2307
|
+
# If already a type, return
|
|
2308
|
+
if isinstance(self.response_schema, type):
|
|
2309
|
+
return self
|
|
2310
|
+
|
|
2311
|
+
# If it's a string, try to load as type, fallback to json_schema
|
|
2312
|
+
if isinstance(self.response_schema, str):
|
|
2313
|
+
from dao_ai.utils import type_from_fqn
|
|
2314
|
+
|
|
2315
|
+
try:
|
|
2316
|
+
resolved_type = type_from_fqn(self.response_schema)
|
|
2317
|
+
self.response_schema = resolved_type
|
|
2318
|
+
logger.debug(
|
|
2319
|
+
f"Resolved response_schema string to type: {resolved_type}"
|
|
2320
|
+
)
|
|
2321
|
+
return self
|
|
2322
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
2323
|
+
# Keep as string - it's a JSON schema
|
|
2324
|
+
logger.debug(
|
|
2325
|
+
f"Could not resolve '{self.response_schema}' as type: {e}. "
|
|
2326
|
+
f"Treating as JSON schema string."
|
|
2327
|
+
)
|
|
2328
|
+
return self
|
|
2329
|
+
|
|
2330
|
+
# Invalid type
|
|
2331
|
+
raise ValueError(
|
|
2332
|
+
f"response_schema must be None, type, or str, got {type(self.response_schema)}"
|
|
2333
|
+
)
|
|
2334
|
+
|
|
2335
|
+
@property
|
|
2336
|
+
def is_type_schema(self) -> bool:
|
|
2337
|
+
"""Returns True if response_schema is a type (not JSON schema string)."""
|
|
2338
|
+
return isinstance(self.response_schema, type)
|
|
2339
|
+
|
|
2340
|
+
@property
|
|
2341
|
+
def is_json_schema(self) -> bool:
|
|
2342
|
+
"""Returns True if response_schema is a JSON schema string (not a type)."""
|
|
2343
|
+
return isinstance(self.response_schema, str)
|
|
1412
2344
|
|
|
1413
2345
|
|
|
1414
2346
|
class AgentModel(BaseModel):
|
|
2347
|
+
"""
|
|
2348
|
+
Configuration model for an agent in the DAO AI framework.
|
|
2349
|
+
|
|
2350
|
+
Agents combine an LLM with tools and middleware to create systems that can
|
|
2351
|
+
reason about tasks, decide which tools to use, and iteratively work towards solutions.
|
|
2352
|
+
|
|
2353
|
+
Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
|
|
2354
|
+
providing a more flexible and composable way to customize agent behavior.
|
|
2355
|
+
"""
|
|
2356
|
+
|
|
1415
2357
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1416
2358
|
name: str
|
|
1417
2359
|
description: Optional[str] = None
|
|
@@ -1420,9 +2362,43 @@ class AgentModel(BaseModel):
|
|
|
1420
2362
|
guardrails: list[GuardrailModel] = Field(default_factory=list)
|
|
1421
2363
|
prompt: Optional[str | PromptModel] = None
|
|
1422
2364
|
handoff_prompt: Optional[str] = None
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
2365
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2366
|
+
default_factory=list,
|
|
2367
|
+
description="List of middleware to apply to this agent",
|
|
2368
|
+
)
|
|
2369
|
+
response_format: Optional[ResponseFormatModel | type | str] = None
|
|
2370
|
+
|
|
2371
|
+
@model_validator(mode="after")
|
|
2372
|
+
def validate_response_format(self) -> Self:
|
|
2373
|
+
"""
|
|
2374
|
+
Validate and normalize response_format.
|
|
2375
|
+
|
|
2376
|
+
Accepts:
|
|
2377
|
+
- None (no response format)
|
|
2378
|
+
- ResponseFormatModel (already validated)
|
|
2379
|
+
- type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
|
|
2380
|
+
- str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
|
|
2381
|
+
|
|
2382
|
+
ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
|
|
2383
|
+
"""
|
|
2384
|
+
if self.response_format is None or isinstance(
|
|
2385
|
+
self.response_format, ResponseFormatModel
|
|
2386
|
+
):
|
|
2387
|
+
return self
|
|
2388
|
+
|
|
2389
|
+
# Convert type or str to ResponseFormatModel
|
|
2390
|
+
# ResponseFormatModel's validator will handle the smart type loading and fallback
|
|
2391
|
+
if isinstance(self.response_format, (type, str)):
|
|
2392
|
+
self.response_format = ResponseFormatModel(
|
|
2393
|
+
response_schema=self.response_format
|
|
2394
|
+
)
|
|
2395
|
+
return self
|
|
2396
|
+
|
|
2397
|
+
# Invalid type
|
|
2398
|
+
raise ValueError(
|
|
2399
|
+
f"response_format must be None, ResponseFormatModel, type, or str, "
|
|
2400
|
+
f"got {type(self.response_format)}"
|
|
2401
|
+
)
|
|
1426
2402
|
|
|
1427
2403
|
def as_runnable(self) -> RunnableLike:
|
|
1428
2404
|
from dao_ai.nodes import create_agent_node
|
|
@@ -1441,12 +2417,20 @@ class SupervisorModel(BaseModel):
|
|
|
1441
2417
|
model: LLMModel
|
|
1442
2418
|
tools: list[ToolModel] = Field(default_factory=list)
|
|
1443
2419
|
prompt: Optional[str] = None
|
|
2420
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2421
|
+
default_factory=list,
|
|
2422
|
+
description="List of middleware to apply to the supervisor",
|
|
2423
|
+
)
|
|
1444
2424
|
|
|
1445
2425
|
|
|
1446
2426
|
class SwarmModel(BaseModel):
|
|
1447
2427
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1448
2428
|
model: LLMModel
|
|
1449
2429
|
default_agent: Optional[AgentModel | str] = None
|
|
2430
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2431
|
+
default_factory=list,
|
|
2432
|
+
description="List of middleware to apply to all agents in the swarm",
|
|
2433
|
+
)
|
|
1450
2434
|
handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
|
|
1451
2435
|
default_factory=dict
|
|
1452
2436
|
)
|
|
@@ -1459,7 +2443,7 @@ class OrchestrationModel(BaseModel):
|
|
|
1459
2443
|
memory: Optional[MemoryModel] = None
|
|
1460
2444
|
|
|
1461
2445
|
@model_validator(mode="after")
|
|
1462
|
-
def validate_mutually_exclusive(self):
|
|
2446
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1463
2447
|
if self.supervisor is not None and self.swarm is not None:
|
|
1464
2448
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
1465
2449
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -1489,9 +2473,21 @@ class Entitlement(str, Enum):
|
|
|
1489
2473
|
|
|
1490
2474
|
class AppPermissionModel(BaseModel):
|
|
1491
2475
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1492
|
-
principals: list[str] = Field(default_factory=list)
|
|
2476
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
1493
2477
|
entitlements: list[Entitlement]
|
|
1494
2478
|
|
|
2479
|
+
@model_validator(mode="after")
|
|
2480
|
+
def resolve_principals(self) -> Self:
|
|
2481
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
2482
|
+
resolved: list[str] = []
|
|
2483
|
+
for principal in self.principals:
|
|
2484
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
2485
|
+
resolved.append(value_of(principal.client_id))
|
|
2486
|
+
else:
|
|
2487
|
+
resolved.append(principal)
|
|
2488
|
+
self.principals = resolved
|
|
2489
|
+
return self
|
|
2490
|
+
|
|
1495
2491
|
|
|
1496
2492
|
class LogLevel(str, Enum):
|
|
1497
2493
|
TRACE = "TRACE"
|
|
@@ -1552,6 +2548,28 @@ class ChatPayload(BaseModel):
|
|
|
1552
2548
|
|
|
1553
2549
|
return self
|
|
1554
2550
|
|
|
2551
|
+
@model_validator(mode="after")
|
|
2552
|
+
def ensure_thread_id(self) -> "ChatPayload":
|
|
2553
|
+
"""Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
|
|
2554
|
+
import uuid
|
|
2555
|
+
|
|
2556
|
+
if self.custom_inputs is None:
|
|
2557
|
+
self.custom_inputs = {}
|
|
2558
|
+
|
|
2559
|
+
# Get or create configurable section
|
|
2560
|
+
configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
|
|
2561
|
+
|
|
2562
|
+
# Check if thread_id or conversation_id exists
|
|
2563
|
+
has_thread_id = configurable.get("thread_id") is not None
|
|
2564
|
+
has_conversation_id = configurable.get("conversation_id") is not None
|
|
2565
|
+
|
|
2566
|
+
# If neither is provided, generate a UUID for conversation_id
|
|
2567
|
+
if not has_thread_id and not has_conversation_id:
|
|
2568
|
+
configurable["conversation_id"] = str(uuid.uuid4())
|
|
2569
|
+
self.custom_inputs["configurable"] = configurable
|
|
2570
|
+
|
|
2571
|
+
return self
|
|
2572
|
+
|
|
1555
2573
|
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1556
2574
|
return messages_from_dict(
|
|
1557
2575
|
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
@@ -1567,25 +2585,44 @@ class ChatPayload(BaseModel):
|
|
|
1567
2585
|
|
|
1568
2586
|
|
|
1569
2587
|
class ChatHistoryModel(BaseModel):
|
|
2588
|
+
"""
|
|
2589
|
+
Configuration for chat history summarization.
|
|
2590
|
+
|
|
2591
|
+
Attributes:
|
|
2592
|
+
model: The LLM to use for generating summaries.
|
|
2593
|
+
max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
|
|
2594
|
+
After summarization, recent messages totaling up to this many tokens are preserved.
|
|
2595
|
+
max_tokens_before_summary: Token threshold that triggers summarization.
|
|
2596
|
+
When conversation exceeds this, summarization runs. Mutually exclusive with
|
|
2597
|
+
max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
|
|
2598
|
+
max_messages_before_summary: Message count threshold that triggers summarization.
|
|
2599
|
+
When conversation exceeds this many messages, summarization runs.
|
|
2600
|
+
Mutually exclusive with max_tokens_before_summary.
|
|
2601
|
+
"""
|
|
2602
|
+
|
|
1570
2603
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1571
2604
|
model: LLMModel
|
|
1572
|
-
max_tokens: int =
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
2605
|
+
max_tokens: int = Field(
|
|
2606
|
+
default=2048,
|
|
2607
|
+
gt=0,
|
|
2608
|
+
description="Maximum tokens to keep after summarization",
|
|
2609
|
+
)
|
|
2610
|
+
max_tokens_before_summary: Optional[int] = Field(
|
|
2611
|
+
default=None,
|
|
2612
|
+
gt=0,
|
|
2613
|
+
description="Token threshold that triggers summarization",
|
|
2614
|
+
)
|
|
2615
|
+
max_messages_before_summary: Optional[int] = Field(
|
|
2616
|
+
default=None,
|
|
2617
|
+
gt=0,
|
|
2618
|
+
description="Message count threshold that triggers summarization",
|
|
2619
|
+
)
|
|
1584
2620
|
|
|
1585
2621
|
|
|
1586
2622
|
class AppModel(BaseModel):
|
|
1587
2623
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1588
2624
|
name: str
|
|
2625
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1589
2626
|
description: Optional[str] = None
|
|
1590
2627
|
log_level: Optional[LogLevel] = "WARNING"
|
|
1591
2628
|
registered_model: RegisteredModelModel
|
|
@@ -1606,23 +2643,54 @@ class AppModel(BaseModel):
|
|
|
1606
2643
|
shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1607
2644
|
default_factory=list
|
|
1608
2645
|
)
|
|
1609
|
-
message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1610
|
-
default_factory=list
|
|
1611
|
-
)
|
|
1612
2646
|
input_example: Optional[ChatPayload] = None
|
|
1613
2647
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1614
2648
|
code_paths: list[str] = Field(default_factory=list)
|
|
1615
2649
|
pip_requirements: list[str] = Field(default_factory=list)
|
|
2650
|
+
python_version: Optional[str] = Field(
|
|
2651
|
+
default="3.12",
|
|
2652
|
+
description="Python version for Model Serving deployment. Defaults to 3.12 "
|
|
2653
|
+
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
2654
|
+
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
2655
|
+
)
|
|
2656
|
+
|
|
2657
|
+
@model_validator(mode="after")
|
|
2658
|
+
def set_databricks_env_vars(self) -> Self:
|
|
2659
|
+
"""Set Databricks environment variables for Model Serving.
|
|
2660
|
+
|
|
2661
|
+
Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
|
|
2662
|
+
Values explicitly provided in environment_vars take precedence.
|
|
2663
|
+
"""
|
|
2664
|
+
from dao_ai.utils import get_default_databricks_host
|
|
2665
|
+
|
|
2666
|
+
# Set DATABRICKS_HOST if not already provided
|
|
2667
|
+
if "DATABRICKS_HOST" not in self.environment_vars:
|
|
2668
|
+
host: str | None = get_default_databricks_host()
|
|
2669
|
+
if host:
|
|
2670
|
+
self.environment_vars["DATABRICKS_HOST"] = host
|
|
2671
|
+
|
|
2672
|
+
# Set service principal credentials if provided
|
|
2673
|
+
if self.service_principal is not None:
|
|
2674
|
+
if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
|
|
2675
|
+
self.environment_vars["DATABRICKS_CLIENT_ID"] = (
|
|
2676
|
+
self.service_principal.client_id
|
|
2677
|
+
)
|
|
2678
|
+
if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
|
|
2679
|
+
self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
|
|
2680
|
+
self.service_principal.client_secret
|
|
2681
|
+
)
|
|
2682
|
+
return self
|
|
1616
2683
|
|
|
1617
2684
|
@model_validator(mode="after")
|
|
1618
|
-
def validate_agents_not_empty(self):
|
|
2685
|
+
def validate_agents_not_empty(self) -> Self:
|
|
1619
2686
|
if not self.agents:
|
|
1620
2687
|
raise ValueError("At least one agent must be specified")
|
|
1621
2688
|
return self
|
|
1622
2689
|
|
|
1623
2690
|
@model_validator(mode="after")
|
|
1624
|
-
def
|
|
2691
|
+
def resolve_environment_vars(self) -> Self:
|
|
1625
2692
|
for key, value in self.environment_vars.items():
|
|
2693
|
+
updated_value: str
|
|
1626
2694
|
if isinstance(value, SecretVariableModel):
|
|
1627
2695
|
updated_value = str(value)
|
|
1628
2696
|
else:
|
|
@@ -1632,7 +2700,7 @@ class AppModel(BaseModel):
|
|
|
1632
2700
|
return self
|
|
1633
2701
|
|
|
1634
2702
|
@model_validator(mode="after")
|
|
1635
|
-
def set_default_orchestration(self):
|
|
2703
|
+
def set_default_orchestration(self) -> Self:
|
|
1636
2704
|
if self.orchestration is None:
|
|
1637
2705
|
if len(self.agents) > 1:
|
|
1638
2706
|
default_agent: AgentModel = self.agents[0]
|
|
@@ -1652,14 +2720,14 @@ class AppModel(BaseModel):
|
|
|
1652
2720
|
return self
|
|
1653
2721
|
|
|
1654
2722
|
@model_validator(mode="after")
|
|
1655
|
-
def set_default_endpoint_name(self):
|
|
2723
|
+
def set_default_endpoint_name(self) -> Self:
|
|
1656
2724
|
if self.endpoint_name is None:
|
|
1657
2725
|
self.endpoint_name = self.name
|
|
1658
2726
|
return self
|
|
1659
2727
|
|
|
1660
2728
|
@model_validator(mode="after")
|
|
1661
|
-
def set_default_agent(self):
|
|
1662
|
-
default_agent_name = self.agents[0].name
|
|
2729
|
+
def set_default_agent(self) -> Self:
|
|
2730
|
+
default_agent_name: str = self.agents[0].name
|
|
1663
2731
|
|
|
1664
2732
|
if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
|
|
1665
2733
|
self.orchestration.swarm.default_agent = default_agent_name
|
|
@@ -1667,7 +2735,7 @@ class AppModel(BaseModel):
|
|
|
1667
2735
|
return self
|
|
1668
2736
|
|
|
1669
2737
|
@model_validator(mode="after")
|
|
1670
|
-
def add_code_paths_to_sys_path(self):
|
|
2738
|
+
def add_code_paths_to_sys_path(self) -> Self:
|
|
1671
2739
|
for code_path in self.code_paths:
|
|
1672
2740
|
parent_path: str = str(Path(code_path).parent)
|
|
1673
2741
|
if parent_path not in sys.path:
|
|
@@ -1700,7 +2768,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
|
|
|
1700
2768
|
expected_facts: Optional[list[str]] = None
|
|
1701
2769
|
|
|
1702
2770
|
@model_validator(mode="after")
|
|
1703
|
-
def validate_mutually_exclusive(self):
|
|
2771
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1704
2772
|
if self.expected_response is not None and self.expected_facts is not None:
|
|
1705
2773
|
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1706
2774
|
return self
|
|
@@ -1779,36 +2847,70 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
|
1779
2847
|
|
|
1780
2848
|
|
|
1781
2849
|
class PromptOptimizationModel(BaseModel):
|
|
2850
|
+
"""Configuration for prompt optimization using GEPA.
|
|
2851
|
+
|
|
2852
|
+
GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
|
|
2853
|
+
optimizer that uses reflective mutation to improve prompts based on
|
|
2854
|
+
evaluation feedback.
|
|
2855
|
+
|
|
2856
|
+
Example:
|
|
2857
|
+
prompt_optimization:
|
|
2858
|
+
name: optimize_my_prompt
|
|
2859
|
+
prompt: *my_prompt
|
|
2860
|
+
agent: *my_agent
|
|
2861
|
+
dataset: *my_training_dataset
|
|
2862
|
+
reflection_model: databricks-meta-llama-3-3-70b-instruct
|
|
2863
|
+
num_candidates: 50
|
|
2864
|
+
"""
|
|
2865
|
+
|
|
1782
2866
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1783
2867
|
name: str
|
|
1784
2868
|
prompt: Optional[PromptModel] = None
|
|
1785
2869
|
agent: AgentModel
|
|
1786
|
-
dataset:
|
|
1787
|
-
EvaluationDatasetModel | str
|
|
1788
|
-
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
2870
|
+
dataset: EvaluationDatasetModel # Training dataset with examples
|
|
1789
2871
|
reflection_model: Optional[LLMModel | str] = None
|
|
1790
2872
|
num_candidates: Optional[int] = 50
|
|
1791
|
-
scorer_model: Optional[LLMModel | str] = None
|
|
1792
2873
|
|
|
1793
2874
|
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1794
2875
|
"""
|
|
1795
|
-
Optimize the prompt using
|
|
2876
|
+
Optimize the prompt using GEPA.
|
|
1796
2877
|
|
|
1797
2878
|
Args:
|
|
1798
|
-
w: Optional WorkspaceClient for
|
|
2879
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
1799
2880
|
|
|
1800
2881
|
Returns:
|
|
1801
|
-
PromptModel: The optimized prompt model
|
|
2882
|
+
PromptModel: The optimized prompt model
|
|
1802
2883
|
"""
|
|
1803
|
-
from dao_ai.
|
|
1804
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
2884
|
+
from dao_ai.optimization import OptimizationResult, optimize_prompt
|
|
1805
2885
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
2886
|
+
# Get reflection model name
|
|
2887
|
+
reflection_model_name: str | None = None
|
|
2888
|
+
if self.reflection_model:
|
|
2889
|
+
if isinstance(self.reflection_model, str):
|
|
2890
|
+
reflection_model_name = self.reflection_model
|
|
2891
|
+
else:
|
|
2892
|
+
reflection_model_name = self.reflection_model.uri
|
|
2893
|
+
|
|
2894
|
+
# Ensure prompt is set
|
|
2895
|
+
prompt = self.prompt
|
|
2896
|
+
if prompt is None:
|
|
2897
|
+
raise ValueError(
|
|
2898
|
+
f"Prompt optimization '{self.name}' requires a prompt to be set"
|
|
2899
|
+
)
|
|
2900
|
+
|
|
2901
|
+
result: OptimizationResult = optimize_prompt(
|
|
2902
|
+
prompt=prompt,
|
|
2903
|
+
agent=self.agent,
|
|
2904
|
+
dataset=self.dataset,
|
|
2905
|
+
reflection_model=reflection_model_name,
|
|
2906
|
+
num_candidates=self.num_candidates or 50,
|
|
2907
|
+
register_if_improved=True,
|
|
2908
|
+
)
|
|
2909
|
+
|
|
2910
|
+
return result.optimized_prompt
|
|
1809
2911
|
|
|
1810
2912
|
@model_validator(mode="after")
|
|
1811
|
-
def set_defaults(self):
|
|
2913
|
+
def set_defaults(self) -> Self:
|
|
1812
2914
|
# If no prompt is specified, try to use the agent's prompt
|
|
1813
2915
|
if self.prompt is None:
|
|
1814
2916
|
if isinstance(self.agent.prompt, PromptModel):
|
|
@@ -1819,12 +2921,6 @@ class PromptOptimizationModel(BaseModel):
|
|
|
1819
2921
|
f"or an agent with a prompt configured"
|
|
1820
2922
|
)
|
|
1821
2923
|
|
|
1822
|
-
if self.reflection_model is None:
|
|
1823
|
-
self.reflection_model = self.agent.model
|
|
1824
|
-
|
|
1825
|
-
if self.scorer_model is None:
|
|
1826
|
-
self.scorer_model = self.agent.model
|
|
1827
|
-
|
|
1828
2924
|
return self
|
|
1829
2925
|
|
|
1830
2926
|
|
|
@@ -1897,7 +2993,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
|
|
|
1897
2993
|
|
|
1898
2994
|
class UnityCatalogFunctionSqlModel(BaseModel):
|
|
1899
2995
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1900
|
-
function:
|
|
2996
|
+
function: FunctionModel
|
|
1901
2997
|
ddl: str
|
|
1902
2998
|
parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1903
2999
|
test: Optional[UnityCatalogFunctionSqlTestModel] = None
|
|
@@ -1925,16 +3021,126 @@ class ResourcesModel(BaseModel):
|
|
|
1925
3021
|
warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
|
|
1926
3022
|
databases: dict[str, DatabaseModel] = Field(default_factory=dict)
|
|
1927
3023
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
3024
|
+
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
3025
|
+
|
|
3026
|
+
@model_validator(mode="after")
|
|
3027
|
+
def update_genie_warehouses(self) -> Self:
|
|
3028
|
+
"""
|
|
3029
|
+
Automatically populate warehouses from genie_rooms.
|
|
3030
|
+
|
|
3031
|
+
Warehouses are extracted from each Genie room and added to the
|
|
3032
|
+
resources if they don't already exist (based on warehouse_id).
|
|
3033
|
+
"""
|
|
3034
|
+
if not self.genie_rooms:
|
|
3035
|
+
return self
|
|
3036
|
+
|
|
3037
|
+
# Process warehouses from all genie rooms
|
|
3038
|
+
for genie_room in self.genie_rooms.values():
|
|
3039
|
+
genie_room: GenieRoomModel
|
|
3040
|
+
warehouse: Optional[WarehouseModel] = genie_room.warehouse
|
|
3041
|
+
|
|
3042
|
+
if warehouse is None:
|
|
3043
|
+
continue
|
|
3044
|
+
|
|
3045
|
+
# Check if warehouse already exists based on warehouse_id
|
|
3046
|
+
warehouse_exists: bool = any(
|
|
3047
|
+
existing_warehouse.warehouse_id == warehouse.warehouse_id
|
|
3048
|
+
for existing_warehouse in self.warehouses.values()
|
|
3049
|
+
)
|
|
3050
|
+
|
|
3051
|
+
if not warehouse_exists:
|
|
3052
|
+
warehouse_key: str = normalize_name(
|
|
3053
|
+
"_".join([genie_room.name, warehouse.warehouse_id])
|
|
3054
|
+
)
|
|
3055
|
+
self.warehouses[warehouse_key] = warehouse
|
|
3056
|
+
logger.trace(
|
|
3057
|
+
"Added warehouse from Genie room",
|
|
3058
|
+
room=genie_room.name,
|
|
3059
|
+
warehouse=warehouse.warehouse_id,
|
|
3060
|
+
key=warehouse_key,
|
|
3061
|
+
)
|
|
3062
|
+
|
|
3063
|
+
return self
|
|
3064
|
+
|
|
3065
|
+
@model_validator(mode="after")
|
|
3066
|
+
def update_genie_tables(self) -> Self:
|
|
3067
|
+
"""
|
|
3068
|
+
Automatically populate tables from genie_rooms.
|
|
3069
|
+
|
|
3070
|
+
Tables are extracted from each Genie room and added to the
|
|
3071
|
+
resources if they don't already exist (based on full_name).
|
|
3072
|
+
"""
|
|
3073
|
+
if not self.genie_rooms:
|
|
3074
|
+
return self
|
|
3075
|
+
|
|
3076
|
+
# Process tables from all genie rooms
|
|
3077
|
+
for genie_room in self.genie_rooms.values():
|
|
3078
|
+
genie_room: GenieRoomModel
|
|
3079
|
+
for table in genie_room.tables:
|
|
3080
|
+
table: TableModel
|
|
3081
|
+
table_exists: bool = any(
|
|
3082
|
+
existing_table.full_name == table.full_name
|
|
3083
|
+
for existing_table in self.tables.values()
|
|
3084
|
+
)
|
|
3085
|
+
if not table_exists:
|
|
3086
|
+
table_key: str = normalize_name(
|
|
3087
|
+
"_".join([genie_room.name, table.full_name])
|
|
3088
|
+
)
|
|
3089
|
+
self.tables[table_key] = table
|
|
3090
|
+
logger.trace(
|
|
3091
|
+
"Added table from Genie room",
|
|
3092
|
+
room=genie_room.name,
|
|
3093
|
+
table=table.name,
|
|
3094
|
+
key=table_key,
|
|
3095
|
+
)
|
|
3096
|
+
|
|
3097
|
+
return self
|
|
3098
|
+
|
|
3099
|
+
@model_validator(mode="after")
|
|
3100
|
+
def update_genie_functions(self) -> Self:
|
|
3101
|
+
"""
|
|
3102
|
+
Automatically populate functions from genie_rooms.
|
|
3103
|
+
|
|
3104
|
+
Functions are extracted from each Genie room and added to the
|
|
3105
|
+
resources if they don't already exist (based on full_name).
|
|
3106
|
+
"""
|
|
3107
|
+
if not self.genie_rooms:
|
|
3108
|
+
return self
|
|
3109
|
+
|
|
3110
|
+
# Process functions from all genie rooms
|
|
3111
|
+
for genie_room in self.genie_rooms.values():
|
|
3112
|
+
genie_room: GenieRoomModel
|
|
3113
|
+
for function in genie_room.functions:
|
|
3114
|
+
function: FunctionModel
|
|
3115
|
+
function_exists: bool = any(
|
|
3116
|
+
existing_function.full_name == function.full_name
|
|
3117
|
+
for existing_function in self.functions.values()
|
|
3118
|
+
)
|
|
3119
|
+
if not function_exists:
|
|
3120
|
+
function_key: str = normalize_name(
|
|
3121
|
+
"_".join([genie_room.name, function.full_name])
|
|
3122
|
+
)
|
|
3123
|
+
self.functions[function_key] = function
|
|
3124
|
+
logger.trace(
|
|
3125
|
+
"Added function from Genie room",
|
|
3126
|
+
room=genie_room.name,
|
|
3127
|
+
function=function.name,
|
|
3128
|
+
key=function_key,
|
|
3129
|
+
)
|
|
3130
|
+
|
|
3131
|
+
return self
|
|
1928
3132
|
|
|
1929
3133
|
|
|
1930
3134
|
class AppConfig(BaseModel):
|
|
1931
3135
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1932
3136
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
3137
|
+
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
1933
3138
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
1934
3139
|
resources: Optional[ResourcesModel] = None
|
|
1935
3140
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
|
1936
3141
|
tools: dict[str, ToolModel] = Field(default_factory=dict)
|
|
1937
3142
|
guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
|
|
3143
|
+
middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
|
|
1938
3144
|
memory: Optional[MemoryModel] = None
|
|
1939
3145
|
prompts: dict[str, PromptModel] = Field(default_factory=dict)
|
|
1940
3146
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
@@ -1962,10 +3168,10 @@ class AppConfig(BaseModel):
|
|
|
1962
3168
|
|
|
1963
3169
|
def initialize(self) -> None:
|
|
1964
3170
|
from dao_ai.hooks.core import create_hooks
|
|
3171
|
+
from dao_ai.logging import configure_logging
|
|
1965
3172
|
|
|
1966
3173
|
if self.app and self.app.log_level:
|
|
1967
|
-
|
|
1968
|
-
logger.add(sys.stderr, level=self.app.log_level)
|
|
3174
|
+
configure_logging(level=self.app.log_level)
|
|
1969
3175
|
|
|
1970
3176
|
logger.debug("Calling initialization hooks...")
|
|
1971
3177
|
initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
|
|
@@ -2009,21 +3215,45 @@ class AppConfig(BaseModel):
|
|
|
2009
3215
|
def create_agent(
|
|
2010
3216
|
self,
|
|
2011
3217
|
w: WorkspaceClient | None = None,
|
|
3218
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3219
|
+
pat: str | None = None,
|
|
3220
|
+
client_id: str | None = None,
|
|
3221
|
+
client_secret: str | None = None,
|
|
3222
|
+
workspace_host: str | None = None,
|
|
2012
3223
|
) -> None:
|
|
2013
3224
|
from dao_ai.providers.base import ServiceProvider
|
|
2014
3225
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
2015
3226
|
|
|
2016
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3227
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3228
|
+
w=w,
|
|
3229
|
+
vsc=vsc,
|
|
3230
|
+
pat=pat,
|
|
3231
|
+
client_id=client_id,
|
|
3232
|
+
client_secret=client_secret,
|
|
3233
|
+
workspace_host=workspace_host,
|
|
3234
|
+
)
|
|
2017
3235
|
provider.create_agent(self)
|
|
2018
3236
|
|
|
2019
3237
|
def deploy_agent(
|
|
2020
3238
|
self,
|
|
2021
3239
|
w: WorkspaceClient | None = None,
|
|
3240
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3241
|
+
pat: str | None = None,
|
|
3242
|
+
client_id: str | None = None,
|
|
3243
|
+
client_secret: str | None = None,
|
|
3244
|
+
workspace_host: str | None = None,
|
|
2022
3245
|
) -> None:
|
|
2023
3246
|
from dao_ai.providers.base import ServiceProvider
|
|
2024
3247
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
2025
3248
|
|
|
2026
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3249
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3250
|
+
w=w,
|
|
3251
|
+
vsc=vsc,
|
|
3252
|
+
pat=pat,
|
|
3253
|
+
client_id=client_id,
|
|
3254
|
+
client_secret=client_secret,
|
|
3255
|
+
workspace_host=workspace_host,
|
|
3256
|
+
)
|
|
2027
3257
|
provider.deploy_agent(self)
|
|
2028
3258
|
|
|
2029
3259
|
def find_agents(
|