dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +797 -242
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +329 -0
- dao_ai/genie/cache/semantic.py +919 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +11 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +108 -35
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.35.dist-info/METADATA +0 -1169
- dao_ai-0.0.35.dist-info/RECORD +0 -41
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -28,8 +28,12 @@ from databricks.sdk.service.database import DatabaseInstance
|
|
|
28
28
|
from databricks.vector_search.client import VectorSearchClient
|
|
29
29
|
from databricks.vector_search.index import VectorSearchIndex
|
|
30
30
|
from databricks_langchain import (
|
|
31
|
+
ChatDatabricks,
|
|
32
|
+
DatabricksEmbeddings,
|
|
31
33
|
DatabricksFunctionClient,
|
|
32
34
|
)
|
|
35
|
+
from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
|
|
36
|
+
from langchain_core.embeddings import Embeddings
|
|
33
37
|
from langchain_core.language_models import LanguageModelLike
|
|
34
38
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
35
39
|
from langchain_core.runnables.base import RunnableLike
|
|
@@ -42,6 +46,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
|
42
46
|
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
43
47
|
from mlflow.models import ModelConfig
|
|
44
48
|
from mlflow.models.resources import (
|
|
49
|
+
DatabricksApp,
|
|
45
50
|
DatabricksFunction,
|
|
46
51
|
DatabricksGenieSpace,
|
|
47
52
|
DatabricksLakebase,
|
|
@@ -82,27 +87,6 @@ class HasFullName(ABC):
|
|
|
82
87
|
def full_name(self) -> str: ...
|
|
83
88
|
|
|
84
89
|
|
|
85
|
-
class IsDatabricksResource(ABC):
|
|
86
|
-
on_behalf_of_user: Optional[bool] = False
|
|
87
|
-
|
|
88
|
-
@abstractmethod
|
|
89
|
-
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
90
|
-
|
|
91
|
-
@property
|
|
92
|
-
@abstractmethod
|
|
93
|
-
def api_scopes(self) -> Sequence[str]: ...
|
|
94
|
-
|
|
95
|
-
@property
|
|
96
|
-
def workspace_client(self) -> WorkspaceClient:
|
|
97
|
-
credentials_strategy: CredentialsStrategy = None
|
|
98
|
-
if self.on_behalf_of_user:
|
|
99
|
-
credentials_strategy = ModelServingUserCredentials()
|
|
100
|
-
logger.debug(
|
|
101
|
-
f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
|
|
102
|
-
)
|
|
103
|
-
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
104
|
-
|
|
105
|
-
|
|
106
90
|
class EnvironmentVariableModel(BaseModel, HasValue):
|
|
107
91
|
model_config = ConfigDict(
|
|
108
92
|
frozen=True,
|
|
@@ -210,6 +194,138 @@ class ServicePrincipalModel(BaseModel):
|
|
|
210
194
|
client_secret: AnyVariable
|
|
211
195
|
|
|
212
196
|
|
|
197
|
+
class IsDatabricksResource(ABC, BaseModel):
|
|
198
|
+
"""
|
|
199
|
+
Base class for Databricks resources with authentication support.
|
|
200
|
+
|
|
201
|
+
Authentication Options:
|
|
202
|
+
----------------------
|
|
203
|
+
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
204
|
+
calling user's identity via ModelServingUserCredentials.
|
|
205
|
+
|
|
206
|
+
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
207
|
+
(client_id + client_secret + workspace_host) for service principal auth.
|
|
208
|
+
|
|
209
|
+
3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
|
|
210
|
+
to authenticate with a personal access token.
|
|
211
|
+
|
|
212
|
+
4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
|
|
213
|
+
(environment variables, notebook context, etc.)
|
|
214
|
+
|
|
215
|
+
Authentication Priority:
|
|
216
|
+
1. OBO (on_behalf_of_user=True)
|
|
217
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
218
|
+
3. PAT (pat + workspace_host)
|
|
219
|
+
4. Ambient/default authentication
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
223
|
+
|
|
224
|
+
on_behalf_of_user: Optional[bool] = False
|
|
225
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
226
|
+
client_id: Optional[AnyVariable] = None
|
|
227
|
+
client_secret: Optional[AnyVariable] = None
|
|
228
|
+
workspace_host: Optional[AnyVariable] = None
|
|
229
|
+
pat: Optional[AnyVariable] = None
|
|
230
|
+
|
|
231
|
+
@abstractmethod
|
|
232
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
@abstractmethod
|
|
236
|
+
def api_scopes(self) -> Sequence[str]: ...
|
|
237
|
+
|
|
238
|
+
@model_validator(mode="after")
|
|
239
|
+
def _expand_service_principal(self) -> Self:
|
|
240
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
241
|
+
if self.service_principal is not None:
|
|
242
|
+
if self.client_id is None:
|
|
243
|
+
self.client_id = self.service_principal.client_id
|
|
244
|
+
if self.client_secret is None:
|
|
245
|
+
self.client_secret = self.service_principal.client_secret
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
@model_validator(mode="after")
|
|
249
|
+
def _validate_auth_not_mixed(self) -> Self:
|
|
250
|
+
"""Validate that OAuth and PAT authentication are not both provided."""
|
|
251
|
+
has_oauth: bool = self.client_id is not None and self.client_secret is not None
|
|
252
|
+
has_pat: bool = self.pat is not None
|
|
253
|
+
|
|
254
|
+
if has_oauth and has_pat:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"Cannot use both OAuth and user authentication methods. "
|
|
257
|
+
"Please provide either OAuth credentials or user credentials."
|
|
258
|
+
)
|
|
259
|
+
return self
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
263
|
+
"""
|
|
264
|
+
Get a WorkspaceClient configured with the appropriate authentication.
|
|
265
|
+
|
|
266
|
+
Authentication priority:
|
|
267
|
+
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
268
|
+
2. If service principal credentials are configured (client_id, client_secret,
|
|
269
|
+
workspace_host), uses OAuth M2M
|
|
270
|
+
3. If PAT is configured, uses token authentication
|
|
271
|
+
4. Otherwise, uses default/ambient authentication
|
|
272
|
+
"""
|
|
273
|
+
from dao_ai.utils import normalize_host
|
|
274
|
+
|
|
275
|
+
# Check for OBO first (highest priority)
|
|
276
|
+
if self.on_behalf_of_user:
|
|
277
|
+
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
278
|
+
logger.debug(
|
|
279
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
280
|
+
f"with OBO credentials strategy"
|
|
281
|
+
)
|
|
282
|
+
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
283
|
+
|
|
284
|
+
# Check for service principal credentials
|
|
285
|
+
client_id_value: str | None = (
|
|
286
|
+
value_of(self.client_id) if self.client_id else None
|
|
287
|
+
)
|
|
288
|
+
client_secret_value: str | None = (
|
|
289
|
+
value_of(self.client_secret) if self.client_secret else None
|
|
290
|
+
)
|
|
291
|
+
workspace_host_value: str | None = (
|
|
292
|
+
normalize_host(value_of(self.workspace_host))
|
|
293
|
+
if self.workspace_host
|
|
294
|
+
else None
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if client_id_value and client_secret_value and workspace_host_value:
|
|
298
|
+
logger.debug(
|
|
299
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
300
|
+
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
301
|
+
)
|
|
302
|
+
return WorkspaceClient(
|
|
303
|
+
host=workspace_host_value,
|
|
304
|
+
client_id=client_id_value,
|
|
305
|
+
client_secret=client_secret_value,
|
|
306
|
+
auth_type="oauth-m2m",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Check for PAT authentication
|
|
310
|
+
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
311
|
+
if pat_value:
|
|
312
|
+
logger.debug(
|
|
313
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
314
|
+
)
|
|
315
|
+
return WorkspaceClient(
|
|
316
|
+
host=workspace_host_value,
|
|
317
|
+
token=pat_value,
|
|
318
|
+
auth_type="pat",
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Default: use ambient authentication
|
|
322
|
+
logger.debug(
|
|
323
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
324
|
+
"with default/ambient authentication"
|
|
325
|
+
)
|
|
326
|
+
return WorkspaceClient()
|
|
327
|
+
|
|
328
|
+
|
|
213
329
|
class Privilege(str, Enum):
|
|
214
330
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
215
331
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -270,7 +386,26 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
270
386
|
provider.create_schema(self)
|
|
271
387
|
|
|
272
388
|
|
|
273
|
-
class
|
|
389
|
+
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
390
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
391
|
+
name: str
|
|
392
|
+
url: str
|
|
393
|
+
|
|
394
|
+
@property
|
|
395
|
+
def full_name(self) -> str:
|
|
396
|
+
return self.name
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def api_scopes(self) -> Sequence[str]:
|
|
400
|
+
return ["apps.apps"]
|
|
401
|
+
|
|
402
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
403
|
+
return [
|
|
404
|
+
DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class TableModel(IsDatabricksResource, HasFullName):
|
|
274
409
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
275
410
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
276
411
|
name: Optional[str] = None
|
|
@@ -339,12 +474,16 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
339
474
|
return resources
|
|
340
475
|
|
|
341
476
|
|
|
342
|
-
class LLMModel(
|
|
477
|
+
class LLMModel(IsDatabricksResource):
|
|
343
478
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
344
479
|
name: str
|
|
345
480
|
temperature: Optional[float] = 0.1
|
|
346
481
|
max_tokens: Optional[int] = 8192
|
|
347
482
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
483
|
+
use_responses_api: Optional[bool] = Field(
|
|
484
|
+
default=False,
|
|
485
|
+
description="Use Responses API for ResponsesAgent endpoints",
|
|
486
|
+
)
|
|
348
487
|
|
|
349
488
|
@property
|
|
350
489
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -364,19 +503,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
364
503
|
]
|
|
365
504
|
|
|
366
505
|
def as_chat_model(self) -> LanguageModelLike:
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
373
|
-
|
|
374
|
-
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
375
|
-
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
506
|
+
chat_client: LanguageModelLike = ChatDatabricks(
|
|
507
|
+
model=self.name,
|
|
508
|
+
temperature=self.temperature,
|
|
509
|
+
max_tokens=self.max_tokens,
|
|
510
|
+
use_responses_api=self.use_responses_api,
|
|
376
511
|
)
|
|
377
|
-
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
378
|
-
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
379
|
-
# )
|
|
380
512
|
|
|
381
513
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
382
514
|
for fallback in self.fallbacks:
|
|
@@ -408,6 +540,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
408
540
|
|
|
409
541
|
return chat_client
|
|
410
542
|
|
|
543
|
+
def as_embeddings_model(self) -> Embeddings:
|
|
544
|
+
return DatabricksEmbeddings(endpoint=self.name)
|
|
545
|
+
|
|
411
546
|
|
|
412
547
|
class VectorSearchEndpointType(str, Enum):
|
|
413
548
|
STANDARD = "STANDARD"
|
|
@@ -427,7 +562,7 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
427
562
|
return str(value)
|
|
428
563
|
|
|
429
564
|
|
|
430
|
-
class IndexModel(
|
|
565
|
+
class IndexModel(IsDatabricksResource, HasFullName):
|
|
431
566
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
432
567
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
433
568
|
name: str
|
|
@@ -452,7 +587,7 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
452
587
|
]
|
|
453
588
|
|
|
454
589
|
|
|
455
|
-
class GenieRoomModel(
|
|
590
|
+
class GenieRoomModel(IsDatabricksResource):
|
|
456
591
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
457
592
|
name: str
|
|
458
593
|
description: Optional[str] = None
|
|
@@ -478,7 +613,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
478
613
|
return self
|
|
479
614
|
|
|
480
615
|
|
|
481
|
-
class VolumeModel(
|
|
616
|
+
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
482
617
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
483
618
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
484
619
|
name: str
|
|
@@ -538,7 +673,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
538
673
|
provider.create_path(self)
|
|
539
674
|
|
|
540
675
|
|
|
541
|
-
class VectorStoreModel(
|
|
676
|
+
class VectorStoreModel(IsDatabricksResource):
|
|
542
677
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
543
678
|
embedding_model: Optional[LLMModel] = None
|
|
544
679
|
index: Optional[IndexModel] = None
|
|
@@ -637,7 +772,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
637
772
|
provider.create_vector_store(self)
|
|
638
773
|
|
|
639
774
|
|
|
640
|
-
class FunctionModel(
|
|
775
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
641
776
|
model_config = ConfigDict()
|
|
642
777
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
643
778
|
name: Optional[str] = None
|
|
@@ -692,7 +827,7 @@ class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
692
827
|
return ["sql.statement-execution"]
|
|
693
828
|
|
|
694
829
|
|
|
695
|
-
class ConnectionModel(
|
|
830
|
+
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
696
831
|
model_config = ConfigDict()
|
|
697
832
|
name: str
|
|
698
833
|
|
|
@@ -719,7 +854,7 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
719
854
|
]
|
|
720
855
|
|
|
721
856
|
|
|
722
|
-
class WarehouseModel(
|
|
857
|
+
class WarehouseModel(IsDatabricksResource):
|
|
723
858
|
model_config = ConfigDict()
|
|
724
859
|
name: str
|
|
725
860
|
description: Optional[str] = None
|
|
@@ -746,30 +881,28 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
746
881
|
return self
|
|
747
882
|
|
|
748
883
|
|
|
749
|
-
class
|
|
884
|
+
class DatabaseType(str, Enum):
|
|
885
|
+
POSTGRES = "postgres"
|
|
886
|
+
LAKEBASE = "lakebase"
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
class DatabaseModel(IsDatabricksResource):
|
|
750
890
|
"""
|
|
751
891
|
Configuration for a Databricks Lakebase (PostgreSQL) database instance.
|
|
752
892
|
|
|
753
|
-
Authentication
|
|
754
|
-
|
|
755
|
-
This model uses TWO separate authentication contexts:
|
|
756
|
-
|
|
757
|
-
1. **Workspace API Authentication** (inherited from IsDatabricksResource):
|
|
758
|
-
- Uses ambient/default authentication (environment variables, notebook context, app service principal)
|
|
759
|
-
- Used for: discovering database instance, getting host DNS, checking instance status
|
|
760
|
-
- Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
|
|
893
|
+
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
894
|
+
- user/password: For user-based database authentication
|
|
761
895
|
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
- OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
|
|
766
|
-
- User Auth: Set user (and optionally password) to connect as a user identity
|
|
896
|
+
Database Type:
|
|
897
|
+
- lakebase: Databricks-managed Lakebase instance (authentication optional, supports ambient auth)
|
|
898
|
+
- postgres: Standard PostgreSQL database (authentication required)
|
|
767
899
|
|
|
768
900
|
Example Service Principal Configuration:
|
|
769
901
|
```yaml
|
|
770
902
|
databases:
|
|
771
903
|
my_lakebase:
|
|
772
904
|
name: my-database
|
|
905
|
+
type: lakebase
|
|
773
906
|
service_principal:
|
|
774
907
|
client_id:
|
|
775
908
|
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
@@ -780,31 +913,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
780
913
|
env: DATABRICKS_HOST
|
|
781
914
|
```
|
|
782
915
|
|
|
783
|
-
Example
|
|
916
|
+
Example User Configuration:
|
|
784
917
|
```yaml
|
|
785
918
|
databases:
|
|
786
919
|
my_lakebase:
|
|
787
920
|
name: my-database
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
client_secret:
|
|
791
|
-
scope: my-scope
|
|
792
|
-
secret: sp-client-secret
|
|
793
|
-
workspace_host:
|
|
794
|
-
env: DATABRICKS_HOST
|
|
921
|
+
type: lakebase
|
|
922
|
+
user: my-user@databricks.com
|
|
795
923
|
```
|
|
796
924
|
|
|
797
|
-
Example
|
|
925
|
+
Example Ambient Authentication (Lakebase only):
|
|
798
926
|
```yaml
|
|
799
927
|
databases:
|
|
800
928
|
my_lakebase:
|
|
801
929
|
name: my-database
|
|
802
|
-
|
|
930
|
+
type: lakebase
|
|
931
|
+
on_behalf_of_user: true
|
|
803
932
|
```
|
|
804
933
|
"""
|
|
805
934
|
|
|
806
935
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
807
936
|
name: str
|
|
937
|
+
type: Optional[DatabaseType] = DatabaseType.LAKEBASE
|
|
808
938
|
instance_name: Optional[str] = None
|
|
809
939
|
description: Optional[str] = None
|
|
810
940
|
host: Optional[AnyVariable] = None
|
|
@@ -815,16 +945,18 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
815
945
|
timeout_seconds: Optional[int] = 10
|
|
816
946
|
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
817
947
|
node_count: Optional[int] = None
|
|
948
|
+
# Database-specific auth (user identity for DB connection)
|
|
818
949
|
user: Optional[AnyVariable] = None
|
|
819
950
|
password: Optional[AnyVariable] = None
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
951
|
+
|
|
952
|
+
@field_serializer("type")
|
|
953
|
+
def serialize_type(self, value: DatabaseType | None) -> str | None:
|
|
954
|
+
"""Serialize the database type enum to its string value."""
|
|
955
|
+
return value.value if value is not None else None
|
|
824
956
|
|
|
825
957
|
@property
|
|
826
958
|
def api_scopes(self) -> Sequence[str]:
|
|
827
|
-
return []
|
|
959
|
+
return ["database.database-instances"]
|
|
828
960
|
|
|
829
961
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
830
962
|
return [
|
|
@@ -838,29 +970,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
838
970
|
def update_instance_name(self) -> Self:
|
|
839
971
|
if self.instance_name is None:
|
|
840
972
|
self.instance_name = self.name
|
|
841
|
-
|
|
842
|
-
return self
|
|
843
|
-
|
|
844
|
-
@model_validator(mode="after")
|
|
845
|
-
def expand_service_principal(self) -> Self:
|
|
846
|
-
"""Expand service_principal into client_id and client_secret if provided."""
|
|
847
|
-
if self.service_principal is not None:
|
|
848
|
-
if self.client_id is None:
|
|
849
|
-
self.client_id = self.service_principal.client_id
|
|
850
|
-
if self.client_secret is None:
|
|
851
|
-
self.client_secret = self.service_principal.client_secret
|
|
852
973
|
return self
|
|
853
974
|
|
|
854
975
|
@model_validator(mode="after")
|
|
855
976
|
def update_user(self) -> Self:
|
|
856
|
-
if
|
|
977
|
+
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
978
|
+
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
857
979
|
return self
|
|
858
980
|
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
981
|
+
# For postgres, we need explicit user credentials
|
|
982
|
+
# For lakebase with no auth, ambient auth is allowed
|
|
983
|
+
if self.type == DatabaseType.POSTGRES:
|
|
984
|
+
# Try to determine current user for local development
|
|
985
|
+
try:
|
|
986
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
987
|
+
except Exception as e:
|
|
988
|
+
logger.warning(
|
|
989
|
+
f"Could not determine current user for PostgreSQL database: {e}. "
|
|
990
|
+
f"Please provide explicit user credentials."
|
|
991
|
+
)
|
|
992
|
+
else:
|
|
993
|
+
# For lakebase, try to determine current user but don't fail if we can't
|
|
994
|
+
try:
|
|
995
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
996
|
+
except Exception:
|
|
997
|
+
# If we can't determine user and no explicit auth, that's okay
|
|
998
|
+
# for lakebase with ambient auth - credentials will be injected at runtime
|
|
999
|
+
pass
|
|
864
1000
|
|
|
865
1001
|
return self
|
|
866
1002
|
|
|
@@ -869,12 +1005,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
869
1005
|
if self.host is not None:
|
|
870
1006
|
return self
|
|
871
1007
|
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
1008
|
+
# Try to fetch host from existing instance
|
|
1009
|
+
# This may fail for OBO/ambient auth during model logging (before deployment)
|
|
1010
|
+
try:
|
|
1011
|
+
existing_instance: DatabaseInstance = (
|
|
1012
|
+
self.workspace_client.database.get_database_instance(
|
|
1013
|
+
name=self.instance_name
|
|
1014
|
+
)
|
|
875
1015
|
)
|
|
876
|
-
|
|
877
|
-
|
|
1016
|
+
self.host = existing_instance.read_write_dns
|
|
1017
|
+
except Exception as e:
|
|
1018
|
+
# For lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1019
|
+
# The host will need to be provided explicitly or fetched at runtime
|
|
1020
|
+
if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
|
|
1021
|
+
logger.debug(
|
|
1022
|
+
f"Could not fetch host for database {self.instance_name} "
|
|
1023
|
+
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1024
|
+
)
|
|
1025
|
+
else:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"Could not fetch host for database {self.instance_name}. "
|
|
1028
|
+
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1029
|
+
)
|
|
878
1030
|
return self
|
|
879
1031
|
|
|
880
1032
|
@model_validator(mode="after")
|
|
@@ -885,21 +1037,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
885
1037
|
self.client_secret,
|
|
886
1038
|
]
|
|
887
1039
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1040
|
+
has_user_auth: bool = self.user is not None
|
|
1041
|
+
has_obo: bool = self.on_behalf_of_user is True
|
|
1042
|
+
has_pat: bool = self.pat is not None
|
|
888
1043
|
|
|
889
|
-
|
|
890
|
-
|
|
1044
|
+
# Count how many auth methods are configured
|
|
1045
|
+
auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
|
|
891
1046
|
|
|
892
|
-
if
|
|
1047
|
+
if auth_methods_count > 1:
|
|
893
1048
|
raise ValueError(
|
|
894
|
-
"Cannot
|
|
895
|
-
"Please provide
|
|
1049
|
+
"Cannot mix authentication methods. "
|
|
1050
|
+
"Please provide exactly one of: "
|
|
1051
|
+
"on_behalf_of_user=true (for passive auth in model serving), "
|
|
1052
|
+
"OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
|
|
1053
|
+
"PAT (personal access token), "
|
|
1054
|
+
"or user credentials (user)."
|
|
896
1055
|
)
|
|
897
1056
|
|
|
898
|
-
|
|
1057
|
+
# For postgres type, at least one auth method must be configured
|
|
1058
|
+
# For lakebase type, auth is optional (supports ambient authentication)
|
|
1059
|
+
if self.type == DatabaseType.POSTGRES and auth_methods_count == 0:
|
|
899
1060
|
raise ValueError(
|
|
900
|
-
"
|
|
901
|
-
"
|
|
902
|
-
"
|
|
1061
|
+
"PostgreSQL databases require explicit authentication. "
|
|
1062
|
+
"Please provide one of: "
|
|
1063
|
+
"OAuth credentials (workspace_host, client_id, client_secret), "
|
|
1064
|
+
"service_principal with workspace_host, "
|
|
1065
|
+
"PAT (personal access token), "
|
|
1066
|
+
"or user credentials (user)."
|
|
903
1067
|
)
|
|
904
1068
|
|
|
905
1069
|
return self
|
|
@@ -913,8 +1077,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
913
1077
|
If username is configured, it will be included; otherwise it will be omitted
|
|
914
1078
|
to allow Lakebase to authenticate using the token's identity.
|
|
915
1079
|
"""
|
|
916
|
-
|
|
917
|
-
|
|
1080
|
+
import uuid as _uuid
|
|
1081
|
+
|
|
1082
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
918
1083
|
|
|
919
1084
|
username: str | None = None
|
|
920
1085
|
|
|
@@ -922,19 +1087,36 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
922
1087
|
username = value_of(self.client_id)
|
|
923
1088
|
elif self.user:
|
|
924
1089
|
username = value_of(self.user)
|
|
1090
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
1091
|
+
|
|
1092
|
+
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1093
|
+
host_value: Any = self.host
|
|
1094
|
+
if host_value is None and self.on_behalf_of_user:
|
|
1095
|
+
# Fetch host at runtime for OBO mode
|
|
1096
|
+
existing_instance: DatabaseInstance = (
|
|
1097
|
+
self.workspace_client.database.get_database_instance(
|
|
1098
|
+
name=self.instance_name
|
|
1099
|
+
)
|
|
1100
|
+
)
|
|
1101
|
+
host_value = existing_instance.read_write_dns
|
|
1102
|
+
|
|
1103
|
+
if host_value is None:
|
|
1104
|
+
raise ValueError(
|
|
1105
|
+
f"Database host not configured for {self.instance_name}. "
|
|
1106
|
+
"Please provide 'host' explicitly."
|
|
1107
|
+
)
|
|
925
1108
|
|
|
926
|
-
host: str = value_of(
|
|
1109
|
+
host: str = value_of(host_value)
|
|
927
1110
|
port: int = value_of(self.port)
|
|
928
1111
|
database: str = value_of(self.database)
|
|
929
1112
|
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
1113
|
+
# Use the resource's own workspace_client to generate the database credential
|
|
1114
|
+
w: WorkspaceClient = self.workspace_client
|
|
1115
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1116
|
+
request_id=str(_uuid.uuid4()),
|
|
1117
|
+
instance_names=[self.instance_name],
|
|
935
1118
|
)
|
|
936
|
-
|
|
937
|
-
token: str = provider.lakebase_password_provider(self.instance_name)
|
|
1119
|
+
token: str = cred.token
|
|
938
1120
|
|
|
939
1121
|
# Build connection parameters dictionary
|
|
940
1122
|
params: dict[str, Any] = {
|
|
@@ -972,11 +1154,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
972
1154
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
973
1155
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
974
1156
|
|
|
1157
|
+
# Use provided workspace client or fall back to resource's own workspace_client
|
|
1158
|
+
if w is None:
|
|
1159
|
+
w = self.workspace_client
|
|
975
1160
|
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
976
1161
|
provider.create_lakebase(self)
|
|
977
1162
|
provider.create_lakebase_instance_role(self)
|
|
978
1163
|
|
|
979
1164
|
|
|
1165
|
+
class GenieLRUCacheParametersModel(BaseModel):
|
|
1166
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1167
|
+
capacity: int = 1000
|
|
1168
|
+
time_to_live_seconds: int | None = (
|
|
1169
|
+
60 * 60 * 24
|
|
1170
|
+
) # 1 day default, None or negative = never expires
|
|
1171
|
+
warehouse: WarehouseModel
|
|
1172
|
+
|
|
1173
|
+
|
|
1174
|
+
class GenieSemanticCacheParametersModel(BaseModel):
|
|
1175
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1176
|
+
time_to_live_seconds: int | None = (
|
|
1177
|
+
60 * 60 * 24
|
|
1178
|
+
) # 1 day default, None or negative = never expires
|
|
1179
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1180
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1181
|
+
question_weight: Optional[float] = (
|
|
1182
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1183
|
+
)
|
|
1184
|
+
context_weight: Optional[float] = (
|
|
1185
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1186
|
+
)
|
|
1187
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1188
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1189
|
+
database: DatabaseModel
|
|
1190
|
+
warehouse: WarehouseModel
|
|
1191
|
+
table_name: str = "genie_semantic_cache"
|
|
1192
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1193
|
+
max_context_tokens: int = (
|
|
1194
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
@model_validator(mode="after")
|
|
1198
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1199
|
+
"""
|
|
1200
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1201
|
+
|
|
1202
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1203
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1204
|
+
If both are provided, they must sum to 1.0.
|
|
1205
|
+
"""
|
|
1206
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1207
|
+
# Both missing - use defaults
|
|
1208
|
+
self.question_weight = 0.6
|
|
1209
|
+
self.context_weight = 0.4
|
|
1210
|
+
elif self.question_weight is None:
|
|
1211
|
+
# Compute question_weight from context_weight
|
|
1212
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1213
|
+
raise ValueError(
|
|
1214
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1215
|
+
)
|
|
1216
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1217
|
+
elif self.context_weight is None:
|
|
1218
|
+
# Compute context_weight from question_weight
|
|
1219
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1220
|
+
raise ValueError(
|
|
1221
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1222
|
+
)
|
|
1223
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1224
|
+
else:
|
|
1225
|
+
# Both provided - validate they sum to 1.0
|
|
1226
|
+
total_weight = self.question_weight + self.context_weight
|
|
1227
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1228
|
+
raise ValueError(
|
|
1229
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1230
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1231
|
+
f"of question vs context similarity in the combined score."
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
return self
|
|
1235
|
+
|
|
1236
|
+
|
|
980
1237
|
class SearchParametersModel(BaseModel):
|
|
981
1238
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
982
1239
|
num_results: Optional[int] = 10
|
|
@@ -1067,28 +1324,47 @@ class FunctionType(str, Enum):
|
|
|
1067
1324
|
MCP = "mcp"
|
|
1068
1325
|
|
|
1069
1326
|
|
|
1070
|
-
class
|
|
1071
|
-
"""
|
|
1327
|
+
class HumanInTheLoopModel(BaseModel):
|
|
1328
|
+
"""
|
|
1329
|
+
Configuration for Human-in-the-Loop tool approval.
|
|
1072
1330
|
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
RESPONSE = "response"
|
|
1076
|
-
DECLINE = "decline"
|
|
1331
|
+
This model configures when and how tools require human approval before execution.
|
|
1332
|
+
It maps to LangChain's HumanInTheLoopMiddleware.
|
|
1077
1333
|
|
|
1334
|
+
LangChain supports three decision types:
|
|
1335
|
+
- "approve": Execute tool with original arguments
|
|
1336
|
+
- "edit": Modify arguments before execution
|
|
1337
|
+
- "reject": Skip execution with optional feedback message
|
|
1338
|
+
"""
|
|
1078
1339
|
|
|
1079
|
-
class HumanInTheLoopModel(BaseModel):
|
|
1080
1340
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1341
|
+
|
|
1342
|
+
review_prompt: Optional[str] = Field(
|
|
1343
|
+
default=None,
|
|
1344
|
+
description="Message shown to the reviewer when approval is requested",
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
|
|
1348
|
+
default_factory=lambda: ["approve", "edit", "reject"],
|
|
1349
|
+
description="List of allowed decision types for this tool",
|
|
1089
1350
|
)
|
|
1090
|
-
|
|
1091
|
-
|
|
1351
|
+
|
|
1352
|
+
@model_validator(mode="after")
|
|
1353
|
+
def validate_and_normalize_decisions(self) -> Self:
|
|
1354
|
+
"""Validate and normalize allowed decisions."""
|
|
1355
|
+
if not self.allowed_decisions:
|
|
1356
|
+
raise ValueError("At least one decision type must be allowed")
|
|
1357
|
+
|
|
1358
|
+
# Remove duplicates while preserving order
|
|
1359
|
+
seen = set()
|
|
1360
|
+
unique_decisions = []
|
|
1361
|
+
for decision in self.allowed_decisions:
|
|
1362
|
+
if decision not in seen:
|
|
1363
|
+
seen.add(decision)
|
|
1364
|
+
unique_decisions.append(decision)
|
|
1365
|
+
self.allowed_decisions = unique_decisions
|
|
1366
|
+
|
|
1367
|
+
return self
|
|
1092
1368
|
|
|
1093
1369
|
|
|
1094
1370
|
class BaseFunctionModel(ABC, BaseModel):
|
|
@@ -1151,7 +1427,16 @@ class TransportType(str, Enum):
|
|
|
1151
1427
|
STDIO = "stdio"
|
|
1152
1428
|
|
|
1153
1429
|
|
|
1154
|
-
class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
1430
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
|
|
1431
|
+
"""
|
|
1432
|
+
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1433
|
+
|
|
1434
|
+
Authentication for MCP connections uses the same options as other resources:
|
|
1435
|
+
- Service Principal (client_id + client_secret + workspace_host)
|
|
1436
|
+
- PAT (pat + workspace_host)
|
|
1437
|
+
- OBO (on_behalf_of_user)
|
|
1438
|
+
"""
|
|
1439
|
+
|
|
1155
1440
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1156
1441
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
1157
1442
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
@@ -1159,26 +1444,27 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1159
1444
|
url: Optional[AnyVariable] = None
|
|
1160
1445
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1161
1446
|
args: list[str] = Field(default_factory=list)
|
|
1162
|
-
|
|
1163
|
-
service_principal: Optional[ServicePrincipalModel] = None
|
|
1164
|
-
client_id: Optional[AnyVariable] = None
|
|
1165
|
-
client_secret: Optional[AnyVariable] = None
|
|
1166
|
-
workspace_host: Optional[AnyVariable] = None
|
|
1447
|
+
# MCP-specific fields
|
|
1167
1448
|
connection: Optional[ConnectionModel] = None
|
|
1168
1449
|
functions: Optional[SchemaModel] = None
|
|
1169
1450
|
genie_room: Optional[GenieRoomModel] = None
|
|
1170
1451
|
sql: Optional[bool] = None
|
|
1171
1452
|
vector_search: Optional[VectorStoreModel] = None
|
|
1172
1453
|
|
|
1173
|
-
@
|
|
1174
|
-
def
|
|
1175
|
-
"""
|
|
1176
|
-
|
|
1177
|
-
|
|
1178
|
-
|
|
1179
|
-
|
|
1180
|
-
|
|
1181
|
-
|
|
1454
|
+
@property
|
|
1455
|
+
def api_scopes(self) -> Sequence[str]:
|
|
1456
|
+
"""API scopes for MCP connections."""
|
|
1457
|
+
return [
|
|
1458
|
+
"serving.serving-endpoints",
|
|
1459
|
+
"mcp.genie",
|
|
1460
|
+
"mcp.functions",
|
|
1461
|
+
"mcp.vectorsearch",
|
|
1462
|
+
"mcp.external",
|
|
1463
|
+
]
|
|
1464
|
+
|
|
1465
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
1466
|
+
"""MCP functions don't declare static resources."""
|
|
1467
|
+
return []
|
|
1182
1468
|
|
|
1183
1469
|
@property
|
|
1184
1470
|
def full_name(self) -> str:
|
|
@@ -1343,27 +1629,6 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1343
1629
|
self.headers[key] = value_of(value)
|
|
1344
1630
|
return self
|
|
1345
1631
|
|
|
1346
|
-
@model_validator(mode="after")
|
|
1347
|
-
def validate_auth_methods(self) -> "McpFunctionModel":
|
|
1348
|
-
oauth_fields: Sequence[Any] = [
|
|
1349
|
-
self.client_id,
|
|
1350
|
-
self.client_secret,
|
|
1351
|
-
]
|
|
1352
|
-
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1353
|
-
|
|
1354
|
-
pat_fields: Sequence[Any] = [self.pat]
|
|
1355
|
-
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
1356
|
-
|
|
1357
|
-
if has_oauth and has_user_auth:
|
|
1358
|
-
raise ValueError(
|
|
1359
|
-
"Cannot use both OAuth and user authentication methods. "
|
|
1360
|
-
"Please provide either OAuth credentials or user credentials."
|
|
1361
|
-
)
|
|
1362
|
-
|
|
1363
|
-
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1364
|
-
|
|
1365
|
-
return self
|
|
1366
|
-
|
|
1367
1632
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1368
1633
|
from dao_ai.tools import create_mcp_tools
|
|
1369
1634
|
|
|
@@ -1405,17 +1670,97 @@ class ToolModel(BaseModel):
|
|
|
1405
1670
|
function: AnyTool
|
|
1406
1671
|
|
|
1407
1672
|
|
|
1673
|
+
class PromptModel(BaseModel, HasFullName):
|
|
1674
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1675
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1676
|
+
name: str
|
|
1677
|
+
description: Optional[str] = None
|
|
1678
|
+
default_template: Optional[str] = None
|
|
1679
|
+
alias: Optional[str] = None
|
|
1680
|
+
version: Optional[int] = None
|
|
1681
|
+
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1682
|
+
|
|
1683
|
+
@property
|
|
1684
|
+
def template(self) -> str:
|
|
1685
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1686
|
+
|
|
1687
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
1688
|
+
prompt_version = provider.get_prompt(self)
|
|
1689
|
+
return prompt_version.to_single_brace_format()
|
|
1690
|
+
|
|
1691
|
+
@property
|
|
1692
|
+
def full_name(self) -> str:
|
|
1693
|
+
prompt_name: str = self.name
|
|
1694
|
+
if self.schema_model:
|
|
1695
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1696
|
+
return prompt_name
|
|
1697
|
+
|
|
1698
|
+
@property
|
|
1699
|
+
def uri(self) -> str:
|
|
1700
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1701
|
+
|
|
1702
|
+
if self.alias:
|
|
1703
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1704
|
+
elif self.version:
|
|
1705
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1706
|
+
else:
|
|
1707
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1708
|
+
|
|
1709
|
+
return prompt_uri
|
|
1710
|
+
|
|
1711
|
+
def as_prompt(self) -> PromptVersion:
|
|
1712
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1713
|
+
return prompt_version
|
|
1714
|
+
|
|
1715
|
+
@model_validator(mode="after")
|
|
1716
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1717
|
+
if self.alias and self.version:
|
|
1718
|
+
raise ValueError("Cannot specify both alias and version")
|
|
1719
|
+
return self
|
|
1720
|
+
|
|
1721
|
+
|
|
1408
1722
|
class GuardrailModel(BaseModel):
|
|
1409
1723
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1410
1724
|
name: str
|
|
1411
|
-
model: LLMModel
|
|
1412
|
-
prompt: str
|
|
1725
|
+
model: str | LLMModel
|
|
1726
|
+
prompt: str | PromptModel
|
|
1413
1727
|
num_retries: Optional[int] = 3
|
|
1414
1728
|
|
|
1729
|
+
@model_validator(mode="after")
|
|
1730
|
+
def validate_llm_model(self) -> Self:
|
|
1731
|
+
if isinstance(self.model, str):
|
|
1732
|
+
self.model = LLMModel(name=self.model)
|
|
1733
|
+
return self
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
class MiddlewareModel(BaseModel):
|
|
1737
|
+
"""Configuration for middleware that can be applied to agents.
|
|
1738
|
+
|
|
1739
|
+
Middleware is defined at the AppConfig level and can be referenced by name
|
|
1740
|
+
in agent configurations using YAML anchors for reusability.
|
|
1741
|
+
"""
|
|
1742
|
+
|
|
1743
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1744
|
+
name: str = Field(
|
|
1745
|
+
description="Fully qualified name of the middleware factory function"
|
|
1746
|
+
)
|
|
1747
|
+
args: dict[str, Any] = Field(
|
|
1748
|
+
default_factory=dict,
|
|
1749
|
+
description="Arguments to pass to the middleware factory function",
|
|
1750
|
+
)
|
|
1751
|
+
|
|
1752
|
+
@model_validator(mode="after")
|
|
1753
|
+
def resolve_args(self) -> Self:
|
|
1754
|
+
"""Resolve any variable references in args."""
|
|
1755
|
+
for key, value in self.args.items():
|
|
1756
|
+
self.args[key] = value_of(value)
|
|
1757
|
+
return self
|
|
1758
|
+
|
|
1415
1759
|
|
|
1416
1760
|
class StorageType(str, Enum):
|
|
1417
1761
|
POSTGRES = "postgres"
|
|
1418
1762
|
MEMORY = "memory"
|
|
1763
|
+
LAKEBASE = "lakebase"
|
|
1419
1764
|
|
|
1420
1765
|
|
|
1421
1766
|
class CheckpointerModel(BaseModel):
|
|
@@ -1425,8 +1770,11 @@ class CheckpointerModel(BaseModel):
|
|
|
1425
1770
|
database: Optional[DatabaseModel] = None
|
|
1426
1771
|
|
|
1427
1772
|
@model_validator(mode="after")
|
|
1428
|
-
def
|
|
1429
|
-
if
|
|
1773
|
+
def validate_storage_requires_database(self) -> Self:
|
|
1774
|
+
if (
|
|
1775
|
+
self.type in [StorageType.POSTGRES, StorageType.LAKEBASE]
|
|
1776
|
+
and not self.database
|
|
1777
|
+
):
|
|
1430
1778
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1431
1779
|
return self
|
|
1432
1780
|
|
|
@@ -1471,56 +1819,158 @@ class MemoryModel(BaseModel):
|
|
|
1471
1819
|
FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
|
|
1472
1820
|
|
|
1473
1821
|
|
|
1474
|
-
class
|
|
1822
|
+
class ResponseFormatModel(BaseModel):
|
|
1823
|
+
"""
|
|
1824
|
+
Configuration for structured response formats.
|
|
1825
|
+
|
|
1826
|
+
The response_schema field accepts either a type or a string:
|
|
1827
|
+
- Type (Pydantic model, dataclass, etc.): Used directly for structured output
|
|
1828
|
+
- String: First attempts to load as a fully qualified type name, falls back to JSON schema string
|
|
1829
|
+
|
|
1830
|
+
This unified approach simplifies the API while maintaining flexibility.
|
|
1831
|
+
"""
|
|
1832
|
+
|
|
1475
1833
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1834
|
+
use_tool: Optional[bool] = Field(
|
|
1835
|
+
default=None,
|
|
1836
|
+
description=(
|
|
1837
|
+
"Strategy for structured output: "
|
|
1838
|
+
"None (default) = auto-detect from model capabilities, "
|
|
1839
|
+
"False = force ProviderStrategy (native), "
|
|
1840
|
+
"True = force ToolStrategy (function calling)"
|
|
1841
|
+
),
|
|
1842
|
+
)
|
|
1843
|
+
response_schema: Optional[str | type] = Field(
|
|
1844
|
+
default=None,
|
|
1845
|
+
description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
|
|
1846
|
+
)
|
|
1483
1847
|
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1848
|
+
def as_strategy(self) -> ProviderStrategy | ToolStrategy:
|
|
1849
|
+
"""
|
|
1850
|
+
Convert response_schema to appropriate LangChain strategy.
|
|
1487
1851
|
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1852
|
+
Returns:
|
|
1853
|
+
- None if no response_schema configured
|
|
1854
|
+
- Raw schema/type for auto-detection (when use_tool=None)
|
|
1855
|
+
- ToolStrategy wrapping the schema (when use_tool=True)
|
|
1856
|
+
- ProviderStrategy wrapping the schema (when use_tool=False)
|
|
1491
1857
|
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
if self.schema_model:
|
|
1496
|
-
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1497
|
-
return prompt_name
|
|
1858
|
+
Raises:
|
|
1859
|
+
ValueError: If response_schema is a JSON schema string that cannot be parsed
|
|
1860
|
+
"""
|
|
1498
1861
|
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1862
|
+
if self.response_schema is None:
|
|
1863
|
+
return None
|
|
1502
1864
|
|
|
1503
|
-
|
|
1504
|
-
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1505
|
-
elif self.version:
|
|
1506
|
-
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1507
|
-
else:
|
|
1508
|
-
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1865
|
+
schema = self.response_schema
|
|
1509
1866
|
|
|
1510
|
-
|
|
1867
|
+
# Handle type schemas (Pydantic, dataclass, etc.)
|
|
1868
|
+
if self.is_type_schema:
|
|
1869
|
+
if self.use_tool is None:
|
|
1870
|
+
# Auto-detect: Pass schema directly, let LangChain decide
|
|
1871
|
+
return schema
|
|
1872
|
+
elif self.use_tool is True:
|
|
1873
|
+
# Force ToolStrategy (function calling)
|
|
1874
|
+
return ToolStrategy(schema)
|
|
1875
|
+
else: # use_tool is False
|
|
1876
|
+
# Force ProviderStrategy (native structured output)
|
|
1877
|
+
return ProviderStrategy(schema)
|
|
1511
1878
|
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1879
|
+
# Handle JSON schema strings
|
|
1880
|
+
elif self.is_json_schema:
|
|
1881
|
+
import json
|
|
1882
|
+
|
|
1883
|
+
try:
|
|
1884
|
+
schema_dict = json.loads(schema)
|
|
1885
|
+
except json.JSONDecodeError as e:
|
|
1886
|
+
raise ValueError(f"Invalid JSON schema string: {e}") from e
|
|
1887
|
+
|
|
1888
|
+
# Apply same use_tool logic as type schemas
|
|
1889
|
+
if self.use_tool is None:
|
|
1890
|
+
# Auto-detect
|
|
1891
|
+
return schema_dict
|
|
1892
|
+
elif self.use_tool is True:
|
|
1893
|
+
# Force ToolStrategy
|
|
1894
|
+
return ToolStrategy(schema_dict)
|
|
1895
|
+
else: # use_tool is False
|
|
1896
|
+
# Force ProviderStrategy
|
|
1897
|
+
return ProviderStrategy(schema_dict)
|
|
1898
|
+
|
|
1899
|
+
return None
|
|
1515
1900
|
|
|
1516
1901
|
@model_validator(mode="after")
|
|
1517
|
-
def
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1902
|
+
def validate_response_schema(self) -> Self:
|
|
1903
|
+
"""
|
|
1904
|
+
Validate and convert response_schema.
|
|
1905
|
+
|
|
1906
|
+
Processing logic:
|
|
1907
|
+
1. If None: no response format specified
|
|
1908
|
+
2. If type: use directly as structured output type
|
|
1909
|
+
3. If str: try to load as FQN using type_from_fqn
|
|
1910
|
+
- Success: response_schema becomes the loaded type
|
|
1911
|
+
- Failure: keep as string (treated as JSON schema)
|
|
1912
|
+
|
|
1913
|
+
After validation, response_schema is one of:
|
|
1914
|
+
- None (no schema)
|
|
1915
|
+
- type (use for structured output)
|
|
1916
|
+
- str (JSON schema)
|
|
1917
|
+
|
|
1918
|
+
Returns:
|
|
1919
|
+
Self with validated response_schema
|
|
1920
|
+
"""
|
|
1921
|
+
if self.response_schema is None:
|
|
1922
|
+
return self
|
|
1923
|
+
|
|
1924
|
+
# If already a type, return
|
|
1925
|
+
if isinstance(self.response_schema, type):
|
|
1926
|
+
return self
|
|
1927
|
+
|
|
1928
|
+
# If it's a string, try to load as type, fallback to json_schema
|
|
1929
|
+
if isinstance(self.response_schema, str):
|
|
1930
|
+
from dao_ai.utils import type_from_fqn
|
|
1931
|
+
|
|
1932
|
+
try:
|
|
1933
|
+
resolved_type = type_from_fqn(self.response_schema)
|
|
1934
|
+
self.response_schema = resolved_type
|
|
1935
|
+
logger.debug(
|
|
1936
|
+
f"Resolved response_schema string to type: {resolved_type}"
|
|
1937
|
+
)
|
|
1938
|
+
return self
|
|
1939
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
1940
|
+
# Keep as string - it's a JSON schema
|
|
1941
|
+
logger.debug(
|
|
1942
|
+
f"Could not resolve '{self.response_schema}' as type: {e}. "
|
|
1943
|
+
f"Treating as JSON schema string."
|
|
1944
|
+
)
|
|
1945
|
+
return self
|
|
1946
|
+
|
|
1947
|
+
# Invalid type
|
|
1948
|
+
raise ValueError(
|
|
1949
|
+
f"response_schema must be None, type, or str, got {type(self.response_schema)}"
|
|
1950
|
+
)
|
|
1951
|
+
|
|
1952
|
+
@property
|
|
1953
|
+
def is_type_schema(self) -> bool:
|
|
1954
|
+
"""Returns True if response_schema is a type (not JSON schema string)."""
|
|
1955
|
+
return isinstance(self.response_schema, type)
|
|
1956
|
+
|
|
1957
|
+
@property
|
|
1958
|
+
def is_json_schema(self) -> bool:
|
|
1959
|
+
"""Returns True if response_schema is a JSON schema string (not a type)."""
|
|
1960
|
+
return isinstance(self.response_schema, str)
|
|
1521
1961
|
|
|
1522
1962
|
|
|
1523
1963
|
class AgentModel(BaseModel):
|
|
1964
|
+
"""
|
|
1965
|
+
Configuration model for an agent in the DAO AI framework.
|
|
1966
|
+
|
|
1967
|
+
Agents combine an LLM with tools and middleware to create systems that can
|
|
1968
|
+
reason about tasks, decide which tools to use, and iteratively work towards solutions.
|
|
1969
|
+
|
|
1970
|
+
Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
|
|
1971
|
+
providing a more flexible and composable way to customize agent behavior.
|
|
1972
|
+
"""
|
|
1973
|
+
|
|
1524
1974
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1525
1975
|
name: str
|
|
1526
1976
|
description: Optional[str] = None
|
|
@@ -1529,9 +1979,43 @@ class AgentModel(BaseModel):
|
|
|
1529
1979
|
guardrails: list[GuardrailModel] = Field(default_factory=list)
|
|
1530
1980
|
prompt: Optional[str | PromptModel] = None
|
|
1531
1981
|
handoff_prompt: Optional[str] = None
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1982
|
+
middleware: list[MiddlewareModel] = Field(
|
|
1983
|
+
default_factory=list,
|
|
1984
|
+
description="List of middleware to apply to this agent",
|
|
1985
|
+
)
|
|
1986
|
+
response_format: Optional[ResponseFormatModel | type | str] = None
|
|
1987
|
+
|
|
1988
|
+
@model_validator(mode="after")
|
|
1989
|
+
def validate_response_format(self) -> Self:
|
|
1990
|
+
"""
|
|
1991
|
+
Validate and normalize response_format.
|
|
1992
|
+
|
|
1993
|
+
Accepts:
|
|
1994
|
+
- None (no response format)
|
|
1995
|
+
- ResponseFormatModel (already validated)
|
|
1996
|
+
- type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
|
|
1997
|
+
- str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
|
|
1998
|
+
|
|
1999
|
+
ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
|
|
2000
|
+
"""
|
|
2001
|
+
if self.response_format is None or isinstance(
|
|
2002
|
+
self.response_format, ResponseFormatModel
|
|
2003
|
+
):
|
|
2004
|
+
return self
|
|
2005
|
+
|
|
2006
|
+
# Convert type or str to ResponseFormatModel
|
|
2007
|
+
# ResponseFormatModel's validator will handle the smart type loading and fallback
|
|
2008
|
+
if isinstance(self.response_format, (type, str)):
|
|
2009
|
+
self.response_format = ResponseFormatModel(
|
|
2010
|
+
response_schema=self.response_format
|
|
2011
|
+
)
|
|
2012
|
+
return self
|
|
2013
|
+
|
|
2014
|
+
# Invalid type
|
|
2015
|
+
raise ValueError(
|
|
2016
|
+
f"response_format must be None, ResponseFormatModel, type, or str, "
|
|
2017
|
+
f"got {type(self.response_format)}"
|
|
2018
|
+
)
|
|
1535
2019
|
|
|
1536
2020
|
def as_runnable(self) -> RunnableLike:
|
|
1537
2021
|
from dao_ai.nodes import create_agent_node
|
|
@@ -1550,6 +2034,10 @@ class SupervisorModel(BaseModel):
|
|
|
1550
2034
|
model: LLMModel
|
|
1551
2035
|
tools: list[ToolModel] = Field(default_factory=list)
|
|
1552
2036
|
prompt: Optional[str] = None
|
|
2037
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2038
|
+
default_factory=list,
|
|
2039
|
+
description="List of middleware to apply to the supervisor",
|
|
2040
|
+
)
|
|
1553
2041
|
|
|
1554
2042
|
|
|
1555
2043
|
class SwarmModel(BaseModel):
|
|
@@ -1673,6 +2161,28 @@ class ChatPayload(BaseModel):
|
|
|
1673
2161
|
|
|
1674
2162
|
return self
|
|
1675
2163
|
|
|
2164
|
+
@model_validator(mode="after")
|
|
2165
|
+
def ensure_thread_id(self) -> "ChatPayload":
|
|
2166
|
+
"""Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
|
|
2167
|
+
import uuid
|
|
2168
|
+
|
|
2169
|
+
if self.custom_inputs is None:
|
|
2170
|
+
self.custom_inputs = {}
|
|
2171
|
+
|
|
2172
|
+
# Get or create configurable section
|
|
2173
|
+
configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
|
|
2174
|
+
|
|
2175
|
+
# Check if thread_id or conversation_id exists
|
|
2176
|
+
has_thread_id = configurable.get("thread_id") is not None
|
|
2177
|
+
has_conversation_id = configurable.get("conversation_id") is not None
|
|
2178
|
+
|
|
2179
|
+
# If neither is provided, generate a UUID for conversation_id
|
|
2180
|
+
if not has_thread_id and not has_conversation_id:
|
|
2181
|
+
configurable["conversation_id"] = str(uuid.uuid4())
|
|
2182
|
+
self.custom_inputs["configurable"] = configurable
|
|
2183
|
+
|
|
2184
|
+
return self
|
|
2185
|
+
|
|
1676
2186
|
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1677
2187
|
return messages_from_dict(
|
|
1678
2188
|
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
@@ -1688,20 +2198,38 @@ class ChatPayload(BaseModel):
|
|
|
1688
2198
|
|
|
1689
2199
|
|
|
1690
2200
|
class ChatHistoryModel(BaseModel):
|
|
2201
|
+
"""
|
|
2202
|
+
Configuration for chat history summarization.
|
|
2203
|
+
|
|
2204
|
+
Attributes:
|
|
2205
|
+
model: The LLM to use for generating summaries.
|
|
2206
|
+
max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
|
|
2207
|
+
After summarization, recent messages totaling up to this many tokens are preserved.
|
|
2208
|
+
max_tokens_before_summary: Token threshold that triggers summarization.
|
|
2209
|
+
When conversation exceeds this, summarization runs. Mutually exclusive with
|
|
2210
|
+
max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
|
|
2211
|
+
max_messages_before_summary: Message count threshold that triggers summarization.
|
|
2212
|
+
When conversation exceeds this many messages, summarization runs.
|
|
2213
|
+
Mutually exclusive with max_tokens_before_summary.
|
|
2214
|
+
"""
|
|
2215
|
+
|
|
1691
2216
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1692
2217
|
model: LLMModel
|
|
1693
|
-
max_tokens: int =
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1703
|
-
|
|
1704
|
-
|
|
2218
|
+
max_tokens: int = Field(
|
|
2219
|
+
default=2048,
|
|
2220
|
+
gt=0,
|
|
2221
|
+
description="Maximum tokens to keep after summarization",
|
|
2222
|
+
)
|
|
2223
|
+
max_tokens_before_summary: Optional[int] = Field(
|
|
2224
|
+
default=None,
|
|
2225
|
+
gt=0,
|
|
2226
|
+
description="Token threshold that triggers summarization",
|
|
2227
|
+
)
|
|
2228
|
+
max_messages_before_summary: Optional[int] = Field(
|
|
2229
|
+
default=None,
|
|
2230
|
+
gt=0,
|
|
2231
|
+
description="Message count threshold that triggers summarization",
|
|
2232
|
+
)
|
|
1705
2233
|
|
|
1706
2234
|
|
|
1707
2235
|
class AppModel(BaseModel):
|
|
@@ -1728,9 +2256,6 @@ class AppModel(BaseModel):
|
|
|
1728
2256
|
shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1729
2257
|
default_factory=list
|
|
1730
2258
|
)
|
|
1731
|
-
message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1732
|
-
default_factory=list
|
|
1733
|
-
)
|
|
1734
2259
|
input_example: Optional[ChatPayload] = None
|
|
1735
2260
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1736
2261
|
code_paths: list[str] = Field(default_factory=list)
|
|
@@ -1935,33 +2460,67 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
|
1935
2460
|
|
|
1936
2461
|
|
|
1937
2462
|
class PromptOptimizationModel(BaseModel):
|
|
2463
|
+
"""Configuration for prompt optimization using GEPA.
|
|
2464
|
+
|
|
2465
|
+
GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
|
|
2466
|
+
optimizer that uses reflective mutation to improve prompts based on
|
|
2467
|
+
evaluation feedback.
|
|
2468
|
+
|
|
2469
|
+
Example:
|
|
2470
|
+
prompt_optimization:
|
|
2471
|
+
name: optimize_my_prompt
|
|
2472
|
+
prompt: *my_prompt
|
|
2473
|
+
agent: *my_agent
|
|
2474
|
+
dataset: *my_training_dataset
|
|
2475
|
+
reflection_model: databricks-meta-llama-3-3-70b-instruct
|
|
2476
|
+
num_candidates: 50
|
|
2477
|
+
"""
|
|
2478
|
+
|
|
1938
2479
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1939
2480
|
name: str
|
|
1940
2481
|
prompt: Optional[PromptModel] = None
|
|
1941
2482
|
agent: AgentModel
|
|
1942
|
-
dataset:
|
|
1943
|
-
EvaluationDatasetModel | str
|
|
1944
|
-
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
2483
|
+
dataset: EvaluationDatasetModel # Training dataset with examples
|
|
1945
2484
|
reflection_model: Optional[LLMModel | str] = None
|
|
1946
2485
|
num_candidates: Optional[int] = 50
|
|
1947
|
-
scorer_model: Optional[LLMModel | str] = None
|
|
1948
2486
|
|
|
1949
2487
|
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1950
2488
|
"""
|
|
1951
|
-
Optimize the prompt using
|
|
2489
|
+
Optimize the prompt using GEPA.
|
|
1952
2490
|
|
|
1953
2491
|
Args:
|
|
1954
|
-
w: Optional WorkspaceClient for
|
|
2492
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
1955
2493
|
|
|
1956
2494
|
Returns:
|
|
1957
|
-
PromptModel: The optimized prompt model
|
|
2495
|
+
PromptModel: The optimized prompt model
|
|
1958
2496
|
"""
|
|
1959
|
-
from dao_ai.
|
|
1960
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
2497
|
+
from dao_ai.optimization import OptimizationResult, optimize_prompt
|
|
1961
2498
|
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
2499
|
+
# Get reflection model name
|
|
2500
|
+
reflection_model_name: str | None = None
|
|
2501
|
+
if self.reflection_model:
|
|
2502
|
+
if isinstance(self.reflection_model, str):
|
|
2503
|
+
reflection_model_name = self.reflection_model
|
|
2504
|
+
else:
|
|
2505
|
+
reflection_model_name = self.reflection_model.uri
|
|
2506
|
+
|
|
2507
|
+
# Ensure prompt is set
|
|
2508
|
+
prompt = self.prompt
|
|
2509
|
+
if prompt is None:
|
|
2510
|
+
raise ValueError(
|
|
2511
|
+
f"Prompt optimization '{self.name}' requires a prompt to be set"
|
|
2512
|
+
)
|
|
2513
|
+
|
|
2514
|
+
result: OptimizationResult = optimize_prompt(
|
|
2515
|
+
prompt=prompt,
|
|
2516
|
+
agent=self.agent,
|
|
2517
|
+
dataset=self.dataset,
|
|
2518
|
+
reflection_model=reflection_model_name,
|
|
2519
|
+
num_candidates=self.num_candidates or 50,
|
|
2520
|
+
register_if_improved=True,
|
|
2521
|
+
)
|
|
2522
|
+
|
|
2523
|
+
return result.optimized_prompt
|
|
1965
2524
|
|
|
1966
2525
|
@model_validator(mode="after")
|
|
1967
2526
|
def set_defaults(self) -> Self:
|
|
@@ -1975,12 +2534,6 @@ class PromptOptimizationModel(BaseModel):
|
|
|
1975
2534
|
f"or an agent with a prompt configured"
|
|
1976
2535
|
)
|
|
1977
2536
|
|
|
1978
|
-
if self.reflection_model is None:
|
|
1979
|
-
self.reflection_model = self.agent.model
|
|
1980
|
-
|
|
1981
|
-
if self.scorer_model is None:
|
|
1982
|
-
self.scorer_model = self.agent.model
|
|
1983
|
-
|
|
1984
2537
|
return self
|
|
1985
2538
|
|
|
1986
2539
|
|
|
@@ -2081,6 +2634,7 @@ class ResourcesModel(BaseModel):
|
|
|
2081
2634
|
warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
|
|
2082
2635
|
databases: dict[str, DatabaseModel] = Field(default_factory=dict)
|
|
2083
2636
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
2637
|
+
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
2084
2638
|
|
|
2085
2639
|
|
|
2086
2640
|
class AppConfig(BaseModel):
|
|
@@ -2092,6 +2646,7 @@ class AppConfig(BaseModel):
|
|
|
2092
2646
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
|
2093
2647
|
tools: dict[str, ToolModel] = Field(default_factory=dict)
|
|
2094
2648
|
guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
|
|
2649
|
+
middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
|
|
2095
2650
|
memory: Optional[MemoryModel] = None
|
|
2096
2651
|
prompts: dict[str, PromptModel] = Field(default_factory=dict)
|
|
2097
2652
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|