dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -28,8 +28,12 @@ from databricks.sdk.service.database import DatabaseInstance
28
28
  from databricks.vector_search.client import VectorSearchClient
29
29
  from databricks.vector_search.index import VectorSearchIndex
30
30
  from databricks_langchain import (
31
+ ChatDatabricks,
32
+ DatabricksEmbeddings,
31
33
  DatabricksFunctionClient,
32
34
  )
35
+ from langchain.agents.structured_output import ProviderStrategy, ToolStrategy
36
+ from langchain_core.embeddings import Embeddings
33
37
  from langchain_core.language_models import LanguageModelLike
34
38
  from langchain_core.messages import BaseMessage, messages_from_dict
35
39
  from langchain_core.runnables.base import RunnableLike
@@ -42,6 +46,7 @@ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
42
46
  from mlflow.genai.prompts import PromptVersion, load_prompt
43
47
  from mlflow.models import ModelConfig
44
48
  from mlflow.models.resources import (
49
+ DatabricksApp,
45
50
  DatabricksFunction,
46
51
  DatabricksGenieSpace,
47
52
  DatabricksLakebase,
@@ -82,27 +87,6 @@ class HasFullName(ABC):
82
87
  def full_name(self) -> str: ...
83
88
 
84
89
 
85
- class IsDatabricksResource(ABC):
86
- on_behalf_of_user: Optional[bool] = False
87
-
88
- @abstractmethod
89
- def as_resources(self) -> Sequence[DatabricksResource]: ...
90
-
91
- @property
92
- @abstractmethod
93
- def api_scopes(self) -> Sequence[str]: ...
94
-
95
- @property
96
- def workspace_client(self) -> WorkspaceClient:
97
- credentials_strategy: CredentialsStrategy = None
98
- if self.on_behalf_of_user:
99
- credentials_strategy = ModelServingUserCredentials()
100
- logger.debug(
101
- f"Creating WorkspaceClient with credentials strategy: {credentials_strategy}"
102
- )
103
- return WorkspaceClient(credentials_strategy=credentials_strategy)
104
-
105
-
106
90
  class EnvironmentVariableModel(BaseModel, HasValue):
107
91
  model_config = ConfigDict(
108
92
  frozen=True,
@@ -210,6 +194,138 @@ class ServicePrincipalModel(BaseModel):
210
194
  client_secret: AnyVariable
211
195
 
212
196
 
197
+ class IsDatabricksResource(ABC, BaseModel):
198
+ """
199
+ Base class for Databricks resources with authentication support.
200
+
201
+ Authentication Options:
202
+ ----------------------
203
+ 1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
204
+ calling user's identity via ModelServingUserCredentials.
205
+
206
+ 2. **Service Principal (OAuth M2M)**: Provide service_principal or
207
+ (client_id + client_secret + workspace_host) for service principal auth.
208
+
209
+ 3. **Personal Access Token (PAT)**: Provide pat (and optionally workspace_host)
210
+ to authenticate with a personal access token.
211
+
212
+ 4. **Ambient Authentication**: If no credentials provided, uses SDK defaults
213
+ (environment variables, notebook context, etc.)
214
+
215
+ Authentication Priority:
216
+ 1. OBO (on_behalf_of_user=True)
217
+ 2. Service Principal (client_id + client_secret + workspace_host)
218
+ 3. PAT (pat + workspace_host)
219
+ 4. Ambient/default authentication
220
+ """
221
+
222
+ model_config = ConfigDict(use_enum_values=True)
223
+
224
+ on_behalf_of_user: Optional[bool] = False
225
+ service_principal: Optional[ServicePrincipalModel] = None
226
+ client_id: Optional[AnyVariable] = None
227
+ client_secret: Optional[AnyVariable] = None
228
+ workspace_host: Optional[AnyVariable] = None
229
+ pat: Optional[AnyVariable] = None
230
+
231
+ @abstractmethod
232
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
233
+
234
+ @property
235
+ @abstractmethod
236
+ def api_scopes(self) -> Sequence[str]: ...
237
+
238
+ @model_validator(mode="after")
239
+ def _expand_service_principal(self) -> Self:
240
+ """Expand service_principal into client_id and client_secret if provided."""
241
+ if self.service_principal is not None:
242
+ if self.client_id is None:
243
+ self.client_id = self.service_principal.client_id
244
+ if self.client_secret is None:
245
+ self.client_secret = self.service_principal.client_secret
246
+ return self
247
+
248
+ @model_validator(mode="after")
249
+ def _validate_auth_not_mixed(self) -> Self:
250
+ """Validate that OAuth and PAT authentication are not both provided."""
251
+ has_oauth: bool = self.client_id is not None and self.client_secret is not None
252
+ has_pat: bool = self.pat is not None
253
+
254
+ if has_oauth and has_pat:
255
+ raise ValueError(
256
+ "Cannot use both OAuth and user authentication methods. "
257
+ "Please provide either OAuth credentials or user credentials."
258
+ )
259
+ return self
260
+
261
+ @property
262
+ def workspace_client(self) -> WorkspaceClient:
263
+ """
264
+ Get a WorkspaceClient configured with the appropriate authentication.
265
+
266
+ Authentication priority:
267
+ 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
268
+ 2. If service principal credentials are configured (client_id, client_secret,
269
+ workspace_host), uses OAuth M2M
270
+ 3. If PAT is configured, uses token authentication
271
+ 4. Otherwise, uses default/ambient authentication
272
+ """
273
+ from dao_ai.utils import normalize_host
274
+
275
+ # Check for OBO first (highest priority)
276
+ if self.on_behalf_of_user:
277
+ credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
278
+ logger.debug(
279
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
280
+ f"with OBO credentials strategy"
281
+ )
282
+ return WorkspaceClient(credentials_strategy=credentials_strategy)
283
+
284
+ # Check for service principal credentials
285
+ client_id_value: str | None = (
286
+ value_of(self.client_id) if self.client_id else None
287
+ )
288
+ client_secret_value: str | None = (
289
+ value_of(self.client_secret) if self.client_secret else None
290
+ )
291
+ workspace_host_value: str | None = (
292
+ normalize_host(value_of(self.workspace_host))
293
+ if self.workspace_host
294
+ else None
295
+ )
296
+
297
+ if client_id_value and client_secret_value and workspace_host_value:
298
+ logger.debug(
299
+ f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
300
+ f"client_id={client_id_value}, host={workspace_host_value}"
301
+ )
302
+ return WorkspaceClient(
303
+ host=workspace_host_value,
304
+ client_id=client_id_value,
305
+ client_secret=client_secret_value,
306
+ auth_type="oauth-m2m",
307
+ )
308
+
309
+ # Check for PAT authentication
310
+ pat_value: str | None = value_of(self.pat) if self.pat else None
311
+ if pat_value:
312
+ logger.debug(
313
+ f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
314
+ )
315
+ return WorkspaceClient(
316
+ host=workspace_host_value,
317
+ token=pat_value,
318
+ auth_type="pat",
319
+ )
320
+
321
+ # Default: use ambient authentication
322
+ logger.debug(
323
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
324
+ "with default/ambient authentication"
325
+ )
326
+ return WorkspaceClient()
327
+
328
+
213
329
  class Privilege(str, Enum):
214
330
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
215
331
  USE_CATALOG = "USE_CATALOG"
@@ -270,7 +386,26 @@ class SchemaModel(BaseModel, HasFullName):
270
386
  provider.create_schema(self)
271
387
 
272
388
 
273
- class TableModel(BaseModel, HasFullName, IsDatabricksResource):
389
+ class DatabricksAppModel(IsDatabricksResource, HasFullName):
390
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
391
+ name: str
392
+ url: str
393
+
394
+ @property
395
+ def full_name(self) -> str:
396
+ return self.name
397
+
398
+ @property
399
+ def api_scopes(self) -> Sequence[str]:
400
+ return ["apps.apps"]
401
+
402
+ def as_resources(self) -> Sequence[DatabricksResource]:
403
+ return [
404
+ DatabricksApp(app_name=self.name, on_behalf_of_user=self.on_behalf_of_user)
405
+ ]
406
+
407
+
408
+ class TableModel(IsDatabricksResource, HasFullName):
274
409
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
275
410
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
276
411
  name: Optional[str] = None
@@ -339,12 +474,16 @@ class TableModel(BaseModel, HasFullName, IsDatabricksResource):
339
474
  return resources
340
475
 
341
476
 
342
- class LLMModel(BaseModel, IsDatabricksResource):
477
+ class LLMModel(IsDatabricksResource):
343
478
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
344
479
  name: str
345
480
  temperature: Optional[float] = 0.1
346
481
  max_tokens: Optional[int] = 8192
347
482
  fallbacks: Optional[list[Union[str, "LLMModel"]]] = Field(default_factory=list)
483
+ use_responses_api: Optional[bool] = Field(
484
+ default=False,
485
+ description="Use Responses API for ResponsesAgent endpoints",
486
+ )
348
487
 
349
488
  @property
350
489
  def api_scopes(self) -> Sequence[str]:
@@ -364,19 +503,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
364
503
  ]
365
504
 
366
505
  def as_chat_model(self) -> LanguageModelLike:
367
- # Retrieve langchain chat client from workspace client to enable OBO
368
- # ChatOpenAI does not allow additional inputs at the moment, so we cannot use it directly
369
- # chat_client: LanguageModelLike = self.as_open_ai_client()
370
-
371
- # Create ChatDatabricksWrapper instance directly
372
- from dao_ai.chat_models import ChatDatabricksFiltered
373
-
374
- chat_client: LanguageModelLike = ChatDatabricksFiltered(
375
- model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
506
+ chat_client: LanguageModelLike = ChatDatabricks(
507
+ model=self.name,
508
+ temperature=self.temperature,
509
+ max_tokens=self.max_tokens,
510
+ use_responses_api=self.use_responses_api,
376
511
  )
377
- # chat_client: LanguageModelLike = ChatDatabricks(
378
- # model=self.name, temperature=self.temperature, max_tokens=self.max_tokens
379
- # )
380
512
 
381
513
  fallbacks: Sequence[LanguageModelLike] = []
382
514
  for fallback in self.fallbacks:
@@ -408,6 +540,9 @@ class LLMModel(BaseModel, IsDatabricksResource):
408
540
 
409
541
  return chat_client
410
542
 
543
+ def as_embeddings_model(self) -> Embeddings:
544
+ return DatabricksEmbeddings(endpoint=self.name)
545
+
411
546
 
412
547
  class VectorSearchEndpointType(str, Enum):
413
548
  STANDARD = "STANDARD"
@@ -427,7 +562,7 @@ class VectorSearchEndpoint(BaseModel):
427
562
  return str(value)
428
563
 
429
564
 
430
- class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
565
+ class IndexModel(IsDatabricksResource, HasFullName):
431
566
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
432
567
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
433
568
  name: str
@@ -452,7 +587,7 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
452
587
  ]
453
588
 
454
589
 
455
- class GenieRoomModel(BaseModel, IsDatabricksResource):
590
+ class GenieRoomModel(IsDatabricksResource):
456
591
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
457
592
  name: str
458
593
  description: Optional[str] = None
@@ -478,7 +613,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
478
613
  return self
479
614
 
480
615
 
481
- class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
616
+ class VolumeModel(IsDatabricksResource, HasFullName):
482
617
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
483
618
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
484
619
  name: str
@@ -538,7 +673,7 @@ class VolumePathModel(BaseModel, HasFullName):
538
673
  provider.create_path(self)
539
674
 
540
675
 
541
- class VectorStoreModel(BaseModel, IsDatabricksResource):
676
+ class VectorStoreModel(IsDatabricksResource):
542
677
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
543
678
  embedding_model: Optional[LLMModel] = None
544
679
  index: Optional[IndexModel] = None
@@ -637,7 +772,7 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
637
772
  provider.create_vector_store(self)
638
773
 
639
774
 
640
- class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
775
+ class FunctionModel(IsDatabricksResource, HasFullName):
641
776
  model_config = ConfigDict()
642
777
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
643
778
  name: Optional[str] = None
@@ -692,7 +827,7 @@ class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
692
827
  return ["sql.statement-execution"]
693
828
 
694
829
 
695
- class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
830
+ class ConnectionModel(IsDatabricksResource, HasFullName):
696
831
  model_config = ConfigDict()
697
832
  name: str
698
833
 
@@ -719,7 +854,7 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
719
854
  ]
720
855
 
721
856
 
722
- class WarehouseModel(BaseModel, IsDatabricksResource):
857
+ class WarehouseModel(IsDatabricksResource):
723
858
  model_config = ConfigDict()
724
859
  name: str
725
860
  description: Optional[str] = None
@@ -746,30 +881,28 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
746
881
  return self
747
882
 
748
883
 
749
- class DatabaseModel(BaseModel, IsDatabricksResource):
884
+ class DatabaseType(str, Enum):
885
+ POSTGRES = "postgres"
886
+ LAKEBASE = "lakebase"
887
+
888
+
889
+ class DatabaseModel(IsDatabricksResource):
750
890
  """
751
891
  Configuration for a Databricks Lakebase (PostgreSQL) database instance.
752
892
 
753
- Authentication Model:
754
- --------------------
755
- This model uses TWO separate authentication contexts:
756
-
757
- 1. **Workspace API Authentication** (inherited from IsDatabricksResource):
758
- - Uses ambient/default authentication (environment variables, notebook context, app service principal)
759
- - Used for: discovering database instance, getting host DNS, checking instance status
760
- - Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
893
+ Authentication is inherited from IsDatabricksResource. Additionally supports:
894
+ - user/password: For user-based database authentication
761
895
 
762
- 2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
763
- - Used for: connecting to the PostgreSQL database as a specific identity
764
- - Service Principal: Set service_principal with workspace_host to connect as a service principal
765
- - OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
766
- - User Auth: Set user (and optionally password) to connect as a user identity
896
+ Database Type:
897
+ - lakebase: Databricks-managed Lakebase instance (authentication optional, supports ambient auth)
898
+ - postgres: Standard PostgreSQL database (authentication required)
767
899
 
768
900
  Example Service Principal Configuration:
769
901
  ```yaml
770
902
  databases:
771
903
  my_lakebase:
772
904
  name: my-database
905
+ type: lakebase
773
906
  service_principal:
774
907
  client_id:
775
908
  env: SERVICE_PRINCIPAL_CLIENT_ID
@@ -780,31 +913,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
780
913
  env: DATABRICKS_HOST
781
914
  ```
782
915
 
783
- Example OAuth M2M Configuration (alternative):
916
+ Example User Configuration:
784
917
  ```yaml
785
918
  databases:
786
919
  my_lakebase:
787
920
  name: my-database
788
- client_id:
789
- env: SERVICE_PRINCIPAL_CLIENT_ID
790
- client_secret:
791
- scope: my-scope
792
- secret: sp-client-secret
793
- workspace_host:
794
- env: DATABRICKS_HOST
921
+ type: lakebase
922
+ user: my-user@databricks.com
795
923
  ```
796
924
 
797
- Example User Configuration:
925
+ Example Ambient Authentication (Lakebase only):
798
926
  ```yaml
799
927
  databases:
800
928
  my_lakebase:
801
929
  name: my-database
802
- user: my-user@databricks.com
930
+ type: lakebase
931
+ on_behalf_of_user: true
803
932
  ```
804
933
  """
805
934
 
806
935
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
807
936
  name: str
937
+ type: Optional[DatabaseType] = DatabaseType.LAKEBASE
808
938
  instance_name: Optional[str] = None
809
939
  description: Optional[str] = None
810
940
  host: Optional[AnyVariable] = None
@@ -815,16 +945,18 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
815
945
  timeout_seconds: Optional[int] = 10
816
946
  capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
817
947
  node_count: Optional[int] = None
948
+ # Database-specific auth (user identity for DB connection)
818
949
  user: Optional[AnyVariable] = None
819
950
  password: Optional[AnyVariable] = None
820
- service_principal: Optional[ServicePrincipalModel] = None
821
- client_id: Optional[AnyVariable] = None
822
- client_secret: Optional[AnyVariable] = None
823
- workspace_host: Optional[AnyVariable] = None
951
+
952
+ @field_serializer("type")
953
+ def serialize_type(self, value: DatabaseType | None) -> str | None:
954
+ """Serialize the database type enum to its string value."""
955
+ return value.value if value is not None else None
824
956
 
825
957
  @property
826
958
  def api_scopes(self) -> Sequence[str]:
827
- return []
959
+ return ["database.database-instances"]
828
960
 
829
961
  def as_resources(self) -> Sequence[DatabricksResource]:
830
962
  return [
@@ -838,29 +970,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
838
970
  def update_instance_name(self) -> Self:
839
971
  if self.instance_name is None:
840
972
  self.instance_name = self.name
841
-
842
- return self
843
-
844
- @model_validator(mode="after")
845
- def expand_service_principal(self) -> Self:
846
- """Expand service_principal into client_id and client_secret if provided."""
847
- if self.service_principal is not None:
848
- if self.client_id is None:
849
- self.client_id = self.service_principal.client_id
850
- if self.client_secret is None:
851
- self.client_secret = self.service_principal.client_secret
852
973
  return self
853
974
 
854
975
  @model_validator(mode="after")
855
976
  def update_user(self) -> Self:
856
- if self.client_id or self.user:
977
+ # Skip if using OBO (passive auth), explicit credentials, or explicit user
978
+ if self.on_behalf_of_user or self.client_id or self.user or self.pat:
857
979
  return self
858
980
 
859
- self.user = self.workspace_client.current_user.me().user_name
860
- if not self.user:
861
- raise ValueError(
862
- "Unable to determine current user. Please provide a user name or OAuth credentials."
863
- )
981
+ # For postgres, we need explicit user credentials
982
+ # For lakebase with no auth, ambient auth is allowed
983
+ if self.type == DatabaseType.POSTGRES:
984
+ # Try to determine current user for local development
985
+ try:
986
+ self.user = self.workspace_client.current_user.me().user_name
987
+ except Exception as e:
988
+ logger.warning(
989
+ f"Could not determine current user for PostgreSQL database: {e}. "
990
+ f"Please provide explicit user credentials."
991
+ )
992
+ else:
993
+ # For lakebase, try to determine current user but don't fail if we can't
994
+ try:
995
+ self.user = self.workspace_client.current_user.me().user_name
996
+ except Exception:
997
+ # If we can't determine user and no explicit auth, that's okay
998
+ # for lakebase with ambient auth - credentials will be injected at runtime
999
+ pass
864
1000
 
865
1001
  return self
866
1002
 
@@ -869,12 +1005,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
869
1005
  if self.host is not None:
870
1006
  return self
871
1007
 
872
- existing_instance: DatabaseInstance = (
873
- self.workspace_client.database.get_database_instance(
874
- name=self.instance_name
1008
+ # Try to fetch host from existing instance
1009
+ # This may fail for OBO/ambient auth during model logging (before deployment)
1010
+ try:
1011
+ existing_instance: DatabaseInstance = (
1012
+ self.workspace_client.database.get_database_instance(
1013
+ name=self.instance_name
1014
+ )
875
1015
  )
876
- )
877
- self.host = existing_instance.read_write_dns
1016
+ self.host = existing_instance.read_write_dns
1017
+ except Exception as e:
1018
+ # For lakebase with OBO/ambient auth, we can't fetch at config time
1019
+ # The host will need to be provided explicitly or fetched at runtime
1020
+ if self.type == DatabaseType.LAKEBASE and self.on_behalf_of_user:
1021
+ logger.debug(
1022
+ f"Could not fetch host for database {self.instance_name} "
1023
+ f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1024
+ )
1025
+ else:
1026
+ raise ValueError(
1027
+ f"Could not fetch host for database {self.instance_name}. "
1028
+ f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1029
+ )
878
1030
  return self
879
1031
 
880
1032
  @model_validator(mode="after")
@@ -885,21 +1037,33 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
885
1037
  self.client_secret,
886
1038
  ]
887
1039
  has_oauth: bool = all(field is not None for field in oauth_fields)
1040
+ has_user_auth: bool = self.user is not None
1041
+ has_obo: bool = self.on_behalf_of_user is True
1042
+ has_pat: bool = self.pat is not None
888
1043
 
889
- pat_fields: Sequence[Any] = [self.user]
890
- has_user_auth: bool = all(field is not None for field in pat_fields)
1044
+ # Count how many auth methods are configured
1045
+ auth_methods_count: int = sum([has_oauth, has_user_auth, has_obo, has_pat])
891
1046
 
892
- if has_oauth and has_user_auth:
1047
+ if auth_methods_count > 1:
893
1048
  raise ValueError(
894
- "Cannot use both OAuth and user authentication methods. "
895
- "Please provide either OAuth credentials or user credentials."
1049
+ "Cannot mix authentication methods. "
1050
+ "Please provide exactly one of: "
1051
+ "on_behalf_of_user=true (for passive auth in model serving), "
1052
+ "OAuth credentials (service_principal or client_id + client_secret + workspace_host), "
1053
+ "PAT (personal access token), "
1054
+ "or user credentials (user)."
896
1055
  )
897
1056
 
898
- if not has_oauth and not has_user_auth:
1057
+ # For postgres type, at least one auth method must be configured
1058
+ # For lakebase type, auth is optional (supports ambient authentication)
1059
+ if self.type == DatabaseType.POSTGRES and auth_methods_count == 0:
899
1060
  raise ValueError(
900
- "At least one authentication method must be provided: "
901
- "either OAuth credentials (workspace_host, client_id, client_secret), "
902
- "service_principal with workspace_host, or user credentials (user, password)."
1061
+ "PostgreSQL databases require explicit authentication. "
1062
+ "Please provide one of: "
1063
+ "OAuth credentials (workspace_host, client_id, client_secret), "
1064
+ "service_principal with workspace_host, "
1065
+ "PAT (personal access token), "
1066
+ "or user credentials (user)."
903
1067
  )
904
1068
 
905
1069
  return self
@@ -913,8 +1077,9 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
913
1077
  If username is configured, it will be included; otherwise it will be omitted
914
1078
  to allow Lakebase to authenticate using the token's identity.
915
1079
  """
916
- from dao_ai.providers.base import ServiceProvider
917
- from dao_ai.providers.databricks import DatabricksProvider
1080
+ import uuid as _uuid
1081
+
1082
+ from databricks.sdk.service.database import DatabaseCredential
918
1083
 
919
1084
  username: str | None = None
920
1085
 
@@ -922,19 +1087,36 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
922
1087
  username = value_of(self.client_id)
923
1088
  elif self.user:
924
1089
  username = value_of(self.user)
1090
+ # For OBO mode, no username is needed - the token identity is used
1091
+
1092
+ # Resolve host - may need to fetch at runtime for OBO mode
1093
+ host_value: Any = self.host
1094
+ if host_value is None and self.on_behalf_of_user:
1095
+ # Fetch host at runtime for OBO mode
1096
+ existing_instance: DatabaseInstance = (
1097
+ self.workspace_client.database.get_database_instance(
1098
+ name=self.instance_name
1099
+ )
1100
+ )
1101
+ host_value = existing_instance.read_write_dns
1102
+
1103
+ if host_value is None:
1104
+ raise ValueError(
1105
+ f"Database host not configured for {self.instance_name}. "
1106
+ "Please provide 'host' explicitly."
1107
+ )
925
1108
 
926
- host: str = value_of(self.host)
1109
+ host: str = value_of(host_value)
927
1110
  port: int = value_of(self.port)
928
1111
  database: str = value_of(self.database)
929
1112
 
930
- provider: ServiceProvider = DatabricksProvider(
931
- client_id=value_of(self.client_id),
932
- client_secret=value_of(self.client_secret),
933
- workspace_host=value_of(self.workspace_host),
934
- pat=value_of(self.password),
1113
+ # Use the resource's own workspace_client to generate the database credential
1114
+ w: WorkspaceClient = self.workspace_client
1115
+ cred: DatabaseCredential = w.database.generate_database_credential(
1116
+ request_id=str(_uuid.uuid4()),
1117
+ instance_names=[self.instance_name],
935
1118
  )
936
-
937
- token: str = provider.lakebase_password_provider(self.instance_name)
1119
+ token: str = cred.token
938
1120
 
939
1121
  # Build connection parameters dictionary
940
1122
  params: dict[str, Any] = {
@@ -972,11 +1154,86 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
972
1154
  def create(self, w: WorkspaceClient | None = None) -> None:
973
1155
  from dao_ai.providers.databricks import DatabricksProvider
974
1156
 
1157
+ # Use provided workspace client or fall back to resource's own workspace_client
1158
+ if w is None:
1159
+ w = self.workspace_client
975
1160
  provider: DatabricksProvider = DatabricksProvider(w=w)
976
1161
  provider.create_lakebase(self)
977
1162
  provider.create_lakebase_instance_role(self)
978
1163
 
979
1164
 
1165
+ class GenieLRUCacheParametersModel(BaseModel):
1166
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1167
+ capacity: int = 1000
1168
+ time_to_live_seconds: int | None = (
1169
+ 60 * 60 * 24
1170
+ ) # 1 day default, None or negative = never expires
1171
+ warehouse: WarehouseModel
1172
+
1173
+
1174
+ class GenieSemanticCacheParametersModel(BaseModel):
1175
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1176
+ time_to_live_seconds: int | None = (
1177
+ 60 * 60 * 24
1178
+ ) # 1 day default, None or negative = never expires
1179
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1180
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1181
+ question_weight: Optional[float] = (
1182
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1183
+ )
1184
+ context_weight: Optional[float] = (
1185
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1186
+ )
1187
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1188
+ embedding_dims: int | None = None # Auto-detected if None
1189
+ database: DatabaseModel
1190
+ warehouse: WarehouseModel
1191
+ table_name: str = "genie_semantic_cache"
1192
+ context_window_size: int = 3 # Number of previous turns to include for context
1193
+ max_context_tokens: int = (
1194
+ 2000 # Maximum context length to prevent extremely long embeddings
1195
+ )
1196
+
1197
+ @model_validator(mode="after")
1198
+ def compute_and_validate_weights(self) -> Self:
1199
+ """
1200
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1201
+
1202
+ Either question_weight or context_weight (or both) can be provided.
1203
+ The missing one will be computed as 1.0 - provided_weight.
1204
+ If both are provided, they must sum to 1.0.
1205
+ """
1206
+ if self.question_weight is None and self.context_weight is None:
1207
+ # Both missing - use defaults
1208
+ self.question_weight = 0.6
1209
+ self.context_weight = 0.4
1210
+ elif self.question_weight is None:
1211
+ # Compute question_weight from context_weight
1212
+ if not (0.0 <= self.context_weight <= 1.0):
1213
+ raise ValueError(
1214
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1215
+ )
1216
+ self.question_weight = 1.0 - self.context_weight
1217
+ elif self.context_weight is None:
1218
+ # Compute context_weight from question_weight
1219
+ if not (0.0 <= self.question_weight <= 1.0):
1220
+ raise ValueError(
1221
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1222
+ )
1223
+ self.context_weight = 1.0 - self.question_weight
1224
+ else:
1225
+ # Both provided - validate they sum to 1.0
1226
+ total_weight = self.question_weight + self.context_weight
1227
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1228
+ raise ValueError(
1229
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1230
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1231
+ f"of question vs context similarity in the combined score."
1232
+ )
1233
+
1234
+ return self
1235
+
1236
+
980
1237
  class SearchParametersModel(BaseModel):
981
1238
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
982
1239
  num_results: Optional[int] = 10
@@ -1067,28 +1324,47 @@ class FunctionType(str, Enum):
1067
1324
  MCP = "mcp"
1068
1325
 
1069
1326
 
1070
- class HumanInTheLoopActionType(str, Enum):
1071
- """Supported action types for human-in-the-loop interactions."""
1327
+ class HumanInTheLoopModel(BaseModel):
1328
+ """
1329
+ Configuration for Human-in-the-Loop tool approval.
1072
1330
 
1073
- ACCEPT = "accept"
1074
- EDIT = "edit"
1075
- RESPONSE = "response"
1076
- DECLINE = "decline"
1331
+ This model configures when and how tools require human approval before execution.
1332
+ It maps to LangChain's HumanInTheLoopMiddleware.
1077
1333
 
1334
+ LangChain supports three decision types:
1335
+ - "approve": Execute tool with original arguments
1336
+ - "edit": Modify arguments before execution
1337
+ - "reject": Skip execution with optional feedback message
1338
+ """
1078
1339
 
1079
- class HumanInTheLoopModel(BaseModel):
1080
1340
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1081
- review_prompt: str = "Please review the tool call"
1082
- interrupt_config: dict[str, Any] = Field(
1083
- default_factory=lambda: {
1084
- "allow_accept": True,
1085
- "allow_edit": True,
1086
- "allow_respond": True,
1087
- "allow_decline": True,
1088
- }
1341
+
1342
+ review_prompt: Optional[str] = Field(
1343
+ default=None,
1344
+ description="Message shown to the reviewer when approval is requested",
1345
+ )
1346
+
1347
+ allowed_decisions: list[Literal["approve", "edit", "reject"]] = Field(
1348
+ default_factory=lambda: ["approve", "edit", "reject"],
1349
+ description="List of allowed decision types for this tool",
1089
1350
  )
1090
- decline_message: str = "Tool call declined by user"
1091
- custom_actions: Optional[dict[str, str]] = Field(default_factory=dict)
1351
+
1352
+ @model_validator(mode="after")
1353
+ def validate_and_normalize_decisions(self) -> Self:
1354
+ """Validate and normalize allowed decisions."""
1355
+ if not self.allowed_decisions:
1356
+ raise ValueError("At least one decision type must be allowed")
1357
+
1358
+ # Remove duplicates while preserving order
1359
+ seen = set()
1360
+ unique_decisions = []
1361
+ for decision in self.allowed_decisions:
1362
+ if decision not in seen:
1363
+ seen.add(decision)
1364
+ unique_decisions.append(decision)
1365
+ self.allowed_decisions = unique_decisions
1366
+
1367
+ return self
1092
1368
 
1093
1369
 
1094
1370
  class BaseFunctionModel(ABC, BaseModel):
@@ -1151,7 +1427,16 @@ class TransportType(str, Enum):
1151
1427
  STDIO = "stdio"
1152
1428
 
1153
1429
 
1154
- class McpFunctionModel(BaseFunctionModel, HasFullName):
1430
+ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource, HasFullName):
1431
+ """
1432
+ MCP Function Model with authentication inherited from IsDatabricksResource.
1433
+
1434
+ Authentication for MCP connections uses the same options as other resources:
1435
+ - Service Principal (client_id + client_secret + workspace_host)
1436
+ - PAT (pat + workspace_host)
1437
+ - OBO (on_behalf_of_user)
1438
+ """
1439
+
1155
1440
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1156
1441
  type: Literal[FunctionType.MCP] = FunctionType.MCP
1157
1442
  transport: TransportType = TransportType.STREAMABLE_HTTP
@@ -1159,26 +1444,27 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1159
1444
  url: Optional[AnyVariable] = None
1160
1445
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1161
1446
  args: list[str] = Field(default_factory=list)
1162
- pat: Optional[AnyVariable] = None
1163
- service_principal: Optional[ServicePrincipalModel] = None
1164
- client_id: Optional[AnyVariable] = None
1165
- client_secret: Optional[AnyVariable] = None
1166
- workspace_host: Optional[AnyVariable] = None
1447
+ # MCP-specific fields
1167
1448
  connection: Optional[ConnectionModel] = None
1168
1449
  functions: Optional[SchemaModel] = None
1169
1450
  genie_room: Optional[GenieRoomModel] = None
1170
1451
  sql: Optional[bool] = None
1171
1452
  vector_search: Optional[VectorStoreModel] = None
1172
1453
 
1173
- @model_validator(mode="after")
1174
- def expand_service_principal(self) -> Self:
1175
- """Expand service_principal into client_id and client_secret if provided."""
1176
- if self.service_principal is not None:
1177
- if self.client_id is None:
1178
- self.client_id = self.service_principal.client_id
1179
- if self.client_secret is None:
1180
- self.client_secret = self.service_principal.client_secret
1181
- return self
1454
+ @property
1455
+ def api_scopes(self) -> Sequence[str]:
1456
+ """API scopes for MCP connections."""
1457
+ return [
1458
+ "serving.serving-endpoints",
1459
+ "mcp.genie",
1460
+ "mcp.functions",
1461
+ "mcp.vectorsearch",
1462
+ "mcp.external",
1463
+ ]
1464
+
1465
+ def as_resources(self) -> Sequence[DatabricksResource]:
1466
+ """MCP functions don't declare static resources."""
1467
+ return []
1182
1468
 
1183
1469
  @property
1184
1470
  def full_name(self) -> str:
@@ -1343,27 +1629,6 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1343
1629
  self.headers[key] = value_of(value)
1344
1630
  return self
1345
1631
 
1346
- @model_validator(mode="after")
1347
- def validate_auth_methods(self) -> "McpFunctionModel":
1348
- oauth_fields: Sequence[Any] = [
1349
- self.client_id,
1350
- self.client_secret,
1351
- ]
1352
- has_oauth: bool = all(field is not None for field in oauth_fields)
1353
-
1354
- pat_fields: Sequence[Any] = [self.pat]
1355
- has_user_auth: bool = all(field is not None for field in pat_fields)
1356
-
1357
- if has_oauth and has_user_auth:
1358
- raise ValueError(
1359
- "Cannot use both OAuth and user authentication methods. "
1360
- "Please provide either OAuth credentials or user credentials."
1361
- )
1362
-
1363
- # Note: workspace_host is optional - it will be derived from workspace client if not provided
1364
-
1365
- return self
1366
-
1367
1632
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
1368
1633
  from dao_ai.tools import create_mcp_tools
1369
1634
 
@@ -1405,17 +1670,97 @@ class ToolModel(BaseModel):
1405
1670
  function: AnyTool
1406
1671
 
1407
1672
 
1673
+ class PromptModel(BaseModel, HasFullName):
1674
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1675
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1676
+ name: str
1677
+ description: Optional[str] = None
1678
+ default_template: Optional[str] = None
1679
+ alias: Optional[str] = None
1680
+ version: Optional[int] = None
1681
+ tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1682
+
1683
+ @property
1684
+ def template(self) -> str:
1685
+ from dao_ai.providers.databricks import DatabricksProvider
1686
+
1687
+ provider: DatabricksProvider = DatabricksProvider()
1688
+ prompt_version = provider.get_prompt(self)
1689
+ return prompt_version.to_single_brace_format()
1690
+
1691
+ @property
1692
+ def full_name(self) -> str:
1693
+ prompt_name: str = self.name
1694
+ if self.schema_model:
1695
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1696
+ return prompt_name
1697
+
1698
+ @property
1699
+ def uri(self) -> str:
1700
+ prompt_uri: str = f"prompts:/{self.full_name}"
1701
+
1702
+ if self.alias:
1703
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1704
+ elif self.version:
1705
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1706
+ else:
1707
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1708
+
1709
+ return prompt_uri
1710
+
1711
+ def as_prompt(self) -> PromptVersion:
1712
+ prompt_version: PromptVersion = load_prompt(self.uri)
1713
+ return prompt_version
1714
+
1715
+ @model_validator(mode="after")
1716
+ def validate_mutually_exclusive(self) -> Self:
1717
+ if self.alias and self.version:
1718
+ raise ValueError("Cannot specify both alias and version")
1719
+ return self
1720
+
1721
+
1408
1722
  class GuardrailModel(BaseModel):
1409
1723
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1410
1724
  name: str
1411
- model: LLMModel
1412
- prompt: str
1725
+ model: str | LLMModel
1726
+ prompt: str | PromptModel
1413
1727
  num_retries: Optional[int] = 3
1414
1728
 
1729
+ @model_validator(mode="after")
1730
+ def validate_llm_model(self) -> Self:
1731
+ if isinstance(self.model, str):
1732
+ self.model = LLMModel(name=self.model)
1733
+ return self
1734
+
1735
+
1736
+ class MiddlewareModel(BaseModel):
1737
+ """Configuration for middleware that can be applied to agents.
1738
+
1739
+ Middleware is defined at the AppConfig level and can be referenced by name
1740
+ in agent configurations using YAML anchors for reusability.
1741
+ """
1742
+
1743
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1744
+ name: str = Field(
1745
+ description="Fully qualified name of the middleware factory function"
1746
+ )
1747
+ args: dict[str, Any] = Field(
1748
+ default_factory=dict,
1749
+ description="Arguments to pass to the middleware factory function",
1750
+ )
1751
+
1752
+ @model_validator(mode="after")
1753
+ def resolve_args(self) -> Self:
1754
+ """Resolve any variable references in args."""
1755
+ for key, value in self.args.items():
1756
+ self.args[key] = value_of(value)
1757
+ return self
1758
+
1415
1759
 
1416
1760
  class StorageType(str, Enum):
1417
1761
  POSTGRES = "postgres"
1418
1762
  MEMORY = "memory"
1763
+ LAKEBASE = "lakebase"
1419
1764
 
1420
1765
 
1421
1766
  class CheckpointerModel(BaseModel):
@@ -1425,8 +1770,11 @@ class CheckpointerModel(BaseModel):
1425
1770
  database: Optional[DatabaseModel] = None
1426
1771
 
1427
1772
  @model_validator(mode="after")
1428
- def validate_postgres_requires_database(self) -> Self:
1429
- if self.type == StorageType.POSTGRES and not self.database:
1773
+ def validate_storage_requires_database(self) -> Self:
1774
+ if (
1775
+ self.type in [StorageType.POSTGRES, StorageType.LAKEBASE]
1776
+ and not self.database
1777
+ ):
1430
1778
  raise ValueError("Database must be provided when storage type is POSTGRES")
1431
1779
  return self
1432
1780
 
@@ -1471,56 +1819,158 @@ class MemoryModel(BaseModel):
1471
1819
  FunctionHook: TypeAlias = PythonFunctionModel | FactoryFunctionModel | str
1472
1820
 
1473
1821
 
1474
- class PromptModel(BaseModel, HasFullName):
1822
+ class ResponseFormatModel(BaseModel):
1823
+ """
1824
+ Configuration for structured response formats.
1825
+
1826
+ The response_schema field accepts either a type or a string:
1827
+ - Type (Pydantic model, dataclass, etc.): Used directly for structured output
1828
+ - String: First attempts to load as a fully qualified type name, falls back to JSON schema string
1829
+
1830
+ This unified approach simplifies the API while maintaining flexibility.
1831
+ """
1832
+
1475
1833
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1476
- schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1477
- name: str
1478
- description: Optional[str] = None
1479
- default_template: Optional[str] = None
1480
- alias: Optional[str] = None
1481
- version: Optional[int] = None
1482
- tags: Optional[dict[str, Any]] = Field(default_factory=dict)
1834
+ use_tool: Optional[bool] = Field(
1835
+ default=None,
1836
+ description=(
1837
+ "Strategy for structured output: "
1838
+ "None (default) = auto-detect from model capabilities, "
1839
+ "False = force ProviderStrategy (native), "
1840
+ "True = force ToolStrategy (function calling)"
1841
+ ),
1842
+ )
1843
+ response_schema: Optional[str | type] = Field(
1844
+ default=None,
1845
+ description="Type or string for response format. String attempts FQN import, falls back to JSON schema.",
1846
+ )
1483
1847
 
1484
- @property
1485
- def template(self) -> str:
1486
- from dao_ai.providers.databricks import DatabricksProvider
1848
+ def as_strategy(self) -> ProviderStrategy | ToolStrategy:
1849
+ """
1850
+ Convert response_schema to appropriate LangChain strategy.
1487
1851
 
1488
- provider: DatabricksProvider = DatabricksProvider()
1489
- prompt_version = provider.get_prompt(self)
1490
- return prompt_version.to_single_brace_format()
1852
+ Returns:
1853
+ - None if no response_schema configured
1854
+ - Raw schema/type for auto-detection (when use_tool=None)
1855
+ - ToolStrategy wrapping the schema (when use_tool=True)
1856
+ - ProviderStrategy wrapping the schema (when use_tool=False)
1491
1857
 
1492
- @property
1493
- def full_name(self) -> str:
1494
- prompt_name: str = self.name
1495
- if self.schema_model:
1496
- prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1497
- return prompt_name
1858
+ Raises:
1859
+ ValueError: If response_schema is a JSON schema string that cannot be parsed
1860
+ """
1498
1861
 
1499
- @property
1500
- def uri(self) -> str:
1501
- prompt_uri: str = f"prompts:/{self.full_name}"
1862
+ if self.response_schema is None:
1863
+ return None
1502
1864
 
1503
- if self.alias:
1504
- prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1505
- elif self.version:
1506
- prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1507
- else:
1508
- prompt_uri = f"prompts:/{self.full_name}@latest"
1865
+ schema = self.response_schema
1509
1866
 
1510
- return prompt_uri
1867
+ # Handle type schemas (Pydantic, dataclass, etc.)
1868
+ if self.is_type_schema:
1869
+ if self.use_tool is None:
1870
+ # Auto-detect: Pass schema directly, let LangChain decide
1871
+ return schema
1872
+ elif self.use_tool is True:
1873
+ # Force ToolStrategy (function calling)
1874
+ return ToolStrategy(schema)
1875
+ else: # use_tool is False
1876
+ # Force ProviderStrategy (native structured output)
1877
+ return ProviderStrategy(schema)
1511
1878
 
1512
- def as_prompt(self) -> PromptVersion:
1513
- prompt_version: PromptVersion = load_prompt(self.uri)
1514
- return prompt_version
1879
+ # Handle JSON schema strings
1880
+ elif self.is_json_schema:
1881
+ import json
1882
+
1883
+ try:
1884
+ schema_dict = json.loads(schema)
1885
+ except json.JSONDecodeError as e:
1886
+ raise ValueError(f"Invalid JSON schema string: {e}") from e
1887
+
1888
+ # Apply same use_tool logic as type schemas
1889
+ if self.use_tool is None:
1890
+ # Auto-detect
1891
+ return schema_dict
1892
+ elif self.use_tool is True:
1893
+ # Force ToolStrategy
1894
+ return ToolStrategy(schema_dict)
1895
+ else: # use_tool is False
1896
+ # Force ProviderStrategy
1897
+ return ProviderStrategy(schema_dict)
1898
+
1899
+ return None
1515
1900
 
1516
1901
  @model_validator(mode="after")
1517
- def validate_mutually_exclusive(self) -> Self:
1518
- if self.alias and self.version:
1519
- raise ValueError("Cannot specify both alias and version")
1520
- return self
1902
+ def validate_response_schema(self) -> Self:
1903
+ """
1904
+ Validate and convert response_schema.
1905
+
1906
+ Processing logic:
1907
+ 1. If None: no response format specified
1908
+ 2. If type: use directly as structured output type
1909
+ 3. If str: try to load as FQN using type_from_fqn
1910
+ - Success: response_schema becomes the loaded type
1911
+ - Failure: keep as string (treated as JSON schema)
1912
+
1913
+ After validation, response_schema is one of:
1914
+ - None (no schema)
1915
+ - type (use for structured output)
1916
+ - str (JSON schema)
1917
+
1918
+ Returns:
1919
+ Self with validated response_schema
1920
+ """
1921
+ if self.response_schema is None:
1922
+ return self
1923
+
1924
+ # If already a type, return
1925
+ if isinstance(self.response_schema, type):
1926
+ return self
1927
+
1928
+ # If it's a string, try to load as type, fallback to json_schema
1929
+ if isinstance(self.response_schema, str):
1930
+ from dao_ai.utils import type_from_fqn
1931
+
1932
+ try:
1933
+ resolved_type = type_from_fqn(self.response_schema)
1934
+ self.response_schema = resolved_type
1935
+ logger.debug(
1936
+ f"Resolved response_schema string to type: {resolved_type}"
1937
+ )
1938
+ return self
1939
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
1940
+ # Keep as string - it's a JSON schema
1941
+ logger.debug(
1942
+ f"Could not resolve '{self.response_schema}' as type: {e}. "
1943
+ f"Treating as JSON schema string."
1944
+ )
1945
+ return self
1946
+
1947
+ # Invalid type
1948
+ raise ValueError(
1949
+ f"response_schema must be None, type, or str, got {type(self.response_schema)}"
1950
+ )
1951
+
1952
+ @property
1953
+ def is_type_schema(self) -> bool:
1954
+ """Returns True if response_schema is a type (not JSON schema string)."""
1955
+ return isinstance(self.response_schema, type)
1956
+
1957
+ @property
1958
+ def is_json_schema(self) -> bool:
1959
+ """Returns True if response_schema is a JSON schema string (not a type)."""
1960
+ return isinstance(self.response_schema, str)
1521
1961
 
1522
1962
 
1523
1963
  class AgentModel(BaseModel):
1964
+ """
1965
+ Configuration model for an agent in the DAO AI framework.
1966
+
1967
+ Agents combine an LLM with tools and middleware to create systems that can
1968
+ reason about tasks, decide which tools to use, and iteratively work towards solutions.
1969
+
1970
+ Middleware replaces the previous pre_agent_hook and post_agent_hook patterns,
1971
+ providing a more flexible and composable way to customize agent behavior.
1972
+ """
1973
+
1524
1974
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1525
1975
  name: str
1526
1976
  description: Optional[str] = None
@@ -1529,9 +1979,43 @@ class AgentModel(BaseModel):
1529
1979
  guardrails: list[GuardrailModel] = Field(default_factory=list)
1530
1980
  prompt: Optional[str | PromptModel] = None
1531
1981
  handoff_prompt: Optional[str] = None
1532
- create_agent_hook: Optional[FunctionHook] = None
1533
- pre_agent_hook: Optional[FunctionHook] = None
1534
- post_agent_hook: Optional[FunctionHook] = None
1982
+ middleware: list[MiddlewareModel] = Field(
1983
+ default_factory=list,
1984
+ description="List of middleware to apply to this agent",
1985
+ )
1986
+ response_format: Optional[ResponseFormatModel | type | str] = None
1987
+
1988
+ @model_validator(mode="after")
1989
+ def validate_response_format(self) -> Self:
1990
+ """
1991
+ Validate and normalize response_format.
1992
+
1993
+ Accepts:
1994
+ - None (no response format)
1995
+ - ResponseFormatModel (already validated)
1996
+ - type (Pydantic model, dataclass, etc.) - converts to ResponseFormatModel
1997
+ - str (FQN or json_schema) - converts to ResponseFormatModel (smart fallback)
1998
+
1999
+ ResponseFormatModel handles the logic of trying FQN import and falling back to JSON schema.
2000
+ """
2001
+ if self.response_format is None or isinstance(
2002
+ self.response_format, ResponseFormatModel
2003
+ ):
2004
+ return self
2005
+
2006
+ # Convert type or str to ResponseFormatModel
2007
+ # ResponseFormatModel's validator will handle the smart type loading and fallback
2008
+ if isinstance(self.response_format, (type, str)):
2009
+ self.response_format = ResponseFormatModel(
2010
+ response_schema=self.response_format
2011
+ )
2012
+ return self
2013
+
2014
+ # Invalid type
2015
+ raise ValueError(
2016
+ f"response_format must be None, ResponseFormatModel, type, or str, "
2017
+ f"got {type(self.response_format)}"
2018
+ )
1535
2019
 
1536
2020
  def as_runnable(self) -> RunnableLike:
1537
2021
  from dao_ai.nodes import create_agent_node
@@ -1550,6 +2034,10 @@ class SupervisorModel(BaseModel):
1550
2034
  model: LLMModel
1551
2035
  tools: list[ToolModel] = Field(default_factory=list)
1552
2036
  prompt: Optional[str] = None
2037
+ middleware: list[MiddlewareModel] = Field(
2038
+ default_factory=list,
2039
+ description="List of middleware to apply to the supervisor",
2040
+ )
1553
2041
 
1554
2042
 
1555
2043
  class SwarmModel(BaseModel):
@@ -1673,6 +2161,28 @@ class ChatPayload(BaseModel):
1673
2161
 
1674
2162
  return self
1675
2163
 
2164
+ @model_validator(mode="after")
2165
+ def ensure_thread_id(self) -> "ChatPayload":
2166
+ """Ensure thread_id or conversation_id is present in configurable, generating UUID if needed."""
2167
+ import uuid
2168
+
2169
+ if self.custom_inputs is None:
2170
+ self.custom_inputs = {}
2171
+
2172
+ # Get or create configurable section
2173
+ configurable: dict[str, Any] = self.custom_inputs.get("configurable", {})
2174
+
2175
+ # Check if thread_id or conversation_id exists
2176
+ has_thread_id = configurable.get("thread_id") is not None
2177
+ has_conversation_id = configurable.get("conversation_id") is not None
2178
+
2179
+ # If neither is provided, generate a UUID for conversation_id
2180
+ if not has_thread_id and not has_conversation_id:
2181
+ configurable["conversation_id"] = str(uuid.uuid4())
2182
+ self.custom_inputs["configurable"] = configurable
2183
+
2184
+ return self
2185
+
1676
2186
  def as_messages(self) -> Sequence[BaseMessage]:
1677
2187
  return messages_from_dict(
1678
2188
  [{"type": m.role, "content": m.content} for m in self.messages]
@@ -1688,20 +2198,38 @@ class ChatPayload(BaseModel):
1688
2198
 
1689
2199
 
1690
2200
  class ChatHistoryModel(BaseModel):
2201
+ """
2202
+ Configuration for chat history summarization.
2203
+
2204
+ Attributes:
2205
+ model: The LLM to use for generating summaries.
2206
+ max_tokens: Maximum tokens to keep after summarization (the "keep" threshold).
2207
+ After summarization, recent messages totaling up to this many tokens are preserved.
2208
+ max_tokens_before_summary: Token threshold that triggers summarization.
2209
+ When conversation exceeds this, summarization runs. Mutually exclusive with
2210
+ max_messages_before_summary. If neither is set, defaults to max_tokens * 10.
2211
+ max_messages_before_summary: Message count threshold that triggers summarization.
2212
+ When conversation exceeds this many messages, summarization runs.
2213
+ Mutually exclusive with max_tokens_before_summary.
2214
+ """
2215
+
1691
2216
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1692
2217
  model: LLMModel
1693
- max_tokens: int = 256
1694
- max_tokens_before_summary: Optional[int] = None
1695
- max_messages_before_summary: Optional[int] = None
1696
- max_summary_tokens: int = 255
1697
-
1698
- @model_validator(mode="after")
1699
- def validate_max_summary_tokens(self) -> "ChatHistoryModel":
1700
- if self.max_summary_tokens >= self.max_tokens:
1701
- raise ValueError(
1702
- f"max_summary_tokens ({self.max_summary_tokens}) must be less than max_tokens ({self.max_tokens})"
1703
- )
1704
- return self
2218
+ max_tokens: int = Field(
2219
+ default=2048,
2220
+ gt=0,
2221
+ description="Maximum tokens to keep after summarization",
2222
+ )
2223
+ max_tokens_before_summary: Optional[int] = Field(
2224
+ default=None,
2225
+ gt=0,
2226
+ description="Token threshold that triggers summarization",
2227
+ )
2228
+ max_messages_before_summary: Optional[int] = Field(
2229
+ default=None,
2230
+ gt=0,
2231
+ description="Message count threshold that triggers summarization",
2232
+ )
1705
2233
 
1706
2234
 
1707
2235
  class AppModel(BaseModel):
@@ -1728,9 +2256,6 @@ class AppModel(BaseModel):
1728
2256
  shutdown_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1729
2257
  default_factory=list
1730
2258
  )
1731
- message_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
1732
- default_factory=list
1733
- )
1734
2259
  input_example: Optional[ChatPayload] = None
1735
2260
  chat_history: Optional[ChatHistoryModel] = None
1736
2261
  code_paths: list[str] = Field(default_factory=list)
@@ -1935,33 +2460,67 @@ class EvaluationDatasetModel(BaseModel, HasFullName):
1935
2460
 
1936
2461
 
1937
2462
  class PromptOptimizationModel(BaseModel):
2463
+ """Configuration for prompt optimization using GEPA.
2464
+
2465
+ GEPA (Generative Evolution of Prompts and Agents) is an evolutionary
2466
+ optimizer that uses reflective mutation to improve prompts based on
2467
+ evaluation feedback.
2468
+
2469
+ Example:
2470
+ prompt_optimization:
2471
+ name: optimize_my_prompt
2472
+ prompt: *my_prompt
2473
+ agent: *my_agent
2474
+ dataset: *my_training_dataset
2475
+ reflection_model: databricks-meta-llama-3-3-70b-instruct
2476
+ num_candidates: 50
2477
+ """
2478
+
1938
2479
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1939
2480
  name: str
1940
2481
  prompt: Optional[PromptModel] = None
1941
2482
  agent: AgentModel
1942
- dataset: (
1943
- EvaluationDatasetModel | str
1944
- ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
2483
+ dataset: EvaluationDatasetModel # Training dataset with examples
1945
2484
  reflection_model: Optional[LLMModel | str] = None
1946
2485
  num_candidates: Optional[int] = 50
1947
- scorer_model: Optional[LLMModel | str] = None
1948
2486
 
1949
2487
  def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1950
2488
  """
1951
- Optimize the prompt using MLflow's prompt optimization.
2489
+ Optimize the prompt using GEPA.
1952
2490
 
1953
2491
  Args:
1954
- w: Optional WorkspaceClient for Databricks operations
2492
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
1955
2493
 
1956
2494
  Returns:
1957
- PromptModel: The optimized prompt model with new URI
2495
+ PromptModel: The optimized prompt model
1958
2496
  """
1959
- from dao_ai.providers.base import ServiceProvider
1960
- from dao_ai.providers.databricks import DatabricksProvider
2497
+ from dao_ai.optimization import OptimizationResult, optimize_prompt
1961
2498
 
1962
- provider: ServiceProvider = DatabricksProvider(w=w)
1963
- optimized_prompt: PromptModel = provider.optimize_prompt(self)
1964
- return optimized_prompt
2499
+ # Get reflection model name
2500
+ reflection_model_name: str | None = None
2501
+ if self.reflection_model:
2502
+ if isinstance(self.reflection_model, str):
2503
+ reflection_model_name = self.reflection_model
2504
+ else:
2505
+ reflection_model_name = self.reflection_model.uri
2506
+
2507
+ # Ensure prompt is set
2508
+ prompt = self.prompt
2509
+ if prompt is None:
2510
+ raise ValueError(
2511
+ f"Prompt optimization '{self.name}' requires a prompt to be set"
2512
+ )
2513
+
2514
+ result: OptimizationResult = optimize_prompt(
2515
+ prompt=prompt,
2516
+ agent=self.agent,
2517
+ dataset=self.dataset,
2518
+ reflection_model=reflection_model_name,
2519
+ num_candidates=self.num_candidates or 50,
2520
+ register_if_improved=True,
2521
+ )
2522
+
2523
+ return result.optimized_prompt
1965
2524
 
1966
2525
  @model_validator(mode="after")
1967
2526
  def set_defaults(self) -> Self:
@@ -1975,12 +2534,6 @@ class PromptOptimizationModel(BaseModel):
1975
2534
  f"or an agent with a prompt configured"
1976
2535
  )
1977
2536
 
1978
- if self.reflection_model is None:
1979
- self.reflection_model = self.agent.model
1980
-
1981
- if self.scorer_model is None:
1982
- self.scorer_model = self.agent.model
1983
-
1984
2537
  return self
1985
2538
 
1986
2539
 
@@ -2081,6 +2634,7 @@ class ResourcesModel(BaseModel):
2081
2634
  warehouses: dict[str, WarehouseModel] = Field(default_factory=dict)
2082
2635
  databases: dict[str, DatabaseModel] = Field(default_factory=dict)
2083
2636
  connections: dict[str, ConnectionModel] = Field(default_factory=dict)
2637
+ apps: dict[str, DatabricksAppModel] = Field(default_factory=dict)
2084
2638
 
2085
2639
 
2086
2640
  class AppConfig(BaseModel):
@@ -2092,6 +2646,7 @@ class AppConfig(BaseModel):
2092
2646
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
2093
2647
  tools: dict[str, ToolModel] = Field(default_factory=dict)
2094
2648
  guardrails: dict[str, GuardrailModel] = Field(default_factory=dict)
2649
+ middleware: dict[str, MiddlewareModel] = Field(default_factory=dict)
2095
2650
  memory: Optional[MemoryModel] = None
2096
2651
  prompts: dict[str, PromptModel] = Field(default_factory=dict)
2097
2652
  agents: dict[str, AgentModel] = Field(default_factory=dict)