dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.1.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.1.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -28,9 +28,11 @@ from databricks.sdk.service.database import DatabaseInstance
28
28
  from databricks.vector_search.client import VectorSearchClient
29
29
  from databricks.vector_search.index import VectorSearchIndex
30
30
  from databricks_langchain import (
31
+ ChatDatabricks,
31
32
  DatabricksEmbeddings,
32
33
  DatabricksFunctionClient,
33
34
  )
35
+ from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
34
36
  from langchain_core.embeddings import Embeddings
35
37
  from langchain_core.language_models import LanguageModelLike
36
38
  from langchain_core.messages import BaseMessage, messages_from_dict
@@ -44,6 +46,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
44
46
  from mlflow.genai.prompts import PromptVersion, load_prompt
45
47
  from mlflow.models import ModelConfig
46
48
  from mlflow.models.resources import (
49
+ DatabricksApp,
47
50
  DatabricksFunction,
48
51
  DatabricksGenieSpace,
49
52
  DatabricksLakebase,
@@ -84,27 +87,6 @@ class HasFullName(ABC):
84
87
  def full_name(self) -> str: ...
85
88
 
86
89
 
87
- class IsDatabricksResource(ABC):
88
- on_behalf_of_user: Optional[bool] = False
89
-
90
- @abstractmethod
91
- def as_resources(self) -> Sequence[DatabricksResource]: ...
92
-
93
- @property
94
- @abstractmethod
95
- def api_scopes(self) -> Sequence[str]: ...
96
-
97
- @property
98
- def workspace_client(self) -> WorkspaceClient:
99
- credentials_strategy: CredentialsStrategy = None
100
- if self.on_behalf_of_user:
101
- credentials_strategy = ModelServingUserCredentials()
102
- logger.debug(
103
- f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
104
- )
105
- return WorkspaceClient(credentials_strategy=credentials_strategy)
106
-
107
-
108
90
  class EnvironmentVariableModel(BaseModel, HasValue):
109
91
  model_config = ConfigDict(
110
92
  frozen=True,
@@ -212,6 +194,138 @@ class ServicePrincipalModel(BaseModel):
212
194
  client_secret: AnyVariable
213
195
 
214
196
 
197
+ class IsDatabricksResource(ABC, BaseModel):
198
+ """
199
+ Base class for Databricks resources with authentication support.
200
+
201
+ Authentication Options:
202
+ ----------------------
203
+ 1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
204
+ calling user's identity via ModelServingUserCredentials.
205
+
206
+ 2. **Service Principal (OAuth M2M)**: Provide service_principal or
207
+ (client_id + client_secret + workspace_host) for service principal auth.
208
+
209
+ 3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
210
+ to authenticate with a personal access token.
211
+
212
+ 4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
213
+ (environment variables, notebook context, etc.)
214
+
215
+ Authentication Priority:
216
+ 1. OBO (on_behalf_of_user=True)
217
+ 2. Service Principal (client_id + client_secret + workspace_host)
218
+ 3. PAT (pat + workspace_host)
219
+ 4. Ambient/default authentication
220
+ """
221
+
222
+ model_config = ConfigDict(use_enum_values=True)
223
+
224
+ on_behalf_of_user: Optional[bool] = False
225
+ service_principal: Optional[ServicePrincipalModel] = None
226
+ client_id: Optional[AnyVariable] = None
227
+ client_secret: Optional[AnyVariable] = None
228
+ workspace_host: Optional[AnyVariable] = None
229
+ pat: Optional[AnyVariable] = None
230
+
231
+ @abstractmethod
232
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
233
+
234
+ @property
235
+ @abstractmethod
236
+ def api_scopes(self) -> Sequence[str]: ...
237
+
238
+ @model_validator(mode="after")
239
+ def _expand_service_principal(self) -> Self:
240
+ """Expand service_principal into client_id and client_secret if provided."""
241
+ if self.service_principal is not None:
242
+ if self.client_id is None:
243
+ self.client_id = self.service_principal.client_id
244
+ if self.client_secret is None:
245
+ self.client_secret = self.service_principal.client_secret
246
+ return self
247
+
248
+ @model_validator(mode="after")
249
+ def _validate_auth_not_mixed(self) -> Self:
250
+ """Validate that OAuth and PAT authentication are not both provided."""
251
+ has_oauth: bool = self.client_id is not None and self.client_secret is not None
252
+ has_pat: bool = self.pat is not None
253
+
254
+ if has_oauth and has_pat:
255
+ raise ValueError(
256
+ "Cannot use both OAuth and user authentication methods. "
257
+ "Please provide either OAuth credentials or user credentials."
258
+ )
259
+ return self
260
+
261
+ @property
262
+ def workspace_client(self) -> WorkspaceClient:
263
+ """
264
+ Get a WorkspaceClient configured with the appropriate authentication.
265
+
266
+ Authentication priority:
267
+ 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
268
+ 2. If service principal credentials are configured (client_id, client_secret,
269
+ workspace_host), uses OAuth M2M
270
+ 3. If PAT is configured, uses token authentication
271
+ 4. Otherwise, uses default/ambient authentication
272
+ """
273
+ from dao_ai.utils import normalize_host
274
+
275
+ # Check for OBO first (highest priority)
276
+ if self.on_behalf_of_user:
277
+ credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
278
+ logger.debug(
279
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
280
+ f"with OBO credentials strategy"
281
+ )
282
+ return WorkspaceClient(credentials_strategy=credentials_strategy)
283
+
284
+ # Check for service principal credentials
285
+ client_id_value: str | None = (
286
+ value_of(self.client_id) if self.client_id else None
287
+ )
288
+ client_secret_value: str | None = (
289
+ value_of(self.client_secret) if self.client_secret else None
290
+ )
291
+ workspace_host_value: str | None = (
292
+ normalize_host(value_of(self.workspace_host))
293
+ if self.workspace_host
294
+ else None
295
+ )
296
+
297
+ if client_id_value and client_secret_value and workspace_host_value:
298
+ logger.debug(
299
+ f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
300
+ f"client_id={client_id_value}, host={workspace_host_value}"
301
+ )
302
+ return WorkspaceClient(
303
+ host=workspace_host_value,
304
+ client_id=client_id_value,
305
+ client_secret=client_secret_value,
306
+ auth_type="oauth-m2m",
307
+ )
308
+
309
+ # Check for PAT authentication
310
+ pat_value: str | None = value_of(self.pat) if self.pat else None
311
+ if pat_value:
312
+ logger.debug(
313
+ f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
314
+ )
315
+ return WorkspaceClient(
316
+ host=workspace_host_value,
317
+ token=pat_value,
318
+ auth_type="pat",
319
+ )
320
+
321
+ # Default: use ambient authentication
322
+ logger.debug(
323
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
324
+ "with default/ambient authentication"
325
+ )
326
+ return WorkspaceClient()
327
+
328
+
215
329
  class Privilege(str, Enum):
216
330
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
217
331
  USE_CATALOG = "USE_CATALOG"
@@ -272,7 +386,26 @@ class SchemaModel(BaseModel, HasFullName):
272
386
  provider.create_schema(self)
273
387
 
274
388
 
275
- class TableModel(BaseModel, HasFullName, IsDatabricksResource):
389
+ class DatabricksAppModel(IsDatabricksResource, HasFullName):
390
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
391
+ name: str
392
+ url: str
393
+
394
+ @property
395
+ def full_name(self) -> str:
396
+ return self.name
397
+
398
+ @property
399
+ def api_scopes(self) -> Sequence[str]:
400
+ return ["apps.apps"]
401
+
402
+ def as_resources(self) -> Sequence[DatabricksResource]:
403
+ return [
404
+ DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
405
+ ]
406
+
407
+
408
+ class TableModel(IsDatabricksResource, HasFullName):
276
409
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
277
410
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
278
411
  name: Optional[str] = None
@@ -341,12 +474,16 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
341
474
  return resources
342
475
 
343
476
 
344
- class LLMModel(BaseModel, IsDatabricksResource):
477
+ class LLMModel(IsDatabricksResource):
345
478
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
346
479
  name: str
347
480
  temperature: Optional[float] = 0.1
348
481
  max_tokens: Optional[int] = 8192
349
482
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
483
+ use_responses_api: Optional[bool] = Field(
484
+ default=False,
485
+ description="Use Responses API for ResponsesAgent endpoints",
486
+ )
350
487
 
351
488
  @property
352
489
  def api_scopes(self) -> Sequence[str]:
@@ -366,19 +503,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
366
503
  ]
367
504
 
368
505
  def as_chat_model(self) -> LanguageModelLike:
369
- # Retrieve langchain chat client from workspace client to enable OBO
370
- # ChatOpenAI does not allow additional inputs at the moment, so we cannot use it directly
371
- # chat_client: LanguageModelLike = self.as_open_ai_client()
372
-
373
- # Create ChatDatabricksWrapper instance directly
374
- from dao_ai.chat_models import ChatDatabricksFiltered
375
-
376
- chat_client: LanguageModelLike = ChatDatabricksFiltered(
377
- model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
506
+ chat_client: LanguageModelLike = ChatDatabricks(
507
+ model=self.name,
508
+ temperature=self.temperature,
509
+ max_tokens=self.max_tokens,
510
+ use_responses_api=self.use_responses_api,
378
511
  )
379
- # chat_client: LanguageModelLike = ChatDatabricks(
380
- # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
381
- # )
382
512
 
383
513
  fallbacks: Sequence[LanguageModelLike] = []
384
514
  for fallback in self.fallbacks:
@@ -432,7 +562,7 @@ class VectorSearchEndpoint(BaseModel):
432
562
  return str(value)
433
563
 
434
564
 
435
- class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
565
+ class IndexModel(IsDatabricksResource, HasFullName):
436
566
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
437
567
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
438
568
  name: str
@@ -457,7 +587,7 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
457
587
  ]
458
588
 
459
589
 
460
- class GenieRoomModel(BaseModel, IsDatabricksResource):
590
+ class GenieRoomModel(IsDatabricksResource):
461
591
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
462
592
  name: str
463
593
  description: Optional[str] = None
@@ -483,7 +613,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
483
613
  return self
484
614
 
485
615
 
486
- class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
616
+ class VolumeModel(IsDatabricksResource, HasFullName):
487
617
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
488
618
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
489
619
  name: str
@@ -543,7 +673,7 @@ class VolumePathModel(BaseModel, HasFullName):
543
673
  provider.create_path(self)
544
674
 
545
675
 
546
- class VectorStoreModel(BaseModel, IsDatabricksResource):
676
+ class VectorStoreModel(IsDatabricksResource):
547
677
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
548
678
  embedding_model: Optional[LLMModel] = None
549
679
  index: Optional[IndexModel] = None
@@ -642,7 +772,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
642
772
  provider.create_vector_store(self)
643
773
 
644
774
 
645
- class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
775
+ class FunctionModel(IsDatabricksResource, HasFullName):
646
776
  model_config = ConfigDict()
647
777
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
648
778
  name: Optional[str] = None
@@ -697,7 +827,7 @@ class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
697
827
  return ["sql.statement-execution"]
698
828
 
699
829
 
700
- class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
830
+ class ConnectionModel(IsDatabricksResource, HasFullName):
701
831
  model_config = ConfigDict()
702
832
  name: str
703
833
 
@@ -724,7 +854,7 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
724
854
  ]
725
855
 
726
856
 
727
- class WarehouseModel(BaseModel, IsDatabricksResource):
857
+ class WarehouseModel(IsDatabricksResource):
728
858
  model_config = ConfigDict()
729
859
  name: str
730
860
  description: Optional[str] = None
@@ -751,30 +881,28 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
751
881
  return self
752
882
 
753
883
 
754
- class DatabaseModel(BaseModel, IsDatabricksResource):
884
+ class DatabaseType(str, Enum):
885
+ POSTGRES = "postgres"
886
+ LAKEBASE = "lakebase"
887
+
888
+
889
+ class DatabaseModel(IsDatabricksResource):
755
890
  """
756
891
  Configuration for a Databricks Lakebase (PostgreSQL) database instance.
757
892
 
758
- Authentication Model:
759
- --------------------
760
- This model uses TWO separate authentication contexts:
761
-
762
- 1. **Workspace API Authentication** (inherited from IsDatabricksResource):
763
- - Uses ambient/default authentication (environment variables, notebook context, app service principal)
764
- - Used for: discovering database instance, getting host DNS, checking instance status
765
- - Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
893
+ Authentication is inherited from IsDatabricksResource. Additionally supports:
894
+ - user/password: For user-based database authentication
766
895
 
767
- 2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
768
- - Used for: connecting to the PostgreSQL database as a specific identity
769
- - Service Principal: Set service_principal with workspace_host to connect as a service principal
770
- - OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
771
- - User Auth: Set user (and optionally password) to connect as a user identity
896
+ Database Type:
897
+ - lakebase: Databricks-managed Lakebase instance (authentication optional, supports ambient auth)
898
+ - postgres: Standard PostgreSQL database (authentication required)
772
899
 
773
900
  Example Service Principal Configuration:
774
901
  ```yaml
775
902
  databases:
776
903
  my_lakebase:
777
904
  name: my-database
905
+ type: lakebase
778
906
  service_principal:
779
907
  client_id:
780
908
  env: SERVICE_PRINCIPAL_CLIENT_ID
@@ -785,31 +913,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
785
913
  env: DATABRICKS_HOST
786
914
  ```
787
915
 
788
- Example OAuth M2M Configuration (alternative):
916
+ Example User Configuration:
789
917
  ```yaml
790
918
  databases:
791
919
  my_lakebase:
792
920
  name: my-database
793
- client_id:
794
- env: SERVICE_PRINCIPAL_CLIENT_ID
795
- client_secret:
796
- scope: my-scope
797
- secret: sp-client-secret
798
- workspace_host:
799
- env: DATABRICKS_HOST
921
+ type: lakebase
922
+ user: my-user@databricks.com
800
923
  ```
801
924
 
802
- Example User Configuration:
925
+ Example Ambient Authentication (Lakebase only):
803
926
  ```yaml
804
927
  databases:
805
928
  my_lakebase:
806
929
  name: my-database
807
- user: my-user@databricks.com
930
+ type: lakebase
931
+ on_behalf_of_user: true
808
932
  ```
809
933
  """
810
934
 
811
935
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
812
936
  name: str
937
+ type: Optional[DatabaseType] = DatabaseType.LAKEBASE
813
938
  instance_name: Optional[str] = None
814
939
  description: Optional[str] = None
815
940
  host: Optional[AnyVariable] = None
@@ -820,16 +945,18 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
820
945
  timeout_seconds: Optional[int] = 10
821
946
  capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
822
947
  node_count: Optional[int] = None
948
+ # Database-specific auth (user identity for DB connection)
823
949
  user: Optional[AnyVariable] = None
824
950
  password: Optional[AnyVariable] = None
825
- service_principal: Optional[ServicePrincipalModel] = None
826
- client_id: Optional[AnyVariable] = None
827
- client_secret: Optional[AnyVariable] = None
828
- workspace_host: Optional[AnyVariable] = None
951
+
952
+ @field_serializer("type")
953
+ def serialize_type(self, value: DatabaseType | None) -> str | None:
954
+ """Serialize the database type enum to its string value."""
955
+ return value.value if value is not None else None
829
956
 
830
957
  @property
831
958
  def api_scopes(self) -> Sequence[str]:
832
- return []
959
+ return ["database.database-instances"]
833
960
 
834
961
  def as_resources(self) -> Sequence[DatabricksResource]:
835
962
  return [
@@ -843,29 +970,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
843
970
  def update_instance_name(self) -> Self:
844
971
  if self.instance_name is None:
845
972
  self.instance_name = self.name
846
-
847
- return self
848
-
849
- @model_validator(mode="after")
850
- def expand_service_principal(self) -> Self:
851
- """Expand service_principal into client_id and client_secret if provided."""
852
- if self.service_principal is not None:
853
- if self.client_id is None:
854
- self.client_id = self.service_principal.client_id
855
- if self.client_secret is None:
856
- self.client_secret = self.service_principal.client_secret
857
973
  return self
858
974
 
859
975
  @model_validator(mode="after")
860
976
  def update_user(self) -> Self:
861
- if self.client_id or self.user:
977
+ # Skip if using OBO (passive auth), explicit credentials, or explicit user
978
+ if self.on_behalf_of_user or self.client_id or self.user or self.pat:
862
979
  return self
863
980
 
864
- self.user = self.workspace_client.current_user.me().user_name
865
- if not self.user:
866
- raise ValueError(
867
- "Unable to determine current user. Please provide a user name or OAuth credentials."
868
- )
981
+ # For postgres, we need explicit user credentials
982
+ # For lakebase with no auth, ambient auth is allowed
983
+ if self.type == DatabaseType.POSTGRES:
984
+ # Try to determine current user for local development
985
+ try:
986
+ self.user = self.workspace_client.current_user.me().user_name
987
+ except Exception as e:
988
+ logger.warning(
989
+ f"Could not determine current user for PostgreSQL database: {e}. "
990
+ f"Please provide explicit user credentials."
991
+ )
992
+ else:
993
+ # For lakebase, try to determine current user but don't fail if we can't
994
+ try:
995
+ self.user = self.workspace_client.current_user.me().user_name
996
+ except Exception:
997
+ # If we can't determine user and no explicit auth, that's okay
998
+ # for lakebase with ambient auth - credentials will be injected at runtime
999
+ pass
869
1000
 
870
1001
  return self
871
1002
 
@@ -874,12 +1005,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
874
1005
  if self.host is not None:
875
1006
  return self
876
1007
 
877
- existing_instance: DatabaseInstance = (
878
- self.workspace_client.database.get_database_instance(
879
- name=self.instance_name
1008
+ # Try to fetch host from existing instance
1009
+ # This may fail for OBO/ambient auth during model logging (before deployment)
1010
+ try:
1011
+ existing_instance: DatabaseInstance = (
1012
+ self.workspace_client.database.get_database_instance(
1013
+ name=self.instance_name
1014
+ )
880
1015
  )
881
- )
882
- self.host = existing_instance.read_write_dns
1016
+ self.host = existing_instance.read_write_dns
1017
+ except Exception as e:
1018
+ # For lakebase with OBO/ambient auth, we can't fetch at config time
1019
+ # The host will need to be provided explicitly or fetched at runtime
1020
+ if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
1021
+ logger.debug(
1022
+ f"Could not fetch host for database {self.instance_name} "
1023
+ f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1024
+ )
1025
+ else:
1026
+ raise ValueError(
1027
+ f"Could not fetch host for database {self.instance_name}. "
1028
+ f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1029
+ )
883
1030
  return self
884
1031
 
885
1032
  @model_validator(mode="after")
@@ -890,21 +1037,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
890
1037
  self.client_secret,
891
1038
  ]
892
1039
  has_oauth: bool = all(field is not None for field in oauth_fields)
1040
+ has_user_auth: bool = self.user is not None
1041
+ has_obo: bool = self.on_behalf_of_user is True
1042
+ has_pat: bool = self.pat is not None
893
1043
 
894
- pat_fields: Sequence[Any] = [self.user]
895
- has_user_auth: bool = all(field is not None for field in pat_fields)
1044
+ # Count how many auth methods are configured
1045
+ auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
896
1046
 
897
- if has_oauth and has_user_auth:
1047
+ if auth_methods_count > 1:
898
1048
  raise ValueError(
899
- "Cannot use both OAuth and user authentication methods. "
900
- "Please provide either OAuth credentials or user credentials."
1049
+ "Cannot mix authentication methods. "
1050
+ "Please provide exactly one of: "
1051
+ "on_behalf_of_user=true (for passive auth in model serving), "
1052
+ "OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
1053
+ "PAT (personal access token), "
1054
+ "or user credentials (user)."
901
1055
  )
902
1056
 
903
- if not has_oauth and not has_user_auth:
1057
+ # For postgres type, at least one auth method must be configured
1058
+ # For lakebase type, auth is optional (supports ambient authentication)
1059
+ if self.type == DatabaseType.POSTGRES and auth_methods_count == 0:
904
1060
  raise ValueError(
905
- "At least one authentication method must be provided: "
906
- "either OAuth credentials (workspace_host, client_id, client_secret), "
907
- "service_principal with workspace_host, or user credentials (user, password)."
1061
+ "PostgreSQL databases require explicit authentication. "
1062
+ "Please provide one of: "
1063
+ "OAuth credentials (workspace_host, client_id, client_secret), "
1064
+ "service_principal with workspace_host, "
1065
+ "PAT (personal access token), "
1066
+ "or user credentials (user)."
908
1067
  )
909
1068
 
910
1069
  return self
@@ -918,8 +1077,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
918
1077
  If username is configured, it will be included; otherwise it will be omitted
919
1078
  to allow Lakebase to authenticate using the token's identity.
920
1079
  """
921
- from dao_ai.providers.base import ServiceProvider
922
- from dao_ai.providers.databricks import DatabricksProvider
1080
+ import uuid as _uuid
1081
+
1082
+ from databricks.sdk.service.database import DatabaseCredential
923
1083
 
924
1084
  username: str | None = None
925
1085
 
@@ -927,19 +1087,36 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
927
1087
  username = value_of(self.client_id)
928
1088
  elif self.user:
929
1089
  username = value_of(self.user)
1090
+ # For OBO mode, no username is needed - the token identity is used
1091
+
1092
+ # Resolve host - may need to fetch at runtime for OBO mode
1093
+ host_value: Any = self.host
1094
+ if host_value is None and self.on_behalf_of_user:
1095
+ # Fetch host at runtime for OBO mode
1096
+ existing_instance: DatabaseInstance = (
1097
+ self.workspace_client.database.get_database_instance(
1098
+ name=self.instance_name
1099
+ )
1100
+ )
1101
+ host_value = existing_instance.read_write_dns
930
1102
 
931
- host: str = value_of(self.host)
1103
+ if host_value is None:
1104
+ raise ValueError(
1105
+ f"Database host not configured for {self.instance_name}. "
1106
+ "Please provide 'host' explicitly."
1107
+ )
1108
+
1109
+ host: str = value_of(host_value)
932
1110
  port: int = value_of(self.port)
933
1111
  database: str = value_of(self.database)
934
1112
 
935
- provider: ServiceProvider = DatabricksProvider(
936
- client_id=value_of(self.client_id),
937
- client_secret=value_of(self.client_secret),
938
- workspace_host=value_of(self.workspace_host),
939
- pat=value_of(self.password),
1113
+ # Use the resource's own workspace_client to generate the database credential
1114
+ w: WorkspaceClient = self.workspace_client
1115
+ cred: DatabaseCredential = w.database.generate_database_credential(
1116
+ request_id=str(_uuid.uuid4()),
1117
+ instance_names=[self.instance_name],
940
1118
  )
941
-
942
- token: str = provider.lakebase_password_provider(self.instance_name)
1119
+ token: str = cred.token
943
1120
 
944
1121
  # Build connection parameters dictionary
945
1122
  params: dict[str, Any] = {
@@ -977,6 +1154,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
977
1154
  def create(self, w: WorkspaceClient | None = None) -> None:
978
1155
  from dao_ai.providers.databricks import DatabricksProvider
979
1156
 
1157
+ # Use provided workspace client or fall back to resource's own workspace_client
1158
+ if w is None:
1159
+ w = self.workspace_client
980
1160
  provider: DatabricksProvider = DatabricksProvider(w=w)
981
1161
  provider.create_lakebase(self)
982
1162
  provider.create_lakebase_instance_role(self)
@@ -996,14 +1176,62 @@ class GenieSemanticCacheParametersModel(BaseModel):
996
1176
  time_to_live_seconds: int | None = (
997
1177
  60 * 60 * 24
998
1178
  ) # 1 day default, None or negative = never expires
999
- similarity_threshold: float = (
1000
- 0.85 # Minimum similarity for cache hit (L2 distance converted to 0-1 scale)
1179
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1180
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1181
+ question_weight: Optional[float] = (
1182
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1183
+ )
1184
+ context_weight: Optional[float] = (
1185
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1001
1186
  )
1002
1187
  embedding_model: str | LLMModel = "databricks-gte-large-en"
1003
1188
  embedding_dims: int | None = None # Auto-detected if None
1004
1189
  database: DatabaseModel
1005
1190
  warehouse: WarehouseModel
1006
1191
  table_name: str = "genie_semantic_cache"
1192
+ context_window_size: int = 3 # Number of previous turns to include for context
1193
+ max_context_tokens: int = (
1194
+ 2000 # Maximum context length to prevent extremely long embeddings
1195
+ )
1196
+
1197
+ @model_validator(mode="after")
1198
+ def compute_and_validate_weights(self) -> Self:
1199
+ """
1200
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1201
+
1202
+ Either question_weight or context_weight (or both) can be provided.
1203
+ The missing one will be computed as 1.0 - provided_weight.
1204
+ If both are provided, they must sum to 1.0.
1205
+ """
1206
+ if self.question_weight is None and self.context_weight is None:
1207
+ # Both missing - use defaults
1208
+ self.question_weight = 0.6
1209
+ self.context_weight = 0.4
1210
+ elif self.question_weight is None:
1211
+ # Compute question_weight from context_weight
1212
+ if not (0.0 <= self.context_weight <= 1.0):
1213
+ raise ValueError(
1214
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1215
+ )
1216
+ self.question_weight = 1.0 - self.context_weight
1217
+ elif self.context_weight is None:
1218
+ # Compute context_weight from question_weight
1219
+ if not (0.0 <= self.question_weight <= 1.0):
1220
+ raise ValueError(
1221
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1222
+ )
1223
+ self.context_weight = 1.0 - self.question_weight
1224
+ else:
1225
+ # Both provided - validate they sum to 1.0
1226
+ total_weight = self.question_weight + self.context_weight
1227
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1228
+ raise ValueError(
1229
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1230
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1231
+ f"of question vs context similarity in the combined score."
1232
+ )
1233
+
1234
+ return self
1007
1235
 
1008
1236
 
1009
1237
  class SearchParametersModel(BaseModel):
@@ -1096,28 +1324,47 @@ class FunctionType(str, Enum):
1096
1324
  MCP = "mcp"
1097
1325
 
1098
1326
 
1099
- class HumanInTheLoopActionType(str, Enum):
1100
- """Supported action types for human-in-the-loop interactions."""
1327
+ class HumanInTheLoopModel(BaseModel):
1328
+ """
1329
+ Configuration for Human-in-the-Loop tool approval.
1101
1330
 
1102
- ACCEPT = "accept"
1103
- EDIT = "edit"
1104
- RESPONSE = "response"
1105
- DECLINE = "decline"
1331
+ This model configures when and how tools require human approval before execution.
1332
+ It maps to LangChain's HumanInTheLoopMiddleware.
1106
1333
 
1334
+ LangChain supports three decision types:
1335
+ - "approve": Execute tool with original arguments
1336
+ - "edit": Modify arguments before execution
1337
+ - "reject": Skip execution with optional feedback message
1338
+ """
1107
1339
 
1108
- class HumanInTheLoopModel(BaseModel):
1109
1340
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1110
- review_prompt: str = "Please review the tool call"
1111
- interrupt_config: dict[str, Any] = Field(
1112
- default_factory=lambda: {
1113
- "allow_accept": True,
1114
- "allow_edit": True,
1115
- "allow_respond": True,
1116
- "allow_decline": True,
1117
- }
1341
+
1342
+ review_prompt: Optional[str] = Field(
1343
+ default=None,
1344
+ description="Message shown to the reviewer when approval is requested",
1345
+ )
1346
+
1347
+ allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
1348
+ default_factory=lambda: ["approve", "edit", "reject"],
1349
+ description="List of allowed decision types for this tool",
1118
1350
  )
1119
- decline_message: str = "Tool call declined by user"
1120
- custom_actions: Optional[dict[str, str]] = Field(default_factory=dict)
1351
+
1352
+ @model_validator(mode="after")
1353
+ def validate_and_normalize_decisions(self) -> Self:
1354
+ """Validate and normalize allowed decisions."""
1355
+ if not self.allowed_decisions:
1356
+ raise ValueError("At least one decision type must be allowed")
1357
+
1358
+ # Remove duplicates while preserving order
1359
+ seen = set()
1360
+ unique_decisions = []
1361
+ for decision in self.allowed_decisions:
1362
+ if decision not in seen:
1363
+ seen.add(decision)
1364
+ unique_decisions.append(decision)
1365
+ self.allowed_decisions = unique_decisions
1366
+
1367
+ return self
1121
1368
 
1122
1369
 
1123
1370
  class BaseFunctionModel(ABC, BaseModel):
@@ -1180,7 +1427,16 @@ class TransportType(str, Enum):
1180
1427
  STDIO = "stdio"
1181
1428
 
1182
1429
 
1183
- class McpFunctionModel(BaseFunctionModel, HasFullName):
1430
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
1431
+ """
1432
+ MCP Function Model with authentication inherited from IsDatabricksResource.
1433
+
1434
+ Authentication for MCP connections uses the same options as other resources:
1435
+ - Service Principal (client_id + client_secret + workspace_host)
1436
+ - PAT (pat + workspace_host)
1437
+ - OBO (on_behalf_of_user)
1438
+ """
1439
+
1184
1440
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1185
1441
  type: Literal[FunctionType.MCP] = FunctionType.MCP
1186
1442
  transport: TransportType = TransportType.STREAMABLE_HTTP
@@ -1188,26 +1444,27 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1188
1444
  url: Optional[AnyVariable] = None
1189
1445
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1190
1446
  args: list[str] = Field(default_factory=list)
1191
- pat: Optional[AnyVariable] = None
1192
- service_principal: Optional[ServicePrincipalModel] = None
1193
- client_id: Optional[AnyVariable] = None
1194
- client_secret: Optional[AnyVariable] = None
1195
- workspace_host: Optional[AnyVariable] = None
1447
+ # MCP-specific fields
1196
1448
  connection: Optional[ConnectionModel] = None
1197
1449
  functions: Optional[SchemaModel] = None
1198
1450
  genie_room: Optional[GenieRoomModel] = None
1199
1451
  sql: Optional[bool] = None
1200
1452
  vector_search: Optional[VectorStoreModel] = None
1201
1453
 
1202
- @model_validator(mode="after")
1203
- def expand_service_principal(self) -> Self:
1204
- """Expand service_principal into client_id and client_secret if provided."""
1205
- if self.service_principal is not None:
1206
- if self.client_id is None:
1207
- self.client_id = self.service_principal.client_id
1208
- if self.client_secret is None:
1209
- self.client_secret = self.service_principal.client_secret
1210
- return self
1454
+ @property
1455
+ def api_scopes(self) -> Sequence[str]:
1456
+ """API scopes for MCP connections."""
1457
+ return [
1458
+ "serving.serving-endpoints",
1459
+ "mcp.genie",
1460
+ "mcp.functions",
1461
+ "mcp.vectorsearch",
1462
+ "mcp.external",
1463
+ ]
1464
+
1465
+ def as_resources(self) -> Sequence[DatabricksResource]:
1466
+ """MCP functions don't declare static resources."""
1467
+ return []
1211
1468
 
1212
1469
  @property
1213
1470
  def full_name(self) -> str:
@@ -1372,27 +1629,6 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1372
1629
  self.headers[key] = value_of(value)
1373
1630
  return self
1374
1631
 
1375
- @model_validator(mode="after")
1376
- def validate_auth_methods(self) -> "McpFunctionModel":
1377
- oauth_fields: Sequence[Any] = [
1378
- self.client_id,
1379
- self.client_secret,
1380
- ]
1381
- has_oauth: bool = all(field is not None for field in oauth_fields)
1382
-
1383
- pat_fields: Sequence[Any] = [self.pat]
1384
- has_user_auth: bool = all(field is not None for field in pat_fields)
1385
-
1386
- if has_oauth and has_user_auth:
1387
- raise ValueError(
1388
- "Cannot use both OAuth and user authentication methods. "
1389
- "Please provide either OAuth credentials or user credentials."
1390
- )
1391
-
1392
- # Note: workspace_host is optional - it will be derived from workspace client if not provided
1393
-
1394
- return self
1395
-
1396
1632
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1397
1633
  from dao_ai.tools import create_mcp_tools
1398
1634
 
@@ -1434,17 +1670,97 @@ class ToolModel(BaseModel):
1434
1670
  function: AnyTool
1435
1671
 
1436
1672
 
1673
+ class PromptModel(BaseModel, HasFullName):
1674
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1675
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1676
+ name: str
1677
+ description: Optional[str] = None
1678
+ default_template: Optional[str] = None
1679
+ alias: Optional[str] = None
1680
+ version: Optional[int] = None
1681
+ tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1682
+
1683
+ @property
1684
+ def template(self) -> str:
1685
+ from dao_ai.providers.databricks import DatabricksProvider
1686
+
1687
+ provider: DatabricksProvider = DatabricksProvider()
1688
+ prompt_version = provider.get_prompt(self)
1689
+ return prompt_version.to_single_brace_format()
1690
+
1691
+ @property
1692
+ def full_name(self) -> str:
1693
+ prompt_name: str = self.name
1694
+ if self.schema_model:
1695
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1696
+ return prompt_name
1697
+
1698
+ @property
1699
+ def uri(self) -> str:
1700
+ prompt_uri: str = f"prompts:/{self.full_name}"
1701
+
1702
+ if self.alias:
1703
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1704
+ elif self.version:
1705
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1706
+ else:
1707
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1708
+
1709
+ return prompt_uri
1710
+
1711
+ def as_prompt(self) -> PromptVersion:
1712
+ prompt_version: PromptVersion = load_prompt(self.uri)
1713
+ return prompt_version
1714
+
1715
+ @model_validator(mode="after")
1716
+ def validate_mutually_exclusive(self) -> Self:
1717
+ if self.alias and self.version:
1718
+ raise ValueError("Cannot specify both alias and version")
1719
+ return self
1720
+
1721
+
1437
1722
  class GuardrailModel(BaseModel):
1438
1723
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1439
1724
  name: str
1440
- model: LLMModel
1441
- prompt: str
1725
+ model: str | LLMModel
1726
+ prompt: str | PromptModel
1442
1727
  num_retries: Optional[int] = 3
1443
1728
 
1729
+ @model_validator(mode="after")
1730
+ def validate_llm_model(self) -> Self:
1731
+ if isinstance(self.model, str):
1732
+ self.model = LLMModel(name=self.model)
1733
+ return self
1734
+
1735
+
1736
+ class MiddlewareModel(BaseModel):
1737
+ """Configuration for middleware that can be applied to agents.
1738
+
1739
+ Middleware is defined at the AppConfig level and can be referenced by name
1740
+ in agent configurations using YAML anchors for reusability.
1741
+ """
1742
+
1743
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1744
+ name: str = Field(
1745
+ description="Fully qualified name of the middleware factory function"
1746
+ )
1747
+ args: dict[str, Any] = Field(
1748
+ default_factory=dict,
1749
+ description="Arguments to pass to the middleware factory function",
1750
+ )
1751
+
1752
+ @model_validator(mode="after")
1753
+ def resolve_args(self) -> Self:
1754
+ """Resolve any variable references in args."""
1755
+ for key, value in self.args.items():
1756
+ self.args[key] = value_of(value)
1757
+ return self
1758
+
1444
1759
 
1445
1760
  class StorageType(str, Enum):
1446
1761
  POSTGRES = "postgres"
1447
1762
  MEMORY = "memory"
1763
+ LAKEBASE = "lakebase"
1448
1764
 
1449
1765
 
1450
1766
  class CheckpointerModel(BaseModel):
@@ -1454,8 +1770,11 @@ class CheckpointerModel(BaseModel):
1454
1770
  database: Optional[DatabaseModel] = None
1455
1771
 
1456
1772
  @model_validator(mode="after")
1457
- def validate_postgres_requires_database(self) -> Self:
1458
- if self.type == StorageType.POSTGRES and not self.database:
1773
+ def validate_storage_requires_database(self) -> Self:
1774
+ if (
1775
+ self.type in [StorageType.POSTGRES, StorageType.LAKEBASE]
1776
+ and not self.database
1777
+ ):
1459
1778
  raise ValueError("Database must be provided when storage type is POSTGRES")
1460
1779
  return self
1461
1780
 
@@ -1500,56 +1819,158 @@ class MemoryModel(BaseModel):
1500
1819
  FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
1501
1820
 
1502
1821
 
1503
- class PromptModel(BaseModel, HasFullName):
1822
+ class ResponseFormatModel(BaseModel):
1823
+ """
1824
+ Configuration for structured response formats.
1825
+
1826
+ The response_schema field accepts either a type or a string:
1827
+ - Type (Pydantic model, dataclass, etc.): Used directly for structured output
1828
+ - String: First attempts to load as a fully qualified type name, falls back to JSON schema string
1829
+
1830
+ This unified approach simplifies the API while maintaining flexibility.
1831
+ """
1832
+
1504
1833
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1505
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1506
- name: str
1507
- description: Optional[str] = None
1508
- default_template: Optional[str] = None
1509
- alias: Optional[str] = None
1510
- version: Optional[int] = None
1511
- tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1834
+ use_tool: Optional[bool] = Field(
1835
+ default=None,
1836
+ description=(
1837
+ "Strategy for structured output: "
1838
+ "None (default) = auto-detect from model capabilities, "
1839
+ "False = force ProviderStrategy (native), "
1840
+ "True = force ToolStrategy (function calling)"
1841
+ ),
1842
+ )
1843
+ response_schema: Optional[str | type] = Field(
1844
+ default=None,
1845
+ description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
1846
+ )
1512
1847
 
1513
- @property
1514
- def template(self) -> str:
1515
- from dao_ai.providers.databricks import DatabricksProvider
1848
+ def as_strategy(self) -> ProviderStrategy | ToolStrategy:
1849
+ """
1850
+ Convert response_schema to appropriate LangChain strategy.
1516
1851
 
1517
- provider: DatabricksProvider = DatabricksProvider()
1518
- prompt_version = provider.get_prompt(self)
1519
- return prompt_version.to_single_brace_format()
1852
+ Returns:
1853
+ - None if no response_schema configured
1854
+ - Raw schema/type for auto-detection (when use_tool=None)
1855
+ - ToolStrategy wrapping the schema (when use_tool=True)
1856
+ - ProviderStrategy wrapping the schema (when use_tool=False)
1520
1857
 
1521
- @property
1522
- def full_name(self) -> str:
1523
- prompt_name: str = self.name
1524
- if self.schema_model:
1525
- prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1526
- return prompt_name
1858
+ Raises:
1859
+ ValueError: If response_schema is a JSON schema string that cannot be parsed
1860
+ """
1527
1861
 
1528
- @property
1529
- def uri(self) -> str:
1530
- prompt_uri: str = f"prompts:/{self.full_name}"
1862
+ if self.response_schema is None:
1863
+ return None
1531
1864
 
1532
- if self.alias:
1533
- prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1534
- elif self.version:
1535
- prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1536
- else:
1537
- prompt_uri = f"prompts:/{self.full_name}@latest"
1865
+ schema = self.response_schema
1538
1866
 
1539
- return prompt_uri
1867
+ # Handle type schemas (Pydantic, dataclass, etc.)
1868
+ if self.is_type_schema:
1869
+ if self.use_tool is None:
1870
+ # Auto-detect: Pass schema directly, let LangChain decide
1871
+ return schema
1872
+ elif self.use_tool is True:
1873
+ # Force ToolStrategy (function calling)
1874
+ return ToolStrategy(schema)
1875
+ else: # use_tool is False
1876
+ # Force ProviderStrategy (native structured output)
1877
+ return ProviderStrategy(schema)
1540
1878
 
1541
- def as_prompt(self) -> PromptVersion:
1542
- prompt_version: PromptVersion = load_prompt(self.uri)
1543
- return prompt_version
1879
+ # Handle JSON schema strings
1880
+ elif self.is_json_schema:
1881
+ import json
1882
+
1883
+ try:
1884
+ schema_dict = json.loads(schema)
1885
+ except json.JSONDecodeError as e:
1886
+ raise ValueError(f"Invalid JSON schema string: {e}") from e
1887
+
1888
+ # Apply same use_tool logic as type schemas
1889
+ if self.use_tool is None:
1890
+ # Auto-detect
1891
+ return schema_dict
1892
+ elif self.use_tool is True:
1893
+ # Force ToolStrategy
1894
+ return ToolStrategy(schema_dict)
1895
+ else: # use_tool is False
1896
+ # Force ProviderStrategy
1897
+ return ProviderStrategy(schema_dict)
1898
+
1899
+ return None
1544
1900
 
1545
1901
  @model_validator(mode="after")
1546
- def validate_mutually_exclusive(self) -> Self:
1547
- if self.alias and self.version:
1548
- raise ValueError("Cannot specify both alias and version")
1549
- return self
1902
+ def validate_response_schema(self) -> Self:
1903
+ """
1904
+ Validate and convert response_schema.
1905
+
1906
+ Processing logic:
1907
+ 1. If None: no response format specified
1908
+ 2. If type: use directly as structured output type
1909
+ 3. If str: try to load as FQN using type_from_fqn
1910
+ - Success: response_schema becomes the loaded type
1911
+ - Failure: keep as string (treated as JSON schema)
1912
+
1913
+ After validation, response_schema is one of:
1914
+ - None (no schema)
1915
+ - type (use for structured output)
1916
+ - str (JSON schema)
1917
+
1918
+ Returns:
1919
+ Self with validated response_schema
1920
+ """
1921
+ if self.response_schema is None:
1922
+ return self
1923
+
1924
+ # If already a type, return
1925
+ if isinstance(self.response_schema, type):
1926
+ return self
1927
+
1928
+ # If it's a string, try to load as type, fallback to json_schema
1929
+ if isinstance(self.response_schema, str):
1930
+ from dao_ai.utils import type_from_fqn
1931
+
1932
+ try:
1933
+ resolved_type = type_from_fqn(self.response_schema)
1934
+ self.response_schema = resolved_type
1935
+ logger.debug(
1936
+ f"Resolved response_schema string to type: {resolved_type}"
1937
+ )
1938
+ return self
1939
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
1940
+ # Keep as string - it's a JSON schema
1941
+ logger.debug(
1942
+ f"Could not resolve '{self.response_schema}' as type: {e}. "
1943
+ f"Treating as JSON schema string."
1944
+ )
1945
+ return self
1946
+
1947
+ # Invalid type
1948
+ raise ValueError(
1949
+ f"response_schema must be None, type, or str, got {type(self.response_schema)}"
1950
+ )
1951
+
1952
+ @property
1953
+ def is_type_schema(self) -> bool:
1954
+ """Returns True if response_schema is a type (not JSON schema string)."""
1955
+ return isinstance(self.response_schema, type)
1956
+
1957
+ @property
1958
+ def is_json_schema(self) -> bool:
1959
+ """Returns True if response_schema is a JSON schema string (not a type)."""
1960
+ return isinstance(self.response_schema, str)
1550
1961
 
1551
1962
 
1552
1963
  class AgentModel(BaseModel):
1964
+ """
1965
+ Configuration model for an agent in the DAO AI framework.
1966
+
1967
+ Agents combine an LLM with tools and middleware to create systems that can
1968
+ reason about tasks, decide which tools to use, and iteratively work towards solutions.
1969
+
1970
+ Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
1971
+ providing a more flexible and composable way to customize agent behavior.
1972
+ """
1973
+
1553
1974
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1554
1975
  name: str
1555
1976
  description: Optional[str] = None
@@ -1558,9 +1979,43 @@ class AgentModel(BaseModel):
1558
1979
  guardrails: list[GuardrailModel] = Field(default_factory=list)
1559
1980
  prompt: Optional[str | PromptModel] = None
1560
1981
  handoff_prompt: Optional[str] = None
1561
- create_agent_hook: Optional[FunctionHook] = None
1562
- pre_agent_hook: Optional[FunctionHook] = None
1563
- post_agent_hook: Optional[FunctionHook] = None
1982
+ middleware: list[MiddlewareModel] = Field(
1983
+ default_factory=list,
1984
+ description="List of middleware to apply to this agent",
1985
+ )
1986
+ response_format: Optional[ResponseFormatModel | type | str] = None
1987
+
1988
+ @model_validator(mode="after")
1989
+ def validate_response_format(self) -> Self:
1990
+ """
1991
+ Validate and normalize response_format.
1992
+
1993
+ Accepts:
1994
+ - None (no response format)
1995
+ - ResponseFormatModel (already validated)
1996
+ - type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
1997
+ - str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
1998
+
1999
+ ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
2000
+ """
2001
+ if self.response_format is None or isinstance(
2002
+ self.response_format, ResponseFormatModel
2003
+ ):
2004
+ return self
2005
+
2006
+ # Convert type or str to ResponseFormatModel
2007
+ # ResponseFormatModel's validator will handle the smart type loading and fallback
2008
+ if isinstance(self.response_format, (type, str)):
2009
+ self.response_format = ResponseFormatModel(
2010
+ response_schema=self.response_format
2011
+ )
2012
+ return self
2013
+
2014
+ # Invalid type
2015
+ raise ValueError(
2016
+ f"response_format must be None, ResponseFormatModel, type, or str, "
2017
+ f"got {type(self.response_format)}"
2018
+ )
1564
2019
 
1565
2020
  def as_runnable(self) -> RunnableLike:
1566
2021
  from dao_ai.nodes import create_agent_node
@@ -1579,6 +2034,10 @@ class SupervisorModel(BaseModel):
1579
2034
  model: LLMModel
1580
2035
  tools: list[ToolModel] = Field(default_factory=list)
1581
2036
  prompt: Optional[str] = None
2037
+ middleware: list[MiddlewareModel] = Field(
2038
+ default_factory=list,
2039
+ description="List of middleware to apply to the supervisor",
2040
+ )
1582
2041
 
1583
2042
 
1584
2043
  class SwarmModel(BaseModel):
@@ -1702,6 +2161,28 @@ class ChatPayload(BaseModel):
1702
2161
 
1703
2162
  return self
1704
2163
 
2164
+ @model_validator(mode="after")
2165
+ def ensure_thread_id(self) -> "ChatPayload":
2166
+ """Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
2167
+ import uuid
2168
+
2169
+ if self.custom_inputs is None:
2170
+ self.custom_inputs = {}
2171
+
2172
+ # Get or create configurable section
2173
+ configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
2174
+
2175
+ # Check if thread_id or conversation_id exists
2176
+ has_thread_id = configurable.get("thread_id") is not None
2177
+ has_conversation_id = configurable.get("conversation_id") is not None
2178
+
2179
+ # If neither is provided, generate a UUID for conversation_id
2180
+ if not has_thread_id and not has_conversation_id:
2181
+ configurable["conversation_id"] = str(uuid.uuid4())
2182
+ self.custom_inputs["configurable"] = configurable
2183
+
2184
+ return self
2185
+
1705
2186
  def as_messages(self) -> Sequence[BaseMessage]:
1706
2187
  return messages_from_dict(
1707
2188
  [{"type": m.role, "content": m.content} for m in self.messages]
@@ -1717,20 +2198,38 @@ class ChatPayload(BaseModel):
1717
2198
 
1718
2199
 
1719
2200
  class ChatHistoryModel(BaseModel):
2201
+ """
2202
+ Configuration for chat history summarization.
2203
+
2204
+ Attributes:
2205
+ model: The LLM to use for generating summaries.
2206
+ max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
2207
+ After summarization, recent messages totaling up to this many tokens are preserved.
2208
+ max_tokens_before_summary: Token threshold that triggers summarization.
2209
+ When conversation exceeds this, summarization runs. Mutually exclusive with
2210
+ max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
2211
+ max_messages_before_summary: Message count threshold that triggers summarization.
2212
+ When conversation exceeds this many messages, summarization runs.
2213
+ Mutually exclusive with max_tokens_before_summary.
2214
+ """
2215
+
1720
2216
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1721
2217
  model: LLMModel
1722
- max_tokens: int = 256
1723
- max_tokens_before_summary: Optional[int] = None
1724
- max_messages_before_summary: Optional[int] = None
1725
- max_summary_tokens: int = 255
1726
-
1727
- @model_validator(mode="after")
1728
- def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1729
- if self.max_summary_tokens >= self.max_tokens:
1730
- raise ValueError(
1731
- f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1732
- )
1733
- return self
2218
+ max_tokens: int = Field(
2219
+ default=2048,
2220
+ gt=0,
2221
+ description="Maximum tokens to keep after summarization",
2222
+ )
2223
+ max_tokens_before_summary: Optional[int] = Field(
2224
+ default=None,
2225
+ gt=0,
2226
+ description="Token threshold that triggers summarization",
2227
+ )
2228
+ max_messages_before_summary: Optional[int] = Field(
2229
+ default=None,
2230
+ gt=0,
2231
+ description="Message count threshold that triggers summarization",
2232
+ )
1734
2233
 
1735
2234
 
1736
2235
  class AppModel(BaseModel):
@@ -1757,9 +2256,6 @@ class AppModel(BaseModel):
1757
2256
  shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1758
2257
  default_factory=list
1759
2258
  )
1760
- message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1761
- default_factory=list
1762
- )
1763
2259
  input_example: Optional[ChatPayload] = None
1764
2260
  chat_history: Optional[ChatHistoryModel] = None
1765
2261
  code_paths: list[str] = Field(default_factory=list)
@@ -1964,33 +2460,67 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
1964
2460
 
1965
2461
 
1966
2462
  class PromptOptimizationModel(BaseModel):
2463
+ """Configuration for prompt optimization using GEPA.
2464
+
2465
+ GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
2466
+ optimizer that uses reflective mutation to improve prompts based on
2467
+ evaluation feedback.
2468
+
2469
+ Example:
2470
+ prompt_optimization:
2471
+ name: optimize_my_prompt
2472
+ prompt: *my_prompt
2473
+ agent: *my_agent
2474
+ dataset: *my_training_dataset
2475
+ reflection_model: databricks-meta-llama-3-3-70b-instruct
2476
+ num_candidates: 50
2477
+ """
2478
+
1967
2479
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1968
2480
  name: str
1969
2481
  prompt: Optional[PromptModel] = None
1970
2482
  agent: AgentModel
1971
- dataset: (
1972
- EvaluationDatasetModel | str
1973
- ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
2483
+ dataset: EvaluationDatasetModel # Training dataset with examples
1974
2484
  reflection_model: Optional[LLMModel | str] = None
1975
2485
  num_candidates: Optional[int] = 50
1976
- scorer_model: Optional[LLMModel | str] = None
1977
2486
 
1978
2487
  def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1979
2488
  """
1980
- Optimize the prompt using MLflow's prompt optimization.
2489
+ Optimize the prompt using GEPA.
1981
2490
 
1982
2491
  Args:
1983
- w: Optional WorkspaceClient for Databricks operations
2492
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
1984
2493
 
1985
2494
  Returns:
1986
- PromptModel: The optimized prompt model with new URI
2495
+ PromptModel: The optimized prompt model
1987
2496
  """
1988
- from dao_ai.providers.base import ServiceProvider
1989
- from dao_ai.providers.databricks import DatabricksProvider
2497
+ from dao_ai.optimization import OptimizationResult, optimize_prompt
1990
2498
 
1991
- provider: ServiceProvider = DatabricksProvider(w=w)
1992
- optimized_prompt: PromptModel = provider.optimize_prompt(self)
1993
- return optimized_prompt
2499
+ # Get reflection model name
2500
+ reflection_model_name: str | None = None
2501
+ if self.reflection_model:
2502
+ if isinstance(self.reflection_model, str):
2503
+ reflection_model_name = self.reflection_model
2504
+ else:
2505
+ reflection_model_name = self.reflection_model.uri
2506
+
2507
+ # Ensure prompt is set
2508
+ prompt = self.prompt
2509
+ if prompt is None:
2510
+ raise ValueError(
2511
+ f"Prompt optimization '{self.name}' requires a prompt to be set"
2512
+ )
2513
+
2514
+ result: OptimizationResult = optimize_prompt(
2515
+ prompt=prompt,
2516
+ agent=self.agent,
2517
+ dataset=self.dataset,
2518
+ reflection_model=reflection_model_name,
2519
+ num_candidates=self.num_candidates or 50,
2520
+ register_if_improved=True,
2521
+ )
2522
+
2523
+ return result.optimized_prompt
1994
2524
 
1995
2525
  @model_validator(mode="after")
1996
2526
  def set_defaults(self) -> Self:
@@ -2004,12 +2534,6 @@ class PromptOptimizationModel(BaseModel):
2004
2534
  f"or an agent with a prompt configured"
2005
2535
  )
2006
2536
 
2007
- if self.reflection_model is None:
2008
- self.reflection_model = self.agent.model
2009
-
2010
- if self.scorer_model is None:
2011
- self.scorer_model = self.agent.model
2012
-
2013
2537
  return self
2014
2538
 
2015
2539
 
@@ -2110,6 +2634,7 @@ class ResourcesModel(BaseModel):
2110
2634
  warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
2111
2635
  databases: dict[str, DatabaseModel] = Field(default_factory=dict)
2112
2636
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
2637
+ apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
2113
2638
 
2114
2639
 
2115
2640
  class AppConfig(BaseModel):
@@ -2121,6 +2646,7 @@ class AppConfig(BaseModel):
2121
2646
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
2122
2647
  tools: dict[str, ToolModel] = Field(default_factory=dict)
2123
2648
  guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
2649
+ middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
2124
2650
  memory: Optional[MemoryModel] = None
2125
2651
  prompts: dict[str, PromptModel] = Field(default_factory=dict)
2126
2652
  agents: dict[str, AgentModel] = Field(default_factory=dict)