dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py CHANGED
@@ -7,6 +7,7 @@ from enum import Enum
7
7
  from os import PathLike
8
8
  from pathlib import Path
9
9
  from typing import (
10
+ TYPE_CHECKING,
10
11
  Any,
11
12
  Callable,
12
13
  Iterator,
@@ -18,12 +19,20 @@ from typing import (
18
19
  Union,
19
20
  )
20
21
 
22
+ if TYPE_CHECKING:
23
+ from dao_ai.genie.cache.optimization import (
24
+ SemanticCacheEvalDataset,
25
+ ThresholdOptimizationResult,
26
+ )
27
+ from dao_ai.state import Context
28
+
21
29
  from databricks.sdk import WorkspaceClient
22
30
  from databricks.sdk.credentials_provider import (
23
31
  CredentialsStrategy,
24
32
  ModelServingUserCredentials,
25
33
  )
26
34
  from databricks.sdk.errors.platform import NotFound
35
+ from databricks.sdk.service.apps import App
27
36
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
28
37
  from databricks.sdk.service.dashboards import GenieSpace
29
38
  from databricks.sdk.service.database import DatabaseInstance
@@ -147,7 +156,7 @@ class PrimitiveVariableModel(BaseModel, HasValue):
147
156
  return str(value)
148
157
 
149
158
  @model_validator(mode="after")
150
- def validate_value(self) -> "PrimitiveVariableModel":
159
+ def validate_value(self) -> Self:
151
160
  if not isinstance(self.as_value(), (str, int, float, bool)):
152
161
  raise ValueError("Value must be a primitive type (str, int, float, bool)")
153
162
  return self
@@ -207,7 +216,9 @@ class IsDatabricksResource(ABC, BaseModel):
207
216
  Authentication Options:
208
217
  ----------------------
209
218
  1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
210
- calling user's identity via ModelServingUserCredentials.
219
+ calling user's identity. Implementation varies by deployment:
220
+ - Databricks Apps: Uses X-Forwarded-Access-Token from request headers
221
+ - Model Serving: Uses ModelServingUserCredentials
211
222
 
212
223
  2. **Service Principal (OAuth M2M)**: Provide service_principal or
213
224
  (client_id + client_secret + workspace_host) for service principal auth.
@@ -220,9 +231,17 @@ class IsDatabricksResource(ABC, BaseModel):
220
231
 
221
232
  Authentication Priority:
222
233
  1. OBO (on_behalf_of_user=True)
234
+ - Checks for forwarded headers (Databricks Apps)
235
+ - Falls back to ModelServingUserCredentials (Model Serving)
223
236
  2. Service Principal (client_id + client_secret + workspace_host)
224
237
  3. PAT (pat + workspace_host)
225
238
  4. Ambient/default authentication
239
+
240
+ Note: When on_behalf_of_user=True, the agent acts as the calling user regardless
241
+ of deployment target. In Databricks Apps, this uses X-Forwarded-Access-Token
242
+ automatically captured by MLflow AgentServer. In Model Serving, this uses
243
+ ModelServingUserCredentials. Forwarded headers are ONLY used when
244
+ on_behalf_of_user=True.
226
245
  """
227
246
 
228
247
  model_config = ConfigDict(use_enum_values=True)
@@ -234,9 +253,6 @@ class IsDatabricksResource(ABC, BaseModel):
234
253
  workspace_host: Optional[AnyVariable] = None
235
254
  pat: Optional[AnyVariable] = None
236
255
 
237
- # Private attribute to cache the workspace client (lazy instantiation)
238
- _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
239
-
240
256
  @abstractmethod
241
257
  def as_resources(self) -> Sequence[DatabricksResource]: ...
242
258
 
@@ -272,19 +288,16 @@ class IsDatabricksResource(ABC, BaseModel):
272
288
  """
273
289
  Get a WorkspaceClient configured with the appropriate authentication.
274
290
 
275
- The client is lazily instantiated on first access and cached for subsequent calls.
291
+ A new client is created on each access.
276
292
 
277
293
  Authentication priority:
278
- 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
279
- 2. If service principal credentials are configured (client_id, client_secret,
280
- workspace_host), uses OAuth M2M
281
- 3. If PAT is configured, uses token authentication
282
- 4. Otherwise, uses default/ambient authentication
294
+ 1. On-Behalf-Of User (on_behalf_of_user=True):
295
+ - Uses ModelServingUserCredentials (Model Serving)
296
+ - For Databricks Apps with headers, use workspace_client_from(context)
297
+ 2. Service Principal (client_id + client_secret + workspace_host)
298
+ 3. PAT (pat + workspace_host)
299
+ 4. Ambient/default authentication
283
300
  """
284
- # Return cached client if already instantiated
285
- if self._workspace_client is not None:
286
- return self._workspace_client
287
-
288
301
  from dao_ai.utils import normalize_host
289
302
 
290
303
  # Check for OBO first (highest priority)
@@ -292,12 +305,9 @@ class IsDatabricksResource(ABC, BaseModel):
292
305
  credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
293
306
  logger.debug(
294
307
  f"Creating WorkspaceClient for {self.__class__.__name__} "
295
- f"with OBO credentials strategy"
296
- )
297
- self._workspace_client = WorkspaceClient(
298
- credentials_strategy=credentials_strategy
308
+ f"with OBO credentials strategy (Model Serving)"
299
309
  )
300
- return self._workspace_client
310
+ return WorkspaceClient(credentials_strategy=credentials_strategy)
301
311
 
302
312
  # Check for service principal credentials
303
313
  client_id_value: str | None = (
@@ -312,18 +322,24 @@ class IsDatabricksResource(ABC, BaseModel):
312
322
  else None
313
323
  )
314
324
 
315
- if client_id_value and client_secret_value and workspace_host_value:
325
+ if client_id_value and client_secret_value:
326
+ # If workspace_host is not provided, check DATABRICKS_HOST env var first,
327
+ # then fall back to WorkspaceClient().config.host
328
+ if not workspace_host_value:
329
+ workspace_host_value = os.getenv("DATABRICKS_HOST")
330
+ if not workspace_host_value:
331
+ workspace_host_value = WorkspaceClient().config.host
332
+
316
333
  logger.debug(
317
334
  f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
318
335
  f"client_id={client_id_value}, host={workspace_host_value}"
319
336
  )
320
- self._workspace_client = WorkspaceClient(
337
+ return WorkspaceClient(
321
338
  host=workspace_host_value,
322
339
  client_id=client_id_value,
323
340
  client_secret=client_secret_value,
324
341
  auth_type="oauth-m2m",
325
342
  )
326
- return self._workspace_client
327
343
 
328
344
  # Check for PAT authentication
329
345
  pat_value: str | None = value_of(self.pat) if self.pat else None
@@ -331,20 +347,83 @@ class IsDatabricksResource(ABC, BaseModel):
331
347
  logger.debug(
332
348
  f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
333
349
  )
334
- self._workspace_client = WorkspaceClient(
350
+ return WorkspaceClient(
335
351
  host=workspace_host_value,
336
352
  token=pat_value,
337
353
  auth_type="pat",
338
354
  )
339
- return self._workspace_client
340
355
 
341
356
  # Default: use ambient authentication
342
357
  logger.debug(
343
358
  f"Creating WorkspaceClient for {self.__class__.__name__} "
344
359
  "with default/ambient authentication"
345
360
  )
346
- self._workspace_client = WorkspaceClient()
347
- return self._workspace_client
361
+ return WorkspaceClient()
362
+
363
+ def workspace_client_from(self, context: "Context | None") -> WorkspaceClient:
364
+ """
365
+ Get a WorkspaceClient using headers from the provided Context.
366
+
367
+ Use this method from tools that have access to ToolRuntime[Context].
368
+ This allows OBO authentication to work in Databricks Apps where headers
369
+ are captured at request entry and passed through the Context.
370
+
371
+ Args:
372
+ context: Runtime context containing headers for OBO auth.
373
+ If None or no headers, falls back to workspace_client property.
374
+
375
+ Returns:
376
+ WorkspaceClient configured with appropriate authentication.
377
+ """
378
+ from dao_ai.utils import normalize_host
379
+
380
+ logger.trace(
381
+ "workspace_client_from called",
382
+ context=context,
383
+ on_behalf_of_user=self.on_behalf_of_user,
384
+ )
385
+
386
+ # Check if we have headers in context for OBO
387
+ if context and context.headers and self.on_behalf_of_user:
388
+ headers = context.headers
389
+ # Try both lowercase and title-case header names (HTTP headers are case-insensitive)
390
+ forwarded_token: str = headers.get(
391
+ "x-forwarded-access-token"
392
+ ) or headers.get("X-Forwarded-Access-Token")
393
+
394
+ if forwarded_token:
395
+ forwarded_user = headers.get("x-forwarded-user") or headers.get(
396
+ "X-Forwarded-User", "unknown"
397
+ )
398
+ logger.debug(
399
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
400
+ f"with OBO using forwarded token from Context",
401
+ forwarded_user=forwarded_user,
402
+ )
403
+ # Use workspace_host if configured, otherwise SDK will auto-detect
404
+ workspace_host_value: str | None = (
405
+ normalize_host(value_of(self.workspace_host))
406
+ if self.workspace_host
407
+ else None
408
+ )
409
+ return WorkspaceClient(
410
+ host=workspace_host_value,
411
+ token=forwarded_token,
412
+ auth_type="pat",
413
+ )
414
+
415
+ # Fall back to existing workspace_client property
416
+ return self.workspace_client
417
+
418
+
419
+ class DeploymentTarget(str, Enum):
420
+ """Target platform for agent deployment."""
421
+
422
+ MODEL_SERVING = "model_serving"
423
+ """Deploy to Databricks Model Serving endpoint."""
424
+
425
+ APPS = "apps"
426
+ """Deploy as a Databricks App."""
348
427
 
349
428
 
350
429
  class Privilege(str, Enum):
@@ -391,10 +470,17 @@ class PermissionModel(BaseModel):
391
470
 
392
471
  class SchemaModel(BaseModel, HasFullName):
393
472
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
394
- catalog_name: str
395
- schema_name: str
473
+ catalog_name: AnyVariable
474
+ schema_name: AnyVariable
396
475
  permissions: Optional[list[PermissionModel]] = Field(default_factory=list)
397
476
 
477
+ @model_validator(mode="after")
478
+ def resolve_variables(self) -> Self:
479
+ """Resolve AnyVariable fields to their actual string values."""
480
+ self.catalog_name = value_of(self.catalog_name)
481
+ self.schema_name = value_of(self.schema_name)
482
+ return self
483
+
398
484
  @property
399
485
  def full_name(self) -> str:
400
486
  return f"{self.catalog_name}.{self.schema_name}"
@@ -408,9 +494,44 @@ class SchemaModel(BaseModel, HasFullName):
408
494
 
409
495
 
410
496
  class DatabricksAppModel(IsDatabricksResource, HasFullName):
497
+ """
498
+ Configuration for a Databricks App resource.
499
+
500
+ The `name` is the unique instance name of the Databricks App within the workspace.
501
+ The `url` is dynamically retrieved from the workspace client by calling
502
+ `apps.get(name)` and returning the app's URL.
503
+
504
+ Example:
505
+ ```yaml
506
+ resources:
507
+ apps:
508
+ my_app:
509
+ name: my-databricks-app
510
+ ```
511
+ """
512
+
411
513
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
412
514
  name: str
413
- url: str
515
+ """The unique instance name of the Databricks App in the workspace."""
516
+
517
+ @property
518
+ def url(self) -> str:
519
+ """
520
+ Retrieve the URL of the Databricks App from the workspace.
521
+
522
+ Returns:
523
+ The URL of the deployed Databricks App.
524
+
525
+ Raises:
526
+ RuntimeError: If the app is not found or URL is not available.
527
+ """
528
+ app: App = self.workspace_client.apps.get(self.name)
529
+ if app.url is None:
530
+ raise RuntimeError(
531
+ f"Databricks App '{self.name}' does not have a URL. "
532
+ "The app may not be deployed yet."
533
+ )
534
+ return app.url
414
535
 
415
536
  @property
416
537
  def full_name(self) -> str:
@@ -432,7 +553,7 @@ class TableModel(IsDatabricksResource, HasFullName):
432
553
  name: Optional[str] = None
433
554
 
434
555
  @model_validator(mode="after")
435
- def validate_name_or_schema_required(self) -> "TableModel":
556
+ def validate_name_or_schema_required(self) -> Self:
436
557
  if not self.name and not self.schema_model:
437
558
  raise ValueError(
438
559
  "Either 'name' or 'schema_model' must be provided for TableModel"
@@ -601,6 +722,8 @@ class VectorSearchEndpoint(BaseModel):
601
722
 
602
723
 
603
724
  class IndexModel(IsDatabricksResource, HasFullName):
725
+ """Model representing a Databricks Vector Search index."""
726
+
604
727
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
605
728
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
606
729
  name: str
@@ -624,6 +747,22 @@ class IndexModel(IsDatabricksResource, HasFullName):
624
747
  )
625
748
  ]
626
749
 
750
+ def exists(self) -> bool:
751
+ """Check if this vector search index exists.
752
+
753
+ Returns:
754
+ True if the index exists, False otherwise.
755
+ """
756
+ try:
757
+ self.workspace_client.vector_search_indexes.get_index(self.full_name)
758
+ return True
759
+ except NotFound:
760
+ logger.debug(f"Index not found: {self.full_name}")
761
+ return False
762
+ except Exception as e:
763
+ logger.warning(f"Error checking index existence for {self.full_name}: {e}")
764
+ return False
765
+
627
766
 
628
767
  class FunctionModel(IsDatabricksResource, HasFullName):
629
768
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -699,11 +838,20 @@ class FunctionModel(IsDatabricksResource, HasFullName):
699
838
 
700
839
 
701
840
  class WarehouseModel(IsDatabricksResource):
702
- model_config = ConfigDict()
703
- name: str
841
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
842
+ name: Optional[str] = None
704
843
  description: Optional[str] = None
705
844
  warehouse_id: AnyVariable
706
845
 
846
+ _warehouse_details: Optional[GetWarehouseResponse] = PrivateAttr(default=None)
847
+
848
+ def _get_warehouse_details(self) -> GetWarehouseResponse:
849
+ if self._warehouse_details is None:
850
+ self._warehouse_details = self.workspace_client.warehouses.get(
851
+ id=value_of(self.warehouse_id)
852
+ )
853
+ return self._warehouse_details
854
+
707
855
  @property
708
856
  def api_scopes(self) -> Sequence[str]:
709
857
  return [
@@ -724,10 +872,22 @@ class WarehouseModel(IsDatabricksResource):
724
872
  self.warehouse_id = value_of(self.warehouse_id)
725
873
  return self
726
874
 
875
+ @model_validator(mode="after")
876
+ def populate_name(self) -> Self:
877
+ """Populate name from warehouse details if not provided."""
878
+ if self.warehouse_id and not self.name:
879
+ try:
880
+ warehouse_details = self._get_warehouse_details()
881
+ if warehouse_details.name:
882
+ self.name = warehouse_details.name
883
+ except Exception as e:
884
+ logger.debug(f"Could not fetch details from warehouse: {e}")
885
+ return self
886
+
727
887
 
728
888
  class GenieRoomModel(IsDatabricksResource):
729
889
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
730
- name: str
890
+ name: Optional[str] = None
731
891
  description: Optional[str] = None
732
892
  space_id: AnyVariable
733
893
 
@@ -783,10 +943,6 @@ class GenieRoomModel(IsDatabricksResource):
783
943
  pat=self.pat,
784
944
  )
785
945
 
786
- # Share the cached workspace client if available
787
- if self._workspace_client is not None:
788
- warehouse_model._workspace_client = self._workspace_client
789
-
790
946
  return warehouse_model
791
947
  except Exception as e:
792
948
  logger.warning(
@@ -830,9 +986,6 @@ class GenieRoomModel(IsDatabricksResource):
830
986
  workspace_host=self.workspace_host,
831
987
  pat=self.pat,
832
988
  )
833
- # Share the cached workspace client if available
834
- if self._workspace_client is not None:
835
- table_model._workspace_client = self._workspace_client
836
989
 
837
990
  # Verify the table exists before adding
838
991
  if not table_model.exists():
@@ -870,9 +1023,6 @@ class GenieRoomModel(IsDatabricksResource):
870
1023
  workspace_host=self.workspace_host,
871
1024
  pat=self.pat,
872
1025
  )
873
- # Share the cached workspace client if available
874
- if self._workspace_client is not None:
875
- function_model._workspace_client = self._workspace_client
876
1026
 
877
1027
  # Verify the function exists before adding
878
1028
  if not function_model.exists():
@@ -936,15 +1086,17 @@ class GenieRoomModel(IsDatabricksResource):
936
1086
  return self
937
1087
 
938
1088
  @model_validator(mode="after")
939
- def update_description_from_space(self) -> Self:
940
- """Populate description from GenieSpace if not provided."""
941
- if not self.description:
1089
+ def populate_name_and_description(self) -> Self:
1090
+ """Populate name and description from GenieSpace if not provided."""
1091
+ if self.space_id and (not self.name or not self.description):
942
1092
  try:
943
1093
  space_details = self._get_space_details()
944
- if space_details.description:
1094
+ if not self.name and space_details.title:
1095
+ self.name = space_details.title
1096
+ if not self.description and space_details.description:
945
1097
  self.description = space_details.description
946
1098
  except Exception as e:
947
- logger.debug(f"Could not fetch description from Genie space: {e}")
1099
+ logger.debug(f"Could not fetch details from Genie space: {e}")
948
1100
  return self
949
1101
 
950
1102
 
@@ -980,7 +1132,7 @@ class VolumePathModel(BaseModel, HasFullName):
980
1132
  path: Optional[str] = None
981
1133
 
982
1134
  @model_validator(mode="after")
983
- def validate_path_or_volume(self) -> "VolumePathModel":
1135
+ def validate_path_or_volume(self) -> Self:
984
1136
  if not self.volume and not self.path:
985
1137
  raise ValueError("Either 'volume' or 'path' must be provided")
986
1138
  return self
@@ -1009,27 +1161,92 @@ class VolumePathModel(BaseModel, HasFullName):
1009
1161
 
1010
1162
 
1011
1163
  class VectorStoreModel(IsDatabricksResource):
1164
+ """
1165
+ Configuration model for a Databricks Vector Search store.
1166
+
1167
+ Supports two modes:
1168
+ 1. **Use Existing Index**: Provide only `index` (fully qualified name).
1169
+ Used for querying an existing vector search index at runtime.
1170
+ 2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
1171
+ Used for creating a new vector search index.
1172
+
1173
+ Examples:
1174
+ Minimal configuration (use existing index):
1175
+ ```yaml
1176
+ vector_stores:
1177
+ products_search:
1178
+ index:
1179
+ name: catalog.schema.my_index
1180
+ ```
1181
+
1182
+ Full provisioning configuration:
1183
+ ```yaml
1184
+ vector_stores:
1185
+ products_search:
1186
+ source_table:
1187
+ schema: *my_schema
1188
+ name: products
1189
+ embedding_source_column: description
1190
+ endpoint:
1191
+ name: my_endpoint
1192
+ ```
1193
+ """
1194
+
1012
1195
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1013
- embedding_model: Optional[LLMModel] = None
1196
+
1197
+ # RUNTIME: Only index is truly required for querying existing indexes
1014
1198
  index: Optional[IndexModel] = None
1199
+
1200
+ # PROVISIONING ONLY: Required when creating a new index
1201
+ source_table: Optional[TableModel] = None
1202
+ embedding_source_column: Optional[str] = None
1203
+ embedding_model: Optional[LLMModel] = None
1015
1204
  endpoint: Optional[VectorSearchEndpoint] = None
1016
- source_table: TableModel
1205
+
1206
+ # OPTIONAL: For both modes
1017
1207
  source_path: Optional[VolumePathModel] = None
1018
1208
  checkpoint_path: Optional[VolumePathModel] = None
1019
1209
  primary_key: Optional[str] = None
1020
1210
  columns: Optional[list[str]] = Field(default_factory=list)
1021
1211
  doc_uri: Optional[str] = None
1022
- embedding_source_column: str
1212
+
1213
+ @model_validator(mode="after")
1214
+ def validate_configuration_mode(self) -> Self:
1215
+ """
1216
+ Validate that configuration is valid for either:
1217
+ - Use existing mode: index is provided
1218
+ - Provisioning mode: source_table + embedding_source_column provided
1219
+ """
1220
+ has_index = self.index is not None
1221
+ has_source_table = self.source_table is not None
1222
+ has_embedding_col = self.embedding_source_column is not None
1223
+
1224
+ # Must have at least index OR source_table
1225
+ if not has_index and not has_source_table:
1226
+ raise ValueError(
1227
+ "Either 'index' (for existing indexes) or 'source_table' "
1228
+ "(for provisioning) must be provided"
1229
+ )
1230
+
1231
+ # If provisioning mode, need embedding_source_column
1232
+ if has_source_table and not has_embedding_col:
1233
+ raise ValueError(
1234
+ "embedding_source_column is required when source_table is provided (provisioning mode)"
1235
+ )
1236
+
1237
+ return self
1023
1238
 
1024
1239
  @model_validator(mode="after")
1025
1240
  def set_default_embedding_model(self) -> Self:
1026
- if not self.embedding_model:
1241
+ # Only set default embedding model in provisioning mode
1242
+ if self.source_table is not None and not self.embedding_model:
1027
1243
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
1028
1244
  return self
1029
1245
 
1030
1246
  @model_validator(mode="after")
1031
1247
  def set_default_primary_key(self) -> Self:
1032
- if self.primary_key is None:
1248
+ # Only auto-discover primary key in provisioning mode
1249
+ if self.primary_key is None and self.source_table is not None:
1033
1250
  from dao_ai.providers.databricks import DatabricksProvider
1034
1251
 
1035
1252
  provider: DatabricksProvider = DatabricksProvider()
@@ -1050,14 +1267,16 @@ class VectorStoreModel(IsDatabricksResource):
1050
1267
 
1051
1268
  @model_validator(mode="after")
1052
1269
  def set_default_index(self) -> Self:
1053
- if self.index is None:
1270
+ # Only generate index from source_table in provisioning mode
1271
+ if self.index is None and self.source_table is not None:
1054
1272
  name: str = f"{self.source_table.name}_index"
1055
1273
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
1056
1274
  return self
1057
1275
 
1058
1276
  @model_validator(mode="after")
1059
1277
  def set_default_endpoint(self) -> Self:
1060
- if self.endpoint is None:
1278
+ # Only find/create endpoint in provisioning mode
1279
+ if self.endpoint is None and self.source_table is not None:
1061
1280
  from dao_ai.providers.databricks import (
1062
1281
  DatabricksProvider,
1063
1282
  with_available_indexes,
@@ -1092,18 +1311,60 @@ class VectorStoreModel(IsDatabricksResource):
1092
1311
  return self.index.as_resources()
1093
1312
 
1094
1313
  def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
1095
- from dao_ai.providers.base import ServiceProvider
1096
1314
  from dao_ai.providers.databricks import DatabricksProvider
1097
1315
 
1098
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
1316
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
1099
1317
  index: VectorSearchIndex = provider.get_vector_index(self)
1100
1318
  return index
1101
1319
 
1102
1320
  def create(self, vsc: VectorSearchClient | None = None) -> None:
1103
- from dao_ai.providers.base import ServiceProvider
1321
+ """
1322
+ Create or validate the vector search index.
1323
+
1324
+ Behavior depends on configuration mode:
1325
+ - **Provisioning Mode** (source_table provided): Creates the index
1326
+ - **Use Existing Mode** (only index provided): Validates the index exists
1327
+
1328
+ Args:
1329
+ vsc: Optional VectorSearchClient instance
1330
+
1331
+ Raises:
1332
+ ValueError: If configuration is invalid or index doesn't exist
1333
+ """
1104
1334
  from dao_ai.providers.databricks import DatabricksProvider
1105
1335
 
1106
- provider: ServiceProvider = DatabricksProvider(vsc=vsc)
1336
+ provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
1337
+
1338
+ if self.source_table is not None:
1339
+ self._create_new_index(provider)
1340
+ else:
1341
+ self._validate_existing_index(provider)
1342
+
1343
+ def _validate_existing_index(self, provider: Any) -> None:
1344
+ """Validate that an existing index is accessible."""
1345
+ if self.index is None:
1346
+ raise ValueError("index is required for 'use existing' mode")
1347
+
1348
+ if self.index.exists():
1349
+ logger.info(
1350
+ "Vector search index exists and ready",
1351
+ index_name=self.index.full_name,
1352
+ )
1353
+ else:
1354
+ raise ValueError(
1355
+ f"Index '{self.index.full_name}' does not exist. "
1356
+ "Provide 'source_table' to provision it."
1357
+ )
1358
+
1359
+ def _create_new_index(self, provider: Any) -> None:
1360
+ """Create a new vector search index from source table."""
1361
+ if self.embedding_source_column is None:
1362
+ raise ValueError("embedding_source_column is required for provisioning")
1363
+ if self.endpoint is None:
1364
+ raise ValueError("endpoint is required for provisioning")
1365
+ if self.index is None:
1366
+ raise ValueError("index is required for provisioning")
1367
+
1107
1368
  provider.create_vector_store(self)
1108
1369
 
1109
1370
 
@@ -1145,13 +1406,20 @@ class DatabaseModel(IsDatabricksResource):
1145
1406
  - Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
1146
1407
  - Standard PostgreSQL: Provide `host` (authentication required via user/password)
1147
1408
 
1148
- Note: `instance_name` and `host` are mutually exclusive. Provide one or the other.
1409
+ Note: For Lakebase connections, `name` is optional and defaults to `instance_name`.
1410
+ For PostgreSQL connections, `name` is required.
1411
+
1412
+ Example Databricks Lakebase (minimal):
1413
+ ```yaml
1414
+ databases:
1415
+ my_lakebase:
1416
+ instance_name: my-lakebase-instance # name defaults to instance_name
1417
+ ```
1149
1418
 
1150
1419
  Example Databricks Lakebase with Service Principal:
1151
1420
  ```yaml
1152
1421
  databases:
1153
1422
  my_lakebase:
1154
- name: my-database
1155
1423
  instance_name: my-lakebase-instance
1156
1424
  service_principal:
1157
1425
  client_id:
@@ -1167,7 +1435,6 @@ class DatabaseModel(IsDatabricksResource):
1167
1435
  ```yaml
1168
1436
  databases:
1169
1437
  my_lakebase:
1170
- name: my-database
1171
1438
  instance_name: my-lakebase-instance
1172
1439
  on_behalf_of_user: true
1173
1440
  ```
@@ -1187,7 +1454,7 @@ class DatabaseModel(IsDatabricksResource):
1187
1454
  """
1188
1455
 
1189
1456
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1190
- name: str
1457
+ name: Optional[str] = None
1191
1458
  instance_name: Optional[str] = None
1192
1459
  description: Optional[str] = None
1193
1460
  host: Optional[AnyVariable] = None
@@ -1236,6 +1503,17 @@ class DatabaseModel(IsDatabricksResource):
1236
1503
  )
1237
1504
  return self
1238
1505
 
1506
+ @model_validator(mode="after")
1507
+ def populate_name_from_instance_name(self) -> Self:
1508
+ """Populate name from instance_name if not provided for Lakebase connections."""
1509
+ if self.name is None and self.instance_name:
1510
+ self.name = self.instance_name
1511
+ elif self.name is None:
1512
+ raise ValueError(
1513
+ "Either 'name' or 'instance_name' must be provided for DatabaseModel."
1514
+ )
1515
+ return self
1516
+
1239
1517
  @model_validator(mode="after")
1240
1518
  def update_user(self) -> Self:
1241
1519
  # Skip if using OBO (passive auth), explicit credentials, or explicit user
@@ -1266,32 +1544,12 @@ class DatabaseModel(IsDatabricksResource):
1266
1544
 
1267
1545
  @model_validator(mode="after")
1268
1546
  def update_host(self) -> Self:
1269
- if self.host is not None:
1547
+ # Lakebase uses instance_name directly via databricks_langchain - host not needed
1548
+ if self.is_lakebase:
1270
1549
  return self
1271
1550
 
1272
- # If instance_name is provided (Lakebase), try to fetch host from existing instance
1273
- # This may fail for OBO/ambient auth during model logging (before deployment)
1274
- if self.is_lakebase:
1275
- try:
1276
- existing_instance: DatabaseInstance = (
1277
- self.workspace_client.database.get_database_instance(
1278
- name=self.instance_name
1279
- )
1280
- )
1281
- self.host = existing_instance.read_write_dns
1282
- except Exception as e:
1283
- # For Lakebase with OBO/ambient auth, we can't fetch at config time
1284
- # The host will need to be provided explicitly or fetched at runtime
1285
- if self.on_behalf_of_user:
1286
- logger.debug(
1287
- f"Could not fetch host for database {self.instance_name} "
1288
- f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
1289
- )
1290
- else:
1291
- raise ValueError(
1292
- f"Could not fetch host for database {self.instance_name}. "
1293
- f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
1294
- )
1551
+ # For standard PostgreSQL, host must be provided by the user
1552
+ # (enforced by validate_connection_type)
1295
1553
  return self
1296
1554
 
1297
1555
  @model_validator(mode="after")
@@ -1353,10 +1611,10 @@ class DatabaseModel(IsDatabricksResource):
1353
1611
  username: str | None = None
1354
1612
  password_value: str | None = None
1355
1613
 
1356
- # Resolve host - may need to fetch at runtime for OBO mode
1614
+ # Resolve host - fetch from API at runtime for Lakebase if not provided
1357
1615
  host_value: Any = self.host
1358
- if host_value is None and self.is_lakebase and self.on_behalf_of_user:
1359
- # Fetch host at runtime for OBO mode
1616
+ if host_value is None and self.is_lakebase:
1617
+ # Fetch host from Lakebase instance API
1360
1618
  existing_instance: DatabaseInstance = (
1361
1619
  self.workspace_client.database.get_database_instance(
1362
1620
  name=self.instance_name
@@ -1456,7 +1714,7 @@ class GenieLRUCacheParametersModel(BaseModel):
1456
1714
  warehouse: WarehouseModel
1457
1715
 
1458
1716
 
1459
- class GenieSemanticCacheParametersModel(BaseModel):
1717
+ class GenieContextAwareCacheParametersModel(BaseModel):
1460
1718
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1461
1719
  time_to_live_seconds: int | None = (
1462
1720
  60 * 60 * 24
@@ -1474,6 +1732,116 @@ class GenieSemanticCacheParametersModel(BaseModel):
1474
1732
  database: DatabaseModel
1475
1733
  warehouse: WarehouseModel
1476
1734
  table_name: str = "genie_semantic_cache"
1735
+ context_window_size: int = 2 # Number of previous turns to include for context
1736
+ max_context_tokens: int = (
1737
+ 2000 # Maximum context length to prevent extremely long embeddings
1738
+ )
1739
+ # Prompt history configuration
1740
+ # Prompt history is always enabled - it stores all user prompts to maintain
1741
+ # conversation context for accurate semantic matching even when cache hits occur
1742
+ prompt_history_table: str = "genie_prompt_history" # Table name for prompt history
1743
+ max_prompt_history_length: int = 50 # Maximum prompts to keep per conversation
1744
+ use_genie_api_for_history: bool = (
1745
+ False # Fallback to Genie API if local history empty
1746
+ )
1747
+ prompt_history_ttl_seconds: int | None = (
1748
+ None # TTL for prompts (None = use cache TTL)
1749
+ )
1750
+
1751
+ @model_validator(mode="after")
1752
+ def compute_and_validate_weights(self) -> Self:
1753
+ """
1754
+ Compute missing weight and validate that question_weight + context_weight = 1.0.
1755
+
1756
+ Either question_weight or context_weight (or both) can be provided.
1757
+ The missing one will be computed as 1.0 - provided_weight.
1758
+ If both are provided, they must sum to 1.0.
1759
+ """
1760
+ if self.question_weight is None and self.context_weight is None:
1761
+ # Both missing - use defaults
1762
+ self.question_weight = 0.6
1763
+ self.context_weight = 0.4
1764
+ elif self.question_weight is None:
1765
+ # Compute question_weight from context_weight
1766
+ if not (0.0 <= self.context_weight <= 1.0):
1767
+ raise ValueError(
1768
+ f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
1769
+ )
1770
+ self.question_weight = 1.0 - self.context_weight
1771
+ elif self.context_weight is None:
1772
+ # Compute context_weight from question_weight
1773
+ if not (0.0 <= self.question_weight <= 1.0):
1774
+ raise ValueError(
1775
+ f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
1776
+ )
1777
+ self.context_weight = 1.0 - self.question_weight
1778
+ else:
1779
+ # Both provided - validate they sum to 1.0
1780
+ total_weight = self.question_weight + self.context_weight
1781
+ if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
1782
+ raise ValueError(
1783
+ f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
1784
+ f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
1785
+ f"of question vs context similarity in the combined score."
1786
+ )
1787
+
1788
+ return self
1789
+
1790
+
1791
+ # Memory estimation for capacity planning:
1792
+ # - Each entry: ~20KB (8KB question embedding + 8KB context embedding + 4KB strings/overhead)
1793
+ # - 1,000 entries: ~20MB (0.4% of 8GB)
1794
+ # - 5,000 entries: ~100MB (2% of 8GB)
1795
+ # - 10,000 entries: ~200MB (4-5% of 8GB) - default for ~30 users
1796
+ # - 20,000 entries: ~400MB (8-10% of 8GB)
1797
+ # Default 10,000 entries provides ~330 queries per user for 30 users.
1798
+ class GenieInMemorySemanticCacheParametersModel(BaseModel):
1799
+ """
1800
+ Configuration for in-memory semantic cache (no database required).
1801
+
1802
+ This cache stores embeddings and cache entries entirely in memory, providing
1803
+ semantic similarity matching without requiring external database dependencies
1804
+ like PostgreSQL or Databricks Lakebase.
1805
+
1806
+ Default settings are tuned for ~30 users on an 8GB machine:
1807
+ - Capacity: 10,000 entries (~200MB memory, ~330 queries per user)
1808
+ - Eviction: LRU (Least Recently Used) - keeps frequently accessed queries
1809
+ - TTL: 1 week (accommodates weekly work patterns and batch jobs)
1810
+ - Memory overhead: ~4-5% of 8GB system
1811
+
1812
+ The LRU eviction strategy ensures hot queries stay cached while cold queries
1813
+ are evicted, providing better hit rates than FIFO eviction.
1814
+
1815
+ For larger deployments or memory-constrained environments, adjust capacity and TTL accordingly.
1816
+
1817
+ Use this when:
1818
+ - No external database access is available
1819
+ - Single-instance deployments (cache not shared across instances)
1820
+ - Cache persistence across restarts is not required
1821
+ - Cache sizes are moderate (hundreds to low thousands of entries)
1822
+
1823
+ For multi-instance deployments or large cache sizes, use GenieContextAwareCacheParametersModel
1824
+ with PostgreSQL backend instead.
1825
+ """
1826
+
1827
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1828
+ time_to_live_seconds: int | None = (
1829
+ 60 * 60 * 24 * 7
1830
+ ) # 1 week default (604800 seconds), None or negative = never expires
1831
+ similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
1832
+ context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
1833
+ question_weight: Optional[float] = (
1834
+ 0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
1835
+ )
1836
+ context_weight: Optional[float] = (
1837
+ None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
1838
+ )
1839
+ embedding_model: str | LLMModel = "databricks-gte-large-en"
1840
+ embedding_dims: int | None = None # Auto-detected if None
1841
+ warehouse: WarehouseModel
1842
+ capacity: int | None = (
1843
+ 10000 # Maximum cache entries. ~200MB for 10000 entries (1024-dim embeddings). LRU eviction when full. None = unlimited (not recommended for production).
1844
+ )
1477
1845
  context_window_size: int = 3 # Number of previous turns to include for context
1478
1846
  max_context_tokens: int = (
1479
1847
  2000 # Maximum context length to prevent extremely long embeddings
@@ -1526,41 +1894,83 @@ class SearchParametersModel(BaseModel):
1526
1894
  query_type: Optional[str] = "ANN"
1527
1895
 
1528
1896
 
1897
+ class InstructionAwareRerankModel(BaseModel):
1898
+ """
1899
+ LLM-based reranking considering user instructions and constraints.
1900
+
1901
+ Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~100ms).
1902
+ Runs AFTER FlashRank as an additional constraint-aware reranking stage.
1903
+ Skipped for 'standard' mode when auto_bypass=true in router config.
1904
+
1905
+ Example:
1906
+ ```yaml
1907
+ rerank:
1908
+ model: ms-marco-MiniLM-L-12-v2
1909
+ top_n: 20
1910
+ instruction_aware:
1911
+ model: *fast_llm
1912
+ instructions: |
1913
+ Prioritize results matching price and brand constraints.
1914
+ top_n: 10
1915
+ ```
1916
+ """
1917
+
1918
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1919
+
1920
+ model: Optional["LLMModel"] = Field(
1921
+ default=None,
1922
+ description="LLM for instruction reranking (fast model recommended)",
1923
+ )
1924
+ instructions: Optional[str] = Field(
1925
+ default=None,
1926
+ description="Custom reranking instructions for constraint prioritization",
1927
+ )
1928
+ top_n: Optional[int] = Field(
1929
+ default=None,
1930
+ description="Number of documents to return after instruction reranking",
1931
+ )
1932
+
1933
+
1529
1934
  class RerankParametersModel(BaseModel):
1530
1935
  """
1531
- Configuration for reranking retrieved documents using FlashRank.
1936
+ Configuration for reranking retrieved documents.
1532
1937
 
1533
- FlashRank provides fast, local reranking without API calls using lightweight
1534
- cross-encoder models. Reranking improves retrieval quality by reordering results
1535
- based on semantic relevance to the query.
1938
+ Supports three reranking options that can be combined:
1939
+ 1. FlashRank (local cross-encoder) - set `model`
1940
+ 2. Databricks server-side reranking - set `columns`
1941
+ 3. LLM instruction-aware reranking - set `instruction_aware`
1536
1942
 
1537
- Typical workflow:
1538
- 1. Retrieve more documents than needed (e.g., 50 via num_results)
1539
- 2. Rerank all retrieved documents
1540
- 3. Return top_n best matches (e.g., 5)
1943
+ Example with Databricks columns + instruction-aware (no FlashRank):
1944
+ ```yaml
1945
+ rerank:
1946
+ columns: # Databricks server-side reranking
1947
+ - product_name
1948
+ - brand_name
1949
+ instruction_aware: # LLM-based constraint reranking
1950
+ model: *fast_llm
1951
+ instructions: "Prioritize by brand preferences"
1952
+ top_n: 10
1953
+ ```
1541
1954
 
1542
- Example:
1955
+ Example with FlashRank:
1543
1956
  ```yaml
1544
- retriever:
1545
- search_parameters:
1546
- num_results: 50 # Retrieve more candidates
1547
- rerank:
1548
- model: ms-marco-MiniLM-L-12-v2
1549
- top_n: 5 # Return top 5 after reranking
1957
+ rerank:
1958
+ model: ms-marco-MiniLM-L-12-v2 # FlashRank model
1959
+ top_n: 10
1550
1960
  ```
1551
1961
 
1552
- Available models (from fastest to most accurate):
1553
- - "ms-marco-TinyBERT-L-2-v2" (fastest, smallest)
1554
- - "ms-marco-MiniLM-L-6-v2"
1555
- - "ms-marco-MiniLM-L-12-v2" (default, good balance)
1556
- - "rank-T5-flan" (most accurate, slower)
1962
+ Available FlashRank models (see https://github.com/PrithivirajDamodaran/FlashRank):
1963
+ - "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
1964
+ - "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder)
1965
+ - "rank-T5-flan" (~110MB, best non cross-encoder)
1966
+ - "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
1557
1967
  """
1558
1968
 
1559
1969
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1560
1970
 
1561
- model: str = Field(
1562
- default="ms-marco-MiniLM-L-12-v2",
1563
- description="FlashRank model name. Default provides good balance of speed and accuracy.",
1971
+ model: Optional[str] = Field(
1972
+ default=None,
1973
+ description="FlashRank model name. If None, FlashRank is not used (use columns for Databricks reranking).",
1564
1974
  )
1565
1975
  top_n: Optional[int] = Field(
1566
1976
  default=None,
@@ -1573,6 +1983,289 @@ class RerankParametersModel(BaseModel):
1573
1983
  columns: Optional[list[str]] = Field(
1574
1984
  default_factory=list, description="Columns to rerank using DatabricksReranker"
1575
1985
  )
1986
+ instruction_aware: Optional[InstructionAwareRerankModel] = Field(
1987
+ default=None,
1988
+ description="Optional LLM-based reranking stage after FlashRank",
1989
+ )
1990
+
1991
+
1992
+ class FilterItem(BaseModel):
1993
+ """A metadata filter for vector search.
1994
+
1995
+ Filters constrain search results by matching column values.
1996
+ Use column names from the provided schema description.
1997
+ """
1998
+
1999
+ model_config = ConfigDict(extra="forbid")
2000
+ key: str = Field(
2001
+ description=(
2002
+ "Column name with optional operator suffix. "
2003
+ "Operators: (none) for equality, NOT for exclusion, "
2004
+ "< <= > >= for numeric comparison, "
2005
+ "LIKE for token match, NOT LIKE to exclude tokens."
2006
+ )
2007
+ )
2008
+ value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
2009
+ description=(
2010
+ "The filter value matching the column type. "
2011
+ "Use an array for IN-style matching multiple values."
2012
+ )
2013
+ )
2014
+
2015
+
2016
+ class SearchQuery(BaseModel):
2017
+ """A single search query with optional metadata filters.
2018
+
2019
+ Represents one focused search intent extracted from the user's request.
2020
+ The text should be a natural language query optimized for semantic search.
2021
+ Filters constrain results to match specific metadata values.
2022
+ """
2023
+
2024
+ model_config = ConfigDict(extra="forbid")
2025
+ text: str = Field(
2026
+ description=(
2027
+ "Natural language search query text optimized for semantic similarity. "
2028
+ "Should be focused on a single search intent. "
2029
+ "Do NOT include filter criteria in the text; use the filters field instead."
2030
+ )
2031
+ )
2032
+ filters: Optional[list[FilterItem]] = Field(
2033
+ default=None,
2034
+ description=(
2035
+ "Metadata filters to constrain search results. "
2036
+ "Set to null if no filters apply. "
2037
+ "Extract filter values from explicit constraints in the user query."
2038
+ ),
2039
+ )
2040
+
2041
+
2042
+ class DecomposedQueries(BaseModel):
2043
+ """Decomposed search queries extracted from a user request.
2044
+
2045
+ Break down complex user queries into multiple focused search queries.
2046
+ Each query targets a distinct search intent with appropriate filters.
2047
+ Generate 1-3 queries depending on the complexity of the user request.
2048
+ """
2049
+
2050
+ model_config = ConfigDict(extra="forbid")
2051
+ queries: list[SearchQuery] = Field(
2052
+ description=(
2053
+ "List of search queries extracted from the user request. "
2054
+ "Each query should target a distinct search intent. "
2055
+ "Order queries by importance, with the most relevant first."
2056
+ )
2057
+ )
2058
+
2059
+
2060
+ class ColumnInfo(BaseModel):
2061
+ """Column metadata for dynamic schema generation in structured output.
2062
+
2063
+ When provided, column information is embedded directly into the JSON schema
2064
+ that with_structured_output sends to the LLM, improving filter accuracy.
2065
+ """
2066
+
2067
+ model_config = ConfigDict(extra="forbid")
2068
+
2069
+ name: str = Field(description="Column name as it appears in the database")
2070
+ type: Literal["string", "number", "boolean", "datetime"] = Field(
2071
+ default="string",
2072
+ description="Column data type for value validation",
2073
+ )
2074
+ operators: list[str] = Field(
2075
+ default=["", "NOT", "<", "<=", ">", ">=", "LIKE", "NOT LIKE"],
2076
+ description="Valid filter operators for this column",
2077
+ )
2078
+
2079
+
2080
+ class InstructedRetrieverModel(BaseModel):
2081
+ """
2082
+ Configuration for instructed retrieval with query decomposition and RRF merging.
2083
+
2084
+ Instructed retrieval decomposes user queries into multiple subqueries with
2085
+ metadata filters, executes them in parallel, and merges results using
2086
+ Reciprocal Rank Fusion (RRF) before reranking.
2087
+
2088
+ Example:
2089
+ ```yaml
2090
+ retriever:
2091
+ vector_store: *products_vector_store
2092
+ instructed:
2093
+ decomposition_model: *fast_llm
2094
+ schema_description: |
2095
+ Products table: product_id, brand_name, category, price, updated_at
2096
+ Filter operators: {"col": val}, {"col >": val}, {"col NOT": val}
2097
+ columns:
2098
+ - name: brand_name
2099
+ type: string
2100
+ - name: price
2101
+ type: number
2102
+ operators: ["", "<", "<=", ">", ">="]
2103
+ constraints:
2104
+ - "Prefer recent products"
2105
+ max_subqueries: 3
2106
+ examples:
2107
+ - query: "cheap drills"
2108
+ filters: {"price <": 100}
2109
+ ```
2110
+ """
2111
+
2112
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2113
+
2114
+ decomposition_model: Optional["LLMModel"] = Field(
2115
+ default=None,
2116
+ description="LLM for query decomposition (smaller/faster model recommended)",
2117
+ )
2118
+ schema_description: str = Field(
2119
+ description="Column names, types, and valid filter syntax for the LLM"
2120
+ )
2121
+ columns: Optional[list[ColumnInfo]] = Field(
2122
+ default=None,
2123
+ description=(
2124
+ "Structured column info for dynamic schema generation. "
2125
+ "When provided, column names are embedded in the JSON schema for better LLM accuracy."
2126
+ ),
2127
+ )
2128
+ constraints: Optional[list[str]] = Field(
2129
+ default=None, description="Default constraints to always apply"
2130
+ )
2131
+ max_subqueries: int = Field(
2132
+ default=3, description="Maximum number of parallel subqueries"
2133
+ )
2134
+ rrf_k: int = Field(
2135
+ default=60,
2136
+ description="RRF constant (lower values weight top ranks more heavily)",
2137
+ )
2138
+ examples: Optional[list[dict[str, Any]]] = Field(
2139
+ default=None,
2140
+ description="Few-shot examples for domain-specific filter translation",
2141
+ )
2142
+ normalize_filter_case: Optional[Literal["uppercase", "lowercase"]] = Field(
2143
+ default=None,
2144
+ description="Auto-normalize filter string values to uppercase or lowercase",
2145
+ )
2146
+
2147
+
2148
+ class RouterModel(BaseModel):
2149
+ """
2150
+ Select internal execution mode based on query characteristics.
2151
+
2152
+ Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
2153
+ Routes to internal modes within the same retriever, not external retrievers.
2154
+ Cross-index routing belongs at the agent/tool-selection level.
2155
+
2156
+ Execution Modes:
2157
+ - "standard": Single similarity_search() for simple keyword/product searches
2158
+ - "instructed": Decompose -> Parallel Search -> RRF for constrained queries
2159
+
2160
+ Example:
2161
+ ```yaml
2162
+ retriever:
2163
+ router:
2164
+ model: *fast_llm
2165
+ default_mode: standard
2166
+ auto_bypass: true
2167
+ ```
2168
+ """
2169
+
2170
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2171
+
2172
+ model: Optional["LLMModel"] = Field(
2173
+ default=None,
2174
+ description="LLM for routing decision (fast model recommended)",
2175
+ )
2176
+ default_mode: Literal["standard", "instructed"] = Field(
2177
+ default="standard",
2178
+ description="Fallback mode if routing fails",
2179
+ )
2180
+ auto_bypass: bool = Field(
2181
+ default=True,
2182
+ description="Skip Instruction Reranker and Verifier for standard mode",
2183
+ )
2184
+
2185
+
2186
+ class VerificationResult(BaseModel):
2187
+ """Verification of whether search results satisfy the user's constraints.
2188
+
2189
+ Analyze the retrieved results against the original query and any explicit
2190
+ constraints to determine if a retry with modified filters is needed.
2191
+ """
2192
+
2193
+ model_config = ConfigDict(extra="forbid")
2194
+
2195
+ passed: bool = Field(
2196
+ description="True if results satisfy the user's query intent and constraints."
2197
+ )
2198
+ confidence: float = Field(
2199
+ ge=0.0,
2200
+ le=1.0,
2201
+ description="Confidence in the verification decision, from 0.0 (uncertain) to 1.0 (certain).",
2202
+ )
2203
+ feedback: Optional[str] = Field(
2204
+ default=None,
2205
+ description="Explanation of why verification passed or failed. Include specific issues found.",
2206
+ )
2207
+ suggested_filter_relaxation: Optional[dict[str, Any]] = Field(
2208
+ default=None,
2209
+ description=(
2210
+ "Suggested filter modifications for retry. "
2211
+ "Keys are column names, values indicate changes (e.g., 'REMOVE', 'WIDEN', or new values)."
2212
+ ),
2213
+ )
2214
+ unmet_constraints: Optional[list[str]] = Field(
2215
+ default=None,
2216
+ description="List of user constraints that the results failed to satisfy.",
2217
+ )
2218
+
2219
+
2220
+ class VerifierModel(BaseModel):
2221
+ """
2222
+ Validate results against user constraints with structured feedback.
2223
+
2224
+ Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
2225
+ Skipped for 'standard' mode when auto_bypass=true in router config.
2226
+ Returns structured feedback for intelligent retry, not blind retry.
2227
+
2228
+ Example:
2229
+ ```yaml
2230
+ retriever:
2231
+ verifier:
2232
+ model: *fast_llm
2233
+ on_failure: warn_and_retry
2234
+ max_retries: 1
2235
+ ```
2236
+ """
2237
+
2238
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
2239
+
2240
+ model: Optional["LLMModel"] = Field(
2241
+ default=None,
2242
+ description="LLM for verification (fast model recommended)",
2243
+ )
2244
+ on_failure: Literal["warn", "retry", "warn_and_retry"] = Field(
2245
+ default="warn",
2246
+ description="Behavior when verification fails",
2247
+ )
2248
+ max_retries: int = Field(
2249
+ default=1,
2250
+ description="Maximum retry attempts before returning with warning",
2251
+ )
2252
+
2253
+
2254
+ class RankedDocument(BaseModel):
2255
+ """Single ranked document."""
2256
+
2257
+ index: int = Field(description="Document index from input list")
2258
+ score: float = Field(description="0.0-1.0 relevance score")
2259
+ reason: str = Field(default="", description="Why this score")
2260
+
2261
+
2262
+ class RankingResult(BaseModel):
2263
+ """Reranking output."""
2264
+
2265
+ rankings: list[RankedDocument] = Field(
2266
+ default_factory=list,
2267
+ description="Ranked documents, highest score first",
2268
+ )
1576
2269
 
1577
2270
 
1578
2271
  class RetrieverModel(BaseModel):
@@ -1582,10 +2275,22 @@ class RetrieverModel(BaseModel):
1582
2275
  search_parameters: SearchParametersModel = Field(
1583
2276
  default_factory=SearchParametersModel
1584
2277
  )
2278
+ router: Optional[RouterModel] = Field(
2279
+ default=None,
2280
+ description="Optional query router for selecting execution mode (standard vs instructed).",
2281
+ )
1585
2282
  rerank: Optional[RerankParametersModel | bool] = Field(
1586
2283
  default=None,
1587
2284
  description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
1588
2285
  )
2286
+ instructed: Optional[InstructedRetrieverModel] = Field(
2287
+ default=None,
2288
+ description="Optional instructed retrieval with query decomposition and RRF merging.",
2289
+ )
2290
+ verifier: Optional[VerifierModel] = Field(
2291
+ default=None,
2292
+ description="Optional result verification with structured feedback for retry.",
2293
+ )
1589
2294
 
1590
2295
  @model_validator(mode="after")
1591
2296
  def set_default_columns(self) -> Self:
@@ -1596,9 +2301,13 @@ class RetrieverModel(BaseModel):
1596
2301
 
1597
2302
  @model_validator(mode="after")
1598
2303
  def set_default_reranker(self) -> Self:
1599
- """Convert bool to ReRankParametersModel with defaults."""
2304
+ """Convert bool to ReRankParametersModel with defaults.
2305
+
2306
+ When rerank: true is used, sets the default FlashRank model
2307
+ (ms-marco-MiniLM-L-12-v2) to enable reranking.
2308
+ """
1600
2309
  if isinstance(self.rerank, bool) and self.rerank:
1601
- self.rerank = RerankParametersModel()
2310
+ self.rerank = RerankParametersModel(model="ms-marco-MiniLM-L-12-v2")
1602
2311
  return self
1603
2312
 
1604
2313
 
@@ -1731,11 +2440,32 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1731
2440
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1732
2441
  args: list[str] = Field(default_factory=list)
1733
2442
  # MCP-specific fields
2443
+ app: Optional[DatabricksAppModel] = None
1734
2444
  connection: Optional[ConnectionModel] = None
1735
2445
  functions: Optional[SchemaModel] = None
1736
2446
  genie_room: Optional[GenieRoomModel] = None
1737
2447
  sql: Optional[bool] = None
1738
2448
  vector_search: Optional[VectorStoreModel] = None
2449
+ # Tool filtering
2450
+ include_tools: Optional[list[str]] = Field(
2451
+ default=None,
2452
+ description=(
2453
+ "Optional list of tool names or glob patterns to include from the MCP server. "
2454
+ "If specified, only tools matching these patterns will be loaded. "
2455
+ "Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
2456
+ "Examples: ['execute_query', 'list_*', 'get_?_data']"
2457
+ ),
2458
+ )
2459
+ exclude_tools: Optional[list[str]] = Field(
2460
+ default=None,
2461
+ description=(
2462
+ "Optional list of tool names or glob patterns to exclude from the MCP server. "
2463
+ "Tools matching these patterns will not be loaded. "
2464
+ "Takes precedence over include_tools. "
2465
+ "Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
2466
+ "Examples: ['drop_*', 'delete_*', 'execute_ddl']"
2467
+ ),
2468
+ )
1739
2469
 
1740
2470
  @property
1741
2471
  def api_scopes(self) -> Sequence[str]:
@@ -1798,6 +2528,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1798
2528
 
1799
2529
  Returns the URL based on the configured source:
1800
2530
  - If url is set, returns it directly
2531
+ - If app is set, retrieves URL from Databricks App via workspace client
1801
2532
  - If connection is set, constructs URL from connection
1802
2533
  - If genie_room is set, constructs Genie MCP URL
1803
2534
  - If sql is set, constructs DBSQL MCP URL (serverless)
@@ -1810,6 +2541,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1810
2541
  - Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
1811
2542
  - UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
1812
2543
  - Connection: https://{host}/api/2.0/mcp/external/{connection_name}
2544
+ - Databricks App: Retrieved dynamically from workspace
1813
2545
  """
1814
2546
  # Direct URL provided
1815
2547
  if self.url:
@@ -1832,6 +2564,49 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1832
2564
  if self.sql:
1833
2565
  return f"{workspace_host}/api/2.0/mcp/sql"
1834
2566
 
2567
+ # Databricks App - MCP endpoint is at {app_url}/mcp
2568
+ # Try McpFunctionModel's workspace_client first (which may have credentials),
2569
+ # then fall back to DatabricksAppModel.url property (which uses its own workspace_client)
2570
+ if self.app:
2571
+ from databricks.sdk.service.apps import App
2572
+
2573
+ app_url: str | None = None
2574
+
2575
+ # First, try using McpFunctionModel's workspace_client
2576
+ try:
2577
+ app: App = self.workspace_client.apps.get(self.app.name)
2578
+ app_url = app.url
2579
+ logger.trace(
2580
+ "Got app URL using McpFunctionModel workspace_client",
2581
+ app_name=self.app.name,
2582
+ url=app_url,
2583
+ )
2584
+ except Exception as e:
2585
+ logger.debug(
2586
+ "Failed to get app URL using McpFunctionModel workspace_client, "
2587
+ "trying DatabricksAppModel.url property",
2588
+ app_name=self.app.name,
2589
+ error=str(e),
2590
+ )
2591
+
2592
+ # Fall back to DatabricksAppModel.url property
2593
+ if not app_url:
2594
+ try:
2595
+ app_url = self.app.url
2596
+ logger.trace(
2597
+ "Got app URL using DatabricksAppModel.url property",
2598
+ app_name=self.app.name,
2599
+ url=app_url,
2600
+ )
2601
+ except Exception as e:
2602
+ raise RuntimeError(
2603
+ f"Databricks App '{self.app.name}' does not have a URL. "
2604
+ "The app may not be deployed yet, or credentials may be invalid. "
2605
+ f"Error: {e}"
2606
+ ) from e
2607
+
2608
+ return f"{app_url.rstrip('/')}/mcp"
2609
+
1835
2610
  # Vector Search
1836
2611
  if self.vector_search:
1837
2612
  if (
@@ -1841,33 +2616,35 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1841
2616
  raise ValueError(
1842
2617
  "vector_search must have an index with a schema (catalog/schema) configured"
1843
2618
  )
1844
- catalog: str = self.vector_search.index.schema_model.catalog_name
1845
- schema: str = self.vector_search.index.schema_model.schema_name
2619
+ catalog: str = value_of(self.vector_search.index.schema_model.catalog_name)
2620
+ schema: str = value_of(self.vector_search.index.schema_model.schema_name)
1846
2621
  return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
1847
2622
 
1848
2623
  # UC Functions MCP server
1849
2624
  if self.functions:
1850
- catalog: str = self.functions.catalog_name
1851
- schema: str = self.functions.schema_name
2625
+ catalog: str = value_of(self.functions.catalog_name)
2626
+ schema: str = value_of(self.functions.schema_name)
1852
2627
  return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
1853
2628
 
1854
2629
  raise ValueError(
1855
- "No URL source configured. Provide one of: url, connection, genie_room, "
2630
+ "No URL source configured. Provide one of: url, app, connection, genie_room, "
1856
2631
  "sql, vector_search, or functions"
1857
2632
  )
1858
2633
 
1859
2634
  @field_serializer("transport")
1860
- def serialize_transport(self, value) -> str:
2635
+ def serialize_transport(self, value: TransportType) -> str:
2636
+ """Serialize transport enum to string."""
1861
2637
  if isinstance(value, TransportType):
1862
2638
  return value.value
1863
2639
  return str(value)
1864
2640
 
1865
2641
  @model_validator(mode="after")
1866
- def validate_mutually_exclusive(self) -> "McpFunctionModel":
2642
+ def validate_mutually_exclusive(self) -> Self:
1867
2643
  """Validate that exactly one URL source is provided."""
1868
2644
  # Count how many URL sources are provided
1869
2645
  url_sources: list[tuple[str, Any]] = [
1870
2646
  ("url", self.url),
2647
+ ("app", self.app),
1871
2648
  ("connection", self.connection),
1872
2649
  ("genie_room", self.genie_room),
1873
2650
  ("sql", self.sql),
@@ -1883,13 +2660,13 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1883
2660
  if len(provided_sources) == 0:
1884
2661
  raise ValueError(
1885
2662
  "For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
1886
- "url, connection, genie_room, sql, vector_search, or functions"
2663
+ "url, app, connection, genie_room, sql, vector_search, or functions"
1887
2664
  )
1888
2665
  if len(provided_sources) > 1:
1889
2666
  raise ValueError(
1890
2667
  f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
1891
2668
  f"Found: {', '.join(provided_sources)}. "
1892
- f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
2669
+ f"Please provide only one of: url, app, connection, genie_room, sql, vector_search, or functions"
1893
2670
  )
1894
2671
 
1895
2672
  if self.transport == TransportType.STDIO:
@@ -1901,14 +2678,41 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
1901
2678
  return self
1902
2679
 
1903
2680
  @model_validator(mode="after")
1904
- def update_url(self) -> "McpFunctionModel":
1905
- self.url = value_of(self.url)
2681
+ def update_url(self) -> Self:
2682
+ """Resolve AnyVariable to concrete value for URL."""
2683
+ if self.url is not None:
2684
+ resolved_value: Any = value_of(self.url)
2685
+ # Cast to string since URL must be a string
2686
+ self.url = str(resolved_value) if resolved_value else None
1906
2687
  return self
1907
2688
 
1908
2689
  @model_validator(mode="after")
1909
- def update_headers(self) -> "McpFunctionModel":
2690
+ def update_headers(self) -> Self:
2691
+ """Resolve AnyVariable to concrete values for headers."""
1910
2692
  for key, value in self.headers.items():
1911
- self.headers[key] = value_of(value)
2693
+ resolved_value: Any = value_of(value)
2694
+ # Headers must be strings
2695
+ self.headers[key] = str(resolved_value) if resolved_value else ""
2696
+ return self
2697
+
2698
+ @model_validator(mode="after")
2699
+ def validate_tool_filters(self) -> Self:
2700
+ """Validate tool filter configuration."""
2701
+ from loguru import logger
2702
+
2703
+ # Warn if both are empty lists (explicit but pointless)
2704
+ if self.include_tools is not None and len(self.include_tools) == 0:
2705
+ logger.warning(
2706
+ "include_tools is empty list - no tools will be loaded. "
2707
+ "Remove field to load all tools."
2708
+ )
2709
+
2710
+ if self.exclude_tools is not None and len(self.exclude_tools) == 0:
2711
+ logger.warning(
2712
+ "exclude_tools is empty list - has no effect. "
2713
+ "Remove field or add patterns."
2714
+ )
2715
+
1912
2716
  return self
1913
2717
 
1914
2718
  def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
@@ -2316,7 +3120,6 @@ class SupervisorModel(BaseModel):
2316
3120
 
2317
3121
  class SwarmModel(BaseModel):
2318
3122
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
2319
- model: LLMModel
2320
3123
  default_agent: Optional[AgentModel | str] = None
2321
3124
  middleware: list[MiddlewareModel] = Field(
2322
3125
  default_factory=list,
@@ -2330,11 +3133,17 @@ class SwarmModel(BaseModel):
2330
3133
  class OrchestrationModel(BaseModel):
2331
3134
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
2332
3135
  supervisor: Optional[SupervisorModel] = None
2333
- swarm: Optional[SwarmModel] = None
3136
+ swarm: Optional[SwarmModel | Literal[True]] = None
2334
3137
  memory: Optional[MemoryModel] = None
2335
3138
 
2336
3139
  @model_validator(mode="after")
2337
- def validate_mutually_exclusive(self) -> Self:
3140
+ def validate_and_normalize(self) -> Self:
3141
+ """Validate orchestration and normalize swarm shorthand."""
3142
+ # Convert swarm: true to SwarmModel()
3143
+ if self.swarm is True:
3144
+ self.swarm = SwarmModel()
3145
+
3146
+ # Validate mutually exclusive
2338
3147
  if self.supervisor is not None and self.swarm is not None:
2339
3148
  raise ValueError("Cannot specify both supervisor and swarm")
2340
3149
  if self.supervisor is None and self.swarm is None:
@@ -2544,6 +3353,11 @@ class AppModel(BaseModel):
2544
3353
  "which is supported by Databricks Model Serving. This allows deploying from "
2545
3354
  "environments with different Python versions (e.g., Databricks Apps with 3.11).",
2546
3355
  )
3356
+ deployment_target: Optional[DeploymentTarget] = Field(
3357
+ default=None,
3358
+ description="Default deployment target. If not specified, defaults to MODEL_SERVING. "
3359
+ "Can be overridden via CLI --target flag. Options: 'model_serving' or 'apps'.",
3360
+ )
2547
3361
 
2548
3362
  @model_validator(mode="after")
2549
3363
  def set_databricks_env_vars(self) -> Self:
@@ -2601,9 +3415,7 @@ class AppModel(BaseModel):
2601
3415
  elif len(self.agents) == 1:
2602
3416
  default_agent: AgentModel = self.agents[0]
2603
3417
  self.orchestration = OrchestrationModel(
2604
- swarm=SwarmModel(
2605
- model=default_agent.model, default_agent=default_agent
2606
- )
3418
+ swarm=SwarmModel(default_agent=default_agent)
2607
3419
  )
2608
3420
  else:
2609
3421
  raise ValueError("At least one agent must be specified")
@@ -2643,8 +3455,24 @@ class GuidelineModel(BaseModel):
2643
3455
 
2644
3456
 
2645
3457
  class EvaluationModel(BaseModel):
3458
+ """
3459
+ Configuration for MLflow GenAI evaluation.
3460
+
3461
+ Attributes:
3462
+ model: LLM model used as the judge for LLM-based scorers (e.g., Guidelines, Safety).
3463
+ This model evaluates agent responses during evaluation.
3464
+ table: Table to store evaluation results.
3465
+ num_evals: Number of evaluation samples to generate.
3466
+ agent_description: Description of the agent for evaluation data generation.
3467
+ question_guidelines: Guidelines for generating evaluation questions.
3468
+ custom_inputs: Custom inputs to pass to the agent during evaluation.
3469
+ guidelines: List of guideline configurations for Guidelines scorers.
3470
+ """
3471
+
2646
3472
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
2647
- model: LLMModel
3473
+ model: LLMModel = Field(
3474
+ ..., description="LLM model used as the judge for LLM-based evaluation scorers"
3475
+ )
2648
3476
  table: TableModel
2649
3477
  num_evals: int
2650
3478
  agent_description: Optional[str] = None
@@ -2652,6 +3480,16 @@ class EvaluationModel(BaseModel):
2652
3480
  custom_inputs: dict[str, Any] = Field(default_factory=dict)
2653
3481
  guidelines: list[GuidelineModel] = Field(default_factory=list)
2654
3482
 
3483
+ @property
3484
+ def judge_model_endpoint(self) -> str:
3485
+ """
3486
+ Get the judge model endpoint string for MLflow scorers.
3487
+
3488
+ Returns:
3489
+ Endpoint string in format 'databricks:/model-name'
3490
+ """
3491
+ return f"databricks:/{self.model.name}"
3492
+
2655
3493
 
2656
3494
  class EvaluationDatasetExpectationsModel(BaseModel):
2657
3495
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -2849,6 +3687,165 @@ class OptimizationsModel(BaseModel):
2849
3687
  return results
2850
3688
 
2851
3689
 
3690
+ class SemanticCacheEvalEntryModel(BaseModel):
3691
+ """Single evaluation entry for semantic cache threshold optimization.
3692
+
3693
+ Represents a pair of question/context combinations to evaluate
3694
+ whether the cache should return a hit or miss.
3695
+
3696
+ Example:
3697
+ entry:
3698
+ question: "What are total sales?"
3699
+ question_embedding: [0.1, 0.2, ...] # Pre-computed
3700
+ context: "Previous: Show me revenue"
3701
+ context_embedding: [0.1, 0.2, ...]
3702
+ cached_question: "Show total sales"
3703
+ cached_question_embedding: [0.1, 0.2, ...]
3704
+ cached_context: "Previous: Show me revenue"
3705
+ cached_context_embedding: [0.1, 0.2, ...]
3706
+ expected_match: true
3707
+ """
3708
+
3709
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
3710
+ question: str
3711
+ question_embedding: list[float]
3712
+ context: str = ""
3713
+ context_embedding: list[float] = Field(default_factory=list)
3714
+ cached_question: str
3715
+ cached_question_embedding: list[float]
3716
+ cached_context: str = ""
3717
+ cached_context_embedding: list[float] = Field(default_factory=list)
3718
+ expected_match: Optional[bool] = None # None = use LLM judge
3719
+
3720
+
3721
+ class SemanticCacheEvalDatasetModel(BaseModel):
3722
+ """Dataset for semantic cache threshold optimization.
3723
+
3724
+ Contains pairs of questions/contexts to evaluate whether thresholds
3725
+ correctly identify semantic matches.
3726
+
3727
+ Example:
3728
+ dataset:
3729
+ name: my_cache_eval_dataset
3730
+ description: "Evaluation data for cache tuning"
3731
+ entries:
3732
+ - question: "What are total sales?"
3733
+ # ... entry fields
3734
+ """
3735
+
3736
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
3737
+ name: str
3738
+ description: str = ""
3739
+ entries: list[SemanticCacheEvalEntryModel] = Field(default_factory=list)
3740
+
3741
+ def as_eval_dataset(self) -> "SemanticCacheEvalDataset":
3742
+ """Convert to internal evaluation dataset format."""
3743
+ from dao_ai.genie.cache.optimization import (
3744
+ SemanticCacheEvalDataset,
3745
+ SemanticCacheEvalEntry,
3746
+ )
3747
+
3748
+ entries = [
3749
+ SemanticCacheEvalEntry(
3750
+ question=e.question,
3751
+ question_embedding=e.question_embedding,
3752
+ context=e.context,
3753
+ context_embedding=e.context_embedding,
3754
+ cached_question=e.cached_question,
3755
+ cached_question_embedding=e.cached_question_embedding,
3756
+ cached_context=e.cached_context,
3757
+ cached_context_embedding=e.cached_context_embedding,
3758
+ expected_match=e.expected_match,
3759
+ )
3760
+ for e in self.entries
3761
+ ]
3762
+
3763
+ return SemanticCacheEvalDataset(
3764
+ name=self.name,
3765
+ entries=entries,
3766
+ description=self.description,
3767
+ )
3768
+
3769
+
3770
+ class SemanticCacheThresholdOptimizationModel(BaseModel):
3771
+ """Configuration for semantic cache threshold optimization.
3772
+
3773
+ Uses Optuna Bayesian optimization to find optimal threshold values
3774
+ that maximize cache hit accuracy (F1 score by default).
3775
+
3776
+ Example:
3777
+ threshold_optimization:
3778
+ name: optimize_cache_thresholds
3779
+ cache_parameters: *my_cache_params
3780
+ dataset: *my_eval_dataset
3781
+ judge_model: databricks-meta-llama-3-3-70b-instruct
3782
+ n_trials: 50
3783
+ metric: f1
3784
+ """
3785
+
3786
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
3787
+ name: str
3788
+ cache_parameters: Optional[GenieContextAwareCacheParametersModel] = None
3789
+ dataset: SemanticCacheEvalDatasetModel
3790
+ judge_model: Optional[LLMModel | str] = "databricks-meta-llama-3-3-70b-instruct"
3791
+ n_trials: int = 50
3792
+ metric: Literal["f1", "precision", "recall", "fbeta"] = "f1"
3793
+ beta: float = 1.0 # For fbeta metric
3794
+ seed: Optional[int] = None
3795
+
3796
+ def optimize(
3797
+ self, w: WorkspaceClient | None = None
3798
+ ) -> "ThresholdOptimizationResult":
3799
+ """
3800
+ Optimize semantic cache thresholds.
3801
+
3802
+ Args:
3803
+ w: Optional WorkspaceClient (not used, kept for API compatibility)
3804
+
3805
+ Returns:
3806
+ ThresholdOptimizationResult with optimized thresholds
3807
+ """
3808
+ from dao_ai.genie.cache.optimization import (
3809
+ ThresholdOptimizationResult,
3810
+ optimize_semantic_cache_thresholds,
3811
+ )
3812
+
3813
+ # Convert dataset
3814
+ eval_dataset = self.dataset.as_eval_dataset()
3815
+
3816
+ # Get original thresholds from cache_parameters
3817
+ original_thresholds: dict[str, float] | None = None
3818
+ if self.cache_parameters:
3819
+ original_thresholds = {
3820
+ "similarity_threshold": self.cache_parameters.similarity_threshold,
3821
+ "context_similarity_threshold": self.cache_parameters.context_similarity_threshold,
3822
+ "question_weight": self.cache_parameters.question_weight or 0.6,
3823
+ }
3824
+
3825
+ # Get judge model
3826
+ judge_model_name: str
3827
+ if isinstance(self.judge_model, str):
3828
+ judge_model_name = self.judge_model
3829
+ elif self.judge_model:
3830
+ judge_model_name = self.judge_model.uri
3831
+ else:
3832
+ judge_model_name = "databricks-meta-llama-3-3-70b-instruct"
3833
+
3834
+ result: ThresholdOptimizationResult = optimize_semantic_cache_thresholds(
3835
+ dataset=eval_dataset,
3836
+ original_thresholds=original_thresholds,
3837
+ judge_model=judge_model_name,
3838
+ n_trials=self.n_trials,
3839
+ metric=self.metric,
3840
+ beta=self.beta,
3841
+ register_if_improved=True,
3842
+ study_name=self.name,
3843
+ seed=self.seed,
3844
+ )
3845
+
3846
+ return result
3847
+
3848
+
2852
3849
  class DatasetFormat(str, Enum):
2853
3850
  CSV = "csv"
2854
3851
  DELTA = "delta"
@@ -3024,6 +4021,7 @@ class ResourcesModel(BaseModel):
3024
4021
 
3025
4022
  class AppConfig(BaseModel):
3026
4023
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
4024
+ version: Optional[str] = None
3027
4025
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
3028
4026
  service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
3029
4027
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
@@ -3044,6 +4042,9 @@ class AppConfig(BaseModel):
3044
4042
  )
3045
4043
  providers: Optional[dict[type | str, Any]] = None
3046
4044
 
4045
+ # Private attribute to track the source config file path (set by from_file)
4046
+ _source_config_path: str | None = None
4047
+
3047
4048
  @classmethod
3048
4049
  def from_file(cls, path: PathLike) -> "AppConfig":
3049
4050
  path = Path(path).as_posix()
@@ -3051,12 +4052,20 @@ class AppConfig(BaseModel):
3051
4052
  model_config: ModelConfig = ModelConfig(development_config=path)
3052
4053
  config: AppConfig = AppConfig(**model_config.to_dict())
3053
4054
 
4055
+ # Store the source config path for later use (e.g., Apps deployment)
4056
+ config._source_config_path = path
4057
+
3054
4058
  config.initialize()
3055
4059
 
3056
4060
  atexit.register(config.shutdown)
3057
4061
 
3058
4062
  return config
3059
4063
 
4064
+ @property
4065
+ def source_config_path(self) -> str | None:
4066
+ """Get the source config file path if loaded via from_file."""
4067
+ return self._source_config_path
4068
+
3060
4069
  def initialize(self) -> None:
3061
4070
  from dao_ai.hooks.core import create_hooks
3062
4071
  from dao_ai.logging import configure_logging
@@ -3127,6 +4136,7 @@ class AppConfig(BaseModel):
3127
4136
 
3128
4137
  def deploy_agent(
3129
4138
  self,
4139
+ target: DeploymentTarget | None = None,
3130
4140
  w: WorkspaceClient | None = None,
3131
4141
  vsc: "VectorSearchClient | None" = None,
3132
4142
  pat: str | None = None,
@@ -3134,9 +4144,39 @@ class AppConfig(BaseModel):
3134
4144
  client_secret: str | None = None,
3135
4145
  workspace_host: str | None = None,
3136
4146
  ) -> None:
4147
+ """
4148
+ Deploy the agent to the specified target.
4149
+
4150
+ Target resolution follows this priority:
4151
+ 1. Explicit `target` parameter (if provided)
4152
+ 2. `app.deployment_target` from config file (if set)
4153
+ 3. Default: MODEL_SERVING
4154
+
4155
+ Args:
4156
+ target: The deployment target (MODEL_SERVING or APPS). If None, uses
4157
+ config.app.deployment_target or defaults to MODEL_SERVING.
4158
+ w: Optional WorkspaceClient instance
4159
+ vsc: Optional VectorSearchClient instance
4160
+ pat: Optional personal access token for authentication
4161
+ client_id: Optional client ID for service principal authentication
4162
+ client_secret: Optional client secret for service principal authentication
4163
+ workspace_host: Optional workspace host URL
4164
+ """
3137
4165
  from dao_ai.providers.base import ServiceProvider
3138
4166
  from dao_ai.providers.databricks import DatabricksProvider
3139
4167
 
4168
+ # Resolve target using hybrid logic:
4169
+ # 1. Explicit parameter takes precedence
4170
+ # 2. Fall back to config.app.deployment_target
4171
+ # 3. Default to MODEL_SERVING
4172
+ resolved_target: DeploymentTarget
4173
+ if target is not None:
4174
+ resolved_target = target
4175
+ elif self.app is not None and self.app.deployment_target is not None:
4176
+ resolved_target = self.app.deployment_target
4177
+ else:
4178
+ resolved_target = DeploymentTarget.MODEL_SERVING
4179
+
3140
4180
  provider: ServiceProvider = DatabricksProvider(
3141
4181
  w=w,
3142
4182
  vsc=vsc,
@@ -3145,7 +4185,7 @@ class AppConfig(BaseModel):
3145
4185
  client_secret=client_secret,
3146
4186
  workspace_host=workspace_host,
3147
4187
  )
3148
- provider.deploy_agent(self)
4188
+ provider.deploy_agent(self, target=resolved_target)
3149
4189
 
3150
4190
  def find_agents(
3151
4191
  self, predicate: Callable[[AgentModel], bool] | None = None