dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -12,6 +12,7 @@ from typing import (
|
|
|
12
12
|
Iterator,
|
|
13
13
|
Literal,
|
|
14
14
|
Optional,
|
|
15
|
+
Self,
|
|
15
16
|
Sequence,
|
|
16
17
|
TypeAlias,
|
|
17
18
|
Union,
|
|
@@ -22,22 +23,33 @@ from databricks.sdk.credentials_provider import (
|
|
|
22
23
|
CredentialsStrategy,
|
|
23
24
|
ModelServingUserCredentials,
|
|
24
25
|
)
|
|
26
|
+
from databricks.sdk.errors.platform import NotFound
|
|
25
27
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
|
+
from databricks.sdk.service.dashboards import GenieSpace
|
|
26
29
|
from databricks.sdk.service.database import DatabaseInstance
|
|
30
|
+
from databricks.sdk.service.sql import GetWarehouseResponse
|
|
27
31
|
from databricks.vector_search.client import VectorSearchClient
|
|
28
32
|
from databricks.vector_search.index import VectorSearchIndex
|
|
29
33
|
from databricks_langchain import (
|
|
34
|
+
ChatDatabricks,
|
|
35
|
+
DatabricksEmbeddings,
|
|
30
36
|
DatabricksFunctionClient,
|
|
31
37
|
)
|
|
38
|
+
from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
|
|
39
|
+
from langchain_core.embeddings import Embeddings
|
|
32
40
|
from langchain_core.language_models import LanguageModelLike
|
|
41
|
+
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
33
42
|
from langchain_core.runnables.base import RunnableLike
|
|
34
43
|
from langchain_openai import ChatOpenAI
|
|
35
44
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
36
45
|
from langgraph.graph.state import CompiledStateGraph
|
|
37
46
|
from langgraph.store.base import BaseStore
|
|
38
47
|
from loguru import logger
|
|
48
|
+
from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
49
|
+
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
39
50
|
from mlflow.models import ModelConfig
|
|
40
51
|
from mlflow.models.resources import (
|
|
52
|
+
DatabricksApp,
|
|
41
53
|
DatabricksFunction,
|
|
42
54
|
DatabricksGenieSpace,
|
|
43
55
|
DatabricksLakebase,
|
|
@@ -49,14 +61,20 @@ from mlflow.models.resources import (
|
|
|
49
61
|
DatabricksVectorSearchIndex,
|
|
50
62
|
)
|
|
51
63
|
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
64
|
+
from mlflow.types.responses import (
|
|
65
|
+
ResponsesAgentRequest,
|
|
66
|
+
)
|
|
52
67
|
from pydantic import (
|
|
53
68
|
BaseModel,
|
|
54
69
|
ConfigDict,
|
|
55
70
|
Field,
|
|
71
|
+
PrivateAttr,
|
|
56
72
|
field_serializer,
|
|
57
73
|
model_validator,
|
|
58
74
|
)
|
|
59
75
|
|
|
76
|
+
from dao_ai.utils import normalize_name
|
|
77
|
+
|
|
60
78
|
|
|
61
79
|
class HasValue(ABC):
|
|
62
80
|
@abstractmethod
|
|
@@ -75,27 +93,6 @@ class HasFullName(ABC):
|
|
|
75
93
|
def full_name(self) -> str: ...
|
|
76
94
|
|
|
77
95
|
|
|
78
|
-
class IsDatabricksResource(ABC):
|
|
79
|
-
on_behalf_of_user: Optional[bool] = False
|
|
80
|
-
|
|
81
|
-
@abstractmethod
|
|
82
|
-
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
@abstractmethod
|
|
86
|
-
def api_scopes(self) -> Sequence[str]: ...
|
|
87
|
-
|
|
88
|
-
@property
|
|
89
|
-
def workspace_client(self) -> WorkspaceClient:
|
|
90
|
-
credentials_strategy: CredentialsStrategy = None
|
|
91
|
-
if self.on_behalf_of_user:
|
|
92
|
-
credentials_strategy = ModelServingUserCredentials()
|
|
93
|
-
logger.debug(
|
|
94
|
-
f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
|
|
95
|
-
)
|
|
96
|
-
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
97
|
-
|
|
98
|
-
|
|
99
96
|
class EnvironmentVariableModel(BaseModel, HasValue):
|
|
100
97
|
model_config = ConfigDict(
|
|
101
98
|
frozen=True,
|
|
@@ -194,6 +191,162 @@ AnyVariable: TypeAlias = (
|
|
|
194
191
|
)
|
|
195
192
|
|
|
196
193
|
|
|
194
|
+
class ServicePrincipalModel(BaseModel):
|
|
195
|
+
model_config = ConfigDict(
|
|
196
|
+
frozen=True,
|
|
197
|
+
use_enum_values=True,
|
|
198
|
+
)
|
|
199
|
+
client_id: AnyVariable
|
|
200
|
+
client_secret: AnyVariable
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class IsDatabricksResource(ABC, BaseModel):
|
|
204
|
+
"""
|
|
205
|
+
Base class for Databricks resources with authentication support.
|
|
206
|
+
|
|
207
|
+
Authentication Options:
|
|
208
|
+
----------------------
|
|
209
|
+
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
210
|
+
calling user's identity via ModelServingUserCredentials.
|
|
211
|
+
|
|
212
|
+
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
213
|
+
(client_id + client_secret + workspace_host) for service principal auth.
|
|
214
|
+
|
|
215
|
+
3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
|
|
216
|
+
to authenticate with a personal access token.
|
|
217
|
+
|
|
218
|
+
4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
|
|
219
|
+
(environment variables, notebook context, etc.)
|
|
220
|
+
|
|
221
|
+
Authentication Priority:
|
|
222
|
+
1. OBO (on_behalf_of_user=True)
|
|
223
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
224
|
+
3. PAT (pat + workspace_host)
|
|
225
|
+
4. Ambient/default authentication
|
|
226
|
+
"""
|
|
227
|
+
|
|
228
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
229
|
+
|
|
230
|
+
on_behalf_of_user: Optional[bool] = False
|
|
231
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
232
|
+
client_id: Optional[AnyVariable] = None
|
|
233
|
+
client_secret: Optional[AnyVariable] = None
|
|
234
|
+
workspace_host: Optional[AnyVariable] = None
|
|
235
|
+
pat: Optional[AnyVariable] = None
|
|
236
|
+
|
|
237
|
+
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
+
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
+
|
|
240
|
+
@abstractmethod
|
|
241
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
242
|
+
|
|
243
|
+
@property
|
|
244
|
+
@abstractmethod
|
|
245
|
+
def api_scopes(self) -> Sequence[str]: ...
|
|
246
|
+
|
|
247
|
+
@model_validator(mode="after")
|
|
248
|
+
def _expand_service_principal(self) -> Self:
|
|
249
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
250
|
+
if self.service_principal is not None:
|
|
251
|
+
if self.client_id is None:
|
|
252
|
+
self.client_id = self.service_principal.client_id
|
|
253
|
+
if self.client_secret is None:
|
|
254
|
+
self.client_secret = self.service_principal.client_secret
|
|
255
|
+
return self
|
|
256
|
+
|
|
257
|
+
@model_validator(mode="after")
|
|
258
|
+
def _validate_auth_not_mixed(self) -> Self:
|
|
259
|
+
"""Validate that OAuth and PAT authentication are not both provided."""
|
|
260
|
+
has_oauth: bool = self.client_id is not None and self.client_secret is not None
|
|
261
|
+
has_pat: bool = self.pat is not None
|
|
262
|
+
|
|
263
|
+
if has_oauth and has_pat:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
"Cannot use both OAuth and user authentication methods. "
|
|
266
|
+
"Please provide either OAuth credentials or user credentials."
|
|
267
|
+
)
|
|
268
|
+
return self
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
272
|
+
"""
|
|
273
|
+
Get a WorkspaceClient configured with the appropriate authentication.
|
|
274
|
+
|
|
275
|
+
The client is lazily instantiated on first access and cached for subsequent calls.
|
|
276
|
+
|
|
277
|
+
Authentication priority:
|
|
278
|
+
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
279
|
+
2. If service principal credentials are configured (client_id, client_secret,
|
|
280
|
+
workspace_host), uses OAuth M2M
|
|
281
|
+
3. If PAT is configured, uses token authentication
|
|
282
|
+
4. Otherwise, uses default/ambient authentication
|
|
283
|
+
"""
|
|
284
|
+
# Return cached client if already instantiated
|
|
285
|
+
if self._workspace_client is not None:
|
|
286
|
+
return self._workspace_client
|
|
287
|
+
|
|
288
|
+
from dao_ai.utils import normalize_host
|
|
289
|
+
|
|
290
|
+
# Check for OBO first (highest priority)
|
|
291
|
+
if self.on_behalf_of_user:
|
|
292
|
+
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
293
|
+
logger.debug(
|
|
294
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
295
|
+
f"with OBO credentials strategy"
|
|
296
|
+
)
|
|
297
|
+
self._workspace_client = WorkspaceClient(
|
|
298
|
+
credentials_strategy=credentials_strategy
|
|
299
|
+
)
|
|
300
|
+
return self._workspace_client
|
|
301
|
+
|
|
302
|
+
# Check for service principal credentials
|
|
303
|
+
client_id_value: str | None = (
|
|
304
|
+
value_of(self.client_id) if self.client_id else None
|
|
305
|
+
)
|
|
306
|
+
client_secret_value: str | None = (
|
|
307
|
+
value_of(self.client_secret) if self.client_secret else None
|
|
308
|
+
)
|
|
309
|
+
workspace_host_value: str | None = (
|
|
310
|
+
normalize_host(value_of(self.workspace_host))
|
|
311
|
+
if self.workspace_host
|
|
312
|
+
else None
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if client_id_value and client_secret_value and workspace_host_value:
|
|
316
|
+
logger.debug(
|
|
317
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
318
|
+
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
319
|
+
)
|
|
320
|
+
self._workspace_client = WorkspaceClient(
|
|
321
|
+
host=workspace_host_value,
|
|
322
|
+
client_id=client_id_value,
|
|
323
|
+
client_secret=client_secret_value,
|
|
324
|
+
auth_type="oauth-m2m",
|
|
325
|
+
)
|
|
326
|
+
return self._workspace_client
|
|
327
|
+
|
|
328
|
+
# Check for PAT authentication
|
|
329
|
+
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
330
|
+
if pat_value:
|
|
331
|
+
logger.debug(
|
|
332
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
333
|
+
)
|
|
334
|
+
self._workspace_client = WorkspaceClient(
|
|
335
|
+
host=workspace_host_value,
|
|
336
|
+
token=pat_value,
|
|
337
|
+
auth_type="pat",
|
|
338
|
+
)
|
|
339
|
+
return self._workspace_client
|
|
340
|
+
|
|
341
|
+
# Default: use ambient authentication
|
|
342
|
+
logger.debug(
|
|
343
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
344
|
+
"with default/ambient authentication"
|
|
345
|
+
)
|
|
346
|
+
self._workspace_client = WorkspaceClient()
|
|
347
|
+
return self._workspace_client
|
|
348
|
+
|
|
349
|
+
|
|
197
350
|
class Privilege(str, Enum):
|
|
198
351
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
199
352
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -220,9 +373,21 @@ class Privilege(str, Enum):
|
|
|
220
373
|
|
|
221
374
|
class PermissionModel(BaseModel):
|
|
222
375
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
223
|
-
principals: list[str] = Field(default_factory=list)
|
|
376
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
224
377
|
privileges: list[Privilege]
|
|
225
378
|
|
|
379
|
+
@model_validator(mode="after")
|
|
380
|
+
def resolve_principals(self) -> Self:
|
|
381
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
382
|
+
resolved: list[str] = []
|
|
383
|
+
for principal in self.principals:
|
|
384
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
385
|
+
resolved.append(value_of(principal.client_id))
|
|
386
|
+
else:
|
|
387
|
+
resolved.append(principal)
|
|
388
|
+
self.principals = resolved
|
|
389
|
+
return self
|
|
390
|
+
|
|
226
391
|
|
|
227
392
|
class SchemaModel(BaseModel, HasFullName):
|
|
228
393
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -242,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
242
407
|
provider.create_schema(self)
|
|
243
408
|
|
|
244
409
|
|
|
245
|
-
class
|
|
410
|
+
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
411
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
412
|
+
name: str
|
|
413
|
+
url: str
|
|
414
|
+
|
|
415
|
+
@property
|
|
416
|
+
def full_name(self) -> str:
|
|
417
|
+
return self.name
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def api_scopes(self) -> Sequence[str]:
|
|
421
|
+
return ["apps.apps"]
|
|
422
|
+
|
|
423
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
424
|
+
return [
|
|
425
|
+
DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
class TableModel(IsDatabricksResource, HasFullName):
|
|
246
430
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
247
431
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
248
432
|
name: Optional[str] = None
|
|
@@ -268,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
268
452
|
def api_scopes(self) -> Sequence[str]:
|
|
269
453
|
return []
|
|
270
454
|
|
|
455
|
+
def exists(self) -> bool:
|
|
456
|
+
"""Check if the table exists in Unity Catalog.
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
True if the table exists, False otherwise.
|
|
460
|
+
"""
|
|
461
|
+
try:
|
|
462
|
+
self.workspace_client.tables.get(full_name=self.full_name)
|
|
463
|
+
return True
|
|
464
|
+
except NotFound:
|
|
465
|
+
logger.debug(f"Table not found: {self.full_name}")
|
|
466
|
+
return False
|
|
467
|
+
except Exception as e:
|
|
468
|
+
logger.warning(f"Error checking table existence for {self.full_name}: {e}")
|
|
469
|
+
return False
|
|
470
|
+
|
|
271
471
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
272
472
|
resources: list[DatabricksResource] = []
|
|
273
473
|
|
|
@@ -311,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
311
511
|
return resources
|
|
312
512
|
|
|
313
513
|
|
|
314
|
-
class LLMModel(
|
|
514
|
+
class LLMModel(IsDatabricksResource):
|
|
315
515
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
316
516
|
name: str
|
|
517
|
+
description: Optional[str] = None
|
|
317
518
|
temperature: Optional[float] = 0.1
|
|
318
519
|
max_tokens: Optional[int] = 8192
|
|
319
520
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
521
|
+
use_responses_api: Optional[bool] = Field(
|
|
522
|
+
default=False,
|
|
523
|
+
description="Use Responses API for ResponsesAgent endpoints",
|
|
524
|
+
)
|
|
320
525
|
|
|
321
526
|
@property
|
|
322
527
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -324,6 +529,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
324
529
|
"serving.serving-endpoints",
|
|
325
530
|
]
|
|
326
531
|
|
|
532
|
+
@property
|
|
533
|
+
def uri(self) -> str:
|
|
534
|
+
return f"databricks:/{self.name}"
|
|
535
|
+
|
|
327
536
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
328
537
|
return [
|
|
329
538
|
DatabricksServingEndpoint(
|
|
@@ -332,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
332
541
|
]
|
|
333
542
|
|
|
334
543
|
def as_chat_model(self) -> LanguageModelLike:
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
341
|
-
|
|
342
|
-
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
343
|
-
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
544
|
+
chat_client: LanguageModelLike = ChatDatabricks(
|
|
545
|
+
model=self.name,
|
|
546
|
+
temperature=self.temperature,
|
|
547
|
+
max_tokens=self.max_tokens,
|
|
548
|
+
use_responses_api=self.use_responses_api,
|
|
344
549
|
)
|
|
345
|
-
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
346
|
-
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
347
|
-
# )
|
|
348
550
|
|
|
349
551
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
350
552
|
for fallback in self.fallbacks:
|
|
@@ -376,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
376
578
|
|
|
377
579
|
return chat_client
|
|
378
580
|
|
|
581
|
+
def as_embeddings_model(self) -> Embeddings:
|
|
582
|
+
return DatabricksEmbeddings(endpoint=self.name)
|
|
583
|
+
|
|
379
584
|
|
|
380
585
|
class VectorSearchEndpointType(str, Enum):
|
|
381
586
|
STANDARD = "STANDARD"
|
|
@@ -387,8 +592,15 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
387
592
|
name: str
|
|
388
593
|
type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
|
|
389
594
|
|
|
595
|
+
@field_serializer("type")
|
|
596
|
+
def serialize_type(self, value: VectorSearchEndpointType) -> str:
|
|
597
|
+
"""Ensure enum is serialized to string value."""
|
|
598
|
+
if isinstance(value, VectorSearchEndpointType):
|
|
599
|
+
return value.value
|
|
600
|
+
return str(value)
|
|
601
|
+
|
|
390
602
|
|
|
391
|
-
class IndexModel(
|
|
603
|
+
class IndexModel(IsDatabricksResource, HasFullName):
|
|
392
604
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
393
605
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
394
606
|
name: str
|
|
@@ -413,12 +625,297 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
413
625
|
]
|
|
414
626
|
|
|
415
627
|
|
|
416
|
-
class
|
|
628
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
629
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
630
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
631
|
+
name: Optional[str] = None
|
|
632
|
+
|
|
633
|
+
@model_validator(mode="after")
|
|
634
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
635
|
+
if not self.name and not self.schema_model:
|
|
636
|
+
raise ValueError(
|
|
637
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
638
|
+
)
|
|
639
|
+
return self
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def full_name(self) -> str:
|
|
643
|
+
if self.schema_model:
|
|
644
|
+
name: str = ""
|
|
645
|
+
if self.name:
|
|
646
|
+
name = f".{self.name}"
|
|
647
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
648
|
+
return self.name
|
|
649
|
+
|
|
650
|
+
def exists(self) -> bool:
|
|
651
|
+
"""Check if the function exists in Unity Catalog.
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
True if the function exists, False otherwise.
|
|
655
|
+
"""
|
|
656
|
+
try:
|
|
657
|
+
self.workspace_client.functions.get(name=self.full_name)
|
|
658
|
+
return True
|
|
659
|
+
except NotFound:
|
|
660
|
+
logger.debug(f"Function not found: {self.full_name}")
|
|
661
|
+
return False
|
|
662
|
+
except Exception as e:
|
|
663
|
+
logger.warning(
|
|
664
|
+
f"Error checking function existence for {self.full_name}: {e}"
|
|
665
|
+
)
|
|
666
|
+
return False
|
|
667
|
+
|
|
668
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
669
|
+
resources: list[DatabricksResource] = []
|
|
670
|
+
if self.name:
|
|
671
|
+
resources.append(
|
|
672
|
+
DatabricksFunction(
|
|
673
|
+
function_name=self.full_name,
|
|
674
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
675
|
+
)
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
w: WorkspaceClient = self.workspace_client
|
|
679
|
+
schema_full_name: str = self.schema_model.full_name
|
|
680
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
681
|
+
catalog_name=self.schema_model.catalog_name,
|
|
682
|
+
schema_name=self.schema_model.schema_name,
|
|
683
|
+
)
|
|
684
|
+
resources.extend(
|
|
685
|
+
[
|
|
686
|
+
DatabricksFunction(
|
|
687
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
688
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
689
|
+
)
|
|
690
|
+
for function in functions
|
|
691
|
+
]
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
return resources
|
|
695
|
+
|
|
696
|
+
@property
|
|
697
|
+
def api_scopes(self) -> Sequence[str]:
|
|
698
|
+
return ["sql.statement-execution"]
|
|
699
|
+
|
|
700
|
+
|
|
701
|
+
class WarehouseModel(IsDatabricksResource):
|
|
702
|
+
model_config = ConfigDict()
|
|
703
|
+
name: str
|
|
704
|
+
description: Optional[str] = None
|
|
705
|
+
warehouse_id: AnyVariable
|
|
706
|
+
|
|
707
|
+
@property
|
|
708
|
+
def api_scopes(self) -> Sequence[str]:
|
|
709
|
+
return [
|
|
710
|
+
"sql.warehouses",
|
|
711
|
+
"sql.statement-execution",
|
|
712
|
+
]
|
|
713
|
+
|
|
714
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
715
|
+
return [
|
|
716
|
+
DatabricksSQLWarehouse(
|
|
717
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
718
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
719
|
+
)
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
@model_validator(mode="after")
|
|
723
|
+
def update_warehouse_id(self) -> Self:
|
|
724
|
+
self.warehouse_id = value_of(self.warehouse_id)
|
|
725
|
+
return self
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class GenieRoomModel(IsDatabricksResource):
|
|
417
729
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
418
730
|
name: str
|
|
419
731
|
description: Optional[str] = None
|
|
420
732
|
space_id: AnyVariable
|
|
421
733
|
|
|
734
|
+
_space_details: Optional[GenieSpace] = PrivateAttr(default=None)
|
|
735
|
+
|
|
736
|
+
def _get_space_details(self) -> GenieSpace:
|
|
737
|
+
if self._space_details is None:
|
|
738
|
+
self._space_details = self.workspace_client.genie.get_space(
|
|
739
|
+
space_id=self.space_id, include_serialized_space=True
|
|
740
|
+
)
|
|
741
|
+
return self._space_details
|
|
742
|
+
|
|
743
|
+
def _parse_serialized_space(self) -> dict[str, Any]:
|
|
744
|
+
"""Parse the serialized_space JSON string and return the parsed data."""
|
|
745
|
+
import json
|
|
746
|
+
|
|
747
|
+
space_details = self._get_space_details()
|
|
748
|
+
if not space_details.serialized_space:
|
|
749
|
+
return {}
|
|
750
|
+
|
|
751
|
+
try:
|
|
752
|
+
return json.loads(space_details.serialized_space)
|
|
753
|
+
except json.JSONDecodeError as e:
|
|
754
|
+
logger.warning(f"Failed to parse serialized_space: {e}")
|
|
755
|
+
return {}
|
|
756
|
+
|
|
757
|
+
@property
|
|
758
|
+
def warehouse(self) -> Optional[WarehouseModel]:
|
|
759
|
+
"""Extract warehouse information from the Genie space.
|
|
760
|
+
|
|
761
|
+
Returns:
|
|
762
|
+
WarehouseModel instance if warehouse_id is available, None otherwise.
|
|
763
|
+
"""
|
|
764
|
+
space_details: GenieSpace = self._get_space_details()
|
|
765
|
+
|
|
766
|
+
if not space_details.warehouse_id:
|
|
767
|
+
return None
|
|
768
|
+
|
|
769
|
+
try:
|
|
770
|
+
response: GetWarehouseResponse = self.workspace_client.warehouses.get(
|
|
771
|
+
space_details.warehouse_id
|
|
772
|
+
)
|
|
773
|
+
warehouse_name: str = response.name or space_details.warehouse_id
|
|
774
|
+
|
|
775
|
+
warehouse_model = WarehouseModel(
|
|
776
|
+
name=warehouse_name,
|
|
777
|
+
warehouse_id=space_details.warehouse_id,
|
|
778
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
779
|
+
service_principal=self.service_principal,
|
|
780
|
+
client_id=self.client_id,
|
|
781
|
+
client_secret=self.client_secret,
|
|
782
|
+
workspace_host=self.workspace_host,
|
|
783
|
+
pat=self.pat,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
# Share the cached workspace client if available
|
|
787
|
+
if self._workspace_client is not None:
|
|
788
|
+
warehouse_model._workspace_client = self._workspace_client
|
|
789
|
+
|
|
790
|
+
return warehouse_model
|
|
791
|
+
except Exception as e:
|
|
792
|
+
logger.warning(
|
|
793
|
+
f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
|
|
794
|
+
)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
@property
|
|
798
|
+
def tables(self) -> list[TableModel]:
|
|
799
|
+
"""Extract tables from the serialized Genie space.
|
|
800
|
+
|
|
801
|
+
Databricks Genie stores tables in: data_sources.tables[].identifier
|
|
802
|
+
Only includes tables that actually exist in Unity Catalog.
|
|
803
|
+
"""
|
|
804
|
+
parsed_space = self._parse_serialized_space()
|
|
805
|
+
tables_list: list[TableModel] = []
|
|
806
|
+
|
|
807
|
+
# Primary structure: data_sources.tables with 'identifier' field
|
|
808
|
+
if "data_sources" in parsed_space:
|
|
809
|
+
data_sources = parsed_space["data_sources"]
|
|
810
|
+
if isinstance(data_sources, dict) and "tables" in data_sources:
|
|
811
|
+
tables_data = data_sources["tables"]
|
|
812
|
+
if isinstance(tables_data, list):
|
|
813
|
+
for table_item in tables_data:
|
|
814
|
+
table_name: str | None = None
|
|
815
|
+
if isinstance(table_item, dict):
|
|
816
|
+
# Standard Databricks structure uses 'identifier'
|
|
817
|
+
table_name = table_item.get("identifier") or table_item.get(
|
|
818
|
+
"name"
|
|
819
|
+
)
|
|
820
|
+
elif isinstance(table_item, str):
|
|
821
|
+
table_name = table_item
|
|
822
|
+
|
|
823
|
+
if table_name:
|
|
824
|
+
table_model = TableModel(
|
|
825
|
+
name=table_name,
|
|
826
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
827
|
+
service_principal=self.service_principal,
|
|
828
|
+
client_id=self.client_id,
|
|
829
|
+
client_secret=self.client_secret,
|
|
830
|
+
workspace_host=self.workspace_host,
|
|
831
|
+
pat=self.pat,
|
|
832
|
+
)
|
|
833
|
+
# Share the cached workspace client if available
|
|
834
|
+
if self._workspace_client is not None:
|
|
835
|
+
table_model._workspace_client = self._workspace_client
|
|
836
|
+
|
|
837
|
+
# Verify the table exists before adding
|
|
838
|
+
if not table_model.exists():
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
tables_list.append(table_model)
|
|
842
|
+
|
|
843
|
+
return tables_list
|
|
844
|
+
|
|
845
|
+
@property
|
|
846
|
+
def functions(self) -> list[FunctionModel]:
|
|
847
|
+
"""Extract functions from the serialized Genie space.
|
|
848
|
+
|
|
849
|
+
Databricks Genie stores functions in multiple locations:
|
|
850
|
+
- instructions.sql_functions[].identifier (SQL functions)
|
|
851
|
+
- data_sources.functions[].identifier (other functions)
|
|
852
|
+
Only includes functions that actually exist in Unity Catalog.
|
|
853
|
+
"""
|
|
854
|
+
parsed_space = self._parse_serialized_space()
|
|
855
|
+
functions_list: list[FunctionModel] = []
|
|
856
|
+
seen_functions: set[str] = set()
|
|
857
|
+
|
|
858
|
+
def add_function_if_exists(function_name: str) -> None:
|
|
859
|
+
"""Helper to add a function if it exists and hasn't been added."""
|
|
860
|
+
if function_name in seen_functions:
|
|
861
|
+
return
|
|
862
|
+
|
|
863
|
+
seen_functions.add(function_name)
|
|
864
|
+
function_model = FunctionModel(
|
|
865
|
+
name=function_name,
|
|
866
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
867
|
+
service_principal=self.service_principal,
|
|
868
|
+
client_id=self.client_id,
|
|
869
|
+
client_secret=self.client_secret,
|
|
870
|
+
workspace_host=self.workspace_host,
|
|
871
|
+
pat=self.pat,
|
|
872
|
+
)
|
|
873
|
+
# Share the cached workspace client if available
|
|
874
|
+
if self._workspace_client is not None:
|
|
875
|
+
function_model._workspace_client = self._workspace_client
|
|
876
|
+
|
|
877
|
+
# Verify the function exists before adding
|
|
878
|
+
if not function_model.exists():
|
|
879
|
+
return
|
|
880
|
+
|
|
881
|
+
functions_list.append(function_model)
|
|
882
|
+
|
|
883
|
+
# Primary structure: instructions.sql_functions with 'identifier' field
|
|
884
|
+
if "instructions" in parsed_space:
|
|
885
|
+
instructions = parsed_space["instructions"]
|
|
886
|
+
if isinstance(instructions, dict) and "sql_functions" in instructions:
|
|
887
|
+
sql_functions_data = instructions["sql_functions"]
|
|
888
|
+
if isinstance(sql_functions_data, list):
|
|
889
|
+
for function_item in sql_functions_data:
|
|
890
|
+
if isinstance(function_item, dict):
|
|
891
|
+
# SQL functions use 'identifier' field
|
|
892
|
+
function_name = function_item.get(
|
|
893
|
+
"identifier"
|
|
894
|
+
) or function_item.get("name")
|
|
895
|
+
if function_name:
|
|
896
|
+
add_function_if_exists(function_name)
|
|
897
|
+
|
|
898
|
+
# Secondary structure: data_sources.functions with 'identifier' field
|
|
899
|
+
if "data_sources" in parsed_space:
|
|
900
|
+
data_sources = parsed_space["data_sources"]
|
|
901
|
+
if isinstance(data_sources, dict) and "functions" in data_sources:
|
|
902
|
+
functions_data = data_sources["functions"]
|
|
903
|
+
if isinstance(functions_data, list):
|
|
904
|
+
for function_item in functions_data:
|
|
905
|
+
function_name: str | None = None
|
|
906
|
+
if isinstance(function_item, dict):
|
|
907
|
+
# Standard Databricks structure uses 'identifier'
|
|
908
|
+
function_name = function_item.get(
|
|
909
|
+
"identifier"
|
|
910
|
+
) or function_item.get("name")
|
|
911
|
+
elif isinstance(function_item, str):
|
|
912
|
+
function_name = function_item
|
|
913
|
+
|
|
914
|
+
if function_name:
|
|
915
|
+
add_function_if_exists(function_name)
|
|
916
|
+
|
|
917
|
+
return functions_list
|
|
918
|
+
|
|
422
919
|
@property
|
|
423
920
|
def api_scopes(self) -> Sequence[str]:
|
|
424
921
|
return [
|
|
@@ -434,12 +931,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
434
931
|
]
|
|
435
932
|
|
|
436
933
|
@model_validator(mode="after")
|
|
437
|
-
def update_space_id(self):
|
|
934
|
+
def update_space_id(self) -> Self:
|
|
438
935
|
self.space_id = value_of(self.space_id)
|
|
439
936
|
return self
|
|
440
937
|
|
|
938
|
+
@model_validator(mode="after")
|
|
939
|
+
def update_description_from_space(self) -> Self:
|
|
940
|
+
"""Populate description from GenieSpace if not provided."""
|
|
941
|
+
if not self.description:
|
|
942
|
+
try:
|
|
943
|
+
space_details = self._get_space_details()
|
|
944
|
+
if space_details.description:
|
|
945
|
+
self.description = space_details.description
|
|
946
|
+
except Exception as e:
|
|
947
|
+
logger.debug(f"Could not fetch description from Genie space: {e}")
|
|
948
|
+
return self
|
|
441
949
|
|
|
442
|
-
|
|
950
|
+
|
|
951
|
+
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
443
952
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
444
953
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
445
954
|
name: str
|
|
@@ -499,7 +1008,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
499
1008
|
provider.create_path(self)
|
|
500
1009
|
|
|
501
1010
|
|
|
502
|
-
class VectorStoreModel(
|
|
1011
|
+
class VectorStoreModel(IsDatabricksResource):
|
|
503
1012
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
504
1013
|
embedding_model: Optional[LLMModel] = None
|
|
505
1014
|
index: Optional[IndexModel] = None
|
|
@@ -513,13 +1022,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
513
1022
|
embedding_source_column: str
|
|
514
1023
|
|
|
515
1024
|
@model_validator(mode="after")
|
|
516
|
-
def set_default_embedding_model(self):
|
|
1025
|
+
def set_default_embedding_model(self) -> Self:
|
|
517
1026
|
if not self.embedding_model:
|
|
518
1027
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
519
1028
|
return self
|
|
520
1029
|
|
|
521
1030
|
@model_validator(mode="after")
|
|
522
|
-
def set_default_primary_key(self):
|
|
1031
|
+
def set_default_primary_key(self) -> Self:
|
|
523
1032
|
if self.primary_key is None:
|
|
524
1033
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
525
1034
|
|
|
@@ -540,14 +1049,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
540
1049
|
return self
|
|
541
1050
|
|
|
542
1051
|
@model_validator(mode="after")
|
|
543
|
-
def set_default_index(self):
|
|
1052
|
+
def set_default_index(self) -> Self:
|
|
544
1053
|
if self.index is None:
|
|
545
1054
|
name: str = f"{self.source_table.name}_index"
|
|
546
1055
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
547
1056
|
return self
|
|
548
1057
|
|
|
549
1058
|
@model_validator(mode="after")
|
|
550
|
-
def set_default_endpoint(self):
|
|
1059
|
+
def set_default_endpoint(self) -> Self:
|
|
551
1060
|
if self.endpoint is None:
|
|
552
1061
|
from dao_ai.providers.databricks import (
|
|
553
1062
|
DatabricksProvider,
|
|
@@ -598,64 +1107,9 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
598
1107
|
provider.create_vector_store(self)
|
|
599
1108
|
|
|
600
1109
|
|
|
601
|
-
class
|
|
1110
|
+
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
602
1111
|
model_config = ConfigDict()
|
|
603
|
-
|
|
604
|
-
name: Optional[str] = None
|
|
605
|
-
|
|
606
|
-
@model_validator(mode="after")
|
|
607
|
-
def validate_name_or_schema_required(self) -> "FunctionModel":
|
|
608
|
-
if not self.name and not self.schema_model:
|
|
609
|
-
raise ValueError(
|
|
610
|
-
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
611
|
-
)
|
|
612
|
-
return self
|
|
613
|
-
|
|
614
|
-
@property
|
|
615
|
-
def full_name(self) -> str:
|
|
616
|
-
if self.schema_model:
|
|
617
|
-
name: str = ""
|
|
618
|
-
if self.name:
|
|
619
|
-
name = f".{self.name}"
|
|
620
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
621
|
-
return self.name
|
|
622
|
-
|
|
623
|
-
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
624
|
-
resources: list[DatabricksResource] = []
|
|
625
|
-
if self.name:
|
|
626
|
-
resources.append(
|
|
627
|
-
DatabricksFunction(
|
|
628
|
-
function_name=self.full_name,
|
|
629
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
630
|
-
)
|
|
631
|
-
)
|
|
632
|
-
else:
|
|
633
|
-
w: WorkspaceClient = self.workspace_client
|
|
634
|
-
schema_full_name: str = self.schema_model.full_name
|
|
635
|
-
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
636
|
-
catalog_name=self.schema_model.catalog_name,
|
|
637
|
-
schema_name=self.schema_model.schema_name,
|
|
638
|
-
)
|
|
639
|
-
resources.extend(
|
|
640
|
-
[
|
|
641
|
-
DatabricksFunction(
|
|
642
|
-
function_name=f"{schema_full_name}.{function.name}",
|
|
643
|
-
on_behalf_of_user=self.on_behalf_of_user,
|
|
644
|
-
)
|
|
645
|
-
for function in functions
|
|
646
|
-
]
|
|
647
|
-
)
|
|
648
|
-
|
|
649
|
-
return resources
|
|
650
|
-
|
|
651
|
-
@property
|
|
652
|
-
def api_scopes(self) -> Sequence[str]:
|
|
653
|
-
return ["sql.statement-execution"]
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
657
|
-
model_config = ConfigDict()
|
|
658
|
-
name: str
|
|
1112
|
+
name: str
|
|
659
1113
|
|
|
660
1114
|
@property
|
|
661
1115
|
def full_name(self) -> str:
|
|
@@ -680,34 +1134,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
680
1134
|
]
|
|
681
1135
|
|
|
682
1136
|
|
|
683
|
-
class
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
1137
|
+
class DatabaseModel(IsDatabricksResource):
|
|
1138
|
+
"""
|
|
1139
|
+
Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
|
|
1140
|
+
|
|
1141
|
+
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
1142
|
+
- user/password: For user-based database authentication
|
|
1143
|
+
|
|
1144
|
+
Connection Types (determined by fields provided):
|
|
1145
|
+
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1146
|
+
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1147
|
+
|
|
1148
|
+
Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
|
|
1149
|
+
|
|
1150
|
+
Example Databricks Lakebase with Service Principal:
|
|
1151
|
+
```yaml
|
|
1152
|
+
databases:
|
|
1153
|
+
my_lakebase:
|
|
1154
|
+
name: my-database
|
|
1155
|
+
instance_name: my-lakebase-instance
|
|
1156
|
+
service_principal:
|
|
1157
|
+
client_id:
|
|
1158
|
+
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
1159
|
+
client_secret:
|
|
1160
|
+
scope: my-scope
|
|
1161
|
+
secret: sp-client-secret
|
|
1162
|
+
workspace_host:
|
|
1163
|
+
env: DATABRICKS_HOST
|
|
1164
|
+
```
|
|
1165
|
+
|
|
1166
|
+
Example Databricks Lakebase with Ambient Authentication:
|
|
1167
|
+
```yaml
|
|
1168
|
+
databases:
|
|
1169
|
+
my_lakebase:
|
|
1170
|
+
name: my-database
|
|
1171
|
+
instance_name: my-lakebase-instance
|
|
1172
|
+
on_behalf_of_user: true
|
|
1173
|
+
```
|
|
1174
|
+
|
|
1175
|
+
Example Standard PostgreSQL:
|
|
1176
|
+
```yaml
|
|
1177
|
+
databases:
|
|
1178
|
+
my_postgres:
|
|
1179
|
+
name: my-database
|
|
1180
|
+
host: my-postgres-host.example.com
|
|
1181
|
+
port: 5432
|
|
1182
|
+
database: my_db
|
|
1183
|
+
user: my_user
|
|
1184
|
+
password:
|
|
1185
|
+
env: PGPASSWORD
|
|
1186
|
+
```
|
|
1187
|
+
"""
|
|
709
1188
|
|
|
710
|
-
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
711
1189
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
712
1190
|
name: str
|
|
713
1191
|
instance_name: Optional[str] = None
|
|
@@ -720,80 +1198,137 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
720
1198
|
timeout_seconds: Optional[int] = 10
|
|
721
1199
|
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
722
1200
|
node_count: Optional[int] = None
|
|
1201
|
+
# Database-specific auth (user identity for DB connection)
|
|
723
1202
|
user: Optional[AnyVariable] = None
|
|
724
1203
|
password: Optional[AnyVariable] = None
|
|
725
|
-
client_id: Optional[AnyVariable] = None
|
|
726
|
-
client_secret: Optional[AnyVariable] = None
|
|
727
|
-
workspace_host: Optional[AnyVariable] = None
|
|
728
1204
|
|
|
729
1205
|
@property
|
|
730
1206
|
def api_scopes(self) -> Sequence[str]:
|
|
731
|
-
return []
|
|
1207
|
+
return ["database.database-instances"]
|
|
1208
|
+
|
|
1209
|
+
@property
|
|
1210
|
+
def is_lakebase(self) -> bool:
|
|
1211
|
+
"""Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
|
|
1212
|
+
return self.instance_name is not None
|
|
732
1213
|
|
|
733
1214
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
1215
|
+
if self.is_lakebase:
|
|
1216
|
+
return [
|
|
1217
|
+
DatabricksLakebase(
|
|
1218
|
+
database_instance_name=self.instance_name,
|
|
1219
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
1220
|
+
)
|
|
1221
|
+
]
|
|
1222
|
+
return []
|
|
740
1223
|
|
|
741
1224
|
@model_validator(mode="after")
|
|
742
|
-
def
|
|
743
|
-
|
|
744
|
-
self.instance_name = self.name
|
|
1225
|
+
def validate_connection_type(self) -> Self:
|
|
1226
|
+
"""Validate connection configuration based on type.
|
|
745
1227
|
|
|
1228
|
+
- If instance_name is provided: Databricks Lakebase connection
|
|
1229
|
+
(host is optional - will be fetched from API if not provided)
|
|
1230
|
+
- If only host is provided: Standard PostgreSQL connection
|
|
1231
|
+
(must not have instance_name)
|
|
1232
|
+
"""
|
|
1233
|
+
if not self.instance_name and not self.host:
|
|
1234
|
+
raise ValueError(
|
|
1235
|
+
"Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
|
|
1236
|
+
)
|
|
746
1237
|
return self
|
|
747
1238
|
|
|
748
1239
|
@model_validator(mode="after")
|
|
749
|
-
def update_user(self):
|
|
750
|
-
if
|
|
1240
|
+
def update_user(self) -> Self:
|
|
1241
|
+
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
1242
|
+
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
751
1243
|
return self
|
|
752
1244
|
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
1245
|
+
# For standard PostgreSQL, we need explicit user credentials
|
|
1246
|
+
# For Lakebase with no auth, ambient auth is allowed
|
|
1247
|
+
if not self.is_lakebase:
|
|
1248
|
+
# Standard PostgreSQL - try to determine current user for local development
|
|
1249
|
+
try:
|
|
1250
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1251
|
+
except Exception as e:
|
|
1252
|
+
logger.warning(
|
|
1253
|
+
f"Could not determine current user for PostgreSQL database: {e}. "
|
|
1254
|
+
f"Please provide explicit user credentials."
|
|
1255
|
+
)
|
|
1256
|
+
else:
|
|
1257
|
+
# For Lakebase, try to determine current user but don't fail if we can't
|
|
1258
|
+
try:
|
|
1259
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
1260
|
+
except Exception:
|
|
1261
|
+
# If we can't determine user and no explicit auth, that's okay
|
|
1262
|
+
# for Lakebase with ambient auth - credentials will be injected at runtime
|
|
1263
|
+
pass
|
|
758
1264
|
|
|
759
1265
|
return self
|
|
760
1266
|
|
|
761
1267
|
@model_validator(mode="after")
|
|
762
|
-
def update_host(self):
|
|
1268
|
+
def update_host(self) -> Self:
|
|
763
1269
|
if self.host is not None:
|
|
764
1270
|
return self
|
|
765
1271
|
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
1272
|
+
# If instance_name is provided (Lakebase), try to fetch host from existing instance
|
|
1273
|
+
# This may fail for OBO/ambient auth during model logging (before deployment)
|
|
1274
|
+
if self.is_lakebase:
|
|
1275
|
+
try:
|
|
1276
|
+
existing_instance: DatabaseInstance = (
|
|
1277
|
+
self.workspace_client.database.get_database_instance(
|
|
1278
|
+
name=self.instance_name
|
|
1279
|
+
)
|
|
1280
|
+
)
|
|
1281
|
+
self.host = existing_instance.read_write_dns
|
|
1282
|
+
except Exception as e:
|
|
1283
|
+
# For Lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1284
|
+
# The host will need to be provided explicitly or fetched at runtime
|
|
1285
|
+
if self.on_behalf_of_user:
|
|
1286
|
+
logger.debug(
|
|
1287
|
+
f"Could not fetch host for database {self.instance_name} "
|
|
1288
|
+
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1289
|
+
)
|
|
1290
|
+
else:
|
|
1291
|
+
raise ValueError(
|
|
1292
|
+
f"Could not fetch host for database {self.instance_name}. "
|
|
1293
|
+
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1294
|
+
)
|
|
772
1295
|
return self
|
|
773
1296
|
|
|
774
1297
|
@model_validator(mode="after")
|
|
775
|
-
def validate_auth_methods(self):
|
|
1298
|
+
def validate_auth_methods(self) -> Self:
|
|
776
1299
|
oauth_fields: Sequence[Any] = [
|
|
777
1300
|
self.workspace_host,
|
|
778
1301
|
self.client_id,
|
|
779
1302
|
self.client_secret,
|
|
780
1303
|
]
|
|
781
1304
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1305
|
+
has_user_auth: bool = self.user is not None
|
|
1306
|
+
has_obo: bool = self.on_behalf_of_user is True
|
|
1307
|
+
has_pat: bool = self.pat is not None
|
|
782
1308
|
|
|
783
|
-
|
|
784
|
-
|
|
1309
|
+
# Count how many auth methods are configured
|
|
1310
|
+
auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
|
|
785
1311
|
|
|
786
|
-
if
|
|
1312
|
+
if auth_methods_count > 1:
|
|
787
1313
|
raise ValueError(
|
|
788
|
-
"Cannot
|
|
789
|
-
"Please provide
|
|
1314
|
+
"Cannot mix authentication methods. "
|
|
1315
|
+
"Please provide exactly one of: "
|
|
1316
|
+
"on_behalf_of_user=true (for passive auth in model serving), "
|
|
1317
|
+
"OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
|
|
1318
|
+
"PAT (personal access token), "
|
|
1319
|
+
"or user credentials (user)."
|
|
790
1320
|
)
|
|
791
1321
|
|
|
792
|
-
|
|
1322
|
+
# For standard PostgreSQL (host-based), at least one auth method must be configured
|
|
1323
|
+
# For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
|
|
1324
|
+
if not self.is_lakebase and auth_methods_count == 0:
|
|
793
1325
|
raise ValueError(
|
|
794
|
-
"
|
|
795
|
-
"
|
|
796
|
-
"
|
|
1326
|
+
"PostgreSQL databases require explicit authentication. "
|
|
1327
|
+
"Please provide one of: "
|
|
1328
|
+
"OAuth credentials (workspace_host, client_id, client_secret), "
|
|
1329
|
+
"service_principal with workspace_host, "
|
|
1330
|
+
"PAT (personal access token), "
|
|
1331
|
+
"or user credentials (user)."
|
|
797
1332
|
)
|
|
798
1333
|
|
|
799
1334
|
return self
|
|
@@ -804,38 +1339,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
804
1339
|
Get database connection parameters as a dictionary.
|
|
805
1340
|
|
|
806
1341
|
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
807
|
-
|
|
808
|
-
|
|
1342
|
+
|
|
1343
|
+
For Lakebase: Uses Databricks-generated credentials (token-based auth).
|
|
1344
|
+
For standard PostgreSQL: Uses provided user/password credentials.
|
|
809
1345
|
"""
|
|
810
|
-
|
|
811
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
1346
|
+
import uuid as _uuid
|
|
812
1347
|
|
|
1348
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
1349
|
+
|
|
1350
|
+
host: str
|
|
1351
|
+
port: int
|
|
1352
|
+
database: str
|
|
813
1353
|
username: str | None = None
|
|
1354
|
+
password_value: str | None = None
|
|
1355
|
+
|
|
1356
|
+
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1357
|
+
host_value: Any = self.host
|
|
1358
|
+
if host_value is None and self.is_lakebase and self.on_behalf_of_user:
|
|
1359
|
+
# Fetch host at runtime for OBO mode
|
|
1360
|
+
existing_instance: DatabaseInstance = (
|
|
1361
|
+
self.workspace_client.database.get_database_instance(
|
|
1362
|
+
name=self.instance_name
|
|
1363
|
+
)
|
|
1364
|
+
)
|
|
1365
|
+
host_value = existing_instance.read_write_dns
|
|
814
1366
|
|
|
815
|
-
if
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
1367
|
+
if host_value is None:
|
|
1368
|
+
instance_or_name = self.instance_name if self.is_lakebase else self.name
|
|
1369
|
+
raise ValueError(
|
|
1370
|
+
f"Database host not configured for {instance_or_name}. "
|
|
1371
|
+
"Please provide 'host' explicitly."
|
|
1372
|
+
)
|
|
819
1373
|
|
|
820
|
-
host
|
|
821
|
-
port
|
|
822
|
-
database
|
|
1374
|
+
host = value_of(host_value)
|
|
1375
|
+
port = value_of(self.port)
|
|
1376
|
+
database = value_of(self.database)
|
|
823
1377
|
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
client_secret
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
1378
|
+
if self.is_lakebase:
|
|
1379
|
+
# Lakebase: Use Databricks-generated credentials
|
|
1380
|
+
if self.client_id and self.client_secret and self.workspace_host:
|
|
1381
|
+
username = value_of(self.client_id)
|
|
1382
|
+
elif self.user:
|
|
1383
|
+
username = value_of(self.user)
|
|
1384
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
830
1385
|
|
|
831
|
-
|
|
1386
|
+
# Generate Databricks database credential (token)
|
|
1387
|
+
w: WorkspaceClient = self.workspace_client
|
|
1388
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1389
|
+
request_id=str(_uuid.uuid4()),
|
|
1390
|
+
instance_names=[self.instance_name],
|
|
1391
|
+
)
|
|
1392
|
+
password_value = cred.token
|
|
1393
|
+
else:
|
|
1394
|
+
# Standard PostgreSQL: Use provided credentials
|
|
1395
|
+
if self.user:
|
|
1396
|
+
username = value_of(self.user)
|
|
1397
|
+
if self.password:
|
|
1398
|
+
password_value = value_of(self.password)
|
|
1399
|
+
|
|
1400
|
+
if not username or not password_value:
|
|
1401
|
+
raise ValueError(
|
|
1402
|
+
f"Standard PostgreSQL databases require both 'user' and 'password'. "
|
|
1403
|
+
f"Database: {self.name}"
|
|
1404
|
+
)
|
|
832
1405
|
|
|
833
1406
|
# Build connection parameters dictionary
|
|
834
1407
|
params: dict[str, Any] = {
|
|
835
1408
|
"dbname": database,
|
|
836
1409
|
"host": host,
|
|
837
1410
|
"port": port,
|
|
838
|
-
"password":
|
|
1411
|
+
"password": password_value,
|
|
839
1412
|
"sslmode": "require",
|
|
840
1413
|
}
|
|
841
1414
|
|
|
@@ -866,11 +1439,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
866
1439
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
867
1440
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
868
1441
|
|
|
869
|
-
|
|
1442
|
+
# Use provided workspace client or fall back to resource's own workspace_client
|
|
1443
|
+
if w is None:
|
|
1444
|
+
w = self.workspace_client
|
|
1445
|
+
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
870
1446
|
provider.create_lakebase(self)
|
|
871
1447
|
provider.create_lakebase_instance_role(self)
|
|
872
1448
|
|
|
873
1449
|
|
|
1450
|
+
class GenieLRUCacheParametersModel(BaseModel):
|
|
1451
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1452
|
+
capacity: int = 1000
|
|
1453
|
+
time_to_live_seconds: int | None = (
|
|
1454
|
+
60 * 60 * 24
|
|
1455
|
+
) # 1 day default, None or negative = never expires
|
|
1456
|
+
warehouse: WarehouseModel
|
|
1457
|
+
|
|
1458
|
+
|
|
1459
|
+
class GenieSemanticCacheParametersModel(BaseModel):
|
|
1460
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1461
|
+
time_to_live_seconds: int | None = (
|
|
1462
|
+
60 * 60 * 24
|
|
1463
|
+
) # 1 day default, None or negative = never expires
|
|
1464
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1465
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1466
|
+
question_weight: Optional[float] = (
|
|
1467
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1468
|
+
)
|
|
1469
|
+
context_weight: Optional[float] = (
|
|
1470
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1471
|
+
)
|
|
1472
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1473
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1474
|
+
database: DatabaseModel
|
|
1475
|
+
warehouse: WarehouseModel
|
|
1476
|
+
table_name: str = "genie_semantic_cache"
|
|
1477
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1478
|
+
max_context_tokens: int = (
|
|
1479
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1480
|
+
)
|
|
1481
|
+
|
|
1482
|
+
@model_validator(mode="after")
|
|
1483
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1484
|
+
"""
|
|
1485
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1486
|
+
|
|
1487
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1488
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1489
|
+
If both are provided, they must sum to 1.0.
|
|
1490
|
+
"""
|
|
1491
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1492
|
+
# Both missing - use defaults
|
|
1493
|
+
self.question_weight = 0.6
|
|
1494
|
+
self.context_weight = 0.4
|
|
1495
|
+
elif self.question_weight is None:
|
|
1496
|
+
# Compute question_weight from context_weight
|
|
1497
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1498
|
+
raise ValueError(
|
|
1499
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1500
|
+
)
|
|
1501
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1502
|
+
elif self.context_weight is None:
|
|
1503
|
+
# Compute context_weight from question_weight
|
|
1504
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1505
|
+
raise ValueError(
|
|
1506
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1507
|
+
)
|
|
1508
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1509
|
+
else:
|
|
1510
|
+
# Both provided - validate they sum to 1.0
|
|
1511
|
+
total_weight = self.question_weight + self.context_weight
|
|
1512
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1513
|
+
raise ValueError(
|
|
1514
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1515
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1516
|
+
f"of question vs context similarity in the combined score."
|
|
1517
|
+
)
|
|
1518
|
+
|
|
1519
|
+
return self
|
|
1520
|
+
|
|
1521
|
+
|
|
874
1522
|
class SearchParametersModel(BaseModel):
|
|
875
1523
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
876
1524
|
num_results: Optional[int] = 10
|
|
@@ -878,6 +1526,55 @@ class SearchParametersModel(BaseModel):
|
|
|
878
1526
|
query_type: Optional[str] = "ANN"
|
|
879
1527
|
|
|
880
1528
|
|
|
1529
|
+
class RerankParametersModel(BaseModel):
|
|
1530
|
+
"""
|
|
1531
|
+
Configuration for reranking retrieved documents using FlashRank.
|
|
1532
|
+
|
|
1533
|
+
FlashRank provides fast, local reranking without API calls using lightweight
|
|
1534
|
+
cross-encoder models. Reranking improves retrieval quality by reordering results
|
|
1535
|
+
based on semantic relevance to the query.
|
|
1536
|
+
|
|
1537
|
+
Typical workflow:
|
|
1538
|
+
1. Retrieve more documents than needed (e.g., 50 via num_results)
|
|
1539
|
+
2. Rerank all retrieved documents
|
|
1540
|
+
3. Return top_n best matches (e.g., 5)
|
|
1541
|
+
|
|
1542
|
+
Example:
|
|
1543
|
+
```yaml
|
|
1544
|
+
retriever:
|
|
1545
|
+
search_parameters:
|
|
1546
|
+
num_results: 50 # Retrieve more candidates
|
|
1547
|
+
rerank:
|
|
1548
|
+
model: ms-marco-MiniLM-L-12-v2
|
|
1549
|
+
top_n: 5 # Return top 5 after reranking
|
|
1550
|
+
```
|
|
1551
|
+
|
|
1552
|
+
Available models (from fastest to most accurate):
|
|
1553
|
+
- "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
|
|
1554
|
+
- "ms-marco-MiniLM-L-6-v2"
|
|
1555
|
+
- "ms-marco-MiniLM-L-12-v2" (default, good balance)
|
|
1556
|
+
- "rank-T5-flan" (most accurate, slower)
|
|
1557
|
+
"""
|
|
1558
|
+
|
|
1559
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1560
|
+
|
|
1561
|
+
model: str = Field(
|
|
1562
|
+
default="ms-marco-MiniLM-L-12-v2",
|
|
1563
|
+
description="FlashRank model name. Default provides good balance of speed and accuracy.",
|
|
1564
|
+
)
|
|
1565
|
+
top_n: Optional[int] = Field(
|
|
1566
|
+
default=None,
|
|
1567
|
+
description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
|
|
1568
|
+
)
|
|
1569
|
+
cache_dir: Optional[str] = Field(
|
|
1570
|
+
default="~/.dao_ai/cache/flashrank",
|
|
1571
|
+
description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
|
|
1572
|
+
)
|
|
1573
|
+
columns: Optional[list[str]] = Field(
|
|
1574
|
+
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
1575
|
+
)
|
|
1576
|
+
|
|
1577
|
+
|
|
881
1578
|
class RetrieverModel(BaseModel):
|
|
882
1579
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
883
1580
|
vector_store: VectorStoreModel
|
|
@@ -885,14 +1582,25 @@ class RetrieverModel(BaseModel):
|
|
|
885
1582
|
search_parameters: SearchParametersModel = Field(
|
|
886
1583
|
default_factory=SearchParametersModel
|
|
887
1584
|
)
|
|
1585
|
+
rerank: Optional[RerankParametersModel | bool] = Field(
|
|
1586
|
+
default=None,
|
|
1587
|
+
description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
|
|
1588
|
+
)
|
|
888
1589
|
|
|
889
1590
|
@model_validator(mode="after")
|
|
890
|
-
def set_default_columns(self):
|
|
1591
|
+
def set_default_columns(self) -> Self:
|
|
891
1592
|
if not self.columns:
|
|
892
1593
|
columns: Sequence[str] = self.vector_store.columns
|
|
893
1594
|
self.columns = columns
|
|
894
1595
|
return self
|
|
895
1596
|
|
|
1597
|
+
@model_validator(mode="after")
|
|
1598
|
+
def set_default_reranker(self) -> Self:
|
|
1599
|
+
"""Convert bool to ReRankParametersModel with defaults."""
|
|
1600
|
+
if isinstance(self.rerank, bool) and self.rerank:
|
|
1601
|
+
self.rerank = RerankParametersModel()
|
|
1602
|
+
return self
|
|
1603
|
+
|
|
896
1604
|
|
|
897
1605
|
class FunctionType(str, Enum):
|
|
898
1606
|
PYTHON = "python"
|
|
@@ -901,28 +1609,47 @@ class FunctionType(str, Enum):
|
|
|
901
1609
|
MCP = "mcp"
|
|
902
1610
|
|
|
903
1611
|
|
|
904
|
-
class
|
|
905
|
-
"""
|
|
1612
|
+
class HumanInTheLoopModel(BaseModel):
|
|
1613
|
+
"""
|
|
1614
|
+
Configuration for Human-in-the-Loop tool approval.
|
|
906
1615
|
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
RESPONSE = "response"
|
|
910
|
-
DECLINE = "decline"
|
|
1616
|
+
This model configures when and how tools require human approval before execution.
|
|
1617
|
+
It maps to LangChain's HumanInTheLoopMiddleware.
|
|
911
1618
|
|
|
1619
|
+
LangChain supports three decision types:
|
|
1620
|
+
- "approve": Execute tool with original arguments
|
|
1621
|
+
- "edit": Modify arguments before execution
|
|
1622
|
+
- "reject": Skip execution with optional feedback message
|
|
1623
|
+
"""
|
|
912
1624
|
|
|
913
|
-
class HumanInTheLoopModel(BaseModel):
|
|
914
1625
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
1626
|
+
|
|
1627
|
+
review_prompt: Optional[str] = Field(
|
|
1628
|
+
default=None,
|
|
1629
|
+
description="Message shown to the reviewer when approval is requested",
|
|
1630
|
+
)
|
|
1631
|
+
|
|
1632
|
+
allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
|
|
1633
|
+
default_factory=lambda: ["approve", "edit", "reject"],
|
|
1634
|
+
description="List of allowed decision types for this tool",
|
|
923
1635
|
)
|
|
924
|
-
|
|
925
|
-
|
|
1636
|
+
|
|
1637
|
+
@model_validator(mode="after")
|
|
1638
|
+
def validate_and_normalize_decisions(self) -> Self:
|
|
1639
|
+
"""Validate and normalize allowed decisions."""
|
|
1640
|
+
if not self.allowed_decisions:
|
|
1641
|
+
raise ValueError("At least one decision type must be allowed")
|
|
1642
|
+
|
|
1643
|
+
# Remove duplicates while preserving order
|
|
1644
|
+
seen = set()
|
|
1645
|
+
unique_decisions = []
|
|
1646
|
+
for decision in self.allowed_decisions:
|
|
1647
|
+
if decision not in seen:
|
|
1648
|
+
seen.add(decision)
|
|
1649
|
+
unique_decisions.append(decision)
|
|
1650
|
+
self.allowed_decisions = unique_decisions
|
|
1651
|
+
|
|
1652
|
+
return self
|
|
926
1653
|
|
|
927
1654
|
|
|
928
1655
|
class BaseFunctionModel(ABC, BaseModel):
|
|
@@ -931,7 +1658,6 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
931
1658
|
discriminator="type",
|
|
932
1659
|
)
|
|
933
1660
|
type: FunctionType
|
|
934
|
-
name: str
|
|
935
1661
|
human_in_the_loop: Optional[HumanInTheLoopModel] = None
|
|
936
1662
|
|
|
937
1663
|
@abstractmethod
|
|
@@ -948,6 +1674,7 @@ class BaseFunctionModel(ABC, BaseModel):
|
|
|
948
1674
|
class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
949
1675
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
950
1676
|
type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
|
|
1677
|
+
name: str
|
|
951
1678
|
|
|
952
1679
|
@property
|
|
953
1680
|
def full_name(self) -> str:
|
|
@@ -961,8 +1688,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
961
1688
|
|
|
962
1689
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
963
1690
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
964
|
-
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
965
1691
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
1692
|
+
name: str
|
|
1693
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
966
1694
|
|
|
967
1695
|
@property
|
|
968
1696
|
def full_name(self) -> str:
|
|
@@ -974,7 +1702,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
974
1702
|
return [create_factory_tool(self, **kwargs)]
|
|
975
1703
|
|
|
976
1704
|
@model_validator(mode="after")
|
|
977
|
-
def update_args(self):
|
|
1705
|
+
def update_args(self) -> Self:
|
|
978
1706
|
for key, value in self.args.items():
|
|
979
1707
|
self.args[key] = value_of(value)
|
|
980
1708
|
return self
|
|
@@ -985,24 +1713,148 @@ class TransportType(str, Enum):
|
|
|
985
1713
|
STDIO = "stdio"
|
|
986
1714
|
|
|
987
1715
|
|
|
988
|
-
class McpFunctionModel(BaseFunctionModel,
|
|
1716
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
1717
|
+
"""
|
|
1718
|
+
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1719
|
+
|
|
1720
|
+
Authentication for MCP connections uses the same options as other resources:
|
|
1721
|
+
- Service Principal (client_id + client_secret + workspace_host)
|
|
1722
|
+
- PAT (pat + workspace_host)
|
|
1723
|
+
- OBO (on_behalf_of_user)
|
|
1724
|
+
"""
|
|
1725
|
+
|
|
989
1726
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
990
1727
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
991
|
-
|
|
992
1728
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
993
1729
|
command: Optional[str] = "python"
|
|
994
1730
|
url: Optional[AnyVariable] = None
|
|
995
|
-
connection: Optional[ConnectionModel] = None
|
|
996
1731
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
997
1732
|
args: list[str] = Field(default_factory=list)
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1733
|
+
# MCP-specific fields
|
|
1734
|
+
connection: Optional[ConnectionModel] = None
|
|
1735
|
+
functions: Optional[SchemaModel] = None
|
|
1736
|
+
genie_room: Optional[GenieRoomModel] = None
|
|
1737
|
+
sql: Optional[bool] = None
|
|
1738
|
+
vector_search: Optional[VectorStoreModel] = None
|
|
1002
1739
|
|
|
1003
1740
|
@property
|
|
1004
|
-
def
|
|
1005
|
-
|
|
1741
|
+
def api_scopes(self) -> Sequence[str]:
|
|
1742
|
+
"""API scopes for MCP connections."""
|
|
1743
|
+
return [
|
|
1744
|
+
"serving.serving-endpoints",
|
|
1745
|
+
"mcp.genie",
|
|
1746
|
+
"mcp.functions",
|
|
1747
|
+
"mcp.vectorsearch",
|
|
1748
|
+
"mcp.external",
|
|
1749
|
+
]
|
|
1750
|
+
|
|
1751
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
1752
|
+
"""MCP functions don't declare static resources."""
|
|
1753
|
+
return []
|
|
1754
|
+
|
|
1755
|
+
def _get_workspace_host(self) -> str:
|
|
1756
|
+
"""
|
|
1757
|
+
Get the workspace host, either from config or from workspace client.
|
|
1758
|
+
|
|
1759
|
+
If connection is provided, uses its workspace client.
|
|
1760
|
+
Otherwise, falls back to the default Databricks host.
|
|
1761
|
+
|
|
1762
|
+
Returns:
|
|
1763
|
+
str: The workspace host URL with https:// scheme and without trailing slash
|
|
1764
|
+
"""
|
|
1765
|
+
from dao_ai.utils import get_default_databricks_host, normalize_host
|
|
1766
|
+
|
|
1767
|
+
# Try to get workspace_host from config
|
|
1768
|
+
workspace_host: str | None = (
|
|
1769
|
+
normalize_host(value_of(self.workspace_host))
|
|
1770
|
+
if self.workspace_host
|
|
1771
|
+
else None
|
|
1772
|
+
)
|
|
1773
|
+
|
|
1774
|
+
# If no workspace_host in config, get it from workspace client
|
|
1775
|
+
if not workspace_host:
|
|
1776
|
+
# Use connection's workspace client if available
|
|
1777
|
+
if self.connection:
|
|
1778
|
+
workspace_host = normalize_host(
|
|
1779
|
+
self.connection.workspace_client.config.host
|
|
1780
|
+
)
|
|
1781
|
+
else:
|
|
1782
|
+
# get_default_databricks_host already normalizes the host
|
|
1783
|
+
workspace_host = get_default_databricks_host()
|
|
1784
|
+
|
|
1785
|
+
if not workspace_host:
|
|
1786
|
+
raise ValueError(
|
|
1787
|
+
"Could not determine workspace host. "
|
|
1788
|
+
"Please set workspace_host in config or DATABRICKS_HOST environment variable."
|
|
1789
|
+
)
|
|
1790
|
+
|
|
1791
|
+
# Remove trailing slash
|
|
1792
|
+
return workspace_host.rstrip("/")
|
|
1793
|
+
|
|
1794
|
+
@property
|
|
1795
|
+
def mcp_url(self) -> str:
|
|
1796
|
+
"""
|
|
1797
|
+
Get the MCP URL for this function.
|
|
1798
|
+
|
|
1799
|
+
Returns the URL based on the configured source:
|
|
1800
|
+
- If url is set, returns it directly
|
|
1801
|
+
- If connection is set, constructs URL from connection
|
|
1802
|
+
- If genie_room is set, constructs Genie MCP URL
|
|
1803
|
+
- If sql is set, constructs DBSQL MCP URL (serverless)
|
|
1804
|
+
- If vector_search is set, constructs Vector Search MCP URL
|
|
1805
|
+
- If functions is set, constructs UC Functions MCP URL
|
|
1806
|
+
|
|
1807
|
+
URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
|
|
1808
|
+
- Genie: https://{host}/api/2.0/mcp/genie/{space_id}
|
|
1809
|
+
- DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
|
|
1810
|
+
- Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
|
|
1811
|
+
- UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
|
|
1812
|
+
- Connection: https://{host}/api/2.0/mcp/external/{connection_name}
|
|
1813
|
+
"""
|
|
1814
|
+
# Direct URL provided
|
|
1815
|
+
if self.url:
|
|
1816
|
+
return self.url
|
|
1817
|
+
|
|
1818
|
+
# Get workspace host (from config, connection, or default workspace client)
|
|
1819
|
+
workspace_host: str = self._get_workspace_host()
|
|
1820
|
+
|
|
1821
|
+
# UC Connection
|
|
1822
|
+
if self.connection:
|
|
1823
|
+
connection_name: str = self.connection.name
|
|
1824
|
+
return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
|
|
1825
|
+
|
|
1826
|
+
# Genie Room
|
|
1827
|
+
if self.genie_room:
|
|
1828
|
+
space_id: str = value_of(self.genie_room.space_id)
|
|
1829
|
+
return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
|
|
1830
|
+
|
|
1831
|
+
# DBSQL MCP server (serverless, workspace-level)
|
|
1832
|
+
if self.sql:
|
|
1833
|
+
return f"{workspace_host}/api/2.0/mcp/sql"
|
|
1834
|
+
|
|
1835
|
+
# Vector Search
|
|
1836
|
+
if self.vector_search:
|
|
1837
|
+
if (
|
|
1838
|
+
not self.vector_search.index
|
|
1839
|
+
or not self.vector_search.index.schema_model
|
|
1840
|
+
):
|
|
1841
|
+
raise ValueError(
|
|
1842
|
+
"vector_search must have an index with a schema (catalog/schema) configured"
|
|
1843
|
+
)
|
|
1844
|
+
catalog: str = self.vector_search.index.schema_model.catalog_name
|
|
1845
|
+
schema: str = self.vector_search.index.schema_model.schema_name
|
|
1846
|
+
return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
|
|
1847
|
+
|
|
1848
|
+
# UC Functions MCP server
|
|
1849
|
+
if self.functions:
|
|
1850
|
+
catalog: str = self.functions.catalog_name
|
|
1851
|
+
schema: str = self.functions.schema_name
|
|
1852
|
+
return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
|
|
1853
|
+
|
|
1854
|
+
raise ValueError(
|
|
1855
|
+
"No URL source configured. Provide one of: url, connection, genie_room, "
|
|
1856
|
+
"sql, vector_search, or functions"
|
|
1857
|
+
)
|
|
1006
1858
|
|
|
1007
1859
|
@field_serializer("transport")
|
|
1008
1860
|
def serialize_transport(self, value) -> str:
|
|
@@ -1011,71 +1863,65 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1011
1863
|
return str(value)
|
|
1012
1864
|
|
|
1013
1865
|
@model_validator(mode="after")
|
|
1014
|
-
def validate_mutually_exclusive(self):
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
)
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1866
|
+
def validate_mutually_exclusive(self) -> "McpFunctionModel":
|
|
1867
|
+
"""Validate that exactly one URL source is provided."""
|
|
1868
|
+
# Count how many URL sources are provided
|
|
1869
|
+
url_sources: list[tuple[str, Any]] = [
|
|
1870
|
+
("url", self.url),
|
|
1871
|
+
("connection", self.connection),
|
|
1872
|
+
("genie_room", self.genie_room),
|
|
1873
|
+
("sql", self.sql),
|
|
1874
|
+
("vector_search", self.vector_search),
|
|
1875
|
+
("functions", self.functions),
|
|
1876
|
+
]
|
|
1877
|
+
|
|
1878
|
+
provided_sources: list[str] = [
|
|
1879
|
+
name for name, value in url_sources if value is not None
|
|
1880
|
+
]
|
|
1881
|
+
|
|
1882
|
+
if self.transport == TransportType.STREAMABLE_HTTP:
|
|
1883
|
+
if len(provided_sources) == 0:
|
|
1884
|
+
raise ValueError(
|
|
1885
|
+
"For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
|
|
1886
|
+
"url, connection, genie_room, sql, vector_search, or functions"
|
|
1887
|
+
)
|
|
1888
|
+
if len(provided_sources) > 1:
|
|
1889
|
+
raise ValueError(
|
|
1890
|
+
f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
|
|
1891
|
+
f"Found: {', '.join(provided_sources)}. "
|
|
1892
|
+
f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
|
|
1893
|
+
)
|
|
1894
|
+
|
|
1895
|
+
if self.transport == TransportType.STDIO:
|
|
1896
|
+
if not self.command:
|
|
1897
|
+
raise ValueError("command must be provided for STDIO transport")
|
|
1898
|
+
if not self.args:
|
|
1899
|
+
raise ValueError("args must be provided for STDIO transport")
|
|
1900
|
+
|
|
1025
1901
|
return self
|
|
1026
1902
|
|
|
1027
1903
|
@model_validator(mode="after")
|
|
1028
|
-
def update_url(self):
|
|
1904
|
+
def update_url(self) -> "McpFunctionModel":
|
|
1029
1905
|
self.url = value_of(self.url)
|
|
1030
1906
|
return self
|
|
1031
1907
|
|
|
1032
1908
|
@model_validator(mode="after")
|
|
1033
|
-
def update_headers(self):
|
|
1909
|
+
def update_headers(self) -> "McpFunctionModel":
|
|
1034
1910
|
for key, value in self.headers.items():
|
|
1035
1911
|
self.headers[key] = value_of(value)
|
|
1036
1912
|
return self
|
|
1037
1913
|
|
|
1038
|
-
@model_validator(mode="after")
|
|
1039
|
-
def validate_auth_methods(self):
|
|
1040
|
-
oauth_fields: Sequence[Any] = [
|
|
1041
|
-
self.client_id,
|
|
1042
|
-
self.client_secret,
|
|
1043
|
-
]
|
|
1044
|
-
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1045
|
-
|
|
1046
|
-
pat_fields: Sequence[Any] = [self.pat]
|
|
1047
|
-
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
1048
|
-
|
|
1049
|
-
if has_oauth and has_user_auth:
|
|
1050
|
-
raise ValueError(
|
|
1051
|
-
"Cannot use both OAuth and user authentication methods. "
|
|
1052
|
-
"Please provide either OAuth credentials or user credentials."
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
if (has_oauth or has_user_auth) and not self.workspace_host:
|
|
1056
|
-
raise ValueError(
|
|
1057
|
-
"Workspace host must be provided when using OAuth or user credentials."
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
|
-
return self
|
|
1061
|
-
|
|
1062
1914
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1063
1915
|
from dao_ai.tools import create_mcp_tools
|
|
1064
1916
|
|
|
1065
1917
|
return create_mcp_tools(self)
|
|
1066
1918
|
|
|
1067
1919
|
|
|
1068
|
-
class UnityCatalogFunctionModel(BaseFunctionModel
|
|
1920
|
+
class UnityCatalogFunctionModel(BaseFunctionModel):
|
|
1069
1921
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1070
|
-
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1071
|
-
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1072
1922
|
type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
def full_name(self) -> str:
|
|
1076
|
-
if self.schema_model:
|
|
1077
|
-
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1078
|
-
return self.name
|
|
1923
|
+
resource: FunctionModel
|
|
1924
|
+
partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1079
1925
|
|
|
1080
1926
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1081
1927
|
from dao_ai.tools import create_uc_tools
|
|
@@ -1100,12 +1946,97 @@ class ToolModel(BaseModel):
|
|
|
1100
1946
|
function: AnyTool
|
|
1101
1947
|
|
|
1102
1948
|
|
|
1103
|
-
class
|
|
1104
|
-
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1949
|
+
class PromptModel(BaseModel, HasFullName):
|
|
1950
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1951
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1952
|
+
name: str
|
|
1953
|
+
description: Optional[str] = None
|
|
1954
|
+
default_template: Optional[str] = None
|
|
1955
|
+
alias: Optional[str] = None
|
|
1956
|
+
version: Optional[int] = None
|
|
1957
|
+
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1958
|
+
auto_register: bool = Field(
|
|
1959
|
+
default=False,
|
|
1960
|
+
description="Whether to automatically register the default_template to the prompt registry. "
|
|
1961
|
+
"If False, the prompt will only be loaded from the registry (never created/updated). "
|
|
1962
|
+
"Defaults to True for backward compatibility.",
|
|
1963
|
+
)
|
|
1964
|
+
|
|
1965
|
+
@property
|
|
1966
|
+
def template(self) -> str:
|
|
1967
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1968
|
+
|
|
1969
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
1970
|
+
prompt_version = provider.get_prompt(self)
|
|
1971
|
+
return prompt_version.to_single_brace_format()
|
|
1972
|
+
|
|
1973
|
+
@property
|
|
1974
|
+
def full_name(self) -> str:
|
|
1975
|
+
prompt_name: str = self.name
|
|
1976
|
+
if self.schema_model:
|
|
1977
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1978
|
+
return prompt_name
|
|
1979
|
+
|
|
1980
|
+
@property
|
|
1981
|
+
def uri(self) -> str:
|
|
1982
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1983
|
+
|
|
1984
|
+
if self.alias:
|
|
1985
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1986
|
+
elif self.version:
|
|
1987
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1988
|
+
else:
|
|
1989
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1990
|
+
|
|
1991
|
+
return prompt_uri
|
|
1992
|
+
|
|
1993
|
+
def as_prompt(self) -> PromptVersion:
|
|
1994
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1995
|
+
return prompt_version
|
|
1996
|
+
|
|
1997
|
+
@model_validator(mode="after")
|
|
1998
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1999
|
+
if self.alias and self.version:
|
|
2000
|
+
raise ValueError("Cannot specify both alias and version")
|
|
2001
|
+
return self
|
|
2002
|
+
|
|
2003
|
+
|
|
2004
|
+
class GuardrailModel(BaseModel):
|
|
2005
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2006
|
+
name: str
|
|
2007
|
+
model: str | LLMModel
|
|
2008
|
+
prompt: str | PromptModel
|
|
2009
|
+
num_retries: Optional[int] = 3
|
|
2010
|
+
|
|
2011
|
+
@model_validator(mode="after")
|
|
2012
|
+
def validate_llm_model(self) -> Self:
|
|
2013
|
+
if isinstance(self.model, str):
|
|
2014
|
+
self.model = LLMModel(name=self.model)
|
|
2015
|
+
return self
|
|
2016
|
+
|
|
2017
|
+
|
|
2018
|
+
class MiddlewareModel(BaseModel):
|
|
2019
|
+
"""Configuration for middleware that can be applied to agents.
|
|
2020
|
+
|
|
2021
|
+
Middleware is defined at the AppConfig level and can be referenced by name
|
|
2022
|
+
in agent configurations using YAML anchors for reusability.
|
|
2023
|
+
"""
|
|
2024
|
+
|
|
2025
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2026
|
+
name: str = Field(
|
|
2027
|
+
description="Fully qualified name of the middleware factory function"
|
|
2028
|
+
)
|
|
2029
|
+
args: dict[str, Any] = Field(
|
|
2030
|
+
default_factory=dict,
|
|
2031
|
+
description="Arguments to pass to the middleware factory function",
|
|
2032
|
+
)
|
|
2033
|
+
|
|
2034
|
+
@model_validator(mode="after")
|
|
2035
|
+
def resolve_args(self) -> Self:
|
|
2036
|
+
"""Resolve any variable references in args."""
|
|
2037
|
+
for key, value in self.args.items():
|
|
2038
|
+
self.args[key] = value_of(value)
|
|
2039
|
+
return self
|
|
1109
2040
|
|
|
1110
2041
|
|
|
1111
2042
|
class StorageType(str, Enum):
|
|
@@ -1116,14 +2047,12 @@ class StorageType(str, Enum):
|
|
|
1116
2047
|
class CheckpointerModel(BaseModel):
|
|
1117
2048
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1118
2049
|
name: str
|
|
1119
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1120
2050
|
database: Optional[DatabaseModel] = None
|
|
1121
2051
|
|
|
1122
|
-
@
|
|
1123
|
-
def
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
return self
|
|
2052
|
+
@property
|
|
2053
|
+
def storage_type(self) -> StorageType:
|
|
2054
|
+
"""Infer storage type from database presence."""
|
|
2055
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1127
2056
|
|
|
1128
2057
|
def as_checkpointer(self) -> BaseCheckpointSaver:
|
|
1129
2058
|
from dao_ai.memory import CheckpointManager
|
|
@@ -1139,16 +2068,14 @@ class StoreModel(BaseModel):
|
|
|
1139
2068
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1140
2069
|
name: str
|
|
1141
2070
|
embedding_model: Optional[LLMModel] = None
|
|
1142
|
-
type: Optional[StorageType] = StorageType.MEMORY
|
|
1143
2071
|
dims: Optional[int] = 1536
|
|
1144
2072
|
database: Optional[DatabaseModel] = None
|
|
1145
2073
|
namespace: Optional[str] = None
|
|
1146
2074
|
|
|
1147
|
-
@
|
|
1148
|
-
def
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
return self
|
|
2075
|
+
@property
|
|
2076
|
+
def storage_type(self) -> StorageType:
|
|
2077
|
+
"""Infer storage type from database presence."""
|
|
2078
|
+
return StorageType.POSTGRES if self.database else StorageType.MEMORY
|
|
1152
2079
|
|
|
1153
2080
|
def as_store(self) -> BaseStore:
|
|
1154
2081
|
from dao_ai.memory import StoreManager
|
|
@@ -1166,41 +2093,158 @@ class MemoryModel(BaseModel):
|
|
|
1166
2093
|
FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
|
|
1167
2094
|
|
|
1168
2095
|
|
|
1169
|
-
class
|
|
2096
|
+
class ResponseFormatModel(BaseModel):
|
|
2097
|
+
"""
|
|
2098
|
+
Configuration for structured response formats.
|
|
2099
|
+
|
|
2100
|
+
The response_schema field accepts either a type or a string:
|
|
2101
|
+
- Type (Pydantic model, dataclass, etc.): Used directly for structured output
|
|
2102
|
+
- String: First attempts to load as a fully qualified type name, falls back to JSON schema string
|
|
2103
|
+
|
|
2104
|
+
This unified approach simplifies the API while maintaining flexibility.
|
|
2105
|
+
"""
|
|
2106
|
+
|
|
1170
2107
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1177
|
-
|
|
2108
|
+
use_tool: Optional[bool] = Field(
|
|
2109
|
+
default=None,
|
|
2110
|
+
description=(
|
|
2111
|
+
"Strategy for structured output: "
|
|
2112
|
+
"None (default) = auto-detect from model capabilities, "
|
|
2113
|
+
"False = force ProviderStrategy (native), "
|
|
2114
|
+
"True = force ToolStrategy (function calling)"
|
|
2115
|
+
),
|
|
2116
|
+
)
|
|
2117
|
+
response_schema: Optional[str | type] = Field(
|
|
2118
|
+
default=None,
|
|
2119
|
+
description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
|
|
2120
|
+
)
|
|
1178
2121
|
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
2122
|
+
def as_strategy(self) -> ProviderStrategy | ToolStrategy:
|
|
2123
|
+
"""
|
|
2124
|
+
Convert response_schema to appropriate LangChain strategy.
|
|
1182
2125
|
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
2126
|
+
Returns:
|
|
2127
|
+
- None if no response_schema configured
|
|
2128
|
+
- Raw schema/type for auto-detection (when use_tool=None)
|
|
2129
|
+
- ToolStrategy wrapping the schema (when use_tool=True)
|
|
2130
|
+
- ProviderStrategy wrapping the schema (when use_tool=False)
|
|
1186
2131
|
|
|
1187
|
-
|
|
1188
|
-
|
|
1189
|
-
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
2132
|
+
Raises:
|
|
2133
|
+
ValueError: If response_schema is a JSON schema string that cannot be parsed
|
|
2134
|
+
"""
|
|
2135
|
+
|
|
2136
|
+
if self.response_schema is None:
|
|
2137
|
+
return None
|
|
2138
|
+
|
|
2139
|
+
schema = self.response_schema
|
|
2140
|
+
|
|
2141
|
+
# Handle type schemas (Pydantic, dataclass, etc.)
|
|
2142
|
+
if self.is_type_schema:
|
|
2143
|
+
if self.use_tool is None:
|
|
2144
|
+
# Auto-detect: Pass schema directly, let LangChain decide
|
|
2145
|
+
return schema
|
|
2146
|
+
elif self.use_tool is True:
|
|
2147
|
+
# Force ToolStrategy (function calling)
|
|
2148
|
+
return ToolStrategy(schema)
|
|
2149
|
+
else: # use_tool is False
|
|
2150
|
+
# Force ProviderStrategy (native structured output)
|
|
2151
|
+
return ProviderStrategy(schema)
|
|
2152
|
+
|
|
2153
|
+
# Handle JSON schema strings
|
|
2154
|
+
elif self.is_json_schema:
|
|
2155
|
+
import json
|
|
2156
|
+
|
|
2157
|
+
try:
|
|
2158
|
+
schema_dict = json.loads(schema)
|
|
2159
|
+
except json.JSONDecodeError as e:
|
|
2160
|
+
raise ValueError(f"Invalid JSON schema string: {e}") from e
|
|
2161
|
+
|
|
2162
|
+
# Apply same use_tool logic as type schemas
|
|
2163
|
+
if self.use_tool is None:
|
|
2164
|
+
# Auto-detect
|
|
2165
|
+
return schema_dict
|
|
2166
|
+
elif self.use_tool is True:
|
|
2167
|
+
# Force ToolStrategy
|
|
2168
|
+
return ToolStrategy(schema_dict)
|
|
2169
|
+
else: # use_tool is False
|
|
2170
|
+
# Force ProviderStrategy
|
|
2171
|
+
return ProviderStrategy(schema_dict)
|
|
2172
|
+
|
|
2173
|
+
return None
|
|
1195
2174
|
|
|
1196
2175
|
@model_validator(mode="after")
|
|
1197
|
-
def
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
2176
|
+
def validate_response_schema(self) -> Self:
|
|
2177
|
+
"""
|
|
2178
|
+
Validate and convert response_schema.
|
|
2179
|
+
|
|
2180
|
+
Processing logic:
|
|
2181
|
+
1. If None: no response format specified
|
|
2182
|
+
2. If type: use directly as structured output type
|
|
2183
|
+
3. If str: try to load as FQN using type_from_fqn
|
|
2184
|
+
- Success: response_schema becomes the loaded type
|
|
2185
|
+
- Failure: keep as string (treated as JSON schema)
|
|
2186
|
+
|
|
2187
|
+
After validation, response_schema is one of:
|
|
2188
|
+
- None (no schema)
|
|
2189
|
+
- type (use for structured output)
|
|
2190
|
+
- str (JSON schema)
|
|
2191
|
+
|
|
2192
|
+
Returns:
|
|
2193
|
+
Self with validated response_schema
|
|
2194
|
+
"""
|
|
2195
|
+
if self.response_schema is None:
|
|
2196
|
+
return self
|
|
2197
|
+
|
|
2198
|
+
# If already a type, return
|
|
2199
|
+
if isinstance(self.response_schema, type):
|
|
2200
|
+
return self
|
|
2201
|
+
|
|
2202
|
+
# If it's a string, try to load as type, fallback to json_schema
|
|
2203
|
+
if isinstance(self.response_schema, str):
|
|
2204
|
+
from dao_ai.utils import type_from_fqn
|
|
2205
|
+
|
|
2206
|
+
try:
|
|
2207
|
+
resolved_type = type_from_fqn(self.response_schema)
|
|
2208
|
+
self.response_schema = resolved_type
|
|
2209
|
+
logger.debug(
|
|
2210
|
+
f"Resolved response_schema string to type: {resolved_type}"
|
|
2211
|
+
)
|
|
2212
|
+
return self
|
|
2213
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
2214
|
+
# Keep as string - it's a JSON schema
|
|
2215
|
+
logger.debug(
|
|
2216
|
+
f"Could not resolve '{self.response_schema}' as type: {e}. "
|
|
2217
|
+
f"Treating as JSON schema string."
|
|
2218
|
+
)
|
|
2219
|
+
return self
|
|
2220
|
+
|
|
2221
|
+
# Invalid type
|
|
2222
|
+
raise ValueError(
|
|
2223
|
+
f"response_schema must be None, type, or str, got {type(self.response_schema)}"
|
|
2224
|
+
)
|
|
2225
|
+
|
|
2226
|
+
@property
|
|
2227
|
+
def is_type_schema(self) -> bool:
|
|
2228
|
+
"""Returns True if response_schema is a type (not JSON schema string)."""
|
|
2229
|
+
return isinstance(self.response_schema, type)
|
|
2230
|
+
|
|
2231
|
+
@property
|
|
2232
|
+
def is_json_schema(self) -> bool:
|
|
2233
|
+
"""Returns True if response_schema is a JSON schema string (not a type)."""
|
|
2234
|
+
return isinstance(self.response_schema, str)
|
|
1201
2235
|
|
|
1202
2236
|
|
|
1203
2237
|
class AgentModel(BaseModel):
|
|
2238
|
+
"""
|
|
2239
|
+
Configuration model for an agent in the DAO AI framework.
|
|
2240
|
+
|
|
2241
|
+
Agents combine an LLM with tools and middleware to create systems that can
|
|
2242
|
+
reason about tasks, decide which tools to use, and iteratively work towards solutions.
|
|
2243
|
+
|
|
2244
|
+
Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
|
|
2245
|
+
providing a more flexible and composable way to customize agent behavior.
|
|
2246
|
+
"""
|
|
2247
|
+
|
|
1204
2248
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1205
2249
|
name: str
|
|
1206
2250
|
description: Optional[str] = None
|
|
@@ -1209,9 +2253,54 @@ class AgentModel(BaseModel):
|
|
|
1209
2253
|
guardrails: list[GuardrailModel] = Field(default_factory=list)
|
|
1210
2254
|
prompt: Optional[str | PromptModel] = None
|
|
1211
2255
|
handoff_prompt: Optional[str] = None
|
|
1212
|
-
|
|
1213
|
-
|
|
1214
|
-
|
|
2256
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2257
|
+
default_factory=list,
|
|
2258
|
+
description="List of middleware to apply to this agent",
|
|
2259
|
+
)
|
|
2260
|
+
response_format: Optional[ResponseFormatModel | type | str] = None
|
|
2261
|
+
|
|
2262
|
+
@model_validator(mode="after")
|
|
2263
|
+
def validate_response_format(self) -> Self:
|
|
2264
|
+
"""
|
|
2265
|
+
Validate and normalize response_format.
|
|
2266
|
+
|
|
2267
|
+
Accepts:
|
|
2268
|
+
- None (no response format)
|
|
2269
|
+
- ResponseFormatModel (already validated)
|
|
2270
|
+
- type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
|
|
2271
|
+
- str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
|
|
2272
|
+
|
|
2273
|
+
ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
|
|
2274
|
+
"""
|
|
2275
|
+
if self.response_format is None or isinstance(
|
|
2276
|
+
self.response_format, ResponseFormatModel
|
|
2277
|
+
):
|
|
2278
|
+
return self
|
|
2279
|
+
|
|
2280
|
+
# Convert type or str to ResponseFormatModel
|
|
2281
|
+
# ResponseFormatModel's validator will handle the smart type loading and fallback
|
|
2282
|
+
if isinstance(self.response_format, (type, str)):
|
|
2283
|
+
self.response_format = ResponseFormatModel(
|
|
2284
|
+
response_schema=self.response_format
|
|
2285
|
+
)
|
|
2286
|
+
return self
|
|
2287
|
+
|
|
2288
|
+
# Invalid type
|
|
2289
|
+
raise ValueError(
|
|
2290
|
+
f"response_format must be None, ResponseFormatModel, type, or str, "
|
|
2291
|
+
f"got {type(self.response_format)}"
|
|
2292
|
+
)
|
|
2293
|
+
|
|
2294
|
+
def as_runnable(self) -> RunnableLike:
|
|
2295
|
+
from dao_ai.nodes import create_agent_node
|
|
2296
|
+
|
|
2297
|
+
return create_agent_node(self)
|
|
2298
|
+
|
|
2299
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
2300
|
+
from dao_ai.models import create_responses_agent
|
|
2301
|
+
|
|
2302
|
+
graph: CompiledStateGraph = self.as_runnable()
|
|
2303
|
+
return create_responses_agent(graph)
|
|
1215
2304
|
|
|
1216
2305
|
|
|
1217
2306
|
class SupervisorModel(BaseModel):
|
|
@@ -1219,12 +2308,20 @@ class SupervisorModel(BaseModel):
|
|
|
1219
2308
|
model: LLMModel
|
|
1220
2309
|
tools: list[ToolModel] = Field(default_factory=list)
|
|
1221
2310
|
prompt: Optional[str] = None
|
|
2311
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2312
|
+
default_factory=list,
|
|
2313
|
+
description="List of middleware to apply to the supervisor",
|
|
2314
|
+
)
|
|
1222
2315
|
|
|
1223
2316
|
|
|
1224
2317
|
class SwarmModel(BaseModel):
|
|
1225
2318
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1226
2319
|
model: LLMModel
|
|
1227
2320
|
default_agent: Optional[AgentModel | str] = None
|
|
2321
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2322
|
+
default_factory=list,
|
|
2323
|
+
description="List of middleware to apply to all agents in the swarm",
|
|
2324
|
+
)
|
|
1228
2325
|
handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
|
|
1229
2326
|
default_factory=dict
|
|
1230
2327
|
)
|
|
@@ -1237,7 +2334,7 @@ class OrchestrationModel(BaseModel):
|
|
|
1237
2334
|
memory: Optional[MemoryModel] = None
|
|
1238
2335
|
|
|
1239
2336
|
@model_validator(mode="after")
|
|
1240
|
-
def validate_mutually_exclusive(self):
|
|
2337
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1241
2338
|
if self.supervisor is not None and self.swarm is not None:
|
|
1242
2339
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
1243
2340
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -1267,9 +2364,21 @@ class Entitlement(str, Enum):
|
|
|
1267
2364
|
|
|
1268
2365
|
class AppPermissionModel(BaseModel):
|
|
1269
2366
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1270
|
-
principals: list[str] = Field(default_factory=list)
|
|
2367
|
+
principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
|
|
1271
2368
|
entitlements: list[Entitlement]
|
|
1272
2369
|
|
|
2370
|
+
@model_validator(mode="after")
|
|
2371
|
+
def resolve_principals(self) -> Self:
|
|
2372
|
+
"""Resolve ServicePrincipalModel objects to their client_id."""
|
|
2373
|
+
resolved: list[str] = []
|
|
2374
|
+
for principal in self.principals:
|
|
2375
|
+
if isinstance(principal, ServicePrincipalModel):
|
|
2376
|
+
resolved.append(value_of(principal.client_id))
|
|
2377
|
+
else:
|
|
2378
|
+
resolved.append(principal)
|
|
2379
|
+
self.principals = resolved
|
|
2380
|
+
return self
|
|
2381
|
+
|
|
1273
2382
|
|
|
1274
2383
|
class LogLevel(str, Enum):
|
|
1275
2384
|
TRACE = "TRACE"
|
|
@@ -1330,27 +2439,81 @@ class ChatPayload(BaseModel):
|
|
|
1330
2439
|
|
|
1331
2440
|
return self
|
|
1332
2441
|
|
|
2442
|
+
@model_validator(mode="after")
|
|
2443
|
+
def ensure_thread_id(self) -> "ChatPayload":
|
|
2444
|
+
"""Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
|
|
2445
|
+
import uuid
|
|
2446
|
+
|
|
2447
|
+
if self.custom_inputs is None:
|
|
2448
|
+
self.custom_inputs = {}
|
|
2449
|
+
|
|
2450
|
+
# Get or create configurable section
|
|
2451
|
+
configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
|
|
2452
|
+
|
|
2453
|
+
# Check if thread_id or conversation_id exists
|
|
2454
|
+
has_thread_id = configurable.get("thread_id") is not None
|
|
2455
|
+
has_conversation_id = configurable.get("conversation_id") is not None
|
|
2456
|
+
|
|
2457
|
+
# If neither is provided, generate a UUID for conversation_id
|
|
2458
|
+
if not has_thread_id and not has_conversation_id:
|
|
2459
|
+
configurable["conversation_id"] = str(uuid.uuid4())
|
|
2460
|
+
self.custom_inputs["configurable"] = configurable
|
|
2461
|
+
|
|
2462
|
+
return self
|
|
2463
|
+
|
|
2464
|
+
def as_messages(self) -> Sequence[BaseMessage]:
|
|
2465
|
+
return messages_from_dict(
|
|
2466
|
+
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
2467
|
+
)
|
|
2468
|
+
|
|
2469
|
+
def as_agent_request(self) -> ResponsesAgentRequest:
|
|
2470
|
+
from mlflow.types.responses_helpers import Message as _Message
|
|
2471
|
+
|
|
2472
|
+
return ResponsesAgentRequest(
|
|
2473
|
+
input=[_Message(role=m.role, content=m.content) for m in self.messages],
|
|
2474
|
+
custom_inputs=self.custom_inputs,
|
|
2475
|
+
)
|
|
2476
|
+
|
|
1333
2477
|
|
|
1334
2478
|
class ChatHistoryModel(BaseModel):
|
|
2479
|
+
"""
|
|
2480
|
+
Configuration for chat history summarization.
|
|
2481
|
+
|
|
2482
|
+
Attributes:
|
|
2483
|
+
model: The LLM to use for generating summaries.
|
|
2484
|
+
max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
|
|
2485
|
+
After summarization, recent messages totaling up to this many tokens are preserved.
|
|
2486
|
+
max_tokens_before_summary: Token threshold that triggers summarization.
|
|
2487
|
+
When conversation exceeds this, summarization runs. Mutually exclusive with
|
|
2488
|
+
max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
|
|
2489
|
+
max_messages_before_summary: Message count threshold that triggers summarization.
|
|
2490
|
+
When conversation exceeds this many messages, summarization runs.
|
|
2491
|
+
Mutually exclusive with max_tokens_before_summary.
|
|
2492
|
+
"""
|
|
2493
|
+
|
|
1335
2494
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1336
2495
|
model: LLMModel
|
|
1337
|
-
max_tokens: int =
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
2496
|
+
max_tokens: int = Field(
|
|
2497
|
+
default=2048,
|
|
2498
|
+
gt=0,
|
|
2499
|
+
description="Maximum tokens to keep after summarization",
|
|
2500
|
+
)
|
|
2501
|
+
max_tokens_before_summary: Optional[int] = Field(
|
|
2502
|
+
default=None,
|
|
2503
|
+
gt=0,
|
|
2504
|
+
description="Token threshold that triggers summarization",
|
|
2505
|
+
)
|
|
2506
|
+
max_messages_before_summary: Optional[int] = Field(
|
|
2507
|
+
default=None,
|
|
2508
|
+
gt=0,
|
|
2509
|
+
description="Message count threshold that triggers summarization",
|
|
2510
|
+
)
|
|
1349
2511
|
|
|
1350
2512
|
|
|
1351
2513
|
class AppModel(BaseModel):
|
|
1352
2514
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1353
2515
|
name: str
|
|
2516
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
1354
2517
|
description: Optional[str] = None
|
|
1355
2518
|
log_level: Optional[LogLevel] = "WARNING"
|
|
1356
2519
|
registered_model: RegisteredModelModel
|
|
@@ -1371,23 +2534,54 @@ class AppModel(BaseModel):
|
|
|
1371
2534
|
shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1372
2535
|
default_factory=list
|
|
1373
2536
|
)
|
|
1374
|
-
message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1375
|
-
default_factory=list
|
|
1376
|
-
)
|
|
1377
2537
|
input_example: Optional[ChatPayload] = None
|
|
1378
2538
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1379
2539
|
code_paths: list[str] = Field(default_factory=list)
|
|
1380
2540
|
pip_requirements: list[str] = Field(default_factory=list)
|
|
2541
|
+
python_version: Optional[str] = Field(
|
|
2542
|
+
default="3.12",
|
|
2543
|
+
description="Python version for Model Serving deployment. Defaults to 3.12 "
|
|
2544
|
+
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
2545
|
+
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
2546
|
+
)
|
|
1381
2547
|
|
|
1382
2548
|
@model_validator(mode="after")
|
|
1383
|
-
def
|
|
2549
|
+
def set_databricks_env_vars(self) -> Self:
|
|
2550
|
+
"""Set Databricks environment variables for Model Serving.
|
|
2551
|
+
|
|
2552
|
+
Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
|
|
2553
|
+
Values explicitly provided in environment_vars take precedence.
|
|
2554
|
+
"""
|
|
2555
|
+
from dao_ai.utils import get_default_databricks_host
|
|
2556
|
+
|
|
2557
|
+
# Set DATABRICKS_HOST if not already provided
|
|
2558
|
+
if "DATABRICKS_HOST" not in self.environment_vars:
|
|
2559
|
+
host: str | None = get_default_databricks_host()
|
|
2560
|
+
if host:
|
|
2561
|
+
self.environment_vars["DATABRICKS_HOST"] = host
|
|
2562
|
+
|
|
2563
|
+
# Set service principal credentials if provided
|
|
2564
|
+
if self.service_principal is not None:
|
|
2565
|
+
if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
|
|
2566
|
+
self.environment_vars["DATABRICKS_CLIENT_ID"] = (
|
|
2567
|
+
self.service_principal.client_id
|
|
2568
|
+
)
|
|
2569
|
+
if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
|
|
2570
|
+
self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
|
|
2571
|
+
self.service_principal.client_secret
|
|
2572
|
+
)
|
|
2573
|
+
return self
|
|
2574
|
+
|
|
2575
|
+
@model_validator(mode="after")
|
|
2576
|
+
def validate_agents_not_empty(self) -> Self:
|
|
1384
2577
|
if not self.agents:
|
|
1385
2578
|
raise ValueError("At least one agent must be specified")
|
|
1386
2579
|
return self
|
|
1387
2580
|
|
|
1388
2581
|
@model_validator(mode="after")
|
|
1389
|
-
def
|
|
2582
|
+
def resolve_environment_vars(self) -> Self:
|
|
1390
2583
|
for key, value in self.environment_vars.items():
|
|
2584
|
+
updated_value: str
|
|
1391
2585
|
if isinstance(value, SecretVariableModel):
|
|
1392
2586
|
updated_value = str(value)
|
|
1393
2587
|
else:
|
|
@@ -1397,7 +2591,7 @@ class AppModel(BaseModel):
|
|
|
1397
2591
|
return self
|
|
1398
2592
|
|
|
1399
2593
|
@model_validator(mode="after")
|
|
1400
|
-
def set_default_orchestration(self):
|
|
2594
|
+
def set_default_orchestration(self) -> Self:
|
|
1401
2595
|
if self.orchestration is None:
|
|
1402
2596
|
if len(self.agents) > 1:
|
|
1403
2597
|
default_agent: AgentModel = self.agents[0]
|
|
@@ -1417,14 +2611,14 @@ class AppModel(BaseModel):
|
|
|
1417
2611
|
return self
|
|
1418
2612
|
|
|
1419
2613
|
@model_validator(mode="after")
|
|
1420
|
-
def set_default_endpoint_name(self):
|
|
2614
|
+
def set_default_endpoint_name(self) -> Self:
|
|
1421
2615
|
if self.endpoint_name is None:
|
|
1422
2616
|
self.endpoint_name = self.name
|
|
1423
2617
|
return self
|
|
1424
2618
|
|
|
1425
2619
|
@model_validator(mode="after")
|
|
1426
|
-
def set_default_agent(self):
|
|
1427
|
-
default_agent_name = self.agents[0].name
|
|
2620
|
+
def set_default_agent(self) -> Self:
|
|
2621
|
+
default_agent_name: str = self.agents[0].name
|
|
1428
2622
|
|
|
1429
2623
|
if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
|
|
1430
2624
|
self.orchestration.swarm.default_agent = default_agent_name
|
|
@@ -1432,7 +2626,7 @@ class AppModel(BaseModel):
|
|
|
1432
2626
|
return self
|
|
1433
2627
|
|
|
1434
2628
|
@model_validator(mode="after")
|
|
1435
|
-
def add_code_paths_to_sys_path(self):
|
|
2629
|
+
def add_code_paths_to_sys_path(self) -> Self:
|
|
1436
2630
|
for code_path in self.code_paths:
|
|
1437
2631
|
parent_path: str = str(Path(code_path).parent)
|
|
1438
2632
|
if parent_path not in sys.path:
|
|
@@ -1459,6 +2653,202 @@ class EvaluationModel(BaseModel):
|
|
|
1459
2653
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
1460
2654
|
|
|
1461
2655
|
|
|
2656
|
+
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
2657
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2658
|
+
expected_response: Optional[str] = None
|
|
2659
|
+
expected_facts: Optional[list[str]] = None
|
|
2660
|
+
|
|
2661
|
+
@model_validator(mode="after")
|
|
2662
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
2663
|
+
if self.expected_response is not None and self.expected_facts is not None:
|
|
2664
|
+
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
2665
|
+
return self
|
|
2666
|
+
|
|
2667
|
+
|
|
2668
|
+
class EvaluationDatasetEntryModel(BaseModel):
|
|
2669
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2670
|
+
inputs: ChatPayload
|
|
2671
|
+
expectations: EvaluationDatasetExpectationsModel
|
|
2672
|
+
|
|
2673
|
+
def to_mlflow_format(self) -> dict[str, Any]:
|
|
2674
|
+
"""
|
|
2675
|
+
Convert to MLflow evaluation dataset format.
|
|
2676
|
+
|
|
2677
|
+
Flattens the expectations fields to the top level alongside inputs,
|
|
2678
|
+
which is the format expected by MLflow's Correctness scorer.
|
|
2679
|
+
|
|
2680
|
+
Returns:
|
|
2681
|
+
dict: Flattened dictionary with inputs and expectation fields at top level
|
|
2682
|
+
"""
|
|
2683
|
+
result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
|
|
2684
|
+
|
|
2685
|
+
# Flatten expectations to top level for MLflow compatibility
|
|
2686
|
+
if self.expectations.expected_response is not None:
|
|
2687
|
+
result["expected_response"] = self.expectations.expected_response
|
|
2688
|
+
if self.expectations.expected_facts is not None:
|
|
2689
|
+
result["expected_facts"] = self.expectations.expected_facts
|
|
2690
|
+
|
|
2691
|
+
return result
|
|
2692
|
+
|
|
2693
|
+
|
|
2694
|
+
class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
2695
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2696
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
2697
|
+
name: str
|
|
2698
|
+
data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
|
|
2699
|
+
overwrite: Optional[bool] = False
|
|
2700
|
+
|
|
2701
|
+
def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
|
|
2702
|
+
evaluation_dataset: EvaluationDataset
|
|
2703
|
+
needs_creation: bool = False
|
|
2704
|
+
|
|
2705
|
+
try:
|
|
2706
|
+
evaluation_dataset = get_dataset(name=self.full_name)
|
|
2707
|
+
if self.overwrite:
|
|
2708
|
+
logger.warning(f"Overwriting dataset {self.full_name}")
|
|
2709
|
+
workspace_client: WorkspaceClient = w if w else WorkspaceClient()
|
|
2710
|
+
logger.debug(f"Dropping table: {self.full_name}")
|
|
2711
|
+
workspace_client.tables.delete(full_name=self.full_name)
|
|
2712
|
+
needs_creation = True
|
|
2713
|
+
except Exception:
|
|
2714
|
+
logger.warning(
|
|
2715
|
+
f"Dataset {self.full_name} not found, will create new dataset"
|
|
2716
|
+
)
|
|
2717
|
+
needs_creation = True
|
|
2718
|
+
|
|
2719
|
+
# Create dataset if needed (either new or after overwrite)
|
|
2720
|
+
if needs_creation:
|
|
2721
|
+
evaluation_dataset = create_dataset(name=self.full_name)
|
|
2722
|
+
if self.data:
|
|
2723
|
+
logger.debug(
|
|
2724
|
+
f"Merging {len(self.data)} entries into dataset {self.full_name}"
|
|
2725
|
+
)
|
|
2726
|
+
# Use to_mlflow_format() to flatten expectations for MLflow compatibility
|
|
2727
|
+
evaluation_dataset.merge_records(
|
|
2728
|
+
[e.to_mlflow_format() for e in self.data]
|
|
2729
|
+
)
|
|
2730
|
+
|
|
2731
|
+
return evaluation_dataset
|
|
2732
|
+
|
|
2733
|
+
@property
|
|
2734
|
+
def full_name(self) -> str:
|
|
2735
|
+
if self.schema_model:
|
|
2736
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
2737
|
+
return self.name
|
|
2738
|
+
|
|
2739
|
+
|
|
2740
|
+
class PromptOptimizationModel(BaseModel):
|
|
2741
|
+
"""Configuration for prompt optimization using GEPA.
|
|
2742
|
+
|
|
2743
|
+
GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
|
|
2744
|
+
optimizer that uses reflective mutation to improve prompts based on
|
|
2745
|
+
evaluation feedback.
|
|
2746
|
+
|
|
2747
|
+
Example:
|
|
2748
|
+
prompt_optimization:
|
|
2749
|
+
name: optimize_my_prompt
|
|
2750
|
+
prompt: *my_prompt
|
|
2751
|
+
agent: *my_agent
|
|
2752
|
+
dataset: *my_training_dataset
|
|
2753
|
+
reflection_model: databricks-meta-llama-3-3-70b-instruct
|
|
2754
|
+
num_candidates: 50
|
|
2755
|
+
"""
|
|
2756
|
+
|
|
2757
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2758
|
+
name: str
|
|
2759
|
+
prompt: Optional[PromptModel] = None
|
|
2760
|
+
agent: AgentModel
|
|
2761
|
+
dataset: EvaluationDatasetModel # Training dataset with examples
|
|
2762
|
+
reflection_model: Optional[LLMModel | str] = None
|
|
2763
|
+
num_candidates: Optional[int] = 50
|
|
2764
|
+
|
|
2765
|
+
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
2766
|
+
"""
|
|
2767
|
+
Optimize the prompt using GEPA.
|
|
2768
|
+
|
|
2769
|
+
Args:
|
|
2770
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
2771
|
+
|
|
2772
|
+
Returns:
|
|
2773
|
+
PromptModel: The optimized prompt model
|
|
2774
|
+
"""
|
|
2775
|
+
from dao_ai.optimization import OptimizationResult, optimize_prompt
|
|
2776
|
+
|
|
2777
|
+
# Get reflection model name
|
|
2778
|
+
reflection_model_name: str | None = None
|
|
2779
|
+
if self.reflection_model:
|
|
2780
|
+
if isinstance(self.reflection_model, str):
|
|
2781
|
+
reflection_model_name = self.reflection_model
|
|
2782
|
+
else:
|
|
2783
|
+
reflection_model_name = self.reflection_model.uri
|
|
2784
|
+
|
|
2785
|
+
# Ensure prompt is set
|
|
2786
|
+
prompt = self.prompt
|
|
2787
|
+
if prompt is None:
|
|
2788
|
+
raise ValueError(
|
|
2789
|
+
f"Prompt optimization '{self.name}' requires a prompt to be set"
|
|
2790
|
+
)
|
|
2791
|
+
|
|
2792
|
+
result: OptimizationResult = optimize_prompt(
|
|
2793
|
+
prompt=prompt,
|
|
2794
|
+
agent=self.agent,
|
|
2795
|
+
dataset=self.dataset,
|
|
2796
|
+
reflection_model=reflection_model_name,
|
|
2797
|
+
num_candidates=self.num_candidates or 50,
|
|
2798
|
+
register_if_improved=True,
|
|
2799
|
+
)
|
|
2800
|
+
|
|
2801
|
+
return result.optimized_prompt
|
|
2802
|
+
|
|
2803
|
+
@model_validator(mode="after")
|
|
2804
|
+
def set_defaults(self) -> Self:
|
|
2805
|
+
# If no prompt is specified, try to use the agent's prompt
|
|
2806
|
+
if self.prompt is None:
|
|
2807
|
+
if isinstance(self.agent.prompt, PromptModel):
|
|
2808
|
+
self.prompt = self.agent.prompt
|
|
2809
|
+
else:
|
|
2810
|
+
raise ValueError(
|
|
2811
|
+
f"Prompt optimization '{self.name}' requires either an explicit prompt "
|
|
2812
|
+
f"or an agent with a prompt configured"
|
|
2813
|
+
)
|
|
2814
|
+
|
|
2815
|
+
return self
|
|
2816
|
+
|
|
2817
|
+
|
|
2818
|
+
class OptimizationsModel(BaseModel):
|
|
2819
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2820
|
+
training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
|
|
2821
|
+
prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
|
|
2822
|
+
default_factory=dict
|
|
2823
|
+
)
|
|
2824
|
+
|
|
2825
|
+
def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
|
|
2826
|
+
"""
|
|
2827
|
+
Optimize all prompts in this configuration.
|
|
2828
|
+
|
|
2829
|
+
This method:
|
|
2830
|
+
1. Ensures all training datasets are created/registered in MLflow
|
|
2831
|
+
2. Runs each prompt optimization
|
|
2832
|
+
|
|
2833
|
+
Args:
|
|
2834
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
2835
|
+
|
|
2836
|
+
Returns:
|
|
2837
|
+
dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
|
|
2838
|
+
"""
|
|
2839
|
+
# First, ensure all training datasets are created/registered in MLflow
|
|
2840
|
+
logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
|
|
2841
|
+
for dataset_name, dataset_model in self.training_datasets.items():
|
|
2842
|
+
logger.debug(f"Creating/updating dataset: {dataset_name}")
|
|
2843
|
+
dataset_model.as_dataset()
|
|
2844
|
+
|
|
2845
|
+
# Run optimizations
|
|
2846
|
+
results: dict[str, PromptModel] = {}
|
|
2847
|
+
for name, optimization in self.prompt_optimizations.items():
|
|
2848
|
+
results[name] = optimization.optimize(w)
|
|
2849
|
+
return results
|
|
2850
|
+
|
|
2851
|
+
|
|
1462
2852
|
class DatasetFormat(str, Enum):
|
|
1463
2853
|
CSV = "csv"
|
|
1464
2854
|
DELTA = "delta"
|
|
@@ -1494,7 +2884,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
|
|
|
1494
2884
|
|
|
1495
2885
|
class UnityCatalogFunctionSqlModel(BaseModel):
|
|
1496
2886
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1497
|
-
function:
|
|
2887
|
+
function: FunctionModel
|
|
1498
2888
|
ddl: str
|
|
1499
2889
|
parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1500
2890
|
test: Optional[UnityCatalogFunctionSqlTestModel] = None
|
|
@@ -1522,21 +2912,132 @@ class ResourcesModel(BaseModel):
|
|
|
1522
2912
|
warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
|
|
1523
2913
|
databases: dict[str, DatabaseModel] = Field(default_factory=dict)
|
|
1524
2914
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
2915
|
+
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
2916
|
+
|
|
2917
|
+
@model_validator(mode="after")
|
|
2918
|
+
def update_genie_warehouses(self) -> Self:
|
|
2919
|
+
"""
|
|
2920
|
+
Automatically populate warehouses from genie_rooms.
|
|
2921
|
+
|
|
2922
|
+
Warehouses are extracted from each Genie room and added to the
|
|
2923
|
+
resources if they don't already exist (based on warehouse_id).
|
|
2924
|
+
"""
|
|
2925
|
+
if not self.genie_rooms:
|
|
2926
|
+
return self
|
|
2927
|
+
|
|
2928
|
+
# Process warehouses from all genie rooms
|
|
2929
|
+
for genie_room in self.genie_rooms.values():
|
|
2930
|
+
genie_room: GenieRoomModel
|
|
2931
|
+
warehouse: Optional[WarehouseModel] = genie_room.warehouse
|
|
2932
|
+
|
|
2933
|
+
if warehouse is None:
|
|
2934
|
+
continue
|
|
2935
|
+
|
|
2936
|
+
# Check if warehouse already exists based on warehouse_id
|
|
2937
|
+
warehouse_exists: bool = any(
|
|
2938
|
+
existing_warehouse.warehouse_id == warehouse.warehouse_id
|
|
2939
|
+
for existing_warehouse in self.warehouses.values()
|
|
2940
|
+
)
|
|
2941
|
+
|
|
2942
|
+
if not warehouse_exists:
|
|
2943
|
+
warehouse_key: str = normalize_name(
|
|
2944
|
+
"_".join([genie_room.name, warehouse.warehouse_id])
|
|
2945
|
+
)
|
|
2946
|
+
self.warehouses[warehouse_key] = warehouse
|
|
2947
|
+
logger.trace(
|
|
2948
|
+
"Added warehouse from Genie room",
|
|
2949
|
+
room=genie_room.name,
|
|
2950
|
+
warehouse=warehouse.warehouse_id,
|
|
2951
|
+
key=warehouse_key,
|
|
2952
|
+
)
|
|
2953
|
+
|
|
2954
|
+
return self
|
|
2955
|
+
|
|
2956
|
+
@model_validator(mode="after")
|
|
2957
|
+
def update_genie_tables(self) -> Self:
|
|
2958
|
+
"""
|
|
2959
|
+
Automatically populate tables from genie_rooms.
|
|
2960
|
+
|
|
2961
|
+
Tables are extracted from each Genie room and added to the
|
|
2962
|
+
resources if they don't already exist (based on full_name).
|
|
2963
|
+
"""
|
|
2964
|
+
if not self.genie_rooms:
|
|
2965
|
+
return self
|
|
2966
|
+
|
|
2967
|
+
# Process tables from all genie rooms
|
|
2968
|
+
for genie_room in self.genie_rooms.values():
|
|
2969
|
+
genie_room: GenieRoomModel
|
|
2970
|
+
for table in genie_room.tables:
|
|
2971
|
+
table: TableModel
|
|
2972
|
+
table_exists: bool = any(
|
|
2973
|
+
existing_table.full_name == table.full_name
|
|
2974
|
+
for existing_table in self.tables.values()
|
|
2975
|
+
)
|
|
2976
|
+
if not table_exists:
|
|
2977
|
+
table_key: str = normalize_name(
|
|
2978
|
+
"_".join([genie_room.name, table.full_name])
|
|
2979
|
+
)
|
|
2980
|
+
self.tables[table_key] = table
|
|
2981
|
+
logger.trace(
|
|
2982
|
+
"Added table from Genie room",
|
|
2983
|
+
room=genie_room.name,
|
|
2984
|
+
table=table.name,
|
|
2985
|
+
key=table_key,
|
|
2986
|
+
)
|
|
2987
|
+
|
|
2988
|
+
return self
|
|
2989
|
+
|
|
2990
|
+
@model_validator(mode="after")
|
|
2991
|
+
def update_genie_functions(self) -> Self:
|
|
2992
|
+
"""
|
|
2993
|
+
Automatically populate functions from genie_rooms.
|
|
2994
|
+
|
|
2995
|
+
Functions are extracted from each Genie room and added to the
|
|
2996
|
+
resources if they don't already exist (based on full_name).
|
|
2997
|
+
"""
|
|
2998
|
+
if not self.genie_rooms:
|
|
2999
|
+
return self
|
|
3000
|
+
|
|
3001
|
+
# Process functions from all genie rooms
|
|
3002
|
+
for genie_room in self.genie_rooms.values():
|
|
3003
|
+
genie_room: GenieRoomModel
|
|
3004
|
+
for function in genie_room.functions:
|
|
3005
|
+
function: FunctionModel
|
|
3006
|
+
function_exists: bool = any(
|
|
3007
|
+
existing_function.full_name == function.full_name
|
|
3008
|
+
for existing_function in self.functions.values()
|
|
3009
|
+
)
|
|
3010
|
+
if not function_exists:
|
|
3011
|
+
function_key: str = normalize_name(
|
|
3012
|
+
"_".join([genie_room.name, function.full_name])
|
|
3013
|
+
)
|
|
3014
|
+
self.functions[function_key] = function
|
|
3015
|
+
logger.trace(
|
|
3016
|
+
"Added function from Genie room",
|
|
3017
|
+
room=genie_room.name,
|
|
3018
|
+
function=function.name,
|
|
3019
|
+
key=function_key,
|
|
3020
|
+
)
|
|
3021
|
+
|
|
3022
|
+
return self
|
|
1525
3023
|
|
|
1526
3024
|
|
|
1527
3025
|
class AppConfig(BaseModel):
|
|
1528
3026
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1529
3027
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
3028
|
+
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
1530
3029
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
1531
3030
|
resources: Optional[ResourcesModel] = None
|
|
1532
3031
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
|
1533
3032
|
tools: dict[str, ToolModel] = Field(default_factory=dict)
|
|
1534
3033
|
guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
|
|
3034
|
+
middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
|
|
1535
3035
|
memory: Optional[MemoryModel] = None
|
|
1536
3036
|
prompts: dict[str, PromptModel] = Field(default_factory=dict)
|
|
1537
3037
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
1538
3038
|
app: Optional[AppModel] = None
|
|
1539
3039
|
evaluation: Optional[EvaluationModel] = None
|
|
3040
|
+
optimizations: Optional[OptimizationsModel] = None
|
|
1540
3041
|
datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
|
|
1541
3042
|
unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
|
|
1542
3043
|
default_factory=list
|
|
@@ -1558,10 +3059,10 @@ class AppConfig(BaseModel):
|
|
|
1558
3059
|
|
|
1559
3060
|
def initialize(self) -> None:
|
|
1560
3061
|
from dao_ai.hooks.core import create_hooks
|
|
3062
|
+
from dao_ai.logging import configure_logging
|
|
1561
3063
|
|
|
1562
3064
|
if self.app and self.app.log_level:
|
|
1563
|
-
|
|
1564
|
-
logger.add(sys.stderr, level=self.app.log_level)
|
|
3065
|
+
configure_logging(level=self.app.log_level)
|
|
1565
3066
|
|
|
1566
3067
|
logger.debug("Calling initialization hooks...")
|
|
1567
3068
|
initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
|
|
@@ -1605,21 +3106,45 @@ class AppConfig(BaseModel):
|
|
|
1605
3106
|
def create_agent(
|
|
1606
3107
|
self,
|
|
1607
3108
|
w: WorkspaceClient | None = None,
|
|
3109
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3110
|
+
pat: str | None = None,
|
|
3111
|
+
client_id: str | None = None,
|
|
3112
|
+
client_secret: str | None = None,
|
|
3113
|
+
workspace_host: str | None = None,
|
|
1608
3114
|
) -> None:
|
|
1609
3115
|
from dao_ai.providers.base import ServiceProvider
|
|
1610
3116
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1611
3117
|
|
|
1612
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3118
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3119
|
+
w=w,
|
|
3120
|
+
vsc=vsc,
|
|
3121
|
+
pat=pat,
|
|
3122
|
+
client_id=client_id,
|
|
3123
|
+
client_secret=client_secret,
|
|
3124
|
+
workspace_host=workspace_host,
|
|
3125
|
+
)
|
|
1613
3126
|
provider.create_agent(self)
|
|
1614
3127
|
|
|
1615
3128
|
def deploy_agent(
|
|
1616
3129
|
self,
|
|
1617
3130
|
w: WorkspaceClient | None = None,
|
|
3131
|
+
vsc: "VectorSearchClient | None" = None,
|
|
3132
|
+
pat: str | None = None,
|
|
3133
|
+
client_id: str | None = None,
|
|
3134
|
+
client_secret: str | None = None,
|
|
3135
|
+
workspace_host: str | None = None,
|
|
1618
3136
|
) -> None:
|
|
1619
3137
|
from dao_ai.providers.base import ServiceProvider
|
|
1620
3138
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1621
3139
|
|
|
1622
|
-
provider: ServiceProvider = DatabricksProvider(
|
|
3140
|
+
provider: ServiceProvider = DatabricksProvider(
|
|
3141
|
+
w=w,
|
|
3142
|
+
vsc=vsc,
|
|
3143
|
+
pat=pat,
|
|
3144
|
+
client_id=client_id,
|
|
3145
|
+
client_secret=client_secret,
|
|
3146
|
+
workspace_host=workspace_host,
|
|
3147
|
+
)
|
|
1623
3148
|
provider.deploy_agent(self)
|
|
1624
3149
|
|
|
1625
3150
|
def find_agents(
|