dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -28,9 +28,11 @@ from databricks.sdk.service.database import DatabaseInstance
|
|
|
28
28
|
from databricks.vector_search.client import VectorSearchClient
|
|
29
29
|
from databricks.vector_search.index import VectorSearchIndex
|
|
30
30
|
from databricks_langchain import (
|
|
31
|
+
ChatDatabricks,
|
|
31
32
|
DatabricksEmbeddings,
|
|
32
33
|
DatabricksFunctionClient,
|
|
33
34
|
)
|
|
35
|
+
from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
|
|
34
36
|
from langchain_core.embeddings import Embeddings
|
|
35
37
|
from langchain_core.language_models import LanguageModelLike
|
|
36
38
|
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
@@ -44,6 +46,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
|
44
46
|
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
45
47
|
from mlflow.models import ModelConfig
|
|
46
48
|
from mlflow.models.resources import (
|
|
49
|
+
DatabricksApp,
|
|
47
50
|
DatabricksFunction,
|
|
48
51
|
DatabricksGenieSpace,
|
|
49
52
|
DatabricksLakebase,
|
|
@@ -84,27 +87,6 @@ class HasFullName(ABC):
|
|
|
84
87
|
def full_name(self) -> str: ...
|
|
85
88
|
|
|
86
89
|
|
|
87
|
-
class IsDatabricksResource(ABC):
|
|
88
|
-
on_behalf_of_user: Optional[bool] = False
|
|
89
|
-
|
|
90
|
-
@abstractmethod
|
|
91
|
-
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
92
|
-
|
|
93
|
-
@property
|
|
94
|
-
@abstractmethod
|
|
95
|
-
def api_scopes(self) -> Sequence[str]: ...
|
|
96
|
-
|
|
97
|
-
@property
|
|
98
|
-
def workspace_client(self) -> WorkspaceClient:
|
|
99
|
-
credentials_strategy: CredentialsStrategy = None
|
|
100
|
-
if self.on_behalf_of_user:
|
|
101
|
-
credentials_strategy = ModelServingUserCredentials()
|
|
102
|
-
logger.debug(
|
|
103
|
-
f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
|
|
104
|
-
)
|
|
105
|
-
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
106
|
-
|
|
107
|
-
|
|
108
90
|
class EnvironmentVariableModel(BaseModel, HasValue):
|
|
109
91
|
model_config = ConfigDict(
|
|
110
92
|
frozen=True,
|
|
@@ -212,6 +194,138 @@ class ServicePrincipalModel(BaseModel):
|
|
|
212
194
|
client_secret: AnyVariable
|
|
213
195
|
|
|
214
196
|
|
|
197
|
+
class IsDatabricksResource(ABC, BaseModel):
|
|
198
|
+
"""
|
|
199
|
+
Base class for Databricks resources with authentication support.
|
|
200
|
+
|
|
201
|
+
Authentication Options:
|
|
202
|
+
----------------------
|
|
203
|
+
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
204
|
+
calling user's identity via ModelServingUserCredentials.
|
|
205
|
+
|
|
206
|
+
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
207
|
+
(client_id + client_secret + workspace_host) for service principal auth.
|
|
208
|
+
|
|
209
|
+
3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
|
|
210
|
+
to authenticate with a personal access token.
|
|
211
|
+
|
|
212
|
+
4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
|
|
213
|
+
(environment variables, notebook context, etc.)
|
|
214
|
+
|
|
215
|
+
Authentication Priority:
|
|
216
|
+
1. OBO (on_behalf_of_user=True)
|
|
217
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
218
|
+
3. PAT (pat + workspace_host)
|
|
219
|
+
4. Ambient/default authentication
|
|
220
|
+
"""
|
|
221
|
+
|
|
222
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
223
|
+
|
|
224
|
+
on_behalf_of_user: Optional[bool] = False
|
|
225
|
+
service_principal: Optional[ServicePrincipalModel] = None
|
|
226
|
+
client_id: Optional[AnyVariable] = None
|
|
227
|
+
client_secret: Optional[AnyVariable] = None
|
|
228
|
+
workspace_host: Optional[AnyVariable] = None
|
|
229
|
+
pat: Optional[AnyVariable] = None
|
|
230
|
+
|
|
231
|
+
@abstractmethod
|
|
232
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
233
|
+
|
|
234
|
+
@property
|
|
235
|
+
@abstractmethod
|
|
236
|
+
def api_scopes(self) -> Sequence[str]: ...
|
|
237
|
+
|
|
238
|
+
@model_validator(mode="after")
|
|
239
|
+
def _expand_service_principal(self) -> Self:
|
|
240
|
+
"""Expand service_principal into client_id and client_secret if provided."""
|
|
241
|
+
if self.service_principal is not None:
|
|
242
|
+
if self.client_id is None:
|
|
243
|
+
self.client_id = self.service_principal.client_id
|
|
244
|
+
if self.client_secret is None:
|
|
245
|
+
self.client_secret = self.service_principal.client_secret
|
|
246
|
+
return self
|
|
247
|
+
|
|
248
|
+
@model_validator(mode="after")
|
|
249
|
+
def _validate_auth_not_mixed(self) -> Self:
|
|
250
|
+
"""Validate that OAuth and PAT authentication are not both provided."""
|
|
251
|
+
has_oauth: bool = self.client_id is not None and self.client_secret is not None
|
|
252
|
+
has_pat: bool = self.pat is not None
|
|
253
|
+
|
|
254
|
+
if has_oauth and has_pat:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
"Cannot use both OAuth and user authentication methods. "
|
|
257
|
+
"Please provide either OAuth credentials or user credentials."
|
|
258
|
+
)
|
|
259
|
+
return self
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
263
|
+
"""
|
|
264
|
+
Get a WorkspaceClient configured with the appropriate authentication.
|
|
265
|
+
|
|
266
|
+
Authentication priority:
|
|
267
|
+
1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
|
|
268
|
+
2. If service principal credentials are configured (client_id, client_secret,
|
|
269
|
+
workspace_host), uses OAuth M2M
|
|
270
|
+
3. If PAT is configured, uses token authentication
|
|
271
|
+
4. Otherwise, uses default/ambient authentication
|
|
272
|
+
"""
|
|
273
|
+
from dao_ai.utils import normalize_host
|
|
274
|
+
|
|
275
|
+
# Check for OBO first (highest priority)
|
|
276
|
+
if self.on_behalf_of_user:
|
|
277
|
+
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
278
|
+
logger.debug(
|
|
279
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
280
|
+
f"with OBO credentials strategy"
|
|
281
|
+
)
|
|
282
|
+
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
283
|
+
|
|
284
|
+
# Check for service principal credentials
|
|
285
|
+
client_id_value: str | None = (
|
|
286
|
+
value_of(self.client_id) if self.client_id else None
|
|
287
|
+
)
|
|
288
|
+
client_secret_value: str | None = (
|
|
289
|
+
value_of(self.client_secret) if self.client_secret else None
|
|
290
|
+
)
|
|
291
|
+
workspace_host_value: str | None = (
|
|
292
|
+
normalize_host(value_of(self.workspace_host))
|
|
293
|
+
if self.workspace_host
|
|
294
|
+
else None
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
if client_id_value and client_secret_value and workspace_host_value:
|
|
298
|
+
logger.debug(
|
|
299
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
300
|
+
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
301
|
+
)
|
|
302
|
+
return WorkspaceClient(
|
|
303
|
+
host=workspace_host_value,
|
|
304
|
+
client_id=client_id_value,
|
|
305
|
+
client_secret=client_secret_value,
|
|
306
|
+
auth_type="oauth-m2m",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# Check for PAT authentication
|
|
310
|
+
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
311
|
+
if pat_value:
|
|
312
|
+
logger.debug(
|
|
313
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
314
|
+
)
|
|
315
|
+
return WorkspaceClient(
|
|
316
|
+
host=workspace_host_value,
|
|
317
|
+
token=pat_value,
|
|
318
|
+
auth_type="pat",
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
# Default: use ambient authentication
|
|
322
|
+
logger.debug(
|
|
323
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
324
|
+
"with default/ambient authentication"
|
|
325
|
+
)
|
|
326
|
+
return WorkspaceClient()
|
|
327
|
+
|
|
328
|
+
|
|
215
329
|
class Privilege(str, Enum):
|
|
216
330
|
ALL_PRIVILEGES = "ALL_PRIVILEGES"
|
|
217
331
|
USE_CATALOG = "USE_CATALOG"
|
|
@@ -272,7 +386,26 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
272
386
|
provider.create_schema(self)
|
|
273
387
|
|
|
274
388
|
|
|
275
|
-
class
|
|
389
|
+
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
390
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
391
|
+
name: str
|
|
392
|
+
url: str
|
|
393
|
+
|
|
394
|
+
@property
|
|
395
|
+
def full_name(self) -> str:
|
|
396
|
+
return self.name
|
|
397
|
+
|
|
398
|
+
@property
|
|
399
|
+
def api_scopes(self) -> Sequence[str]:
|
|
400
|
+
return ["apps.apps"]
|
|
401
|
+
|
|
402
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
403
|
+
return [
|
|
404
|
+
DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class TableModel(IsDatabricksResource, HasFullName):
|
|
276
409
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
277
410
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
278
411
|
name: Optional[str] = None
|
|
@@ -341,12 +474,16 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
341
474
|
return resources
|
|
342
475
|
|
|
343
476
|
|
|
344
|
-
class LLMModel(
|
|
477
|
+
class LLMModel(IsDatabricksResource):
|
|
345
478
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
346
479
|
name: str
|
|
347
480
|
temperature: Optional[float] = 0.1
|
|
348
481
|
max_tokens: Optional[int] = 8192
|
|
349
482
|
fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
|
|
483
|
+
use_responses_api: Optional[bool] = Field(
|
|
484
|
+
default=False,
|
|
485
|
+
description="Use Responses API for ResponsesAgent endpoints",
|
|
486
|
+
)
|
|
350
487
|
|
|
351
488
|
@property
|
|
352
489
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -366,19 +503,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
366
503
|
]
|
|
367
504
|
|
|
368
505
|
def as_chat_model(self) -> LanguageModelLike:
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
from dao_ai.chat_models import ChatDatabricksFiltered
|
|
375
|
-
|
|
376
|
-
chat_client: LanguageModelLike = ChatDatabricksFiltered(
|
|
377
|
-
model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
506
|
+
chat_client: LanguageModelLike = ChatDatabricks(
|
|
507
|
+
model=self.name,
|
|
508
|
+
temperature=self.temperature,
|
|
509
|
+
max_tokens=self.max_tokens,
|
|
510
|
+
use_responses_api=self.use_responses_api,
|
|
378
511
|
)
|
|
379
|
-
# chat_client: LanguageModelLike = ChatDatabricks(
|
|
380
|
-
# model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
|
|
381
|
-
# )
|
|
382
512
|
|
|
383
513
|
fallbacks: Sequence[LanguageModelLike] = []
|
|
384
514
|
for fallback in self.fallbacks:
|
|
@@ -432,7 +562,7 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
432
562
|
return str(value)
|
|
433
563
|
|
|
434
564
|
|
|
435
|
-
class IndexModel(
|
|
565
|
+
class IndexModel(IsDatabricksResource, HasFullName):
|
|
436
566
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
437
567
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
438
568
|
name: str
|
|
@@ -457,7 +587,7 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
457
587
|
]
|
|
458
588
|
|
|
459
589
|
|
|
460
|
-
class GenieRoomModel(
|
|
590
|
+
class GenieRoomModel(IsDatabricksResource):
|
|
461
591
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
462
592
|
name: str
|
|
463
593
|
description: Optional[str] = None
|
|
@@ -483,7 +613,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
483
613
|
return self
|
|
484
614
|
|
|
485
615
|
|
|
486
|
-
class VolumeModel(
|
|
616
|
+
class VolumeModel(IsDatabricksResource, HasFullName):
|
|
487
617
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
488
618
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
489
619
|
name: str
|
|
@@ -543,7 +673,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
543
673
|
provider.create_path(self)
|
|
544
674
|
|
|
545
675
|
|
|
546
|
-
class VectorStoreModel(
|
|
676
|
+
class VectorStoreModel(IsDatabricksResource):
|
|
547
677
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
548
678
|
embedding_model: Optional[LLMModel] = None
|
|
549
679
|
index: Optional[IndexModel] = None
|
|
@@ -642,7 +772,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
642
772
|
provider.create_vector_store(self)
|
|
643
773
|
|
|
644
774
|
|
|
645
|
-
class FunctionModel(
|
|
775
|
+
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
646
776
|
model_config = ConfigDict()
|
|
647
777
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
648
778
|
name: Optional[str] = None
|
|
@@ -697,7 +827,7 @@ class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
697
827
|
return ["sql.statement-execution"]
|
|
698
828
|
|
|
699
829
|
|
|
700
|
-
class ConnectionModel(
|
|
830
|
+
class ConnectionModel(IsDatabricksResource, HasFullName):
|
|
701
831
|
model_config = ConfigDict()
|
|
702
832
|
name: str
|
|
703
833
|
|
|
@@ -724,7 +854,7 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
724
854
|
]
|
|
725
855
|
|
|
726
856
|
|
|
727
|
-
class WarehouseModel(
|
|
857
|
+
class WarehouseModel(IsDatabricksResource):
|
|
728
858
|
model_config = ConfigDict()
|
|
729
859
|
name: str
|
|
730
860
|
description: Optional[str] = None
|
|
@@ -751,30 +881,28 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
751
881
|
return self
|
|
752
882
|
|
|
753
883
|
|
|
754
|
-
class
|
|
884
|
+
class DatabaseType(str, Enum):
|
|
885
|
+
POSTGRES = "postgres"
|
|
886
|
+
LAKEBASE = "lakebase"
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
class DatabaseModel(IsDatabricksResource):
|
|
755
890
|
"""
|
|
756
891
|
Configuration for a Databricks Lakebase (PostgreSQL) database instance.
|
|
757
892
|
|
|
758
|
-
Authentication
|
|
759
|
-
|
|
760
|
-
This model uses TWO separate authentication contexts:
|
|
761
|
-
|
|
762
|
-
1. **Workspace API Authentication** (inherited from IsDatabricksResource):
|
|
763
|
-
- Uses ambient/default authentication (environment variables, notebook context, app service principal)
|
|
764
|
-
- Used for: discovering database instance, getting host DNS, checking instance status
|
|
765
|
-
- Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
|
|
893
|
+
Authentication is inherited from IsDatabricksResource. Additionally supports:
|
|
894
|
+
- user/password: For user-based database authentication
|
|
766
895
|
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
- OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
|
|
771
|
-
- User Auth: Set user (and optionally password) to connect as a user identity
|
|
896
|
+
Database Type:
|
|
897
|
+
- lakebase: Databricks-managed Lakebase instance (authentication optional, supports ambient auth)
|
|
898
|
+
- postgres: Standard PostgreSQL database (authentication required)
|
|
772
899
|
|
|
773
900
|
Example Service Principal Configuration:
|
|
774
901
|
```yaml
|
|
775
902
|
databases:
|
|
776
903
|
my_lakebase:
|
|
777
904
|
name: my-database
|
|
905
|
+
type: lakebase
|
|
778
906
|
service_principal:
|
|
779
907
|
client_id:
|
|
780
908
|
env: SERVICE_PRINCIPAL_CLIENT_ID
|
|
@@ -785,31 +913,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
785
913
|
env: DATABRICKS_HOST
|
|
786
914
|
```
|
|
787
915
|
|
|
788
|
-
Example
|
|
916
|
+
Example User Configuration:
|
|
789
917
|
```yaml
|
|
790
918
|
databases:
|
|
791
919
|
my_lakebase:
|
|
792
920
|
name: my-database
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
client_secret:
|
|
796
|
-
scope: my-scope
|
|
797
|
-
secret: sp-client-secret
|
|
798
|
-
workspace_host:
|
|
799
|
-
env: DATABRICKS_HOST
|
|
921
|
+
type: lakebase
|
|
922
|
+
user: my-user@databricks.com
|
|
800
923
|
```
|
|
801
924
|
|
|
802
|
-
Example
|
|
925
|
+
Example Ambient Authentication (Lakebase only):
|
|
803
926
|
```yaml
|
|
804
927
|
databases:
|
|
805
928
|
my_lakebase:
|
|
806
929
|
name: my-database
|
|
807
|
-
|
|
930
|
+
type: lakebase
|
|
931
|
+
on_behalf_of_user: true
|
|
808
932
|
```
|
|
809
933
|
"""
|
|
810
934
|
|
|
811
935
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
812
936
|
name: str
|
|
937
|
+
type: Optional[DatabaseType] = DatabaseType.LAKEBASE
|
|
813
938
|
instance_name: Optional[str] = None
|
|
814
939
|
description: Optional[str] = None
|
|
815
940
|
host: Optional[AnyVariable] = None
|
|
@@ -820,16 +945,18 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
820
945
|
timeout_seconds: Optional[int] = 10
|
|
821
946
|
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
822
947
|
node_count: Optional[int] = None
|
|
948
|
+
# Database-specific auth (user identity for DB connection)
|
|
823
949
|
user: Optional[AnyVariable] = None
|
|
824
950
|
password: Optional[AnyVariable] = None
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
951
|
+
|
|
952
|
+
@field_serializer("type")
|
|
953
|
+
def serialize_type(self, value: DatabaseType | None) -> str | None:
|
|
954
|
+
"""Serialize the database type enum to its string value."""
|
|
955
|
+
return value.value if value is not None else None
|
|
829
956
|
|
|
830
957
|
@property
|
|
831
958
|
def api_scopes(self) -> Sequence[str]:
|
|
832
|
-
return []
|
|
959
|
+
return ["database.database-instances"]
|
|
833
960
|
|
|
834
961
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
835
962
|
return [
|
|
@@ -843,29 +970,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
843
970
|
def update_instance_name(self) -> Self:
|
|
844
971
|
if self.instance_name is None:
|
|
845
972
|
self.instance_name = self.name
|
|
846
|
-
|
|
847
|
-
return self
|
|
848
|
-
|
|
849
|
-
@model_validator(mode="after")
|
|
850
|
-
def expand_service_principal(self) -> Self:
|
|
851
|
-
"""Expand service_principal into client_id and client_secret if provided."""
|
|
852
|
-
if self.service_principal is not None:
|
|
853
|
-
if self.client_id is None:
|
|
854
|
-
self.client_id = self.service_principal.client_id
|
|
855
|
-
if self.client_secret is None:
|
|
856
|
-
self.client_secret = self.service_principal.client_secret
|
|
857
973
|
return self
|
|
858
974
|
|
|
859
975
|
@model_validator(mode="after")
|
|
860
976
|
def update_user(self) -> Self:
|
|
861
|
-
if
|
|
977
|
+
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
978
|
+
if self.on_behalf_of_user or self.client_id or self.user or self.pat:
|
|
862
979
|
return self
|
|
863
980
|
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
981
|
+
# For postgres, we need explicit user credentials
|
|
982
|
+
# For lakebase with no auth, ambient auth is allowed
|
|
983
|
+
if self.type == DatabaseType.POSTGRES:
|
|
984
|
+
# Try to determine current user for local development
|
|
985
|
+
try:
|
|
986
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
987
|
+
except Exception as e:
|
|
988
|
+
logger.warning(
|
|
989
|
+
f"Could not determine current user for PostgreSQL database: {e}. "
|
|
990
|
+
f"Please provide explicit user credentials."
|
|
991
|
+
)
|
|
992
|
+
else:
|
|
993
|
+
# For lakebase, try to determine current user but don't fail if we can't
|
|
994
|
+
try:
|
|
995
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
996
|
+
except Exception:
|
|
997
|
+
# If we can't determine user and no explicit auth, that's okay
|
|
998
|
+
# for lakebase with ambient auth - credentials will be injected at runtime
|
|
999
|
+
pass
|
|
869
1000
|
|
|
870
1001
|
return self
|
|
871
1002
|
|
|
@@ -874,12 +1005,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
874
1005
|
if self.host is not None:
|
|
875
1006
|
return self
|
|
876
1007
|
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
1008
|
+
# Try to fetch host from existing instance
|
|
1009
|
+
# This may fail for OBO/ambient auth during model logging (before deployment)
|
|
1010
|
+
try:
|
|
1011
|
+
existing_instance: DatabaseInstance = (
|
|
1012
|
+
self.workspace_client.database.get_database_instance(
|
|
1013
|
+
name=self.instance_name
|
|
1014
|
+
)
|
|
880
1015
|
)
|
|
881
|
-
|
|
882
|
-
|
|
1016
|
+
self.host = existing_instance.read_write_dns
|
|
1017
|
+
except Exception as e:
|
|
1018
|
+
# For lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1019
|
+
# The host will need to be provided explicitly or fetched at runtime
|
|
1020
|
+
if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
|
|
1021
|
+
logger.debug(
|
|
1022
|
+
f"Could not fetch host for database {self.instance_name} "
|
|
1023
|
+
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1024
|
+
)
|
|
1025
|
+
else:
|
|
1026
|
+
raise ValueError(
|
|
1027
|
+
f"Could not fetch host for database {self.instance_name}. "
|
|
1028
|
+
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1029
|
+
)
|
|
883
1030
|
return self
|
|
884
1031
|
|
|
885
1032
|
@model_validator(mode="after")
|
|
@@ -890,21 +1037,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
890
1037
|
self.client_secret,
|
|
891
1038
|
]
|
|
892
1039
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1040
|
+
has_user_auth: bool = self.user is not None
|
|
1041
|
+
has_obo: bool = self.on_behalf_of_user is True
|
|
1042
|
+
has_pat: bool = self.pat is not None
|
|
893
1043
|
|
|
894
|
-
|
|
895
|
-
|
|
1044
|
+
# Count how many auth methods are configured
|
|
1045
|
+
auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
|
|
896
1046
|
|
|
897
|
-
if
|
|
1047
|
+
if auth_methods_count > 1:
|
|
898
1048
|
raise ValueError(
|
|
899
|
-
"Cannot
|
|
900
|
-
"Please provide
|
|
1049
|
+
"Cannot mix authentication methods. "
|
|
1050
|
+
"Please provide exactly one of: "
|
|
1051
|
+
"on_behalf_of_user=true (for passive auth in model serving), "
|
|
1052
|
+
"OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
|
|
1053
|
+
"PAT (personal access token), "
|
|
1054
|
+
"or user credentials (user)."
|
|
901
1055
|
)
|
|
902
1056
|
|
|
903
|
-
|
|
1057
|
+
# For postgres type, at least one auth method must be configured
|
|
1058
|
+
# For lakebase type, auth is optional (supports ambient authentication)
|
|
1059
|
+
if self.type == DatabaseType.POSTGRES and auth_methods_count == 0:
|
|
904
1060
|
raise ValueError(
|
|
905
|
-
"
|
|
906
|
-
"
|
|
907
|
-
"
|
|
1061
|
+
"PostgreSQL databases require explicit authentication. "
|
|
1062
|
+
"Please provide one of: "
|
|
1063
|
+
"OAuth credentials (workspace_host, client_id, client_secret), "
|
|
1064
|
+
"service_principal with workspace_host, "
|
|
1065
|
+
"PAT (personal access token), "
|
|
1066
|
+
"or user credentials (user)."
|
|
908
1067
|
)
|
|
909
1068
|
|
|
910
1069
|
return self
|
|
@@ -918,8 +1077,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
918
1077
|
If username is configured, it will be included; otherwise it will be omitted
|
|
919
1078
|
to allow Lakebase to authenticate using the token's identity.
|
|
920
1079
|
"""
|
|
921
|
-
|
|
922
|
-
|
|
1080
|
+
import uuid as _uuid
|
|
1081
|
+
|
|
1082
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
923
1083
|
|
|
924
1084
|
username: str | None = None
|
|
925
1085
|
|
|
@@ -927,19 +1087,36 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
927
1087
|
username = value_of(self.client_id)
|
|
928
1088
|
elif self.user:
|
|
929
1089
|
username = value_of(self.user)
|
|
1090
|
+
# For OBO mode, no username is needed - the token identity is used
|
|
1091
|
+
|
|
1092
|
+
# Resolve host - may need to fetch at runtime for OBO mode
|
|
1093
|
+
host_value: Any = self.host
|
|
1094
|
+
if host_value is None and self.on_behalf_of_user:
|
|
1095
|
+
# Fetch host at runtime for OBO mode
|
|
1096
|
+
existing_instance: DatabaseInstance = (
|
|
1097
|
+
self.workspace_client.database.get_database_instance(
|
|
1098
|
+
name=self.instance_name
|
|
1099
|
+
)
|
|
1100
|
+
)
|
|
1101
|
+
host_value = existing_instance.read_write_dns
|
|
930
1102
|
|
|
931
|
-
|
|
1103
|
+
if host_value is None:
|
|
1104
|
+
raise ValueError(
|
|
1105
|
+
f"Database host not configured for {self.instance_name}. "
|
|
1106
|
+
"Please provide 'host' explicitly."
|
|
1107
|
+
)
|
|
1108
|
+
|
|
1109
|
+
host: str = value_of(host_value)
|
|
932
1110
|
port: int = value_of(self.port)
|
|
933
1111
|
database: str = value_of(self.database)
|
|
934
1112
|
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
1113
|
+
# Use the resource's own workspace_client to generate the database credential
|
|
1114
|
+
w: WorkspaceClient = self.workspace_client
|
|
1115
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
1116
|
+
request_id=str(_uuid.uuid4()),
|
|
1117
|
+
instance_names=[self.instance_name],
|
|
940
1118
|
)
|
|
941
|
-
|
|
942
|
-
token: str = provider.lakebase_password_provider(self.instance_name)
|
|
1119
|
+
token: str = cred.token
|
|
943
1120
|
|
|
944
1121
|
# Build connection parameters dictionary
|
|
945
1122
|
params: dict[str, Any] = {
|
|
@@ -977,6 +1154,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
977
1154
|
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
978
1155
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
979
1156
|
|
|
1157
|
+
# Use provided workspace client or fall back to resource's own workspace_client
|
|
1158
|
+
if w is None:
|
|
1159
|
+
w = self.workspace_client
|
|
980
1160
|
provider: DatabricksProvider = DatabricksProvider(w=w)
|
|
981
1161
|
provider.create_lakebase(self)
|
|
982
1162
|
provider.create_lakebase_instance_role(self)
|
|
@@ -996,14 +1176,62 @@ class GenieSemanticCacheParametersModel(BaseModel):
|
|
|
996
1176
|
time_to_live_seconds: int | None = (
|
|
997
1177
|
60 * 60 * 24
|
|
998
1178
|
) # 1 day default, None or negative = never expires
|
|
999
|
-
similarity_threshold: float = (
|
|
1000
|
-
|
|
1179
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1180
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1181
|
+
question_weight: Optional[float] = (
|
|
1182
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1183
|
+
)
|
|
1184
|
+
context_weight: Optional[float] = (
|
|
1185
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1001
1186
|
)
|
|
1002
1187
|
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1003
1188
|
embedding_dims: int | None = None # Auto-detected if None
|
|
1004
1189
|
database: DatabaseModel
|
|
1005
1190
|
warehouse: WarehouseModel
|
|
1006
1191
|
table_name: str = "genie_semantic_cache"
|
|
1192
|
+
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1193
|
+
max_context_tokens: int = (
|
|
1194
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
@model_validator(mode="after")
|
|
1198
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1199
|
+
"""
|
|
1200
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1201
|
+
|
|
1202
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1203
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1204
|
+
If both are provided, they must sum to 1.0.
|
|
1205
|
+
"""
|
|
1206
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1207
|
+
# Both missing - use defaults
|
|
1208
|
+
self.question_weight = 0.6
|
|
1209
|
+
self.context_weight = 0.4
|
|
1210
|
+
elif self.question_weight is None:
|
|
1211
|
+
# Compute question_weight from context_weight
|
|
1212
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1213
|
+
raise ValueError(
|
|
1214
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1215
|
+
)
|
|
1216
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1217
|
+
elif self.context_weight is None:
|
|
1218
|
+
# Compute context_weight from question_weight
|
|
1219
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1220
|
+
raise ValueError(
|
|
1221
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1222
|
+
)
|
|
1223
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1224
|
+
else:
|
|
1225
|
+
# Both provided - validate they sum to 1.0
|
|
1226
|
+
total_weight = self.question_weight + self.context_weight
|
|
1227
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1228
|
+
raise ValueError(
|
|
1229
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1230
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1231
|
+
f"of question vs context similarity in the combined score."
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
return self
|
|
1007
1235
|
|
|
1008
1236
|
|
|
1009
1237
|
class SearchParametersModel(BaseModel):
|
|
@@ -1096,28 +1324,47 @@ class FunctionType(str, Enum):
|
|
|
1096
1324
|
MCP = "mcp"
|
|
1097
1325
|
|
|
1098
1326
|
|
|
1099
|
-
class
|
|
1100
|
-
"""
|
|
1327
|
+
class HumanInTheLoopModel(BaseModel):
|
|
1328
|
+
"""
|
|
1329
|
+
Configuration for Human-in-the-Loop tool approval.
|
|
1101
1330
|
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
RESPONSE = "response"
|
|
1105
|
-
DECLINE = "decline"
|
|
1331
|
+
This model configures when and how tools require human approval before execution.
|
|
1332
|
+
It maps to LangChain's HumanInTheLoopMiddleware.
|
|
1106
1333
|
|
|
1334
|
+
LangChain supports three decision types:
|
|
1335
|
+
- "approve": Execute tool with original arguments
|
|
1336
|
+
- "edit": Modify arguments before execution
|
|
1337
|
+
- "reject": Skip execution with optional feedback message
|
|
1338
|
+
"""
|
|
1107
1339
|
|
|
1108
|
-
class HumanInTheLoopModel(BaseModel):
|
|
1109
1340
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1341
|
+
|
|
1342
|
+
review_prompt: Optional[str] = Field(
|
|
1343
|
+
default=None,
|
|
1344
|
+
description="Message shown to the reviewer when approval is requested",
|
|
1345
|
+
)
|
|
1346
|
+
|
|
1347
|
+
allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
|
|
1348
|
+
default_factory=lambda: ["approve", "edit", "reject"],
|
|
1349
|
+
description="List of allowed decision types for this tool",
|
|
1118
1350
|
)
|
|
1119
|
-
|
|
1120
|
-
|
|
1351
|
+
|
|
1352
|
+
@model_validator(mode="after")
|
|
1353
|
+
def validate_and_normalize_decisions(self) -> Self:
|
|
1354
|
+
"""Validate and normalize allowed decisions."""
|
|
1355
|
+
if not self.allowed_decisions:
|
|
1356
|
+
raise ValueError("At least one decision type must be allowed")
|
|
1357
|
+
|
|
1358
|
+
# Remove duplicates while preserving order
|
|
1359
|
+
seen = set()
|
|
1360
|
+
unique_decisions = []
|
|
1361
|
+
for decision in self.allowed_decisions:
|
|
1362
|
+
if decision not in seen:
|
|
1363
|
+
seen.add(decision)
|
|
1364
|
+
unique_decisions.append(decision)
|
|
1365
|
+
self.allowed_decisions = unique_decisions
|
|
1366
|
+
|
|
1367
|
+
return self
|
|
1121
1368
|
|
|
1122
1369
|
|
|
1123
1370
|
class BaseFunctionModel(ABC, BaseModel):
|
|
@@ -1180,7 +1427,16 @@ class TransportType(str, Enum):
|
|
|
1180
1427
|
STDIO = "stdio"
|
|
1181
1428
|
|
|
1182
1429
|
|
|
1183
|
-
class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
1430
|
+
class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
|
|
1431
|
+
"""
|
|
1432
|
+
MCP Function Model with authentication inherited from IsDatabricksResource.
|
|
1433
|
+
|
|
1434
|
+
Authentication for MCP connections uses the same options as other resources:
|
|
1435
|
+
- Service Principal (client_id + client_secret + workspace_host)
|
|
1436
|
+
- PAT (pat + workspace_host)
|
|
1437
|
+
- OBO (on_behalf_of_user)
|
|
1438
|
+
"""
|
|
1439
|
+
|
|
1184
1440
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1185
1441
|
type: Literal[FunctionType.MCP] = FunctionType.MCP
|
|
1186
1442
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
@@ -1188,26 +1444,27 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1188
1444
|
url: Optional[AnyVariable] = None
|
|
1189
1445
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1190
1446
|
args: list[str] = Field(default_factory=list)
|
|
1191
|
-
|
|
1192
|
-
service_principal: Optional[ServicePrincipalModel] = None
|
|
1193
|
-
client_id: Optional[AnyVariable] = None
|
|
1194
|
-
client_secret: Optional[AnyVariable] = None
|
|
1195
|
-
workspace_host: Optional[AnyVariable] = None
|
|
1447
|
+
# MCP-specific fields
|
|
1196
1448
|
connection: Optional[ConnectionModel] = None
|
|
1197
1449
|
functions: Optional[SchemaModel] = None
|
|
1198
1450
|
genie_room: Optional[GenieRoomModel] = None
|
|
1199
1451
|
sql: Optional[bool] = None
|
|
1200
1452
|
vector_search: Optional[VectorStoreModel] = None
|
|
1201
1453
|
|
|
1202
|
-
@
|
|
1203
|
-
def
|
|
1204
|
-
"""
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1454
|
+
@property
|
|
1455
|
+
def api_scopes(self) -> Sequence[str]:
|
|
1456
|
+
"""API scopes for MCP connections."""
|
|
1457
|
+
return [
|
|
1458
|
+
"serving.serving-endpoints",
|
|
1459
|
+
"mcp.genie",
|
|
1460
|
+
"mcp.functions",
|
|
1461
|
+
"mcp.vectorsearch",
|
|
1462
|
+
"mcp.external",
|
|
1463
|
+
]
|
|
1464
|
+
|
|
1465
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
1466
|
+
"""MCP functions don't declare static resources."""
|
|
1467
|
+
return []
|
|
1211
1468
|
|
|
1212
1469
|
@property
|
|
1213
1470
|
def full_name(self) -> str:
|
|
@@ -1372,27 +1629,6 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
1372
1629
|
self.headers[key] = value_of(value)
|
|
1373
1630
|
return self
|
|
1374
1631
|
|
|
1375
|
-
@model_validator(mode="after")
|
|
1376
|
-
def validate_auth_methods(self) -> "McpFunctionModel":
|
|
1377
|
-
oauth_fields: Sequence[Any] = [
|
|
1378
|
-
self.client_id,
|
|
1379
|
-
self.client_secret,
|
|
1380
|
-
]
|
|
1381
|
-
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
1382
|
-
|
|
1383
|
-
pat_fields: Sequence[Any] = [self.pat]
|
|
1384
|
-
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
1385
|
-
|
|
1386
|
-
if has_oauth and has_user_auth:
|
|
1387
|
-
raise ValueError(
|
|
1388
|
-
"Cannot use both OAuth and user authentication methods. "
|
|
1389
|
-
"Please provide either OAuth credentials or user credentials."
|
|
1390
|
-
)
|
|
1391
|
-
|
|
1392
|
-
# Note: workspace_host is optional - it will be derived from workspace client if not provided
|
|
1393
|
-
|
|
1394
|
-
return self
|
|
1395
|
-
|
|
1396
1632
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
1397
1633
|
from dao_ai.tools import create_mcp_tools
|
|
1398
1634
|
|
|
@@ -1434,17 +1670,97 @@ class ToolModel(BaseModel):
|
|
|
1434
1670
|
function: AnyTool
|
|
1435
1671
|
|
|
1436
1672
|
|
|
1673
|
+
class PromptModel(BaseModel, HasFullName):
|
|
1674
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1675
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1676
|
+
name: str
|
|
1677
|
+
description: Optional[str] = None
|
|
1678
|
+
default_template: Optional[str] = None
|
|
1679
|
+
alias: Optional[str] = None
|
|
1680
|
+
version: Optional[int] = None
|
|
1681
|
+
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1682
|
+
|
|
1683
|
+
@property
|
|
1684
|
+
def template(self) -> str:
|
|
1685
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1686
|
+
|
|
1687
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
1688
|
+
prompt_version = provider.get_prompt(self)
|
|
1689
|
+
return prompt_version.to_single_brace_format()
|
|
1690
|
+
|
|
1691
|
+
@property
|
|
1692
|
+
def full_name(self) -> str:
|
|
1693
|
+
prompt_name: str = self.name
|
|
1694
|
+
if self.schema_model:
|
|
1695
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1696
|
+
return prompt_name
|
|
1697
|
+
|
|
1698
|
+
@property
|
|
1699
|
+
def uri(self) -> str:
|
|
1700
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1701
|
+
|
|
1702
|
+
if self.alias:
|
|
1703
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1704
|
+
elif self.version:
|
|
1705
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1706
|
+
else:
|
|
1707
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1708
|
+
|
|
1709
|
+
return prompt_uri
|
|
1710
|
+
|
|
1711
|
+
def as_prompt(self) -> PromptVersion:
|
|
1712
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1713
|
+
return prompt_version
|
|
1714
|
+
|
|
1715
|
+
@model_validator(mode="after")
|
|
1716
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1717
|
+
if self.alias and self.version:
|
|
1718
|
+
raise ValueError("Cannot specify both alias and version")
|
|
1719
|
+
return self
|
|
1720
|
+
|
|
1721
|
+
|
|
1437
1722
|
class GuardrailModel(BaseModel):
|
|
1438
1723
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1439
1724
|
name: str
|
|
1440
|
-
model: LLMModel
|
|
1441
|
-
prompt: str
|
|
1725
|
+
model: str | LLMModel
|
|
1726
|
+
prompt: str | PromptModel
|
|
1442
1727
|
num_retries: Optional[int] = 3
|
|
1443
1728
|
|
|
1729
|
+
@model_validator(mode="after")
|
|
1730
|
+
def validate_llm_model(self) -> Self:
|
|
1731
|
+
if isinstance(self.model, str):
|
|
1732
|
+
self.model = LLMModel(name=self.model)
|
|
1733
|
+
return self
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
class MiddlewareModel(BaseModel):
|
|
1737
|
+
"""Configuration for middleware that can be applied to agents.
|
|
1738
|
+
|
|
1739
|
+
Middleware is defined at the AppConfig level and can be referenced by name
|
|
1740
|
+
in agent configurations using YAML anchors for reusability.
|
|
1741
|
+
"""
|
|
1742
|
+
|
|
1743
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1744
|
+
name: str = Field(
|
|
1745
|
+
description="Fully qualified name of the middleware factory function"
|
|
1746
|
+
)
|
|
1747
|
+
args: dict[str, Any] = Field(
|
|
1748
|
+
default_factory=dict,
|
|
1749
|
+
description="Arguments to pass to the middleware factory function",
|
|
1750
|
+
)
|
|
1751
|
+
|
|
1752
|
+
@model_validator(mode="after")
|
|
1753
|
+
def resolve_args(self) -> Self:
|
|
1754
|
+
"""Resolve any variable references in args."""
|
|
1755
|
+
for key, value in self.args.items():
|
|
1756
|
+
self.args[key] = value_of(value)
|
|
1757
|
+
return self
|
|
1758
|
+
|
|
1444
1759
|
|
|
1445
1760
|
class StorageType(str, Enum):
|
|
1446
1761
|
POSTGRES = "postgres"
|
|
1447
1762
|
MEMORY = "memory"
|
|
1763
|
+
LAKEBASE = "lakebase"
|
|
1448
1764
|
|
|
1449
1765
|
|
|
1450
1766
|
class CheckpointerModel(BaseModel):
|
|
@@ -1454,8 +1770,11 @@ class CheckpointerModel(BaseModel):
|
|
|
1454
1770
|
database: Optional[DatabaseModel] = None
|
|
1455
1771
|
|
|
1456
1772
|
@model_validator(mode="after")
|
|
1457
|
-
def
|
|
1458
|
-
if
|
|
1773
|
+
def validate_storage_requires_database(self) -> Self:
|
|
1774
|
+
if (
|
|
1775
|
+
self.type in [StorageType.POSTGRES, StorageType.LAKEBASE]
|
|
1776
|
+
and not self.database
|
|
1777
|
+
):
|
|
1459
1778
|
raise ValueError("Database must be provided when storage type is POSTGRES")
|
|
1460
1779
|
return self
|
|
1461
1780
|
|
|
@@ -1500,56 +1819,158 @@ class MemoryModel(BaseModel):
|
|
|
1500
1819
|
FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
|
|
1501
1820
|
|
|
1502
1821
|
|
|
1503
|
-
class
|
|
1822
|
+
class ResponseFormatModel(BaseModel):
|
|
1823
|
+
"""
|
|
1824
|
+
Configuration for structured response formats.
|
|
1825
|
+
|
|
1826
|
+
The response_schema field accepts either a type or a string:
|
|
1827
|
+
- Type (Pydantic model, dataclass, etc.): Used directly for structured output
|
|
1828
|
+
- String: First attempts to load as a fully qualified type name, falls back to JSON schema string
|
|
1829
|
+
|
|
1830
|
+
This unified approach simplifies the API while maintaining flexibility.
|
|
1831
|
+
"""
|
|
1832
|
+
|
|
1504
1833
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1834
|
+
use_tool: Optional[bool] = Field(
|
|
1835
|
+
default=None,
|
|
1836
|
+
description=(
|
|
1837
|
+
"Strategy for structured output: "
|
|
1838
|
+
"None (default) = auto-detect from model capabilities, "
|
|
1839
|
+
"False = force ProviderStrategy (native), "
|
|
1840
|
+
"True = force ToolStrategy (function calling)"
|
|
1841
|
+
),
|
|
1842
|
+
)
|
|
1843
|
+
response_schema: Optional[str | type] = Field(
|
|
1844
|
+
default=None,
|
|
1845
|
+
description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
|
|
1846
|
+
)
|
|
1512
1847
|
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1848
|
+
def as_strategy(self) -> ProviderStrategy | ToolStrategy:
|
|
1849
|
+
"""
|
|
1850
|
+
Convert response_schema to appropriate LangChain strategy.
|
|
1516
1851
|
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1852
|
+
Returns:
|
|
1853
|
+
- None if no response_schema configured
|
|
1854
|
+
- Raw schema/type for auto-detection (when use_tool=None)
|
|
1855
|
+
- ToolStrategy wrapping the schema (when use_tool=True)
|
|
1856
|
+
- ProviderStrategy wrapping the schema (when use_tool=False)
|
|
1520
1857
|
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
if self.schema_model:
|
|
1525
|
-
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1526
|
-
return prompt_name
|
|
1858
|
+
Raises:
|
|
1859
|
+
ValueError: If response_schema is a JSON schema string that cannot be parsed
|
|
1860
|
+
"""
|
|
1527
1861
|
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1862
|
+
if self.response_schema is None:
|
|
1863
|
+
return None
|
|
1531
1864
|
|
|
1532
|
-
|
|
1533
|
-
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1534
|
-
elif self.version:
|
|
1535
|
-
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1536
|
-
else:
|
|
1537
|
-
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1865
|
+
schema = self.response_schema
|
|
1538
1866
|
|
|
1539
|
-
|
|
1867
|
+
# Handle type schemas (Pydantic, dataclass, etc.)
|
|
1868
|
+
if self.is_type_schema:
|
|
1869
|
+
if self.use_tool is None:
|
|
1870
|
+
# Auto-detect: Pass schema directly, let LangChain decide
|
|
1871
|
+
return schema
|
|
1872
|
+
elif self.use_tool is True:
|
|
1873
|
+
# Force ToolStrategy (function calling)
|
|
1874
|
+
return ToolStrategy(schema)
|
|
1875
|
+
else: # use_tool is False
|
|
1876
|
+
# Force ProviderStrategy (native structured output)
|
|
1877
|
+
return ProviderStrategy(schema)
|
|
1540
1878
|
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1879
|
+
# Handle JSON schema strings
|
|
1880
|
+
elif self.is_json_schema:
|
|
1881
|
+
import json
|
|
1882
|
+
|
|
1883
|
+
try:
|
|
1884
|
+
schema_dict = json.loads(schema)
|
|
1885
|
+
except json.JSONDecodeError as e:
|
|
1886
|
+
raise ValueError(f"Invalid JSON schema string: {e}") from e
|
|
1887
|
+
|
|
1888
|
+
# Apply same use_tool logic as type schemas
|
|
1889
|
+
if self.use_tool is None:
|
|
1890
|
+
# Auto-detect
|
|
1891
|
+
return schema_dict
|
|
1892
|
+
elif self.use_tool is True:
|
|
1893
|
+
# Force ToolStrategy
|
|
1894
|
+
return ToolStrategy(schema_dict)
|
|
1895
|
+
else: # use_tool is False
|
|
1896
|
+
# Force ProviderStrategy
|
|
1897
|
+
return ProviderStrategy(schema_dict)
|
|
1898
|
+
|
|
1899
|
+
return None
|
|
1544
1900
|
|
|
1545
1901
|
@model_validator(mode="after")
|
|
1546
|
-
def
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1902
|
+
def validate_response_schema(self) -> Self:
|
|
1903
|
+
"""
|
|
1904
|
+
Validate and convert response_schema.
|
|
1905
|
+
|
|
1906
|
+
Processing logic:
|
|
1907
|
+
1. If None: no response format specified
|
|
1908
|
+
2. If type: use directly as structured output type
|
|
1909
|
+
3. If str: try to load as FQN using type_from_fqn
|
|
1910
|
+
- Success: response_schema becomes the loaded type
|
|
1911
|
+
- Failure: keep as string (treated as JSON schema)
|
|
1912
|
+
|
|
1913
|
+
After validation, response_schema is one of:
|
|
1914
|
+
- None (no schema)
|
|
1915
|
+
- type (use for structured output)
|
|
1916
|
+
- str (JSON schema)
|
|
1917
|
+
|
|
1918
|
+
Returns:
|
|
1919
|
+
Self with validated response_schema
|
|
1920
|
+
"""
|
|
1921
|
+
if self.response_schema is None:
|
|
1922
|
+
return self
|
|
1923
|
+
|
|
1924
|
+
# If already a type, return
|
|
1925
|
+
if isinstance(self.response_schema, type):
|
|
1926
|
+
return self
|
|
1927
|
+
|
|
1928
|
+
# If it's a string, try to load as type, fallback to json_schema
|
|
1929
|
+
if isinstance(self.response_schema, str):
|
|
1930
|
+
from dao_ai.utils import type_from_fqn
|
|
1931
|
+
|
|
1932
|
+
try:
|
|
1933
|
+
resolved_type = type_from_fqn(self.response_schema)
|
|
1934
|
+
self.response_schema = resolved_type
|
|
1935
|
+
logger.debug(
|
|
1936
|
+
f"Resolved response_schema string to type: {resolved_type}"
|
|
1937
|
+
)
|
|
1938
|
+
return self
|
|
1939
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
1940
|
+
# Keep as string - it's a JSON schema
|
|
1941
|
+
logger.debug(
|
|
1942
|
+
f"Could not resolve '{self.response_schema}' as type: {e}. "
|
|
1943
|
+
f"Treating as JSON schema string."
|
|
1944
|
+
)
|
|
1945
|
+
return self
|
|
1946
|
+
|
|
1947
|
+
# Invalid type
|
|
1948
|
+
raise ValueError(
|
|
1949
|
+
f"response_schema must be None, type, or str, got {type(self.response_schema)}"
|
|
1950
|
+
)
|
|
1951
|
+
|
|
1952
|
+
@property
|
|
1953
|
+
def is_type_schema(self) -> bool:
|
|
1954
|
+
"""Returns True if response_schema is a type (not JSON schema string)."""
|
|
1955
|
+
return isinstance(self.response_schema, type)
|
|
1956
|
+
|
|
1957
|
+
@property
|
|
1958
|
+
def is_json_schema(self) -> bool:
|
|
1959
|
+
"""Returns True if response_schema is a JSON schema string (not a type)."""
|
|
1960
|
+
return isinstance(self.response_schema, str)
|
|
1550
1961
|
|
|
1551
1962
|
|
|
1552
1963
|
class AgentModel(BaseModel):
|
|
1964
|
+
"""
|
|
1965
|
+
Configuration model for an agent in the DAO AI framework.
|
|
1966
|
+
|
|
1967
|
+
Agents combine an LLM with tools and middleware to create systems that can
|
|
1968
|
+
reason about tasks, decide which tools to use, and iteratively work towards solutions.
|
|
1969
|
+
|
|
1970
|
+
Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
|
|
1971
|
+
providing a more flexible and composable way to customize agent behavior.
|
|
1972
|
+
"""
|
|
1973
|
+
|
|
1553
1974
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1554
1975
|
name: str
|
|
1555
1976
|
description: Optional[str] = None
|
|
@@ -1558,9 +1979,43 @@ class AgentModel(BaseModel):
|
|
|
1558
1979
|
guardrails: list[GuardrailModel] = Field(default_factory=list)
|
|
1559
1980
|
prompt: Optional[str | PromptModel] = None
|
|
1560
1981
|
handoff_prompt: Optional[str] = None
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1982
|
+
middleware: list[MiddlewareModel] = Field(
|
|
1983
|
+
default_factory=list,
|
|
1984
|
+
description="List of middleware to apply to this agent",
|
|
1985
|
+
)
|
|
1986
|
+
response_format: Optional[ResponseFormatModel | type | str] = None
|
|
1987
|
+
|
|
1988
|
+
@model_validator(mode="after")
|
|
1989
|
+
def validate_response_format(self) -> Self:
|
|
1990
|
+
"""
|
|
1991
|
+
Validate and normalize response_format.
|
|
1992
|
+
|
|
1993
|
+
Accepts:
|
|
1994
|
+
- None (no response format)
|
|
1995
|
+
- ResponseFormatModel (already validated)
|
|
1996
|
+
- type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
|
|
1997
|
+
- str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
|
|
1998
|
+
|
|
1999
|
+
ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
|
|
2000
|
+
"""
|
|
2001
|
+
if self.response_format is None or isinstance(
|
|
2002
|
+
self.response_format, ResponseFormatModel
|
|
2003
|
+
):
|
|
2004
|
+
return self
|
|
2005
|
+
|
|
2006
|
+
# Convert type or str to ResponseFormatModel
|
|
2007
|
+
# ResponseFormatModel's validator will handle the smart type loading and fallback
|
|
2008
|
+
if isinstance(self.response_format, (type, str)):
|
|
2009
|
+
self.response_format = ResponseFormatModel(
|
|
2010
|
+
response_schema=self.response_format
|
|
2011
|
+
)
|
|
2012
|
+
return self
|
|
2013
|
+
|
|
2014
|
+
# Invalid type
|
|
2015
|
+
raise ValueError(
|
|
2016
|
+
f"response_format must be None, ResponseFormatModel, type, or str, "
|
|
2017
|
+
f"got {type(self.response_format)}"
|
|
2018
|
+
)
|
|
1564
2019
|
|
|
1565
2020
|
def as_runnable(self) -> RunnableLike:
|
|
1566
2021
|
from dao_ai.nodes import create_agent_node
|
|
@@ -1579,6 +2034,10 @@ class SupervisorModel(BaseModel):
|
|
|
1579
2034
|
model: LLMModel
|
|
1580
2035
|
tools: list[ToolModel] = Field(default_factory=list)
|
|
1581
2036
|
prompt: Optional[str] = None
|
|
2037
|
+
middleware: list[MiddlewareModel] = Field(
|
|
2038
|
+
default_factory=list,
|
|
2039
|
+
description="List of middleware to apply to the supervisor",
|
|
2040
|
+
)
|
|
1582
2041
|
|
|
1583
2042
|
|
|
1584
2043
|
class SwarmModel(BaseModel):
|
|
@@ -1702,6 +2161,28 @@ class ChatPayload(BaseModel):
|
|
|
1702
2161
|
|
|
1703
2162
|
return self
|
|
1704
2163
|
|
|
2164
|
+
@model_validator(mode="after")
|
|
2165
|
+
def ensure_thread_id(self) -> "ChatPayload":
|
|
2166
|
+
"""Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
|
|
2167
|
+
import uuid
|
|
2168
|
+
|
|
2169
|
+
if self.custom_inputs is None:
|
|
2170
|
+
self.custom_inputs = {}
|
|
2171
|
+
|
|
2172
|
+
# Get or create configurable section
|
|
2173
|
+
configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
|
|
2174
|
+
|
|
2175
|
+
# Check if thread_id or conversation_id exists
|
|
2176
|
+
has_thread_id = configurable.get("thread_id") is not None
|
|
2177
|
+
has_conversation_id = configurable.get("conversation_id") is not None
|
|
2178
|
+
|
|
2179
|
+
# If neither is provided, generate a UUID for conversation_id
|
|
2180
|
+
if not has_thread_id and not has_conversation_id:
|
|
2181
|
+
configurable["conversation_id"] = str(uuid.uuid4())
|
|
2182
|
+
self.custom_inputs["configurable"] = configurable
|
|
2183
|
+
|
|
2184
|
+
return self
|
|
2185
|
+
|
|
1705
2186
|
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1706
2187
|
return messages_from_dict(
|
|
1707
2188
|
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
@@ -1717,20 +2198,38 @@ class ChatPayload(BaseModel):
|
|
|
1717
2198
|
|
|
1718
2199
|
|
|
1719
2200
|
class ChatHistoryModel(BaseModel):
|
|
2201
|
+
"""
|
|
2202
|
+
Configuration for chat history summarization.
|
|
2203
|
+
|
|
2204
|
+
Attributes:
|
|
2205
|
+
model: The LLM to use for generating summaries.
|
|
2206
|
+
max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
|
|
2207
|
+
After summarization, recent messages totaling up to this many tokens are preserved.
|
|
2208
|
+
max_tokens_before_summary: Token threshold that triggers summarization.
|
|
2209
|
+
When conversation exceeds this, summarization runs. Mutually exclusive with
|
|
2210
|
+
max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
|
|
2211
|
+
max_messages_before_summary: Message count threshold that triggers summarization.
|
|
2212
|
+
When conversation exceeds this many messages, summarization runs.
|
|
2213
|
+
Mutually exclusive with max_tokens_before_summary.
|
|
2214
|
+
"""
|
|
2215
|
+
|
|
1720
2216
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1721
2217
|
model: LLMModel
|
|
1722
|
-
max_tokens: int =
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
2218
|
+
max_tokens: int = Field(
|
|
2219
|
+
default=2048,
|
|
2220
|
+
gt=0,
|
|
2221
|
+
description="Maximum tokens to keep after summarization",
|
|
2222
|
+
)
|
|
2223
|
+
max_tokens_before_summary: Optional[int] = Field(
|
|
2224
|
+
default=None,
|
|
2225
|
+
gt=0,
|
|
2226
|
+
description="Token threshold that triggers summarization",
|
|
2227
|
+
)
|
|
2228
|
+
max_messages_before_summary: Optional[int] = Field(
|
|
2229
|
+
default=None,
|
|
2230
|
+
gt=0,
|
|
2231
|
+
description="Message count threshold that triggers summarization",
|
|
2232
|
+
)
|
|
1734
2233
|
|
|
1735
2234
|
|
|
1736
2235
|
class AppModel(BaseModel):
|
|
@@ -1757,9 +2256,6 @@ class AppModel(BaseModel):
|
|
|
1757
2256
|
shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1758
2257
|
default_factory=list
|
|
1759
2258
|
)
|
|
1760
|
-
message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1761
|
-
default_factory=list
|
|
1762
|
-
)
|
|
1763
2259
|
input_example: Optional[ChatPayload] = None
|
|
1764
2260
|
chat_history: Optional[ChatHistoryModel] = None
|
|
1765
2261
|
code_paths: list[str] = Field(default_factory=list)
|
|
@@ -1964,33 +2460,67 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
|
1964
2460
|
|
|
1965
2461
|
|
|
1966
2462
|
class PromptOptimizationModel(BaseModel):
|
|
2463
|
+
"""Configuration for prompt optimization using GEPA.
|
|
2464
|
+
|
|
2465
|
+
GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
|
|
2466
|
+
optimizer that uses reflective mutation to improve prompts based on
|
|
2467
|
+
evaluation feedback.
|
|
2468
|
+
|
|
2469
|
+
Example:
|
|
2470
|
+
prompt_optimization:
|
|
2471
|
+
name: optimize_my_prompt
|
|
2472
|
+
prompt: *my_prompt
|
|
2473
|
+
agent: *my_agent
|
|
2474
|
+
dataset: *my_training_dataset
|
|
2475
|
+
reflection_model: databricks-meta-llama-3-3-70b-instruct
|
|
2476
|
+
num_candidates: 50
|
|
2477
|
+
"""
|
|
2478
|
+
|
|
1967
2479
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1968
2480
|
name: str
|
|
1969
2481
|
prompt: Optional[PromptModel] = None
|
|
1970
2482
|
agent: AgentModel
|
|
1971
|
-
dataset:
|
|
1972
|
-
EvaluationDatasetModel | str
|
|
1973
|
-
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
2483
|
+
dataset: EvaluationDatasetModel # Training dataset with examples
|
|
1974
2484
|
reflection_model: Optional[LLMModel | str] = None
|
|
1975
2485
|
num_candidates: Optional[int] = 50
|
|
1976
|
-
scorer_model: Optional[LLMModel | str] = None
|
|
1977
2486
|
|
|
1978
2487
|
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1979
2488
|
"""
|
|
1980
|
-
Optimize the prompt using
|
|
2489
|
+
Optimize the prompt using GEPA.
|
|
1981
2490
|
|
|
1982
2491
|
Args:
|
|
1983
|
-
w: Optional WorkspaceClient for
|
|
2492
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
1984
2493
|
|
|
1985
2494
|
Returns:
|
|
1986
|
-
PromptModel: The optimized prompt model
|
|
2495
|
+
PromptModel: The optimized prompt model
|
|
1987
2496
|
"""
|
|
1988
|
-
from dao_ai.
|
|
1989
|
-
from dao_ai.providers.databricks import DatabricksProvider
|
|
2497
|
+
from dao_ai.optimization import OptimizationResult, optimize_prompt
|
|
1990
2498
|
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
2499
|
+
# Get reflection model name
|
|
2500
|
+
reflection_model_name: str | None = None
|
|
2501
|
+
if self.reflection_model:
|
|
2502
|
+
if isinstance(self.reflection_model, str):
|
|
2503
|
+
reflection_model_name = self.reflection_model
|
|
2504
|
+
else:
|
|
2505
|
+
reflection_model_name = self.reflection_model.uri
|
|
2506
|
+
|
|
2507
|
+
# Ensure prompt is set
|
|
2508
|
+
prompt = self.prompt
|
|
2509
|
+
if prompt is None:
|
|
2510
|
+
raise ValueError(
|
|
2511
|
+
f"Prompt optimization '{self.name}' requires a prompt to be set"
|
|
2512
|
+
)
|
|
2513
|
+
|
|
2514
|
+
result: OptimizationResult = optimize_prompt(
|
|
2515
|
+
prompt=prompt,
|
|
2516
|
+
agent=self.agent,
|
|
2517
|
+
dataset=self.dataset,
|
|
2518
|
+
reflection_model=reflection_model_name,
|
|
2519
|
+
num_candidates=self.num_candidates or 50,
|
|
2520
|
+
register_if_improved=True,
|
|
2521
|
+
)
|
|
2522
|
+
|
|
2523
|
+
return result.optimized_prompt
|
|
1994
2524
|
|
|
1995
2525
|
@model_validator(mode="after")
|
|
1996
2526
|
def set_defaults(self) -> Self:
|
|
@@ -2004,12 +2534,6 @@ class PromptOptimizationModel(BaseModel):
|
|
|
2004
2534
|
f"or an agent with a prompt configured"
|
|
2005
2535
|
)
|
|
2006
2536
|
|
|
2007
|
-
if self.reflection_model is None:
|
|
2008
|
-
self.reflection_model = self.agent.model
|
|
2009
|
-
|
|
2010
|
-
if self.scorer_model is None:
|
|
2011
|
-
self.scorer_model = self.agent.model
|
|
2012
|
-
|
|
2013
2537
|
return self
|
|
2014
2538
|
|
|
2015
2539
|
|
|
@@ -2110,6 +2634,7 @@ class ResourcesModel(BaseModel):
|
|
|
2110
2634
|
warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
|
|
2111
2635
|
databases: dict[str, DatabaseModel] = Field(default_factory=dict)
|
|
2112
2636
|
connections: dict[str, ConnectionModel] = Field(default_factory=dict)
|
|
2637
|
+
apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
|
|
2113
2638
|
|
|
2114
2639
|
|
|
2115
2640
|
class AppConfig(BaseModel):
|
|
@@ -2121,6 +2646,7 @@ class AppConfig(BaseModel):
|
|
|
2121
2646
|
retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
|
|
2122
2647
|
tools: dict[str, ToolModel] = Field(default_factory=dict)
|
|
2123
2648
|
guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
|
|
2649
|
+
middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
|
|
2124
2650
|
memory: Optional[MemoryModel] = None
|
|
2125
2651
|
prompts: dict[str, PromptModel] = Field(default_factory=dict)
|
|
2126
2652
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|