dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  Iterator,
13
13
  Literal,
14
14
  Optional,
15
+ Self,
15
16
  Sequence,
16
17
  TypeAlias,
17
18
  Union,
@@ -22,13 +23,20 @@ from databricks.sdk.credentials_provider import (
22
23
  CredentialsStrategy,
23
24
  ModelServingUserCredentials,
24
25
  )
26
+ from databricks.sdk.errors.platform import NotFound
25
27
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
28
+ from databricks.sdk.service.dashboards import GenieSpace
26
29
  from databricks.sdk.service.database import DatabaseInstance
30
+ from databricks.sdk.service.sql import GetWarehouseResponse
27
31
  from databricks.vector_search.client import VectorSearchClient
28
32
  from databricks.vector_search.index import VectorSearchIndex
29
33
  from databricks_langchain import (
34
+ ChatDatabricks,
35
+ DatabricksEmbeddings,
30
36
  DatabricksFunctionClient,
31
37
  )
38
+ from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
39
+ from langchain_core.embeddings import Embeddings
32
40
  from langchain_core.language_models import LanguageModelLike
33
41
  from langchain_core.messages import BaseMessage, messages_from_dict
34
42
  from langchain_core.runnables.base import RunnableLike
@@ -41,6 +49,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
49
  from mlflow.genai.prompts import PromptVersion, load_prompt
42
50
  from mlflow.models import ModelConfig
43
51
  from mlflow.models.resources import (
52
+ DatabricksApp,
44
53
  DatabricksFunction,
45
54
  DatabricksGenieSpace,
46
55
  DatabricksLakebase,
@@ -59,10 +68,13 @@ from pydantic import (
59
68
  BaseModel,
60
69
  ConfigDict,
61
70
  Field,
71
+ PrivateAttr,
62
72
  field_serializer,
63
73
  model_validator,
64
74
  )
65
75
 
76
+ from dao_ai.utils import normalize_name
77
+
66
78
 
67
79
  class HasValue(ABC):
68
80
  @abstractmethod
@@ -81,27 +93,6 @@ class HasFullName(ABC):
81
93
  def full_name(self) -> str: ...
82
94
 
83
95
 
84
- class IsDatabricksResource(ABC):
85
- on_behalf_of_user: Optional[bool] = False
86
-
87
- @abstractmethod
88
- def as_resources(self) -> Sequence[DatabricksResource]: ...
89
-
90
- @property
91
- @abstractmethod
92
- def api_scopes(self) -> Sequence[str]: ...
93
-
94
- @property
95
- def workspace_client(self) -> WorkspaceClient:
96
- credentials_strategy: CredentialsStrategy = None
97
- if self.on_behalf_of_user:
98
- credentials_strategy = ModelServingUserCredentials()
99
- logger.debug(
100
- f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
101
- )
102
- return WorkspaceClient(credentials_strategy=credentials_strategy)
103
-
104
-
105
96
  class EnvironmentVariableModel(BaseModel, HasValue):
106
97
  model_config = ConfigDict(
107
98
  frozen=True,
@@ -200,6 +191,162 @@ AnyVariable: TypeAlias = (
200
191
  )
201
192
 
202
193
 
194
+ class ServicePrincipalModel(BaseModel):
195
+ model_config = ConfigDict(
196
+ frozen=True,
197
+ use_enum_values=True,
198
+ )
199
+ client_id: AnyVariable
200
+ client_secret: AnyVariable
201
+
202
+
203
+ class IsDatabricksResource(ABC, BaseModel):
204
+ """
205
+ Base class for Databricks resources with authentication support.
206
+
207
+ Authentication Options:
208
+ ----------------------
209
+ 1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
210
+ calling user's identity via ModelServingUserCredentials.
211
+
212
+ 2. **Service Principal (OAuth M2M)**: Provide service_principal or
213
+ (client_id + client_secret + workspace_host) for service principal auth.
214
+
215
+ 3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
216
+ to authenticate with a personal access token.
217
+
218
+ 4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
219
+ (environment variables, notebook context, etc.)
220
+
221
+ Authentication Priority:
222
+ 1. OBO (on_behalf_of_user=True)
223
+ 2. Service Principal (client_id + client_secret + workspace_host)
224
+ 3. PAT (pat + workspace_host)
225
+ 4. Ambient/default authentication
226
+ """
227
+
228
+ model_config = ConfigDict(use_enum_values=True)
229
+
230
+ on_behalf_of_user: Optional[bool] = False
231
+ service_principal: Optional[ServicePrincipalModel] = None
232
+ client_id: Optional[AnyVariable] = None
233
+ client_secret: Optional[AnyVariable] = None
234
+ workspace_host: Optional[AnyVariable] = None
235
+ pat: Optional[AnyVariable] = None
236
+
237
+ # Private attribute to cache the workspace client (lazy instantiation)
238
+ _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
239
+
240
+ @abstractmethod
241
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
242
+
243
+ @property
244
+ @abstractmethod
245
+ def api_scopes(self) -> Sequence[str]: ...
246
+
247
+ @model_validator(mode="after")
248
+ def _expand_service_principal(self) -> Self:
249
+ """Expand service_principal into client_id and client_secret if provided."""
250
+ if self.service_principal is not None:
251
+ if self.client_id is None:
252
+ self.client_id = self.service_principal.client_id
253
+ if self.client_secret is None:
254
+ self.client_secret = self.service_principal.client_secret
255
+ return self
256
+
257
+ @model_validator(mode="after")
258
+ def _validate_auth_not_mixed(self) -> Self:
259
+ """Validate that OAuth and PAT authentication are not both provided."""
260
+ has_oauth: bool = self.client_id is not None and self.client_secret is not None
261
+ has_pat: bool = self.pat is not None
262
+
263
+ if has_oauth and has_pat:
264
+ raise ValueError(
265
+ "Cannot use both OAuth and user authentication methods. "
266
+ "Please provide either OAuth credentials or user credentials."
267
+ )
268
+ return self
269
+
270
+ @property
271
+ def workspace_client(self) -> WorkspaceClient:
272
+ """
273
+ Get a WorkspaceClient configured with the appropriate authentication.
274
+
275
+ The client is lazily instantiated on first access and cached for subsequent calls.
276
+
277
+ Authentication priority:
278
+ 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
279
+ 2. If service principal credentials are configured (client_id, client_secret,
280
+ workspace_host), uses OAuth M2M
281
+ 3. If PAT is configured, uses token authentication
282
+ 4. Otherwise, uses default/ambient authentication
283
+ """
284
+ # Return cached client if already instantiated
285
+ if self._workspace_client is not None:
286
+ return self._workspace_client
287
+
288
+ from dao_ai.utils import normalize_host
289
+
290
+ # Check for OBO first (highest priority)
291
+ if self.on_behalf_of_user:
292
+ credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
293
+ logger.debug(
294
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
295
+ f"with OBO credentials strategy"
296
+ )
297
+ self._workspace_client = WorkspaceClient(
298
+ credentials_strategy=credentials_strategy
299
+ )
300
+ return self._workspace_client
301
+
302
+ # Check for service principal credentials
303
+ client_id_value: str | None = (
304
+ value_of(self.client_id) if self.client_id else None
305
+ )
306
+ client_secret_value: str | None = (
307
+ value_of(self.client_secret) if self.client_secret else None
308
+ )
309
+ workspace_host_value: str | None = (
310
+ normalize_host(value_of(self.workspace_host))
311
+ if self.workspace_host
312
+ else None
313
+ )
314
+
315
+ if client_id_value and client_secret_value and workspace_host_value:
316
+ logger.debug(
317
+ f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
318
+ f"client_id={client_id_value}, host={workspace_host_value}"
319
+ )
320
+ self._workspace_client = WorkspaceClient(
321
+ host=workspace_host_value,
322
+ client_id=client_id_value,
323
+ client_secret=client_secret_value,
324
+ auth_type="oauth-m2m",
325
+ )
326
+ return self._workspace_client
327
+
328
+ # Check for PAT authentication
329
+ pat_value: str | None = value_of(self.pat) if self.pat else None
330
+ if pat_value:
331
+ logger.debug(
332
+ f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
333
+ )
334
+ self._workspace_client = WorkspaceClient(
335
+ host=workspace_host_value,
336
+ token=pat_value,
337
+ auth_type="pat",
338
+ )
339
+ return self._workspace_client
340
+
341
+ # Default: use ambient authentication
342
+ logger.debug(
343
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
344
+ "with default/ambient authentication"
345
+ )
346
+ self._workspace_client = WorkspaceClient()
347
+ return self._workspace_client
348
+
349
+
203
350
  class Privilege(str, Enum):
204
351
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
205
352
  USE_CATALOG = "USE_CATALOG"
@@ -226,9 +373,21 @@ class Privilege(str, Enum):
226
373
 
227
374
  class PermissionModel(BaseModel):
228
375
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
229
- principals: list[str] = Field(default_factory=list)
376
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
230
377
  privileges: list[Privilege]
231
378
 
379
+ @model_validator(mode="after")
380
+ def resolve_principals(self) -> Self:
381
+ """Resolve ServicePrincipalModel objects to their client_id."""
382
+ resolved: list[str] = []
383
+ for principal in self.principals:
384
+ if isinstance(principal, ServicePrincipalModel):
385
+ resolved.append(value_of(principal.client_id))
386
+ else:
387
+ resolved.append(principal)
388
+ self.principals = resolved
389
+ return self
390
+
232
391
 
233
392
  class SchemaModel(BaseModel, HasFullName):
234
393
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -248,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
248
407
  provider.create_schema(self)
249
408
 
250
409
 
251
- class TableModel(BaseModel, HasFullName, IsDatabricksResource):
410
+ class DatabricksAppModel(IsDatabricksResource, HasFullName):
411
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
412
+ name: str
413
+ url: str
414
+
415
+ @property
416
+ def full_name(self) -> str:
417
+ return self.name
418
+
419
+ @property
420
+ def api_scopes(self) -> Sequence[str]:
421
+ return ["apps.apps"]
422
+
423
+ def as_resources(self) -> Sequence[DatabricksResource]:
424
+ return [
425
+ DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
426
+ ]
427
+
428
+
429
+ class TableModel(IsDatabricksResource, HasFullName):
252
430
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
253
431
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
254
432
  name: Optional[str] = None
@@ -274,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
274
452
  def api_scopes(self) -> Sequence[str]:
275
453
  return []
276
454
 
455
+ def exists(self) -> bool:
456
+ """Check if the table exists in Unity Catalog.
457
+
458
+ Returns:
459
+ True if the table exists, False otherwise.
460
+ """
461
+ try:
462
+ self.workspace_client.tables.get(full_name=self.full_name)
463
+ return True
464
+ except NotFound:
465
+ logger.debug(f"Table not found: {self.full_name}")
466
+ return False
467
+ except Exception as e:
468
+ logger.warning(f"Error checking table existence for {self.full_name}: {e}")
469
+ return False
470
+
277
471
  def as_resources(self) -> Sequence[DatabricksResource]:
278
472
  resources: list[DatabricksResource] = []
279
473
 
@@ -317,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
317
511
  return resources
318
512
 
319
513
 
320
- class LLMModel(BaseModel, IsDatabricksResource):
514
+ class LLMModel(IsDatabricksResource):
321
515
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
322
516
  name: str
517
+ description: Optional[str] = None
323
518
  temperature: Optional[float] = 0.1
324
519
  max_tokens: Optional[int] = 8192
325
520
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
521
+ use_responses_api: Optional[bool] = Field(
522
+ default=False,
523
+ description="Use Responses API for ResponsesAgent endpoints",
524
+ )
326
525
 
327
526
  @property
328
527
  def api_scopes(self) -> Sequence[str]:
@@ -342,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
342
541
  ]
343
542
 
344
543
  def as_chat_model(self) -> LanguageModelLike:
345
- # Retrieve langchain chat client from workspace client to enable OBO
346
- # ChatOpenAI does not allow additional inputs at the moment, so we cannot use it directly
347
- # chat_client: LanguageModelLike = self.as_open_ai_client()
348
-
349
- # Create ChatDatabricksWrapper instance directly
350
- from dao_ai.chat_models import ChatDatabricksFiltered
351
-
352
- chat_client: LanguageModelLike = ChatDatabricksFiltered(
353
- model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
544
+ chat_client: LanguageModelLike = ChatDatabricks(
545
+ model=self.name,
546
+ temperature=self.temperature,
547
+ max_tokens=self.max_tokens,
548
+ use_responses_api=self.use_responses_api,
354
549
  )
355
- # chat_client: LanguageModelLike = ChatDatabricks(
356
- # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
357
- # )
358
550
 
359
551
  fallbacks: Sequence[LanguageModelLike] = []
360
552
  for fallback in self.fallbacks:
@@ -386,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
386
578
 
387
579
  return chat_client
388
580
 
581
+ def as_embeddings_model(self) -> Embeddings:
582
+ return DatabricksEmbeddings(endpoint=self.name)
583
+
389
584
 
390
585
  class VectorSearchEndpointType(str, Enum):
391
586
  STANDARD = "STANDARD"
@@ -405,7 +600,7 @@ class VectorSearchEndpoint(BaseModel):
405
600
  return str(value)
406
601
 
407
602
 
408
- class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
603
+ class IndexModel(IsDatabricksResource, HasFullName):
409
604
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
410
605
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
411
606
  name: str
@@ -430,12 +625,297 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
430
625
  ]
431
626
 
432
627
 
433
- class GenieRoomModel(BaseModel, IsDatabricksResource):
628
+ class FunctionModel(IsDatabricksResource, HasFullName):
629
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
630
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
631
+ name: Optional[str] = None
632
+
633
+ @model_validator(mode="after")
634
+ def validate_name_or_schema_required(self) -> Self:
635
+ if not self.name and not self.schema_model:
636
+ raise ValueError(
637
+ "Either 'name' or 'schema_model' must be provided for FunctionModel"
638
+ )
639
+ return self
640
+
641
+ @property
642
+ def full_name(self) -> str:
643
+ if self.schema_model:
644
+ name: str = ""
645
+ if self.name:
646
+ name = f".{self.name}"
647
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
648
+ return self.name
649
+
650
+ def exists(self) -> bool:
651
+ """Check if the function exists in Unity Catalog.
652
+
653
+ Returns:
654
+ True if the function exists, False otherwise.
655
+ """
656
+ try:
657
+ self.workspace_client.functions.get(name=self.full_name)
658
+ return True
659
+ except NotFound:
660
+ logger.debug(f"Function not found: {self.full_name}")
661
+ return False
662
+ except Exception as e:
663
+ logger.warning(
664
+ f"Error checking function existence for {self.full_name}: {e}"
665
+ )
666
+ return False
667
+
668
+ def as_resources(self) -> Sequence[DatabricksResource]:
669
+ resources: list[DatabricksResource] = []
670
+ if self.name:
671
+ resources.append(
672
+ DatabricksFunction(
673
+ function_name=self.full_name,
674
+ on_behalf_of_user=self.on_behalf_of_user,
675
+ )
676
+ )
677
+ else:
678
+ w: WorkspaceClient = self.workspace_client
679
+ schema_full_name: str = self.schema_model.full_name
680
+ functions: Iterator[FunctionInfo] = w.functions.list(
681
+ catalog_name=self.schema_model.catalog_name,
682
+ schema_name=self.schema_model.schema_name,
683
+ )
684
+ resources.extend(
685
+ [
686
+ DatabricksFunction(
687
+ function_name=f"{schema_full_name}.{function.name}",
688
+ on_behalf_of_user=self.on_behalf_of_user,
689
+ )
690
+ for function in functions
691
+ ]
692
+ )
693
+
694
+ return resources
695
+
696
+ @property
697
+ def api_scopes(self) -> Sequence[str]:
698
+ return ["sql.statement-execution"]
699
+
700
+
701
+ class WarehouseModel(IsDatabricksResource):
702
+ model_config = ConfigDict()
703
+ name: str
704
+ description: Optional[str] = None
705
+ warehouse_id: AnyVariable
706
+
707
+ @property
708
+ def api_scopes(self) -> Sequence[str]:
709
+ return [
710
+ "sql.warehouses",
711
+ "sql.statement-execution",
712
+ ]
713
+
714
+ def as_resources(self) -> Sequence[DatabricksResource]:
715
+ return [
716
+ DatabricksSQLWarehouse(
717
+ warehouse_id=value_of(self.warehouse_id),
718
+ on_behalf_of_user=self.on_behalf_of_user,
719
+ )
720
+ ]
721
+
722
+ @model_validator(mode="after")
723
+ def update_warehouse_id(self) -> Self:
724
+ self.warehouse_id = value_of(self.warehouse_id)
725
+ return self
726
+
727
+
728
+ class GenieRoomModel(IsDatabricksResource):
434
729
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
435
730
  name: str
436
731
  description: Optional[str] = None
437
732
  space_id: AnyVariable
438
733
 
734
+ _space_details: Optional[GenieSpace] = PrivateAttr(default=None)
735
+
736
+ def _get_space_details(self) -> GenieSpace:
737
+ if self._space_details is None:
738
+ self._space_details = self.workspace_client.genie.get_space(
739
+ space_id=self.space_id, include_serialized_space=True
740
+ )
741
+ return self._space_details
742
+
743
+ def _parse_serialized_space(self) -> dict[str, Any]:
744
+ """Parse the serialized_space JSON string and return the parsed data."""
745
+ import json
746
+
747
+ space_details = self._get_space_details()
748
+ if not space_details.serialized_space:
749
+ return {}
750
+
751
+ try:
752
+ return json.loads(space_details.serialized_space)
753
+ except json.JSONDecodeError as e:
754
+ logger.warning(f"Failed to parse serialized_space: {e}")
755
+ return {}
756
+
757
+ @property
758
+ def warehouse(self) -> Optional[WarehouseModel]:
759
+ """Extract warehouse information from the Genie space.
760
+
761
+ Returns:
762
+ WarehouseModel instance if warehouse_id is available, None otherwise.
763
+ """
764
+ space_details: GenieSpace = self._get_space_details()
765
+
766
+ if not space_details.warehouse_id:
767
+ return None
768
+
769
+ try:
770
+ response: GetWarehouseResponse = self.workspace_client.warehouses.get(
771
+ space_details.warehouse_id
772
+ )
773
+ warehouse_name: str = response.name or space_details.warehouse_id
774
+
775
+ warehouse_model = WarehouseModel(
776
+ name=warehouse_name,
777
+ warehouse_id=space_details.warehouse_id,
778
+ on_behalf_of_user=self.on_behalf_of_user,
779
+ service_principal=self.service_principal,
780
+ client_id=self.client_id,
781
+ client_secret=self.client_secret,
782
+ workspace_host=self.workspace_host,
783
+ pat=self.pat,
784
+ )
785
+
786
+ # Share the cached workspace client if available
787
+ if self._workspace_client is not None:
788
+ warehouse_model._workspace_client = self._workspace_client
789
+
790
+ return warehouse_model
791
+ except Exception as e:
792
+ logger.warning(
793
+ f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
794
+ )
795
+ return None
796
+
797
+ @property
798
+ def tables(self) -> list[TableModel]:
799
+ """Extract tables from the serialized Genie space.
800
+
801
+ Databricks Genie stores tables in: data_sources.tables[].identifier
802
+ Only includes tables that actually exist in Unity Catalog.
803
+ """
804
+ parsed_space = self._parse_serialized_space()
805
+ tables_list: list[TableModel] = []
806
+
807
+ # Primary structure: data_sources.tables with 'identifier' field
808
+ if "data_sources" in parsed_space:
809
+ data_sources = parsed_space["data_sources"]
810
+ if isinstance(data_sources, dict) and "tables" in data_sources:
811
+ tables_data = data_sources["tables"]
812
+ if isinstance(tables_data, list):
813
+ for table_item in tables_data:
814
+ table_name: str | None = None
815
+ if isinstance(table_item, dict):
816
+ # Standard Databricks structure uses 'identifier'
817
+ table_name = table_item.get("identifier") or table_item.get(
818
+ "name"
819
+ )
820
+ elif isinstance(table_item, str):
821
+ table_name = table_item
822
+
823
+ if table_name:
824
+ table_model = TableModel(
825
+ name=table_name,
826
+ on_behalf_of_user=self.on_behalf_of_user,
827
+ service_principal=self.service_principal,
828
+ client_id=self.client_id,
829
+ client_secret=self.client_secret,
830
+ workspace_host=self.workspace_host,
831
+ pat=self.pat,
832
+ )
833
+ # Share the cached workspace client if available
834
+ if self._workspace_client is not None:
835
+ table_model._workspace_client = self._workspace_client
836
+
837
+ # Verify the table exists before adding
838
+ if not table_model.exists():
839
+ continue
840
+
841
+ tables_list.append(table_model)
842
+
843
+ return tables_list
844
+
845
+ @property
846
+ def functions(self) -> list[FunctionModel]:
847
+ """Extract functions from the serialized Genie space.
848
+
849
+ Databricks Genie stores functions in multiple locations:
850
+ - instructions.sql_functions[].identifier (SQL functions)
851
+ - data_sources.functions[].identifier (other functions)
852
+ Only includes functions that actually exist in Unity Catalog.
853
+ """
854
+ parsed_space = self._parse_serialized_space()
855
+ functions_list: list[FunctionModel] = []
856
+ seen_functions: set[str] = set()
857
+
858
+ def add_function_if_exists(function_name: str) -> None:
859
+ """Helper to add a function if it exists and hasn't been added."""
860
+ if function_name in seen_functions:
861
+ return
862
+
863
+ seen_functions.add(function_name)
864
+ function_model = FunctionModel(
865
+ name=function_name,
866
+ on_behalf_of_user=self.on_behalf_of_user,
867
+ service_principal=self.service_principal,
868
+ client_id=self.client_id,
869
+ client_secret=self.client_secret,
870
+ workspace_host=self.workspace_host,
871
+ pat=self.pat,
872
+ )
873
+ # Share the cached workspace client if available
874
+ if self._workspace_client is not None:
875
+ function_model._workspace_client = self._workspace_client
876
+
877
+ # Verify the function exists before adding
878
+ if not function_model.exists():
879
+ return
880
+
881
+ functions_list.append(function_model)
882
+
883
+ # Primary structure: instructions.sql_functions with 'identifier' field
884
+ if "instructions" in parsed_space:
885
+ instructions = parsed_space["instructions"]
886
+ if isinstance(instructions, dict) and "sql_functions" in instructions:
887
+ sql_functions_data = instructions["sql_functions"]
888
+ if isinstance(sql_functions_data, list):
889
+ for function_item in sql_functions_data:
890
+ if isinstance(function_item, dict):
891
+ # SQL functions use 'identifier' field
892
+ function_name = function_item.get(
893
+ "identifier"
894
+ ) or function_item.get("name")
895
+ if function_name:
896
+ add_function_if_exists(function_name)
897
+
898
+ # Secondary structure: data_sources.functions with 'identifier' field
899
+ if "data_sources" in parsed_space:
900
+ data_sources = parsed_space["data_sources"]
901
+ if isinstance(data_sources, dict) and "functions" in data_sources:
902
+ functions_data = data_sources["functions"]
903
+ if isinstance(functions_data, list):
904
+ for function_item in functions_data:
905
+ function_name: str | None = None
906
+ if isinstance(function_item, dict):
907
+ # Standard Databricks structure uses 'identifier'
908
+ function_name = function_item.get(
909
+ "identifier"
910
+ ) or function_item.get("name")
911
+ elif isinstance(function_item, str):
912
+ function_name = function_item
913
+
914
+ if function_name:
915
+ add_function_if_exists(function_name)
916
+
917
+ return functions_list
918
+
439
919
  @property
440
920
  def api_scopes(self) -> Sequence[str]:
441
921
  return [
@@ -451,12 +931,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
451
931
  ]
452
932
 
453
933
  @model_validator(mode="after")
454
- def update_space_id(self):
934
+ def update_space_id(self) -> Self:
455
935
  self.space_id = value_of(self.space_id)
456
936
  return self
457
937
 
938
+ @model_validator(mode="after")
939
+ def update_description_from_space(self) -> Self:
940
+ """Populate description from GenieSpace if not provided."""
941
+ if not self.description:
942
+ try:
943
+ space_details = self._get_space_details()
944
+ if space_details.description:
945
+ self.description = space_details.description
946
+ except Exception as e:
947
+ logger.debug(f"Could not fetch description from Genie space: {e}")
948
+ return self
949
+
458
950
 
459
- class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
951
+ class VolumeModel(IsDatabricksResource, HasFullName):
460
952
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
461
953
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
462
954
  name: str
@@ -516,7 +1008,7 @@ class VolumePathModel(BaseModel, HasFullName):
516
1008
  provider.create_path(self)
517
1009
 
518
1010
 
519
- class VectorStoreModel(BaseModel, IsDatabricksResource):
1011
+ class VectorStoreModel(IsDatabricksResource):
520
1012
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
521
1013
  embedding_model: Optional[LLMModel] = None
522
1014
  index: Optional[IndexModel] = None
@@ -530,13 +1022,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
530
1022
  embedding_source_column: str
531
1023
 
532
1024
  @model_validator(mode="after")
533
- def set_default_embedding_model(self):
1025
+ def set_default_embedding_model(self) -> Self:
534
1026
  if not self.embedding_model:
535
1027
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
536
1028
  return self
537
1029
 
538
1030
  @model_validator(mode="after")
539
- def set_default_primary_key(self):
1031
+ def set_default_primary_key(self) -> Self:
540
1032
  if self.primary_key is None:
541
1033
  from dao_ai.providers.databricks import DatabricksProvider
542
1034
 
@@ -557,14 +1049,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
557
1049
  return self
558
1050
 
559
1051
  @model_validator(mode="after")
560
- def set_default_index(self):
1052
+ def set_default_index(self) -> Self:
561
1053
  if self.index is None:
562
1054
  name: str = f"{self.source_table.name}_index"
563
1055
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
564
1056
  return self
565
1057
 
566
1058
  @model_validator(mode="after")
567
- def set_default_endpoint(self):
1059
+ def set_default_endpoint(self) -> Self:
568
1060
  if self.endpoint is None:
569
1061
  from dao_ai.providers.databricks import (
570
1062
  DatabricksProvider,
@@ -615,62 +1107,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
615
1107
  provider.create_vector_store(self)
616
1108
 
617
1109
 
618
- class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
619
- model_config = ConfigDict()
620
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
621
- name: Optional[str] = None
622
-
623
- @model_validator(mode="after")
624
- def validate_name_or_schema_required(self) -> "FunctionModel":
625
- if not self.name and not self.schema_model:
626
- raise ValueError(
627
- "Either 'name' or 'schema_model' must be provided for FunctionModel"
628
- )
629
- return self
630
-
631
- @property
632
- def full_name(self) -> str:
633
- if self.schema_model:
634
- name: str = ""
635
- if self.name:
636
- name = f".{self.name}"
637
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
638
- return self.name
639
-
640
- def as_resources(self) -> Sequence[DatabricksResource]:
641
- resources: list[DatabricksResource] = []
642
- if self.name:
643
- resources.append(
644
- DatabricksFunction(
645
- function_name=self.full_name,
646
- on_behalf_of_user=self.on_behalf_of_user,
647
- )
648
- )
649
- else:
650
- w: WorkspaceClient = self.workspace_client
651
- schema_full_name: str = self.schema_model.full_name
652
- functions: Iterator[FunctionInfo] = w.functions.list(
653
- catalog_name=self.schema_model.catalog_name,
654
- schema_name=self.schema_model.schema_name,
655
- )
656
- resources.extend(
657
- [
658
- DatabricksFunction(
659
- function_name=f"{schema_full_name}.{function.name}",
660
- on_behalf_of_user=self.on_behalf_of_user,
661
- )
662
- for function in functions
663
- ]
664
- )
665
-
666
- return resources
667
-
668
- @property
669
- def api_scopes(self) -> Sequence[str]:
670
- return ["sql.statement-execution"]
671
-
672
-
673
- class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
1110
+ class ConnectionModel(IsDatabricksResource, HasFullName):
674
1111
  model_config = ConfigDict()
675
1112
  name: str
676
1113
 
@@ -697,34 +1134,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
697
1134
  ]
698
1135
 
699
1136
 
700
- class WarehouseModel(BaseModel, IsDatabricksResource):
701
- model_config = ConfigDict()
702
- name: str
703
- description: Optional[str] = None
704
- warehouse_id: AnyVariable
705
-
706
- @property
707
- def api_scopes(self) -> Sequence[str]:
708
- return [
709
- "sql.warehouses",
710
- "sql.statement-execution",
711
- ]
712
-
713
- def as_resources(self) -> Sequence[DatabricksResource]:
714
- return [
715
- DatabricksSQLWarehouse(
716
- warehouse_id=value_of(self.warehouse_id),
717
- on_behalf_of_user=self.on_behalf_of_user,
718
- )
719
- ]
720
-
721
- @model_validator(mode="after")
722
- def update_warehouse_id(self):
723
- self.warehouse_id = value_of(self.warehouse_id)
724
- return self
725
-
1137
+ class DatabaseModel(IsDatabricksResource):
1138
+ """
1139
+ Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
1140
+
1141
+ Authentication is inherited from IsDatabricksResource. Additionally supports:
1142
+ - user/password: For user-based database authentication
1143
+
1144
+ Connection Types (determined by fields provided):
1145
+ - Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
1146
+ - Standard PostgreSQL: Provide `host` (authentication required via user/password)
1147
+
1148
+ Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
1149
+
1150
+ Example Databricks Lakebase with Service Principal:
1151
+ ```yaml
1152
+ databases:
1153
+ my_lakebase:
1154
+ name: my-database
1155
+ instance_name: my-lakebase-instance
1156
+ service_principal:
1157
+ client_id:
1158
+ env: SERVICE_PRINCIPAL_CLIENT_ID
1159
+ client_secret:
1160
+ scope: my-scope
1161
+ secret: sp-client-secret
1162
+ workspace_host:
1163
+ env: DATABRICKS_HOST
1164
+ ```
1165
+
1166
+ Example Databricks Lakebase with Ambient Authentication:
1167
+ ```yaml
1168
+ databases:
1169
+ my_lakebase:
1170
+ name: my-database
1171
+ instance_name: my-lakebase-instance
1172
+ on_behalf_of_user: true
1173
+ ```
1174
+
1175
+ Example Standard PostgreSQL:
1176
+ ```yaml
1177
+ databases:
1178
+ my_postgres:
1179
+ name: my-database
1180
+ host: my-postgres-host.example.com
1181
+ port: 5432
1182
+ database: my_db
1183
+ user: my_user
1184
+ password:
1185
+ env: PGPASSWORD
1186
+ ```
1187
+ """
726
1188
 
727
- class DatabaseModel(BaseModel, IsDatabricksResource):
728
1189
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
729
1190
  name: str
730
1191
  instance_name: Optional[str] = None
@@ -737,80 +1198,137 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
737
1198
  timeout_seconds: Optional[int] = 10
738
1199
  capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
739
1200
  node_count: Optional[int] = None
1201
+ # Database-specific auth (user identity for DB connection)
740
1202
  user: Optional[AnyVariable] = None
741
1203
  password: Optional[AnyVariable] = None
742
- client_id: Optional[AnyVariable] = None
743
- client_secret: Optional[AnyVariable] = None
744
- workspace_host: Optional[AnyVariable] = None
745
1204
 
746
1205
  @property
747
1206
  def api_scopes(self) -> Sequence[str]:
748
- return []
1207
+ return ["database.database-instances"]
1208
+
1209
+ @property
1210
+ def is_lakebase(self) -> bool:
1211
+ """Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
1212
+ return self.instance_name is not None
749
1213
 
750
1214
  def as_resources(self) -> Sequence[DatabricksResource]:
751
- return [
752
- DatabricksLakebase(
753
- database_instance_name=self.instance_name,
754
- on_behalf_of_user=self.on_behalf_of_user,
755
- )
756
- ]
1215
+ if self.is_lakebase:
1216
+ return [
1217
+ DatabricksLakebase(
1218
+ database_instance_name=self.instance_name,
1219
+ on_behalf_of_user=self.on_behalf_of_user,
1220
+ )
1221
+ ]
1222
+ return []
757
1223
 
758
1224
  @model_validator(mode="after")
759
- def update_instance_name(self):
760
- if self.instance_name is None:
761
- self.instance_name = self.name
1225
+ def validate_connection_type(self) -> Self:
1226
+ """Validate connection configuration based on type.
762
1227
 
1228
+ - If instance_name is provided: Databricks Lakebase connection
1229
+ (host is optional - will be fetched from API if not provided)
1230
+ - If only host is provided: Standard PostgreSQL connection
1231
+ (must not have instance_name)
1232
+ """
1233
+ if not self.instance_name and not self.host:
1234
+ raise ValueError(
1235
+ "Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
1236
+ )
763
1237
  return self
764
1238
 
765
1239
  @model_validator(mode="after")
766
- def update_user(self):
767
- if self.client_id or self.user:
1240
+ def update_user(self) -> Self:
1241
+ # Skip if using OBO (passive auth), explicit credentials, or explicit user
1242
+ if self.on_behalf_of_user or self.client_id or self.user or self.pat:
768
1243
  return self
769
1244
 
770
- self.user = self.workspace_client.current_user.me().user_name
771
- if not self.user:
772
- raise ValueError(
773
- "Unable to determine current user. Please provide a user name or OAuth credentials."
774
- )
1245
+ # For standard PostgreSQL, we need explicit user credentials
1246
+ # For Lakebase with no auth, ambient auth is allowed
1247
+ if not self.is_lakebase:
1248
+ # Standard PostgreSQL - try to determine current user for local development
1249
+ try:
1250
+ self.user = self.workspace_client.current_user.me().user_name
1251
+ except Exception as e:
1252
+ logger.warning(
1253
+ f"Could not determine current user for PostgreSQL database: {e}. "
1254
+ f"Please provide explicit user credentials."
1255
+ )
1256
+ else:
1257
+ # For Lakebase, try to determine current user but don't fail if we can't
1258
+ try:
1259
+ self.user = self.workspace_client.current_user.me().user_name
1260
+ except Exception:
1261
+ # If we can't determine user and no explicit auth, that's okay
1262
+ # for Lakebase with ambient auth - credentials will be injected at runtime
1263
+ pass
775
1264
 
776
1265
  return self
777
1266
 
778
1267
  @model_validator(mode="after")
779
- def update_host(self):
1268
+ def update_host(self) -> Self:
780
1269
  if self.host is not None:
781
1270
  return self
782
1271
 
783
- existing_instance: DatabaseInstance = (
784
- self.workspace_client.database.get_database_instance(
785
- name=self.instance_name
786
- )
787
- )
788
- self.host = existing_instance.read_write_dns
1272
+ # If instance_name is provided (Lakebase), try to fetch host from existing instance
1273
+ # This may fail for OBO/ambient auth during model logging (before deployment)
1274
+ if self.is_lakebase:
1275
+ try:
1276
+ existing_instance: DatabaseInstance = (
1277
+ self.workspace_client.database.get_database_instance(
1278
+ name=self.instance_name
1279
+ )
1280
+ )
1281
+ self.host = existing_instance.read_write_dns
1282
+ except Exception as e:
1283
+ # For Lakebase with OBO/ambient auth, we can't fetch at config time
1284
+ # The host will need to be provided explicitly or fetched at runtime
1285
+ if self.on_behalf_of_user:
1286
+ logger.debug(
1287
+ f"Could not fetch host for database {self.instance_name} "
1288
+ f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1289
+ )
1290
+ else:
1291
+ raise ValueError(
1292
+ f"Could not fetch host for database {self.instance_name}. "
1293
+ f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1294
+ )
789
1295
  return self
790
1296
 
791
1297
  @model_validator(mode="after")
792
- def validate_auth_methods(self):
1298
+ def validate_auth_methods(self) -> Self:
793
1299
  oauth_fields: Sequence[Any] = [
794
1300
  self.workspace_host,
795
1301
  self.client_id,
796
1302
  self.client_secret,
797
1303
  ]
798
1304
  has_oauth: bool = all(field is not None for field in oauth_fields)
1305
+ has_user_auth: bool = self.user is not None
1306
+ has_obo: bool = self.on_behalf_of_user is True
1307
+ has_pat: bool = self.pat is not None
799
1308
 
800
- pat_fields: Sequence[Any] = [self.user]
801
- has_user_auth: bool = all(field is not None for field in pat_fields)
1309
+ # Count how many auth methods are configured
1310
+ auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
802
1311
 
803
- if has_oauth and has_user_auth:
1312
+ if auth_methods_count > 1:
804
1313
  raise ValueError(
805
- "Cannot use both OAuth and user authentication methods. "
806
- "Please provide either OAuth credentials or user credentials."
1314
+ "Cannot mix authentication methods. "
1315
+ "Please provide exactly one of: "
1316
+ "on_behalf_of_user=true (for passive auth in model serving), "
1317
+ "OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
1318
+ "PAT (personal access token), "
1319
+ "or user credentials (user)."
807
1320
  )
808
1321
 
809
- if not has_oauth and not has_user_auth:
1322
+ # For standard PostgreSQL (host-based), at least one auth method must be configured
1323
+ # For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
1324
+ if not self.is_lakebase and auth_methods_count == 0:
810
1325
  raise ValueError(
811
- "At least one authentication method must be provided: "
812
- "either OAuth credentials (workspace_host, client_id, client_secret) "
813
- "or user credentials (user, password)."
1326
+ "PostgreSQL databases require explicit authentication. "
1327
+ "Please provide one of: "
1328
+ "OAuth credentials (workspace_host, client_id, client_secret), "
1329
+ "service_principal with workspace_host, "
1330
+ "PAT (personal access token), "
1331
+ "or user credentials (user)."
814
1332
  )
815
1333
 
816
1334
  return self
@@ -821,38 +1339,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
821
1339
  Get database connection parameters as a dictionary.
822
1340
 
823
1341
  Returns a dict with connection parameters suitable for psycopg ConnectionPool.
824
- If username is configured, it will be included; otherwise it will be omitted
825
- to allow Lakebase to authenticate using the token's identity.
1342
+
1343
+ For Lakebase: Uses Databricks-generated credentials (token-based auth).
1344
+ For standard PostgreSQL: Uses provided user/password credentials.
826
1345
  """
827
- from dao_ai.providers.base import ServiceProvider
828
- from dao_ai.providers.databricks import DatabricksProvider
1346
+ import uuid as _uuid
1347
+
1348
+ from databricks.sdk.service.database import DatabaseCredential
829
1349
 
1350
+ host: str
1351
+ port: int
1352
+ database: str
830
1353
  username: str | None = None
1354
+ password_value: str | None = None
1355
+
1356
+ # Resolve host - may need to fetch at runtime for OBO mode
1357
+ host_value: Any = self.host
1358
+ if host_value is None and self.is_lakebase and self.on_behalf_of_user:
1359
+ # Fetch host at runtime for OBO mode
1360
+ existing_instance: DatabaseInstance = (
1361
+ self.workspace_client.database.get_database_instance(
1362
+ name=self.instance_name
1363
+ )
1364
+ )
1365
+ host_value = existing_instance.read_write_dns
831
1366
 
832
- if self.client_id and self.client_secret and self.workspace_host:
833
- username = value_of(self.client_id)
834
- elif self.user:
835
- username = value_of(self.user)
1367
+ if host_value is None:
1368
+ instance_or_name = self.instance_name if self.is_lakebase else self.name
1369
+ raise ValueError(
1370
+ f"Database host not configured for {instance_or_name}. "
1371
+ "Please provide 'host' explicitly."
1372
+ )
836
1373
 
837
- host: str = value_of(self.host)
838
- port: int = value_of(self.port)
839
- database: str = value_of(self.database)
1374
+ host = value_of(host_value)
1375
+ port = value_of(self.port)
1376
+ database = value_of(self.database)
840
1377
 
841
- provider: ServiceProvider = DatabricksProvider(
842
- client_id=value_of(self.client_id),
843
- client_secret=value_of(self.client_secret),
844
- workspace_host=value_of(self.workspace_host),
845
- pat=value_of(self.password),
846
- )
1378
+ if self.is_lakebase:
1379
+ # Lakebase: Use Databricks-generated credentials
1380
+ if self.client_id and self.client_secret and self.workspace_host:
1381
+ username = value_of(self.client_id)
1382
+ elif self.user:
1383
+ username = value_of(self.user)
1384
+ # For OBO mode, no username is needed - the token identity is used
1385
+
1386
+ # Generate Databricks database credential (token)
1387
+ w: WorkspaceClient = self.workspace_client
1388
+ cred: DatabaseCredential = w.database.generate_database_credential(
1389
+ request_id=str(_uuid.uuid4()),
1390
+ instance_names=[self.instance_name],
1391
+ )
1392
+ password_value = cred.token
1393
+ else:
1394
+ # Standard PostgreSQL: Use provided credentials
1395
+ if self.user:
1396
+ username = value_of(self.user)
1397
+ if self.password:
1398
+ password_value = value_of(self.password)
847
1399
 
848
- token: str = provider.lakebase_password_provider(self.instance_name)
1400
+ if not username or not password_value:
1401
+ raise ValueError(
1402
+ f"Standard PostgreSQL databases require both 'user' and 'password'. "
1403
+ f"Database: {self.name}"
1404
+ )
849
1405
 
850
1406
  # Build connection parameters dictionary
851
1407
  params: dict[str, Any] = {
852
1408
  "dbname": database,
853
1409
  "host": host,
854
1410
  "port": port,
855
- "password": token,
1411
+ "password": password_value,
856
1412
  "sslmode": "require",
857
1413
  }
858
1414
 
@@ -883,11 +1439,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
883
1439
  def create(self, w: WorkspaceClient | None = None) -> None:
884
1440
  from dao_ai.providers.databricks import DatabricksProvider
885
1441
 
886
- provider: DatabricksProvider = DatabricksProvider()
1442
+ # Use provided workspace client or fall back to resource's own workspace_client
1443
+ if w is None:
1444
+ w = self.workspace_client
1445
+ provider: DatabricksProvider = DatabricksProvider(w=w)
887
1446
  provider.create_lakebase(self)
888
1447
  provider.create_lakebase_instance_role(self)
889
1448
 
890
1449
 
1450
+ class GenieLRUCacheParametersModel(BaseModel):
1451
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1452
+ capacity: int = 1000
1453
+ time_to_live_seconds: int | None = (
1454
+ 60 * 60 * 24
1455
+ ) # 1 day default, None or negative = never expires
1456
+ warehouse: WarehouseModel
1457
+
1458
+
1459
+ class GenieSemanticCacheParametersModel(BaseModel):
1460
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1461
+ time_to_live_seconds: int | None = (
1462
+ 60 * 60 * 24
1463
+ ) # 1 day default, None or negative = never expires
1464
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1465
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1466
+ question_weight: Optional[float] = (
1467
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1468
+ )
1469
+ context_weight: Optional[float] = (
1470
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1471
+ )
1472
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1473
+ embedding_dims: int | None = None # Auto-detected if None
1474
+ database: DatabaseModel
1475
+ warehouse: WarehouseModel
1476
+ table_name: str = "genie_semantic_cache"
1477
+ context_window_size: int = 3 # Number of previous turns to include for context
1478
+ max_context_tokens: int = (
1479
+ 2000 # Maximum context length to prevent extremely long embeddings
1480
+ )
1481
+
1482
+ @model_validator(mode="after")
1483
+ def compute_and_validate_weights(self) -> Self:
1484
+ """
1485
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1486
+
1487
+ Either question_weight or context_weight (or both) can be provided.
1488
+ The missing one will be computed as 1.0 - provided_weight.
1489
+ If both are provided, they must sum to 1.0.
1490
+ """
1491
+ if self.question_weight is None and self.context_weight is None:
1492
+ # Both missing - use defaults
1493
+ self.question_weight = 0.6
1494
+ self.context_weight = 0.4
1495
+ elif self.question_weight is None:
1496
+ # Compute question_weight from context_weight
1497
+ if not (0.0 <= self.context_weight <= 1.0):
1498
+ raise ValueError(
1499
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1500
+ )
1501
+ self.question_weight = 1.0 - self.context_weight
1502
+ elif self.context_weight is None:
1503
+ # Compute context_weight from question_weight
1504
+ if not (0.0 <= self.question_weight <= 1.0):
1505
+ raise ValueError(
1506
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1507
+ )
1508
+ self.context_weight = 1.0 - self.question_weight
1509
+ else:
1510
+ # Both provided - validate they sum to 1.0
1511
+ total_weight = self.question_weight + self.context_weight
1512
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1513
+ raise ValueError(
1514
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1515
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1516
+ f"of question vs context similarity in the combined score."
1517
+ )
1518
+
1519
+ return self
1520
+
1521
+
891
1522
  class SearchParametersModel(BaseModel):
892
1523
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
893
1524
  num_results: Optional[int] = 10
@@ -936,8 +1567,8 @@ class RerankParametersModel(BaseModel):
936
1567
  description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
937
1568
  )
938
1569
  cache_dir: Optional[str] = Field(
939
- default="/tmp/flashrank_cache",
940
- description="Directory to cache downloaded model weights.",
1570
+ default="~/.dao_ai/cache/flashrank",
1571
+ description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
941
1572
  )
942
1573
  columns: Optional[list[str]] = Field(
943
1574
  default_factory=list, description="Columns to rerank using DatabricksReranker"
@@ -957,14 +1588,14 @@ class RetrieverModel(BaseModel):
957
1588
  )
958
1589
 
959
1590
  @model_validator(mode="after")
960
- def set_default_columns(self):
1591
+ def set_default_columns(self) -> Self:
961
1592
  if not self.columns:
962
1593
  columns: Sequence[str] = self.vector_store.columns
963
1594
  self.columns = columns
964
1595
  return self
965
1596
 
966
1597
  @model_validator(mode="after")
967
- def set_default_reranker(self):
1598
+ def set_default_reranker(self) -> Self:
968
1599
  """Convert bool to ReRankParametersModel with defaults."""
969
1600
  if isinstance(self.rerank, bool) and self.rerank:
970
1601
  self.rerank = RerankParametersModel()
@@ -978,28 +1609,47 @@ class FunctionType(str, Enum):
978
1609
  MCP = "mcp"
979
1610
 
980
1611
 
981
- class HumanInTheLoopActionType(str, Enum):
982
- """Supported action types for human-in-the-loop interactions."""
1612
+ class HumanInTheLoopModel(BaseModel):
1613
+ """
1614
+ Configuration for Human-in-the-Loop tool approval.
983
1615
 
984
- ACCEPT = "accept"
985
- EDIT = "edit"
986
- RESPONSE = "response"
987
- DECLINE = "decline"
1616
+ This model configures when and how tools require human approval before execution.
1617
+ It maps to LangChain's HumanInTheLoopMiddleware.
988
1618
 
1619
+ LangChain supports three decision types:
1620
+ - "approve": Execute tool with original arguments
1621
+ - "edit": Modify arguments before execution
1622
+ - "reject": Skip execution with optional feedback message
1623
+ """
989
1624
 
990
- class HumanInTheLoopModel(BaseModel):
991
1625
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
992
- review_prompt: str = "Please review the tool call"
993
- interrupt_config: dict[str, Any] = Field(
994
- default_factory=lambda: {
995
- "allow_accept": True,
996
- "allow_edit": True,
997
- "allow_respond": True,
998
- "allow_decline": True,
999
- }
1626
+
1627
+ review_prompt: Optional[str] = Field(
1628
+ default=None,
1629
+ description="Message shown to the reviewer when approval is requested",
1630
+ )
1631
+
1632
+ allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
1633
+ default_factory=lambda: ["approve", "edit", "reject"],
1634
+ description="List of allowed decision types for this tool",
1000
1635
  )
1001
- decline_message: str = "Tool call declined by user"
1002
- custom_actions: Optional[dict[str, str]] = Field(default_factory=dict)
1636
+
1637
+ @model_validator(mode="after")
1638
+ def validate_and_normalize_decisions(self) -> Self:
1639
+ """Validate and normalize allowed decisions."""
1640
+ if not self.allowed_decisions:
1641
+ raise ValueError("At least one decision type must be allowed")
1642
+
1643
+ # Remove duplicates while preserving order
1644
+ seen = set()
1645
+ unique_decisions = []
1646
+ for decision in self.allowed_decisions:
1647
+ if decision not in seen:
1648
+ seen.add(decision)
1649
+ unique_decisions.append(decision)
1650
+ self.allowed_decisions = unique_decisions
1651
+
1652
+ return self
1003
1653
 
1004
1654
 
1005
1655
  class BaseFunctionModel(ABC, BaseModel):
@@ -1008,7 +1658,6 @@ class BaseFunctionModel(ABC, BaseModel):
1008
1658
  discriminator="type",
1009
1659
  )
1010
1660
  type: FunctionType
1011
- name: str
1012
1661
  human_in_the_loop: Optional[HumanInTheLoopModel] = None
1013
1662
 
1014
1663
  @abstractmethod
@@ -1025,6 +1674,7 @@ class BaseFunctionModel(ABC, BaseModel):
1025
1674
  class PythonFunctionModel(BaseFunctionModel, HasFullName):
1026
1675
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1027
1676
  type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
1677
+ name: str
1028
1678
 
1029
1679
  @property
1030
1680
  def full_name(self) -> str:
@@ -1038,8 +1688,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
1038
1688
 
1039
1689
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1040
1690
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1041
- args: Optional[dict[str, Any]] = Field(default_factory=dict)
1042
1691
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
1692
+ name: str
1693
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
1043
1694
 
1044
1695
  @property
1045
1696
  def full_name(self) -> str:
@@ -1051,7 +1702,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1051
1702
  return [create_factory_tool(self, **kwargs)]
1052
1703
 
1053
1704
  @model_validator(mode="after")
1054
- def update_args(self):
1705
+ def update_args(self) -> Self:
1055
1706
  for key, value in self.args.items():
1056
1707
  self.args[key] = value_of(value)
1057
1708
  return self
@@ -1062,7 +1713,16 @@ class TransportType(str, Enum):
1062
1713
  STDIO = "stdio"
1063
1714
 
1064
1715
 
1065
- class McpFunctionModel(BaseFunctionModel, HasFullName):
1716
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1717
+ """
1718
+ MCP Function Model with authentication inherited from IsDatabricksResource.
1719
+
1720
+ Authentication for MCP connections uses the same options as other resources:
1721
+ - Service Principal (client_id + client_secret + workspace_host)
1722
+ - PAT (pat + workspace_host)
1723
+ - OBO (on_behalf_of_user)
1724
+ """
1725
+
1066
1726
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1067
1727
  type: Literal[FunctionType.MCP] = FunctionType.MCP
1068
1728
  transport: TransportType = TransportType.STREAMABLE_HTTP
@@ -1070,10 +1730,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1070
1730
  url: Optional[AnyVariable] = None
1071
1731
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1072
1732
  args: list[str] = Field(default_factory=list)
1073
- pat: Optional[AnyVariable] = None
1074
- client_id: Optional[AnyVariable] = None
1075
- client_secret: Optional[AnyVariable] = None
1076
- workspace_host: Optional[AnyVariable] = None
1733
+ # MCP-specific fields
1077
1734
  connection: Optional[ConnectionModel] = None
1078
1735
  functions: Optional[SchemaModel] = None
1079
1736
  genie_room: Optional[GenieRoomModel] = None
@@ -1081,35 +1738,55 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1081
1738
  vector_search: Optional[VectorStoreModel] = None
1082
1739
 
1083
1740
  @property
1084
- def full_name(self) -> str:
1085
- return self.name
1741
+ def api_scopes(self) -> Sequence[str]:
1742
+ """API scopes for MCP connections."""
1743
+ return [
1744
+ "serving.serving-endpoints",
1745
+ "mcp.genie",
1746
+ "mcp.functions",
1747
+ "mcp.vectorsearch",
1748
+ "mcp.external",
1749
+ ]
1750
+
1751
+ def as_resources(self) -> Sequence[DatabricksResource]:
1752
+ """MCP functions don't declare static resources."""
1753
+ return []
1086
1754
 
1087
1755
  def _get_workspace_host(self) -> str:
1088
1756
  """
1089
1757
  Get the workspace host, either from config or from workspace client.
1090
1758
 
1091
1759
  If connection is provided, uses its workspace client.
1092
- Otherwise, falls back to creating a new workspace client.
1760
+ Otherwise, falls back to the default Databricks host.
1093
1761
 
1094
1762
  Returns:
1095
- str: The workspace host URL without trailing slash
1763
+ str: The workspace host URL with https:// scheme and without trailing slash
1096
1764
  """
1097
- from databricks.sdk import WorkspaceClient
1765
+ from dao_ai.utils import get_default_databricks_host, normalize_host
1098
1766
 
1099
1767
  # Try to get workspace_host from config
1100
1768
  workspace_host: str | None = (
1101
- value_of(self.workspace_host) if self.workspace_host else None
1769
+ normalize_host(value_of(self.workspace_host))
1770
+ if self.workspace_host
1771
+ else None
1102
1772
  )
1103
1773
 
1104
1774
  # If no workspace_host in config, get it from workspace client
1105
1775
  if not workspace_host:
1106
1776
  # Use connection's workspace client if available
1107
1777
  if self.connection:
1108
- workspace_host = self.connection.workspace_client.config.host
1778
+ workspace_host = normalize_host(
1779
+ self.connection.workspace_client.config.host
1780
+ )
1109
1781
  else:
1110
- # Create a default workspace client
1111
- w: WorkspaceClient = WorkspaceClient()
1112
- workspace_host = w.config.host
1782
+ # get_default_databricks_host already normalizes the host
1783
+ workspace_host = get_default_databricks_host()
1784
+
1785
+ if not workspace_host:
1786
+ raise ValueError(
1787
+ "Could not determine workspace host. "
1788
+ "Please set workspace_host in config or DATABRICKS_HOST environment variable."
1789
+ )
1113
1790
 
1114
1791
  # Remove trailing slash
1115
1792
  return workspace_host.rstrip("/")
@@ -1234,74 +1911,132 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1234
1911
  self.headers[key] = value_of(value)
1235
1912
  return self
1236
1913
 
1237
- @model_validator(mode="after")
1238
- def validate_auth_methods(self) -> "McpFunctionModel":
1239
- oauth_fields: Sequence[Any] = [
1240
- self.client_id,
1241
- self.client_secret,
1242
- ]
1243
- has_oauth: bool = all(field is not None for field in oauth_fields)
1244
-
1245
- pat_fields: Sequence[Any] = [self.pat]
1246
- has_user_auth: bool = all(field is not None for field in pat_fields)
1914
+ def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1915
+ from dao_ai.tools import create_mcp_tools
1247
1916
 
1248
- if has_oauth and has_user_auth:
1249
- raise ValueError(
1250
- "Cannot use both OAuth and user authentication methods. "
1251
- "Please provide either OAuth credentials or user credentials."
1252
- )
1917
+ return create_mcp_tools(self)
1253
1918
 
1254
- # Note: workspace_host is optional - it will be derived from workspace client if not provided
1255
1919
 
1256
- return self
1920
+ class UnityCatalogFunctionModel(BaseFunctionModel):
1921
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1922
+ type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
1923
+ resource: FunctionModel
1924
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1257
1925
 
1258
1926
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1259
- from dao_ai.tools import create_mcp_tools
1927
+ from dao_ai.tools import create_uc_tools
1260
1928
 
1261
- return create_mcp_tools(self)
1929
+ return create_uc_tools(self)
1930
+
1931
+
1932
+ AnyTool: TypeAlias = (
1933
+ Union[
1934
+ PythonFunctionModel,
1935
+ FactoryFunctionModel,
1936
+ UnityCatalogFunctionModel,
1937
+ McpFunctionModel,
1938
+ ]
1939
+ | str
1940
+ )
1941
+
1942
+
1943
+ class ToolModel(BaseModel):
1944
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1945
+ name: str
1946
+ function: AnyTool
1262
1947
 
1263
1948
 
1264
- class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
1949
+ class PromptModel(BaseModel, HasFullName):
1265
1950
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1266
1951
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1267
- partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1268
- type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
1952
+ name: str
1953
+ description: Optional[str] = None
1954
+ default_template: Optional[str] = None
1955
+ alias: Optional[str] = None
1956
+ version: Optional[int] = None
1957
+ tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1958
+ auto_register: bool = Field(
1959
+ default=False,
1960
+ description="Whether to automatically register the default_template to the prompt registry. "
1961
+ "If False, the prompt will only be loaded from the registry (never created/updated). "
1962
+ "Defaults to True for backward compatibility.",
1963
+ )
1964
+
1965
+ @property
1966
+ def template(self) -> str:
1967
+ from dao_ai.providers.databricks import DatabricksProvider
1968
+
1969
+ provider: DatabricksProvider = DatabricksProvider()
1970
+ prompt_version = provider.get_prompt(self)
1971
+ return prompt_version.to_single_brace_format()
1972
+
1973
+ @property
1974
+ def full_name(self) -> str:
1975
+ prompt_name: str = self.name
1976
+ if self.schema_model:
1977
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1978
+ return prompt_name
1269
1979
 
1270
1980
  @property
1271
- def full_name(self) -> str:
1272
- if self.schema_model:
1273
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1274
- return self.name
1981
+ def uri(self) -> str:
1982
+ prompt_uri: str = f"prompts:/{self.full_name}"
1275
1983
 
1276
- def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1277
- from dao_ai.tools import create_uc_tools
1984
+ if self.alias:
1985
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1986
+ elif self.version:
1987
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1988
+ else:
1989
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1278
1990
 
1279
- return create_uc_tools(self)
1991
+ return prompt_uri
1280
1992
 
1993
+ def as_prompt(self) -> PromptVersion:
1994
+ prompt_version: PromptVersion = load_prompt(self.uri)
1995
+ return prompt_version
1281
1996
 
1282
- AnyTool: TypeAlias = (
1283
- Union[
1284
- PythonFunctionModel,
1285
- FactoryFunctionModel,
1286
- UnityCatalogFunctionModel,
1287
- McpFunctionModel,
1288
- ]
1289
- | str
1290
- )
1997
+ @model_validator(mode="after")
1998
+ def validate_mutually_exclusive(self) -> Self:
1999
+ if self.alias and self.version:
2000
+ raise ValueError("Cannot specify both alias and version")
2001
+ return self
1291
2002
 
1292
2003
 
1293
- class ToolModel(BaseModel):
2004
+ class GuardrailModel(BaseModel):
1294
2005
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1295
2006
  name: str
1296
- function: AnyTool
2007
+ model: str | LLMModel
2008
+ prompt: str | PromptModel
2009
+ num_retries: Optional[int] = 3
2010
+
2011
+ @model_validator(mode="after")
2012
+ def validate_llm_model(self) -> Self:
2013
+ if isinstance(self.model, str):
2014
+ self.model = LLMModel(name=self.model)
2015
+ return self
1297
2016
 
1298
2017
 
1299
- class GuardrailModel(BaseModel):
2018
+ class MiddlewareModel(BaseModel):
2019
+ """Configuration for middleware that can be applied to agents.
2020
+
2021
+ Middleware is defined at the AppConfig level and can be referenced by name
2022
+ in agent configurations using YAML anchors for reusability.
2023
+ """
2024
+
1300
2025
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1301
- name: str
1302
- model: LLMModel
1303
- prompt: str
1304
- num_retries: Optional[int] = 3
2026
+ name: str = Field(
2027
+ description="Fully qualified name of the middleware factory function"
2028
+ )
2029
+ args: dict[str, Any] = Field(
2030
+ default_factory=dict,
2031
+ description="Arguments to pass to the middleware factory function",
2032
+ )
2033
+
2034
+ @model_validator(mode="after")
2035
+ def resolve_args(self) -> Self:
2036
+ """Resolve any variable references in args."""
2037
+ for key, value in self.args.items():
2038
+ self.args[key] = value_of(value)
2039
+ return self
1305
2040
 
1306
2041
 
1307
2042
  class StorageType(str, Enum):
@@ -1312,14 +2047,12 @@ class StorageType(str, Enum):
1312
2047
  class CheckpointerModel(BaseModel):
1313
2048
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1314
2049
  name: str
1315
- type: Optional[StorageType] = StorageType.MEMORY
1316
2050
  database: Optional[DatabaseModel] = None
1317
2051
 
1318
- @model_validator(mode="after")
1319
- def validate_postgres_requires_database(self):
1320
- if self.type == StorageType.POSTGRES and not self.database:
1321
- raise ValueError("Database must be provided when storage type is POSTGRES")
1322
- return self
2052
+ @property
2053
+ def storage_type(self) -> StorageType:
2054
+ """Infer storage type from database presence."""
2055
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1323
2056
 
1324
2057
  def as_checkpointer(self) -> BaseCheckpointSaver:
1325
2058
  from dao_ai.memory import CheckpointManager
@@ -1335,16 +2068,14 @@ class StoreModel(BaseModel):
1335
2068
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1336
2069
  name: str
1337
2070
  embedding_model: Optional[LLMModel] = None
1338
- type: Optional[StorageType] = StorageType.MEMORY
1339
2071
  dims: Optional[int] = 1536
1340
2072
  database: Optional[DatabaseModel] = None
1341
2073
  namespace: Optional[str] = None
1342
2074
 
1343
- @model_validator(mode="after")
1344
- def validate_postgres_requires_database(self):
1345
- if self.type == StorageType.POSTGRES and not self.database:
1346
- raise ValueError("Database must be provided when storage type is POSTGRES")
1347
- return self
2075
+ @property
2076
+ def storage_type(self) -> StorageType:
2077
+ """Infer storage type from database presence."""
2078
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1348
2079
 
1349
2080
  def as_store(self) -> BaseStore:
1350
2081
  from dao_ai.memory import StoreManager
@@ -1362,56 +2093,158 @@ class MemoryModel(BaseModel):
1362
2093
  FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
1363
2094
 
1364
2095
 
1365
- class PromptModel(BaseModel, HasFullName):
2096
+ class ResponseFormatModel(BaseModel):
2097
+ """
2098
+ Configuration for structured response formats.
2099
+
2100
+ The response_schema field accepts either a type or a string:
2101
+ - Type (Pydantic model, dataclass, etc.): Used directly for structured output
2102
+ - String: First attempts to load as a fully qualified type name, falls back to JSON schema string
2103
+
2104
+ This unified approach simplifies the API while maintaining flexibility.
2105
+ """
2106
+
1366
2107
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1367
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1368
- name: str
1369
- description: Optional[str] = None
1370
- default_template: Optional[str] = None
1371
- alias: Optional[str] = None
1372
- version: Optional[int] = None
1373
- tags: Optional[dict[str, Any]] = Field(default_factory=dict)
2108
+ use_tool: Optional[bool] = Field(
2109
+ default=None,
2110
+ description=(
2111
+ "Strategy for structured output: "
2112
+ "None (default) = auto-detect from model capabilities, "
2113
+ "False = force ProviderStrategy (native), "
2114
+ "True = force ToolStrategy (function calling)"
2115
+ ),
2116
+ )
2117
+ response_schema: Optional[str | type] = Field(
2118
+ default=None,
2119
+ description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
2120
+ )
1374
2121
 
1375
- @property
1376
- def template(self) -> str:
1377
- from dao_ai.providers.databricks import DatabricksProvider
2122
+ def as_strategy(self) -> ProviderStrategy | ToolStrategy:
2123
+ """
2124
+ Convert response_schema to appropriate LangChain strategy.
1378
2125
 
1379
- provider: DatabricksProvider = DatabricksProvider()
1380
- prompt_version = provider.get_prompt(self)
1381
- return prompt_version.to_single_brace_format()
2126
+ Returns:
2127
+ - None if no response_schema configured
2128
+ - Raw schema/type for auto-detection (when use_tool=None)
2129
+ - ToolStrategy wrapping the schema (when use_tool=True)
2130
+ - ProviderStrategy wrapping the schema (when use_tool=False)
1382
2131
 
1383
- @property
1384
- def full_name(self) -> str:
1385
- prompt_name: str = self.name
1386
- if self.schema_model:
1387
- prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1388
- return prompt_name
2132
+ Raises:
2133
+ ValueError: If response_schema is a JSON schema string that cannot be parsed
2134
+ """
1389
2135
 
1390
- @property
1391
- def uri(self) -> str:
1392
- prompt_uri: str = f"prompts:/{self.full_name}"
2136
+ if self.response_schema is None:
2137
+ return None
1393
2138
 
1394
- if self.alias:
1395
- prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1396
- elif self.version:
1397
- prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1398
- else:
1399
- prompt_uri = f"prompts:/{self.full_name}@latest"
2139
+ schema = self.response_schema
1400
2140
 
1401
- return prompt_uri
2141
+ # Handle type schemas (Pydantic, dataclass, etc.)
2142
+ if self.is_type_schema:
2143
+ if self.use_tool is None:
2144
+ # Auto-detect: Pass schema directly, let LangChain decide
2145
+ return schema
2146
+ elif self.use_tool is True:
2147
+ # Force ToolStrategy (function calling)
2148
+ return ToolStrategy(schema)
2149
+ else: # use_tool is False
2150
+ # Force ProviderStrategy (native structured output)
2151
+ return ProviderStrategy(schema)
1402
2152
 
1403
- def as_prompt(self) -> PromptVersion:
1404
- prompt_version: PromptVersion = load_prompt(self.uri)
1405
- return prompt_version
2153
+ # Handle JSON schema strings
2154
+ elif self.is_json_schema:
2155
+ import json
2156
+
2157
+ try:
2158
+ schema_dict = json.loads(schema)
2159
+ except json.JSONDecodeError as e:
2160
+ raise ValueError(f"Invalid JSON schema string: {e}") from e
2161
+
2162
+ # Apply same use_tool logic as type schemas
2163
+ if self.use_tool is None:
2164
+ # Auto-detect
2165
+ return schema_dict
2166
+ elif self.use_tool is True:
2167
+ # Force ToolStrategy
2168
+ return ToolStrategy(schema_dict)
2169
+ else: # use_tool is False
2170
+ # Force ProviderStrategy
2171
+ return ProviderStrategy(schema_dict)
2172
+
2173
+ return None
1406
2174
 
1407
2175
  @model_validator(mode="after")
1408
- def validate_mutually_exclusive(self):
1409
- if self.alias and self.version:
1410
- raise ValueError("Cannot specify both alias and version")
1411
- return self
2176
+ def validate_response_schema(self) -> Self:
2177
+ """
2178
+ Validate and convert response_schema.
2179
+
2180
+ Processing logic:
2181
+ 1. If None: no response format specified
2182
+ 2. If type: use directly as structured output type
2183
+ 3. If str: try to load as FQN using type_from_fqn
2184
+ - Success: response_schema becomes the loaded type
2185
+ - Failure: keep as string (treated as JSON schema)
2186
+
2187
+ After validation, response_schema is one of:
2188
+ - None (no schema)
2189
+ - type (use for structured output)
2190
+ - str (JSON schema)
2191
+
2192
+ Returns:
2193
+ Self with validated response_schema
2194
+ """
2195
+ if self.response_schema is None:
2196
+ return self
2197
+
2198
+ # If already a type, return
2199
+ if isinstance(self.response_schema, type):
2200
+ return self
2201
+
2202
+ # If it's a string, try to load as type, fallback to json_schema
2203
+ if isinstance(self.response_schema, str):
2204
+ from dao_ai.utils import type_from_fqn
2205
+
2206
+ try:
2207
+ resolved_type = type_from_fqn(self.response_schema)
2208
+ self.response_schema = resolved_type
2209
+ logger.debug(
2210
+ f"Resolved response_schema string to type: {resolved_type}"
2211
+ )
2212
+ return self
2213
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
2214
+ # Keep as string - it's a JSON schema
2215
+ logger.debug(
2216
+ f"Could not resolve '{self.response_schema}' as type: {e}. "
2217
+ f"Treating as JSON schema string."
2218
+ )
2219
+ return self
2220
+
2221
+ # Invalid type
2222
+ raise ValueError(
2223
+ f"response_schema must be None, type, or str, got {type(self.response_schema)}"
2224
+ )
2225
+
2226
+ @property
2227
+ def is_type_schema(self) -> bool:
2228
+ """Returns True if response_schema is a type (not JSON schema string)."""
2229
+ return isinstance(self.response_schema, type)
2230
+
2231
+ @property
2232
+ def is_json_schema(self) -> bool:
2233
+ """Returns True if response_schema is a JSON schema string (not a type)."""
2234
+ return isinstance(self.response_schema, str)
1412
2235
 
1413
2236
 
1414
2237
  class AgentModel(BaseModel):
2238
+ """
2239
+ Configuration model for an agent in the DAO AI framework.
2240
+
2241
+ Agents combine an LLM with tools and middleware to create systems that can
2242
+ reason about tasks, decide which tools to use, and iteratively work towards solutions.
2243
+
2244
+ Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
2245
+ providing a more flexible and composable way to customize agent behavior.
2246
+ """
2247
+
1415
2248
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1416
2249
  name: str
1417
2250
  description: Optional[str] = None
@@ -1420,9 +2253,43 @@ class AgentModel(BaseModel):
1420
2253
  guardrails: list[GuardrailModel] = Field(default_factory=list)
1421
2254
  prompt: Optional[str | PromptModel] = None
1422
2255
  handoff_prompt: Optional[str] = None
1423
- create_agent_hook: Optional[FunctionHook] = None
1424
- pre_agent_hook: Optional[FunctionHook] = None
1425
- post_agent_hook: Optional[FunctionHook] = None
2256
+ middleware: list[MiddlewareModel] = Field(
2257
+ default_factory=list,
2258
+ description="List of middleware to apply to this agent",
2259
+ )
2260
+ response_format: Optional[ResponseFormatModel | type | str] = None
2261
+
2262
+ @model_validator(mode="after")
2263
+ def validate_response_format(self) -> Self:
2264
+ """
2265
+ Validate and normalize response_format.
2266
+
2267
+ Accepts:
2268
+ - None (no response format)
2269
+ - ResponseFormatModel (already validated)
2270
+ - type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
2271
+ - str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
2272
+
2273
+ ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
2274
+ """
2275
+ if self.response_format is None or isinstance(
2276
+ self.response_format, ResponseFormatModel
2277
+ ):
2278
+ return self
2279
+
2280
+ # Convert type or str to ResponseFormatModel
2281
+ # ResponseFormatModel's validator will handle the smart type loading and fallback
2282
+ if isinstance(self.response_format, (type, str)):
2283
+ self.response_format = ResponseFormatModel(
2284
+ response_schema=self.response_format
2285
+ )
2286
+ return self
2287
+
2288
+ # Invalid type
2289
+ raise ValueError(
2290
+ f"response_format must be None, ResponseFormatModel, type, or str, "
2291
+ f"got {type(self.response_format)}"
2292
+ )
1426
2293
 
1427
2294
  def as_runnable(self) -> RunnableLike:
1428
2295
  from dao_ai.nodes import create_agent_node
@@ -1441,12 +2308,20 @@ class SupervisorModel(BaseModel):
1441
2308
  model: LLMModel
1442
2309
  tools: list[ToolModel] = Field(default_factory=list)
1443
2310
  prompt: Optional[str] = None
2311
+ middleware: list[MiddlewareModel] = Field(
2312
+ default_factory=list,
2313
+ description="List of middleware to apply to the supervisor",
2314
+ )
1444
2315
 
1445
2316
 
1446
2317
  class SwarmModel(BaseModel):
1447
2318
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1448
2319
  model: LLMModel
1449
2320
  default_agent: Optional[AgentModel | str] = None
2321
+ middleware: list[MiddlewareModel] = Field(
2322
+ default_factory=list,
2323
+ description="List of middleware to apply to all agents in the swarm",
2324
+ )
1450
2325
  handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
1451
2326
  default_factory=dict
1452
2327
  )
@@ -1459,7 +2334,7 @@ class OrchestrationModel(BaseModel):
1459
2334
  memory: Optional[MemoryModel] = None
1460
2335
 
1461
2336
  @model_validator(mode="after")
1462
- def validate_mutually_exclusive(self):
2337
+ def validate_mutually_exclusive(self) -> Self:
1463
2338
  if self.supervisor is not None and self.swarm is not None:
1464
2339
  raise ValueError("Cannot specify both supervisor and swarm")
1465
2340
  if self.supervisor is None and self.swarm is None:
@@ -1489,9 +2364,21 @@ class Entitlement(str, Enum):
1489
2364
 
1490
2365
  class AppPermissionModel(BaseModel):
1491
2366
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1492
- principals: list[str] = Field(default_factory=list)
2367
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
1493
2368
  entitlements: list[Entitlement]
1494
2369
 
2370
+ @model_validator(mode="after")
2371
+ def resolve_principals(self) -> Self:
2372
+ """Resolve ServicePrincipalModel objects to their client_id."""
2373
+ resolved: list[str] = []
2374
+ for principal in self.principals:
2375
+ if isinstance(principal, ServicePrincipalModel):
2376
+ resolved.append(value_of(principal.client_id))
2377
+ else:
2378
+ resolved.append(principal)
2379
+ self.principals = resolved
2380
+ return self
2381
+
1495
2382
 
1496
2383
  class LogLevel(str, Enum):
1497
2384
  TRACE = "TRACE"
@@ -1552,6 +2439,28 @@ class ChatPayload(BaseModel):
1552
2439
 
1553
2440
  return self
1554
2441
 
2442
+ @model_validator(mode="after")
2443
+ def ensure_thread_id(self) -> "ChatPayload":
2444
+ """Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
2445
+ import uuid
2446
+
2447
+ if self.custom_inputs is None:
2448
+ self.custom_inputs = {}
2449
+
2450
+ # Get or create configurable section
2451
+ configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
2452
+
2453
+ # Check if thread_id or conversation_id exists
2454
+ has_thread_id = configurable.get("thread_id") is not None
2455
+ has_conversation_id = configurable.get("conversation_id") is not None
2456
+
2457
+ # If neither is provided, generate a UUID for conversation_id
2458
+ if not has_thread_id and not has_conversation_id:
2459
+ configurable["conversation_id"] = str(uuid.uuid4())
2460
+ self.custom_inputs["configurable"] = configurable
2461
+
2462
+ return self
2463
+
1555
2464
  def as_messages(self) -> Sequence[BaseMessage]:
1556
2465
  return messages_from_dict(
1557
2466
  [{"type": m.role, "content": m.content} for m in self.messages]
@@ -1567,25 +2476,44 @@ class ChatPayload(BaseModel):
1567
2476
 
1568
2477
 
1569
2478
  class ChatHistoryModel(BaseModel):
2479
+ """
2480
+ Configuration for chat history summarization.
2481
+
2482
+ Attributes:
2483
+ model: The LLM to use for generating summaries.
2484
+ max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
2485
+ After summarization, recent messages totaling up to this many tokens are preserved.
2486
+ max_tokens_before_summary: Token threshold that triggers summarization.
2487
+ When conversation exceeds this, summarization runs. Mutually exclusive with
2488
+ max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
2489
+ max_messages_before_summary: Message count threshold that triggers summarization.
2490
+ When conversation exceeds this many messages, summarization runs.
2491
+ Mutually exclusive with max_tokens_before_summary.
2492
+ """
2493
+
1570
2494
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1571
2495
  model: LLMModel
1572
- max_tokens: int = 256
1573
- max_tokens_before_summary: Optional[int] = None
1574
- max_messages_before_summary: Optional[int] = None
1575
- max_summary_tokens: int = 255
1576
-
1577
- @model_validator(mode="after")
1578
- def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1579
- if self.max_summary_tokens >= self.max_tokens:
1580
- raise ValueError(
1581
- f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1582
- )
1583
- return self
2496
+ max_tokens: int = Field(
2497
+ default=2048,
2498
+ gt=0,
2499
+ description="Maximum tokens to keep after summarization",
2500
+ )
2501
+ max_tokens_before_summary: Optional[int] = Field(
2502
+ default=None,
2503
+ gt=0,
2504
+ description="Token threshold that triggers summarization",
2505
+ )
2506
+ max_messages_before_summary: Optional[int] = Field(
2507
+ default=None,
2508
+ gt=0,
2509
+ description="Message count threshold that triggers summarization",
2510
+ )
1584
2511
 
1585
2512
 
1586
2513
  class AppModel(BaseModel):
1587
2514
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1588
2515
  name: str
2516
+ service_principal: Optional[ServicePrincipalModel] = None
1589
2517
  description: Optional[str] = None
1590
2518
  log_level: Optional[LogLevel] = "WARNING"
1591
2519
  registered_model: RegisteredModelModel
@@ -1606,23 +2534,54 @@ class AppModel(BaseModel):
1606
2534
  shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1607
2535
  default_factory=list
1608
2536
  )
1609
- message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1610
- default_factory=list
1611
- )
1612
2537
  input_example: Optional[ChatPayload] = None
1613
2538
  chat_history: Optional[ChatHistoryModel] = None
1614
2539
  code_paths: list[str] = Field(default_factory=list)
1615
2540
  pip_requirements: list[str] = Field(default_factory=list)
2541
+ python_version: Optional[str] = Field(
2542
+ default="3.12",
2543
+ description="Python version for Model Serving deployment. Defaults to 3.12 "
2544
+ "which is supported by Databricks Model Serving. This allows deploying from "
2545
+ "environments with different Python versions (e.g., Databricks Apps with 3.11).",
2546
+ )
2547
+
2548
+ @model_validator(mode="after")
2549
+ def set_databricks_env_vars(self) -> Self:
2550
+ """Set Databricks environment variables for Model Serving.
2551
+
2552
+ Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
2553
+ Values explicitly provided in environment_vars take precedence.
2554
+ """
2555
+ from dao_ai.utils import get_default_databricks_host
2556
+
2557
+ # Set DATABRICKS_HOST if not already provided
2558
+ if "DATABRICKS_HOST" not in self.environment_vars:
2559
+ host: str | None = get_default_databricks_host()
2560
+ if host:
2561
+ self.environment_vars["DATABRICKS_HOST"] = host
2562
+
2563
+ # Set service principal credentials if provided
2564
+ if self.service_principal is not None:
2565
+ if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
2566
+ self.environment_vars["DATABRICKS_CLIENT_ID"] = (
2567
+ self.service_principal.client_id
2568
+ )
2569
+ if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
2570
+ self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
2571
+ self.service_principal.client_secret
2572
+ )
2573
+ return self
1616
2574
 
1617
2575
  @model_validator(mode="after")
1618
- def validate_agents_not_empty(self):
2576
+ def validate_agents_not_empty(self) -> Self:
1619
2577
  if not self.agents:
1620
2578
  raise ValueError("At least one agent must be specified")
1621
2579
  return self
1622
2580
 
1623
2581
  @model_validator(mode="after")
1624
- def update_environment_vars(self):
2582
+ def resolve_environment_vars(self) -> Self:
1625
2583
  for key, value in self.environment_vars.items():
2584
+ updated_value: str
1626
2585
  if isinstance(value, SecretVariableModel):
1627
2586
  updated_value = str(value)
1628
2587
  else:
@@ -1632,7 +2591,7 @@ class AppModel(BaseModel):
1632
2591
  return self
1633
2592
 
1634
2593
  @model_validator(mode="after")
1635
- def set_default_orchestration(self):
2594
+ def set_default_orchestration(self) -> Self:
1636
2595
  if self.orchestration is None:
1637
2596
  if len(self.agents) > 1:
1638
2597
  default_agent: AgentModel = self.agents[0]
@@ -1652,14 +2611,14 @@ class AppModel(BaseModel):
1652
2611
  return self
1653
2612
 
1654
2613
  @model_validator(mode="after")
1655
- def set_default_endpoint_name(self):
2614
+ def set_default_endpoint_name(self) -> Self:
1656
2615
  if self.endpoint_name is None:
1657
2616
  self.endpoint_name = self.name
1658
2617
  return self
1659
2618
 
1660
2619
  @model_validator(mode="after")
1661
- def set_default_agent(self):
1662
- default_agent_name = self.agents[0].name
2620
+ def set_default_agent(self) -> Self:
2621
+ default_agent_name: str = self.agents[0].name
1663
2622
 
1664
2623
  if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
1665
2624
  self.orchestration.swarm.default_agent = default_agent_name
@@ -1667,7 +2626,7 @@ class AppModel(BaseModel):
1667
2626
  return self
1668
2627
 
1669
2628
  @model_validator(mode="after")
1670
- def add_code_paths_to_sys_path(self):
2629
+ def add_code_paths_to_sys_path(self) -> Self:
1671
2630
  for code_path in self.code_paths:
1672
2631
  parent_path: str = str(Path(code_path).parent)
1673
2632
  if parent_path not in sys.path:
@@ -1700,7 +2659,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
1700
2659
  expected_facts: Optional[list[str]] = None
1701
2660
 
1702
2661
  @model_validator(mode="after")
1703
- def validate_mutually_exclusive(self):
2662
+ def validate_mutually_exclusive(self) -> Self:
1704
2663
  if self.expected_response is not None and self.expected_facts is not None:
1705
2664
  raise ValueError("Cannot specify both expected_response and expected_facts")
1706
2665
  return self
@@ -1779,36 +2738,70 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
1779
2738
 
1780
2739
 
1781
2740
  class PromptOptimizationModel(BaseModel):
2741
+ """Configuration for prompt optimization using GEPA.
2742
+
2743
+ GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
2744
+ optimizer that uses reflective mutation to improve prompts based on
2745
+ evaluation feedback.
2746
+
2747
+ Example:
2748
+ prompt_optimization:
2749
+ name: optimize_my_prompt
2750
+ prompt: *my_prompt
2751
+ agent: *my_agent
2752
+ dataset: *my_training_dataset
2753
+ reflection_model: databricks-meta-llama-3-3-70b-instruct
2754
+ num_candidates: 50
2755
+ """
2756
+
1782
2757
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1783
2758
  name: str
1784
2759
  prompt: Optional[PromptModel] = None
1785
2760
  agent: AgentModel
1786
- dataset: (
1787
- EvaluationDatasetModel | str
1788
- ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
2761
+ dataset: EvaluationDatasetModel # Training dataset with examples
1789
2762
  reflection_model: Optional[LLMModel | str] = None
1790
2763
  num_candidates: Optional[int] = 50
1791
- scorer_model: Optional[LLMModel | str] = None
1792
2764
 
1793
2765
  def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1794
2766
  """
1795
- Optimize the prompt using MLflow's prompt optimization.
2767
+ Optimize the prompt using GEPA.
1796
2768
 
1797
2769
  Args:
1798
- w: Optional WorkspaceClient for Databricks operations
2770
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
1799
2771
 
1800
2772
  Returns:
1801
- PromptModel: The optimized prompt model with new URI
2773
+ PromptModel: The optimized prompt model
1802
2774
  """
1803
- from dao_ai.providers.base import ServiceProvider
1804
- from dao_ai.providers.databricks import DatabricksProvider
2775
+ from dao_ai.optimization import OptimizationResult, optimize_prompt
1805
2776
 
1806
- provider: ServiceProvider = DatabricksProvider(w=w)
1807
- optimized_prompt: PromptModel = provider.optimize_prompt(self)
1808
- return optimized_prompt
2777
+ # Get reflection model name
2778
+ reflection_model_name: str | None = None
2779
+ if self.reflection_model:
2780
+ if isinstance(self.reflection_model, str):
2781
+ reflection_model_name = self.reflection_model
2782
+ else:
2783
+ reflection_model_name = self.reflection_model.uri
2784
+
2785
+ # Ensure prompt is set
2786
+ prompt = self.prompt
2787
+ if prompt is None:
2788
+ raise ValueError(
2789
+ f"Prompt optimization '{self.name}' requires a prompt to be set"
2790
+ )
2791
+
2792
+ result: OptimizationResult = optimize_prompt(
2793
+ prompt=prompt,
2794
+ agent=self.agent,
2795
+ dataset=self.dataset,
2796
+ reflection_model=reflection_model_name,
2797
+ num_candidates=self.num_candidates or 50,
2798
+ register_if_improved=True,
2799
+ )
2800
+
2801
+ return result.optimized_prompt
1809
2802
 
1810
2803
  @model_validator(mode="after")
1811
- def set_defaults(self):
2804
+ def set_defaults(self) -> Self:
1812
2805
  # If no prompt is specified, try to use the agent's prompt
1813
2806
  if self.prompt is None:
1814
2807
  if isinstance(self.agent.prompt, PromptModel):
@@ -1819,12 +2812,6 @@ class PromptOptimizationModel(BaseModel):
1819
2812
  f"or an agent with a prompt configured"
1820
2813
  )
1821
2814
 
1822
- if self.reflection_model is None:
1823
- self.reflection_model = self.agent.model
1824
-
1825
- if self.scorer_model is None:
1826
- self.scorer_model = self.agent.model
1827
-
1828
2815
  return self
1829
2816
 
1830
2817
 
@@ -1897,7 +2884,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
1897
2884
 
1898
2885
  class UnityCatalogFunctionSqlModel(BaseModel):
1899
2886
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1900
- function: UnityCatalogFunctionModel
2887
+ function: FunctionModel
1901
2888
  ddl: str
1902
2889
  parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
1903
2890
  test: Optional[UnityCatalogFunctionSqlTestModel] = None
@@ -1925,16 +2912,126 @@ class ResourcesModel(BaseModel):
1925
2912
  warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
1926
2913
  databases: dict[str, DatabaseModel] = Field(default_factory=dict)
1927
2914
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
2915
+ apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
2916
+
2917
+ @model_validator(mode="after")
2918
+ def update_genie_warehouses(self) -> Self:
2919
+ """
2920
+ Automatically populate warehouses from genie_rooms.
2921
+
2922
+ Warehouses are extracted from each Genie room and added to the
2923
+ resources if they don't already exist (based on warehouse_id).
2924
+ """
2925
+ if not self.genie_rooms:
2926
+ return self
2927
+
2928
+ # Process warehouses from all genie rooms
2929
+ for genie_room in self.genie_rooms.values():
2930
+ genie_room: GenieRoomModel
2931
+ warehouse: Optional[WarehouseModel] = genie_room.warehouse
2932
+
2933
+ if warehouse is None:
2934
+ continue
2935
+
2936
+ # Check if warehouse already exists based on warehouse_id
2937
+ warehouse_exists: bool = any(
2938
+ existing_warehouse.warehouse_id == warehouse.warehouse_id
2939
+ for existing_warehouse in self.warehouses.values()
2940
+ )
2941
+
2942
+ if not warehouse_exists:
2943
+ warehouse_key: str = normalize_name(
2944
+ "_".join([genie_room.name, warehouse.warehouse_id])
2945
+ )
2946
+ self.warehouses[warehouse_key] = warehouse
2947
+ logger.trace(
2948
+ "Added warehouse from Genie room",
2949
+ room=genie_room.name,
2950
+ warehouse=warehouse.warehouse_id,
2951
+ key=warehouse_key,
2952
+ )
2953
+
2954
+ return self
2955
+
2956
+ @model_validator(mode="after")
2957
+ def update_genie_tables(self) -> Self:
2958
+ """
2959
+ Automatically populate tables from genie_rooms.
2960
+
2961
+ Tables are extracted from each Genie room and added to the
2962
+ resources if they don't already exist (based on full_name).
2963
+ """
2964
+ if not self.genie_rooms:
2965
+ return self
2966
+
2967
+ # Process tables from all genie rooms
2968
+ for genie_room in self.genie_rooms.values():
2969
+ genie_room: GenieRoomModel
2970
+ for table in genie_room.tables:
2971
+ table: TableModel
2972
+ table_exists: bool = any(
2973
+ existing_table.full_name == table.full_name
2974
+ for existing_table in self.tables.values()
2975
+ )
2976
+ if not table_exists:
2977
+ table_key: str = normalize_name(
2978
+ "_".join([genie_room.name, table.full_name])
2979
+ )
2980
+ self.tables[table_key] = table
2981
+ logger.trace(
2982
+ "Added table from Genie room",
2983
+ room=genie_room.name,
2984
+ table=table.name,
2985
+ key=table_key,
2986
+ )
2987
+
2988
+ return self
2989
+
2990
+ @model_validator(mode="after")
2991
+ def update_genie_functions(self) -> Self:
2992
+ """
2993
+ Automatically populate functions from genie_rooms.
2994
+
2995
+ Functions are extracted from each Genie room and added to the
2996
+ resources if they don't already exist (based on full_name).
2997
+ """
2998
+ if not self.genie_rooms:
2999
+ return self
3000
+
3001
+ # Process functions from all genie rooms
3002
+ for genie_room in self.genie_rooms.values():
3003
+ genie_room: GenieRoomModel
3004
+ for function in genie_room.functions:
3005
+ function: FunctionModel
3006
+ function_exists: bool = any(
3007
+ existing_function.full_name == function.full_name
3008
+ for existing_function in self.functions.values()
3009
+ )
3010
+ if not function_exists:
3011
+ function_key: str = normalize_name(
3012
+ "_".join([genie_room.name, function.full_name])
3013
+ )
3014
+ self.functions[function_key] = function
3015
+ logger.trace(
3016
+ "Added function from Genie room",
3017
+ room=genie_room.name,
3018
+ function=function.name,
3019
+ key=function_key,
3020
+ )
3021
+
3022
+ return self
1928
3023
 
1929
3024
 
1930
3025
  class AppConfig(BaseModel):
1931
3026
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1932
3027
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
3028
+ service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
1933
3029
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
1934
3030
  resources: Optional[ResourcesModel] = None
1935
3031
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
1936
3032
  tools: dict[str, ToolModel] = Field(default_factory=dict)
1937
3033
  guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
3034
+ middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
1938
3035
  memory: Optional[MemoryModel] = None
1939
3036
  prompts: dict[str, PromptModel] = Field(default_factory=dict)
1940
3037
  agents: dict[str, AgentModel] = Field(default_factory=dict)
@@ -1962,10 +3059,10 @@ class AppConfig(BaseModel):
1962
3059
 
1963
3060
  def initialize(self) -> None:
1964
3061
  from dao_ai.hooks.core import create_hooks
3062
+ from dao_ai.logging import configure_logging
1965
3063
 
1966
3064
  if self.app and self.app.log_level:
1967
- logger.remove()
1968
- logger.add(sys.stderr, level=self.app.log_level)
3065
+ configure_logging(level=self.app.log_level)
1969
3066
 
1970
3067
  logger.debug("Calling initialization hooks...")
1971
3068
  initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
@@ -2009,21 +3106,45 @@ class AppConfig(BaseModel):
2009
3106
  def create_agent(
2010
3107
  self,
2011
3108
  w: WorkspaceClient | None = None,
3109
+ vsc: "VectorSearchClient | None" = None,
3110
+ pat: str | None = None,
3111
+ client_id: str | None = None,
3112
+ client_secret: str | None = None,
3113
+ workspace_host: str | None = None,
2012
3114
  ) -> None:
2013
3115
  from dao_ai.providers.base import ServiceProvider
2014
3116
  from dao_ai.providers.databricks import DatabricksProvider
2015
3117
 
2016
- provider: ServiceProvider = DatabricksProvider(w=w)
3118
+ provider: ServiceProvider = DatabricksProvider(
3119
+ w=w,
3120
+ vsc=vsc,
3121
+ pat=pat,
3122
+ client_id=client_id,
3123
+ client_secret=client_secret,
3124
+ workspace_host=workspace_host,
3125
+ )
2017
3126
  provider.create_agent(self)
2018
3127
 
2019
3128
  def deploy_agent(
2020
3129
  self,
2021
3130
  w: WorkspaceClient | None = None,
3131
+ vsc: "VectorSearchClient | None" = None,
3132
+ pat: str | None = None,
3133
+ client_id: str | None = None,
3134
+ client_secret: str | None = None,
3135
+ workspace_host: str | None = None,
2022
3136
  ) -> None:
2023
3137
  from dao_ai.providers.base import ServiceProvider
2024
3138
  from dao_ai.providers.databricks import DatabricksProvider
2025
3139
 
2026
- provider: ServiceProvider = DatabricksProvider(w=w)
3140
+ provider: ServiceProvider = DatabricksProvider(
3141
+ w=w,
3142
+ vsc=vsc,
3143
+ pat=pat,
3144
+ client_id=client_id,
3145
+ client_secret=client_secret,
3146
+ workspace_host=workspace_host,
3147
+ )
2027
3148
  provider.deploy_agent(self)
2028
3149
 
2029
3150
  def find_agents(