dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  Iterator,
13
13
  Literal,
14
14
  Optional,
15
+ Self,
15
16
  Sequence,
16
17
  TypeAlias,
17
18
  Union,
@@ -22,22 +23,33 @@ from databricks.sdk.credentials_provider import (
22
23
  CredentialsStrategy,
23
24
  ModelServingUserCredentials,
24
25
  )
26
+ from databricks.sdk.errors.platform import NotFound
25
27
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
28
+ from databricks.sdk.service.dashboards import GenieSpace
26
29
  from databricks.sdk.service.database import DatabaseInstance
30
+ from databricks.sdk.service.sql import GetWarehouseResponse
27
31
  from databricks.vector_search.client import VectorSearchClient
28
32
  from databricks.vector_search.index import VectorSearchIndex
29
33
  from databricks_langchain import (
34
+ ChatDatabricks,
35
+ DatabricksEmbeddings,
30
36
  DatabricksFunctionClient,
31
37
  )
38
+ from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
39
+ from langchain_core.embeddings import Embeddings
32
40
  from langchain_core.language_models import LanguageModelLike
41
+ from langchain_core.messages import BaseMessage, messages_from_dict
33
42
  from langchain_core.runnables.base import RunnableLike
34
43
  from langchain_openai import ChatOpenAI
35
44
  from langgraph.checkpoint.base import BaseCheckpointSaver
36
45
  from langgraph.graph.state import CompiledStateGraph
37
46
  from langgraph.store.base import BaseStore
38
47
  from loguru import logger
48
+ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
49
+ from mlflow.genai.prompts import PromptVersion, load_prompt
39
50
  from mlflow.models import ModelConfig
40
51
  from mlflow.models.resources import (
52
+ DatabricksApp,
41
53
  DatabricksFunction,
42
54
  DatabricksGenieSpace,
43
55
  DatabricksLakebase,
@@ -49,14 +61,20 @@ from mlflow.models.resources import (
49
61
  DatabricksVectorSearchIndex,
50
62
  )
51
63
  from mlflow.pyfunc import ChatModel, ResponsesAgent
64
+ from mlflow.types.responses import (
65
+ ResponsesAgentRequest,
66
+ )
52
67
  from pydantic import (
53
68
  BaseModel,
54
69
  ConfigDict,
55
70
  Field,
71
+ PrivateAttr,
56
72
  field_serializer,
57
73
  model_validator,
58
74
  )
59
75
 
76
+ from dao_ai.utils import normalize_name
77
+
60
78
 
61
79
  class HasValue(ABC):
62
80
  @abstractmethod
@@ -75,27 +93,6 @@ class HasFullName(ABC):
75
93
  def full_name(self) -> str: ...
76
94
 
77
95
 
78
- class IsDatabricksResource(ABC):
79
- on_behalf_of_user: Optional[bool] = False
80
-
81
- @abstractmethod
82
- def as_resources(self) -> Sequence[DatabricksResource]: ...
83
-
84
- @property
85
- @abstractmethod
86
- def api_scopes(self) -> Sequence[str]: ...
87
-
88
- @property
89
- def workspace_client(self) -> WorkspaceClient:
90
- credentials_strategy: CredentialsStrategy = None
91
- if self.on_behalf_of_user:
92
- credentials_strategy = ModelServingUserCredentials()
93
- logger.debug(
94
- f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
95
- )
96
- return WorkspaceClient(credentials_strategy=credentials_strategy)
97
-
98
-
99
96
  class EnvironmentVariableModel(BaseModel, HasValue):
100
97
  model_config = ConfigDict(
101
98
  frozen=True,
@@ -194,6 +191,162 @@ AnyVariable: TypeAlias = (
194
191
  )
195
192
 
196
193
 
194
+ class ServicePrincipalModel(BaseModel):
195
+ model_config = ConfigDict(
196
+ frozen=True,
197
+ use_enum_values=True,
198
+ )
199
+ client_id: AnyVariable
200
+ client_secret: AnyVariable
201
+
202
+
203
+ class IsDatabricksResource(ABC, BaseModel):
204
+ """
205
+ Base class for Databricks resources with authentication support.
206
+
207
+ Authentication Options:
208
+ ----------------------
209
+ 1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
210
+ calling user's identity via ModelServingUserCredentials.
211
+
212
+ 2. **Service Principal (OAuth M2M)**: Provide service_principal or
213
+ (client_id + client_secret + workspace_host) for service principal auth.
214
+
215
+ 3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
216
+ to authenticate with a personal access token.
217
+
218
+ 4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
219
+ (environment variables, notebook context, etc.)
220
+
221
+ Authentication Priority:
222
+ 1. OBO (on_behalf_of_user=True)
223
+ 2. Service Principal (client_id + client_secret + workspace_host)
224
+ 3. PAT (pat + workspace_host)
225
+ 4. Ambient/default authentication
226
+ """
227
+
228
+ model_config = ConfigDict(use_enum_values=True)
229
+
230
+ on_behalf_of_user: Optional[bool] = False
231
+ service_principal: Optional[ServicePrincipalModel] = None
232
+ client_id: Optional[AnyVariable] = None
233
+ client_secret: Optional[AnyVariable] = None
234
+ workspace_host: Optional[AnyVariable] = None
235
+ pat: Optional[AnyVariable] = None
236
+
237
+ # Private attribute to cache the workspace client (lazy instantiation)
238
+ _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
239
+
240
+ @abstractmethod
241
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
242
+
243
+ @property
244
+ @abstractmethod
245
+ def api_scopes(self) -> Sequence[str]: ...
246
+
247
+ @model_validator(mode="after")
248
+ def _expand_service_principal(self) -> Self:
249
+ """Expand service_principal into client_id and client_secret if provided."""
250
+ if self.service_principal is not None:
251
+ if self.client_id is None:
252
+ self.client_id = self.service_principal.client_id
253
+ if self.client_secret is None:
254
+ self.client_secret = self.service_principal.client_secret
255
+ return self
256
+
257
+ @model_validator(mode="after")
258
+ def _validate_auth_not_mixed(self) -> Self:
259
+ """Validate that OAuth and PAT authentication are not both provided."""
260
+ has_oauth: bool = self.client_id is not None and self.client_secret is not None
261
+ has_pat: bool = self.pat is not None
262
+
263
+ if has_oauth and has_pat:
264
+ raise ValueError(
265
+ "Cannot use both OAuth and user authentication methods. "
266
+ "Please provide either OAuth credentials or user credentials."
267
+ )
268
+ return self
269
+
270
+ @property
271
+ def workspace_client(self) -> WorkspaceClient:
272
+ """
273
+ Get a WorkspaceClient configured with the appropriate authentication.
274
+
275
+ The client is lazily instantiated on first access and cached for subsequent calls.
276
+
277
+ Authentication priority:
278
+ 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
279
+ 2. If service principal credentials are configured (client_id, client_secret,
280
+ workspace_host), uses OAuth M2M
281
+ 3. If PAT is configured, uses token authentication
282
+ 4. Otherwise, uses default/ambient authentication
283
+ """
284
+ # Return cached client if already instantiated
285
+ if self._workspace_client is not None:
286
+ return self._workspace_client
287
+
288
+ from dao_ai.utils import normalize_host
289
+
290
+ # Check for OBO first (highest priority)
291
+ if self.on_behalf_of_user:
292
+ credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
293
+ logger.debug(
294
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
295
+ f"with OBO credentials strategy"
296
+ )
297
+ self._workspace_client = WorkspaceClient(
298
+ credentials_strategy=credentials_strategy
299
+ )
300
+ return self._workspace_client
301
+
302
+ # Check for service principal credentials
303
+ client_id_value: str | None = (
304
+ value_of(self.client_id) if self.client_id else None
305
+ )
306
+ client_secret_value: str | None = (
307
+ value_of(self.client_secret) if self.client_secret else None
308
+ )
309
+ workspace_host_value: str | None = (
310
+ normalize_host(value_of(self.workspace_host))
311
+ if self.workspace_host
312
+ else None
313
+ )
314
+
315
+ if client_id_value and client_secret_value and workspace_host_value:
316
+ logger.debug(
317
+ f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
318
+ f"client_id={client_id_value}, host={workspace_host_value}"
319
+ )
320
+ self._workspace_client = WorkspaceClient(
321
+ host=workspace_host_value,
322
+ client_id=client_id_value,
323
+ client_secret=client_secret_value,
324
+ auth_type="oauth-m2m",
325
+ )
326
+ return self._workspace_client
327
+
328
+ # Check for PAT authentication
329
+ pat_value: str | None = value_of(self.pat) if self.pat else None
330
+ if pat_value:
331
+ logger.debug(
332
+ f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
333
+ )
334
+ self._workspace_client = WorkspaceClient(
335
+ host=workspace_host_value,
336
+ token=pat_value,
337
+ auth_type="pat",
338
+ )
339
+ return self._workspace_client
340
+
341
+ # Default: use ambient authentication
342
+ logger.debug(
343
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
344
+ "with default/ambient authentication"
345
+ )
346
+ self._workspace_client = WorkspaceClient()
347
+ return self._workspace_client
348
+
349
+
197
350
  class Privilege(str, Enum):
198
351
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
199
352
  USE_CATALOG = "USE_CATALOG"
@@ -220,9 +373,21 @@ class Privilege(str, Enum):
220
373
 
221
374
  class PermissionModel(BaseModel):
222
375
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
223
- principals: list[str] = Field(default_factory=list)
376
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
224
377
  privileges: list[Privilege]
225
378
 
379
+ @model_validator(mode="after")
380
+ def resolve_principals(self) -> Self:
381
+ """Resolve ServicePrincipalModel objects to their client_id."""
382
+ resolved: list[str] = []
383
+ for principal in self.principals:
384
+ if isinstance(principal, ServicePrincipalModel):
385
+ resolved.append(value_of(principal.client_id))
386
+ else:
387
+ resolved.append(principal)
388
+ self.principals = resolved
389
+ return self
390
+
226
391
 
227
392
  class SchemaModel(BaseModel, HasFullName):
228
393
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -242,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
242
407
  provider.create_schema(self)
243
408
 
244
409
 
245
- class TableModel(BaseModel, HasFullName, IsDatabricksResource):
410
+ class DatabricksAppModel(IsDatabricksResource, HasFullName):
411
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
412
+ name: str
413
+ url: str
414
+
415
+ @property
416
+ def full_name(self) -> str:
417
+ return self.name
418
+
419
+ @property
420
+ def api_scopes(self) -> Sequence[str]:
421
+ return ["apps.apps"]
422
+
423
+ def as_resources(self) -> Sequence[DatabricksResource]:
424
+ return [
425
+ DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
426
+ ]
427
+
428
+
429
+ class TableModel(IsDatabricksResource, HasFullName):
246
430
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
247
431
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
248
432
  name: Optional[str] = None
@@ -268,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
268
452
  def api_scopes(self) -> Sequence[str]:
269
453
  return []
270
454
 
455
+ def exists(self) -> bool:
456
+ """Check if the table exists in Unity Catalog.
457
+
458
+ Returns:
459
+ True if the table exists, False otherwise.
460
+ """
461
+ try:
462
+ self.workspace_client.tables.get(full_name=self.full_name)
463
+ return True
464
+ except NotFound:
465
+ logger.debug(f"Table not found: {self.full_name}")
466
+ return False
467
+ except Exception as e:
468
+ logger.warning(f"Error checking table existence for {self.full_name}: {e}")
469
+ return False
470
+
271
471
  def as_resources(self) -> Sequence[DatabricksResource]:
272
472
  resources: list[DatabricksResource] = []
273
473
 
@@ -311,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
311
511
  return resources
312
512
 
313
513
 
314
- class LLMModel(BaseModel, IsDatabricksResource):
514
+ class LLMModel(IsDatabricksResource):
315
515
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
316
516
  name: str
517
+ description: Optional[str] = None
317
518
  temperature: Optional[float] = 0.1
318
519
  max_tokens: Optional[int] = 8192
319
520
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
521
+ use_responses_api: Optional[bool] = Field(
522
+ default=False,
523
+ description="Use Responses API for ResponsesAgent endpoints",
524
+ )
320
525
 
321
526
  @property
322
527
  def api_scopes(self) -> Sequence[str]:
@@ -324,6 +529,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
324
529
  "serving.serving-endpoints",
325
530
  ]
326
531
 
532
+ @property
533
+ def uri(self) -> str:
534
+ return f"databricks:/{self.name}"
535
+
327
536
  def as_resources(self) -> Sequence[DatabricksResource]:
328
537
  return [
329
538
  DatabricksServingEndpoint(
@@ -332,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
332
541
  ]
333
542
 
334
543
  def as_chat_model(self) -> LanguageModelLike:
335
- # Retrieve langchain chat client from workspace client to enable OBO
336
- # ChatOpenAI does not allow additional inputs at the moment, so we cannot use it directly
337
- # chat_client: LanguageModelLike = self.as_open_ai_client()
338
-
339
- # Create ChatDatabricksWrapper instance directly
340
- from dao_ai.chat_models import ChatDatabricksFiltered
341
-
342
- chat_client: LanguageModelLike = ChatDatabricksFiltered(
343
- model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
544
+ chat_client: LanguageModelLike = ChatDatabricks(
545
+ model=self.name,
546
+ temperature=self.temperature,
547
+ max_tokens=self.max_tokens,
548
+ use_responses_api=self.use_responses_api,
344
549
  )
345
- # chat_client: LanguageModelLike = ChatDatabricks(
346
- # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
347
- # )
348
550
 
349
551
  fallbacks: Sequence[LanguageModelLike] = []
350
552
  for fallback in self.fallbacks:
@@ -376,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
376
578
 
377
579
  return chat_client
378
580
 
581
+ def as_embeddings_model(self) -> Embeddings:
582
+ return DatabricksEmbeddings(endpoint=self.name)
583
+
379
584
 
380
585
  class VectorSearchEndpointType(str, Enum):
381
586
  STANDARD = "STANDARD"
@@ -387,8 +592,15 @@ class VectorSearchEndpoint(BaseModel):
387
592
  name: str
388
593
  type: VectorSearchEndpointType = VectorSearchEndpointType.STANDARD
389
594
 
595
+ @field_serializer("type")
596
+ def serialize_type(self, value: VectorSearchEndpointType) -> str:
597
+ """Ensure enum is serialized to string value."""
598
+ if isinstance(value, VectorSearchEndpointType):
599
+ return value.value
600
+ return str(value)
601
+
390
602
 
391
- class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
603
+ class IndexModel(IsDatabricksResource, HasFullName):
392
604
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
393
605
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
394
606
  name: str
@@ -413,12 +625,297 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
413
625
  ]
414
626
 
415
627
 
416
- class GenieRoomModel(BaseModel, IsDatabricksResource):
628
+ class FunctionModel(IsDatabricksResource, HasFullName):
629
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
630
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
631
+ name: Optional[str] = None
632
+
633
+ @model_validator(mode="after")
634
+ def validate_name_or_schema_required(self) -> Self:
635
+ if not self.name and not self.schema_model:
636
+ raise ValueError(
637
+ "Either 'name' or 'schema_model' must be provided for FunctionModel"
638
+ )
639
+ return self
640
+
641
+ @property
642
+ def full_name(self) -> str:
643
+ if self.schema_model:
644
+ name: str = ""
645
+ if self.name:
646
+ name = f".{self.name}"
647
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
648
+ return self.name
649
+
650
+ def exists(self) -> bool:
651
+ """Check if the function exists in Unity Catalog.
652
+
653
+ Returns:
654
+ True if the function exists, False otherwise.
655
+ """
656
+ try:
657
+ self.workspace_client.functions.get(name=self.full_name)
658
+ return True
659
+ except NotFound:
660
+ logger.debug(f"Function not found: {self.full_name}")
661
+ return False
662
+ except Exception as e:
663
+ logger.warning(
664
+ f"Error checking function existence for {self.full_name}: {e}"
665
+ )
666
+ return False
667
+
668
+ def as_resources(self) -> Sequence[DatabricksResource]:
669
+ resources: list[DatabricksResource] = []
670
+ if self.name:
671
+ resources.append(
672
+ DatabricksFunction(
673
+ function_name=self.full_name,
674
+ on_behalf_of_user=self.on_behalf_of_user,
675
+ )
676
+ )
677
+ else:
678
+ w: WorkspaceClient = self.workspace_client
679
+ schema_full_name: str = self.schema_model.full_name
680
+ functions: Iterator[FunctionInfo] = w.functions.list(
681
+ catalog_name=self.schema_model.catalog_name,
682
+ schema_name=self.schema_model.schema_name,
683
+ )
684
+ resources.extend(
685
+ [
686
+ DatabricksFunction(
687
+ function_name=f"{schema_full_name}.{function.name}",
688
+ on_behalf_of_user=self.on_behalf_of_user,
689
+ )
690
+ for function in functions
691
+ ]
692
+ )
693
+
694
+ return resources
695
+
696
+ @property
697
+ def api_scopes(self) -> Sequence[str]:
698
+ return ["sql.statement-execution"]
699
+
700
+
701
+ class WarehouseModel(IsDatabricksResource):
702
+ model_config = ConfigDict()
703
+ name: str
704
+ description: Optional[str] = None
705
+ warehouse_id: AnyVariable
706
+
707
+ @property
708
+ def api_scopes(self) -> Sequence[str]:
709
+ return [
710
+ "sql.warehouses",
711
+ "sql.statement-execution",
712
+ ]
713
+
714
+ def as_resources(self) -> Sequence[DatabricksResource]:
715
+ return [
716
+ DatabricksSQLWarehouse(
717
+ warehouse_id=value_of(self.warehouse_id),
718
+ on_behalf_of_user=self.on_behalf_of_user,
719
+ )
720
+ ]
721
+
722
+ @model_validator(mode="after")
723
+ def update_warehouse_id(self) -> Self:
724
+ self.warehouse_id = value_of(self.warehouse_id)
725
+ return self
726
+
727
+
728
+ class GenieRoomModel(IsDatabricksResource):
417
729
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
418
730
  name: str
419
731
  description: Optional[str] = None
420
732
  space_id: AnyVariable
421
733
 
734
+ _space_details: Optional[GenieSpace] = PrivateAttr(default=None)
735
+
736
+ def _get_space_details(self) -> GenieSpace:
737
+ if self._space_details is None:
738
+ self._space_details = self.workspace_client.genie.get_space(
739
+ space_id=self.space_id, include_serialized_space=True
740
+ )
741
+ return self._space_details
742
+
743
+ def _parse_serialized_space(self) -> dict[str, Any]:
744
+ """Parse the serialized_space JSON string and return the parsed data."""
745
+ import json
746
+
747
+ space_details = self._get_space_details()
748
+ if not space_details.serialized_space:
749
+ return {}
750
+
751
+ try:
752
+ return json.loads(space_details.serialized_space)
753
+ except json.JSONDecodeError as e:
754
+ logger.warning(f"Failed to parse serialized_space: {e}")
755
+ return {}
756
+
757
+ @property
758
+ def warehouse(self) -> Optional[WarehouseModel]:
759
+ """Extract warehouse information from the Genie space.
760
+
761
+ Returns:
762
+ WarehouseModel instance if warehouse_id is available, None otherwise.
763
+ """
764
+ space_details: GenieSpace = self._get_space_details()
765
+
766
+ if not space_details.warehouse_id:
767
+ return None
768
+
769
+ try:
770
+ response: GetWarehouseResponse = self.workspace_client.warehouses.get(
771
+ space_details.warehouse_id
772
+ )
773
+ warehouse_name: str = response.name or space_details.warehouse_id
774
+
775
+ warehouse_model = WarehouseModel(
776
+ name=warehouse_name,
777
+ warehouse_id=space_details.warehouse_id,
778
+ on_behalf_of_user=self.on_behalf_of_user,
779
+ service_principal=self.service_principal,
780
+ client_id=self.client_id,
781
+ client_secret=self.client_secret,
782
+ workspace_host=self.workspace_host,
783
+ pat=self.pat,
784
+ )
785
+
786
+ # Share the cached workspace client if available
787
+ if self._workspace_client is not None:
788
+ warehouse_model._workspace_client = self._workspace_client
789
+
790
+ return warehouse_model
791
+ except Exception as e:
792
+ logger.warning(
793
+ f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
794
+ )
795
+ return None
796
+
797
+ @property
798
+ def tables(self) -> list[TableModel]:
799
+ """Extract tables from the serialized Genie space.
800
+
801
+ Databricks Genie stores tables in: data_sources.tables[].identifier
802
+ Only includes tables that actually exist in Unity Catalog.
803
+ """
804
+ parsed_space = self._parse_serialized_space()
805
+ tables_list: list[TableModel] = []
806
+
807
+ # Primary structure: data_sources.tables with 'identifier' field
808
+ if "data_sources" in parsed_space:
809
+ data_sources = parsed_space["data_sources"]
810
+ if isinstance(data_sources, dict) and "tables" in data_sources:
811
+ tables_data = data_sources["tables"]
812
+ if isinstance(tables_data, list):
813
+ for table_item in tables_data:
814
+ table_name: str | None = None
815
+ if isinstance(table_item, dict):
816
+ # Standard Databricks structure uses 'identifier'
817
+ table_name = table_item.get("identifier") or table_item.get(
818
+ "name"
819
+ )
820
+ elif isinstance(table_item, str):
821
+ table_name = table_item
822
+
823
+ if table_name:
824
+ table_model = TableModel(
825
+ name=table_name,
826
+ on_behalf_of_user=self.on_behalf_of_user,
827
+ service_principal=self.service_principal,
828
+ client_id=self.client_id,
829
+ client_secret=self.client_secret,
830
+ workspace_host=self.workspace_host,
831
+ pat=self.pat,
832
+ )
833
+ # Share the cached workspace client if available
834
+ if self._workspace_client is not None:
835
+ table_model._workspace_client = self._workspace_client
836
+
837
+ # Verify the table exists before adding
838
+ if not table_model.exists():
839
+ continue
840
+
841
+ tables_list.append(table_model)
842
+
843
+ return tables_list
844
+
845
+ @property
846
+ def functions(self) -> list[FunctionModel]:
847
+ """Extract functions from the serialized Genie space.
848
+
849
+ Databricks Genie stores functions in multiple locations:
850
+ - instructions.sql_functions[].identifier (SQL functions)
851
+ - data_sources.functions[].identifier (other functions)
852
+ Only includes functions that actually exist in Unity Catalog.
853
+ """
854
+ parsed_space = self._parse_serialized_space()
855
+ functions_list: list[FunctionModel] = []
856
+ seen_functions: set[str] = set()
857
+
858
+ def add_function_if_exists(function_name: str) -> None:
859
+ """Helper to add a function if it exists and hasn't been added."""
860
+ if function_name in seen_functions:
861
+ return
862
+
863
+ seen_functions.add(function_name)
864
+ function_model = FunctionModel(
865
+ name=function_name,
866
+ on_behalf_of_user=self.on_behalf_of_user,
867
+ service_principal=self.service_principal,
868
+ client_id=self.client_id,
869
+ client_secret=self.client_secret,
870
+ workspace_host=self.workspace_host,
871
+ pat=self.pat,
872
+ )
873
+ # Share the cached workspace client if available
874
+ if self._workspace_client is not None:
875
+ function_model._workspace_client = self._workspace_client
876
+
877
+ # Verify the function exists before adding
878
+ if not function_model.exists():
879
+ return
880
+
881
+ functions_list.append(function_model)
882
+
883
+ # Primary structure: instructions.sql_functions with 'identifier' field
884
+ if "instructions" in parsed_space:
885
+ instructions = parsed_space["instructions"]
886
+ if isinstance(instructions, dict) and "sql_functions" in instructions:
887
+ sql_functions_data = instructions["sql_functions"]
888
+ if isinstance(sql_functions_data, list):
889
+ for function_item in sql_functions_data:
890
+ if isinstance(function_item, dict):
891
+ # SQL functions use 'identifier' field
892
+ function_name = function_item.get(
893
+ "identifier"
894
+ ) or function_item.get("name")
895
+ if function_name:
896
+ add_function_if_exists(function_name)
897
+
898
+ # Secondary structure: data_sources.functions with 'identifier' field
899
+ if "data_sources" in parsed_space:
900
+ data_sources = parsed_space["data_sources"]
901
+ if isinstance(data_sources, dict) and "functions" in data_sources:
902
+ functions_data = data_sources["functions"]
903
+ if isinstance(functions_data, list):
904
+ for function_item in functions_data:
905
+ function_name: str | None = None
906
+ if isinstance(function_item, dict):
907
+ # Standard Databricks structure uses 'identifier'
908
+ function_name = function_item.get(
909
+ "identifier"
910
+ ) or function_item.get("name")
911
+ elif isinstance(function_item, str):
912
+ function_name = function_item
913
+
914
+ if function_name:
915
+ add_function_if_exists(function_name)
916
+
917
+ return functions_list
918
+
422
919
  @property
423
920
  def api_scopes(self) -> Sequence[str]:
424
921
  return [
@@ -434,12 +931,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
434
931
  ]
435
932
 
436
933
  @model_validator(mode="after")
437
- def update_space_id(self):
934
+ def update_space_id(self) -> Self:
438
935
  self.space_id = value_of(self.space_id)
439
936
  return self
440
937
 
938
+ @model_validator(mode="after")
939
+ def update_description_from_space(self) -> Self:
940
+ """Populate description from GenieSpace if not provided."""
941
+ if not self.description:
942
+ try:
943
+ space_details = self._get_space_details()
944
+ if space_details.description:
945
+ self.description = space_details.description
946
+ except Exception as e:
947
+ logger.debug(f"Could not fetch description from Genie space: {e}")
948
+ return self
441
949
 
442
- class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
950
+
951
+ class VolumeModel(IsDatabricksResource, HasFullName):
443
952
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
444
953
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
445
954
  name: str
@@ -499,7 +1008,7 @@ class VolumePathModel(BaseModel, HasFullName):
499
1008
  provider.create_path(self)
500
1009
 
501
1010
 
502
- class VectorStoreModel(BaseModel, IsDatabricksResource):
1011
+ class VectorStoreModel(IsDatabricksResource):
503
1012
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
504
1013
  embedding_model: Optional[LLMModel] = None
505
1014
  index: Optional[IndexModel] = None
@@ -513,13 +1022,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
513
1022
  embedding_source_column: str
514
1023
 
515
1024
  @model_validator(mode="after")
516
- def set_default_embedding_model(self):
1025
+ def set_default_embedding_model(self) -> Self:
517
1026
  if not self.embedding_model:
518
1027
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
519
1028
  return self
520
1029
 
521
1030
  @model_validator(mode="after")
522
- def set_default_primary_key(self):
1031
+ def set_default_primary_key(self) -> Self:
523
1032
  if self.primary_key is None:
524
1033
  from dao_ai.providers.databricks import DatabricksProvider
525
1034
 
@@ -540,14 +1049,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
540
1049
  return self
541
1050
 
542
1051
  @model_validator(mode="after")
543
- def set_default_index(self):
1052
+ def set_default_index(self) -> Self:
544
1053
  if self.index is None:
545
1054
  name: str = f"{self.source_table.name}_index"
546
1055
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
547
1056
  return self
548
1057
 
549
1058
  @model_validator(mode="after")
550
- def set_default_endpoint(self):
1059
+ def set_default_endpoint(self) -> Self:
551
1060
  if self.endpoint is None:
552
1061
  from dao_ai.providers.databricks import (
553
1062
  DatabricksProvider,
@@ -598,64 +1107,9 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
598
1107
  provider.create_vector_store(self)
599
1108
 
600
1109
 
601
- class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
1110
+ class ConnectionModel(IsDatabricksResource, HasFullName):
602
1111
  model_config = ConfigDict()
603
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
604
- name: Optional[str] = None
605
-
606
- @model_validator(mode="after")
607
- def validate_name_or_schema_required(self) -> "FunctionModel":
608
- if not self.name and not self.schema_model:
609
- raise ValueError(
610
- "Either 'name' or 'schema_model' must be provided for FunctionModel"
611
- )
612
- return self
613
-
614
- @property
615
- def full_name(self) -> str:
616
- if self.schema_model:
617
- name: str = ""
618
- if self.name:
619
- name = f".{self.name}"
620
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
621
- return self.name
622
-
623
- def as_resources(self) -> Sequence[DatabricksResource]:
624
- resources: list[DatabricksResource] = []
625
- if self.name:
626
- resources.append(
627
- DatabricksFunction(
628
- function_name=self.full_name,
629
- on_behalf_of_user=self.on_behalf_of_user,
630
- )
631
- )
632
- else:
633
- w: WorkspaceClient = self.workspace_client
634
- schema_full_name: str = self.schema_model.full_name
635
- functions: Iterator[FunctionInfo] = w.functions.list(
636
- catalog_name=self.schema_model.catalog_name,
637
- schema_name=self.schema_model.schema_name,
638
- )
639
- resources.extend(
640
- [
641
- DatabricksFunction(
642
- function_name=f"{schema_full_name}.{function.name}",
643
- on_behalf_of_user=self.on_behalf_of_user,
644
- )
645
- for function in functions
646
- ]
647
- )
648
-
649
- return resources
650
-
651
- @property
652
- def api_scopes(self) -> Sequence[str]:
653
- return ["sql.statement-execution"]
654
-
655
-
656
- class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
657
- model_config = ConfigDict()
658
- name: str
1112
+ name: str
659
1113
 
660
1114
  @property
661
1115
  def full_name(self) -> str:
@@ -680,34 +1134,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
680
1134
  ]
681
1135
 
682
1136
 
683
- class WarehouseModel(BaseModel, IsDatabricksResource):
684
- model_config = ConfigDict()
685
- name: str
686
- description: Optional[str] = None
687
- warehouse_id: AnyVariable
688
-
689
- @property
690
- def api_scopes(self) -> Sequence[str]:
691
- return [
692
- "sql.warehouses",
693
- "sql.statement-execution",
694
- ]
695
-
696
- def as_resources(self) -> Sequence[DatabricksResource]:
697
- return [
698
- DatabricksSQLWarehouse(
699
- warehouse_id=value_of(self.warehouse_id),
700
- on_behalf_of_user=self.on_behalf_of_user,
701
- )
702
- ]
703
-
704
- @model_validator(mode="after")
705
- def update_warehouse_id(self):
706
- self.warehouse_id = value_of(self.warehouse_id)
707
- return self
708
-
1137
+ class DatabaseModel(IsDatabricksResource):
1138
+ """
1139
+ Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
1140
+
1141
+ Authentication is inherited from IsDatabricksResource. Additionally supports:
1142
+ - user/password: For user-based database authentication
1143
+
1144
+ Connection Types (determined by fields provided):
1145
+ - Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
1146
+ - Standard PostgreSQL: Provide `host` (authentication required via user/password)
1147
+
1148
+ Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
1149
+
1150
+ Example Databricks Lakebase with Service Principal:
1151
+ ```yaml
1152
+ databases:
1153
+ my_lakebase:
1154
+ name: my-database
1155
+ instance_name: my-lakebase-instance
1156
+ service_principal:
1157
+ client_id:
1158
+ env: SERVICE_PRINCIPAL_CLIENT_ID
1159
+ client_secret:
1160
+ scope: my-scope
1161
+ secret: sp-client-secret
1162
+ workspace_host:
1163
+ env: DATABRICKS_HOST
1164
+ ```
1165
+
1166
+ Example Databricks Lakebase with Ambient Authentication:
1167
+ ```yaml
1168
+ databases:
1169
+ my_lakebase:
1170
+ name: my-database
1171
+ instance_name: my-lakebase-instance
1172
+ on_behalf_of_user: true
1173
+ ```
1174
+
1175
+ Example Standard PostgreSQL:
1176
+ ```yaml
1177
+ databases:
1178
+ my_postgres:
1179
+ name: my-database
1180
+ host: my-postgres-host.example.com
1181
+ port: 5432
1182
+ database: my_db
1183
+ user: my_user
1184
+ password:
1185
+ env: PGPASSWORD
1186
+ ```
1187
+ """
709
1188
 
710
- class DatabaseModel(BaseModel, IsDatabricksResource):
711
1189
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
712
1190
  name: str
713
1191
  instance_name: Optional[str] = None
@@ -720,80 +1198,137 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
720
1198
  timeout_seconds: Optional[int] = 10
721
1199
  capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
722
1200
  node_count: Optional[int] = None
1201
+ # Database-specific auth (user identity for DB connection)
723
1202
  user: Optional[AnyVariable] = None
724
1203
  password: Optional[AnyVariable] = None
725
- client_id: Optional[AnyVariable] = None
726
- client_secret: Optional[AnyVariable] = None
727
- workspace_host: Optional[AnyVariable] = None
728
1204
 
729
1205
  @property
730
1206
  def api_scopes(self) -> Sequence[str]:
731
- return []
1207
+ return ["database.database-instances"]
1208
+
1209
+ @property
1210
+ def is_lakebase(self) -> bool:
1211
+ """Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
1212
+ return self.instance_name is not None
732
1213
 
733
1214
  def as_resources(self) -> Sequence[DatabricksResource]:
734
- return [
735
- DatabricksLakebase(
736
- database_instance_name=self.instance_name,
737
- on_behalf_of_user=self.on_behalf_of_user,
738
- )
739
- ]
1215
+ if self.is_lakebase:
1216
+ return [
1217
+ DatabricksLakebase(
1218
+ database_instance_name=self.instance_name,
1219
+ on_behalf_of_user=self.on_behalf_of_user,
1220
+ )
1221
+ ]
1222
+ return []
740
1223
 
741
1224
  @model_validator(mode="after")
742
- def update_instance_name(self):
743
- if self.instance_name is None:
744
- self.instance_name = self.name
1225
+ def validate_connection_type(self) -> Self:
1226
+ """Validate connection configuration based on type.
745
1227
 
1228
+ - If instance_name is provided: Databricks Lakebase connection
1229
+ (host is optional - will be fetched from API if not provided)
1230
+ - If only host is provided: Standard PostgreSQL connection
1231
+ (must not have instance_name)
1232
+ """
1233
+ if not self.instance_name and not self.host:
1234
+ raise ValueError(
1235
+ "Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
1236
+ )
746
1237
  return self
747
1238
 
748
1239
  @model_validator(mode="after")
749
- def update_user(self):
750
- if self.client_id or self.user:
1240
+ def update_user(self) -> Self:
1241
+ # Skip if using OBO (passive auth), explicit credentials, or explicit user
1242
+ if self.on_behalf_of_user or self.client_id or self.user or self.pat:
751
1243
  return self
752
1244
 
753
- self.user = self.workspace_client.current_user.me().user_name
754
- if not self.user:
755
- raise ValueError(
756
- "Unable to determine current user. Please provide a user name or OAuth credentials."
757
- )
1245
+ # For standard PostgreSQL, we need explicit user credentials
1246
+ # For Lakebase with no auth, ambient auth is allowed
1247
+ if not self.is_lakebase:
1248
+ # Standard PostgreSQL - try to determine current user for local development
1249
+ try:
1250
+ self.user = self.workspace_client.current_user.me().user_name
1251
+ except Exception as e:
1252
+ logger.warning(
1253
+ f"Could not determine current user for PostgreSQL database: {e}. "
1254
+ f"Please provide explicit user credentials."
1255
+ )
1256
+ else:
1257
+ # For Lakebase, try to determine current user but don't fail if we can't
1258
+ try:
1259
+ self.user = self.workspace_client.current_user.me().user_name
1260
+ except Exception:
1261
+ # If we can't determine user and no explicit auth, that's okay
1262
+ # for Lakebase with ambient auth - credentials will be injected at runtime
1263
+ pass
758
1264
 
759
1265
  return self
760
1266
 
761
1267
  @model_validator(mode="after")
762
- def update_host(self):
1268
+ def update_host(self) -> Self:
763
1269
  if self.host is not None:
764
1270
  return self
765
1271
 
766
- existing_instance: DatabaseInstance = (
767
- self.workspace_client.database.get_database_instance(
768
- name=self.instance_name
769
- )
770
- )
771
- self.host = existing_instance.read_write_dns
1272
+ # If instance_name is provided (Lakebase), try to fetch host from existing instance
1273
+ # This may fail for OBO/ambient auth during model logging (before deployment)
1274
+ if self.is_lakebase:
1275
+ try:
1276
+ existing_instance: DatabaseInstance = (
1277
+ self.workspace_client.database.get_database_instance(
1278
+ name=self.instance_name
1279
+ )
1280
+ )
1281
+ self.host = existing_instance.read_write_dns
1282
+ except Exception as e:
1283
+ # For Lakebase with OBO/ambient auth, we can't fetch at config time
1284
+ # The host will need to be provided explicitly or fetched at runtime
1285
+ if self.on_behalf_of_user:
1286
+ logger.debug(
1287
+ f"Could not fetch host for database {self.instance_name} "
1288
+ f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1289
+ )
1290
+ else:
1291
+ raise ValueError(
1292
+ f"Could not fetch host for database {self.instance_name}. "
1293
+ f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1294
+ )
772
1295
  return self
773
1296
 
774
1297
  @model_validator(mode="after")
775
- def validate_auth_methods(self):
1298
+ def validate_auth_methods(self) -> Self:
776
1299
  oauth_fields: Sequence[Any] = [
777
1300
  self.workspace_host,
778
1301
  self.client_id,
779
1302
  self.client_secret,
780
1303
  ]
781
1304
  has_oauth: bool = all(field is not None for field in oauth_fields)
1305
+ has_user_auth: bool = self.user is not None
1306
+ has_obo: bool = self.on_behalf_of_user is True
1307
+ has_pat: bool = self.pat is not None
782
1308
 
783
- pat_fields: Sequence[Any] = [self.user]
784
- has_user_auth: bool = all(field is not None for field in pat_fields)
1309
+ # Count how many auth methods are configured
1310
+ auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
785
1311
 
786
- if has_oauth and has_user_auth:
1312
+ if auth_methods_count > 1:
787
1313
  raise ValueError(
788
- "Cannot use both OAuth and user authentication methods. "
789
- "Please provide either OAuth credentials or user credentials."
1314
+ "Cannot mix authentication methods. "
1315
+ "Please provide exactly one of: "
1316
+ "on_behalf_of_user=true (for passive auth in model serving), "
1317
+ "OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
1318
+ "PAT (personal access token), "
1319
+ "or user credentials (user)."
790
1320
  )
791
1321
 
792
- if not has_oauth and not has_user_auth:
1322
+ # For standard PostgreSQL (host-based), at least one auth method must be configured
1323
+ # For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
1324
+ if not self.is_lakebase and auth_methods_count == 0:
793
1325
  raise ValueError(
794
- "At least one authentication method must be provided: "
795
- "either OAuth credentials (workspace_host, client_id, client_secret) "
796
- "or user credentials (user, password)."
1326
+ "PostgreSQL databases require explicit authentication. "
1327
+ "Please provide one of: "
1328
+ "OAuth credentials (workspace_host, client_id, client_secret), "
1329
+ "service_principal with workspace_host, "
1330
+ "PAT (personal access token), "
1331
+ "or user credentials (user)."
797
1332
  )
798
1333
 
799
1334
  return self
@@ -804,38 +1339,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
804
1339
  Get database connection parameters as a dictionary.
805
1340
 
806
1341
  Returns a dict with connection parameters suitable for psycopg ConnectionPool.
807
- If username is configured, it will be included; otherwise it will be omitted
808
- to allow Lakebase to authenticate using the token's identity.
1342
+
1343
+ For Lakebase: Uses Databricks-generated credentials (token-based auth).
1344
+ For standard PostgreSQL: Uses provided user/password credentials.
809
1345
  """
810
- from dao_ai.providers.base import ServiceProvider
811
- from dao_ai.providers.databricks import DatabricksProvider
1346
+ import uuid as _uuid
812
1347
 
1348
+ from databricks.sdk.service.database import DatabaseCredential
1349
+
1350
+ host: str
1351
+ port: int
1352
+ database: str
813
1353
  username: str | None = None
1354
+ password_value: str | None = None
1355
+
1356
+ # Resolve host - may need to fetch at runtime for OBO mode
1357
+ host_value: Any = self.host
1358
+ if host_value is None and self.is_lakebase and self.on_behalf_of_user:
1359
+ # Fetch host at runtime for OBO mode
1360
+ existing_instance: DatabaseInstance = (
1361
+ self.workspace_client.database.get_database_instance(
1362
+ name=self.instance_name
1363
+ )
1364
+ )
1365
+ host_value = existing_instance.read_write_dns
814
1366
 
815
- if self.client_id and self.client_secret and self.workspace_host:
816
- username = value_of(self.client_id)
817
- elif self.user:
818
- username = value_of(self.user)
1367
+ if host_value is None:
1368
+ instance_or_name = self.instance_name if self.is_lakebase else self.name
1369
+ raise ValueError(
1370
+ f"Database host not configured for {instance_or_name}. "
1371
+ "Please provide 'host' explicitly."
1372
+ )
819
1373
 
820
- host: str = value_of(self.host)
821
- port: int = value_of(self.port)
822
- database: str = value_of(self.database)
1374
+ host = value_of(host_value)
1375
+ port = value_of(self.port)
1376
+ database = value_of(self.database)
823
1377
 
824
- provider: ServiceProvider = DatabricksProvider(
825
- client_id=value_of(self.client_id),
826
- client_secret=value_of(self.client_secret),
827
- workspace_host=value_of(self.workspace_host),
828
- pat=value_of(self.password),
829
- )
1378
+ if self.is_lakebase:
1379
+ # Lakebase: Use Databricks-generated credentials
1380
+ if self.client_id and self.client_secret and self.workspace_host:
1381
+ username = value_of(self.client_id)
1382
+ elif self.user:
1383
+ username = value_of(self.user)
1384
+ # For OBO mode, no username is needed - the token identity is used
830
1385
 
831
- token: str = provider.lakebase_password_provider(self.instance_name)
1386
+ # Generate Databricks database credential (token)
1387
+ w: WorkspaceClient = self.workspace_client
1388
+ cred: DatabaseCredential = w.database.generate_database_credential(
1389
+ request_id=str(_uuid.uuid4()),
1390
+ instance_names=[self.instance_name],
1391
+ )
1392
+ password_value = cred.token
1393
+ else:
1394
+ # Standard PostgreSQL: Use provided credentials
1395
+ if self.user:
1396
+ username = value_of(self.user)
1397
+ if self.password:
1398
+ password_value = value_of(self.password)
1399
+
1400
+ if not username or not password_value:
1401
+ raise ValueError(
1402
+ f"Standard PostgreSQL databases require both 'user' and 'password'. "
1403
+ f"Database: {self.name}"
1404
+ )
832
1405
 
833
1406
  # Build connection parameters dictionary
834
1407
  params: dict[str, Any] = {
835
1408
  "dbname": database,
836
1409
  "host": host,
837
1410
  "port": port,
838
- "password": token,
1411
+ "password": password_value,
839
1412
  "sslmode": "require",
840
1413
  }
841
1414
 
@@ -866,11 +1439,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
866
1439
  def create(self, w: WorkspaceClient | None = None) -> None:
867
1440
  from dao_ai.providers.databricks import DatabricksProvider
868
1441
 
869
- provider: DatabricksProvider = DatabricksProvider()
1442
+ # Use provided workspace client or fall back to resource's own workspace_client
1443
+ if w is None:
1444
+ w = self.workspace_client
1445
+ provider: DatabricksProvider = DatabricksProvider(w=w)
870
1446
  provider.create_lakebase(self)
871
1447
  provider.create_lakebase_instance_role(self)
872
1448
 
873
1449
 
1450
+ class GenieLRUCacheParametersModel(BaseModel):
1451
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1452
+ capacity: int = 1000
1453
+ time_to_live_seconds: int | None = (
1454
+ 60 * 60 * 24
1455
+ ) # 1 day default, None or negative = never expires
1456
+ warehouse: WarehouseModel
1457
+
1458
+
1459
+ class GenieSemanticCacheParametersModel(BaseModel):
1460
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1461
+ time_to_live_seconds: int | None = (
1462
+ 60 * 60 * 24
1463
+ ) # 1 day default, None or negative = never expires
1464
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1465
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1466
+ question_weight: Optional[float] = (
1467
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1468
+ )
1469
+ context_weight: Optional[float] = (
1470
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1471
+ )
1472
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1473
+ embedding_dims: int | None = None # Auto-detected if None
1474
+ database: DatabaseModel
1475
+ warehouse: WarehouseModel
1476
+ table_name: str = "genie_semantic_cache"
1477
+ context_window_size: int = 3 # Number of previous turns to include for context
1478
+ max_context_tokens: int = (
1479
+ 2000 # Maximum context length to prevent extremely long embeddings
1480
+ )
1481
+
1482
+ @model_validator(mode="after")
1483
+ def compute_and_validate_weights(self) -> Self:
1484
+ """
1485
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1486
+
1487
+ Either question_weight or context_weight (or both) can be provided.
1488
+ The missing one will be computed as 1.0 - provided_weight.
1489
+ If both are provided, they must sum to 1.0.
1490
+ """
1491
+ if self.question_weight is None and self.context_weight is None:
1492
+ # Both missing - use defaults
1493
+ self.question_weight = 0.6
1494
+ self.context_weight = 0.4
1495
+ elif self.question_weight is None:
1496
+ # Compute question_weight from context_weight
1497
+ if not (0.0 <= self.context_weight <= 1.0):
1498
+ raise ValueError(
1499
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1500
+ )
1501
+ self.question_weight = 1.0 - self.context_weight
1502
+ elif self.context_weight is None:
1503
+ # Compute context_weight from question_weight
1504
+ if not (0.0 <= self.question_weight <= 1.0):
1505
+ raise ValueError(
1506
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1507
+ )
1508
+ self.context_weight = 1.0 - self.question_weight
1509
+ else:
1510
+ # Both provided - validate they sum to 1.0
1511
+ total_weight = self.question_weight + self.context_weight
1512
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1513
+ raise ValueError(
1514
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1515
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1516
+ f"of question vs context similarity in the combined score."
1517
+ )
1518
+
1519
+ return self
1520
+
1521
+
874
1522
  class SearchParametersModel(BaseModel):
875
1523
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
876
1524
  num_results: Optional[int] = 10
@@ -878,6 +1526,55 @@ class SearchParametersModel(BaseModel):
878
1526
  query_type: Optional[str] = "ANN"
879
1527
 
880
1528
 
1529
+ class RerankParametersModel(BaseModel):
1530
+ """
1531
+ Configuration for reranking retrieved documents using FlashRank.
1532
+
1533
+ FlashRank provides fast, local reranking without API calls using lightweight
1534
+ cross-encoder models. Reranking improves retrieval quality by reordering results
1535
+ based on semantic relevance to the query.
1536
+
1537
+ Typical workflow:
1538
+ 1. Retrieve more documents than needed (e.g., 50 via num_results)
1539
+ 2. Rerank all retrieved documents
1540
+ 3. Return top_n best matches (e.g., 5)
1541
+
1542
+ Example:
1543
+ ```yaml
1544
+ retriever:
1545
+ search_parameters:
1546
+ num_results: 50 # Retrieve more candidates
1547
+ rerank:
1548
+ model: ms-marco-MiniLM-L-12-v2
1549
+ top_n: 5 # Return top 5 after reranking
1550
+ ```
1551
+
1552
+ Available models (from fastest to most accurate):
1553
+ - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
1554
+ - "ms-marco-MiniLM-L-6-v2"
1555
+ - "ms-marco-MiniLM-L-12-v2" (default, good balance)
1556
+ - "rank-T5-flan" (most accurate, slower)
1557
+ """
1558
+
1559
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1560
+
1561
+ model: str = Field(
1562
+ default="ms-marco-MiniLM-L-12-v2",
1563
+ description="FlashRank model name. Default provides good balance of speed and accuracy.",
1564
+ )
1565
+ top_n: Optional[int] = Field(
1566
+ default=None,
1567
+ description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
1568
+ )
1569
+ cache_dir: Optional[str] = Field(
1570
+ default="~/.dao_ai/cache/flashrank",
1571
+ description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
1572
+ )
1573
+ columns: Optional[list[str]] = Field(
1574
+ default_factory=list, description="Columns to rerank using DatabricksReranker"
1575
+ )
1576
+
1577
+
881
1578
  class RetrieverModel(BaseModel):
882
1579
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
883
1580
  vector_store: VectorStoreModel
@@ -885,14 +1582,25 @@ class RetrieverModel(BaseModel):
885
1582
  search_parameters: SearchParametersModel = Field(
886
1583
  default_factory=SearchParametersModel
887
1584
  )
1585
+ rerank: Optional[RerankParametersModel | bool] = Field(
1586
+ default=None,
1587
+ description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
1588
+ )
888
1589
 
889
1590
  @model_validator(mode="after")
890
- def set_default_columns(self):
1591
+ def set_default_columns(self) -> Self:
891
1592
  if not self.columns:
892
1593
  columns: Sequence[str] = self.vector_store.columns
893
1594
  self.columns = columns
894
1595
  return self
895
1596
 
1597
+ @model_validator(mode="after")
1598
+ def set_default_reranker(self) -> Self:
1599
+ """Convert bool to ReRankParametersModel with defaults."""
1600
+ if isinstance(self.rerank, bool) and self.rerank:
1601
+ self.rerank = RerankParametersModel()
1602
+ return self
1603
+
896
1604
 
897
1605
  class FunctionType(str, Enum):
898
1606
  PYTHON = "python"
@@ -901,28 +1609,47 @@ class FunctionType(str, Enum):
901
1609
  MCP = "mcp"
902
1610
 
903
1611
 
904
- class HumanInTheLoopActionType(str, Enum):
905
- """Supported action types for human-in-the-loop interactions."""
1612
+ class HumanInTheLoopModel(BaseModel):
1613
+ """
1614
+ Configuration for Human-in-the-Loop tool approval.
906
1615
 
907
- ACCEPT = "accept"
908
- EDIT = "edit"
909
- RESPONSE = "response"
910
- DECLINE = "decline"
1616
+ This model configures when and how tools require human approval before execution.
1617
+ It maps to LangChain's HumanInTheLoopMiddleware.
911
1618
 
1619
+ LangChain supports three decision types:
1620
+ - "approve": Execute tool with original arguments
1621
+ - "edit": Modify arguments before execution
1622
+ - "reject": Skip execution with optional feedback message
1623
+ """
912
1624
 
913
- class HumanInTheLoopModel(BaseModel):
914
1625
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
915
- review_prompt: str = "Please review the tool call"
916
- interrupt_config: dict[str, Any] = Field(
917
- default_factory=lambda: {
918
- "allow_accept": True,
919
- "allow_edit": True,
920
- "allow_respond": True,
921
- "allow_decline": True,
922
- }
1626
+
1627
+ review_prompt: Optional[str] = Field(
1628
+ default=None,
1629
+ description="Message shown to the reviewer when approval is requested",
1630
+ )
1631
+
1632
+ allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
1633
+ default_factory=lambda: ["approve", "edit", "reject"],
1634
+ description="List of allowed decision types for this tool",
923
1635
  )
924
- decline_message: str = "Tool call declined by user"
925
- custom_actions: Optional[dict[str, str]] = Field(default_factory=dict)
1636
+
1637
+ @model_validator(mode="after")
1638
+ def validate_and_normalize_decisions(self) -> Self:
1639
+ """Validate and normalize allowed decisions."""
1640
+ if not self.allowed_decisions:
1641
+ raise ValueError("At least one decision type must be allowed")
1642
+
1643
+ # Remove duplicates while preserving order
1644
+ seen = set()
1645
+ unique_decisions = []
1646
+ for decision in self.allowed_decisions:
1647
+ if decision not in seen:
1648
+ seen.add(decision)
1649
+ unique_decisions.append(decision)
1650
+ self.allowed_decisions = unique_decisions
1651
+
1652
+ return self
926
1653
 
927
1654
 
928
1655
  class BaseFunctionModel(ABC, BaseModel):
@@ -931,7 +1658,6 @@ class BaseFunctionModel(ABC, BaseModel):
931
1658
  discriminator="type",
932
1659
  )
933
1660
  type: FunctionType
934
- name: str
935
1661
  human_in_the_loop: Optional[HumanInTheLoopModel] = None
936
1662
 
937
1663
  @abstractmethod
@@ -948,6 +1674,7 @@ class BaseFunctionModel(ABC, BaseModel):
948
1674
  class PythonFunctionModel(BaseFunctionModel, HasFullName):
949
1675
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
950
1676
  type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
1677
+ name: str
951
1678
 
952
1679
  @property
953
1680
  def full_name(self) -> str:
@@ -961,8 +1688,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
961
1688
 
962
1689
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
963
1690
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
964
- args: Optional[dict[str, Any]] = Field(default_factory=dict)
965
1691
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
1692
+ name: str
1693
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
966
1694
 
967
1695
  @property
968
1696
  def full_name(self) -> str:
@@ -974,7 +1702,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
974
1702
  return [create_factory_tool(self, **kwargs)]
975
1703
 
976
1704
  @model_validator(mode="after")
977
- def update_args(self):
1705
+ def update_args(self) -> Self:
978
1706
  for key, value in self.args.items():
979
1707
  self.args[key] = value_of(value)
980
1708
  return self
@@ -985,24 +1713,148 @@ class TransportType(str, Enum):
985
1713
  STDIO = "stdio"
986
1714
 
987
1715
 
988
- class McpFunctionModel(BaseFunctionModel, HasFullName):
1716
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1717
+ """
1718
+ MCP Function Model with authentication inherited from IsDatabricksResource.
1719
+
1720
+ Authentication for MCP connections uses the same options as other resources:
1721
+ - Service Principal (client_id + client_secret + workspace_host)
1722
+ - PAT (pat + workspace_host)
1723
+ - OBO (on_behalf_of_user)
1724
+ """
1725
+
989
1726
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
990
1727
  type: Literal[FunctionType.MCP] = FunctionType.MCP
991
-
992
1728
  transport: TransportType = TransportType.STREAMABLE_HTTP
993
1729
  command: Optional[str] = "python"
994
1730
  url: Optional[AnyVariable] = None
995
- connection: Optional[ConnectionModel] = None
996
1731
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
997
1732
  args: list[str] = Field(default_factory=list)
998
- pat: Optional[AnyVariable] = None
999
- client_id: Optional[AnyVariable] = None
1000
- client_secret: Optional[AnyVariable] = None
1001
- workspace_host: Optional[AnyVariable] = None
1733
+ # MCP-specific fields
1734
+ connection: Optional[ConnectionModel] = None
1735
+ functions: Optional[SchemaModel] = None
1736
+ genie_room: Optional[GenieRoomModel] = None
1737
+ sql: Optional[bool] = None
1738
+ vector_search: Optional[VectorStoreModel] = None
1002
1739
 
1003
1740
  @property
1004
- def full_name(self) -> str:
1005
- return self.name
1741
+ def api_scopes(self) -> Sequence[str]:
1742
+ """API scopes for MCP connections."""
1743
+ return [
1744
+ "serving.serving-endpoints",
1745
+ "mcp.genie",
1746
+ "mcp.functions",
1747
+ "mcp.vectorsearch",
1748
+ "mcp.external",
1749
+ ]
1750
+
1751
+ def as_resources(self) -> Sequence[DatabricksResource]:
1752
+ """MCP functions don't declare static resources."""
1753
+ return []
1754
+
1755
+ def _get_workspace_host(self) -> str:
1756
+ """
1757
+ Get the workspace host, either from config or from workspace client.
1758
+
1759
+ If connection is provided, uses its workspace client.
1760
+ Otherwise, falls back to the default Databricks host.
1761
+
1762
+ Returns:
1763
+ str: The workspace host URL with https:// scheme and without trailing slash
1764
+ """
1765
+ from dao_ai.utils import get_default_databricks_host, normalize_host
1766
+
1767
+ # Try to get workspace_host from config
1768
+ workspace_host: str | None = (
1769
+ normalize_host(value_of(self.workspace_host))
1770
+ if self.workspace_host
1771
+ else None
1772
+ )
1773
+
1774
+ # If no workspace_host in config, get it from workspace client
1775
+ if not workspace_host:
1776
+ # Use connection's workspace client if available
1777
+ if self.connection:
1778
+ workspace_host = normalize_host(
1779
+ self.connection.workspace_client.config.host
1780
+ )
1781
+ else:
1782
+ # get_default_databricks_host already normalizes the host
1783
+ workspace_host = get_default_databricks_host()
1784
+
1785
+ if not workspace_host:
1786
+ raise ValueError(
1787
+ "Could not determine workspace host. "
1788
+ "Please set workspace_host in config or DATABRICKS_HOST environment variable."
1789
+ )
1790
+
1791
+ # Remove trailing slash
1792
+ return workspace_host.rstrip("/")
1793
+
1794
+ @property
1795
+ def mcp_url(self) -> str:
1796
+ """
1797
+ Get the MCP URL for this function.
1798
+
1799
+ Returns the URL based on the configured source:
1800
+ - If url is set, returns it directly
1801
+ - If connection is set, constructs URL from connection
1802
+ - If genie_room is set, constructs Genie MCP URL
1803
+ - If sql is set, constructs DBSQL MCP URL (serverless)
1804
+ - If vector_search is set, constructs Vector Search MCP URL
1805
+ - If functions is set, constructs UC Functions MCP URL
1806
+
1807
+ URL patterns (per https://docs.databricks.com/aws/en/generative-ai/mcp/managed-mcp):
1808
+ - Genie: https://{host}/api/2.0/mcp/genie/{space_id}
1809
+ - DBSQL: https://{host}/api/2.0/mcp/sql (serverless, workspace-level)
1810
+ - Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
1811
+ - UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
1812
+ - Connection: https://{host}/api/2.0/mcp/external/{connection_name}
1813
+ """
1814
+ # Direct URL provided
1815
+ if self.url:
1816
+ return self.url
1817
+
1818
+ # Get workspace host (from config, connection, or default workspace client)
1819
+ workspace_host: str = self._get_workspace_host()
1820
+
1821
+ # UC Connection
1822
+ if self.connection:
1823
+ connection_name: str = self.connection.name
1824
+ return f"{workspace_host}/api/2.0/mcp/external/{connection_name}"
1825
+
1826
+ # Genie Room
1827
+ if self.genie_room:
1828
+ space_id: str = value_of(self.genie_room.space_id)
1829
+ return f"{workspace_host}/api/2.0/mcp/genie/{space_id}"
1830
+
1831
+ # DBSQL MCP server (serverless, workspace-level)
1832
+ if self.sql:
1833
+ return f"{workspace_host}/api/2.0/mcp/sql"
1834
+
1835
+ # Vector Search
1836
+ if self.vector_search:
1837
+ if (
1838
+ not self.vector_search.index
1839
+ or not self.vector_search.index.schema_model
1840
+ ):
1841
+ raise ValueError(
1842
+ "vector_search must have an index with a schema (catalog/schema) configured"
1843
+ )
1844
+ catalog: str = self.vector_search.index.schema_model.catalog_name
1845
+ schema: str = self.vector_search.index.schema_model.schema_name
1846
+ return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
1847
+
1848
+ # UC Functions MCP server
1849
+ if self.functions:
1850
+ catalog: str = self.functions.catalog_name
1851
+ schema: str = self.functions.schema_name
1852
+ return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
1853
+
1854
+ raise ValueError(
1855
+ "No URL source configured. Provide one of: url, connection, genie_room, "
1856
+ "sql, vector_search, or functions"
1857
+ )
1006
1858
 
1007
1859
  @field_serializer("transport")
1008
1860
  def serialize_transport(self, value) -> str:
@@ -1011,71 +1863,65 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1011
1863
  return str(value)
1012
1864
 
1013
1865
  @model_validator(mode="after")
1014
- def validate_mutually_exclusive(self):
1015
- if self.transport == TransportType.STREAMABLE_HTTP and not (
1016
- self.url or self.connection
1017
- ):
1018
- raise ValueError(
1019
- "url or connection must be provided for STREAMABLE_HTTP transport"
1020
- )
1021
- if self.transport == TransportType.STDIO and not self.command:
1022
- raise ValueError("command must not be provided for STDIO transport")
1023
- if self.transport == TransportType.STDIO and not self.args:
1024
- raise ValueError("args must not be provided for STDIO transport")
1866
+ def validate_mutually_exclusive(self) -> "McpFunctionModel":
1867
+ """Validate that exactly one URL source is provided."""
1868
+ # Count how many URL sources are provided
1869
+ url_sources: list[tuple[str, Any]] = [
1870
+ ("url", self.url),
1871
+ ("connection", self.connection),
1872
+ ("genie_room", self.genie_room),
1873
+ ("sql", self.sql),
1874
+ ("vector_search", self.vector_search),
1875
+ ("functions", self.functions),
1876
+ ]
1877
+
1878
+ provided_sources: list[str] = [
1879
+ name for name, value in url_sources if value is not None
1880
+ ]
1881
+
1882
+ if self.transport == TransportType.STREAMABLE_HTTP:
1883
+ if len(provided_sources) == 0:
1884
+ raise ValueError(
1885
+ "For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
1886
+ "url, connection, genie_room, sql, vector_search, or functions"
1887
+ )
1888
+ if len(provided_sources) > 1:
1889
+ raise ValueError(
1890
+ f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
1891
+ f"Found: {', '.join(provided_sources)}. "
1892
+ f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
1893
+ )
1894
+
1895
+ if self.transport == TransportType.STDIO:
1896
+ if not self.command:
1897
+ raise ValueError("command must be provided for STDIO transport")
1898
+ if not self.args:
1899
+ raise ValueError("args must be provided for STDIO transport")
1900
+
1025
1901
  return self
1026
1902
 
1027
1903
  @model_validator(mode="after")
1028
- def update_url(self):
1904
+ def update_url(self) -> "McpFunctionModel":
1029
1905
  self.url = value_of(self.url)
1030
1906
  return self
1031
1907
 
1032
1908
  @model_validator(mode="after")
1033
- def update_headers(self):
1909
+ def update_headers(self) -> "McpFunctionModel":
1034
1910
  for key, value in self.headers.items():
1035
1911
  self.headers[key] = value_of(value)
1036
1912
  return self
1037
1913
 
1038
- @model_validator(mode="after")
1039
- def validate_auth_methods(self):
1040
- oauth_fields: Sequence[Any] = [
1041
- self.client_id,
1042
- self.client_secret,
1043
- ]
1044
- has_oauth: bool = all(field is not None for field in oauth_fields)
1045
-
1046
- pat_fields: Sequence[Any] = [self.pat]
1047
- has_user_auth: bool = all(field is not None for field in pat_fields)
1048
-
1049
- if has_oauth and has_user_auth:
1050
- raise ValueError(
1051
- "Cannot use both OAuth and user authentication methods. "
1052
- "Please provide either OAuth credentials or user credentials."
1053
- )
1054
-
1055
- if (has_oauth or has_user_auth) and not self.workspace_host:
1056
- raise ValueError(
1057
- "Workspace host must be provided when using OAuth or user credentials."
1058
- )
1059
-
1060
- return self
1061
-
1062
1914
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1063
1915
  from dao_ai.tools import create_mcp_tools
1064
1916
 
1065
1917
  return create_mcp_tools(self)
1066
1918
 
1067
1919
 
1068
- class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
1920
+ class UnityCatalogFunctionModel(BaseFunctionModel):
1069
1921
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1070
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1071
- partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1072
1922
  type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
1073
-
1074
- @property
1075
- def full_name(self) -> str:
1076
- if self.schema_model:
1077
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1078
- return self.name
1923
+ resource: FunctionModel
1924
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1079
1925
 
1080
1926
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1081
1927
  from dao_ai.tools import create_uc_tools
@@ -1100,12 +1946,97 @@ class ToolModel(BaseModel):
1100
1946
  function: AnyTool
1101
1947
 
1102
1948
 
1103
- class GuardrailModel(BaseModel):
1104
- model_config = ConfigDict(use_enum_values=True, extra="forbid")
1105
- name: str
1106
- model: LLMModel
1107
- prompt: str
1108
- num_retries: Optional[int] = 3
1949
+ class PromptModel(BaseModel, HasFullName):
1950
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1951
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1952
+ name: str
1953
+ description: Optional[str] = None
1954
+ default_template: Optional[str] = None
1955
+ alias: Optional[str] = None
1956
+ version: Optional[int] = None
1957
+ tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1958
+ auto_register: bool = Field(
1959
+ default=False,
1960
+ description="Whether to automatically register the default_template to the prompt registry. "
1961
+ "If False, the prompt will only be loaded from the registry (never created/updated). "
1962
+ "Defaults to True for backward compatibility.",
1963
+ )
1964
+
1965
+ @property
1966
+ def template(self) -> str:
1967
+ from dao_ai.providers.databricks import DatabricksProvider
1968
+
1969
+ provider: DatabricksProvider = DatabricksProvider()
1970
+ prompt_version = provider.get_prompt(self)
1971
+ return prompt_version.to_single_brace_format()
1972
+
1973
+ @property
1974
+ def full_name(self) -> str:
1975
+ prompt_name: str = self.name
1976
+ if self.schema_model:
1977
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1978
+ return prompt_name
1979
+
1980
+ @property
1981
+ def uri(self) -> str:
1982
+ prompt_uri: str = f"prompts:/{self.full_name}"
1983
+
1984
+ if self.alias:
1985
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1986
+ elif self.version:
1987
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1988
+ else:
1989
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1990
+
1991
+ return prompt_uri
1992
+
1993
+ def as_prompt(self) -> PromptVersion:
1994
+ prompt_version: PromptVersion = load_prompt(self.uri)
1995
+ return prompt_version
1996
+
1997
+ @model_validator(mode="after")
1998
+ def validate_mutually_exclusive(self) -> Self:
1999
+ if self.alias and self.version:
2000
+ raise ValueError("Cannot specify both alias and version")
2001
+ return self
2002
+
2003
+
2004
+ class GuardrailModel(BaseModel):
2005
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2006
+ name: str
2007
+ model: str | LLMModel
2008
+ prompt: str | PromptModel
2009
+ num_retries: Optional[int] = 3
2010
+
2011
+ @model_validator(mode="after")
2012
+ def validate_llm_model(self) -> Self:
2013
+ if isinstance(self.model, str):
2014
+ self.model = LLMModel(name=self.model)
2015
+ return self
2016
+
2017
+
2018
+ class MiddlewareModel(BaseModel):
2019
+ """Configuration for middleware that can be applied to agents.
2020
+
2021
+ Middleware is defined at the AppConfig level and can be referenced by name
2022
+ in agent configurations using YAML anchors for reusability.
2023
+ """
2024
+
2025
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2026
+ name: str = Field(
2027
+ description="Fully qualified name of the middleware factory function"
2028
+ )
2029
+ args: dict[str, Any] = Field(
2030
+ default_factory=dict,
2031
+ description="Arguments to pass to the middleware factory function",
2032
+ )
2033
+
2034
+ @model_validator(mode="after")
2035
+ def resolve_args(self) -> Self:
2036
+ """Resolve any variable references in args."""
2037
+ for key, value in self.args.items():
2038
+ self.args[key] = value_of(value)
2039
+ return self
1109
2040
 
1110
2041
 
1111
2042
  class StorageType(str, Enum):
@@ -1116,14 +2047,12 @@ class StorageType(str, Enum):
1116
2047
  class CheckpointerModel(BaseModel):
1117
2048
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1118
2049
  name: str
1119
- type: Optional[StorageType] = StorageType.MEMORY
1120
2050
  database: Optional[DatabaseModel] = None
1121
2051
 
1122
- @model_validator(mode="after")
1123
- def validate_postgres_requires_database(self):
1124
- if self.type == StorageType.POSTGRES and not self.database:
1125
- raise ValueError("Database must be provided when storage type is POSTGRES")
1126
- return self
2052
+ @property
2053
+ def storage_type(self) -> StorageType:
2054
+ """Infer storage type from database presence."""
2055
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1127
2056
 
1128
2057
  def as_checkpointer(self) -> BaseCheckpointSaver:
1129
2058
  from dao_ai.memory import CheckpointManager
@@ -1139,16 +2068,14 @@ class StoreModel(BaseModel):
1139
2068
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1140
2069
  name: str
1141
2070
  embedding_model: Optional[LLMModel] = None
1142
- type: Optional[StorageType] = StorageType.MEMORY
1143
2071
  dims: Optional[int] = 1536
1144
2072
  database: Optional[DatabaseModel] = None
1145
2073
  namespace: Optional[str] = None
1146
2074
 
1147
- @model_validator(mode="after")
1148
- def validate_postgres_requires_database(self):
1149
- if self.type == StorageType.POSTGRES and not self.database:
1150
- raise ValueError("Database must be provided when storage type is POSTGRES")
1151
- return self
2075
+ @property
2076
+ def storage_type(self) -> StorageType:
2077
+ """Infer storage type from database presence."""
2078
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1152
2079
 
1153
2080
  def as_store(self) -> BaseStore:
1154
2081
  from dao_ai.memory import StoreManager
@@ -1166,41 +2093,158 @@ class MemoryModel(BaseModel):
1166
2093
  FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
1167
2094
 
1168
2095
 
1169
- class PromptModel(BaseModel, HasFullName):
2096
+ class ResponseFormatModel(BaseModel):
2097
+ """
2098
+ Configuration for structured response formats.
2099
+
2100
+ The response_schema field accepts either a type or a string:
2101
+ - Type (Pydantic model, dataclass, etc.): Used directly for structured output
2102
+ - String: First attempts to load as a fully qualified type name, falls back to JSON schema string
2103
+
2104
+ This unified approach simplifies the API while maintaining flexibility.
2105
+ """
2106
+
1170
2107
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1171
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1172
- name: str
1173
- description: Optional[str] = None
1174
- default_template: Optional[str] = None
1175
- alias: Optional[str] = None
1176
- version: Optional[int] = None
1177
- tags: Optional[dict[str, Any]] = Field(default_factory=dict)
2108
+ use_tool: Optional[bool] = Field(
2109
+ default=None,
2110
+ description=(
2111
+ "Strategy for structured output: "
2112
+ "None (default) = auto-detect from model capabilities, "
2113
+ "False = force ProviderStrategy (native), "
2114
+ "True = force ToolStrategy (function calling)"
2115
+ ),
2116
+ )
2117
+ response_schema: Optional[str | type] = Field(
2118
+ default=None,
2119
+ description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
2120
+ )
1178
2121
 
1179
- @property
1180
- def template(self) -> str:
1181
- from dao_ai.providers.databricks import DatabricksProvider
2122
+ def as_strategy(self) -> ProviderStrategy | ToolStrategy:
2123
+ """
2124
+ Convert response_schema to appropriate LangChain strategy.
1182
2125
 
1183
- provider: DatabricksProvider = DatabricksProvider()
1184
- prompt: str = provider.get_prompt(self)
1185
- return prompt
2126
+ Returns:
2127
+ - None if no response_schema configured
2128
+ - Raw schema/type for auto-detection (when use_tool=None)
2129
+ - ToolStrategy wrapping the schema (when use_tool=True)
2130
+ - ProviderStrategy wrapping the schema (when use_tool=False)
1186
2131
 
1187
- @property
1188
- def full_name(self) -> str:
1189
- if self.schema_model:
1190
- name: str = ""
1191
- if self.name:
1192
- name = f".{self.name}"
1193
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
1194
- return self.name
2132
+ Raises:
2133
+ ValueError: If response_schema is a JSON schema string that cannot be parsed
2134
+ """
2135
+
2136
+ if self.response_schema is None:
2137
+ return None
2138
+
2139
+ schema = self.response_schema
2140
+
2141
+ # Handle type schemas (Pydantic, dataclass, etc.)
2142
+ if self.is_type_schema:
2143
+ if self.use_tool is None:
2144
+ # Auto-detect: Pass schema directly, let LangChain decide
2145
+ return schema
2146
+ elif self.use_tool is True:
2147
+ # Force ToolStrategy (function calling)
2148
+ return ToolStrategy(schema)
2149
+ else: # use_tool is False
2150
+ # Force ProviderStrategy (native structured output)
2151
+ return ProviderStrategy(schema)
2152
+
2153
+ # Handle JSON schema strings
2154
+ elif self.is_json_schema:
2155
+ import json
2156
+
2157
+ try:
2158
+ schema_dict = json.loads(schema)
2159
+ except json.JSONDecodeError as e:
2160
+ raise ValueError(f"Invalid JSON schema string: {e}") from e
2161
+
2162
+ # Apply same use_tool logic as type schemas
2163
+ if self.use_tool is None:
2164
+ # Auto-detect
2165
+ return schema_dict
2166
+ elif self.use_tool is True:
2167
+ # Force ToolStrategy
2168
+ return ToolStrategy(schema_dict)
2169
+ else: # use_tool is False
2170
+ # Force ProviderStrategy
2171
+ return ProviderStrategy(schema_dict)
2172
+
2173
+ return None
1195
2174
 
1196
2175
  @model_validator(mode="after")
1197
- def validate_mutually_exclusive(self):
1198
- if self.alias and self.version:
1199
- raise ValueError("Cannot specify both alias and version")
1200
- return self
2176
+ def validate_response_schema(self) -> Self:
2177
+ """
2178
+ Validate and convert response_schema.
2179
+
2180
+ Processing logic:
2181
+ 1. If None: no response format specified
2182
+ 2. If type: use directly as structured output type
2183
+ 3. If str: try to load as FQN using type_from_fqn
2184
+ - Success: response_schema becomes the loaded type
2185
+ - Failure: keep as string (treated as JSON schema)
2186
+
2187
+ After validation, response_schema is one of:
2188
+ - None (no schema)
2189
+ - type (use for structured output)
2190
+ - str (JSON schema)
2191
+
2192
+ Returns:
2193
+ Self with validated response_schema
2194
+ """
2195
+ if self.response_schema is None:
2196
+ return self
2197
+
2198
+ # If already a type, return
2199
+ if isinstance(self.response_schema, type):
2200
+ return self
2201
+
2202
+ # If it's a string, try to load as type, fallback to json_schema
2203
+ if isinstance(self.response_schema, str):
2204
+ from dao_ai.utils import type_from_fqn
2205
+
2206
+ try:
2207
+ resolved_type = type_from_fqn(self.response_schema)
2208
+ self.response_schema = resolved_type
2209
+ logger.debug(
2210
+ f"Resolved response_schema string to type: {resolved_type}"
2211
+ )
2212
+ return self
2213
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
2214
+ # Keep as string - it's a JSON schema
2215
+ logger.debug(
2216
+ f"Could not resolve '{self.response_schema}' as type: {e}. "
2217
+ f"Treating as JSON schema string."
2218
+ )
2219
+ return self
2220
+
2221
+ # Invalid type
2222
+ raise ValueError(
2223
+ f"response_schema must be None, type, or str, got {type(self.response_schema)}"
2224
+ )
2225
+
2226
+ @property
2227
+ def is_type_schema(self) -> bool:
2228
+ """Returns True if response_schema is a type (not JSON schema string)."""
2229
+ return isinstance(self.response_schema, type)
2230
+
2231
+ @property
2232
+ def is_json_schema(self) -> bool:
2233
+ """Returns True if response_schema is a JSON schema string (not a type)."""
2234
+ return isinstance(self.response_schema, str)
1201
2235
 
1202
2236
 
1203
2237
  class AgentModel(BaseModel):
2238
+ """
2239
+ Configuration model for an agent in the DAO AI framework.
2240
+
2241
+ Agents combine an LLM with tools and middleware to create systems that can
2242
+ reason about tasks, decide which tools to use, and iteratively work towards solutions.
2243
+
2244
+ Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
2245
+ providing a more flexible and composable way to customize agent behavior.
2246
+ """
2247
+
1204
2248
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1205
2249
  name: str
1206
2250
  description: Optional[str] = None
@@ -1209,9 +2253,54 @@ class AgentModel(BaseModel):
1209
2253
  guardrails: list[GuardrailModel] = Field(default_factory=list)
1210
2254
  prompt: Optional[str | PromptModel] = None
1211
2255
  handoff_prompt: Optional[str] = None
1212
- create_agent_hook: Optional[FunctionHook] = None
1213
- pre_agent_hook: Optional[FunctionHook] = None
1214
- post_agent_hook: Optional[FunctionHook] = None
2256
+ middleware: list[MiddlewareModel] = Field(
2257
+ default_factory=list,
2258
+ description="List of middleware to apply to this agent",
2259
+ )
2260
+ response_format: Optional[ResponseFormatModel | type | str] = None
2261
+
2262
+ @model_validator(mode="after")
2263
+ def validate_response_format(self) -> Self:
2264
+ """
2265
+ Validate and normalize response_format.
2266
+
2267
+ Accepts:
2268
+ - None (no response format)
2269
+ - ResponseFormatModel (already validated)
2270
+ - type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
2271
+ - str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
2272
+
2273
+ ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
2274
+ """
2275
+ if self.response_format is None or isinstance(
2276
+ self.response_format, ResponseFormatModel
2277
+ ):
2278
+ return self
2279
+
2280
+ # Convert type or str to ResponseFormatModel
2281
+ # ResponseFormatModel's validator will handle the smart type loading and fallback
2282
+ if isinstance(self.response_format, (type, str)):
2283
+ self.response_format = ResponseFormatModel(
2284
+ response_schema=self.response_format
2285
+ )
2286
+ return self
2287
+
2288
+ # Invalid type
2289
+ raise ValueError(
2290
+ f"response_format must be None, ResponseFormatModel, type, or str, "
2291
+ f"got {type(self.response_format)}"
2292
+ )
2293
+
2294
+ def as_runnable(self) -> RunnableLike:
2295
+ from dao_ai.nodes import create_agent_node
2296
+
2297
+ return create_agent_node(self)
2298
+
2299
+ def as_responses_agent(self) -> ResponsesAgent:
2300
+ from dao_ai.models import create_responses_agent
2301
+
2302
+ graph: CompiledStateGraph = self.as_runnable()
2303
+ return create_responses_agent(graph)
1215
2304
 
1216
2305
 
1217
2306
  class SupervisorModel(BaseModel):
@@ -1219,12 +2308,20 @@ class SupervisorModel(BaseModel):
1219
2308
  model: LLMModel
1220
2309
  tools: list[ToolModel] = Field(default_factory=list)
1221
2310
  prompt: Optional[str] = None
2311
+ middleware: list[MiddlewareModel] = Field(
2312
+ default_factory=list,
2313
+ description="List of middleware to apply to the supervisor",
2314
+ )
1222
2315
 
1223
2316
 
1224
2317
  class SwarmModel(BaseModel):
1225
2318
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1226
2319
  model: LLMModel
1227
2320
  default_agent: Optional[AgentModel | str] = None
2321
+ middleware: list[MiddlewareModel] = Field(
2322
+ default_factory=list,
2323
+ description="List of middleware to apply to all agents in the swarm",
2324
+ )
1228
2325
  handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
1229
2326
  default_factory=dict
1230
2327
  )
@@ -1237,7 +2334,7 @@ class OrchestrationModel(BaseModel):
1237
2334
  memory: Optional[MemoryModel] = None
1238
2335
 
1239
2336
  @model_validator(mode="after")
1240
- def validate_mutually_exclusive(self):
2337
+ def validate_mutually_exclusive(self) -> Self:
1241
2338
  if self.supervisor is not None and self.swarm is not None:
1242
2339
  raise ValueError("Cannot specify both supervisor and swarm")
1243
2340
  if self.supervisor is None and self.swarm is None:
@@ -1267,9 +2364,21 @@ class Entitlement(str, Enum):
1267
2364
 
1268
2365
  class AppPermissionModel(BaseModel):
1269
2366
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1270
- principals: list[str] = Field(default_factory=list)
2367
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
1271
2368
  entitlements: list[Entitlement]
1272
2369
 
2370
+ @model_validator(mode="after")
2371
+ def resolve_principals(self) -> Self:
2372
+ """Resolve ServicePrincipalModel objects to their client_id."""
2373
+ resolved: list[str] = []
2374
+ for principal in self.principals:
2375
+ if isinstance(principal, ServicePrincipalModel):
2376
+ resolved.append(value_of(principal.client_id))
2377
+ else:
2378
+ resolved.append(principal)
2379
+ self.principals = resolved
2380
+ return self
2381
+
1273
2382
 
1274
2383
  class LogLevel(str, Enum):
1275
2384
  TRACE = "TRACE"
@@ -1330,27 +2439,81 @@ class ChatPayload(BaseModel):
1330
2439
 
1331
2440
  return self
1332
2441
 
2442
+ @model_validator(mode="after")
2443
+ def ensure_thread_id(self) -> "ChatPayload":
2444
+ """Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
2445
+ import uuid
2446
+
2447
+ if self.custom_inputs is None:
2448
+ self.custom_inputs = {}
2449
+
2450
+ # Get or create configurable section
2451
+ configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
2452
+
2453
+ # Check if thread_id or conversation_id exists
2454
+ has_thread_id = configurable.get("thread_id") is not None
2455
+ has_conversation_id = configurable.get("conversation_id") is not None
2456
+
2457
+ # If neither is provided, generate a UUID for conversation_id
2458
+ if not has_thread_id and not has_conversation_id:
2459
+ configurable["conversation_id"] = str(uuid.uuid4())
2460
+ self.custom_inputs["configurable"] = configurable
2461
+
2462
+ return self
2463
+
2464
+ def as_messages(self) -> Sequence[BaseMessage]:
2465
+ return messages_from_dict(
2466
+ [{"type": m.role, "content": m.content} for m in self.messages]
2467
+ )
2468
+
2469
+ def as_agent_request(self) -> ResponsesAgentRequest:
2470
+ from mlflow.types.responses_helpers import Message as _Message
2471
+
2472
+ return ResponsesAgentRequest(
2473
+ input=[_Message(role=m.role, content=m.content) for m in self.messages],
2474
+ custom_inputs=self.custom_inputs,
2475
+ )
2476
+
1333
2477
 
1334
2478
  class ChatHistoryModel(BaseModel):
2479
+ """
2480
+ Configuration for chat history summarization.
2481
+
2482
+ Attributes:
2483
+ model: The LLM to use for generating summaries.
2484
+ max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
2485
+ After summarization, recent messages totaling up to this many tokens are preserved.
2486
+ max_tokens_before_summary: Token threshold that triggers summarization.
2487
+ When conversation exceeds this, summarization runs. Mutually exclusive with
2488
+ max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
2489
+ max_messages_before_summary: Message count threshold that triggers summarization.
2490
+ When conversation exceeds this many messages, summarization runs.
2491
+ Mutually exclusive with max_tokens_before_summary.
2492
+ """
2493
+
1335
2494
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1336
2495
  model: LLMModel
1337
- max_tokens: int = 256
1338
- max_tokens_before_summary: Optional[int] = None
1339
- max_messages_before_summary: Optional[int] = None
1340
- max_summary_tokens: int = 255
1341
-
1342
- @model_validator(mode="after")
1343
- def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1344
- if self.max_summary_tokens >= self.max_tokens:
1345
- raise ValueError(
1346
- f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1347
- )
1348
- return self
2496
+ max_tokens: int = Field(
2497
+ default=2048,
2498
+ gt=0,
2499
+ description="Maximum tokens to keep after summarization",
2500
+ )
2501
+ max_tokens_before_summary: Optional[int] = Field(
2502
+ default=None,
2503
+ gt=0,
2504
+ description="Token threshold that triggers summarization",
2505
+ )
2506
+ max_messages_before_summary: Optional[int] = Field(
2507
+ default=None,
2508
+ gt=0,
2509
+ description="Message count threshold that triggers summarization",
2510
+ )
1349
2511
 
1350
2512
 
1351
2513
  class AppModel(BaseModel):
1352
2514
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1353
2515
  name: str
2516
+ service_principal: Optional[ServicePrincipalModel] = None
1354
2517
  description: Optional[str] = None
1355
2518
  log_level: Optional[LogLevel] = "WARNING"
1356
2519
  registered_model: RegisteredModelModel
@@ -1371,23 +2534,54 @@ class AppModel(BaseModel):
1371
2534
  shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1372
2535
  default_factory=list
1373
2536
  )
1374
- message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1375
- default_factory=list
1376
- )
1377
2537
  input_example: Optional[ChatPayload] = None
1378
2538
  chat_history: Optional[ChatHistoryModel] = None
1379
2539
  code_paths: list[str] = Field(default_factory=list)
1380
2540
  pip_requirements: list[str] = Field(default_factory=list)
2541
+ python_version: Optional[str] = Field(
2542
+ default="3.12",
2543
+ description="Python version for Model Serving deployment. Defaults to 3.12 "
2544
+ "which is supported by Databricks Model Serving. This allows deploying from "
2545
+ "environments with different Python versions (e.g., Databricks Apps with 3.11).",
2546
+ )
1381
2547
 
1382
2548
  @model_validator(mode="after")
1383
- def validate_agents_not_empty(self):
2549
+ def set_databricks_env_vars(self) -> Self:
2550
+ """Set Databricks environment variables for Model Serving.
2551
+
2552
+ Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
2553
+ Values explicitly provided in environment_vars take precedence.
2554
+ """
2555
+ from dao_ai.utils import get_default_databricks_host
2556
+
2557
+ # Set DATABRICKS_HOST if not already provided
2558
+ if "DATABRICKS_HOST" not in self.environment_vars:
2559
+ host: str | None = get_default_databricks_host()
2560
+ if host:
2561
+ self.environment_vars["DATABRICKS_HOST"] = host
2562
+
2563
+ # Set service principal credentials if provided
2564
+ if self.service_principal is not None:
2565
+ if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
2566
+ self.environment_vars["DATABRICKS_CLIENT_ID"] = (
2567
+ self.service_principal.client_id
2568
+ )
2569
+ if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
2570
+ self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
2571
+ self.service_principal.client_secret
2572
+ )
2573
+ return self
2574
+
2575
+ @model_validator(mode="after")
2576
+ def validate_agents_not_empty(self) -> Self:
1384
2577
  if not self.agents:
1385
2578
  raise ValueError("At least one agent must be specified")
1386
2579
  return self
1387
2580
 
1388
2581
  @model_validator(mode="after")
1389
- def update_environment_vars(self):
2582
+ def resolve_environment_vars(self) -> Self:
1390
2583
  for key, value in self.environment_vars.items():
2584
+ updated_value: str
1391
2585
  if isinstance(value, SecretVariableModel):
1392
2586
  updated_value = str(value)
1393
2587
  else:
@@ -1397,7 +2591,7 @@ class AppModel(BaseModel):
1397
2591
  return self
1398
2592
 
1399
2593
  @model_validator(mode="after")
1400
- def set_default_orchestration(self):
2594
+ def set_default_orchestration(self) -> Self:
1401
2595
  if self.orchestration is None:
1402
2596
  if len(self.agents) > 1:
1403
2597
  default_agent: AgentModel = self.agents[0]
@@ -1417,14 +2611,14 @@ class AppModel(BaseModel):
1417
2611
  return self
1418
2612
 
1419
2613
  @model_validator(mode="after")
1420
- def set_default_endpoint_name(self):
2614
+ def set_default_endpoint_name(self) -> Self:
1421
2615
  if self.endpoint_name is None:
1422
2616
  self.endpoint_name = self.name
1423
2617
  return self
1424
2618
 
1425
2619
  @model_validator(mode="after")
1426
- def set_default_agent(self):
1427
- default_agent_name = self.agents[0].name
2620
+ def set_default_agent(self) -> Self:
2621
+ default_agent_name: str = self.agents[0].name
1428
2622
 
1429
2623
  if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
1430
2624
  self.orchestration.swarm.default_agent = default_agent_name
@@ -1432,7 +2626,7 @@ class AppModel(BaseModel):
1432
2626
  return self
1433
2627
 
1434
2628
  @model_validator(mode="after")
1435
- def add_code_paths_to_sys_path(self):
2629
+ def add_code_paths_to_sys_path(self) -> Self:
1436
2630
  for code_path in self.code_paths:
1437
2631
  parent_path: str = str(Path(code_path).parent)
1438
2632
  if parent_path not in sys.path:
@@ -1459,6 +2653,202 @@ class EvaluationModel(BaseModel):
1459
2653
  guidelines: list[GuidelineModel] = Field(default_factory=list)
1460
2654
 
1461
2655
 
2656
+ class EvaluationDatasetExpectationsModel(BaseModel):
2657
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2658
+ expected_response: Optional[str] = None
2659
+ expected_facts: Optional[list[str]] = None
2660
+
2661
+ @model_validator(mode="after")
2662
+ def validate_mutually_exclusive(self) -> Self:
2663
+ if self.expected_response is not None and self.expected_facts is not None:
2664
+ raise ValueError("Cannot specify both expected_response and expected_facts")
2665
+ return self
2666
+
2667
+
2668
+ class EvaluationDatasetEntryModel(BaseModel):
2669
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2670
+ inputs: ChatPayload
2671
+ expectations: EvaluationDatasetExpectationsModel
2672
+
2673
+ def to_mlflow_format(self) -> dict[str, Any]:
2674
+ """
2675
+ Convert to MLflow evaluation dataset format.
2676
+
2677
+ Flattens the expectations fields to the top level alongside inputs,
2678
+ which is the format expected by MLflow's Correctness scorer.
2679
+
2680
+ Returns:
2681
+ dict: Flattened dictionary with inputs and expectation fields at top level
2682
+ """
2683
+ result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
2684
+
2685
+ # Flatten expectations to top level for MLflow compatibility
2686
+ if self.expectations.expected_response is not None:
2687
+ result["expected_response"] = self.expectations.expected_response
2688
+ if self.expectations.expected_facts is not None:
2689
+ result["expected_facts"] = self.expectations.expected_facts
2690
+
2691
+ return result
2692
+
2693
+
2694
+ class EvaluationDatasetModel(BaseModel, HasFullName):
2695
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2696
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
2697
+ name: str
2698
+ data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
2699
+ overwrite: Optional[bool] = False
2700
+
2701
+ def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
2702
+ evaluation_dataset: EvaluationDataset
2703
+ needs_creation: bool = False
2704
+
2705
+ try:
2706
+ evaluation_dataset = get_dataset(name=self.full_name)
2707
+ if self.overwrite:
2708
+ logger.warning(f"Overwriting dataset {self.full_name}")
2709
+ workspace_client: WorkspaceClient = w if w else WorkspaceClient()
2710
+ logger.debug(f"Dropping table: {self.full_name}")
2711
+ workspace_client.tables.delete(full_name=self.full_name)
2712
+ needs_creation = True
2713
+ except Exception:
2714
+ logger.warning(
2715
+ f"Dataset {self.full_name} not found, will create new dataset"
2716
+ )
2717
+ needs_creation = True
2718
+
2719
+ # Create dataset if needed (either new or after overwrite)
2720
+ if needs_creation:
2721
+ evaluation_dataset = create_dataset(name=self.full_name)
2722
+ if self.data:
2723
+ logger.debug(
2724
+ f"Merging {len(self.data)} entries into dataset {self.full_name}"
2725
+ )
2726
+ # Use to_mlflow_format() to flatten expectations for MLflow compatibility
2727
+ evaluation_dataset.merge_records(
2728
+ [e.to_mlflow_format() for e in self.data]
2729
+ )
2730
+
2731
+ return evaluation_dataset
2732
+
2733
+ @property
2734
+ def full_name(self) -> str:
2735
+ if self.schema_model:
2736
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
2737
+ return self.name
2738
+
2739
+
2740
+ class PromptOptimizationModel(BaseModel):
2741
+ """Configuration for prompt optimization using GEPA.
2742
+
2743
+ GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
2744
+ optimizer that uses reflective mutation to improve prompts based on
2745
+ evaluation feedback.
2746
+
2747
+ Example:
2748
+ prompt_optimization:
2749
+ name: optimize_my_prompt
2750
+ prompt: *my_prompt
2751
+ agent: *my_agent
2752
+ dataset: *my_training_dataset
2753
+ reflection_model: databricks-meta-llama-3-3-70b-instruct
2754
+ num_candidates: 50
2755
+ """
2756
+
2757
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2758
+ name: str
2759
+ prompt: Optional[PromptModel] = None
2760
+ agent: AgentModel
2761
+ dataset: EvaluationDatasetModel # Training dataset with examples
2762
+ reflection_model: Optional[LLMModel | str] = None
2763
+ num_candidates: Optional[int] = 50
2764
+
2765
+ def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
2766
+ """
2767
+ Optimize the prompt using GEPA.
2768
+
2769
+ Args:
2770
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
2771
+
2772
+ Returns:
2773
+ PromptModel: The optimized prompt model
2774
+ """
2775
+ from dao_ai.optimization import OptimizationResult, optimize_prompt
2776
+
2777
+ # Get reflection model name
2778
+ reflection_model_name: str | None = None
2779
+ if self.reflection_model:
2780
+ if isinstance(self.reflection_model, str):
2781
+ reflection_model_name = self.reflection_model
2782
+ else:
2783
+ reflection_model_name = self.reflection_model.uri
2784
+
2785
+ # Ensure prompt is set
2786
+ prompt = self.prompt
2787
+ if prompt is None:
2788
+ raise ValueError(
2789
+ f"Prompt optimization '{self.name}' requires a prompt to be set"
2790
+ )
2791
+
2792
+ result: OptimizationResult = optimize_prompt(
2793
+ prompt=prompt,
2794
+ agent=self.agent,
2795
+ dataset=self.dataset,
2796
+ reflection_model=reflection_model_name,
2797
+ num_candidates=self.num_candidates or 50,
2798
+ register_if_improved=True,
2799
+ )
2800
+
2801
+ return result.optimized_prompt
2802
+
2803
+ @model_validator(mode="after")
2804
+ def set_defaults(self) -> Self:
2805
+ # If no prompt is specified, try to use the agent's prompt
2806
+ if self.prompt is None:
2807
+ if isinstance(self.agent.prompt, PromptModel):
2808
+ self.prompt = self.agent.prompt
2809
+ else:
2810
+ raise ValueError(
2811
+ f"Prompt optimization '{self.name}' requires either an explicit prompt "
2812
+ f"or an agent with a prompt configured"
2813
+ )
2814
+
2815
+ return self
2816
+
2817
+
2818
+ class OptimizationsModel(BaseModel):
2819
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2820
+ training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
2821
+ prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
2822
+ default_factory=dict
2823
+ )
2824
+
2825
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
2826
+ """
2827
+ Optimize all prompts in this configuration.
2828
+
2829
+ This method:
2830
+ 1. Ensures all training datasets are created/registered in MLflow
2831
+ 2. Runs each prompt optimization
2832
+
2833
+ Args:
2834
+ w: Optional WorkspaceClient for Databricks operations
2835
+
2836
+ Returns:
2837
+ dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
2838
+ """
2839
+ # First, ensure all training datasets are created/registered in MLflow
2840
+ logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
2841
+ for dataset_name, dataset_model in self.training_datasets.items():
2842
+ logger.debug(f"Creating/updating dataset: {dataset_name}")
2843
+ dataset_model.as_dataset()
2844
+
2845
+ # Run optimizations
2846
+ results: dict[str, PromptModel] = {}
2847
+ for name, optimization in self.prompt_optimizations.items():
2848
+ results[name] = optimization.optimize(w)
2849
+ return results
2850
+
2851
+
1462
2852
  class DatasetFormat(str, Enum):
1463
2853
  CSV = "csv"
1464
2854
  DELTA = "delta"
@@ -1494,7 +2884,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
1494
2884
 
1495
2885
  class UnityCatalogFunctionSqlModel(BaseModel):
1496
2886
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1497
- function: UnityCatalogFunctionModel
2887
+ function: FunctionModel
1498
2888
  ddl: str
1499
2889
  parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
1500
2890
  test: Optional[UnityCatalogFunctionSqlTestModel] = None
@@ -1522,21 +2912,132 @@ class ResourcesModel(BaseModel):
1522
2912
  warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
1523
2913
  databases: dict[str, DatabaseModel] = Field(default_factory=dict)
1524
2914
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
2915
+ apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
2916
+
2917
+ @model_validator(mode="after")
2918
+ def update_genie_warehouses(self) -> Self:
2919
+ """
2920
+ Automatically populate warehouses from genie_rooms.
2921
+
2922
+ Warehouses are extracted from each Genie room and added to the
2923
+ resources if they don't already exist (based on warehouse_id).
2924
+ """
2925
+ if not self.genie_rooms:
2926
+ return self
2927
+
2928
+ # Process warehouses from all genie rooms
2929
+ for genie_room in self.genie_rooms.values():
2930
+ genie_room: GenieRoomModel
2931
+ warehouse: Optional[WarehouseModel] = genie_room.warehouse
2932
+
2933
+ if warehouse is None:
2934
+ continue
2935
+
2936
+ # Check if warehouse already exists based on warehouse_id
2937
+ warehouse_exists: bool = any(
2938
+ existing_warehouse.warehouse_id == warehouse.warehouse_id
2939
+ for existing_warehouse in self.warehouses.values()
2940
+ )
2941
+
2942
+ if not warehouse_exists:
2943
+ warehouse_key: str = normalize_name(
2944
+ "_".join([genie_room.name, warehouse.warehouse_id])
2945
+ )
2946
+ self.warehouses[warehouse_key] = warehouse
2947
+ logger.trace(
2948
+ "Added warehouse from Genie room",
2949
+ room=genie_room.name,
2950
+ warehouse=warehouse.warehouse_id,
2951
+ key=warehouse_key,
2952
+ )
2953
+
2954
+ return self
2955
+
2956
+ @model_validator(mode="after")
2957
+ def update_genie_tables(self) -> Self:
2958
+ """
2959
+ Automatically populate tables from genie_rooms.
2960
+
2961
+ Tables are extracted from each Genie room and added to the
2962
+ resources if they don't already exist (based on full_name).
2963
+ """
2964
+ if not self.genie_rooms:
2965
+ return self
2966
+
2967
+ # Process tables from all genie rooms
2968
+ for genie_room in self.genie_rooms.values():
2969
+ genie_room: GenieRoomModel
2970
+ for table in genie_room.tables:
2971
+ table: TableModel
2972
+ table_exists: bool = any(
2973
+ existing_table.full_name == table.full_name
2974
+ for existing_table in self.tables.values()
2975
+ )
2976
+ if not table_exists:
2977
+ table_key: str = normalize_name(
2978
+ "_".join([genie_room.name, table.full_name])
2979
+ )
2980
+ self.tables[table_key] = table
2981
+ logger.trace(
2982
+ "Added table from Genie room",
2983
+ room=genie_room.name,
2984
+ table=table.name,
2985
+ key=table_key,
2986
+ )
2987
+
2988
+ return self
2989
+
2990
+ @model_validator(mode="after")
2991
+ def update_genie_functions(self) -> Self:
2992
+ """
2993
+ Automatically populate functions from genie_rooms.
2994
+
2995
+ Functions are extracted from each Genie room and added to the
2996
+ resources if they don't already exist (based on full_name).
2997
+ """
2998
+ if not self.genie_rooms:
2999
+ return self
3000
+
3001
+ # Process functions from all genie rooms
3002
+ for genie_room in self.genie_rooms.values():
3003
+ genie_room: GenieRoomModel
3004
+ for function in genie_room.functions:
3005
+ function: FunctionModel
3006
+ function_exists: bool = any(
3007
+ existing_function.full_name == function.full_name
3008
+ for existing_function in self.functions.values()
3009
+ )
3010
+ if not function_exists:
3011
+ function_key: str = normalize_name(
3012
+ "_".join([genie_room.name, function.full_name])
3013
+ )
3014
+ self.functions[function_key] = function
3015
+ logger.trace(
3016
+ "Added function from Genie room",
3017
+ room=genie_room.name,
3018
+ function=function.name,
3019
+ key=function_key,
3020
+ )
3021
+
3022
+ return self
1525
3023
 
1526
3024
 
1527
3025
  class AppConfig(BaseModel):
1528
3026
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1529
3027
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
3028
+ service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
1530
3029
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
1531
3030
  resources: Optional[ResourcesModel] = None
1532
3031
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
1533
3032
  tools: dict[str, ToolModel] = Field(default_factory=dict)
1534
3033
  guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
3034
+ middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
1535
3035
  memory: Optional[MemoryModel] = None
1536
3036
  prompts: dict[str, PromptModel] = Field(default_factory=dict)
1537
3037
  agents: dict[str, AgentModel] = Field(default_factory=dict)
1538
3038
  app: Optional[AppModel] = None
1539
3039
  evaluation: Optional[EvaluationModel] = None
3040
+ optimizations: Optional[OptimizationsModel] = None
1540
3041
  datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
1541
3042
  unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
1542
3043
  default_factory=list
@@ -1558,10 +3059,10 @@ class AppConfig(BaseModel):
1558
3059
 
1559
3060
  def initialize(self) -> None:
1560
3061
  from dao_ai.hooks.core import create_hooks
3062
+ from dao_ai.logging import configure_logging
1561
3063
 
1562
3064
  if self.app and self.app.log_level:
1563
- logger.remove()
1564
- logger.add(sys.stderr, level=self.app.log_level)
3065
+ configure_logging(level=self.app.log_level)
1565
3066
 
1566
3067
  logger.debug("Calling initialization hooks...")
1567
3068
  initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
@@ -1605,21 +3106,45 @@ class AppConfig(BaseModel):
1605
3106
  def create_agent(
1606
3107
  self,
1607
3108
  w: WorkspaceClient | None = None,
3109
+ vsc: "VectorSearchClient | None" = None,
3110
+ pat: str | None = None,
3111
+ client_id: str | None = None,
3112
+ client_secret: str | None = None,
3113
+ workspace_host: str | None = None,
1608
3114
  ) -> None:
1609
3115
  from dao_ai.providers.base import ServiceProvider
1610
3116
  from dao_ai.providers.databricks import DatabricksProvider
1611
3117
 
1612
- provider: ServiceProvider = DatabricksProvider(w=w)
3118
+ provider: ServiceProvider = DatabricksProvider(
3119
+ w=w,
3120
+ vsc=vsc,
3121
+ pat=pat,
3122
+ client_id=client_id,
3123
+ client_secret=client_secret,
3124
+ workspace_host=workspace_host,
3125
+ )
1613
3126
  provider.create_agent(self)
1614
3127
 
1615
3128
  def deploy_agent(
1616
3129
  self,
1617
3130
  w: WorkspaceClient | None = None,
3131
+ vsc: "VectorSearchClient | None" = None,
3132
+ pat: str | None = None,
3133
+ client_id: str | None = None,
3134
+ client_secret: str | None = None,
3135
+ workspace_host: str | None = None,
1618
3136
  ) -> None:
1619
3137
  from dao_ai.providers.base import ServiceProvider
1620
3138
  from dao_ai.providers.databricks import DatabricksProvider
1621
3139
 
1622
- provider: ServiceProvider = DatabricksProvider(w=w)
3140
+ provider: ServiceProvider = DatabricksProvider(
3141
+ w=w,
3142
+ vsc=vsc,
3143
+ pat=pat,
3144
+ client_id=client_id,
3145
+ client_secret=client_secret,
3146
+ workspace_host=workspace_host,
3147
+ )
1623
3148
  provider.deploy_agent(self)
1624
3149
 
1625
3150
  def find_agents(