dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Iterator,
|
|
13
13
|
Literal,
|
|
14
14
|
Optional,
|
|
15
|
+
Self,
|
|
15
16
|
Sequence,
|
|
16
17
|
TypeAlias,
|
|
17
18
|
Union,
|
|
@@ -22,13 +23,20 @@ from databricks.sdk.credentials_provider import (
|
|
|
22
23
|
CredentialsStrategy,
|
|
23
24
|
ModelServingUserCredentials,
|
|
24
25
|
)
|
|
26
|
+
from databricks.sdk.errors.platform import NotFound
|
|
25
27
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
|
+
from databricks.sdk.service.dashboards import GenieSpace
|
|
26
29
|
from databricks.sdk.service.database import DatabaseInstance
|
|
30
|
+
from databricks.sdk.service.sql import GetWarehouseResponse
|
|
27
31
|
from databricks.vector_search.client import VectorSearchClient
|
|
28
32
|
from databricks.vector_search.index import VectorSearchIndex
|
|
29
33
|
from databricks_langchain import (
|
|
34
|
+
ChatDatabricks,
|
|
35
|
+
DatabricksEmbeddings,
|
|
30
36
|
DatabricksFunctionClient,
|
|
31
37
|
)
|
|
38
|
+
from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
|
|
39
|
+
from langchain_core.embeddings import Embeddings
|
|
32
40
|
from langchain_core.language_models import LanguageModelLike
|
|
33
41
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
34
42
|
from langchain_core.runnables.base import RunnableLike
|
|
@@ -41,6 +49,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
|
41
49
|
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
42
50
|
from mlflow.models import ModelConfig
|
|
43
51
|
from mlflow.models.resources import (
|
|
52
|
+
DatabricksApp,
|
|
44
53
|
DatabricksFunction,
|
|
45
54
|
DatabricksGenieSpace,
|
|
46
55
|
DatabricksLakebase,
|
|
@@ -59,10 +68,13 @@ from pydantic import (
|
|
|
59
68
|
BaseModel,
|
|
60
69
|
ConfigDict,
|
|
61
70
|
Field,
|
|
71
|
+
PrivateAttr,
|
|
62
72
|
field_serializer,
|
|
63
73
|
model_validator,
|
|
64
74
|
)
|
|
65
75
|
|
|
76
|
+
from dao_ai.utils import normalize_name
|
|
77
|
+
|
|
66
78
|
|
|
67
79
|
class HasValue(ABC):
|
|
68
80
|
@abstractmethod
|
|
@@ -81,27 +93,6 @@ class HasFullName(ABC):
|
|
|
81
93
|
def full_name(self) -> str: ...
|
|
82
94
|
|
|
83
95
|
|
|
84
|
-
class IsDatabricksResource(ABC):
|
|
85
|
-
on_behalf_of_user: Optional[bool] = False
|
|
86
|
-
|
|
87
|
-
@abstractmethod
|
|
88
|
-
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
89
|
-
|
|
90
|
-
@property
|
|
91
|
-
@abstractmethod
|
|
92
|
-
def api_scopes(self) -> Sequence[str]: ...
|
|
93
|
-
|
|
94
|
-
@property
|
|
95
|
-
def workspace_client(self) -> WorkspaceClient:
|
|
96
|
-
credentials_strategy: CredentialsStrategy = None
|
|
97
|
-
if self.on_behalf_of_user:
|
|
98
|
-
credentials_strategy = ModelServingUserCredentials()
|
|
99
|
-
logger.debug(
|
|
100
|
-
f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
|
|
101
|
-
)
|
|
102
|
-
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
103
|
-
|
|
104
|
-
|
|
105
96
|
class EnvironmentVariableModel(BaseModel, HasValue):
|
|
106
97
|
model_config = ConfigDict(
|
|
107
98
|
frozen=True,
|
|
@@ -200,6 +191,162 @@ AnyVariable: TypeAlias = (
|
|
|
200
191
|
)
|
|
201
192
|
|
|
202
193
|
|
|
194
|
+
class ServicePrincipalModel(BaseModel):
|
|
195
|
+
model_config = ConfigDict(
|
|
196
|
+
frozen=True,
|
|
197
|
+
use_enum_values=True,
|
|
198
|
+
)
|
|
199
|
+
client_id: AnyVariable
|
|
200
|
+
client_secret: AnyVariable
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class IsDatabricksResource(ABC, BaseModel):
|
|
204
|
+
"""
|
|
205
|
+
Base class for Databricks resources with authentication support.
|
|
206
|
+
|
|
207
|
+
Authentication Options:
|
|
208
|
+
----------------------
|
|
209
|
+
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
210
|
+
calling user's identity via ModelServingUserCredentials.
|
|
211
|
+
|
|
212
|
+
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
213
|
+
(client_id + client_secret + workspace_host) for service principal auth.
|
|
214
|
+
|
|
215
|
+
3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
|
|
216
|
+
to authenticate with a personal access token.
|
|
217
|
+
|
|
218
|
+
4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
|
|
219
|
+
(environment variables, notebook context, etc.)
|
|
220
|
+
|
|
221
|
+
Authentication Priority:
|
|
222
|
+
1. OBO (on_behalf_of_user=True)
|
|
223
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
224
|
+
3. PAT (pat + workspace_host)
|
|
225
|
+
4. Ambient/default authentication
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
229
|
+
|
|
230
|
+
on_behalf_of_user: Optional[bool] = False
|
|
231
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
232
|
+
client_id: Optional[AnyVariable] = None
|
|
233
|
+
client_secret: Optional[AnyVariable] = None
|
|
234
|
+
workspace_host: Optional[AnyVariable] = None
|
|
235
|
+
pat: Optional[AnyVariable] = None
|
|
236
|
+
|
|
237
|
+
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
+
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
+
|
|
240
|
+
@abstractmethod
|
|
241
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
@abstractmethod
|
|
245
|
+
def api_scopes(self) -> Sequence[str]: ...
|
|
246
|
+
|
|
247
|
+
@model_validator(mode="after")
|
|
248
|
+
def _expand_service_principal(self) -> Self:
|
|
249
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
250
|
+
if self.service_principal is not None:
|
|
251
|
+
if self.client_id is None:
|
|
252
|
+
self.client_id = self.service_principal.client_id
|
|
253
|
+
if self.client_secret is None:
|
|
254
|
+
self.client_secret = self.service_principal.client_secret
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
@model_validator(mode="after")
|
|
258
|
+
def _validate_auth_not_mixed(self) -> Self:
|
|
259
|
+
"""Validate that OAuth and PAT authentication are not both provided."""
|
|
260
|
+
has_oauth: bool = self.client_id is not None and self.client_secret is not None
|
|
261
|
+
has_pat: bool = self.pat is not None
|
|
262
|
+
|
|
263
|
+
if has_oauth and has_pat:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"Cannot use both OAuth and user authentication methods. "
|
|
266
|
+
"Please provide either OAuth credentials or user credentials."
|
|
267
|
+
)
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
272
|
+
"""
|
|
273
|
+
Get a WorkspaceClient configured with the appropriate authentication.
|
|
274
|
+
|
|
275
|
+
The client is lazily instantiated on first access and cached for subsequent calls.
|
|
276
|
+
|
|
277
|
+
Authentication priority:
|
|
278
|
+
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
279
|
+
2. If service principal credentials are configured (client_id, client_secret,
|
|
280
|
+
workspace_host), uses OAuth M2M
|
|
281
|
+
3. If PAT is configured, uses token authentication
|
|
282
|
+
4. Otherwise, uses default/ambient authentication
|
|
283
|
+
"""
|
|
284
|
+
# Return cached client if already instantiated
|
|
285
|
+
if self._workspace_client is not None:
|
|
286
|
+
return self._workspace_client
|
|
287
|
+
|
|
288
|
+
from dao_ai.utils import normalize_host
|
|
289
|
+
|
|
290
|
+
# Check for OBO first (highest priority)
|
|
291
|
+
if self.on_behalf_of_user:
|
|
292
|
+
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
293
|
+
logger.debug(
|
|
294
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
295
|
+
f"with OBO credentials strategy"
|
|
296
|
+
)
|
|
297
|
+
self._workspace_client = WorkspaceClient(
|
|
298
|
+
credentials_strategy=credentials_strategy
|
|
299
|
+
)
|
|
300
|
+
return self._workspace_client
|
|
301
|
+
|
|
302
|
+
# Check for service principal credentials
|
|
303
|
+
client_id_value: str | None = (
|
|
304
|
+
value_of(self.client_id) if self.client_id else None
|
|
305
|
+
)
|
|
306
|
+
client_secret_value: str | None = (
|
|
307
|
+
value_of(self.client_secret) if self.client_secret else None
|
|
308
|
+
)
|
|
309
|
+
workspace_host_value: str | None = (
|
|
310
|
+
normalize_host(value_of(self.workspace_host))
|
|
311
|
+
if self.workspace_host
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if client_id_value and client_secret_value and workspace_host_value:
|
|
316
|
+
logger.debug(
|
|
317
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
318
|
+
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
319
|
+
)
|
|
320
|
+
self._workspace_client = WorkspaceClient(
|
|
321
|
+
host=workspace_host_value,
|
|
322
|
+
client_id=client_id_value,
|
|
323
|
+
client_secret=client_secret_value,
|
|
324
|
+
auth_type="oauth-m2m",
|
|
325
|
+
)
|
|
326
|
+
return self._workspace_client
|
|
327
|
+
|
|
328
|
+
# Check for PAT authentication
|
|
329
|
+
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
330
|
+
if pat_value:
|
|
331
|
+
logger.debug(
|
|
332
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
333
|
+
)
|
|
334
|
+
self._workspace_client = WorkspaceClient(
|
|
335
|
+
host=workspace_host_value,
|
|
336
|
+
token=pat_value,
|
|
337
|
+
auth_type="pat",
|
|
338
|
+
)
|
|
339
|
+
return self._workspace_client
|
|
340
|
+
|
|
341
|
+
# Default: use ambient authentication
|
|
342
|
+
logger.debug(
|
|
343
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
344
|
+
"with default/ambient authentication"
|
|
345
|
+
)
|
|
346
|
+
self._workspace_client = WorkspaceClient()
|
|
347
|
+
return self._workspace_client
|
|
348
|
+
|
|
349
|
+
|
|
203
350
|
class Privilege(str, Enum):
|
|
204
351
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
205
352
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -226,9 +373,21 @@ class Privilege(str, Enum):
|
|
|
226
373
|
|
|
227
374
|
class PermissionModel(BaseModel):
|
|
228
375
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
229
|
-
principals: list[str] = Field(default_factory=list)
|
|
376
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
230
377
|
privileges: list[Privilege]
|
|
231
378
|
|
|
379
|
+
@model_validator(mode="after")
|
|
380
|
+
def resolve_principals(self) -> Self:
|
|
381
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
382
|
+
resolved: list[str] = []
|
|
383
|
+
for principal in self.principals:
|
|
384
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
385
|
+
resolved.append(value_of(principal.client_id))
|
|
386
|
+
else:
|
|
387
|
+
resolved.append(principal)
|
|
388
|
+
self.principals = resolved
|
|
389
|
+
return self
|
|
390
|
+
|
|
232
391
|
|
|
233
392
|
class SchemaModel(BaseModel, HasFullName):
|
|
234
393
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -248,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
248
407
|
provider.create_schema(self)
|
|
249
408
|
|
|
250
409
|
|
|
251
|
-
class
|
|
410
|
+
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
411
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
412
|
+
name: str
|
|
413
|
+
url: str
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def full_name(self) -> str:
|
|
417
|
+
return self.name
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def api_scopes(self) -> Sequence[str]:
|
|
421
|
+
return ["apps.apps"]
|
|
422
|
+
|
|
423
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
424
|
+
return [
|
|
425
|
+
DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class TableModel(IsDatabricksResource, HasFullName):
|
|
252
430
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
253
431
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
254
432
|
name: Optional[str] = None
|
|
@@ -274,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
274
452
|
def api_scopes(self) -> Sequence[str]:
|
|
275
453
|
return []
|
|
276
454
|
|
|
455
|
+
def exists(self) -> bool:
|
|
456
|
+
"""Check if the table exists in Unity Catalog.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
True if the table exists, False otherwise.
|
|
460
|
+
"""
|
|
461
|
+
try:
|
|
462
|
+
self.workspace_client.tables.get(full_name=self.full_name)
|
|
463
|
+
return True
|
|
464
|
+
except NotFound:
|
|
465
|
+
logger.debug(f"Table not found: {self.full_name}")
|
|
466
|
+
return False
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.warning(f"Error checking table existence for {self.full_name}: {e}")
|
|
469
|
+
return False
|
|
470
|
+
|
|
277
471
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
278
472
|
resources: list[DatabricksResource] = []
|
|
279
473
|
|
|
@@ -317,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
317
511
|
return resources
|
|
318
512
|
|
|
319
513
|
|
|
320
|
-
class LLMModel(
|
|
514
|
+
class LLMModel(IsDatabricksResource):
|
|
321
515
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
322
516
|
name: str
|
|
517
|
+
description: Optional[str] = None
|
|
323
518
|
temperature: Optional[float] = 0.1
|
|
324
519
|
max_tokens: Optional[int] = 8192
|
|
325
520
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
521
|
+
use_responses_api: Optional[bool] = Field(
|
|
522
|
+
default=False,
|
|
523
|
+
description="Use Responses API for ResponsesAgent endpoints",
|
|
524
|
+
)
|
|
326
525
|
|
|
327
526
|
@property
|
|
328
527
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -342,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
342
541
|
]
|
|
343
542
|
|
|
344
543
|
def as_chat_model(self) -> LanguageModelLike:
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
351
|
-
|
|
352
|
-
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
353
|
-
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
544
|
+
chat_client: LanguageModelLike = ChatDatabricks(
|
|
545
|
+
model=self.name,
|
|
546
|
+
temperature=self.temperature,
|
|
547
|
+
max_tokens=self.max_tokens,
|
|
548
|
+
use_responses_api=self.use_responses_api,
|
|
354
549
|
)
|
|
355
|
-
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
356
|
-
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
357
|
-
# )
|
|
358
550
|
|
|
359
551
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
360
552
|
for fallback in self.fallbacks:
|
|
@@ -386,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
386
578
|
|
|
387
579
|
return chat_client
|
|
388
580
|
|
|
581
|
+
def as_embeddings_model(self) -> Embeddings:
|
|
582
|
+
return DatabricksEmbeddings(endpoint=self.name)
|
|
583
|
+
|
|
389
584
|
|
|
390
585
|
class VectorSearchEndpointType(str, Enum):
|
|
391
586
|
STANDARD = "STANDARD"
|
|
@@ -405,7 +600,7 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
405
600
|
return str(value)
|
|
406
601
|
|
|
407
602
|
|
|
408
|
-
class IndexModel(
|
|
603
|
+
class IndexModel(IsDatabricksResource, HasFullName):
|
|
409
604
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
410
605
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
411
606
|
name: str
|
|
@@ -430,12 +625,297 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
430
625
|
]
|
|
431
626
|
|
|
432
627
|
|
|
433
|
-
class
|
|
628
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
629
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
630
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
631
|
+
name: Optional[str] = None
|
|
632
|
+
|
|
633
|
+
@model_validator(mode="after")
|
|
634
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
635
|
+
if not self.name and not self.schema_model:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
638
|
+
)
|
|
639
|
+
return self
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def full_name(self) -> str:
|
|
643
|
+
if self.schema_model:
|
|
644
|
+
name: str = ""
|
|
645
|
+
if self.name:
|
|
646
|
+
name = f".{self.name}"
|
|
647
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
648
|
+
return self.name
|
|
649
|
+
|
|
650
|
+
def exists(self) -> bool:
|
|
651
|
+
"""Check if the function exists in Unity Catalog.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
True if the function exists, False otherwise.
|
|
655
|
+
"""
|
|
656
|
+
try:
|
|
657
|
+
self.workspace_client.functions.get(name=self.full_name)
|
|
658
|
+
return True
|
|
659
|
+
except NotFound:
|
|
660
|
+
logger.debug(f"Function not found: {self.full_name}")
|
|
661
|
+
return False
|
|
662
|
+
except Exception as e:
|
|
663
|
+
logger.warning(
|
|
664
|
+
f"Error checking function existence for {self.full_name}: {e}"
|
|
665
|
+
)
|
|
666
|
+
return False
|
|
667
|
+
|
|
668
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
669
|
+
resources: list[DatabricksResource] = []
|
|
670
|
+
if self.name:
|
|
671
|
+
resources.append(
|
|
672
|
+
DatabricksFunction(
|
|
673
|
+
function_name=self.full_name,
|
|
674
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
675
|
+
)
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
w: WorkspaceClient = self.workspace_client
|
|
679
|
+
schema_full_name: str = self.schema_model.full_name
|
|
680
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
681
|
+
catalog_name=self.schema_model.catalog_name,
|
|
682
|
+
schema_name=self.schema_model.schema_name,
|
|
683
|
+
)
|
|
684
|
+
resources.extend(
|
|
685
|
+
[
|
|
686
|
+
DatabricksFunction(
|
|
687
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
688
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
689
|
+
)
|
|
690
|
+
for function in functions
|
|
691
|
+
]
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return resources
|
|
695
|
+
|
|
696
|
+
@property
|
|
697
|
+
def api_scopes(self) -> Sequence[str]:
|
|
698
|
+
return ["sql.statement-execution"]
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
class WarehouseModel(IsDatabricksResource):
|
|
702
|
+
model_config = ConfigDict()
|
|
703
|
+
name: str
|
|
704
|
+
description: Optional[str] = None
|
|
705
|
+
warehouse_id: AnyVariable
|
|
706
|
+
|
|
707
|
+
@property
|
|
708
|
+
def api_scopes(self) -> Sequence[str]:
|
|
709
|
+
return [
|
|
710
|
+
"sql.warehouses",
|
|
711
|
+
"sql.statement-execution",
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
715
|
+
return [
|
|
716
|
+
DatabricksSQLWarehouse(
|
|
717
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
718
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
719
|
+
)
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
@model_validator(mode="after")
|
|
723
|
+
def update_warehouse_id(self) -> Self:
|
|
724
|
+
self.warehouse_id = value_of(self.warehouse_id)
|
|
725
|
+
return self
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class GenieRoomModel(IsDatabricksResource):
|
|
434
729
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
435
730
|
name: str
|
|
436
731
|
description: Optional[str] = None
|
|
437
732
|
space_id: AnyVariable
|
|
438
733
|
|
|
734
|
+
_space_details: Optional[GenieSpace] = PrivateAttr(default=None)
|
|
735
|
+
|
|
736
|
+
def _get_space_details(self) -> GenieSpace:
|
|
737
|
+
if self._space_details is None:
|
|
738
|
+
self._space_details = self.workspace_client.genie.get_space(
|
|
739
|
+
space_id=self.space_id, include_serialized_space=True
|
|
740
|
+
)
|
|
741
|
+
return self._space_details
|
|
742
|
+
|
|
743
|
+
def _parse_serialized_space(self) -> dict[str, Any]:
|
|
744
|
+
"""Parse the serialized_space JSON string and return the parsed data."""
|
|
745
|
+
import json
|
|
746
|
+
|
|
747
|
+
space_details = self._get_space_details()
|
|
748
|
+
if not space_details.serialized_space:
|
|
749
|
+
return {}
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
return json.loads(space_details.serialized_space)
|
|
753
|
+
except json.JSONDecodeError as e:
|
|
754
|
+
logger.warning(f"Failed to parse serialized_space: {e}")
|
|
755
|
+
return {}
|
|
756
|
+
|
|
757
|
+
@property
|
|
758
|
+
def warehouse(self) -> Optional[WarehouseModel]:
|
|
759
|
+
"""Extract warehouse information from the Genie space.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
WarehouseModel instance if warehouse_id is available, None otherwise.
|
|
763
|
+
"""
|
|
764
|
+
space_details: GenieSpace = self._get_space_details()
|
|
765
|
+
|
|
766
|
+
if not space_details.warehouse_id:
|
|
767
|
+
return None
|
|
768
|
+
|
|
769
|
+
try:
|
|
770
|
+
response: GetWarehouseResponse = self.workspace_client.warehouses.get(
|
|
771
|
+
space_details.warehouse_id
|
|
772
|
+
)
|
|
773
|
+
warehouse_name: str = response.name or space_details.warehouse_id
|
|
774
|
+
|
|
775
|
+
warehouse_model = WarehouseModel(
|
|
776
|
+
name=warehouse_name,
|
|
777
|
+
warehouse_id=space_details.warehouse_id,
|
|
778
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
779
|
+
service_principal=self.service_principal,
|
|
780
|
+
client_id=self.client_id,
|
|
781
|
+
client_secret=self.client_secret,
|
|
782
|
+
workspace_host=self.workspace_host,
|
|
783
|
+
pat=self.pat,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# Share the cached workspace client if available
|
|
787
|
+
if self._workspace_client is not None:
|
|
788
|
+
warehouse_model._workspace_client = self._workspace_client
|
|
789
|
+
|
|
790
|
+
return warehouse_model
|
|
791
|
+
except Exception as e:
|
|
792
|
+
logger.warning(
|
|
793
|
+
f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
|
|
794
|
+
)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
@property
|
|
798
|
+
def tables(self) -> list[TableModel]:
|
|
799
|
+
"""Extract tables from the serialized Genie space.
|
|
800
|
+
|
|
801
|
+
Databricks Genie stores tables in: data_sources.tables[].identifier
|
|
802
|
+
Only includes tables that actually exist in Unity Catalog.
|
|
803
|
+
"""
|
|
804
|
+
parsed_space = self._parse_serialized_space()
|
|
805
|
+
tables_list: list[TableModel] = []
|
|
806
|
+
|
|
807
|
+
# Primary structure: data_sources.tables with 'identifier' field
|
|
808
|
+
if "data_sources" in parsed_space:
|
|
809
|
+
data_sources = parsed_space["data_sources"]
|
|
810
|
+
if isinstance(data_sources, dict) and "tables" in data_sources:
|
|
811
|
+
tables_data = data_sources["tables"]
|
|
812
|
+
if isinstance(tables_data, list):
|
|
813
|
+
for table_item in tables_data:
|
|
814
|
+
table_name: str | None = None
|
|
815
|
+
if isinstance(table_item, dict):
|
|
816
|
+
# Standard Databricks structure uses 'identifier'
|
|
817
|
+
table_name = table_item.get("identifier") or table_item.get(
|
|
818
|
+
"name"
|
|
819
|
+
)
|
|
820
|
+
elif isinstance(table_item, str):
|
|
821
|
+
table_name = table_item
|
|
822
|
+
|
|
823
|
+
if table_name:
|
|
824
|
+
table_model = TableModel(
|
|
825
|
+
name=table_name,
|
|
826
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
827
|
+
service_principal=self.service_principal,
|
|
828
|
+
client_id=self.client_id,
|
|
829
|
+
client_secret=self.client_secret,
|
|
830
|
+
workspace_host=self.workspace_host,
|
|
831
|
+
pat=self.pat,
|
|
832
|
+
)
|
|
833
|
+
# Share the cached workspace client if available
|
|
834
|
+
if self._workspace_client is not None:
|
|
835
|
+
table_model._workspace_client = self._workspace_client
|
|
836
|
+
|
|
837
|
+
# Verify the table exists before adding
|
|
838
|
+
if not table_model.exists():
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
tables_list.append(table_model)
|
|
842
|
+
|
|
843
|
+
return tables_list
|
|
844
|
+
|
|
845
|
+
@property
|
|
846
|
+
def functions(self) -> list[FunctionModel]:
|
|
847
|
+
"""Extract functions from the serialized Genie space.
|
|
848
|
+
|
|
849
|
+
Databricks Genie stores functions in multiple locations:
|
|
850
|
+
- instructions.sql_functions[].identifier (SQL functions)
|
|
851
|
+
- data_sources.functions[].identifier (other functions)
|
|
852
|
+
Only includes functions that actually exist in Unity Catalog.
|
|
853
|
+
"""
|
|
854
|
+
parsed_space = self._parse_serialized_space()
|
|
855
|
+
functions_list: list[FunctionModel] = []
|
|
856
|
+
seen_functions: set[str] = set()
|
|
857
|
+
|
|
858
|
+
def add_function_if_exists(function_name: str) -> None:
|
|
859
|
+
"""Helper to add a function if it exists and hasn't been added."""
|
|
860
|
+
if function_name in seen_functions:
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
seen_functions.add(function_name)
|
|
864
|
+
function_model = FunctionModel(
|
|
865
|
+
name=function_name,
|
|
866
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
867
|
+
service_principal=self.service_principal,
|
|
868
|
+
client_id=self.client_id,
|
|
869
|
+
client_secret=self.client_secret,
|
|
870
|
+
workspace_host=self.workspace_host,
|
|
871
|
+
pat=self.pat,
|
|
872
|
+
)
|
|
873
|
+
# Share the cached workspace client if available
|
|
874
|
+
if self._workspace_client is not None:
|
|
875
|
+
function_model._workspace_client = self._workspace_client
|
|
876
|
+
|
|
877
|
+
# Verify the function exists before adding
|
|
878
|
+
if not function_model.exists():
|
|
879
|
+
return
|
|
880
|
+
|
|
881
|
+
functions_list.append(function_model)
|
|
882
|
+
|
|
883
|
+
# Primary structure: instructions.sql_functions with 'identifier' field
|
|
884
|
+
if "instructions" in parsed_space:
|
|
885
|
+
instructions = parsed_space["instructions"]
|
|
886
|
+
if isinstance(instructions, dict) and "sql_functions" in instructions:
|
|
887
|
+
sql_functions_data = instructions["sql_functions"]
|
|
888
|
+
if isinstance(sql_functions_data, list):
|
|
889
|
+
for function_item in sql_functions_data:
|
|
890
|
+
if isinstance(function_item, dict):
|
|
891
|
+
# SQL functions use 'identifier' field
|
|
892
|
+
function_name = function_item.get(
|
|
893
|
+
"identifier"
|
|
894
|
+
) or function_item.get("name")
|
|
895
|
+
if function_name:
|
|
896
|
+
add_function_if_exists(function_name)
|
|
897
|
+
|
|
898
|
+
# Secondary structure: data_sources.functions with 'identifier' field
|
|
899
|
+
if "data_sources" in parsed_space:
|
|
900
|
+
data_sources = parsed_space["data_sources"]
|
|
901
|
+
if isinstance(data_sources, dict) and "functions" in data_sources:
|
|
902
|
+
functions_data = data_sources["functions"]
|
|
903
|
+
if isinstance(functions_data, list):
|
|
904
|
+
for function_item in functions_data:
|
|
905
|
+
function_name: str | None = None
|
|
906
|
+
if isinstance(function_item, dict):
|
|
907
|
+
# Standard Databricks structure uses 'identifier'
|
|
908
|
+
function_name = function_item.get(
|
|
909
|
+
"identifier"
|
|
910
|
+
) or function_item.get("name")
|
|
911
|
+
elif isinstance(function_item, str):
|
|
912
|
+
function_name = function_item
|
|
913
|
+
|
|
914
|
+
if function_name:
|
|
915
|
+
add_function_if_exists(function_name)
|
|
916
|
+
|
|
917
|
+
return functions_list
|
|
918
|
+
|
|
439
919
|
@property
|
|
440
920
|
def api_scopes(self) -> Sequence[str]:
|
|
441
921
|
return [
|
|
@@ -451,12 +931,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
451
931
|
]
|
|
452
932
|
|
|
453
933
|
@model_validator(mode="after")
|
|
454
|
-
def update_space_id(self):
|
|
934
|
+
def update_space_id(self) -> Self:
|
|
455
935
|
self.space_id = value_of(self.space_id)
|
|
456
936
|
return self
|
|
457
937
|
|
|
938
|
+
@model_validator(mode="after")
|
|
939
|
+
def update_description_from_space(self) -> Self:
|
|
940
|
+
"""Populate description from GenieSpace if not provided."""
|
|
941
|
+
if not self.description:
|
|
942
|
+
try:
|
|
943
|
+
space_details = self._get_space_details()
|
|
944
|
+
if space_details.description:
|
|
945
|
+
self.description = space_details.description
|
|
946
|
+
except Exception as e:
|
|
947
|
+
logger.debug(f"Could not fetch description from Genie space: {e}")
|
|
948
|
+
return self
|
|
949
|
+
|
|
458
950
|
|
|
459
|
-
class VolumeModel(
|
|
951
|
+
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
460
952
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
461
953
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
462
954
|
name: str
|
|
@@ -516,7 +1008,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
516
1008
|
provider.create_path(self)
|
|
517
1009
|
|
|
518
1010
|
|
|
519
|
-
class VectorStoreModel(
|
|
1011
|
+
class VectorStoreModel(IsDatabricksResource):
|
|
520
1012
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
521
1013
|
embedding_model: Optional[LLMModel] = None
|
|
522
1014
|
index: Optional[IndexModel] = None
|
|
@@ -530,13 +1022,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
530
1022
|
embedding_source_column: str
|
|
531
1023
|
|
|
532
1024
|
@model_validator(mode="after")
|
|
533
|
-
def set_default_embedding_model(self):
|
|
1025
|
+
def set_default_embedding_model(self) -> Self:
|
|
534
1026
|
if not self.embedding_model:
|
|
535
1027
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
536
1028
|
return self
|
|
537
1029
|
|
|
538
1030
|
@model_validator(mode="after")
|
|
539
|
-
def set_default_primary_key(self):
|
|
1031
|
+
def set_default_primary_key(self) -> Self:
|
|
540
1032
|
if self.primary_key is None:
|
|
541
1033
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
542
1034
|
|
|
@@ -557,14 +1049,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
557
1049
|
return self
|
|
558
1050
|
|
|
559
1051
|
@model_validator(mode="after")
|
|
560
|
-
def set_default_index(self):
|
|
1052
|
+
def set_default_index(self) -> Self:
|
|
561
1053
|
if self.index is None:
|
|
562
1054
|
name: str = f"{self.source_table.name}_index"
|
|
563
1055
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
564
1056
|
return self
|
|
565
1057
|
|
|
566
1058
|
@model_validator(mode="after")
|
|
567
|
-
def set_default_endpoint(self):
|
|
1059
|
+
def set_default_endpoint(self) -> Self:
|
|
568
1060
|
if self.endpoint is None:
|
|
569
1061
|
from dao_ai.providers.databricks import (
|
|
570
1062
|
DatabricksProvider,
|
|
@@ -615,62 +1107,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
615
1107
|
provider.create_vector_store(self)
|
|
616
1108
|
|
|
617
1109
|
|
|
618
|
-
class
|
|
619
|
-
model_config = ConfigDict()
|
|
620
|
-
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
621
|
-
name: Optional[str] = None
|
|
622
|
-
|
|
623
|
-
@model_validator(mode="after")
|
|
624
|
-
def validate_name_or_schema_required(self) -> "FunctionModel":
|
|
625
|
-
if not self.name and not self.schema_model:
|
|
626
|
-
raise ValueError(
|
|
627
|
-
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
628
|
-
)
|
|
629
|
-
return self
|
|
630
|
-
|
|
631
|
-
@property
|
|
632
|
-
def full_name(self) -> str:
|
|
633
|
-
if self.schema_model:
|
|
634
|
-
name: str = ""
|
|
635
|
-
if self.name:
|
|
636
|
-
name = f".{self.name}"
|
|
637
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
638
|
-
return self.name
|
|
639
|
-
|
|
640
|
-
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
641
|
-
resources: list[DatabricksResource] = []
|
|
642
|
-
if self.name:
|
|
643
|
-
resources.append(
|
|
644
|
-
DatabricksFunction(
|
|
645
|
-
function_name=self.full_name,
|
|
646
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
647
|
-
)
|
|
648
|
-
)
|
|
649
|
-
else:
|
|
650
|
-
w: WorkspaceClient = self.workspace_client
|
|
651
|
-
schema_full_name: str = self.schema_model.full_name
|
|
652
|
-
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
653
|
-
catalog_name=self.schema_model.catalog_name,
|
|
654
|
-
schema_name=self.schema_model.schema_name,
|
|
655
|
-
)
|
|
656
|
-
resources.extend(
|
|
657
|
-
[
|
|
658
|
-
DatabricksFunction(
|
|
659
|
-
function_name=f"{schema_full_name}.{function.name}",
|
|
660
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
661
|
-
)
|
|
662
|
-
for function in functions
|
|
663
|
-
]
|
|
664
|
-
)
|
|
665
|
-
|
|
666
|
-
return resources
|
|
667
|
-
|
|
668
|
-
@property
|
|
669
|
-
def api_scopes(self) -> Sequence[str]:
|
|
670
|
-
return ["sql.statement-execution"]
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
1110
|
+
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
674
1111
|
model_config = ConfigDict()
|
|
675
1112
|
name: str
|
|
676
1113
|
|
|
@@ -697,34 +1134,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
697
1134
|
]
|
|
698
1135
|
|
|
699
1136
|
|
|
700
|
-
class
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
1137
|
+
class DatabaseModel(IsDatabricksResource):
|
|
1138
|
+
"""
|
|
1139
|
+
Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
|
|
1140
|
+
|
|
1141
|
+
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
1142
|
+
- user/password: For user-based database authentication
|
|
1143
|
+
|
|
1144
|
+
Connection Types (determined by fields provided):
|
|
1145
|
+
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1146
|
+
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1147
|
+
|
|
1148
|
+
Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
|
|
1149
|
+
|
|
1150
|
+
Example Databricks Lakebase with Service Principal:
|
|
1151
|
+
```yaml
|
|
1152
|
+
databases:
|
|
1153
|
+
my_lakebase:
|
|
1154
|
+
name: my-database
|
|
1155
|
+
instance_name: my-lakebase-instance
|
|
1156
|
+
service_principal:
|
|
1157
|
+
client_id:
|
|
1158
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
1159
|
+
client_secret:
|
|
1160
|
+
scope: my-scope
|
|
1161
|
+
secret: sp-client-secret
|
|
1162
|
+
workspace_host:
|
|
1163
|
+
env: DATABRICKS_HOST
|
|
1164
|
+
```
|
|
1165
|
+
|
|
1166
|
+
Example Databricks Lakebase with Ambient Authentication:
|
|
1167
|
+
```yaml
|
|
1168
|
+
databases:
|
|
1169
|
+
my_lakebase:
|
|
1170
|
+
name: my-database
|
|
1171
|
+
instance_name: my-lakebase-instance
|
|
1172
|
+
on_behalf_of_user: true
|
|
1173
|
+
```
|
|
1174
|
+
|
|
1175
|
+
Example Standard PostgreSQL:
|
|
1176
|
+
```yaml
|
|
1177
|
+
databases:
|
|
1178
|
+
my_postgres:
|
|
1179
|
+
name: my-database
|
|
1180
|
+
host: my-postgres-host.example.com
|
|
1181
|
+
port: 5432
|
|
1182
|
+
database: my_db
|
|
1183
|
+
user: my_user
|
|
1184
|
+
password:
|
|
1185
|
+
env: PGPASSWORD
|
|
1186
|
+
```
|
|
1187
|
+
"""
|
|
726
1188
|
|
|
727
|
-
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
728
1189
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
729
1190
|
name: str
|
|
730
1191
|
instance_name: Optional[str] = None
|
|
@@ -737,80 +1198,137 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
737
1198
|
timeout_seconds: Optional[int] = 10
|
|
738
1199
|
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
739
1200
|
node_count: Optional[int] = None
|
|
1201
|
+
# Database-specific auth (user identity for DB connection)
|
|
740
1202
|
user: Optional[AnyVariable] = None
|
|
741
1203
|
password: Optional[AnyVariable] = None
|
|
742
|
-
client_id: Optional[AnyVariable] = None
|
|
743
|
-
client_secret: Optional[AnyVariable] = None
|
|
744
|
-
workspace_host: Optional[AnyVariable] = None
|
|
745
1204
|
|
|
746
1205
|
@property
|
|
747
1206
|
def api_scopes(self) -> Sequence[str]:
|
|
748
|
-
return []
|
|
1207
|
+
return ["database.database-instances"]
|
|
1208
|
+
|
|
1209
|
+
@property
|
|
1210
|
+
def is_lakebase(self) -> bool:
|
|
1211
|
+
"""Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
|
|
1212
|
+
return self.instance_name is not None
|
|
749
1213
|
|
|
750
1214
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
1215
|
+
if self.is_lakebase:
|
|
1216
|
+
return [
|
|
1217
|
+
DatabricksLakebase(
|
|
1218
|
+
database_instance_name=self.instance_name,
|
|
1219
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
1220
|
+
)
|
|
1221
|
+
]
|
|
1222
|
+
return []
|
|
757
1223
|
|
|
758
1224
|
@model_validator(mode="after")
|
|
759
|
-
def
|
|
760
|
-
|
|
761
|
-
self.instance_name = self.name
|
|
1225
|
+
def validate_connection_type(self) -> Self:
|
|
1226
|
+
"""Validate connection configuration based on type.
|
|
762
1227
|
|
|
1228
|
+
- If instance_name is provided: Databricks Lakebase connection
|
|
1229
|
+
(host is optional - will be fetched from API if not provided)
|
|
1230
|
+
- If only host is provided: Standard PostgreSQL connection
|
|
1231
|
+
(must not have instance_name)
|
|
1232
|
+
"""
|
|
1233
|
+
if not self.instance_name and not self.host:
|
|
1234
|
+
raise ValueError(
|
|
1235
|
+
"Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
|
|
1236
|
+
)
|
|
763
1237
|
return self
|
|
764
1238
|
|
|
765
1239
|
@model_validator(mode="after")
|
|
766
|
-
def update_user(self):
|
|
767
|
-
if
|
|
1240
|
+
def update_user(self) -> Self:
|
|
1241
|
+
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
1242
|
+
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
768
1243
|
return self
|
|
769
1244
|
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
1245
|
+
# For standard PostgreSQL, we need explicit user credentials
|
|
1246
|
+
# For Lakebase with no auth, ambient auth is allowed
|
|
1247
|
+
if not self.is_lakebase:
|
|
1248
|
+
# Standard PostgreSQL - try to determine current user for local development
|
|
1249
|
+
try:
|
|
1250
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
logger.warning(
|
|
1253
|
+
f"Could not determine current user for PostgreSQL database: {e}. "
|
|
1254
|
+
f"Please provide explicit user credentials."
|
|
1255
|
+
)
|
|
1256
|
+
else:
|
|
1257
|
+
# For Lakebase, try to determine current user but don't fail if we can't
|
|
1258
|
+
try:
|
|
1259
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1260
|
+
except Exception:
|
|
1261
|
+
# If we can't determine user and no explicit auth, that's okay
|
|
1262
|
+
# for Lakebase with ambient auth - credentials will be injected at runtime
|
|
1263
|
+
pass
|
|
775
1264
|
|
|
776
1265
|
return self
|
|
777
1266
|
|
|
778
1267
|
@model_validator(mode="after")
|
|
779
|
-
def update_host(self):
|
|
1268
|
+
def update_host(self) -> Self:
|
|
780
1269
|
if self.host is not None:
|
|
781
1270
|
return self
|
|
782
1271
|
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
1272
|
+
# If instance_name is provided (Lakebase), try to fetch host from existing instance
|
|
1273
|
+
# This may fail for OBO/ambient auth during model logging (before deployment)
|
|
1274
|
+
if self.is_lakebase:
|
|
1275
|
+
try:
|
|
1276
|
+
existing_instance: DatabaseInstance = (
|
|
1277
|
+
self.workspace_client.database.get_database_instance(
|
|
1278
|
+
name=self.instance_name
|
|
1279
|
+
)
|
|
1280
|
+
)
|
|
1281
|
+
self.host = existing_instance.read_write_dns
|
|
1282
|
+
except Exception as e:
|
|
1283
|
+
# For Lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1284
|
+
# The host will need to be provided explicitly or fetched at runtime
|
|
1285
|
+
if self.on_behalf_of_user:
|
|
1286
|
+
logger.debug(
|
|
1287
|
+
f"Could not fetch host for database {self.instance_name} "
|
|
1288
|
+
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1289
|
+
)
|
|
1290
|
+
else:
|
|
1291
|
+
raise ValueError(
|
|
1292
|
+
f"Could not fetch host for database {self.instance_name}. "
|
|
1293
|
+
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1294
|
+
)
|
|
789
1295
|
return self
|
|
790
1296
|
|
|
791
1297
|
@model_validator(mode="after")
|
|
792
|
-
def validate_auth_methods(self):
|
|
1298
|
+
def validate_auth_methods(self) -> Self:
|
|
793
1299
|
oauth_fields: Sequence[Any] = [
|
|
794
1300
|
self.workspace_host,
|
|
795
1301
|
self.client_id,
|
|
796
1302
|
self.client_secret,
|
|
797
1303
|
]
|
|
798
1304
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1305
|
+
has_user_auth: bool = self.user is not None
|
|
1306
|
+
has_obo: bool = self.on_behalf_of_user is True
|
|
1307
|
+
has_pat: bool = self.pat is not None
|
|
799
1308
|
|
|
800
|
-
|
|
801
|
-
|
|
1309
|
+
# Count how many auth methods are configured
|
|
1310
|
+
auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
|
|
802
1311
|
|
|
803
|
-
if
|
|
1312
|
+
if auth_methods_count > 1:
|
|
804
1313
|
raise ValueError(
|
|
805
|
-
"Cannot
|
|
806
|
-
"Please provide
|
|
1314
|
+
"Cannot mix authentication methods. "
|
|
1315
|
+
"Please provide exactly one of: "
|
|
1316
|
+
"on_behalf_of_user=true (for passive auth in model serving), "
|
|
1317
|
+
"OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
|
|
1318
|
+
"PAT (personal access token), "
|
|
1319
|
+
"or user credentials (user)."
|
|
807
1320
|
)
|
|
808
1321
|
|
|
809
|
-
|
|
1322
|
+
# For standard PostgreSQL (host-based), at least one auth method must be configured
|
|
1323
|
+
# For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
|
|
1324
|
+
if not self.is_lakebase and auth_methods_count == 0:
|
|
810
1325
|
raise ValueError(
|
|
811
|
-
"
|
|
812
|
-
"
|
|
813
|
-
"
|
|
1326
|
+
"PostgreSQL databases require explicit authentication. "
|
|
1327
|
+
"Please provide one of: "
|
|
1328
|
+
"OAuth credentials (workspace_host, client_id, client_secret), "
|
|
1329
|
+
"service_principal with workspace_host, "
|
|
1330
|
+
"PAT (personal access token), "
|
|
1331
|
+
"or user credentials (user)."
|
|
814
1332
|
)
|
|
815
1333
|
|
|
816
1334
|
return self
|
|
@@ -821,38 +1339,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
821
1339
|
Get database connection parameters as a dictionary.
|
|
822
1340
|
|
|
823
1341
|
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
824
|
-
|
|
825
|
-
|
|
1342
|
+
|
|
1343
|
+
For Lakebase: Uses Databricks-generated credentials (token-based auth).
|
|
1344
|
+
For standard PostgreSQL: Uses provided user/password credentials.
|
|
826
1345
|
"""
|
|
827
|
-
|
|
828
|
-
|
|
1346
|
+
import uuid as _uuid
|
|
1347
|
+
|
|
1348
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
829
1349
|
|
|
1350
|
+
host: str
|
|
1351
|
+
port: int
|
|
1352
|
+
database: str
|
|
830
1353
|
username: str | None = None
|
|
1354
|
+
password_value: str | None = None
|
|
1355
|
+
|
|
1356
|
+
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1357
|
+
host_value: Any = self.host
|
|
1358
|
+
if host_value is None and self.is_lakebase and self.on_behalf_of_user:
|
|
1359
|
+
# Fetch host at runtime for OBO mode
|
|
1360
|
+
existing_instance: DatabaseInstance = (
|
|
1361
|
+
self.workspace_client.database.get_database_instance(
|
|
1362
|
+
name=self.instance_name
|
|
1363
|
+
)
|
|
1364
|
+
)
|
|
1365
|
+
host_value = existing_instance.read_write_dns
|
|
831
1366
|
|
|
832
|
-
if
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
1367
|
+
if host_value is None:
|
|
1368
|
+
instance_or_name = self.instance_name if self.is_lakebase else self.name
|
|
1369
|
+
raise ValueError(
|
|
1370
|
+
f"Database host not configured for {instance_or_name}. "
|
|
1371
|
+
"Please provide 'host' explicitly."
|
|
1372
|
+
)
|
|
836
1373
|
|
|
837
|
-
host
|
|
838
|
-
port
|
|
839
|
-
database
|
|
1374
|
+
host = value_of(host_value)
|
|
1375
|
+
port = value_of(self.port)
|
|
1376
|
+
database = value_of(self.database)
|
|
840
1377
|
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
client_secret
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
1378
|
+
if self.is_lakebase:
|
|
1379
|
+
# Lakebase: Use Databricks-generated credentials
|
|
1380
|
+
if self.client_id and self.client_secret and self.workspace_host:
|
|
1381
|
+
username = value_of(self.client_id)
|
|
1382
|
+
elif self.user:
|
|
1383
|
+
username = value_of(self.user)
|
|
1384
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
1385
|
+
|
|
1386
|
+
# Generate Databricks database credential (token)
|
|
1387
|
+
w: WorkspaceClient = self.workspace_client
|
|
1388
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1389
|
+
request_id=str(_uuid.uuid4()),
|
|
1390
|
+
instance_names=[self.instance_name],
|
|
1391
|
+
)
|
|
1392
|
+
password_value = cred.token
|
|
1393
|
+
else:
|
|
1394
|
+
# Standard PostgreSQL: Use provided credentials
|
|
1395
|
+
if self.user:
|
|
1396
|
+
username = value_of(self.user)
|
|
1397
|
+
if self.password:
|
|
1398
|
+
password_value = value_of(self.password)
|
|
847
1399
|
|
|
848
|
-
|
|
1400
|
+
if not username or not password_value:
|
|
1401
|
+
raise ValueError(
|
|
1402
|
+
f"Standard PostgreSQL databases require both 'user' and 'password'. "
|
|
1403
|
+
f"Database: {self.name}"
|
|
1404
|
+
)
|
|
849
1405
|
|
|
850
1406
|
# Build connection parameters dictionary
|
|
851
1407
|
params: dict[str, Any] = {
|
|
852
1408
|
"dbname": database,
|
|
853
1409
|
"host": host,
|
|
854
1410
|
"port": port,
|
|
855
|
-
"password":
|
|
1411
|
+
"password": password_value,
|
|
856
1412
|
"sslmode": "require",
|
|
857
1413
|
}
|
|
858
1414
|
|
|
@@ -883,11 +1439,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
883
1439
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
884
1440
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
885
1441
|
|
|
886
|
-
|
|
1442
|
+
# Use provided workspace client or fall back to resource's own workspace_client
|
|
1443
|
+
if w is None:
|
|
1444
|
+
w = self.workspace_client
|
|
1445
|
+
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
887
1446
|
provider.create_lakebase(self)
|
|
888
1447
|
provider.create_lakebase_instance_role(self)
|
|
889
1448
|
|
|
890
1449
|
|
|
1450
|
+
class GenieLRUCacheParametersModel(BaseModel):
|
|
1451
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1452
|
+
capacity: int = 1000
|
|
1453
|
+
time_to_live_seconds: int | None = (
|
|
1454
|
+
60 * 60 * 24
|
|
1455
|
+
) # 1 day default, None or negative = never expires
|
|
1456
|
+
warehouse: WarehouseModel
|
|
1457
|
+
|
|
1458
|
+
|
|
1459
|
+
class GenieSemanticCacheParametersModel(BaseModel):
|
|
1460
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1461
|
+
time_to_live_seconds: int | None = (
|
|
1462
|
+
60 * 60 * 24
|
|
1463
|
+
) # 1 day default, None or negative = never expires
|
|
1464
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1465
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1466
|
+
question_weight: Optional[float] = (
|
|
1467
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1468
|
+
)
|
|
1469
|
+
context_weight: Optional[float] = (
|
|
1470
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1471
|
+
)
|
|
1472
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1473
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1474
|
+
database: DatabaseModel
|
|
1475
|
+
warehouse: WarehouseModel
|
|
1476
|
+
table_name: str = "genie_semantic_cache"
|
|
1477
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1478
|
+
max_context_tokens: int = (
|
|
1479
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1480
|
+
)
|
|
1481
|
+
|
|
1482
|
+
@model_validator(mode="after")
|
|
1483
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1484
|
+
"""
|
|
1485
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1486
|
+
|
|
1487
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1488
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1489
|
+
If both are provided, they must sum to 1.0.
|
|
1490
|
+
"""
|
|
1491
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1492
|
+
# Both missing - use defaults
|
|
1493
|
+
self.question_weight = 0.6
|
|
1494
|
+
self.context_weight = 0.4
|
|
1495
|
+
elif self.question_weight is None:
|
|
1496
|
+
# Compute question_weight from context_weight
|
|
1497
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1498
|
+
raise ValueError(
|
|
1499
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1500
|
+
)
|
|
1501
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1502
|
+
elif self.context_weight is None:
|
|
1503
|
+
# Compute context_weight from question_weight
|
|
1504
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1505
|
+
raise ValueError(
|
|
1506
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1507
|
+
)
|
|
1508
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1509
|
+
else:
|
|
1510
|
+
# Both provided - validate they sum to 1.0
|
|
1511
|
+
total_weight = self.question_weight + self.context_weight
|
|
1512
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1513
|
+
raise ValueError(
|
|
1514
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1515
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1516
|
+
f"of question vs context similarity in the combined score."
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
return self
|
|
1520
|
+
|
|
1521
|
+
|
|
891
1522
|
class SearchParametersModel(BaseModel):
|
|
892
1523
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
893
1524
|
num_results: Optional[int] = 10
|
|
@@ -936,8 +1567,8 @@ class RerankParametersModel(BaseModel):
|
|
|
936
1567
|
description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
|
|
937
1568
|
)
|
|
938
1569
|
cache_dir: Optional[str] = Field(
|
|
939
|
-
default="/
|
|
940
|
-
description="Directory to cache downloaded model weights.",
|
|
1570
|
+
default="~/.dao_ai/cache/flashrank",
|
|
1571
|
+
description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
|
|
941
1572
|
)
|
|
942
1573
|
columns: Optional[list[str]] = Field(
|
|
943
1574
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
@@ -957,14 +1588,14 @@ class RetrieverModel(BaseModel):
|
|
|
957
1588
|
)
|
|
958
1589
|
|
|
959
1590
|
@model_validator(mode="after")
|
|
960
|
-
def set_default_columns(self):
|
|
1591
|
+
def set_default_columns(self) -> Self:
|
|
961
1592
|
if not self.columns:
|
|
962
1593
|
columns: Sequence[str] = self.vector_store.columns
|
|
963
1594
|
self.columns = columns
|
|
964
1595
|
return self
|
|
965
1596
|
|
|
966
1597
|
@model_validator(mode="after")
|
|
967
|
-
def set_default_reranker(self):
|
|
1598
|
+
def set_default_reranker(self) -> Self:
|
|
968
1599
|
"""Convert bool to ReRankParametersModel with defaults."""
|
|
969
1600
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
970
1601
|
self.rerank = RerankParametersModel()
|
|
@@ -978,28 +1609,47 @@ class FunctionType(str, Enum):
|
|
|
978
1609
|
MCP = "mcp"
|
|
979
1610
|
|
|
980
1611
|
|
|
981
|
-
class
|
|
982
|
-
"""
|
|
1612
|
+
class HumanInTheLoopModel(BaseModel):
|
|
1613
|
+
"""
|
|
1614
|
+
Configuration for Human-in-the-Loop tool approval.
|
|
983
1615
|
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
RESPONSE = "response"
|
|
987
|
-
DECLINE = "decline"
|
|
1616
|
+
This model configures when and how tools require human approval before execution.
|
|
1617
|
+
It maps to LangChain's HumanInTheLoopMiddleware.
|
|
988
1618
|
|
|
1619
|
+
LangChain supports three decision types:
|
|
1620
|
+
- "approve": Execute tool with original arguments
|
|
1621
|
+
- "edit": Modify arguments before execution
|
|
1622
|
+
- "reject": Skip execution with optional feedback message
|
|
1623
|
+
"""
|
|
989
1624
|
|
|
990
|
-
class HumanInTheLoopModel(BaseModel):
|
|
991
1625
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1626
|
+
|
|
1627
|
+
review_prompt: Optional[str] = Field(
|
|
1628
|
+
default=None,
|
|
1629
|
+
description="Message shown to the reviewer when approval is requested",
|
|
1630
|
+
)
|
|
1631
|
+
|
|
1632
|
+
allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
|
|
1633
|
+
default_factory=lambda: ["approve", "edit", "reject"],
|
|
1634
|
+
description="List of allowed decision types for this tool",
|
|
1000
1635
|
)
|
|
1001
|
-
|
|
1002
|
-
|
|
1636
|
+
|
|
1637
|
+
@model_validator(mode="after")
|
|
1638
|
+
def validate_and_normalize_decisions(self) -> Self:
|
|
1639
|
+
"""Validate and normalize allowed decisions."""
|
|
1640
|
+
if not self.allowed_decisions:
|
|
1641
|
+
raise ValueError("At least one decision type must be allowed")
|
|
1642
|
+
|
|
1643
|
+
# Remove duplicates while preserving order
|
|
1644
|
+
seen = set()
|
|
1645
|
+
unique_decisions = []
|
|
1646
|
+
for decision in self.allowed_decisions:
|
|
1647
|
+
if decision not in seen:
|
|
1648
|
+
seen.add(decision)
|
|
1649
|
+
unique_decisions.append(decision)
|
|
1650
|
+
self.allowed_decisions = unique_decisions
|
|
1651
|
+
|
|
1652
|
+
return self
|
|
1003
1653
|
|
|
1004
1654
|
|
|
1005
1655
|
class BaseFunctionModel(ABC, BaseModel):
|
|
@@ -1008,7 +1658,6 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1008
1658
|
discriminator="type",
|
|
1009
1659
|
)
|
|
1010
1660
|
type: FunctionType
|
|
1011
|
-
name: str
|
|
1012
1661
|
human_in_the_loop: Optional[HumanInTheLoopModel] = None
|
|
1013
1662
|
|
|
1014
1663
|
@abstractmethod
|
|
@@ -1025,6 +1674,7 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
1025
1674
|
class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
1026
1675
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1027
1676
|
type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
|
|
1677
|
+
name: str
|
|
1028
1678
|
|
|
1029
1679
|
@property
|
|
1030
1680
|
def full_name(self) -> str:
|
|
@@ -1038,8 +1688,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1038
1688
|
|
|
1039
1689
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
1040
1690
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1041
|
-
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1042
1691
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
1692
|
+
name: str
|
|
1693
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1043
1694
|
|
|
1044
1695
|
@property
|
|
1045
1696
|
def full_name(self) -> str:
|
|
@@ -1051,7 +1702,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1051
1702
|
return [create_factory_tool(self, **kwargs)]
|
|
1052
1703
|
|
|
1053
1704
|
@model_validator(mode="after")
|
|
1054
|
-
def update_args(self):
|
|
1705
|
+
def update_args(self) -> Self:
|
|
1055
1706
|
for key, value in self.args.items():
|
|
1056
1707
|
self.args[key] = value_of(value)
|
|
1057
1708
|
return self
|
|
@@ -1062,7 +1713,16 @@ class TransportType(str, Enum):
|
|
|
1062
1713
|
STDIO = "stdio"
|
|
1063
1714
|
|
|
1064
1715
|
|
|
1065
|
-
class McpFunctionModel(BaseFunctionModel,
|
|
1716
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
1717
|
+
"""
|
|
1718
|
+
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1719
|
+
|
|
1720
|
+
Authentication for MCP connections uses the same options as other resources:
|
|
1721
|
+
- Service Principal (client_id + client_secret + workspace_host)
|
|
1722
|
+
- PAT (pat + workspace_host)
|
|
1723
|
+
- OBO (on_behalf_of_user)
|
|
1724
|
+
"""
|
|
1725
|
+
|
|
1066
1726
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1067
1727
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
1068
1728
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
@@ -1070,10 +1730,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1070
1730
|
url: Optional[AnyVariable] = None
|
|
1071
1731
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1072
1732
|
args: list[str] = Field(default_factory=list)
|
|
1073
|
-
|
|
1074
|
-
client_id: Optional[AnyVariable] = None
|
|
1075
|
-
client_secret: Optional[AnyVariable] = None
|
|
1076
|
-
workspace_host: Optional[AnyVariable] = None
|
|
1733
|
+
# MCP-specific fields
|
|
1077
1734
|
connection: Optional[ConnectionModel] = None
|
|
1078
1735
|
functions: Optional[SchemaModel] = None
|
|
1079
1736
|
genie_room: Optional[GenieRoomModel] = None
|
|
@@ -1081,35 +1738,55 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1081
1738
|
vector_search: Optional[VectorStoreModel] = None
|
|
1082
1739
|
|
|
1083
1740
|
@property
|
|
1084
|
-
def
|
|
1085
|
-
|
|
1741
|
+
def api_scopes(self) -> Sequence[str]:
|
|
1742
|
+
"""API scopes for MCP connections."""
|
|
1743
|
+
return [
|
|
1744
|
+
"serving.serving-endpoints",
|
|
1745
|
+
"mcp.genie",
|
|
1746
|
+
"mcp.functions",
|
|
1747
|
+
"mcp.vectorsearch",
|
|
1748
|
+
"mcp.external",
|
|
1749
|
+
]
|
|
1750
|
+
|
|
1751
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
1752
|
+
"""MCP functions don't declare static resources."""
|
|
1753
|
+
return []
|
|
1086
1754
|
|
|
1087
1755
|
def _get_workspace_host(self) -> str:
|
|
1088
1756
|
"""
|
|
1089
1757
|
Get the workspace host, either from config or from workspace client.
|
|
1090
1758
|
|
|
1091
1759
|
If connection is provided, uses its workspace client.
|
|
1092
|
-
Otherwise, falls back to
|
|
1760
|
+
Otherwise, falls back to the default Databricks host.
|
|
1093
1761
|
|
|
1094
1762
|
Returns:
|
|
1095
|
-
str: The workspace host URL without trailing slash
|
|
1763
|
+
str: The workspace host URL with https:// scheme and without trailing slash
|
|
1096
1764
|
"""
|
|
1097
|
-
from
|
|
1765
|
+
from dao_ai.utils import get_default_databricks_host, normalize_host
|
|
1098
1766
|
|
|
1099
1767
|
# Try to get workspace_host from config
|
|
1100
1768
|
workspace_host: str | None = (
|
|
1101
|
-
value_of(self.workspace_host)
|
|
1769
|
+
normalize_host(value_of(self.workspace_host))
|
|
1770
|
+
if self.workspace_host
|
|
1771
|
+
else None
|
|
1102
1772
|
)
|
|
1103
1773
|
|
|
1104
1774
|
# If no workspace_host in config, get it from workspace client
|
|
1105
1775
|
if not workspace_host:
|
|
1106
1776
|
# Use connection's workspace client if available
|
|
1107
1777
|
if self.connection:
|
|
1108
|
-
workspace_host =
|
|
1778
|
+
workspace_host = normalize_host(
|
|
1779
|
+
self.connection.workspace_client.config.host
|
|
1780
|
+
)
|
|
1109
1781
|
else:
|
|
1110
|
-
#
|
|
1111
|
-
|
|
1112
|
-
|
|
1782
|
+
# get_default_databricks_host already normalizes the host
|
|
1783
|
+
workspace_host = get_default_databricks_host()
|
|
1784
|
+
|
|
1785
|
+
if not workspace_host:
|
|
1786
|
+
raise ValueError(
|
|
1787
|
+
"Could not determine workspace host. "
|
|
1788
|
+
"Please set workspace_host in config or DATABRICKS_HOST environment variable."
|
|
1789
|
+
)
|
|
1113
1790
|
|
|
1114
1791
|
# Remove trailing slash
|
|
1115
1792
|
return workspace_host.rstrip("/")
|
|
@@ -1234,74 +1911,132 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1234
1911
|
self.headers[key] = value_of(value)
|
|
1235
1912
|
return self
|
|
1236
1913
|
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
oauth_fields: Sequence[Any] = [
|
|
1240
|
-
self.client_id,
|
|
1241
|
-
self.client_secret,
|
|
1242
|
-
]
|
|
1243
|
-
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1244
|
-
|
|
1245
|
-
pat_fields: Sequence[Any] = [self.pat]
|
|
1246
|
-
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
1914
|
+
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1915
|
+
from dao_ai.tools import create_mcp_tools
|
|
1247
1916
|
|
|
1248
|
-
|
|
1249
|
-
raise ValueError(
|
|
1250
|
-
"Cannot use both OAuth and user authentication methods. "
|
|
1251
|
-
"Please provide either OAuth credentials or user credentials."
|
|
1252
|
-
)
|
|
1917
|
+
return create_mcp_tools(self)
|
|
1253
1918
|
|
|
1254
|
-
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1255
1919
|
|
|
1256
|
-
|
|
1920
|
+
class UnityCatalogFunctionModel(BaseFunctionModel):
|
|
1921
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1922
|
+
type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
|
|
1923
|
+
resource: FunctionModel
|
|
1924
|
+
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1257
1925
|
|
|
1258
1926
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1259
|
-
from dao_ai.tools import
|
|
1927
|
+
from dao_ai.tools import create_uc_tools
|
|
1260
1928
|
|
|
1261
|
-
return
|
|
1929
|
+
return create_uc_tools(self)
|
|
1930
|
+
|
|
1931
|
+
|
|
1932
|
+
AnyTool: TypeAlias = (
|
|
1933
|
+
Union[
|
|
1934
|
+
PythonFunctionModel,
|
|
1935
|
+
FactoryFunctionModel,
|
|
1936
|
+
UnityCatalogFunctionModel,
|
|
1937
|
+
McpFunctionModel,
|
|
1938
|
+
]
|
|
1939
|
+
| str
|
|
1940
|
+
)
|
|
1941
|
+
|
|
1942
|
+
|
|
1943
|
+
class ToolModel(BaseModel):
|
|
1944
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1945
|
+
name: str
|
|
1946
|
+
function: AnyTool
|
|
1262
1947
|
|
|
1263
1948
|
|
|
1264
|
-
class
|
|
1949
|
+
class PromptModel(BaseModel, HasFullName):
|
|
1265
1950
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1266
1951
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1267
|
-
|
|
1268
|
-
|
|
1952
|
+
name: str
|
|
1953
|
+
description: Optional[str] = None
|
|
1954
|
+
default_template: Optional[str] = None
|
|
1955
|
+
alias: Optional[str] = None
|
|
1956
|
+
version: Optional[int] = None
|
|
1957
|
+
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1958
|
+
auto_register: bool = Field(
|
|
1959
|
+
default=False,
|
|
1960
|
+
description="Whether to automatically register the default_template to the prompt registry. "
|
|
1961
|
+
"If False, the prompt will only be loaded from the registry (never created/updated). "
|
|
1962
|
+
"Defaults to True for backward compatibility.",
|
|
1963
|
+
)
|
|
1964
|
+
|
|
1965
|
+
@property
|
|
1966
|
+
def template(self) -> str:
|
|
1967
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1968
|
+
|
|
1969
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
1970
|
+
prompt_version = provider.get_prompt(self)
|
|
1971
|
+
return prompt_version.to_single_brace_format()
|
|
1972
|
+
|
|
1973
|
+
@property
|
|
1974
|
+
def full_name(self) -> str:
|
|
1975
|
+
prompt_name: str = self.name
|
|
1976
|
+
if self.schema_model:
|
|
1977
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1978
|
+
return prompt_name
|
|
1269
1979
|
|
|
1270
1980
|
@property
|
|
1271
|
-
def
|
|
1272
|
-
|
|
1273
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1274
|
-
return self.name
|
|
1981
|
+
def uri(self) -> str:
|
|
1982
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1275
1983
|
|
|
1276
|
-
|
|
1277
|
-
|
|
1984
|
+
if self.alias:
|
|
1985
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1986
|
+
elif self.version:
|
|
1987
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1988
|
+
else:
|
|
1989
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1278
1990
|
|
|
1279
|
-
return
|
|
1991
|
+
return prompt_uri
|
|
1280
1992
|
|
|
1993
|
+
def as_prompt(self) -> PromptVersion:
|
|
1994
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1995
|
+
return prompt_version
|
|
1281
1996
|
|
|
1282
|
-
|
|
1283
|
-
|
|
1284
|
-
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
McpFunctionModel,
|
|
1288
|
-
]
|
|
1289
|
-
| str
|
|
1290
|
-
)
|
|
1997
|
+
@model_validator(mode="after")
|
|
1998
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1999
|
+
if self.alias and self.version:
|
|
2000
|
+
raise ValueError("Cannot specify both alias and version")
|
|
2001
|
+
return self
|
|
1291
2002
|
|
|
1292
2003
|
|
|
1293
|
-
class
|
|
2004
|
+
class GuardrailModel(BaseModel):
|
|
1294
2005
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1295
2006
|
name: str
|
|
1296
|
-
|
|
2007
|
+
model: str | LLMModel
|
|
2008
|
+
prompt: str | PromptModel
|
|
2009
|
+
num_retries: Optional[int] = 3
|
|
2010
|
+
|
|
2011
|
+
@model_validator(mode="after")
|
|
2012
|
+
def validate_llm_model(self) -> Self:
|
|
2013
|
+
if isinstance(self.model, str):
|
|
2014
|
+
self.model = LLMModel(name=self.model)
|
|
2015
|
+
return self
|
|
1297
2016
|
|
|
1298
2017
|
|
|
1299
|
-
class
|
|
2018
|
+
class MiddlewareModel(BaseModel):
|
|
2019
|
+
"""Configuration for middleware that can be applied to agents.
|
|
2020
|
+
|
|
2021
|
+
Middleware is defined at the AppConfig level and can be referenced by name
|
|
2022
|
+
in agent configurations using YAML anchors for reusability.
|
|
2023
|
+
"""
|
|
2024
|
+
|
|
1300
2025
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1301
|
-
name: str
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
2026
|
+
name: str = Field(
|
|
2027
|
+
description="Fully qualified name of the middleware factory function"
|
|
2028
|
+
)
|
|
2029
|
+
args: dict[str, Any] = Field(
|
|
2030
|
+
default_factory=dict,
|
|
2031
|
+
description="Arguments to pass to the middleware factory function",
|
|
2032
|
+
)
|
|
2033
|
+
|
|
2034
|
+
@model_validator(mode="after")
|
|
2035
|
+
def resolve_args(self) -> Self:
|
|
2036
|
+
"""Resolve any variable references in args."""
|
|
2037
|
+
for key, value in self.args.items():
|
|
2038
|
+
self.args[key] = value_of(value)
|
|
2039
|
+
return self
|
|
1305
2040
|
|
|
1306
2041
|
|
|
1307
2042
|
class StorageType(str, Enum):
|
|
@@ -1312,14 +2047,12 @@ class StorageType(str, Enum):
|
|
|
1312
2047
|
class CheckpointerModel(BaseModel):
|
|
1313
2048
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1314
2049
|
name: str
|
|
1315
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1316
2050
|
database: Optional[DatabaseModel] = None
|
|
1317
2051
|
|
|
1318
|
-
@
|
|
1319
|
-
def
|
|
1320
|
-
|
|
1321
|
-
|
|
1322
|
-
return self
|
|
2052
|
+
@property
|
|
2053
|
+
def storage_type(self) -> StorageType:
|
|
2054
|
+
"""Infer storage type from database presence."""
|
|
2055
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1323
2056
|
|
|
1324
2057
|
def as_checkpointer(self) -> BaseCheckpointSaver:
|
|
1325
2058
|
from dao_ai.memory import CheckpointManager
|
|
@@ -1335,16 +2068,14 @@ class StoreModel(BaseModel):
|
|
|
1335
2068
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1336
2069
|
name: str
|
|
1337
2070
|
embedding_model: Optional[LLMModel] = None
|
|
1338
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1339
2071
|
dims: Optional[int] = 1536
|
|
1340
2072
|
database: Optional[DatabaseModel] = None
|
|
1341
2073
|
namespace: Optional[str] = None
|
|
1342
2074
|
|
|
1343
|
-
@
|
|
1344
|
-
def
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
return self
|
|
2075
|
+
@property
|
|
2076
|
+
def storage_type(self) -> StorageType:
|
|
2077
|
+
"""Infer storage type from database presence."""
|
|
2078
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1348
2079
|
|
|
1349
2080
|
def as_store(self) -> BaseStore:
|
|
1350
2081
|
from dao_ai.memory import StoreManager
|
|
@@ -1362,56 +2093,158 @@ class MemoryModel(BaseModel):
|
|
|
1362
2093
|
FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
|
|
1363
2094
|
|
|
1364
2095
|
|
|
1365
|
-
class
|
|
2096
|
+
class ResponseFormatModel(BaseModel):
|
|
2097
|
+
"""
|
|
2098
|
+
Configuration for structured response formats.
|
|
2099
|
+
|
|
2100
|
+
The response_schema field accepts either a type or a string:
|
|
2101
|
+
- Type (Pydantic model, dataclass, etc.): Used directly for structured output
|
|
2102
|
+
- String: First attempts to load as a fully qualified type name, falls back to JSON schema string
|
|
2103
|
+
|
|
2104
|
+
This unified approach simplifies the API while maintaining flexibility.
|
|
2105
|
+
"""
|
|
2106
|
+
|
|
1366
2107
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
2108
|
+
use_tool: Optional[bool] = Field(
|
|
2109
|
+
default=None,
|
|
2110
|
+
description=(
|
|
2111
|
+
"Strategy for structured output: "
|
|
2112
|
+
"None (default) = auto-detect from model capabilities, "
|
|
2113
|
+
"False = force ProviderStrategy (native), "
|
|
2114
|
+
"True = force ToolStrategy (function calling)"
|
|
2115
|
+
),
|
|
2116
|
+
)
|
|
2117
|
+
response_schema: Optional[str | type] = Field(
|
|
2118
|
+
default=None,
|
|
2119
|
+
description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
|
|
2120
|
+
)
|
|
1374
2121
|
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
2122
|
+
def as_strategy(self) -> ProviderStrategy | ToolStrategy:
|
|
2123
|
+
"""
|
|
2124
|
+
Convert response_schema to appropriate LangChain strategy.
|
|
1378
2125
|
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
2126
|
+
Returns:
|
|
2127
|
+
- None if no response_schema configured
|
|
2128
|
+
- Raw schema/type for auto-detection (when use_tool=None)
|
|
2129
|
+
- ToolStrategy wrapping the schema (when use_tool=True)
|
|
2130
|
+
- ProviderStrategy wrapping the schema (when use_tool=False)
|
|
1382
2131
|
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
if self.schema_model:
|
|
1387
|
-
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1388
|
-
return prompt_name
|
|
2132
|
+
Raises:
|
|
2133
|
+
ValueError: If response_schema is a JSON schema string that cannot be parsed
|
|
2134
|
+
"""
|
|
1389
2135
|
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
2136
|
+
if self.response_schema is None:
|
|
2137
|
+
return None
|
|
1393
2138
|
|
|
1394
|
-
|
|
1395
|
-
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1396
|
-
elif self.version:
|
|
1397
|
-
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1398
|
-
else:
|
|
1399
|
-
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
2139
|
+
schema = self.response_schema
|
|
1400
2140
|
|
|
1401
|
-
|
|
2141
|
+
# Handle type schemas (Pydantic, dataclass, etc.)
|
|
2142
|
+
if self.is_type_schema:
|
|
2143
|
+
if self.use_tool is None:
|
|
2144
|
+
# Auto-detect: Pass schema directly, let LangChain decide
|
|
2145
|
+
return schema
|
|
2146
|
+
elif self.use_tool is True:
|
|
2147
|
+
# Force ToolStrategy (function calling)
|
|
2148
|
+
return ToolStrategy(schema)
|
|
2149
|
+
else: # use_tool is False
|
|
2150
|
+
# Force ProviderStrategy (native structured output)
|
|
2151
|
+
return ProviderStrategy(schema)
|
|
1402
2152
|
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
2153
|
+
# Handle JSON schema strings
|
|
2154
|
+
elif self.is_json_schema:
|
|
2155
|
+
import json
|
|
2156
|
+
|
|
2157
|
+
try:
|
|
2158
|
+
schema_dict = json.loads(schema)
|
|
2159
|
+
except json.JSONDecodeError as e:
|
|
2160
|
+
raise ValueError(f"Invalid JSON schema string: {e}") from e
|
|
2161
|
+
|
|
2162
|
+
# Apply same use_tool logic as type schemas
|
|
2163
|
+
if self.use_tool is None:
|
|
2164
|
+
# Auto-detect
|
|
2165
|
+
return schema_dict
|
|
2166
|
+
elif self.use_tool is True:
|
|
2167
|
+
# Force ToolStrategy
|
|
2168
|
+
return ToolStrategy(schema_dict)
|
|
2169
|
+
else: # use_tool is False
|
|
2170
|
+
# Force ProviderStrategy
|
|
2171
|
+
return ProviderStrategy(schema_dict)
|
|
2172
|
+
|
|
2173
|
+
return None
|
|
1406
2174
|
|
|
1407
2175
|
@model_validator(mode="after")
|
|
1408
|
-
def
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
2176
|
+
def validate_response_schema(self) -> Self:
|
|
2177
|
+
"""
|
|
2178
|
+
Validate and convert response_schema.
|
|
2179
|
+
|
|
2180
|
+
Processing logic:
|
|
2181
|
+
1. If None: no response format specified
|
|
2182
|
+
2. If type: use directly as structured output type
|
|
2183
|
+
3. If str: try to load as FQN using type_from_fqn
|
|
2184
|
+
- Success: response_schema becomes the loaded type
|
|
2185
|
+
- Failure: keep as string (treated as JSON schema)
|
|
2186
|
+
|
|
2187
|
+
After validation, response_schema is one of:
|
|
2188
|
+
- None (no schema)
|
|
2189
|
+
- type (use for structured output)
|
|
2190
|
+
- str (JSON schema)
|
|
2191
|
+
|
|
2192
|
+
Returns:
|
|
2193
|
+
Self with validated response_schema
|
|
2194
|
+
"""
|
|
2195
|
+
if self.response_schema is None:
|
|
2196
|
+
return self
|
|
2197
|
+
|
|
2198
|
+
# If already a type, return
|
|
2199
|
+
if isinstance(self.response_schema, type):
|
|
2200
|
+
return self
|
|
2201
|
+
|
|
2202
|
+
# If it's a string, try to load as type, fallback to json_schema
|
|
2203
|
+
if isinstance(self.response_schema, str):
|
|
2204
|
+
from dao_ai.utils import type_from_fqn
|
|
2205
|
+
|
|
2206
|
+
try:
|
|
2207
|
+
resolved_type = type_from_fqn(self.response_schema)
|
|
2208
|
+
self.response_schema = resolved_type
|
|
2209
|
+
logger.debug(
|
|
2210
|
+
f"Resolved response_schema string to type: {resolved_type}"
|
|
2211
|
+
)
|
|
2212
|
+
return self
|
|
2213
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
2214
|
+
# Keep as string - it's a JSON schema
|
|
2215
|
+
logger.debug(
|
|
2216
|
+
f"Could not resolve '{self.response_schema}' as type: {e}. "
|
|
2217
|
+
f"Treating as JSON schema string."
|
|
2218
|
+
)
|
|
2219
|
+
return self
|
|
2220
|
+
|
|
2221
|
+
# Invalid type
|
|
2222
|
+
raise ValueError(
|
|
2223
|
+
f"response_schema must be None, type, or str, got {type(self.response_schema)}"
|
|
2224
|
+
)
|
|
2225
|
+
|
|
2226
|
+
@property
|
|
2227
|
+
def is_type_schema(self) -> bool:
|
|
2228
|
+
"""Returns True if response_schema is a type (not JSON schema string)."""
|
|
2229
|
+
return isinstance(self.response_schema, type)
|
|
2230
|
+
|
|
2231
|
+
@property
|
|
2232
|
+
def is_json_schema(self) -> bool:
|
|
2233
|
+
"""Returns True if response_schema is a JSON schema string (not a type)."""
|
|
2234
|
+
return isinstance(self.response_schema, str)
|
|
1412
2235
|
|
|
1413
2236
|
|
|
1414
2237
|
class AgentModel(BaseModel):
|
|
2238
|
+
"""
|
|
2239
|
+
Configuration model for an agent in the DAO AI framework.
|
|
2240
|
+
|
|
2241
|
+
Agents combine an LLM with tools and middleware to create systems that can
|
|
2242
|
+
reason about tasks, decide which tools to use, and iteratively work towards solutions.
|
|
2243
|
+
|
|
2244
|
+
Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
|
|
2245
|
+
providing a more flexible and composable way to customize agent behavior.
|
|
2246
|
+
"""
|
|
2247
|
+
|
|
1415
2248
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1416
2249
|
name: str
|
|
1417
2250
|
description: Optional[str] = None
|
|
@@ -1420,9 +2253,43 @@ class AgentModel(BaseModel):
|
|
|
1420
2253
|
guardrails: list[GuardrailModel] = Field(default_factory=list)
|
|
1421
2254
|
prompt: Optional[str | PromptModel] = None
|
|
1422
2255
|
handoff_prompt: Optional[str] = None
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
2256
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2257
|
+
default_factory=list,
|
|
2258
|
+
description="List of middleware to apply to this agent",
|
|
2259
|
+
)
|
|
2260
|
+
response_format: Optional[ResponseFormatModel | type | str] = None
|
|
2261
|
+
|
|
2262
|
+
@model_validator(mode="after")
|
|
2263
|
+
def validate_response_format(self) -> Self:
|
|
2264
|
+
"""
|
|
2265
|
+
Validate and normalize response_format.
|
|
2266
|
+
|
|
2267
|
+
Accepts:
|
|
2268
|
+
- None (no response format)
|
|
2269
|
+
- ResponseFormatModel (already validated)
|
|
2270
|
+
- type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
|
|
2271
|
+
- str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
|
|
2272
|
+
|
|
2273
|
+
ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
|
|
2274
|
+
"""
|
|
2275
|
+
if self.response_format is None or isinstance(
|
|
2276
|
+
self.response_format, ResponseFormatModel
|
|
2277
|
+
):
|
|
2278
|
+
return self
|
|
2279
|
+
|
|
2280
|
+
# Convert type or str to ResponseFormatModel
|
|
2281
|
+
# ResponseFormatModel's validator will handle the smart type loading and fallback
|
|
2282
|
+
if isinstance(self.response_format, (type, str)):
|
|
2283
|
+
self.response_format = ResponseFormatModel(
|
|
2284
|
+
response_schema=self.response_format
|
|
2285
|
+
)
|
|
2286
|
+
return self
|
|
2287
|
+
|
|
2288
|
+
# Invalid type
|
|
2289
|
+
raise ValueError(
|
|
2290
|
+
f"response_format must be None, ResponseFormatModel, type, or str, "
|
|
2291
|
+
f"got {type(self.response_format)}"
|
|
2292
|
+
)
|
|
1426
2293
|
|
|
1427
2294
|
def as_runnable(self) -> RunnableLike:
|
|
1428
2295
|
from dao_ai.nodes import create_agent_node
|
|
@@ -1441,12 +2308,20 @@ class SupervisorModel(BaseModel):
|
|
|
1441
2308
|
model: LLMModel
|
|
1442
2309
|
tools: list[ToolModel] = Field(default_factory=list)
|
|
1443
2310
|
prompt: Optional[str] = None
|
|
2311
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2312
|
+
default_factory=list,
|
|
2313
|
+
description="List of middleware to apply to the supervisor",
|
|
2314
|
+
)
|
|
1444
2315
|
|
|
1445
2316
|
|
|
1446
2317
|
class SwarmModel(BaseModel):
|
|
1447
2318
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1448
2319
|
model: LLMModel
|
|
1449
2320
|
default_agent: Optional[AgentModel | str] = None
|
|
2321
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2322
|
+
default_factory=list,
|
|
2323
|
+
description="List of middleware to apply to all agents in the swarm",
|
|
2324
|
+
)
|
|
1450
2325
|
handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
|
|
1451
2326
|
default_factory=dict
|
|
1452
2327
|
)
|
|
@@ -1459,7 +2334,7 @@ class OrchestrationModel(BaseModel):
|
|
|
1459
2334
|
memory: Optional[MemoryModel] = None
|
|
1460
2335
|
|
|
1461
2336
|
@model_validator(mode="after")
|
|
1462
|
-
def validate_mutually_exclusive(self):
|
|
2337
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1463
2338
|
if self.supervisor is not None and self.swarm is not None:
|
|
1464
2339
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
1465
2340
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -1489,9 +2364,21 @@ class Entitlement(str, Enum):
|
|
|
1489
2364
|
|
|
1490
2365
|
class AppPermissionModel(BaseModel):
|
|
1491
2366
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1492
|
-
principals: list[str] = Field(default_factory=list)
|
|
2367
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
1493
2368
|
entitlements: list[Entitlement]
|
|
1494
2369
|
|
|
2370
|
+
@model_validator(mode="after")
|
|
2371
|
+
def resolve_principals(self) -> Self:
|
|
2372
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
2373
|
+
resolved: list[str] = []
|
|
2374
|
+
for principal in self.principals:
|
|
2375
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
2376
|
+
resolved.append(value_of(principal.client_id))
|
|
2377
|
+
else:
|
|
2378
|
+
resolved.append(principal)
|
|
2379
|
+
self.principals = resolved
|
|
2380
|
+
return self
|
|
2381
|
+
|
|
1495
2382
|
|
|
1496
2383
|
class LogLevel(str, Enum):
|
|
1497
2384
|
TRACE = "TRACE"
|
|
@@ -1552,6 +2439,28 @@ class ChatPayload(BaseModel):
|
|
|
1552
2439
|
|
|
1553
2440
|
return self
|
|
1554
2441
|
|
|
2442
|
+
@model_validator(mode="after")
|
|
2443
|
+
def ensure_thread_id(self) -> "ChatPayload":
|
|
2444
|
+
"""Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
|
|
2445
|
+
import uuid
|
|
2446
|
+
|
|
2447
|
+
if self.custom_inputs is None:
|
|
2448
|
+
self.custom_inputs = {}
|
|
2449
|
+
|
|
2450
|
+
# Get or create configurable section
|
|
2451
|
+
configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
|
|
2452
|
+
|
|
2453
|
+
# Check if thread_id or conversation_id exists
|
|
2454
|
+
has_thread_id = configurable.get("thread_id") is not None
|
|
2455
|
+
has_conversation_id = configurable.get("conversation_id") is not None
|
|
2456
|
+
|
|
2457
|
+
# If neither is provided, generate a UUID for conversation_id
|
|
2458
|
+
if not has_thread_id and not has_conversation_id:
|
|
2459
|
+
configurable["conversation_id"] = str(uuid.uuid4())
|
|
2460
|
+
self.custom_inputs["configurable"] = configurable
|
|
2461
|
+
|
|
2462
|
+
return self
|
|
2463
|
+
|
|
1555
2464
|
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1556
2465
|
return messages_from_dict(
|
|
1557
2466
|
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
@@ -1567,25 +2476,44 @@ class ChatPayload(BaseModel):
|
|
|
1567
2476
|
|
|
1568
2477
|
|
|
1569
2478
|
class ChatHistoryModel(BaseModel):
|
|
2479
|
+
"""
|
|
2480
|
+
Configuration for chat history summarization.
|
|
2481
|
+
|
|
2482
|
+
Attributes:
|
|
2483
|
+
model: The LLM to use for generating summaries.
|
|
2484
|
+
max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
|
|
2485
|
+
After summarization, recent messages totaling up to this many tokens are preserved.
|
|
2486
|
+
max_tokens_before_summary: Token threshold that triggers summarization.
|
|
2487
|
+
When conversation exceeds this, summarization runs. Mutually exclusive with
|
|
2488
|
+
max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
|
|
2489
|
+
max_messages_before_summary: Message count threshold that triggers summarization.
|
|
2490
|
+
When conversation exceeds this many messages, summarization runs.
|
|
2491
|
+
Mutually exclusive with max_tokens_before_summary.
|
|
2492
|
+
"""
|
|
2493
|
+
|
|
1570
2494
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1571
2495
|
model: LLMModel
|
|
1572
|
-
max_tokens: int =
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
2496
|
+
max_tokens: int = Field(
|
|
2497
|
+
default=2048,
|
|
2498
|
+
gt=0,
|
|
2499
|
+
description="Maximum tokens to keep after summarization",
|
|
2500
|
+
)
|
|
2501
|
+
max_tokens_before_summary: Optional[int] = Field(
|
|
2502
|
+
default=None,
|
|
2503
|
+
gt=0,
|
|
2504
|
+
description="Token threshold that triggers summarization",
|
|
2505
|
+
)
|
|
2506
|
+
max_messages_before_summary: Optional[int] = Field(
|
|
2507
|
+
default=None,
|
|
2508
|
+
gt=0,
|
|
2509
|
+
description="Message count threshold that triggers summarization",
|
|
2510
|
+
)
|
|
1584
2511
|
|
|
1585
2512
|
|
|
1586
2513
|
class AppModel(BaseModel):
|
|
1587
2514
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1588
2515
|
name: str
|
|
2516
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1589
2517
|
description: Optional[str] = None
|
|
1590
2518
|
log_level: Optional[LogLevel] = "WARNING"
|
|
1591
2519
|
registered_model: RegisteredModelModel
|
|
@@ -1606,23 +2534,54 @@ class AppModel(BaseModel):
|
|
|
1606
2534
|
shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1607
2535
|
default_factory=list
|
|
1608
2536
|
)
|
|
1609
|
-
message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1610
|
-
default_factory=list
|
|
1611
|
-
)
|
|
1612
2537
|
input_example: Optional[ChatPayload] = None
|
|
1613
2538
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1614
2539
|
code_paths: list[str] = Field(default_factory=list)
|
|
1615
2540
|
pip_requirements: list[str] = Field(default_factory=list)
|
|
2541
|
+
python_version: Optional[str] = Field(
|
|
2542
|
+
default="3.12",
|
|
2543
|
+
description="Python version for Model Serving deployment. Defaults to 3.12 "
|
|
2544
|
+
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
2545
|
+
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
2546
|
+
)
|
|
2547
|
+
|
|
2548
|
+
@model_validator(mode="after")
|
|
2549
|
+
def set_databricks_env_vars(self) -> Self:
|
|
2550
|
+
"""Set Databricks environment variables for Model Serving.
|
|
2551
|
+
|
|
2552
|
+
Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
|
|
2553
|
+
Values explicitly provided in environment_vars take precedence.
|
|
2554
|
+
"""
|
|
2555
|
+
from dao_ai.utils import get_default_databricks_host
|
|
2556
|
+
|
|
2557
|
+
# Set DATABRICKS_HOST if not already provided
|
|
2558
|
+
if "DATABRICKS_HOST" not in self.environment_vars:
|
|
2559
|
+
host: str | None = get_default_databricks_host()
|
|
2560
|
+
if host:
|
|
2561
|
+
self.environment_vars["DATABRICKS_HOST"] = host
|
|
2562
|
+
|
|
2563
|
+
# Set service principal credentials if provided
|
|
2564
|
+
if self.service_principal is not None:
|
|
2565
|
+
if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
|
|
2566
|
+
self.environment_vars["DATABRICKS_CLIENT_ID"] = (
|
|
2567
|
+
self.service_principal.client_id
|
|
2568
|
+
)
|
|
2569
|
+
if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
|
|
2570
|
+
self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
|
|
2571
|
+
self.service_principal.client_secret
|
|
2572
|
+
)
|
|
2573
|
+
return self
|
|
1616
2574
|
|
|
1617
2575
|
@model_validator(mode="after")
|
|
1618
|
-
def validate_agents_not_empty(self):
|
|
2576
|
+
def validate_agents_not_empty(self) -> Self:
|
|
1619
2577
|
if not self.agents:
|
|
1620
2578
|
raise ValueError("At least one agent must be specified")
|
|
1621
2579
|
return self
|
|
1622
2580
|
|
|
1623
2581
|
@model_validator(mode="after")
|
|
1624
|
-
def
|
|
2582
|
+
def resolve_environment_vars(self) -> Self:
|
|
1625
2583
|
for key, value in self.environment_vars.items():
|
|
2584
|
+
updated_value: str
|
|
1626
2585
|
if isinstance(value, SecretVariableModel):
|
|
1627
2586
|
updated_value = str(value)
|
|
1628
2587
|
else:
|
|
@@ -1632,7 +2591,7 @@ class AppModel(BaseModel):
|
|
|
1632
2591
|
return self
|
|
1633
2592
|
|
|
1634
2593
|
@model_validator(mode="after")
|
|
1635
|
-
def set_default_orchestration(self):
|
|
2594
|
+
def set_default_orchestration(self) -> Self:
|
|
1636
2595
|
if self.orchestration is None:
|
|
1637
2596
|
if len(self.agents) > 1:
|
|
1638
2597
|
default_agent: AgentModel = self.agents[0]
|
|
@@ -1652,14 +2611,14 @@ class AppModel(BaseModel):
|
|
|
1652
2611
|
return self
|
|
1653
2612
|
|
|
1654
2613
|
@model_validator(mode="after")
|
|
1655
|
-
def set_default_endpoint_name(self):
|
|
2614
|
+
def set_default_endpoint_name(self) -> Self:
|
|
1656
2615
|
if self.endpoint_name is None:
|
|
1657
2616
|
self.endpoint_name = self.name
|
|
1658
2617
|
return self
|
|
1659
2618
|
|
|
1660
2619
|
@model_validator(mode="after")
|
|
1661
|
-
def set_default_agent(self):
|
|
1662
|
-
default_agent_name = self.agents[0].name
|
|
2620
|
+
def set_default_agent(self) -> Self:
|
|
2621
|
+
default_agent_name: str = self.agents[0].name
|
|
1663
2622
|
|
|
1664
2623
|
if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
|
|
1665
2624
|
self.orchestration.swarm.default_agent = default_agent_name
|
|
@@ -1667,7 +2626,7 @@ class AppModel(BaseModel):
|
|
|
1667
2626
|
return self
|
|
1668
2627
|
|
|
1669
2628
|
@model_validator(mode="after")
|
|
1670
|
-
def add_code_paths_to_sys_path(self):
|
|
2629
|
+
def add_code_paths_to_sys_path(self) -> Self:
|
|
1671
2630
|
for code_path in self.code_paths:
|
|
1672
2631
|
parent_path: str = str(Path(code_path).parent)
|
|
1673
2632
|
if parent_path not in sys.path:
|
|
@@ -1700,7 +2659,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
|
|
|
1700
2659
|
expected_facts: Optional[list[str]] = None
|
|
1701
2660
|
|
|
1702
2661
|
@model_validator(mode="after")
|
|
1703
|
-
def validate_mutually_exclusive(self):
|
|
2662
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1704
2663
|
if self.expected_response is not None and self.expected_facts is not None:
|
|
1705
2664
|
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1706
2665
|
return self
|
|
@@ -1779,36 +2738,70 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
|
1779
2738
|
|
|
1780
2739
|
|
|
1781
2740
|
class PromptOptimizationModel(BaseModel):
|
|
2741
|
+
"""Configuration for prompt optimization using GEPA.
|
|
2742
|
+
|
|
2743
|
+
GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
|
|
2744
|
+
optimizer that uses reflective mutation to improve prompts based on
|
|
2745
|
+
evaluation feedback.
|
|
2746
|
+
|
|
2747
|
+
Example:
|
|
2748
|
+
prompt_optimization:
|
|
2749
|
+
name: optimize_my_prompt
|
|
2750
|
+
prompt: *my_prompt
|
|
2751
|
+
agent: *my_agent
|
|
2752
|
+
dataset: *my_training_dataset
|
|
2753
|
+
reflection_model: databricks-meta-llama-3-3-70b-instruct
|
|
2754
|
+
num_candidates: 50
|
|
2755
|
+
"""
|
|
2756
|
+
|
|
1782
2757
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1783
2758
|
name: str
|
|
1784
2759
|
prompt: Optional[PromptModel] = None
|
|
1785
2760
|
agent: AgentModel
|
|
1786
|
-
dataset:
|
|
1787
|
-
EvaluationDatasetModel | str
|
|
1788
|
-
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
2761
|
+
dataset: EvaluationDatasetModel # Training dataset with examples
|
|
1789
2762
|
reflection_model: Optional[LLMModel | str] = None
|
|
1790
2763
|
num_candidates: Optional[int] = 50
|
|
1791
|
-
scorer_model: Optional[LLMModel | str] = None
|
|
1792
2764
|
|
|
1793
2765
|
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1794
2766
|
"""
|
|
1795
|
-
Optimize the prompt using
|
|
2767
|
+
Optimize the prompt using GEPA.
|
|
1796
2768
|
|
|
1797
2769
|
Args:
|
|
1798
|
-
w: Optional WorkspaceClient for
|
|
2770
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
1799
2771
|
|
|
1800
2772
|
Returns:
|
|
1801
|
-
PromptModel: The optimized prompt model
|
|
2773
|
+
PromptModel: The optimized prompt model
|
|
1802
2774
|
"""
|
|
1803
|
-
from dao_ai.
|
|
1804
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
2775
|
+
from dao_ai.optimization import OptimizationResult, optimize_prompt
|
|
1805
2776
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
2777
|
+
# Get reflection model name
|
|
2778
|
+
reflection_model_name: str | None = None
|
|
2779
|
+
if self.reflection_model:
|
|
2780
|
+
if isinstance(self.reflection_model, str):
|
|
2781
|
+
reflection_model_name = self.reflection_model
|
|
2782
|
+
else:
|
|
2783
|
+
reflection_model_name = self.reflection_model.uri
|
|
2784
|
+
|
|
2785
|
+
# Ensure prompt is set
|
|
2786
|
+
prompt = self.prompt
|
|
2787
|
+
if prompt is None:
|
|
2788
|
+
raise ValueError(
|
|
2789
|
+
f"Prompt optimization '{self.name}' requires a prompt to be set"
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
result: OptimizationResult = optimize_prompt(
|
|
2793
|
+
prompt=prompt,
|
|
2794
|
+
agent=self.agent,
|
|
2795
|
+
dataset=self.dataset,
|
|
2796
|
+
reflection_model=reflection_model_name,
|
|
2797
|
+
num_candidates=self.num_candidates or 50,
|
|
2798
|
+
register_if_improved=True,
|
|
2799
|
+
)
|
|
2800
|
+
|
|
2801
|
+
return result.optimized_prompt
|
|
1809
2802
|
|
|
1810
2803
|
@model_validator(mode="after")
|
|
1811
|
-
def set_defaults(self):
|
|
2804
|
+
def set_defaults(self) -> Self:
|
|
1812
2805
|
# If no prompt is specified, try to use the agent's prompt
|
|
1813
2806
|
if self.prompt is None:
|
|
1814
2807
|
if isinstance(self.agent.prompt, PromptModel):
|
|
@@ -1819,12 +2812,6 @@ class PromptOptimizationModel(BaseModel):
|
|
|
1819
2812
|
f"or an agent with a prompt configured"
|
|
1820
2813
|
)
|
|
1821
2814
|
|
|
1822
|
-
if self.reflection_model is None:
|
|
1823
|
-
self.reflection_model = self.agent.model
|
|
1824
|
-
|
|
1825
|
-
if self.scorer_model is None:
|
|
1826
|
-
self.scorer_model = self.agent.model
|
|
1827
|
-
|
|
1828
2815
|
return self
|
|
1829
2816
|
|
|
1830
2817
|
|
|
@@ -1897,7 +2884,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
|
|
|
1897
2884
|
|
|
1898
2885
|
class UnityCatalogFunctionSqlModel(BaseModel):
|
|
1899
2886
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1900
|
-
function:
|
|
2887
|
+
function: FunctionModel
|
|
1901
2888
|
ddl: str
|
|
1902
2889
|
parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1903
2890
|
test: Optional[UnityCatalogFunctionSqlTestModel] = None
|
|
@@ -1925,16 +2912,126 @@ class ResourcesModel(BaseModel):
|
|
|
1925
2912
|
warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
|
|
1926
2913
|
databases: dict[str, DatabaseModel] = Field(default_factory=dict)
|
|
1927
2914
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
2915
|
+
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
2916
|
+
|
|
2917
|
+
@model_validator(mode="after")
|
|
2918
|
+
def update_genie_warehouses(self) -> Self:
|
|
2919
|
+
"""
|
|
2920
|
+
Automatically populate warehouses from genie_rooms.
|
|
2921
|
+
|
|
2922
|
+
Warehouses are extracted from each Genie room and added to the
|
|
2923
|
+
resources if they don't already exist (based on warehouse_id).
|
|
2924
|
+
"""
|
|
2925
|
+
if not self.genie_rooms:
|
|
2926
|
+
return self
|
|
2927
|
+
|
|
2928
|
+
# Process warehouses from all genie rooms
|
|
2929
|
+
for genie_room in self.genie_rooms.values():
|
|
2930
|
+
genie_room: GenieRoomModel
|
|
2931
|
+
warehouse: Optional[WarehouseModel] = genie_room.warehouse
|
|
2932
|
+
|
|
2933
|
+
if warehouse is None:
|
|
2934
|
+
continue
|
|
2935
|
+
|
|
2936
|
+
# Check if warehouse already exists based on warehouse_id
|
|
2937
|
+
warehouse_exists: bool = any(
|
|
2938
|
+
existing_warehouse.warehouse_id == warehouse.warehouse_id
|
|
2939
|
+
for existing_warehouse in self.warehouses.values()
|
|
2940
|
+
)
|
|
2941
|
+
|
|
2942
|
+
if not warehouse_exists:
|
|
2943
|
+
warehouse_key: str = normalize_name(
|
|
2944
|
+
"_".join([genie_room.name, warehouse.warehouse_id])
|
|
2945
|
+
)
|
|
2946
|
+
self.warehouses[warehouse_key] = warehouse
|
|
2947
|
+
logger.trace(
|
|
2948
|
+
"Added warehouse from Genie room",
|
|
2949
|
+
room=genie_room.name,
|
|
2950
|
+
warehouse=warehouse.warehouse_id,
|
|
2951
|
+
key=warehouse_key,
|
|
2952
|
+
)
|
|
2953
|
+
|
|
2954
|
+
return self
|
|
2955
|
+
|
|
2956
|
+
@model_validator(mode="after")
|
|
2957
|
+
def update_genie_tables(self) -> Self:
|
|
2958
|
+
"""
|
|
2959
|
+
Automatically populate tables from genie_rooms.
|
|
2960
|
+
|
|
2961
|
+
Tables are extracted from each Genie room and added to the
|
|
2962
|
+
resources if they don't already exist (based on full_name).
|
|
2963
|
+
"""
|
|
2964
|
+
if not self.genie_rooms:
|
|
2965
|
+
return self
|
|
2966
|
+
|
|
2967
|
+
# Process tables from all genie rooms
|
|
2968
|
+
for genie_room in self.genie_rooms.values():
|
|
2969
|
+
genie_room: GenieRoomModel
|
|
2970
|
+
for table in genie_room.tables:
|
|
2971
|
+
table: TableModel
|
|
2972
|
+
table_exists: bool = any(
|
|
2973
|
+
existing_table.full_name == table.full_name
|
|
2974
|
+
for existing_table in self.tables.values()
|
|
2975
|
+
)
|
|
2976
|
+
if not table_exists:
|
|
2977
|
+
table_key: str = normalize_name(
|
|
2978
|
+
"_".join([genie_room.name, table.full_name])
|
|
2979
|
+
)
|
|
2980
|
+
self.tables[table_key] = table
|
|
2981
|
+
logger.trace(
|
|
2982
|
+
"Added table from Genie room",
|
|
2983
|
+
room=genie_room.name,
|
|
2984
|
+
table=table.name,
|
|
2985
|
+
key=table_key,
|
|
2986
|
+
)
|
|
2987
|
+
|
|
2988
|
+
return self
|
|
2989
|
+
|
|
2990
|
+
@model_validator(mode="after")
|
|
2991
|
+
def update_genie_functions(self) -> Self:
|
|
2992
|
+
"""
|
|
2993
|
+
Automatically populate functions from genie_rooms.
|
|
2994
|
+
|
|
2995
|
+
Functions are extracted from each Genie room and added to the
|
|
2996
|
+
resources if they don't already exist (based on full_name).
|
|
2997
|
+
"""
|
|
2998
|
+
if not self.genie_rooms:
|
|
2999
|
+
return self
|
|
3000
|
+
|
|
3001
|
+
# Process functions from all genie rooms
|
|
3002
|
+
for genie_room in self.genie_rooms.values():
|
|
3003
|
+
genie_room: GenieRoomModel
|
|
3004
|
+
for function in genie_room.functions:
|
|
3005
|
+
function: FunctionModel
|
|
3006
|
+
function_exists: bool = any(
|
|
3007
|
+
existing_function.full_name == function.full_name
|
|
3008
|
+
for existing_function in self.functions.values()
|
|
3009
|
+
)
|
|
3010
|
+
if not function_exists:
|
|
3011
|
+
function_key: str = normalize_name(
|
|
3012
|
+
"_".join([genie_room.name, function.full_name])
|
|
3013
|
+
)
|
|
3014
|
+
self.functions[function_key] = function
|
|
3015
|
+
logger.trace(
|
|
3016
|
+
"Added function from Genie room",
|
|
3017
|
+
room=genie_room.name,
|
|
3018
|
+
function=function.name,
|
|
3019
|
+
key=function_key,
|
|
3020
|
+
)
|
|
3021
|
+
|
|
3022
|
+
return self
|
|
1928
3023
|
|
|
1929
3024
|
|
|
1930
3025
|
class AppConfig(BaseModel):
|
|
1931
3026
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1932
3027
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
3028
|
+
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
1933
3029
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
1934
3030
|
resources: Optional[ResourcesModel] = None
|
|
1935
3031
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
|
1936
3032
|
tools: dict[str, ToolModel] = Field(default_factory=dict)
|
|
1937
3033
|
guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
|
|
3034
|
+
middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
|
|
1938
3035
|
memory: Optional[MemoryModel] = None
|
|
1939
3036
|
prompts: dict[str, PromptModel] = Field(default_factory=dict)
|
|
1940
3037
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
@@ -1962,10 +3059,10 @@ class AppConfig(BaseModel):
|
|
|
1962
3059
|
|
|
1963
3060
|
def initialize(self) -> None:
|
|
1964
3061
|
from dao_ai.hooks.core import create_hooks
|
|
3062
|
+
from dao_ai.logging import configure_logging
|
|
1965
3063
|
|
|
1966
3064
|
if self.app and self.app.log_level:
|
|
1967
|
-
|
|
1968
|
-
logger.add(sys.stderr, level=self.app.log_level)
|
|
3065
|
+
configure_logging(level=self.app.log_level)
|
|
1969
3066
|
|
|
1970
3067
|
logger.debug("Calling initialization hooks...")
|
|
1971
3068
|
initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
|
|
@@ -2009,21 +3106,45 @@ class AppConfig(BaseModel):
|
|
|
2009
3106
|
def create_agent(
|
|
2010
3107
|
self,
|
|
2011
3108
|
w: WorkspaceClient | None = None,
|
|
3109
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3110
|
+
pat: str | None = None,
|
|
3111
|
+
client_id: str | None = None,
|
|
3112
|
+
client_secret: str | None = None,
|
|
3113
|
+
workspace_host: str | None = None,
|
|
2012
3114
|
) -> None:
|
|
2013
3115
|
from dao_ai.providers.base import ServiceProvider
|
|
2014
3116
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
2015
3117
|
|
|
2016
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3118
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3119
|
+
w=w,
|
|
3120
|
+
vsc=vsc,
|
|
3121
|
+
pat=pat,
|
|
3122
|
+
client_id=client_id,
|
|
3123
|
+
client_secret=client_secret,
|
|
3124
|
+
workspace_host=workspace_host,
|
|
3125
|
+
)
|
|
2017
3126
|
provider.create_agent(self)
|
|
2018
3127
|
|
|
2019
3128
|
def deploy_agent(
|
|
2020
3129
|
self,
|
|
2021
3130
|
w: WorkspaceClient | None = None,
|
|
3131
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3132
|
+
pat: str | None = None,
|
|
3133
|
+
client_id: str | None = None,
|
|
3134
|
+
client_secret: str | None = None,
|
|
3135
|
+
workspace_host: str | None = None,
|
|
2022
3136
|
) -> None:
|
|
2023
3137
|
from dao_ai.providers.base import ServiceProvider
|
|
2024
3138
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
2025
3139
|
|
|
2026
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3140
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3141
|
+
w=w,
|
|
3142
|
+
vsc=vsc,
|
|
3143
|
+
pat=pat,
|
|
3144
|
+
client_id=client_id,
|
|
3145
|
+
client_secret=client_secret,
|
|
3146
|
+
workspace_host=workspace_host,
|
|
3147
|
+
)
|
|
2027
3148
|
provider.deploy_agent(self)
|
|
2028
3149
|
|
|
2029
3150
|
def find_agents(
|