dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  Iterator,
13
13
  Literal,
14
14
  Optional,
15
+ Self,
15
16
  Sequence,
16
17
  TypeAlias,
17
18
  Union,
@@ -22,13 +23,20 @@ from databricks.sdk.credentials_provider import (
22
23
  CredentialsStrategy,
23
24
  ModelServingUserCredentials,
24
25
  )
26
+ from databricks.sdk.errors.platform import NotFound
25
27
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
28
+ from databricks.sdk.service.dashboards import GenieSpace
26
29
  from databricks.sdk.service.database import DatabaseInstance
30
+ from databricks.sdk.service.sql import GetWarehouseResponse
27
31
  from databricks.vector_search.client import VectorSearchClient
28
32
  from databricks.vector_search.index import VectorSearchIndex
29
33
  from databricks_langchain import (
34
+ ChatDatabricks,
35
+ DatabricksEmbeddings,
30
36
  DatabricksFunctionClient,
31
37
  )
38
+ from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
39
+ from langchain_core.embeddings import Embeddings
32
40
  from langchain_core.language_models import LanguageModelLike
33
41
  from langchain_core.messages import BaseMessage, messages_from_dict
34
42
  from langchain_core.runnables.base import RunnableLike
@@ -41,6 +49,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
49
  from mlflow.genai.prompts import PromptVersion, load_prompt
42
50
  from mlflow.models import ModelConfig
43
51
  from mlflow.models.resources import (
52
+ DatabricksApp,
44
53
  DatabricksFunction,
45
54
  DatabricksGenieSpace,
46
55
  DatabricksLakebase,
@@ -59,10 +68,13 @@ from pydantic import (
59
68
  BaseModel,
60
69
  ConfigDict,
61
70
  Field,
71
+ PrivateAttr,
62
72
  field_serializer,
63
73
  model_validator,
64
74
  )
65
75
 
76
+ from dao_ai.utils import normalize_name
77
+
66
78
 
67
79
  class HasValue(ABC):
68
80
  @abstractmethod
@@ -81,27 +93,6 @@ class HasFullName(ABC):
81
93
  def full_name(self) -> str: ...
82
94
 
83
95
 
84
- class IsDatabricksResource(ABC):
85
- on_behalf_of_user: Optional[bool] = False
86
-
87
- @abstractmethod
88
- def as_resources(self) -> Sequence[DatabricksResource]: ...
89
-
90
- @property
91
- @abstractmethod
92
- def api_scopes(self) -> Sequence[str]: ...
93
-
94
- @property
95
- def workspace_client(self) -> WorkspaceClient:
96
- credentials_strategy: CredentialsStrategy = None
97
- if self.on_behalf_of_user:
98
- credentials_strategy = ModelServingUserCredentials()
99
- logger.debug(
100
- f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
101
- )
102
- return WorkspaceClient(credentials_strategy=credentials_strategy)
103
-
104
-
105
96
  class EnvironmentVariableModel(BaseModel, HasValue):
106
97
  model_config = ConfigDict(
107
98
  frozen=True,
@@ -200,6 +191,162 @@ AnyVariable: TypeAlias = (
200
191
  )
201
192
 
202
193
 
194
+ class ServicePrincipalModel(BaseModel):
195
+ model_config = ConfigDict(
196
+ frozen=True,
197
+ use_enum_values=True,
198
+ )
199
+ client_id: AnyVariable
200
+ client_secret: AnyVariable
201
+
202
+
203
+ class IsDatabricksResource(ABC, BaseModel):
204
+ """
205
+ Base class for Databricks resources with authentication support.
206
+
207
+ Authentication Options:
208
+ ----------------------
209
+ 1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
210
+ calling user's identity via ModelServingUserCredentials.
211
+
212
+ 2. **Service Principal (OAuth M2M)**: Provide service_principal or
213
+ (client_id + client_secret + workspace_host) for service principal auth.
214
+
215
+ 3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
216
+ to authenticate with a personal access token.
217
+
218
+ 4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
219
+ (environment variables, notebook context, etc.)
220
+
221
+ Authentication Priority:
222
+ 1. OBO (on_behalf_of_user=True)
223
+ 2. Service Principal (client_id + client_secret + workspace_host)
224
+ 3. PAT (pat + workspace_host)
225
+ 4. Ambient/default authentication
226
+ """
227
+
228
+ model_config = ConfigDict(use_enum_values=True)
229
+
230
+ on_behalf_of_user: Optional[bool] = False
231
+ service_principal: Optional[ServicePrincipalModel] = None
232
+ client_id: Optional[AnyVariable] = None
233
+ client_secret: Optional[AnyVariable] = None
234
+ workspace_host: Optional[AnyVariable] = None
235
+ pat: Optional[AnyVariable] = None
236
+
237
+ # Private attribute to cache the workspace client (lazy instantiation)
238
+ _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
239
+
240
+ @abstractmethod
241
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
242
+
243
+ @property
244
+ @abstractmethod
245
+ def api_scopes(self) -> Sequence[str]: ...
246
+
247
+ @model_validator(mode="after")
248
+ def _expand_service_principal(self) -> Self:
249
+ """Expand service_principal into client_id and client_secret if provided."""
250
+ if self.service_principal is not None:
251
+ if self.client_id is None:
252
+ self.client_id = self.service_principal.client_id
253
+ if self.client_secret is None:
254
+ self.client_secret = self.service_principal.client_secret
255
+ return self
256
+
257
+ @model_validator(mode="after")
258
+ def _validate_auth_not_mixed(self) -> Self:
259
+ """Validate that OAuth and PAT authentication are not both provided."""
260
+ has_oauth: bool = self.client_id is not None and self.client_secret is not None
261
+ has_pat: bool = self.pat is not None
262
+
263
+ if has_oauth and has_pat:
264
+ raise ValueError(
265
+ "Cannot use both OAuth and user authentication methods. "
266
+ "Please provide either OAuth credentials or user credentials."
267
+ )
268
+ return self
269
+
270
+ @property
271
+ def workspace_client(self) -> WorkspaceClient:
272
+ """
273
+ Get a WorkspaceClient configured with the appropriate authentication.
274
+
275
+ The client is lazily instantiated on first access and cached for subsequent calls.
276
+
277
+ Authentication priority:
278
+ 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
279
+ 2. If service principal credentials are configured (client_id, client_secret,
280
+ workspace_host), uses OAuth M2M
281
+ 3. If PAT is configured, uses token authentication
282
+ 4. Otherwise, uses default/ambient authentication
283
+ """
284
+ # Return cached client if already instantiated
285
+ if self._workspace_client is not None:
286
+ return self._workspace_client
287
+
288
+ from dao_ai.utils import normalize_host
289
+
290
+ # Check for OBO first (highest priority)
291
+ if self.on_behalf_of_user:
292
+ credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
293
+ logger.debug(
294
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
295
+ f"with OBO credentials strategy"
296
+ )
297
+ self._workspace_client = WorkspaceClient(
298
+ credentials_strategy=credentials_strategy
299
+ )
300
+ return self._workspace_client
301
+
302
+ # Check for service principal credentials
303
+ client_id_value: str | None = (
304
+ value_of(self.client_id) if self.client_id else None
305
+ )
306
+ client_secret_value: str | None = (
307
+ value_of(self.client_secret) if self.client_secret else None
308
+ )
309
+ workspace_host_value: str | None = (
310
+ normalize_host(value_of(self.workspace_host))
311
+ if self.workspace_host
312
+ else None
313
+ )
314
+
315
+ if client_id_value and client_secret_value and workspace_host_value:
316
+ logger.debug(
317
+ f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
318
+ f"client_id={client_id_value}, host={workspace_host_value}"
319
+ )
320
+ self._workspace_client = WorkspaceClient(
321
+ host=workspace_host_value,
322
+ client_id=client_id_value,
323
+ client_secret=client_secret_value,
324
+ auth_type="oauth-m2m",
325
+ )
326
+ return self._workspace_client
327
+
328
+ # Check for PAT authentication
329
+ pat_value: str | None = value_of(self.pat) if self.pat else None
330
+ if pat_value:
331
+ logger.debug(
332
+ f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
333
+ )
334
+ self._workspace_client = WorkspaceClient(
335
+ host=workspace_host_value,
336
+ token=pat_value,
337
+ auth_type="pat",
338
+ )
339
+ return self._workspace_client
340
+
341
+ # Default: use ambient authentication
342
+ logger.debug(
343
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
344
+ "with default/ambient authentication"
345
+ )
346
+ self._workspace_client = WorkspaceClient()
347
+ return self._workspace_client
348
+
349
+
203
350
  class Privilege(str, Enum):
204
351
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
205
352
  USE_CATALOG = "USE_CATALOG"
@@ -226,9 +373,21 @@ class Privilege(str, Enum):
226
373
 
227
374
  class PermissionModel(BaseModel):
228
375
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
229
- principals: list[str] = Field(default_factory=list)
376
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
230
377
  privileges: list[Privilege]
231
378
 
379
+ @model_validator(mode="after")
380
+ def resolve_principals(self) -> Self:
381
+ """Resolve ServicePrincipalModel objects to their client_id."""
382
+ resolved: list[str] = []
383
+ for principal in self.principals:
384
+ if isinstance(principal, ServicePrincipalModel):
385
+ resolved.append(value_of(principal.client_id))
386
+ else:
387
+ resolved.append(principal)
388
+ self.principals = resolved
389
+ return self
390
+
232
391
 
233
392
  class SchemaModel(BaseModel, HasFullName):
234
393
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -248,7 +407,26 @@ class SchemaModel(BaseModel, HasFullName):
248
407
  provider.create_schema(self)
249
408
 
250
409
 
251
- class TableModel(BaseModel, HasFullName, IsDatabricksResource):
410
+ class DatabricksAppModel(IsDatabricksResource, HasFullName):
411
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
412
+ name: str
413
+ url: str
414
+
415
+ @property
416
+ def full_name(self) -> str:
417
+ return self.name
418
+
419
+ @property
420
+ def api_scopes(self) -> Sequence[str]:
421
+ return ["apps.apps"]
422
+
423
+ def as_resources(self) -> Sequence[DatabricksResource]:
424
+ return [
425
+ DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
426
+ ]
427
+
428
+
429
+ class TableModel(IsDatabricksResource, HasFullName):
252
430
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
253
431
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
254
432
  name: Optional[str] = None
@@ -274,6 +452,22 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
274
452
  def api_scopes(self) -> Sequence[str]:
275
453
  return []
276
454
 
455
+ def exists(self) -> bool:
456
+ """Check if the table exists in Unity Catalog.
457
+
458
+ Returns:
459
+ True if the table exists, False otherwise.
460
+ """
461
+ try:
462
+ self.workspace_client.tables.get(full_name=self.full_name)
463
+ return True
464
+ except NotFound:
465
+ logger.debug(f"Table not found: {self.full_name}")
466
+ return False
467
+ except Exception as e:
468
+ logger.warning(f"Error checking table existence for {self.full_name}: {e}")
469
+ return False
470
+
277
471
  def as_resources(self) -> Sequence[DatabricksResource]:
278
472
  resources: list[DatabricksResource] = []
279
473
 
@@ -317,12 +511,17 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
317
511
  return resources
318
512
 
319
513
 
320
- class LLMModel(BaseModel, IsDatabricksResource):
514
+ class LLMModel(IsDatabricksResource):
321
515
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
322
516
  name: str
517
+ description: Optional[str] = None
323
518
  temperature: Optional[float] = 0.1
324
519
  max_tokens: Optional[int] = 8192
325
520
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
521
+ use_responses_api: Optional[bool] = Field(
522
+ default=False,
523
+ description="Use Responses API for ResponsesAgent endpoints",
524
+ )
326
525
 
327
526
  @property
328
527
  def api_scopes(self) -> Sequence[str]:
@@ -342,19 +541,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
342
541
  ]
343
542
 
344
543
  def as_chat_model(self) -> LanguageModelLike:
345
- # Retrieve langchain chat client from workspace client to enable OBO
346
- # ChatOpenAI does not allow additional inputs at the moment, so we cannot use it directly
347
- # chat_client: LanguageModelLike = self.as_open_ai_client()
348
-
349
- # Create ChatDatabricksWrapper instance directly
350
- from dao_ai.chat_models import ChatDatabricksFiltered
351
-
352
- chat_client: LanguageModelLike = ChatDatabricksFiltered(
353
- model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
544
+ chat_client: LanguageModelLike = ChatDatabricks(
545
+ model=self.name,
546
+ temperature=self.temperature,
547
+ max_tokens=self.max_tokens,
548
+ use_responses_api=self.use_responses_api,
354
549
  )
355
- # chat_client: LanguageModelLike = ChatDatabricks(
356
- # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
357
- # )
358
550
 
359
551
  fallbacks: Sequence[LanguageModelLike] = []
360
552
  for fallback in self.fallbacks:
@@ -386,6 +578,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
386
578
 
387
579
  return chat_client
388
580
 
581
+ def as_embeddings_model(self) -> Embeddings:
582
+ return DatabricksEmbeddings(endpoint=self.name)
583
+
389
584
 
390
585
  class VectorSearchEndpointType(str, Enum):
391
586
  STANDARD = "STANDARD"
@@ -405,7 +600,9 @@ class VectorSearchEndpoint(BaseModel):
405
600
  return str(value)
406
601
 
407
602
 
408
- class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
603
+ class IndexModel(IsDatabricksResource, HasFullName):
604
+ """Model representing a Databricks Vector Search index."""
605
+
409
606
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
410
607
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
411
608
  name: str
@@ -429,13 +626,314 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
429
626
  )
430
627
  ]
431
628
 
629
+ def exists(self) -> bool:
630
+ """Check if this vector search index exists.
631
+
632
+ Returns:
633
+ True if the index exists, False otherwise.
634
+ """
635
+ try:
636
+ self.workspace_client.vector_search_indexes.get_index(self.full_name)
637
+ return True
638
+ except NotFound:
639
+ logger.debug(f"Index not found: {self.full_name}")
640
+ return False
641
+ except Exception as e:
642
+ logger.warning(f"Error checking index existence for {self.full_name}: {e}")
643
+ return False
644
+
432
645
 
433
- class GenieRoomModel(BaseModel, IsDatabricksResource):
646
+ class FunctionModel(IsDatabricksResource, HasFullName):
647
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
648
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
649
+ name: Optional[str] = None
650
+
651
+ @model_validator(mode="after")
652
+ def validate_name_or_schema_required(self) -> Self:
653
+ if not self.name and not self.schema_model:
654
+ raise ValueError(
655
+ "Either 'name' or 'schema_model' must be provided for FunctionModel"
656
+ )
657
+ return self
658
+
659
+ @property
660
+ def full_name(self) -> str:
661
+ if self.schema_model:
662
+ name: str = ""
663
+ if self.name:
664
+ name = f".{self.name}"
665
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
666
+ return self.name
667
+
668
+ def exists(self) -> bool:
669
+ """Check if the function exists in Unity Catalog.
670
+
671
+ Returns:
672
+ True if the function exists, False otherwise.
673
+ """
674
+ try:
675
+ self.workspace_client.functions.get(name=self.full_name)
676
+ return True
677
+ except NotFound:
678
+ logger.debug(f"Function not found: {self.full_name}")
679
+ return False
680
+ except Exception as e:
681
+ logger.warning(
682
+ f"Error checking function existence for {self.full_name}: {e}"
683
+ )
684
+ return False
685
+
686
+ def as_resources(self) -> Sequence[DatabricksResource]:
687
+ resources: list[DatabricksResource] = []
688
+ if self.name:
689
+ resources.append(
690
+ DatabricksFunction(
691
+ function_name=self.full_name,
692
+ on_behalf_of_user=self.on_behalf_of_user,
693
+ )
694
+ )
695
+ else:
696
+ w: WorkspaceClient = self.workspace_client
697
+ schema_full_name: str = self.schema_model.full_name
698
+ functions: Iterator[FunctionInfo] = w.functions.list(
699
+ catalog_name=self.schema_model.catalog_name,
700
+ schema_name=self.schema_model.schema_name,
701
+ )
702
+ resources.extend(
703
+ [
704
+ DatabricksFunction(
705
+ function_name=f"{schema_full_name}.{function.name}",
706
+ on_behalf_of_user=self.on_behalf_of_user,
707
+ )
708
+ for function in functions
709
+ ]
710
+ )
711
+
712
+ return resources
713
+
714
+ @property
715
+ def api_scopes(self) -> Sequence[str]:
716
+ return ["sql.statement-execution"]
717
+
718
+
719
+ class WarehouseModel(IsDatabricksResource):
720
+ model_config = ConfigDict()
721
+ name: str
722
+ description: Optional[str] = None
723
+ warehouse_id: AnyVariable
724
+
725
+ @property
726
+ def api_scopes(self) -> Sequence[str]:
727
+ return [
728
+ "sql.warehouses",
729
+ "sql.statement-execution",
730
+ ]
731
+
732
+ def as_resources(self) -> Sequence[DatabricksResource]:
733
+ return [
734
+ DatabricksSQLWarehouse(
735
+ warehouse_id=value_of(self.warehouse_id),
736
+ on_behalf_of_user=self.on_behalf_of_user,
737
+ )
738
+ ]
739
+
740
+ @model_validator(mode="after")
741
+ def update_warehouse_id(self) -> Self:
742
+ self.warehouse_id = value_of(self.warehouse_id)
743
+ return self
744
+
745
+
746
+ class GenieRoomModel(IsDatabricksResource):
434
747
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
435
748
  name: str
436
749
  description: Optional[str] = None
437
750
  space_id: AnyVariable
438
751
 
752
+ _space_details: Optional[GenieSpace] = PrivateAttr(default=None)
753
+
754
+ def _get_space_details(self) -> GenieSpace:
755
+ if self._space_details is None:
756
+ self._space_details = self.workspace_client.genie.get_space(
757
+ space_id=self.space_id, include_serialized_space=True
758
+ )
759
+ return self._space_details
760
+
761
+ def _parse_serialized_space(self) -> dict[str, Any]:
762
+ """Parse the serialized_space JSON string and return the parsed data."""
763
+ import json
764
+
765
+ space_details = self._get_space_details()
766
+ if not space_details.serialized_space:
767
+ return {}
768
+
769
+ try:
770
+ return json.loads(space_details.serialized_space)
771
+ except json.JSONDecodeError as e:
772
+ logger.warning(f"Failed to parse serialized_space: {e}")
773
+ return {}
774
+
775
+ @property
776
+ def warehouse(self) -> Optional[WarehouseModel]:
777
+ """Extract warehouse information from the Genie space.
778
+
779
+ Returns:
780
+ WarehouseModel instance if warehouse_id is available, None otherwise.
781
+ """
782
+ space_details: GenieSpace = self._get_space_details()
783
+
784
+ if not space_details.warehouse_id:
785
+ return None
786
+
787
+ try:
788
+ response: GetWarehouseResponse = self.workspace_client.warehouses.get(
789
+ space_details.warehouse_id
790
+ )
791
+ warehouse_name: str = response.name or space_details.warehouse_id
792
+
793
+ warehouse_model = WarehouseModel(
794
+ name=warehouse_name,
795
+ warehouse_id=space_details.warehouse_id,
796
+ on_behalf_of_user=self.on_behalf_of_user,
797
+ service_principal=self.service_principal,
798
+ client_id=self.client_id,
799
+ client_secret=self.client_secret,
800
+ workspace_host=self.workspace_host,
801
+ pat=self.pat,
802
+ )
803
+
804
+ # Share the cached workspace client if available
805
+ if self._workspace_client is not None:
806
+ warehouse_model._workspace_client = self._workspace_client
807
+
808
+ return warehouse_model
809
+ except Exception as e:
810
+ logger.warning(
811
+ f"Failed to fetch warehouse details for {space_details.warehouse_id}: {e}"
812
+ )
813
+ return None
814
+
815
+ @property
816
+ def tables(self) -> list[TableModel]:
817
+ """Extract tables from the serialized Genie space.
818
+
819
+ Databricks Genie stores tables in: data_sources.tables[].identifier
820
+ Only includes tables that actually exist in Unity Catalog.
821
+ """
822
+ parsed_space = self._parse_serialized_space()
823
+ tables_list: list[TableModel] = []
824
+
825
+ # Primary structure: data_sources.tables with 'identifier' field
826
+ if "data_sources" in parsed_space:
827
+ data_sources = parsed_space["data_sources"]
828
+ if isinstance(data_sources, dict) and "tables" in data_sources:
829
+ tables_data = data_sources["tables"]
830
+ if isinstance(tables_data, list):
831
+ for table_item in tables_data:
832
+ table_name: str | None = None
833
+ if isinstance(table_item, dict):
834
+ # Standard Databricks structure uses 'identifier'
835
+ table_name = table_item.get("identifier") or table_item.get(
836
+ "name"
837
+ )
838
+ elif isinstance(table_item, str):
839
+ table_name = table_item
840
+
841
+ if table_name:
842
+ table_model = TableModel(
843
+ name=table_name,
844
+ on_behalf_of_user=self.on_behalf_of_user,
845
+ service_principal=self.service_principal,
846
+ client_id=self.client_id,
847
+ client_secret=self.client_secret,
848
+ workspace_host=self.workspace_host,
849
+ pat=self.pat,
850
+ )
851
+ # Share the cached workspace client if available
852
+ if self._workspace_client is not None:
853
+ table_model._workspace_client = self._workspace_client
854
+
855
+ # Verify the table exists before adding
856
+ if not table_model.exists():
857
+ continue
858
+
859
+ tables_list.append(table_model)
860
+
861
+ return tables_list
862
+
863
+ @property
864
+ def functions(self) -> list[FunctionModel]:
865
+ """Extract functions from the serialized Genie space.
866
+
867
+ Databricks Genie stores functions in multiple locations:
868
+ - instructions.sql_functions[].identifier (SQL functions)
869
+ - data_sources.functions[].identifier (other functions)
870
+ Only includes functions that actually exist in Unity Catalog.
871
+ """
872
+ parsed_space = self._parse_serialized_space()
873
+ functions_list: list[FunctionModel] = []
874
+ seen_functions: set[str] = set()
875
+
876
+ def add_function_if_exists(function_name: str) -> None:
877
+ """Helper to add a function if it exists and hasn't been added."""
878
+ if function_name in seen_functions:
879
+ return
880
+
881
+ seen_functions.add(function_name)
882
+ function_model = FunctionModel(
883
+ name=function_name,
884
+ on_behalf_of_user=self.on_behalf_of_user,
885
+ service_principal=self.service_principal,
886
+ client_id=self.client_id,
887
+ client_secret=self.client_secret,
888
+ workspace_host=self.workspace_host,
889
+ pat=self.pat,
890
+ )
891
+ # Share the cached workspace client if available
892
+ if self._workspace_client is not None:
893
+ function_model._workspace_client = self._workspace_client
894
+
895
+ # Verify the function exists before adding
896
+ if not function_model.exists():
897
+ return
898
+
899
+ functions_list.append(function_model)
900
+
901
+ # Primary structure: instructions.sql_functions with 'identifier' field
902
+ if "instructions" in parsed_space:
903
+ instructions = parsed_space["instructions"]
904
+ if isinstance(instructions, dict) and "sql_functions" in instructions:
905
+ sql_functions_data = instructions["sql_functions"]
906
+ if isinstance(sql_functions_data, list):
907
+ for function_item in sql_functions_data:
908
+ if isinstance(function_item, dict):
909
+ # SQL functions use 'identifier' field
910
+ function_name = function_item.get(
911
+ "identifier"
912
+ ) or function_item.get("name")
913
+ if function_name:
914
+ add_function_if_exists(function_name)
915
+
916
+ # Secondary structure: data_sources.functions with 'identifier' field
917
+ if "data_sources" in parsed_space:
918
+ data_sources = parsed_space["data_sources"]
919
+ if isinstance(data_sources, dict) and "functions" in data_sources:
920
+ functions_data = data_sources["functions"]
921
+ if isinstance(functions_data, list):
922
+ for function_item in functions_data:
923
+ function_name: str | None = None
924
+ if isinstance(function_item, dict):
925
+ # Standard Databricks structure uses 'identifier'
926
+ function_name = function_item.get(
927
+ "identifier"
928
+ ) or function_item.get("name")
929
+ elif isinstance(function_item, str):
930
+ function_name = function_item
931
+
932
+ if function_name:
933
+ add_function_if_exists(function_name)
934
+
935
+ return functions_list
936
+
439
937
  @property
440
938
  def api_scopes(self) -> Sequence[str]:
441
939
  return [
@@ -451,12 +949,24 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
451
949
  ]
452
950
 
453
951
  @model_validator(mode="after")
454
- def update_space_id(self):
952
+ def update_space_id(self) -> Self:
455
953
  self.space_id = value_of(self.space_id)
456
954
  return self
457
955
 
956
+ @model_validator(mode="after")
957
+ def update_description_from_space(self) -> Self:
958
+ """Populate description from GenieSpace if not provided."""
959
+ if not self.description:
960
+ try:
961
+ space_details = self._get_space_details()
962
+ if space_details.description:
963
+ self.description = space_details.description
964
+ except Exception as e:
965
+ logger.debug(f"Could not fetch description from Genie space: {e}")
966
+ return self
458
967
 
459
- class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
968
+
969
+ class VolumeModel(IsDatabricksResource, HasFullName):
460
970
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
461
971
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
462
972
  name: str
@@ -516,28 +1026,93 @@ class VolumePathModel(BaseModel, HasFullName):
516
1026
  provider.create_path(self)
517
1027
 
518
1028
 
519
- class VectorStoreModel(BaseModel, IsDatabricksResource):
1029
+ class VectorStoreModel(IsDatabricksResource):
1030
+ """
1031
+ Configuration model for a Databricks Vector Search store.
1032
+
1033
+ Supports two modes:
1034
+ 1. **Use Existing Index**: Provide only `index` (fully qualified name).
1035
+ Used for querying an existing vector search index at runtime.
1036
+ 2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
1037
+ Used for creating a new vector search index.
1038
+
1039
+ Examples:
1040
+ Minimal configuration (use existing index):
1041
+ ```yaml
1042
+ vector_stores:
1043
+ products_search:
1044
+ index:
1045
+ name: catalog.schema.my_index
1046
+ ```
1047
+
1048
+ Full provisioning configuration:
1049
+ ```yaml
1050
+ vector_stores:
1051
+ products_search:
1052
+ source_table:
1053
+ schema: *my_schema
1054
+ name: products
1055
+ embedding_source_column: description
1056
+ endpoint:
1057
+ name: my_endpoint
1058
+ ```
1059
+ """
1060
+
520
1061
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
521
- embedding_model: Optional[LLMModel] = None
1062
+
1063
+ # RUNTIME: Only index is truly required for querying existing indexes
522
1064
  index: Optional[IndexModel] = None
1065
+
1066
+ # PROVISIONING ONLY: Required when creating a new index
1067
+ source_table: Optional[TableModel] = None
1068
+ embedding_source_column: Optional[str] = None
1069
+ embedding_model: Optional[LLMModel] = None
523
1070
  endpoint: Optional[VectorSearchEndpoint] = None
524
- source_table: TableModel
1071
+
1072
+ # OPTIONAL: For both modes
525
1073
  source_path: Optional[VolumePathModel] = None
526
1074
  checkpoint_path: Optional[VolumePathModel] = None
527
1075
  primary_key: Optional[str] = None
528
1076
  columns: Optional[list[str]] = Field(default_factory=list)
529
1077
  doc_uri: Optional[str] = None
530
- embedding_source_column: str
531
1078
 
532
1079
  @model_validator(mode="after")
533
- def set_default_embedding_model(self):
534
- if not self.embedding_model:
1080
+ def validate_configuration_mode(self) -> Self:
1081
+ """
1082
+ Validate that configuration is valid for either:
1083
+ - Use existing mode: index is provided
1084
+ - Provisioning mode: source_table + embedding_source_column provided
1085
+ """
1086
+ has_index = self.index is not None
1087
+ has_source_table = self.source_table is not None
1088
+ has_embedding_col = self.embedding_source_column is not None
1089
+
1090
+ # Must have at least index OR source_table
1091
+ if not has_index and not has_source_table:
1092
+ raise ValueError(
1093
+ "Either 'index' (for existing indexes) or 'source_table' "
1094
+ "(for provisioning) must be provided"
1095
+ )
1096
+
1097
+ # If provisioning mode, need embedding_source_column
1098
+ if has_source_table and not has_embedding_col:
1099
+ raise ValueError(
1100
+ "embedding_source_column is required when source_table is provided (provisioning mode)"
1101
+ )
1102
+
1103
+ return self
1104
+
1105
+ @model_validator(mode="after")
1106
+ def set_default_embedding_model(self) -> Self:
1107
+ # Only set default embedding model in provisioning mode
1108
+ if self.source_table is not None and not self.embedding_model:
535
1109
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
536
1110
  return self
537
1111
 
538
1112
  @model_validator(mode="after")
539
- def set_default_primary_key(self):
540
- if self.primary_key is None:
1113
+ def set_default_primary_key(self) -> Self:
1114
+ # Only auto-discover primary key in provisioning mode
1115
+ if self.primary_key is None and self.source_table is not None:
541
1116
  from dao_ai.providers.databricks import DatabricksProvider
542
1117
 
543
1118
  provider: DatabricksProvider = DatabricksProvider()
@@ -557,15 +1132,17 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
557
1132
  return self
558
1133
 
559
1134
  @model_validator(mode="after")
560
- def set_default_index(self):
561
- if self.index is None:
1135
+ def set_default_index(self) -> Self:
1136
+ # Only generate index from source_table in provisioning mode
1137
+ if self.index is None and self.source_table is not None:
562
1138
  name: str = f"{self.source_table.name}_index"
563
1139
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
564
1140
  return self
565
1141
 
566
1142
  @model_validator(mode="after")
567
- def set_default_endpoint(self):
568
- if self.endpoint is None:
1143
+ def set_default_endpoint(self) -> Self:
1144
+ # Only find/create endpoint in provisioning mode
1145
+ if self.endpoint is None and self.source_table is not None:
569
1146
  from dao_ai.providers.databricks import (
570
1147
  DatabricksProvider,
571
1148
  with_available_indexes,
@@ -600,77 +1177,64 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
600
1177
  return self.index.as_resources()
601
1178
 
602
1179
  def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
603
- from dao_ai.providers.base import ServiceProvider
604
1180
  from dao_ai.providers.databricks import DatabricksProvider
605
1181
 
606
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
1182
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
607
1183
  index: VectorSearchIndex = provider.get_vector_index(self)
608
1184
  return index
609
1185
 
610
- def create(self, vsc: VectorSearchClient | None = None) -> None:
611
- from dao_ai.providers.base import ServiceProvider
612
- from dao_ai.providers.databricks import DatabricksProvider
1186
+ def create(self, vsc: VectorSearchClient | None = None) -> None:
1187
+ """
1188
+ Create or validate the vector search index.
1189
+
1190
+ Behavior depends on configuration mode:
1191
+ - **Provisioning Mode** (source_table provided): Creates the index
1192
+ - **Use Existing Mode** (only index provided): Validates the index exists
613
1193
 
614
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
615
- provider.create_vector_store(self)
1194
+ Args:
1195
+ vsc: Optional VectorSearchClient instance
616
1196
 
1197
+ Raises:
1198
+ ValueError: If configuration is invalid or index doesn't exist
1199
+ """
1200
+ from dao_ai.providers.databricks import DatabricksProvider
617
1201
 
618
- class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
619
- model_config = ConfigDict()
620
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
621
- name: Optional[str] = None
1202
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
622
1203
 
623
- @model_validator(mode="after")
624
- def validate_name_or_schema_required(self) -> "FunctionModel":
625
- if not self.name and not self.schema_model:
626
- raise ValueError(
627
- "Either 'name' or 'schema_model' must be provided for FunctionModel"
628
- )
629
- return self
1204
+ if self.source_table is not None:
1205
+ self._create_new_index(provider)
1206
+ else:
1207
+ self._validate_existing_index(provider)
630
1208
 
631
- @property
632
- def full_name(self) -> str:
633
- if self.schema_model:
634
- name: str = ""
635
- if self.name:
636
- name = f".{self.name}"
637
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
638
- return self.name
1209
+ def _validate_existing_index(self, provider: Any) -> None:
1210
+ """Validate that an existing index is accessible."""
1211
+ if self.index is None:
1212
+ raise ValueError("index is required for 'use existing' mode")
639
1213
 
640
- def as_resources(self) -> Sequence[DatabricksResource]:
641
- resources: list[DatabricksResource] = []
642
- if self.name:
643
- resources.append(
644
- DatabricksFunction(
645
- function_name=self.full_name,
646
- on_behalf_of_user=self.on_behalf_of_user,
647
- )
1214
+ if self.index.exists():
1215
+ logger.info(
1216
+ "Vector search index exists and ready",
1217
+ index_name=self.index.full_name,
648
1218
  )
649
1219
  else:
650
- w: WorkspaceClient = self.workspace_client
651
- schema_full_name: str = self.schema_model.full_name
652
- functions: Iterator[FunctionInfo] = w.functions.list(
653
- catalog_name=self.schema_model.catalog_name,
654
- schema_name=self.schema_model.schema_name,
655
- )
656
- resources.extend(
657
- [
658
- DatabricksFunction(
659
- function_name=f"{schema_full_name}.{function.name}",
660
- on_behalf_of_user=self.on_behalf_of_user,
661
- )
662
- for function in functions
663
- ]
1220
+ raise ValueError(
1221
+ f"Index '{self.index.full_name}' does not exist. "
1222
+ "Provide 'source_table' to provision it."
664
1223
  )
665
1224
 
666
- return resources
1225
+ def _create_new_index(self, provider: Any) -> None:
1226
+ """Create a new vector search index from source table."""
1227
+ if self.embedding_source_column is None:
1228
+ raise ValueError("embedding_source_column is required for provisioning")
1229
+ if self.endpoint is None:
1230
+ raise ValueError("endpoint is required for provisioning")
1231
+ if self.index is None:
1232
+ raise ValueError("index is required for provisioning")
667
1233
 
668
- @property
669
- def api_scopes(self) -> Sequence[str]:
670
- return ["sql.statement-execution"]
1234
+ provider.create_vector_store(self)
671
1235
 
672
1236
 
673
- class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
1237
+ class ConnectionModel(IsDatabricksResource, HasFullName):
674
1238
  model_config = ConfigDict()
675
1239
  name: str
676
1240
 
@@ -697,34 +1261,58 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
697
1261
  ]
698
1262
 
699
1263
 
700
- class WarehouseModel(BaseModel, IsDatabricksResource):
701
- model_config = ConfigDict()
702
- name: str
703
- description: Optional[str] = None
704
- warehouse_id: AnyVariable
705
-
706
- @property
707
- def api_scopes(self) -> Sequence[str]:
708
- return [
709
- "sql.warehouses",
710
- "sql.statement-execution",
711
- ]
712
-
713
- def as_resources(self) -> Sequence[DatabricksResource]:
714
- return [
715
- DatabricksSQLWarehouse(
716
- warehouse_id=value_of(self.warehouse_id),
717
- on_behalf_of_user=self.on_behalf_of_user,
718
- )
719
- ]
720
-
721
- @model_validator(mode="after")
722
- def update_warehouse_id(self):
723
- self.warehouse_id = value_of(self.warehouse_id)
724
- return self
725
-
1264
+ class DatabaseModel(IsDatabricksResource):
1265
+ """
1266
+ Configuration for database connections supporting both Databricks Lakebase and standard PostgreSQL.
1267
+
1268
+ Authentication is inherited from IsDatabricksResource. Additionally supports:
1269
+ - user/password: For user-based database authentication
1270
+
1271
+ Connection Types (determined by fields provided):
1272
+ - Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
1273
+ - Standard PostgreSQL: Provide `host` (authentication required via user/password)
1274
+
1275
+ Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
1276
+
1277
+ Example Databricks Lakebase with Service Principal:
1278
+ ```yaml
1279
+ databases:
1280
+ my_lakebase:
1281
+ name: my-database
1282
+ instance_name: my-lakebase-instance
1283
+ service_principal:
1284
+ client_id:
1285
+ env: SERVICE_PRINCIPAL_CLIENT_ID
1286
+ client_secret:
1287
+ scope: my-scope
1288
+ secret: sp-client-secret
1289
+ workspace_host:
1290
+ env: DATABRICKS_HOST
1291
+ ```
1292
+
1293
+ Example Databricks Lakebase with Ambient Authentication:
1294
+ ```yaml
1295
+ databases:
1296
+ my_lakebase:
1297
+ name: my-database
1298
+ instance_name: my-lakebase-instance
1299
+ on_behalf_of_user: true
1300
+ ```
1301
+
1302
+ Example Standard PostgreSQL:
1303
+ ```yaml
1304
+ databases:
1305
+ my_postgres:
1306
+ name: my-database
1307
+ host: my-postgres-host.example.com
1308
+ port: 5432
1309
+ database: my_db
1310
+ user: my_user
1311
+ password:
1312
+ env: PGPASSWORD
1313
+ ```
1314
+ """
726
1315
 
727
- class DatabaseModel(BaseModel, IsDatabricksResource):
728
1316
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
729
1317
  name: str
730
1318
  instance_name: Optional[str] = None
@@ -737,80 +1325,117 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
737
1325
  timeout_seconds: Optional[int] = 10
738
1326
  capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
739
1327
  node_count: Optional[int] = None
1328
+ # Database-specific auth (user identity for DB connection)
740
1329
  user: Optional[AnyVariable] = None
741
1330
  password: Optional[AnyVariable] = None
742
- client_id: Optional[AnyVariable] = None
743
- client_secret: Optional[AnyVariable] = None
744
- workspace_host: Optional[AnyVariable] = None
745
1331
 
746
1332
  @property
747
1333
  def api_scopes(self) -> Sequence[str]:
748
- return []
1334
+ return ["database.database-instances"]
1335
+
1336
+ @property
1337
+ def is_lakebase(self) -> bool:
1338
+ """Returns True if this is a Databricks Lakebase connection (instance_name provided)."""
1339
+ return self.instance_name is not None
749
1340
 
750
1341
  def as_resources(self) -> Sequence[DatabricksResource]:
751
- return [
752
- DatabricksLakebase(
753
- database_instance_name=self.instance_name,
754
- on_behalf_of_user=self.on_behalf_of_user,
755
- )
756
- ]
1342
+ if self.is_lakebase:
1343
+ return [
1344
+ DatabricksLakebase(
1345
+ database_instance_name=self.instance_name,
1346
+ on_behalf_of_user=self.on_behalf_of_user,
1347
+ )
1348
+ ]
1349
+ return []
757
1350
 
758
1351
  @model_validator(mode="after")
759
- def update_instance_name(self):
760
- if self.instance_name is None:
761
- self.instance_name = self.name
1352
+ def validate_connection_type(self) -> Self:
1353
+ """Validate connection configuration based on type.
762
1354
 
1355
+ - If instance_name is provided: Databricks Lakebase connection
1356
+ (host is optional - will be fetched from API if not provided)
1357
+ - If only host is provided: Standard PostgreSQL connection
1358
+ (must not have instance_name)
1359
+ """
1360
+ if not self.instance_name and not self.host:
1361
+ raise ValueError(
1362
+ "Either instance_name (Databricks Lakebase) or host (PostgreSQL) must be provided."
1363
+ )
763
1364
  return self
764
1365
 
765
1366
  @model_validator(mode="after")
766
- def update_user(self):
767
- if self.client_id or self.user:
1367
+ def update_user(self) -> Self:
1368
+ # Skip if using OBO (passive auth), explicit credentials, or explicit user
1369
+ if self.on_behalf_of_user or self.client_id or self.user or self.pat:
768
1370
  return self
769
1371
 
770
- self.user = self.workspace_client.current_user.me().user_name
771
- if not self.user:
772
- raise ValueError(
773
- "Unable to determine current user. Please provide a user name or OAuth credentials."
774
- )
1372
+ # For standard PostgreSQL, we need explicit user credentials
1373
+ # For Lakebase with no auth, ambient auth is allowed
1374
+ if not self.is_lakebase:
1375
+ # Standard PostgreSQL - try to determine current user for local development
1376
+ try:
1377
+ self.user = self.workspace_client.current_user.me().user_name
1378
+ except Exception as e:
1379
+ logger.warning(
1380
+ f"Could not determine current user for PostgreSQL database: {e}. "
1381
+ f"Please provide explicit user credentials."
1382
+ )
1383
+ else:
1384
+ # For Lakebase, try to determine current user but don't fail if we can't
1385
+ try:
1386
+ self.user = self.workspace_client.current_user.me().user_name
1387
+ except Exception:
1388
+ # If we can't determine user and no explicit auth, that's okay
1389
+ # for Lakebase with ambient auth - credentials will be injected at runtime
1390
+ pass
775
1391
 
776
1392
  return self
777
1393
 
778
1394
  @model_validator(mode="after")
779
- def update_host(self):
780
- if self.host is not None:
1395
+ def update_host(self) -> Self:
1396
+ # Lakebase uses instance_name directly via databricks_langchain - host not needed
1397
+ if self.is_lakebase:
781
1398
  return self
782
1399
 
783
- existing_instance: DatabaseInstance = (
784
- self.workspace_client.database.get_database_instance(
785
- name=self.instance_name
786
- )
787
- )
788
- self.host = existing_instance.read_write_dns
1400
+ # For standard PostgreSQL, host must be provided by the user
1401
+ # (enforced by validate_connection_type)
789
1402
  return self
790
1403
 
791
1404
  @model_validator(mode="after")
792
- def validate_auth_methods(self):
1405
+ def validate_auth_methods(self) -> Self:
793
1406
  oauth_fields: Sequence[Any] = [
794
1407
  self.workspace_host,
795
1408
  self.client_id,
796
1409
  self.client_secret,
797
1410
  ]
798
1411
  has_oauth: bool = all(field is not None for field in oauth_fields)
1412
+ has_user_auth: bool = self.user is not None
1413
+ has_obo: bool = self.on_behalf_of_user is True
1414
+ has_pat: bool = self.pat is not None
799
1415
 
800
- pat_fields: Sequence[Any] = [self.user]
801
- has_user_auth: bool = all(field is not None for field in pat_fields)
1416
+ # Count how many auth methods are configured
1417
+ auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
802
1418
 
803
- if has_oauth and has_user_auth:
1419
+ if auth_methods_count > 1:
804
1420
  raise ValueError(
805
- "Cannot use both OAuth and user authentication methods. "
806
- "Please provide either OAuth credentials or user credentials."
1421
+ "Cannot mix authentication methods. "
1422
+ "Please provide exactly one of: "
1423
+ "on_behalf_of_user=true (for passive auth in model serving), "
1424
+ "OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
1425
+ "PAT (personal access token), "
1426
+ "or user credentials (user)."
807
1427
  )
808
1428
 
809
- if not has_oauth and not has_user_auth:
1429
+ # For standard PostgreSQL (host-based), at least one auth method must be configured
1430
+ # For Lakebase (instance_name-based), auth is optional (supports ambient authentication)
1431
+ if not self.is_lakebase and auth_methods_count == 0:
810
1432
  raise ValueError(
811
- "At least one authentication method must be provided: "
812
- "either OAuth credentials (workspace_host, client_id, client_secret) "
813
- "or user credentials (user, password)."
1433
+ "PostgreSQL databases require explicit authentication. "
1434
+ "Please provide one of: "
1435
+ "OAuth credentials (workspace_host, client_id, client_secret), "
1436
+ "service_principal with workspace_host, "
1437
+ "PAT (personal access token), "
1438
+ "or user credentials (user)."
814
1439
  )
815
1440
 
816
1441
  return self
@@ -821,38 +1446,76 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
821
1446
  Get database connection parameters as a dictionary.
822
1447
 
823
1448
  Returns a dict with connection parameters suitable for psycopg ConnectionPool.
824
- If username is configured, it will be included; otherwise it will be omitted
825
- to allow Lakebase to authenticate using the token's identity.
1449
+
1450
+ For Lakebase: Uses Databricks-generated credentials (token-based auth).
1451
+ For standard PostgreSQL: Uses provided user/password credentials.
826
1452
  """
827
- from dao_ai.providers.base import ServiceProvider
828
- from dao_ai.providers.databricks import DatabricksProvider
1453
+ import uuid as _uuid
829
1454
 
1455
+ from databricks.sdk.service.database import DatabaseCredential
1456
+
1457
+ host: str
1458
+ port: int
1459
+ database: str
830
1460
  username: str | None = None
1461
+ password_value: str | None = None
1462
+
1463
+ # Resolve host - may need to fetch at runtime for OBO mode
1464
+ host_value: Any = self.host
1465
+ if host_value is None and self.is_lakebase and self.on_behalf_of_user:
1466
+ # Fetch host at runtime for OBO mode
1467
+ existing_instance: DatabaseInstance = (
1468
+ self.workspace_client.database.get_database_instance(
1469
+ name=self.instance_name
1470
+ )
1471
+ )
1472
+ host_value = existing_instance.read_write_dns
1473
+
1474
+ if host_value is None:
1475
+ instance_or_name = self.instance_name if self.is_lakebase else self.name
1476
+ raise ValueError(
1477
+ f"Database host not configured for {instance_or_name}. "
1478
+ "Please provide 'host' explicitly."
1479
+ )
831
1480
 
832
- if self.client_id and self.client_secret and self.workspace_host:
833
- username = value_of(self.client_id)
834
- elif self.user:
835
- username = value_of(self.user)
1481
+ host = value_of(host_value)
1482
+ port = value_of(self.port)
1483
+ database = value_of(self.database)
836
1484
 
837
- host: str = value_of(self.host)
838
- port: int = value_of(self.port)
839
- database: str = value_of(self.database)
1485
+ if self.is_lakebase:
1486
+ # Lakebase: Use Databricks-generated credentials
1487
+ if self.client_id and self.client_secret and self.workspace_host:
1488
+ username = value_of(self.client_id)
1489
+ elif self.user:
1490
+ username = value_of(self.user)
1491
+ # For OBO mode, no username is needed - the token identity is used
840
1492
 
841
- provider: ServiceProvider = DatabricksProvider(
842
- client_id=value_of(self.client_id),
843
- client_secret=value_of(self.client_secret),
844
- workspace_host=value_of(self.workspace_host),
845
- pat=value_of(self.password),
846
- )
1493
+ # Generate Databricks database credential (token)
1494
+ w: WorkspaceClient = self.workspace_client
1495
+ cred: DatabaseCredential = w.database.generate_database_credential(
1496
+ request_id=str(_uuid.uuid4()),
1497
+ instance_names=[self.instance_name],
1498
+ )
1499
+ password_value = cred.token
1500
+ else:
1501
+ # Standard PostgreSQL: Use provided credentials
1502
+ if self.user:
1503
+ username = value_of(self.user)
1504
+ if self.password:
1505
+ password_value = value_of(self.password)
847
1506
 
848
- token: str = provider.lakebase_password_provider(self.instance_name)
1507
+ if not username or not password_value:
1508
+ raise ValueError(
1509
+ f"Standard PostgreSQL databases require both 'user' and 'password'. "
1510
+ f"Database: {self.name}"
1511
+ )
849
1512
 
850
1513
  # Build connection parameters dictionary
851
1514
  params: dict[str, Any] = {
852
1515
  "dbname": database,
853
1516
  "host": host,
854
1517
  "port": port,
855
- "password": token,
1518
+ "password": password_value,
856
1519
  "sslmode": "require",
857
1520
  }
858
1521
 
@@ -883,11 +1546,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
883
1546
  def create(self, w: WorkspaceClient | None = None) -> None:
884
1547
  from dao_ai.providers.databricks import DatabricksProvider
885
1548
 
886
- provider: DatabricksProvider = DatabricksProvider()
1549
+ # Use provided workspace client or fall back to resource's own workspace_client
1550
+ if w is None:
1551
+ w = self.workspace_client
1552
+ provider: DatabricksProvider = DatabricksProvider(w=w)
887
1553
  provider.create_lakebase(self)
888
1554
  provider.create_lakebase_instance_role(self)
889
1555
 
890
1556
 
1557
+ class GenieLRUCacheParametersModel(BaseModel):
1558
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1559
+ capacity: int = 1000
1560
+ time_to_live_seconds: int | None = (
1561
+ 60 * 60 * 24
1562
+ ) # 1 day default, None or negative = never expires
1563
+ warehouse: WarehouseModel
1564
+
1565
+
1566
+ class GenieSemanticCacheParametersModel(BaseModel):
1567
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1568
+ time_to_live_seconds: int | None = (
1569
+ 60 * 60 * 24
1570
+ ) # 1 day default, None or negative = never expires
1571
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1572
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1573
+ question_weight: Optional[float] = (
1574
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1575
+ )
1576
+ context_weight: Optional[float] = (
1577
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1578
+ )
1579
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1580
+ embedding_dims: int | None = None # Auto-detected if None
1581
+ database: DatabaseModel
1582
+ warehouse: WarehouseModel
1583
+ table_name: str = "genie_semantic_cache"
1584
+ context_window_size: int = 3 # Number of previous turns to include for context
1585
+ max_context_tokens: int = (
1586
+ 2000 # Maximum context length to prevent extremely long embeddings
1587
+ )
1588
+
1589
+ @model_validator(mode="after")
1590
+ def compute_and_validate_weights(self) -> Self:
1591
+ """
1592
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1593
+
1594
+ Either question_weight or context_weight (or both) can be provided.
1595
+ The missing one will be computed as 1.0 - provided_weight.
1596
+ If both are provided, they must sum to 1.0.
1597
+ """
1598
+ if self.question_weight is None and self.context_weight is None:
1599
+ # Both missing - use defaults
1600
+ self.question_weight = 0.6
1601
+ self.context_weight = 0.4
1602
+ elif self.question_weight is None:
1603
+ # Compute question_weight from context_weight
1604
+ if not (0.0 <= self.context_weight <= 1.0):
1605
+ raise ValueError(
1606
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1607
+ )
1608
+ self.question_weight = 1.0 - self.context_weight
1609
+ elif self.context_weight is None:
1610
+ # Compute context_weight from question_weight
1611
+ if not (0.0 <= self.question_weight <= 1.0):
1612
+ raise ValueError(
1613
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1614
+ )
1615
+ self.context_weight = 1.0 - self.question_weight
1616
+ else:
1617
+ # Both provided - validate they sum to 1.0
1618
+ total_weight = self.question_weight + self.context_weight
1619
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1620
+ raise ValueError(
1621
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1622
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1623
+ f"of question vs context similarity in the combined score."
1624
+ )
1625
+
1626
+ return self
1627
+
1628
+
891
1629
  class SearchParametersModel(BaseModel):
892
1630
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
893
1631
  num_results: Optional[int] = 10
@@ -918,11 +1656,13 @@ class RerankParametersModel(BaseModel):
918
1656
  top_n: 5 # Return top 5 after reranking
919
1657
  ```
920
1658
 
921
- Available models (from fastest to most accurate):
922
- - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
923
- - "ms-marco-MiniLM-L-6-v2"
924
- - "ms-marco-MiniLM-L-12-v2" (default, good balance)
925
- - "rank-T5-flan" (most accurate, slower)
1659
+ Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
1660
+ - "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
1661
+ - "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder, default)
1662
+ - "rank-T5-flan" (~110MB, best non cross-encoder)
1663
+ - "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
1664
+ - "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
1665
+ - "miniReranker_arabic_v1" (Arabic language)
926
1666
  """
927
1667
 
928
1668
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -936,8 +1676,8 @@ class RerankParametersModel(BaseModel):
936
1676
  description="Number of documents to return after reranking. If None, uses search_parameters.num_results.",
937
1677
  )
938
1678
  cache_dir: Optional[str] = Field(
939
- default="/tmp/flashrank_cache",
940
- description="Directory to cache downloaded model weights.",
1679
+ default="~/.dao_ai/cache/flashrank",
1680
+ description="Directory to cache downloaded model weights. Supports tilde expansion (e.g., ~/.dao_ai).",
941
1681
  )
942
1682
  columns: Optional[list[str]] = Field(
943
1683
  default_factory=list, description="Columns to rerank using DatabricksReranker"
@@ -957,14 +1697,14 @@ class RetrieverModel(BaseModel):
957
1697
  )
958
1698
 
959
1699
  @model_validator(mode="after")
960
- def set_default_columns(self):
1700
+ def set_default_columns(self) -> Self:
961
1701
  if not self.columns:
962
1702
  columns: Sequence[str] = self.vector_store.columns
963
1703
  self.columns = columns
964
1704
  return self
965
1705
 
966
1706
  @model_validator(mode="after")
967
- def set_default_reranker(self):
1707
+ def set_default_reranker(self) -> Self:
968
1708
  """Convert bool to ReRankParametersModel with defaults."""
969
1709
  if isinstance(self.rerank, bool) and self.rerank:
970
1710
  self.rerank = RerankParametersModel()
@@ -978,28 +1718,47 @@ class FunctionType(str, Enum):
978
1718
  MCP = "mcp"
979
1719
 
980
1720
 
981
- class HumanInTheLoopActionType(str, Enum):
982
- """Supported action types for human-in-the-loop interactions."""
1721
+ class HumanInTheLoopModel(BaseModel):
1722
+ """
1723
+ Configuration for Human-in-the-Loop tool approval.
983
1724
 
984
- ACCEPT = "accept"
985
- EDIT = "edit"
986
- RESPONSE = "response"
987
- DECLINE = "decline"
1725
+ This model configures when and how tools require human approval before execution.
1726
+ It maps to LangChain's HumanInTheLoopMiddleware.
988
1727
 
1728
+ LangChain supports three decision types:
1729
+ - "approve": Execute tool with original arguments
1730
+ - "edit": Modify arguments before execution
1731
+ - "reject": Skip execution with optional feedback message
1732
+ """
989
1733
 
990
- class HumanInTheLoopModel(BaseModel):
991
1734
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
992
- review_prompt: str = "Please review the tool call"
993
- interrupt_config: dict[str, Any] = Field(
994
- default_factory=lambda: {
995
- "allow_accept": True,
996
- "allow_edit": True,
997
- "allow_respond": True,
998
- "allow_decline": True,
999
- }
1735
+
1736
+ review_prompt: Optional[str] = Field(
1737
+ default=None,
1738
+ description="Message shown to the reviewer when approval is requested",
1739
+ )
1740
+
1741
+ allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
1742
+ default_factory=lambda: ["approve", "edit", "reject"],
1743
+ description="List of allowed decision types for this tool",
1000
1744
  )
1001
- decline_message: str = "Tool call declined by user"
1002
- custom_actions: Optional[dict[str, str]] = Field(default_factory=dict)
1745
+
1746
+ @model_validator(mode="after")
1747
+ def validate_and_normalize_decisions(self) -> Self:
1748
+ """Validate and normalize allowed decisions."""
1749
+ if not self.allowed_decisions:
1750
+ raise ValueError("At least one decision type must be allowed")
1751
+
1752
+ # Remove duplicates while preserving order
1753
+ seen = set()
1754
+ unique_decisions = []
1755
+ for decision in self.allowed_decisions:
1756
+ if decision not in seen:
1757
+ seen.add(decision)
1758
+ unique_decisions.append(decision)
1759
+ self.allowed_decisions = unique_decisions
1760
+
1761
+ return self
1003
1762
 
1004
1763
 
1005
1764
  class BaseFunctionModel(ABC, BaseModel):
@@ -1008,7 +1767,6 @@ class BaseFunctionModel(ABC, BaseModel):
1008
1767
  discriminator="type",
1009
1768
  )
1010
1769
  type: FunctionType
1011
- name: str
1012
1770
  human_in_the_loop: Optional[HumanInTheLoopModel] = None
1013
1771
 
1014
1772
  @abstractmethod
@@ -1025,6 +1783,7 @@ class BaseFunctionModel(ABC, BaseModel):
1025
1783
  class PythonFunctionModel(BaseFunctionModel, HasFullName):
1026
1784
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1027
1785
  type: Literal[FunctionType.PYTHON] = FunctionType.PYTHON
1786
+ name: str
1028
1787
 
1029
1788
  @property
1030
1789
  def full_name(self) -> str:
@@ -1038,8 +1797,9 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
1038
1797
 
1039
1798
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1040
1799
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1041
- args: Optional[dict[str, Any]] = Field(default_factory=dict)
1042
1800
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
1801
+ name: str
1802
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
1043
1803
 
1044
1804
  @property
1045
1805
  def full_name(self) -> str:
@@ -1051,7 +1811,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1051
1811
  return [create_factory_tool(self, **kwargs)]
1052
1812
 
1053
1813
  @model_validator(mode="after")
1054
- def update_args(self):
1814
+ def update_args(self) -> Self:
1055
1815
  for key, value in self.args.items():
1056
1816
  self.args[key] = value_of(value)
1057
1817
  return self
@@ -1062,7 +1822,16 @@ class TransportType(str, Enum):
1062
1822
  STDIO = "stdio"
1063
1823
 
1064
1824
 
1065
- class McpFunctionModel(BaseFunctionModel, HasFullName):
1825
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1826
+ """
1827
+ MCP Function Model with authentication inherited from IsDatabricksResource.
1828
+
1829
+ Authentication for MCP connections uses the same options as other resources:
1830
+ - Service Principal (client_id + client_secret + workspace_host)
1831
+ - PAT (pat + workspace_host)
1832
+ - OBO (on_behalf_of_user)
1833
+ """
1834
+
1066
1835
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1067
1836
  type: Literal[FunctionType.MCP] = FunctionType.MCP
1068
1837
  transport: TransportType = TransportType.STREAMABLE_HTTP
@@ -1070,10 +1839,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1070
1839
  url: Optional[AnyVariable] = None
1071
1840
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1072
1841
  args: list[str] = Field(default_factory=list)
1073
- pat: Optional[AnyVariable] = None
1074
- client_id: Optional[AnyVariable] = None
1075
- client_secret: Optional[AnyVariable] = None
1076
- workspace_host: Optional[AnyVariable] = None
1842
+ # MCP-specific fields
1077
1843
  connection: Optional[ConnectionModel] = None
1078
1844
  functions: Optional[SchemaModel] = None
1079
1845
  genie_room: Optional[GenieRoomModel] = None
@@ -1081,35 +1847,55 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1081
1847
  vector_search: Optional[VectorStoreModel] = None
1082
1848
 
1083
1849
  @property
1084
- def full_name(self) -> str:
1085
- return self.name
1850
+ def api_scopes(self) -> Sequence[str]:
1851
+ """API scopes for MCP connections."""
1852
+ return [
1853
+ "serving.serving-endpoints",
1854
+ "mcp.genie",
1855
+ "mcp.functions",
1856
+ "mcp.vectorsearch",
1857
+ "mcp.external",
1858
+ ]
1859
+
1860
+ def as_resources(self) -> Sequence[DatabricksResource]:
1861
+ """MCP functions don't declare static resources."""
1862
+ return []
1086
1863
 
1087
1864
  def _get_workspace_host(self) -> str:
1088
1865
  """
1089
1866
  Get the workspace host, either from config or from workspace client.
1090
1867
 
1091
1868
  If connection is provided, uses its workspace client.
1092
- Otherwise, falls back to creating a new workspace client.
1869
+ Otherwise, falls back to the default Databricks host.
1093
1870
 
1094
1871
  Returns:
1095
- str: The workspace host URL without trailing slash
1872
+ str: The workspace host URL with https:// scheme and without trailing slash
1096
1873
  """
1097
- from databricks.sdk import WorkspaceClient
1874
+ from dao_ai.utils import get_default_databricks_host, normalize_host
1098
1875
 
1099
1876
  # Try to get workspace_host from config
1100
1877
  workspace_host: str | None = (
1101
- value_of(self.workspace_host) if self.workspace_host else None
1878
+ normalize_host(value_of(self.workspace_host))
1879
+ if self.workspace_host
1880
+ else None
1102
1881
  )
1103
1882
 
1104
1883
  # If no workspace_host in config, get it from workspace client
1105
1884
  if not workspace_host:
1106
1885
  # Use connection's workspace client if available
1107
1886
  if self.connection:
1108
- workspace_host = self.connection.workspace_client.config.host
1887
+ workspace_host = normalize_host(
1888
+ self.connection.workspace_client.config.host
1889
+ )
1109
1890
  else:
1110
- # Create a default workspace client
1111
- w: WorkspaceClient = WorkspaceClient()
1112
- workspace_host = w.config.host
1891
+ # get_default_databricks_host already normalizes the host
1892
+ workspace_host = get_default_databricks_host()
1893
+
1894
+ if not workspace_host:
1895
+ raise ValueError(
1896
+ "Could not determine workspace host. "
1897
+ "Please set workspace_host in config or DATABRICKS_HOST environment variable."
1898
+ )
1113
1899
 
1114
1900
  # Remove trailing slash
1115
1901
  return workspace_host.rstrip("/")
@@ -1234,74 +2020,132 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1234
2020
  self.headers[key] = value_of(value)
1235
2021
  return self
1236
2022
 
1237
- @model_validator(mode="after")
1238
- def validate_auth_methods(self) -> "McpFunctionModel":
1239
- oauth_fields: Sequence[Any] = [
1240
- self.client_id,
1241
- self.client_secret,
1242
- ]
1243
- has_oauth: bool = all(field is not None for field in oauth_fields)
1244
-
1245
- pat_fields: Sequence[Any] = [self.pat]
1246
- has_user_auth: bool = all(field is not None for field in pat_fields)
2023
+ def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
2024
+ from dao_ai.tools import create_mcp_tools
1247
2025
 
1248
- if has_oauth and has_user_auth:
1249
- raise ValueError(
1250
- "Cannot use both OAuth and user authentication methods. "
1251
- "Please provide either OAuth credentials or user credentials."
1252
- )
2026
+ return create_mcp_tools(self)
1253
2027
 
1254
- # Note: workspace_host is optional - it will be derived from workspace client if not provided
1255
2028
 
1256
- return self
2029
+ class UnityCatalogFunctionModel(BaseFunctionModel):
2030
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2031
+ type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
2032
+ resource: FunctionModel
2033
+ partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1257
2034
 
1258
2035
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1259
- from dao_ai.tools import create_mcp_tools
2036
+ from dao_ai.tools import create_uc_tools
1260
2037
 
1261
- return create_mcp_tools(self)
2038
+ return create_uc_tools(self)
2039
+
2040
+
2041
+ AnyTool: TypeAlias = (
2042
+ Union[
2043
+ PythonFunctionModel,
2044
+ FactoryFunctionModel,
2045
+ UnityCatalogFunctionModel,
2046
+ McpFunctionModel,
2047
+ ]
2048
+ | str
2049
+ )
2050
+
2051
+
2052
+ class ToolModel(BaseModel):
2053
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2054
+ name: str
2055
+ function: AnyTool
1262
2056
 
1263
2057
 
1264
- class UnityCatalogFunctionModel(BaseFunctionModel, HasFullName):
2058
+ class PromptModel(BaseModel, HasFullName):
1265
2059
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1266
2060
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1267
- partial_args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
1268
- type: Literal[FunctionType.UNITY_CATALOG] = FunctionType.UNITY_CATALOG
2061
+ name: str
2062
+ description: Optional[str] = None
2063
+ default_template: Optional[str] = None
2064
+ alias: Optional[str] = None
2065
+ version: Optional[int] = None
2066
+ tags: Optional[dict[str, Any]] = Field(default_factory=dict)
2067
+ auto_register: bool = Field(
2068
+ default=False,
2069
+ description="Whether to automatically register the default_template to the prompt registry. "
2070
+ "If False, the prompt will only be loaded from the registry (never created/updated). "
2071
+ "Defaults to True for backward compatibility.",
2072
+ )
2073
+
2074
+ @property
2075
+ def template(self) -> str:
2076
+ from dao_ai.providers.databricks import DatabricksProvider
2077
+
2078
+ provider: DatabricksProvider = DatabricksProvider()
2079
+ prompt_version = provider.get_prompt(self)
2080
+ return prompt_version.to_single_brace_format()
2081
+
2082
+ @property
2083
+ def full_name(self) -> str:
2084
+ prompt_name: str = self.name
2085
+ if self.schema_model:
2086
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
2087
+ return prompt_name
1269
2088
 
1270
2089
  @property
1271
- def full_name(self) -> str:
1272
- if self.schema_model:
1273
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1274
- return self.name
2090
+ def uri(self) -> str:
2091
+ prompt_uri: str = f"prompts:/{self.full_name}"
1275
2092
 
1276
- def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1277
- from dao_ai.tools import create_uc_tools
2093
+ if self.alias:
2094
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
2095
+ elif self.version:
2096
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
2097
+ else:
2098
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1278
2099
 
1279
- return create_uc_tools(self)
2100
+ return prompt_uri
1280
2101
 
2102
+ def as_prompt(self) -> PromptVersion:
2103
+ prompt_version: PromptVersion = load_prompt(self.uri)
2104
+ return prompt_version
1281
2105
 
1282
- AnyTool: TypeAlias = (
1283
- Union[
1284
- PythonFunctionModel,
1285
- FactoryFunctionModel,
1286
- UnityCatalogFunctionModel,
1287
- McpFunctionModel,
1288
- ]
1289
- | str
1290
- )
2106
+ @model_validator(mode="after")
2107
+ def validate_mutually_exclusive(self) -> Self:
2108
+ if self.alias and self.version:
2109
+ raise ValueError("Cannot specify both alias and version")
2110
+ return self
1291
2111
 
1292
2112
 
1293
- class ToolModel(BaseModel):
2113
+ class GuardrailModel(BaseModel):
1294
2114
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1295
2115
  name: str
1296
- function: AnyTool
2116
+ model: str | LLMModel
2117
+ prompt: str | PromptModel
2118
+ num_retries: Optional[int] = 3
2119
+
2120
+ @model_validator(mode="after")
2121
+ def validate_llm_model(self) -> Self:
2122
+ if isinstance(self.model, str):
2123
+ self.model = LLMModel(name=self.model)
2124
+ return self
1297
2125
 
1298
2126
 
1299
- class GuardrailModel(BaseModel):
2127
+ class MiddlewareModel(BaseModel):
2128
+ """Configuration for middleware that can be applied to agents.
2129
+
2130
+ Middleware is defined at the AppConfig level and can be referenced by name
2131
+ in agent configurations using YAML anchors for reusability.
2132
+ """
2133
+
1300
2134
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1301
- name: str
1302
- model: LLMModel
1303
- prompt: str
1304
- num_retries: Optional[int] = 3
2135
+ name: str = Field(
2136
+ description="Fully qualified name of the middleware factory function"
2137
+ )
2138
+ args: dict[str, Any] = Field(
2139
+ default_factory=dict,
2140
+ description="Arguments to pass to the middleware factory function",
2141
+ )
2142
+
2143
+ @model_validator(mode="after")
2144
+ def resolve_args(self) -> Self:
2145
+ """Resolve any variable references in args."""
2146
+ for key, value in self.args.items():
2147
+ self.args[key] = value_of(value)
2148
+ return self
1305
2149
 
1306
2150
 
1307
2151
  class StorageType(str, Enum):
@@ -1312,14 +2156,12 @@ class StorageType(str, Enum):
1312
2156
  class CheckpointerModel(BaseModel):
1313
2157
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1314
2158
  name: str
1315
- type: Optional[StorageType] = StorageType.MEMORY
1316
2159
  database: Optional[DatabaseModel] = None
1317
2160
 
1318
- @model_validator(mode="after")
1319
- def validate_postgres_requires_database(self):
1320
- if self.type == StorageType.POSTGRES and not self.database:
1321
- raise ValueError("Database must be provided when storage type is POSTGRES")
1322
- return self
2161
+ @property
2162
+ def storage_type(self) -> StorageType:
2163
+ """Infer storage type from database presence."""
2164
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1323
2165
 
1324
2166
  def as_checkpointer(self) -> BaseCheckpointSaver:
1325
2167
  from dao_ai.memory import CheckpointManager
@@ -1335,16 +2177,14 @@ class StoreModel(BaseModel):
1335
2177
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1336
2178
  name: str
1337
2179
  embedding_model: Optional[LLMModel] = None
1338
- type: Optional[StorageType] = StorageType.MEMORY
1339
2180
  dims: Optional[int] = 1536
1340
2181
  database: Optional[DatabaseModel] = None
1341
2182
  namespace: Optional[str] = None
1342
2183
 
1343
- @model_validator(mode="after")
1344
- def validate_postgres_requires_database(self):
1345
- if self.type == StorageType.POSTGRES and not self.database:
1346
- raise ValueError("Database must be provided when storage type is POSTGRES")
1347
- return self
2184
+ @property
2185
+ def storage_type(self) -> StorageType:
2186
+ """Infer storage type from database presence."""
2187
+ return StorageType.POSTGRES if self.database else StorageType.MEMORY
1348
2188
 
1349
2189
  def as_store(self) -> BaseStore:
1350
2190
  from dao_ai.memory import StoreManager
@@ -1362,56 +2202,158 @@ class MemoryModel(BaseModel):
1362
2202
  FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
1363
2203
 
1364
2204
 
1365
- class PromptModel(BaseModel, HasFullName):
2205
+ class ResponseFormatModel(BaseModel):
2206
+ """
2207
+ Configuration for structured response formats.
2208
+
2209
+ The response_schema field accepts either a type or a string:
2210
+ - Type (Pydantic model, dataclass, etc.): Used directly for structured output
2211
+ - String: First attempts to load as a fully qualified type name, falls back to JSON schema string
2212
+
2213
+ This unified approach simplifies the API while maintaining flexibility.
2214
+ """
2215
+
1366
2216
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1367
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1368
- name: str
1369
- description: Optional[str] = None
1370
- default_template: Optional[str] = None
1371
- alias: Optional[str] = None
1372
- version: Optional[int] = None
1373
- tags: Optional[dict[str, Any]] = Field(default_factory=dict)
2217
+ use_tool: Optional[bool] = Field(
2218
+ default=None,
2219
+ description=(
2220
+ "Strategy for structured output: "
2221
+ "None (default) = auto-detect from model capabilities, "
2222
+ "False = force ProviderStrategy (native), "
2223
+ "True = force ToolStrategy (function calling)"
2224
+ ),
2225
+ )
2226
+ response_schema: Optional[str | type] = Field(
2227
+ default=None,
2228
+ description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
2229
+ )
1374
2230
 
1375
- @property
1376
- def template(self) -> str:
1377
- from dao_ai.providers.databricks import DatabricksProvider
2231
+ def as_strategy(self) -> ProviderStrategy | ToolStrategy:
2232
+ """
2233
+ Convert response_schema to appropriate LangChain strategy.
1378
2234
 
1379
- provider: DatabricksProvider = DatabricksProvider()
1380
- prompt_version = provider.get_prompt(self)
1381
- return prompt_version.to_single_brace_format()
2235
+ Returns:
2236
+ - None if no response_schema configured
2237
+ - Raw schema/type for auto-detection (when use_tool=None)
2238
+ - ToolStrategy wrapping the schema (when use_tool=True)
2239
+ - ProviderStrategy wrapping the schema (when use_tool=False)
1382
2240
 
1383
- @property
1384
- def full_name(self) -> str:
1385
- prompt_name: str = self.name
1386
- if self.schema_model:
1387
- prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1388
- return prompt_name
2241
+ Raises:
2242
+ ValueError: If response_schema is a JSON schema string that cannot be parsed
2243
+ """
1389
2244
 
1390
- @property
1391
- def uri(self) -> str:
1392
- prompt_uri: str = f"prompts:/{self.full_name}"
2245
+ if self.response_schema is None:
2246
+ return None
1393
2247
 
1394
- if self.alias:
1395
- prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1396
- elif self.version:
1397
- prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1398
- else:
1399
- prompt_uri = f"prompts:/{self.full_name}@latest"
2248
+ schema = self.response_schema
1400
2249
 
1401
- return prompt_uri
2250
+ # Handle type schemas (Pydantic, dataclass, etc.)
2251
+ if self.is_type_schema:
2252
+ if self.use_tool is None:
2253
+ # Auto-detect: Pass schema directly, let LangChain decide
2254
+ return schema
2255
+ elif self.use_tool is True:
2256
+ # Force ToolStrategy (function calling)
2257
+ return ToolStrategy(schema)
2258
+ else: # use_tool is False
2259
+ # Force ProviderStrategy (native structured output)
2260
+ return ProviderStrategy(schema)
1402
2261
 
1403
- def as_prompt(self) -> PromptVersion:
1404
- prompt_version: PromptVersion = load_prompt(self.uri)
1405
- return prompt_version
2262
+ # Handle JSON schema strings
2263
+ elif self.is_json_schema:
2264
+ import json
2265
+
2266
+ try:
2267
+ schema_dict = json.loads(schema)
2268
+ except json.JSONDecodeError as e:
2269
+ raise ValueError(f"Invalid JSON schema string: {e}") from e
2270
+
2271
+ # Apply same use_tool logic as type schemas
2272
+ if self.use_tool is None:
2273
+ # Auto-detect
2274
+ return schema_dict
2275
+ elif self.use_tool is True:
2276
+ # Force ToolStrategy
2277
+ return ToolStrategy(schema_dict)
2278
+ else: # use_tool is False
2279
+ # Force ProviderStrategy
2280
+ return ProviderStrategy(schema_dict)
2281
+
2282
+ return None
1406
2283
 
1407
2284
  @model_validator(mode="after")
1408
- def validate_mutually_exclusive(self):
1409
- if self.alias and self.version:
1410
- raise ValueError("Cannot specify both alias and version")
1411
- return self
2285
+ def validate_response_schema(self) -> Self:
2286
+ """
2287
+ Validate and convert response_schema.
2288
+
2289
+ Processing logic:
2290
+ 1. If None: no response format specified
2291
+ 2. If type: use directly as structured output type
2292
+ 3. If str: try to load as FQN using type_from_fqn
2293
+ - Success: response_schema becomes the loaded type
2294
+ - Failure: keep as string (treated as JSON schema)
2295
+
2296
+ After validation, response_schema is one of:
2297
+ - None (no schema)
2298
+ - type (use for structured output)
2299
+ - str (JSON schema)
2300
+
2301
+ Returns:
2302
+ Self with validated response_schema
2303
+ """
2304
+ if self.response_schema is None:
2305
+ return self
2306
+
2307
+ # If already a type, return
2308
+ if isinstance(self.response_schema, type):
2309
+ return self
2310
+
2311
+ # If it's a string, try to load as type, fallback to json_schema
2312
+ if isinstance(self.response_schema, str):
2313
+ from dao_ai.utils import type_from_fqn
2314
+
2315
+ try:
2316
+ resolved_type = type_from_fqn(self.response_schema)
2317
+ self.response_schema = resolved_type
2318
+ logger.debug(
2319
+ f"Resolved response_schema string to type: {resolved_type}"
2320
+ )
2321
+ return self
2322
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
2323
+ # Keep as string - it's a JSON schema
2324
+ logger.debug(
2325
+ f"Could not resolve '{self.response_schema}' as type: {e}. "
2326
+ f"Treating as JSON schema string."
2327
+ )
2328
+ return self
2329
+
2330
+ # Invalid type
2331
+ raise ValueError(
2332
+ f"response_schema must be None, type, or str, got {type(self.response_schema)}"
2333
+ )
2334
+
2335
+ @property
2336
+ def is_type_schema(self) -> bool:
2337
+ """Returns True if response_schema is a type (not JSON schema string)."""
2338
+ return isinstance(self.response_schema, type)
2339
+
2340
+ @property
2341
+ def is_json_schema(self) -> bool:
2342
+ """Returns True if response_schema is a JSON schema string (not a type)."""
2343
+ return isinstance(self.response_schema, str)
1412
2344
 
1413
2345
 
1414
2346
  class AgentModel(BaseModel):
2347
+ """
2348
+ Configuration model for an agent in the DAO AI framework.
2349
+
2350
+ Agents combine an LLM with tools and middleware to create systems that can
2351
+ reason about tasks, decide which tools to use, and iteratively work towards solutions.
2352
+
2353
+ Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
2354
+ providing a more flexible and composable way to customize agent behavior.
2355
+ """
2356
+
1415
2357
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1416
2358
  name: str
1417
2359
  description: Optional[str] = None
@@ -1420,9 +2362,43 @@ class AgentModel(BaseModel):
1420
2362
  guardrails: list[GuardrailModel] = Field(default_factory=list)
1421
2363
  prompt: Optional[str | PromptModel] = None
1422
2364
  handoff_prompt: Optional[str] = None
1423
- create_agent_hook: Optional[FunctionHook] = None
1424
- pre_agent_hook: Optional[FunctionHook] = None
1425
- post_agent_hook: Optional[FunctionHook] = None
2365
+ middleware: list[MiddlewareModel] = Field(
2366
+ default_factory=list,
2367
+ description="List of middleware to apply to this agent",
2368
+ )
2369
+ response_format: Optional[ResponseFormatModel | type | str] = None
2370
+
2371
+ @model_validator(mode="after")
2372
+ def validate_response_format(self) -> Self:
2373
+ """
2374
+ Validate and normalize response_format.
2375
+
2376
+ Accepts:
2377
+ - None (no response format)
2378
+ - ResponseFormatModel (already validated)
2379
+ - type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
2380
+ - str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
2381
+
2382
+ ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
2383
+ """
2384
+ if self.response_format is None or isinstance(
2385
+ self.response_format, ResponseFormatModel
2386
+ ):
2387
+ return self
2388
+
2389
+ # Convert type or str to ResponseFormatModel
2390
+ # ResponseFormatModel's validator will handle the smart type loading and fallback
2391
+ if isinstance(self.response_format, (type, str)):
2392
+ self.response_format = ResponseFormatModel(
2393
+ response_schema=self.response_format
2394
+ )
2395
+ return self
2396
+
2397
+ # Invalid type
2398
+ raise ValueError(
2399
+ f"response_format must be None, ResponseFormatModel, type, or str, "
2400
+ f"got {type(self.response_format)}"
2401
+ )
1426
2402
 
1427
2403
  def as_runnable(self) -> RunnableLike:
1428
2404
  from dao_ai.nodes import create_agent_node
@@ -1441,12 +2417,20 @@ class SupervisorModel(BaseModel):
1441
2417
  model: LLMModel
1442
2418
  tools: list[ToolModel] = Field(default_factory=list)
1443
2419
  prompt: Optional[str] = None
2420
+ middleware: list[MiddlewareModel] = Field(
2421
+ default_factory=list,
2422
+ description="List of middleware to apply to the supervisor",
2423
+ )
1444
2424
 
1445
2425
 
1446
2426
  class SwarmModel(BaseModel):
1447
2427
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1448
2428
  model: LLMModel
1449
2429
  default_agent: Optional[AgentModel | str] = None
2430
+ middleware: list[MiddlewareModel] = Field(
2431
+ default_factory=list,
2432
+ description="List of middleware to apply to all agents in the swarm",
2433
+ )
1450
2434
  handoffs: Optional[dict[str, Optional[list[AgentModel | str]]]] = Field(
1451
2435
  default_factory=dict
1452
2436
  )
@@ -1459,7 +2443,7 @@ class OrchestrationModel(BaseModel):
1459
2443
  memory: Optional[MemoryModel] = None
1460
2444
 
1461
2445
  @model_validator(mode="after")
1462
- def validate_mutually_exclusive(self):
2446
+ def validate_mutually_exclusive(self) -> Self:
1463
2447
  if self.supervisor is not None and self.swarm is not None:
1464
2448
  raise ValueError("Cannot specify both supervisor and swarm")
1465
2449
  if self.supervisor is None and self.swarm is None:
@@ -1489,9 +2473,21 @@ class Entitlement(str, Enum):
1489
2473
 
1490
2474
  class AppPermissionModel(BaseModel):
1491
2475
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1492
- principals: list[str] = Field(default_factory=list)
2476
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
1493
2477
  entitlements: list[Entitlement]
1494
2478
 
2479
+ @model_validator(mode="after")
2480
+ def resolve_principals(self) -> Self:
2481
+ """Resolve ServicePrincipalModel objects to their client_id."""
2482
+ resolved: list[str] = []
2483
+ for principal in self.principals:
2484
+ if isinstance(principal, ServicePrincipalModel):
2485
+ resolved.append(value_of(principal.client_id))
2486
+ else:
2487
+ resolved.append(principal)
2488
+ self.principals = resolved
2489
+ return self
2490
+
1495
2491
 
1496
2492
  class LogLevel(str, Enum):
1497
2493
  TRACE = "TRACE"
@@ -1552,6 +2548,28 @@ class ChatPayload(BaseModel):
1552
2548
 
1553
2549
  return self
1554
2550
 
2551
+ @model_validator(mode="after")
2552
+ def ensure_thread_id(self) -> "ChatPayload":
2553
+ """Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
2554
+ import uuid
2555
+
2556
+ if self.custom_inputs is None:
2557
+ self.custom_inputs = {}
2558
+
2559
+ # Get or create configurable section
2560
+ configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
2561
+
2562
+ # Check if thread_id or conversation_id exists
2563
+ has_thread_id = configurable.get("thread_id") is not None
2564
+ has_conversation_id = configurable.get("conversation_id") is not None
2565
+
2566
+ # If neither is provided, generate a UUID for conversation_id
2567
+ if not has_thread_id and not has_conversation_id:
2568
+ configurable["conversation_id"] = str(uuid.uuid4())
2569
+ self.custom_inputs["configurable"] = configurable
2570
+
2571
+ return self
2572
+
1555
2573
  def as_messages(self) -> Sequence[BaseMessage]:
1556
2574
  return messages_from_dict(
1557
2575
  [{"type": m.role, "content": m.content} for m in self.messages]
@@ -1567,25 +2585,44 @@ class ChatPayload(BaseModel):
1567
2585
 
1568
2586
 
1569
2587
  class ChatHistoryModel(BaseModel):
2588
+ """
2589
+ Configuration for chat history summarization.
2590
+
2591
+ Attributes:
2592
+ model: The LLM to use for generating summaries.
2593
+ max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
2594
+ After summarization, recent messages totaling up to this many tokens are preserved.
2595
+ max_tokens_before_summary: Token threshold that triggers summarization.
2596
+ When conversation exceeds this, summarization runs. Mutually exclusive with
2597
+ max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
2598
+ max_messages_before_summary: Message count threshold that triggers summarization.
2599
+ When conversation exceeds this many messages, summarization runs.
2600
+ Mutually exclusive with max_tokens_before_summary.
2601
+ """
2602
+
1570
2603
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1571
2604
  model: LLMModel
1572
- max_tokens: int = 256
1573
- max_tokens_before_summary: Optional[int] = None
1574
- max_messages_before_summary: Optional[int] = None
1575
- max_summary_tokens: int = 255
1576
-
1577
- @model_validator(mode="after")
1578
- def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1579
- if self.max_summary_tokens >= self.max_tokens:
1580
- raise ValueError(
1581
- f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1582
- )
1583
- return self
2605
+ max_tokens: int = Field(
2606
+ default=2048,
2607
+ gt=0,
2608
+ description="Maximum tokens to keep after summarization",
2609
+ )
2610
+ max_tokens_before_summary: Optional[int] = Field(
2611
+ default=None,
2612
+ gt=0,
2613
+ description="Token threshold that triggers summarization",
2614
+ )
2615
+ max_messages_before_summary: Optional[int] = Field(
2616
+ default=None,
2617
+ gt=0,
2618
+ description="Message count threshold that triggers summarization",
2619
+ )
1584
2620
 
1585
2621
 
1586
2622
  class AppModel(BaseModel):
1587
2623
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1588
2624
  name: str
2625
+ service_principal: Optional[ServicePrincipalModel] = None
1589
2626
  description: Optional[str] = None
1590
2627
  log_level: Optional[LogLevel] = "WARNING"
1591
2628
  registered_model: RegisteredModelModel
@@ -1606,23 +2643,54 @@ class AppModel(BaseModel):
1606
2643
  shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1607
2644
  default_factory=list
1608
2645
  )
1609
- message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1610
- default_factory=list
1611
- )
1612
2646
  input_example: Optional[ChatPayload] = None
1613
2647
  chat_history: Optional[ChatHistoryModel] = None
1614
2648
  code_paths: list[str] = Field(default_factory=list)
1615
2649
  pip_requirements: list[str] = Field(default_factory=list)
2650
+ python_version: Optional[str] = Field(
2651
+ default="3.12",
2652
+ description="Python version for Model Serving deployment. Defaults to 3.12 "
2653
+ "which is supported by Databricks Model Serving. This allows deploying from "
2654
+ "environments with different Python versions (e.g., Databricks Apps with 3.11).",
2655
+ )
2656
+
2657
+ @model_validator(mode="after")
2658
+ def set_databricks_env_vars(self) -> Self:
2659
+ """Set Databricks environment variables for Model Serving.
2660
+
2661
+ Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
2662
+ Values explicitly provided in environment_vars take precedence.
2663
+ """
2664
+ from dao_ai.utils import get_default_databricks_host
2665
+
2666
+ # Set DATABRICKS_HOST if not already provided
2667
+ if "DATABRICKS_HOST" not in self.environment_vars:
2668
+ host: str | None = get_default_databricks_host()
2669
+ if host:
2670
+ self.environment_vars["DATABRICKS_HOST"] = host
2671
+
2672
+ # Set service principal credentials if provided
2673
+ if self.service_principal is not None:
2674
+ if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
2675
+ self.environment_vars["DATABRICKS_CLIENT_ID"] = (
2676
+ self.service_principal.client_id
2677
+ )
2678
+ if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
2679
+ self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
2680
+ self.service_principal.client_secret
2681
+ )
2682
+ return self
1616
2683
 
1617
2684
  @model_validator(mode="after")
1618
- def validate_agents_not_empty(self):
2685
+ def validate_agents_not_empty(self) -> Self:
1619
2686
  if not self.agents:
1620
2687
  raise ValueError("At least one agent must be specified")
1621
2688
  return self
1622
2689
 
1623
2690
  @model_validator(mode="after")
1624
- def update_environment_vars(self):
2691
+ def resolve_environment_vars(self) -> Self:
1625
2692
  for key, value in self.environment_vars.items():
2693
+ updated_value: str
1626
2694
  if isinstance(value, SecretVariableModel):
1627
2695
  updated_value = str(value)
1628
2696
  else:
@@ -1632,7 +2700,7 @@ class AppModel(BaseModel):
1632
2700
  return self
1633
2701
 
1634
2702
  @model_validator(mode="after")
1635
- def set_default_orchestration(self):
2703
+ def set_default_orchestration(self) -> Self:
1636
2704
  if self.orchestration is None:
1637
2705
  if len(self.agents) > 1:
1638
2706
  default_agent: AgentModel = self.agents[0]
@@ -1652,14 +2720,14 @@ class AppModel(BaseModel):
1652
2720
  return self
1653
2721
 
1654
2722
  @model_validator(mode="after")
1655
- def set_default_endpoint_name(self):
2723
+ def set_default_endpoint_name(self) -> Self:
1656
2724
  if self.endpoint_name is None:
1657
2725
  self.endpoint_name = self.name
1658
2726
  return self
1659
2727
 
1660
2728
  @model_validator(mode="after")
1661
- def set_default_agent(self):
1662
- default_agent_name = self.agents[0].name
2729
+ def set_default_agent(self) -> Self:
2730
+ default_agent_name: str = self.agents[0].name
1663
2731
 
1664
2732
  if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
1665
2733
  self.orchestration.swarm.default_agent = default_agent_name
@@ -1667,7 +2735,7 @@ class AppModel(BaseModel):
1667
2735
  return self
1668
2736
 
1669
2737
  @model_validator(mode="after")
1670
- def add_code_paths_to_sys_path(self):
2738
+ def add_code_paths_to_sys_path(self) -> Self:
1671
2739
  for code_path in self.code_paths:
1672
2740
  parent_path: str = str(Path(code_path).parent)
1673
2741
  if parent_path not in sys.path:
@@ -1700,7 +2768,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
1700
2768
  expected_facts: Optional[list[str]] = None
1701
2769
 
1702
2770
  @model_validator(mode="after")
1703
- def validate_mutually_exclusive(self):
2771
+ def validate_mutually_exclusive(self) -> Self:
1704
2772
  if self.expected_response is not None and self.expected_facts is not None:
1705
2773
  raise ValueError("Cannot specify both expected_response and expected_facts")
1706
2774
  return self
@@ -1779,36 +2847,70 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
1779
2847
 
1780
2848
 
1781
2849
  class PromptOptimizationModel(BaseModel):
2850
+ """Configuration for prompt optimization using GEPA.
2851
+
2852
+ GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
2853
+ optimizer that uses reflective mutation to improve prompts based on
2854
+ evaluation feedback.
2855
+
2856
+ Example:
2857
+ prompt_optimization:
2858
+ name: optimize_my_prompt
2859
+ prompt: *my_prompt
2860
+ agent: *my_agent
2861
+ dataset: *my_training_dataset
2862
+ reflection_model: databricks-meta-llama-3-3-70b-instruct
2863
+ num_candidates: 50
2864
+ """
2865
+
1782
2866
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1783
2867
  name: str
1784
2868
  prompt: Optional[PromptModel] = None
1785
2869
  agent: AgentModel
1786
- dataset: (
1787
- EvaluationDatasetModel | str
1788
- ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
2870
+ dataset: EvaluationDatasetModel # Training dataset with examples
1789
2871
  reflection_model: Optional[LLMModel | str] = None
1790
2872
  num_candidates: Optional[int] = 50
1791
- scorer_model: Optional[LLMModel | str] = None
1792
2873
 
1793
2874
  def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1794
2875
  """
1795
- Optimize the prompt using MLflow's prompt optimization.
2876
+ Optimize the prompt using GEPA.
1796
2877
 
1797
2878
  Args:
1798
- w: Optional WorkspaceClient for Databricks operations
2879
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
1799
2880
 
1800
2881
  Returns:
1801
- PromptModel: The optimized prompt model with new URI
2882
+ PromptModel: The optimized prompt model
1802
2883
  """
1803
- from dao_ai.providers.base import ServiceProvider
1804
- from dao_ai.providers.databricks import DatabricksProvider
2884
+ from dao_ai.optimization import OptimizationResult, optimize_prompt
1805
2885
 
1806
- provider: ServiceProvider = DatabricksProvider(w=w)
1807
- optimized_prompt: PromptModel = provider.optimize_prompt(self)
1808
- return optimized_prompt
2886
+ # Get reflection model name
2887
+ reflection_model_name: str | None = None
2888
+ if self.reflection_model:
2889
+ if isinstance(self.reflection_model, str):
2890
+ reflection_model_name = self.reflection_model
2891
+ else:
2892
+ reflection_model_name = self.reflection_model.uri
2893
+
2894
+ # Ensure prompt is set
2895
+ prompt = self.prompt
2896
+ if prompt is None:
2897
+ raise ValueError(
2898
+ f"Prompt optimization '{self.name}' requires a prompt to be set"
2899
+ )
2900
+
2901
+ result: OptimizationResult = optimize_prompt(
2902
+ prompt=prompt,
2903
+ agent=self.agent,
2904
+ dataset=self.dataset,
2905
+ reflection_model=reflection_model_name,
2906
+ num_candidates=self.num_candidates or 50,
2907
+ register_if_improved=True,
2908
+ )
2909
+
2910
+ return result.optimized_prompt
1809
2911
 
1810
2912
  @model_validator(mode="after")
1811
- def set_defaults(self):
2913
+ def set_defaults(self) -> Self:
1812
2914
  # If no prompt is specified, try to use the agent's prompt
1813
2915
  if self.prompt is None:
1814
2916
  if isinstance(self.agent.prompt, PromptModel):
@@ -1819,12 +2921,6 @@ class PromptOptimizationModel(BaseModel):
1819
2921
  f"or an agent with a prompt configured"
1820
2922
  )
1821
2923
 
1822
- if self.reflection_model is None:
1823
- self.reflection_model = self.agent.model
1824
-
1825
- if self.scorer_model is None:
1826
- self.scorer_model = self.agent.model
1827
-
1828
2924
  return self
1829
2925
 
1830
2926
 
@@ -1897,7 +2993,7 @@ class UnityCatalogFunctionSqlTestModel(BaseModel):
1897
2993
 
1898
2994
  class UnityCatalogFunctionSqlModel(BaseModel):
1899
2995
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1900
- function: UnityCatalogFunctionModel
2996
+ function: FunctionModel
1901
2997
  ddl: str
1902
2998
  parameters: Optional[dict[str, Any]] = Field(default_factory=dict)
1903
2999
  test: Optional[UnityCatalogFunctionSqlTestModel] = None
@@ -1925,16 +3021,126 @@ class ResourcesModel(BaseModel):
1925
3021
  warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
1926
3022
  databases: dict[str, DatabaseModel] = Field(default_factory=dict)
1927
3023
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
3024
+ apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
3025
+
3026
+ @model_validator(mode="after")
3027
+ def update_genie_warehouses(self) -> Self:
3028
+ """
3029
+ Automatically populate warehouses from genie_rooms.
3030
+
3031
+ Warehouses are extracted from each Genie room and added to the
3032
+ resources if they don't already exist (based on warehouse_id).
3033
+ """
3034
+ if not self.genie_rooms:
3035
+ return self
3036
+
3037
+ # Process warehouses from all genie rooms
3038
+ for genie_room in self.genie_rooms.values():
3039
+ genie_room: GenieRoomModel
3040
+ warehouse: Optional[WarehouseModel] = genie_room.warehouse
3041
+
3042
+ if warehouse is None:
3043
+ continue
3044
+
3045
+ # Check if warehouse already exists based on warehouse_id
3046
+ warehouse_exists: bool = any(
3047
+ existing_warehouse.warehouse_id == warehouse.warehouse_id
3048
+ for existing_warehouse in self.warehouses.values()
3049
+ )
3050
+
3051
+ if not warehouse_exists:
3052
+ warehouse_key: str = normalize_name(
3053
+ "_".join([genie_room.name, warehouse.warehouse_id])
3054
+ )
3055
+ self.warehouses[warehouse_key] = warehouse
3056
+ logger.trace(
3057
+ "Added warehouse from Genie room",
3058
+ room=genie_room.name,
3059
+ warehouse=warehouse.warehouse_id,
3060
+ key=warehouse_key,
3061
+ )
3062
+
3063
+ return self
3064
+
3065
+ @model_validator(mode="after")
3066
+ def update_genie_tables(self) -> Self:
3067
+ """
3068
+ Automatically populate tables from genie_rooms.
3069
+
3070
+ Tables are extracted from each Genie room and added to the
3071
+ resources if they don't already exist (based on full_name).
3072
+ """
3073
+ if not self.genie_rooms:
3074
+ return self
3075
+
3076
+ # Process tables from all genie rooms
3077
+ for genie_room in self.genie_rooms.values():
3078
+ genie_room: GenieRoomModel
3079
+ for table in genie_room.tables:
3080
+ table: TableModel
3081
+ table_exists: bool = any(
3082
+ existing_table.full_name == table.full_name
3083
+ for existing_table in self.tables.values()
3084
+ )
3085
+ if not table_exists:
3086
+ table_key: str = normalize_name(
3087
+ "_".join([genie_room.name, table.full_name])
3088
+ )
3089
+ self.tables[table_key] = table
3090
+ logger.trace(
3091
+ "Added table from Genie room",
3092
+ room=genie_room.name,
3093
+ table=table.name,
3094
+ key=table_key,
3095
+ )
3096
+
3097
+ return self
3098
+
3099
+ @model_validator(mode="after")
3100
+ def update_genie_functions(self) -> Self:
3101
+ """
3102
+ Automatically populate functions from genie_rooms.
3103
+
3104
+ Functions are extracted from each Genie room and added to the
3105
+ resources if they don't already exist (based on full_name).
3106
+ """
3107
+ if not self.genie_rooms:
3108
+ return self
3109
+
3110
+ # Process functions from all genie rooms
3111
+ for genie_room in self.genie_rooms.values():
3112
+ genie_room: GenieRoomModel
3113
+ for function in genie_room.functions:
3114
+ function: FunctionModel
3115
+ function_exists: bool = any(
3116
+ existing_function.full_name == function.full_name
3117
+ for existing_function in self.functions.values()
3118
+ )
3119
+ if not function_exists:
3120
+ function_key: str = normalize_name(
3121
+ "_".join([genie_room.name, function.full_name])
3122
+ )
3123
+ self.functions[function_key] = function
3124
+ logger.trace(
3125
+ "Added function from Genie room",
3126
+ room=genie_room.name,
3127
+ function=function.name,
3128
+ key=function_key,
3129
+ )
3130
+
3131
+ return self
1928
3132
 
1929
3133
 
1930
3134
  class AppConfig(BaseModel):
1931
3135
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1932
3136
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
3137
+ service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
1933
3138
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
1934
3139
  resources: Optional[ResourcesModel] = None
1935
3140
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
1936
3141
  tools: dict[str, ToolModel] = Field(default_factory=dict)
1937
3142
  guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
3143
+ middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
1938
3144
  memory: Optional[MemoryModel] = None
1939
3145
  prompts: dict[str, PromptModel] = Field(default_factory=dict)
1940
3146
  agents: dict[str, AgentModel] = Field(default_factory=dict)
@@ -1962,10 +3168,10 @@ class AppConfig(BaseModel):
1962
3168
 
1963
3169
  def initialize(self) -> None:
1964
3170
  from dao_ai.hooks.core import create_hooks
3171
+ from dao_ai.logging import configure_logging
1965
3172
 
1966
3173
  if self.app and self.app.log_level:
1967
- logger.remove()
1968
- logger.add(sys.stderr, level=self.app.log_level)
3174
+ configure_logging(level=self.app.log_level)
1969
3175
 
1970
3176
  logger.debug("Calling initialization hooks...")
1971
3177
  initialization_functions: Sequence[Callable[..., Any]] = create_hooks(
@@ -2009,21 +3215,45 @@ class AppConfig(BaseModel):
2009
3215
  def create_agent(
2010
3216
  self,
2011
3217
  w: WorkspaceClient | None = None,
3218
+ vsc: "VectorSearchClient | None" = None,
3219
+ pat: str | None = None,
3220
+ client_id: str | None = None,
3221
+ client_secret: str | None = None,
3222
+ workspace_host: str | None = None,
2012
3223
  ) -> None:
2013
3224
  from dao_ai.providers.base import ServiceProvider
2014
3225
  from dao_ai.providers.databricks import DatabricksProvider
2015
3226
 
2016
- provider: ServiceProvider = DatabricksProvider(w=w)
3227
+ provider: ServiceProvider = DatabricksProvider(
3228
+ w=w,
3229
+ vsc=vsc,
3230
+ pat=pat,
3231
+ client_id=client_id,
3232
+ client_secret=client_secret,
3233
+ workspace_host=workspace_host,
3234
+ )
2017
3235
  provider.create_agent(self)
2018
3236
 
2019
3237
  def deploy_agent(
2020
3238
  self,
2021
3239
  w: WorkspaceClient | None = None,
3240
+ vsc: "VectorSearchClient | None" = None,
3241
+ pat: str | None = None,
3242
+ client_id: str | None = None,
3243
+ client_secret: str | None = None,
3244
+ workspace_host: str | None = None,
2022
3245
  ) -> None:
2023
3246
  from dao_ai.providers.base import ServiceProvider
2024
3247
  from dao_ai.providers.databricks import DatabricksProvider
2025
3248
 
2026
- provider: ServiceProvider = DatabricksProvider(w=w)
3249
+ provider: ServiceProvider = DatabricksProvider(
3250
+ w=w,
3251
+ vsc=vsc,
3252
+ pat=pat,
3253
+ client_id=client_id,
3254
+ client_secret=client_secret,
3255
+ workspace_host=workspace_host,
3256
+ )
2027
3257
  provider.deploy_agent(self)
2028
3258
 
2029
3259
  def find_agents(