dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +446 -16
- dao_ai/config.py +1034 -103
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +5 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +4 -4
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +352 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.5.dist-info/RECORD +0 -70
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from os import PathLike
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
10
11
|
Any,
|
|
11
12
|
Callable,
|
|
12
13
|
Iterator,
|
|
@@ -18,12 +19,20 @@ from typing import (
|
|
|
18
19
|
Union,
|
|
19
20
|
)
|
|
20
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from dao_ai.genie.cache.optimization import (
|
|
24
|
+
SemanticCacheEvalDataset,
|
|
25
|
+
ThresholdOptimizationResult,
|
|
26
|
+
)
|
|
27
|
+
from dao_ai.state import Context
|
|
28
|
+
|
|
21
29
|
from databricks.sdk import WorkspaceClient
|
|
22
30
|
from databricks.sdk.credentials_provider import (
|
|
23
31
|
CredentialsStrategy,
|
|
24
32
|
ModelServingUserCredentials,
|
|
25
33
|
)
|
|
26
34
|
from databricks.sdk.errors.platform import NotFound
|
|
35
|
+
from databricks.sdk.service.apps import App
|
|
27
36
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
37
|
from databricks.sdk.service.dashboards import GenieSpace
|
|
29
38
|
from databricks.sdk.service.database import DatabaseInstance
|
|
@@ -147,7 +156,7 @@ class PrimitiveVariableModel(BaseModel, HasValue):
|
|
|
147
156
|
return str(value)
|
|
148
157
|
|
|
149
158
|
@model_validator(mode="after")
|
|
150
|
-
def validate_value(self) ->
|
|
159
|
+
def validate_value(self) -> Self:
|
|
151
160
|
if not isinstance(self.as_value(), (str, int, float, bool)):
|
|
152
161
|
raise ValueError("Value must be a primitive type (str, int, float, bool)")
|
|
153
162
|
return self
|
|
@@ -207,7 +216,9 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
207
216
|
Authentication Options:
|
|
208
217
|
----------------------
|
|
209
218
|
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
210
|
-
calling user's identity
|
|
219
|
+
calling user's identity. Implementation varies by deployment:
|
|
220
|
+
- Databricks Apps: Uses X-Forwarded-Access-Token from request headers
|
|
221
|
+
- Model Serving: Uses ModelServingUserCredentials
|
|
211
222
|
|
|
212
223
|
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
213
224
|
(client_id + client_secret + workspace_host) for service principal auth.
|
|
@@ -220,9 +231,17 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
220
231
|
|
|
221
232
|
Authentication Priority:
|
|
222
233
|
1. OBO (on_behalf_of_user=True)
|
|
234
|
+
- Checks for forwarded headers (Databricks Apps)
|
|
235
|
+
- Falls back to ModelServingUserCredentials (Model Serving)
|
|
223
236
|
2. Service Principal (client_id + client_secret + workspace_host)
|
|
224
237
|
3. PAT (pat + workspace_host)
|
|
225
238
|
4. Ambient/default authentication
|
|
239
|
+
|
|
240
|
+
Note: When on_behalf_of_user=True, the agent acts as the calling user regardless
|
|
241
|
+
of deployment target. In Databricks Apps, this uses X-Forwarded-Access-Token
|
|
242
|
+
automatically captured by MLflow AgentServer. In Model Serving, this uses
|
|
243
|
+
ModelServingUserCredentials. Forwarded headers are ONLY used when
|
|
244
|
+
on_behalf_of_user=True.
|
|
226
245
|
"""
|
|
227
246
|
|
|
228
247
|
model_config = ConfigDict(use_enum_values=True)
|
|
@@ -234,9 +253,6 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
234
253
|
workspace_host: Optional[AnyVariable] = None
|
|
235
254
|
pat: Optional[AnyVariable] = None
|
|
236
255
|
|
|
237
|
-
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
-
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
-
|
|
240
256
|
@abstractmethod
|
|
241
257
|
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
242
258
|
|
|
@@ -272,19 +288,16 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
272
288
|
"""
|
|
273
289
|
Get a WorkspaceClient configured with the appropriate authentication.
|
|
274
290
|
|
|
275
|
-
|
|
291
|
+
A new client is created on each access.
|
|
276
292
|
|
|
277
293
|
Authentication priority:
|
|
278
|
-
1.
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
294
|
+
1. On-Behalf-Of User (on_behalf_of_user=True):
|
|
295
|
+
- Uses ModelServingUserCredentials (Model Serving)
|
|
296
|
+
- For Databricks Apps with headers, use workspace_client_from(context)
|
|
297
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
298
|
+
3. PAT (pat + workspace_host)
|
|
299
|
+
4. Ambient/default authentication
|
|
283
300
|
"""
|
|
284
|
-
# Return cached client if already instantiated
|
|
285
|
-
if self._workspace_client is not None:
|
|
286
|
-
return self._workspace_client
|
|
287
|
-
|
|
288
301
|
from dao_ai.utils import normalize_host
|
|
289
302
|
|
|
290
303
|
# Check for OBO first (highest priority)
|
|
@@ -292,12 +305,9 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
292
305
|
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
293
306
|
logger.debug(
|
|
294
307
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
295
|
-
f"with OBO credentials strategy"
|
|
296
|
-
)
|
|
297
|
-
self._workspace_client = WorkspaceClient(
|
|
298
|
-
credentials_strategy=credentials_strategy
|
|
308
|
+
f"with OBO credentials strategy (Model Serving)"
|
|
299
309
|
)
|
|
300
|
-
return
|
|
310
|
+
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
301
311
|
|
|
302
312
|
# Check for service principal credentials
|
|
303
313
|
client_id_value: str | None = (
|
|
@@ -312,18 +322,24 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
312
322
|
else None
|
|
313
323
|
)
|
|
314
324
|
|
|
315
|
-
if client_id_value and client_secret_value
|
|
325
|
+
if client_id_value and client_secret_value:
|
|
326
|
+
# If workspace_host is not provided, check DATABRICKS_HOST env var first,
|
|
327
|
+
# then fall back to WorkspaceClient().config.host
|
|
328
|
+
if not workspace_host_value:
|
|
329
|
+
workspace_host_value = os.getenv("DATABRICKS_HOST")
|
|
330
|
+
if not workspace_host_value:
|
|
331
|
+
workspace_host_value = WorkspaceClient().config.host
|
|
332
|
+
|
|
316
333
|
logger.debug(
|
|
317
334
|
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
318
335
|
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
319
336
|
)
|
|
320
|
-
|
|
337
|
+
return WorkspaceClient(
|
|
321
338
|
host=workspace_host_value,
|
|
322
339
|
client_id=client_id_value,
|
|
323
340
|
client_secret=client_secret_value,
|
|
324
341
|
auth_type="oauth-m2m",
|
|
325
342
|
)
|
|
326
|
-
return self._workspace_client
|
|
327
343
|
|
|
328
344
|
# Check for PAT authentication
|
|
329
345
|
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
@@ -331,20 +347,83 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
331
347
|
logger.debug(
|
|
332
348
|
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
333
349
|
)
|
|
334
|
-
|
|
350
|
+
return WorkspaceClient(
|
|
335
351
|
host=workspace_host_value,
|
|
336
352
|
token=pat_value,
|
|
337
353
|
auth_type="pat",
|
|
338
354
|
)
|
|
339
|
-
return self._workspace_client
|
|
340
355
|
|
|
341
356
|
# Default: use ambient authentication
|
|
342
357
|
logger.debug(
|
|
343
358
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
344
359
|
"with default/ambient authentication"
|
|
345
360
|
)
|
|
346
|
-
|
|
347
|
-
|
|
361
|
+
return WorkspaceClient()
|
|
362
|
+
|
|
363
|
+
def workspace_client_from(self, context: "Context | None") -> WorkspaceClient:
|
|
364
|
+
"""
|
|
365
|
+
Get a WorkspaceClient using headers from the provided Context.
|
|
366
|
+
|
|
367
|
+
Use this method from tools that have access to ToolRuntime[Context].
|
|
368
|
+
This allows OBO authentication to work in Databricks Apps where headers
|
|
369
|
+
are captured at request entry and passed through the Context.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
context: Runtime context containing headers for OBO auth.
|
|
373
|
+
If None or no headers, falls back to workspace_client property.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
WorkspaceClient configured with appropriate authentication.
|
|
377
|
+
"""
|
|
378
|
+
from dao_ai.utils import normalize_host
|
|
379
|
+
|
|
380
|
+
logger.trace(
|
|
381
|
+
"workspace_client_from called",
|
|
382
|
+
context=context,
|
|
383
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Check if we have headers in context for OBO
|
|
387
|
+
if context and context.headers and self.on_behalf_of_user:
|
|
388
|
+
headers = context.headers
|
|
389
|
+
# Try both lowercase and title-case header names (HTTP headers are case-insensitive)
|
|
390
|
+
forwarded_token: str = headers.get(
|
|
391
|
+
"x-forwarded-access-token"
|
|
392
|
+
) or headers.get("X-Forwarded-Access-Token")
|
|
393
|
+
|
|
394
|
+
if forwarded_token:
|
|
395
|
+
forwarded_user = headers.get("x-forwarded-user") or headers.get(
|
|
396
|
+
"X-Forwarded-User", "unknown"
|
|
397
|
+
)
|
|
398
|
+
logger.debug(
|
|
399
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
400
|
+
f"with OBO using forwarded token from Context",
|
|
401
|
+
forwarded_user=forwarded_user,
|
|
402
|
+
)
|
|
403
|
+
# Use workspace_host if configured, otherwise SDK will auto-detect
|
|
404
|
+
workspace_host_value: str | None = (
|
|
405
|
+
normalize_host(value_of(self.workspace_host))
|
|
406
|
+
if self.workspace_host
|
|
407
|
+
else None
|
|
408
|
+
)
|
|
409
|
+
return WorkspaceClient(
|
|
410
|
+
host=workspace_host_value,
|
|
411
|
+
token=forwarded_token,
|
|
412
|
+
auth_type="pat",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Fall back to existing workspace_client property
|
|
416
|
+
return self.workspace_client
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class DeploymentTarget(str, Enum):
|
|
420
|
+
"""Target platform for agent deployment."""
|
|
421
|
+
|
|
422
|
+
MODEL_SERVING = "model_serving"
|
|
423
|
+
"""Deploy to Databricks Model Serving endpoint."""
|
|
424
|
+
|
|
425
|
+
APPS = "apps"
|
|
426
|
+
"""Deploy as a Databricks App."""
|
|
348
427
|
|
|
349
428
|
|
|
350
429
|
class Privilege(str, Enum):
|
|
@@ -391,10 +470,17 @@ class PermissionModel(BaseModel):
|
|
|
391
470
|
|
|
392
471
|
class SchemaModel(BaseModel, HasFullName):
|
|
393
472
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
394
|
-
catalog_name:
|
|
395
|
-
schema_name:
|
|
473
|
+
catalog_name: AnyVariable
|
|
474
|
+
schema_name: AnyVariable
|
|
396
475
|
permissions: Optional[list[PermissionModel]] = Field(default_factory=list)
|
|
397
476
|
|
|
477
|
+
@model_validator(mode="after")
|
|
478
|
+
def resolve_variables(self) -> Self:
|
|
479
|
+
"""Resolve AnyVariable fields to their actual string values."""
|
|
480
|
+
self.catalog_name = value_of(self.catalog_name)
|
|
481
|
+
self.schema_name = value_of(self.schema_name)
|
|
482
|
+
return self
|
|
483
|
+
|
|
398
484
|
@property
|
|
399
485
|
def full_name(self) -> str:
|
|
400
486
|
return f"{self.catalog_name}.{self.schema_name}"
|
|
@@ -408,9 +494,44 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
408
494
|
|
|
409
495
|
|
|
410
496
|
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
497
|
+
"""
|
|
498
|
+
Configuration for a Databricks App resource.
|
|
499
|
+
|
|
500
|
+
The `name` is the unique instance name of the Databricks App within the workspace.
|
|
501
|
+
The `url` is dynamically retrieved from the workspace client by calling
|
|
502
|
+
`apps.get(name)` and returning the app's URL.
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
```yaml
|
|
506
|
+
resources:
|
|
507
|
+
apps:
|
|
508
|
+
my_app:
|
|
509
|
+
name: my-databricks-app
|
|
510
|
+
```
|
|
511
|
+
"""
|
|
512
|
+
|
|
411
513
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
412
514
|
name: str
|
|
413
|
-
|
|
515
|
+
"""The unique instance name of the Databricks App in the workspace."""
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def url(self) -> str:
|
|
519
|
+
"""
|
|
520
|
+
Retrieve the URL of the Databricks App from the workspace.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
The URL of the deployed Databricks App.
|
|
524
|
+
|
|
525
|
+
Raises:
|
|
526
|
+
RuntimeError: If the app is not found or URL is not available.
|
|
527
|
+
"""
|
|
528
|
+
app: App = self.workspace_client.apps.get(self.name)
|
|
529
|
+
if app.url is None:
|
|
530
|
+
raise RuntimeError(
|
|
531
|
+
f"Databricks App '{self.name}' does not have a URL. "
|
|
532
|
+
"The app may not be deployed yet."
|
|
533
|
+
)
|
|
534
|
+
return app.url
|
|
414
535
|
|
|
415
536
|
@property
|
|
416
537
|
def full_name(self) -> str:
|
|
@@ -432,7 +553,7 @@ class TableModel(IsDatabricksResource, HasFullName):
|
|
|
432
553
|
name: Optional[str] = None
|
|
433
554
|
|
|
434
555
|
@model_validator(mode="after")
|
|
435
|
-
def validate_name_or_schema_required(self) ->
|
|
556
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
436
557
|
if not self.name and not self.schema_model:
|
|
437
558
|
raise ValueError(
|
|
438
559
|
"Either 'name' or 'schema_model' must be provided for TableModel"
|
|
@@ -717,11 +838,20 @@ class FunctionModel(IsDatabricksResource, HasFullName):
|
|
|
717
838
|
|
|
718
839
|
|
|
719
840
|
class WarehouseModel(IsDatabricksResource):
|
|
720
|
-
model_config = ConfigDict()
|
|
721
|
-
name: str
|
|
841
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
842
|
+
name: Optional[str] = None
|
|
722
843
|
description: Optional[str] = None
|
|
723
844
|
warehouse_id: AnyVariable
|
|
724
845
|
|
|
846
|
+
_warehouse_details: Optional[GetWarehouseResponse] = PrivateAttr(default=None)
|
|
847
|
+
|
|
848
|
+
def _get_warehouse_details(self) -> GetWarehouseResponse:
|
|
849
|
+
if self._warehouse_details is None:
|
|
850
|
+
self._warehouse_details = self.workspace_client.warehouses.get(
|
|
851
|
+
id=value_of(self.warehouse_id)
|
|
852
|
+
)
|
|
853
|
+
return self._warehouse_details
|
|
854
|
+
|
|
725
855
|
@property
|
|
726
856
|
def api_scopes(self) -> Sequence[str]:
|
|
727
857
|
return [
|
|
@@ -742,10 +872,22 @@ class WarehouseModel(IsDatabricksResource):
|
|
|
742
872
|
self.warehouse_id = value_of(self.warehouse_id)
|
|
743
873
|
return self
|
|
744
874
|
|
|
875
|
+
@model_validator(mode="after")
|
|
876
|
+
def populate_name(self) -> Self:
|
|
877
|
+
"""Populate name from warehouse details if not provided."""
|
|
878
|
+
if self.warehouse_id and not self.name:
|
|
879
|
+
try:
|
|
880
|
+
warehouse_details = self._get_warehouse_details()
|
|
881
|
+
if warehouse_details.name:
|
|
882
|
+
self.name = warehouse_details.name
|
|
883
|
+
except Exception as e:
|
|
884
|
+
logger.debug(f"Could not fetch details from warehouse: {e}")
|
|
885
|
+
return self
|
|
886
|
+
|
|
745
887
|
|
|
746
888
|
class GenieRoomModel(IsDatabricksResource):
|
|
747
889
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
748
|
-
name: str
|
|
890
|
+
name: Optional[str] = None
|
|
749
891
|
description: Optional[str] = None
|
|
750
892
|
space_id: AnyVariable
|
|
751
893
|
|
|
@@ -801,10 +943,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
801
943
|
pat=self.pat,
|
|
802
944
|
)
|
|
803
945
|
|
|
804
|
-
# Share the cached workspace client if available
|
|
805
|
-
if self._workspace_client is not None:
|
|
806
|
-
warehouse_model._workspace_client = self._workspace_client
|
|
807
|
-
|
|
808
946
|
return warehouse_model
|
|
809
947
|
except Exception as e:
|
|
810
948
|
logger.warning(
|
|
@@ -848,9 +986,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
848
986
|
workspace_host=self.workspace_host,
|
|
849
987
|
pat=self.pat,
|
|
850
988
|
)
|
|
851
|
-
# Share the cached workspace client if available
|
|
852
|
-
if self._workspace_client is not None:
|
|
853
|
-
table_model._workspace_client = self._workspace_client
|
|
854
989
|
|
|
855
990
|
# Verify the table exists before adding
|
|
856
991
|
if not table_model.exists():
|
|
@@ -888,9 +1023,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
888
1023
|
workspace_host=self.workspace_host,
|
|
889
1024
|
pat=self.pat,
|
|
890
1025
|
)
|
|
891
|
-
# Share the cached workspace client if available
|
|
892
|
-
if self._workspace_client is not None:
|
|
893
|
-
function_model._workspace_client = self._workspace_client
|
|
894
1026
|
|
|
895
1027
|
# Verify the function exists before adding
|
|
896
1028
|
if not function_model.exists():
|
|
@@ -954,15 +1086,17 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
954
1086
|
return self
|
|
955
1087
|
|
|
956
1088
|
@model_validator(mode="after")
|
|
957
|
-
def
|
|
958
|
-
"""Populate description from GenieSpace if not provided."""
|
|
959
|
-
if not self.description:
|
|
1089
|
+
def populate_name_and_description(self) -> Self:
|
|
1090
|
+
"""Populate name and description from GenieSpace if not provided."""
|
|
1091
|
+
if self.space_id and (not self.name or not self.description):
|
|
960
1092
|
try:
|
|
961
1093
|
space_details = self._get_space_details()
|
|
962
|
-
if space_details.
|
|
1094
|
+
if not self.name and space_details.title:
|
|
1095
|
+
self.name = space_details.title
|
|
1096
|
+
if not self.description and space_details.description:
|
|
963
1097
|
self.description = space_details.description
|
|
964
1098
|
except Exception as e:
|
|
965
|
-
logger.debug(f"Could not fetch
|
|
1099
|
+
logger.debug(f"Could not fetch details from Genie space: {e}")
|
|
966
1100
|
return self
|
|
967
1101
|
|
|
968
1102
|
|
|
@@ -998,7 +1132,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
998
1132
|
path: Optional[str] = None
|
|
999
1133
|
|
|
1000
1134
|
@model_validator(mode="after")
|
|
1001
|
-
def validate_path_or_volume(self) ->
|
|
1135
|
+
def validate_path_or_volume(self) -> Self:
|
|
1002
1136
|
if not self.volume and not self.path:
|
|
1003
1137
|
raise ValueError("Either 'volume' or 'path' must be provided")
|
|
1004
1138
|
return self
|
|
@@ -1272,13 +1406,20 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1272
1406
|
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1273
1407
|
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1274
1408
|
|
|
1275
|
-
Note:
|
|
1409
|
+
Note: For Lakebase connections, `name` is optional and defaults to `instance_name`.
|
|
1410
|
+
For PostgreSQL connections, `name` is required.
|
|
1411
|
+
|
|
1412
|
+
Example Databricks Lakebase (minimal):
|
|
1413
|
+
```yaml
|
|
1414
|
+
databases:
|
|
1415
|
+
my_lakebase:
|
|
1416
|
+
instance_name: my-lakebase-instance # name defaults to instance_name
|
|
1417
|
+
```
|
|
1276
1418
|
|
|
1277
1419
|
Example Databricks Lakebase with Service Principal:
|
|
1278
1420
|
```yaml
|
|
1279
1421
|
databases:
|
|
1280
1422
|
my_lakebase:
|
|
1281
|
-
name: my-database
|
|
1282
1423
|
instance_name: my-lakebase-instance
|
|
1283
1424
|
service_principal:
|
|
1284
1425
|
client_id:
|
|
@@ -1294,7 +1435,6 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1294
1435
|
```yaml
|
|
1295
1436
|
databases:
|
|
1296
1437
|
my_lakebase:
|
|
1297
|
-
name: my-database
|
|
1298
1438
|
instance_name: my-lakebase-instance
|
|
1299
1439
|
on_behalf_of_user: true
|
|
1300
1440
|
```
|
|
@@ -1314,7 +1454,7 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1314
1454
|
"""
|
|
1315
1455
|
|
|
1316
1456
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1317
|
-
name: str
|
|
1457
|
+
name: Optional[str] = None
|
|
1318
1458
|
instance_name: Optional[str] = None
|
|
1319
1459
|
description: Optional[str] = None
|
|
1320
1460
|
host: Optional[AnyVariable] = None
|
|
@@ -1363,6 +1503,17 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1363
1503
|
)
|
|
1364
1504
|
return self
|
|
1365
1505
|
|
|
1506
|
+
@model_validator(mode="after")
|
|
1507
|
+
def populate_name_from_instance_name(self) -> Self:
|
|
1508
|
+
"""Populate name from instance_name if not provided for Lakebase connections."""
|
|
1509
|
+
if self.name is None and self.instance_name:
|
|
1510
|
+
self.name = self.instance_name
|
|
1511
|
+
elif self.name is None:
|
|
1512
|
+
raise ValueError(
|
|
1513
|
+
"Either 'name' or 'instance_name' must be provided for DatabaseModel."
|
|
1514
|
+
)
|
|
1515
|
+
return self
|
|
1516
|
+
|
|
1366
1517
|
@model_validator(mode="after")
|
|
1367
1518
|
def update_user(self) -> Self:
|
|
1368
1519
|
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
@@ -1460,10 +1611,10 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1460
1611
|
username: str | None = None
|
|
1461
1612
|
password_value: str | None = None
|
|
1462
1613
|
|
|
1463
|
-
# Resolve host -
|
|
1614
|
+
# Resolve host - fetch from API at runtime for Lakebase if not provided
|
|
1464
1615
|
host_value: Any = self.host
|
|
1465
|
-
if host_value is None and self.is_lakebase
|
|
1466
|
-
# Fetch host
|
|
1616
|
+
if host_value is None and self.is_lakebase:
|
|
1617
|
+
# Fetch host from Lakebase instance API
|
|
1467
1618
|
existing_instance: DatabaseInstance = (
|
|
1468
1619
|
self.workspace_client.database.get_database_instance(
|
|
1469
1620
|
name=self.instance_name
|
|
@@ -1563,7 +1714,7 @@ class GenieLRUCacheParametersModel(BaseModel):
|
|
|
1563
1714
|
warehouse: WarehouseModel
|
|
1564
1715
|
|
|
1565
1716
|
|
|
1566
|
-
class
|
|
1717
|
+
class GenieContextAwareCacheParametersModel(BaseModel):
|
|
1567
1718
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1568
1719
|
time_to_live_seconds: int | None = (
|
|
1569
1720
|
60 * 60 * 24
|
|
@@ -1581,6 +1732,116 @@ class GenieSemanticCacheParametersModel(BaseModel):
|
|
|
1581
1732
|
database: DatabaseModel
|
|
1582
1733
|
warehouse: WarehouseModel
|
|
1583
1734
|
table_name: str = "genie_semantic_cache"
|
|
1735
|
+
context_window_size: int = 2 # Number of previous turns to include for context
|
|
1736
|
+
max_context_tokens: int = (
|
|
1737
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1738
|
+
)
|
|
1739
|
+
# Prompt history configuration
|
|
1740
|
+
# Prompt history is always enabled - it stores all user prompts to maintain
|
|
1741
|
+
# conversation context for accurate semantic matching even when cache hits occur
|
|
1742
|
+
prompt_history_table: str = "genie_prompt_history" # Table name for prompt history
|
|
1743
|
+
max_prompt_history_length: int = 50 # Maximum prompts to keep per conversation
|
|
1744
|
+
use_genie_api_for_history: bool = (
|
|
1745
|
+
False # Fallback to Genie API if local history empty
|
|
1746
|
+
)
|
|
1747
|
+
prompt_history_ttl_seconds: int | None = (
|
|
1748
|
+
None # TTL for prompts (None = use cache TTL)
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1751
|
+
@model_validator(mode="after")
|
|
1752
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1753
|
+
"""
|
|
1754
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1755
|
+
|
|
1756
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1757
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1758
|
+
If both are provided, they must sum to 1.0.
|
|
1759
|
+
"""
|
|
1760
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1761
|
+
# Both missing - use defaults
|
|
1762
|
+
self.question_weight = 0.6
|
|
1763
|
+
self.context_weight = 0.4
|
|
1764
|
+
elif self.question_weight is None:
|
|
1765
|
+
# Compute question_weight from context_weight
|
|
1766
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1767
|
+
raise ValueError(
|
|
1768
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1769
|
+
)
|
|
1770
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1771
|
+
elif self.context_weight is None:
|
|
1772
|
+
# Compute context_weight from question_weight
|
|
1773
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1774
|
+
raise ValueError(
|
|
1775
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1776
|
+
)
|
|
1777
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1778
|
+
else:
|
|
1779
|
+
# Both provided - validate they sum to 1.0
|
|
1780
|
+
total_weight = self.question_weight + self.context_weight
|
|
1781
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1782
|
+
raise ValueError(
|
|
1783
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1784
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1785
|
+
f"of question vs context similarity in the combined score."
|
|
1786
|
+
)
|
|
1787
|
+
|
|
1788
|
+
return self
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
# Memory estimation for capacity planning:
|
|
1792
|
+
# - Each entry: ~20KB (8KB question embedding + 8KB context embedding + 4KB strings/overhead)
|
|
1793
|
+
# - 1,000 entries: ~20MB (0.4% of 8GB)
|
|
1794
|
+
# - 5,000 entries: ~100MB (2% of 8GB)
|
|
1795
|
+
# - 10,000 entries: ~200MB (4-5% of 8GB) - default for ~30 users
|
|
1796
|
+
# - 20,000 entries: ~400MB (8-10% of 8GB)
|
|
1797
|
+
# Default 10,000 entries provides ~330 queries per user for 30 users.
|
|
1798
|
+
class GenieInMemorySemanticCacheParametersModel(BaseModel):
|
|
1799
|
+
"""
|
|
1800
|
+
Configuration for in-memory semantic cache (no database required).
|
|
1801
|
+
|
|
1802
|
+
This cache stores embeddings and cache entries entirely in memory, providing
|
|
1803
|
+
semantic similarity matching without requiring external database dependencies
|
|
1804
|
+
like PostgreSQL or Databricks Lakebase.
|
|
1805
|
+
|
|
1806
|
+
Default settings are tuned for ~30 users on an 8GB machine:
|
|
1807
|
+
- Capacity: 10,000 entries (~200MB memory, ~330 queries per user)
|
|
1808
|
+
- Eviction: LRU (Least Recently Used) - keeps frequently accessed queries
|
|
1809
|
+
- TTL: 1 week (accommodates weekly work patterns and batch jobs)
|
|
1810
|
+
- Memory overhead: ~4-5% of 8GB system
|
|
1811
|
+
|
|
1812
|
+
The LRU eviction strategy ensures hot queries stay cached while cold queries
|
|
1813
|
+
are evicted, providing better hit rates than FIFO eviction.
|
|
1814
|
+
|
|
1815
|
+
For larger deployments or memory-constrained environments, adjust capacity and TTL accordingly.
|
|
1816
|
+
|
|
1817
|
+
Use this when:
|
|
1818
|
+
- No external database access is available
|
|
1819
|
+
- Single-instance deployments (cache not shared across instances)
|
|
1820
|
+
- Cache persistence across restarts is not required
|
|
1821
|
+
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
1822
|
+
|
|
1823
|
+
For multi-instance deployments or large cache sizes, use GenieContextAwareCacheParametersModel
|
|
1824
|
+
with PostgreSQL backend instead.
|
|
1825
|
+
"""
|
|
1826
|
+
|
|
1827
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1828
|
+
time_to_live_seconds: int | None = (
|
|
1829
|
+
60 * 60 * 24 * 7
|
|
1830
|
+
) # 1 week default (604800 seconds), None or negative = never expires
|
|
1831
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1832
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1833
|
+
question_weight: Optional[float] = (
|
|
1834
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1835
|
+
)
|
|
1836
|
+
context_weight: Optional[float] = (
|
|
1837
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1838
|
+
)
|
|
1839
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1840
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1841
|
+
warehouse: WarehouseModel
|
|
1842
|
+
capacity: int | None = (
|
|
1843
|
+
10000 # Maximum cache entries. ~200MB for 10000 entries (1024-dim embeddings). LRU eviction when full. None = unlimited (not recommended for production).
|
|
1844
|
+
)
|
|
1584
1845
|
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1585
1846
|
max_context_tokens: int = (
|
|
1586
1847
|
2000 # Maximum context length to prevent extremely long embeddings
|
|
@@ -1633,43 +1894,83 @@ class SearchParametersModel(BaseModel):
|
|
|
1633
1894
|
query_type: Optional[str] = "ANN"
|
|
1634
1895
|
|
|
1635
1896
|
|
|
1897
|
+
class InstructionAwareRerankModel(BaseModel):
|
|
1898
|
+
"""
|
|
1899
|
+
LLM-based reranking considering user instructions and constraints.
|
|
1900
|
+
|
|
1901
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~100ms).
|
|
1902
|
+
Runs AFTER FlashRank as an additional constraint-aware reranking stage.
|
|
1903
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
1904
|
+
|
|
1905
|
+
Example:
|
|
1906
|
+
```yaml
|
|
1907
|
+
rerank:
|
|
1908
|
+
model: ms-marco-MiniLM-L-12-v2
|
|
1909
|
+
top_n: 20
|
|
1910
|
+
instruction_aware:
|
|
1911
|
+
model: *fast_llm
|
|
1912
|
+
instructions: |
|
|
1913
|
+
Prioritize results matching price and brand constraints.
|
|
1914
|
+
top_n: 10
|
|
1915
|
+
```
|
|
1916
|
+
"""
|
|
1917
|
+
|
|
1918
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1919
|
+
|
|
1920
|
+
model: Optional["LLMModel"] = Field(
|
|
1921
|
+
default=None,
|
|
1922
|
+
description="LLM for instruction reranking (fast model recommended)",
|
|
1923
|
+
)
|
|
1924
|
+
instructions: Optional[str] = Field(
|
|
1925
|
+
default=None,
|
|
1926
|
+
description="Custom reranking instructions for constraint prioritization",
|
|
1927
|
+
)
|
|
1928
|
+
top_n: Optional[int] = Field(
|
|
1929
|
+
default=None,
|
|
1930
|
+
description="Number of documents to return after instruction reranking",
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
|
|
1636
1934
|
class RerankParametersModel(BaseModel):
|
|
1637
1935
|
"""
|
|
1638
|
-
Configuration for reranking retrieved documents
|
|
1936
|
+
Configuration for reranking retrieved documents.
|
|
1639
1937
|
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1938
|
+
Supports three reranking options that can be combined:
|
|
1939
|
+
1. FlashRank (local cross-encoder) - set `model`
|
|
1940
|
+
2. Databricks server-side reranking - set `columns`
|
|
1941
|
+
3. LLM instruction-aware reranking - set `instruction_aware`
|
|
1643
1942
|
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1943
|
+
Example with Databricks columns + instruction-aware (no FlashRank):
|
|
1944
|
+
```yaml
|
|
1945
|
+
rerank:
|
|
1946
|
+
columns: # Databricks server-side reranking
|
|
1947
|
+
- product_name
|
|
1948
|
+
- brand_name
|
|
1949
|
+
instruction_aware: # LLM-based constraint reranking
|
|
1950
|
+
model: *fast_llm
|
|
1951
|
+
instructions: "Prioritize by brand preferences"
|
|
1952
|
+
top_n: 10
|
|
1953
|
+
```
|
|
1648
1954
|
|
|
1649
|
-
Example:
|
|
1955
|
+
Example with FlashRank:
|
|
1650
1956
|
```yaml
|
|
1651
|
-
|
|
1652
|
-
|
|
1653
|
-
|
|
1654
|
-
rerank:
|
|
1655
|
-
model: ms-marco-MiniLM-L-12-v2
|
|
1656
|
-
top_n: 5 # Return top 5 after reranking
|
|
1957
|
+
rerank:
|
|
1958
|
+
model: ms-marco-MiniLM-L-12-v2 # FlashRank model
|
|
1959
|
+
top_n: 10
|
|
1657
1960
|
```
|
|
1658
1961
|
|
|
1659
|
-
Available models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1962
|
+
Available FlashRank models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1660
1963
|
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1661
|
-
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder
|
|
1964
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder)
|
|
1662
1965
|
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1663
1966
|
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1664
|
-
- "ce-esci-MiniLM-L12-v2" (e-commerce optimized, Amazon ESCI)
|
|
1665
|
-
- "miniReranker_arabic_v1" (Arabic language)
|
|
1666
1967
|
"""
|
|
1667
1968
|
|
|
1668
1969
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1669
1970
|
|
|
1670
|
-
model: str = Field(
|
|
1671
|
-
default=
|
|
1672
|
-
description="FlashRank model name.
|
|
1971
|
+
model: Optional[str] = Field(
|
|
1972
|
+
default=None,
|
|
1973
|
+
description="FlashRank model name. If None, FlashRank is not used (use columns for Databricks reranking).",
|
|
1673
1974
|
)
|
|
1674
1975
|
top_n: Optional[int] = Field(
|
|
1675
1976
|
default=None,
|
|
@@ -1682,6 +1983,289 @@ class RerankParametersModel(BaseModel):
|
|
|
1682
1983
|
columns: Optional[list[str]] = Field(
|
|
1683
1984
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
1684
1985
|
)
|
|
1986
|
+
instruction_aware: Optional[InstructionAwareRerankModel] = Field(
|
|
1987
|
+
default=None,
|
|
1988
|
+
description="Optional LLM-based reranking stage after FlashRank",
|
|
1989
|
+
)
|
|
1990
|
+
|
|
1991
|
+
|
|
1992
|
+
class FilterItem(BaseModel):
|
|
1993
|
+
"""A metadata filter for vector search.
|
|
1994
|
+
|
|
1995
|
+
Filters constrain search results by matching column values.
|
|
1996
|
+
Use column names from the provided schema description.
|
|
1997
|
+
"""
|
|
1998
|
+
|
|
1999
|
+
model_config = ConfigDict(extra="forbid")
|
|
2000
|
+
key: str = Field(
|
|
2001
|
+
description=(
|
|
2002
|
+
"Column name with optional operator suffix. "
|
|
2003
|
+
"Operators: (none) for equality, NOT for exclusion, "
|
|
2004
|
+
"< <= > >= for numeric comparison, "
|
|
2005
|
+
"LIKE for token match, NOT LIKE to exclude tokens."
|
|
2006
|
+
)
|
|
2007
|
+
)
|
|
2008
|
+
value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
|
|
2009
|
+
description=(
|
|
2010
|
+
"The filter value matching the column type. "
|
|
2011
|
+
"Use an array for IN-style matching multiple values."
|
|
2012
|
+
)
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
|
|
2016
|
+
class SearchQuery(BaseModel):
|
|
2017
|
+
"""A single search query with optional metadata filters.
|
|
2018
|
+
|
|
2019
|
+
Represents one focused search intent extracted from the user's request.
|
|
2020
|
+
The text should be a natural language query optimized for semantic search.
|
|
2021
|
+
Filters constrain results to match specific metadata values.
|
|
2022
|
+
"""
|
|
2023
|
+
|
|
2024
|
+
model_config = ConfigDict(extra="forbid")
|
|
2025
|
+
text: str = Field(
|
|
2026
|
+
description=(
|
|
2027
|
+
"Natural language search query text optimized for semantic similarity. "
|
|
2028
|
+
"Should be focused on a single search intent. "
|
|
2029
|
+
"Do NOT include filter criteria in the text; use the filters field instead."
|
|
2030
|
+
)
|
|
2031
|
+
)
|
|
2032
|
+
filters: Optional[list[FilterItem]] = Field(
|
|
2033
|
+
default=None,
|
|
2034
|
+
description=(
|
|
2035
|
+
"Metadata filters to constrain search results. "
|
|
2036
|
+
"Set to null if no filters apply. "
|
|
2037
|
+
"Extract filter values from explicit constraints in the user query."
|
|
2038
|
+
),
|
|
2039
|
+
)
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
class DecomposedQueries(BaseModel):
|
|
2043
|
+
"""Decomposed search queries extracted from a user request.
|
|
2044
|
+
|
|
2045
|
+
Break down complex user queries into multiple focused search queries.
|
|
2046
|
+
Each query targets a distinct search intent with appropriate filters.
|
|
2047
|
+
Generate 1-3 queries depending on the complexity of the user request.
|
|
2048
|
+
"""
|
|
2049
|
+
|
|
2050
|
+
model_config = ConfigDict(extra="forbid")
|
|
2051
|
+
queries: list[SearchQuery] = Field(
|
|
2052
|
+
description=(
|
|
2053
|
+
"List of search queries extracted from the user request. "
|
|
2054
|
+
"Each query should target a distinct search intent. "
|
|
2055
|
+
"Order queries by importance, with the most relevant first."
|
|
2056
|
+
)
|
|
2057
|
+
)
|
|
2058
|
+
|
|
2059
|
+
|
|
2060
|
+
class ColumnInfo(BaseModel):
|
|
2061
|
+
"""Column metadata for dynamic schema generation in structured output.
|
|
2062
|
+
|
|
2063
|
+
When provided, column information is embedded directly into the JSON schema
|
|
2064
|
+
that with_structured_output sends to the LLM, improving filter accuracy.
|
|
2065
|
+
"""
|
|
2066
|
+
|
|
2067
|
+
model_config = ConfigDict(extra="forbid")
|
|
2068
|
+
|
|
2069
|
+
name: str = Field(description="Column name as it appears in the database")
|
|
2070
|
+
type: Literal["string", "number", "boolean", "datetime"] = Field(
|
|
2071
|
+
default="string",
|
|
2072
|
+
description="Column data type for value validation",
|
|
2073
|
+
)
|
|
2074
|
+
operators: list[str] = Field(
|
|
2075
|
+
default=["", "NOT", "<", "<=", ">", ">=", "LIKE", "NOT LIKE"],
|
|
2076
|
+
description="Valid filter operators for this column",
|
|
2077
|
+
)
|
|
2078
|
+
|
|
2079
|
+
|
|
2080
|
+
class InstructedRetrieverModel(BaseModel):
|
|
2081
|
+
"""
|
|
2082
|
+
Configuration for instructed retrieval with query decomposition and RRF merging.
|
|
2083
|
+
|
|
2084
|
+
Instructed retrieval decomposes user queries into multiple subqueries with
|
|
2085
|
+
metadata filters, executes them in parallel, and merges results using
|
|
2086
|
+
Reciprocal Rank Fusion (RRF) before reranking.
|
|
2087
|
+
|
|
2088
|
+
Example:
|
|
2089
|
+
```yaml
|
|
2090
|
+
retriever:
|
|
2091
|
+
vector_store: *products_vector_store
|
|
2092
|
+
instructed:
|
|
2093
|
+
decomposition_model: *fast_llm
|
|
2094
|
+
schema_description: |
|
|
2095
|
+
Products table: product_id, brand_name, category, price, updated_at
|
|
2096
|
+
Filter operators: {"col": val}, {"col >": val}, {"col NOT": val}
|
|
2097
|
+
columns:
|
|
2098
|
+
- name: brand_name
|
|
2099
|
+
type: string
|
|
2100
|
+
- name: price
|
|
2101
|
+
type: number
|
|
2102
|
+
operators: ["", "<", "<=", ">", ">="]
|
|
2103
|
+
constraints:
|
|
2104
|
+
- "Prefer recent products"
|
|
2105
|
+
max_subqueries: 3
|
|
2106
|
+
examples:
|
|
2107
|
+
- query: "cheap drills"
|
|
2108
|
+
filters: {"price <": 100}
|
|
2109
|
+
```
|
|
2110
|
+
"""
|
|
2111
|
+
|
|
2112
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2113
|
+
|
|
2114
|
+
decomposition_model: Optional["LLMModel"] = Field(
|
|
2115
|
+
default=None,
|
|
2116
|
+
description="LLM for query decomposition (smaller/faster model recommended)",
|
|
2117
|
+
)
|
|
2118
|
+
schema_description: str = Field(
|
|
2119
|
+
description="Column names, types, and valid filter syntax for the LLM"
|
|
2120
|
+
)
|
|
2121
|
+
columns: Optional[list[ColumnInfo]] = Field(
|
|
2122
|
+
default=None,
|
|
2123
|
+
description=(
|
|
2124
|
+
"Structured column info for dynamic schema generation. "
|
|
2125
|
+
"When provided, column names are embedded in the JSON schema for better LLM accuracy."
|
|
2126
|
+
),
|
|
2127
|
+
)
|
|
2128
|
+
constraints: Optional[list[str]] = Field(
|
|
2129
|
+
default=None, description="Default constraints to always apply"
|
|
2130
|
+
)
|
|
2131
|
+
max_subqueries: int = Field(
|
|
2132
|
+
default=3, description="Maximum number of parallel subqueries"
|
|
2133
|
+
)
|
|
2134
|
+
rrf_k: int = Field(
|
|
2135
|
+
default=60,
|
|
2136
|
+
description="RRF constant (lower values weight top ranks more heavily)",
|
|
2137
|
+
)
|
|
2138
|
+
examples: Optional[list[dict[str, Any]]] = Field(
|
|
2139
|
+
default=None,
|
|
2140
|
+
description="Few-shot examples for domain-specific filter translation",
|
|
2141
|
+
)
|
|
2142
|
+
normalize_filter_case: Optional[Literal["uppercase", "lowercase"]] = Field(
|
|
2143
|
+
default=None,
|
|
2144
|
+
description="Auto-normalize filter string values to uppercase or lowercase",
|
|
2145
|
+
)
|
|
2146
|
+
|
|
2147
|
+
|
|
2148
|
+
class RouterModel(BaseModel):
|
|
2149
|
+
"""
|
|
2150
|
+
Select internal execution mode based on query characteristics.
|
|
2151
|
+
|
|
2152
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2153
|
+
Routes to internal modes within the same retriever, not external retrievers.
|
|
2154
|
+
Cross-index routing belongs at the agent/tool-selection level.
|
|
2155
|
+
|
|
2156
|
+
Execution Modes:
|
|
2157
|
+
- "standard": Single similarity_search() for simple keyword/product searches
|
|
2158
|
+
- "instructed": Decompose -> Parallel Search -> RRF for constrained queries
|
|
2159
|
+
|
|
2160
|
+
Example:
|
|
2161
|
+
```yaml
|
|
2162
|
+
retriever:
|
|
2163
|
+
router:
|
|
2164
|
+
model: *fast_llm
|
|
2165
|
+
default_mode: standard
|
|
2166
|
+
auto_bypass: true
|
|
2167
|
+
```
|
|
2168
|
+
"""
|
|
2169
|
+
|
|
2170
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2171
|
+
|
|
2172
|
+
model: Optional["LLMModel"] = Field(
|
|
2173
|
+
default=None,
|
|
2174
|
+
description="LLM for routing decision (fast model recommended)",
|
|
2175
|
+
)
|
|
2176
|
+
default_mode: Literal["standard", "instructed"] = Field(
|
|
2177
|
+
default="standard",
|
|
2178
|
+
description="Fallback mode if routing fails",
|
|
2179
|
+
)
|
|
2180
|
+
auto_bypass: bool = Field(
|
|
2181
|
+
default=True,
|
|
2182
|
+
description="Skip Instruction Reranker and Verifier for standard mode",
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
class VerificationResult(BaseModel):
|
|
2187
|
+
"""Verification of whether search results satisfy the user's constraints.
|
|
2188
|
+
|
|
2189
|
+
Analyze the retrieved results against the original query and any explicit
|
|
2190
|
+
constraints to determine if a retry with modified filters is needed.
|
|
2191
|
+
"""
|
|
2192
|
+
|
|
2193
|
+
model_config = ConfigDict(extra="forbid")
|
|
2194
|
+
|
|
2195
|
+
passed: bool = Field(
|
|
2196
|
+
description="True if results satisfy the user's query intent and constraints."
|
|
2197
|
+
)
|
|
2198
|
+
confidence: float = Field(
|
|
2199
|
+
ge=0.0,
|
|
2200
|
+
le=1.0,
|
|
2201
|
+
description="Confidence in the verification decision, from 0.0 (uncertain) to 1.0 (certain).",
|
|
2202
|
+
)
|
|
2203
|
+
feedback: Optional[str] = Field(
|
|
2204
|
+
default=None,
|
|
2205
|
+
description="Explanation of why verification passed or failed. Include specific issues found.",
|
|
2206
|
+
)
|
|
2207
|
+
suggested_filter_relaxation: Optional[dict[str, Any]] = Field(
|
|
2208
|
+
default=None,
|
|
2209
|
+
description=(
|
|
2210
|
+
"Suggested filter modifications for retry. "
|
|
2211
|
+
"Keys are column names, values indicate changes (e.g., 'REMOVE', 'WIDEN', or new values)."
|
|
2212
|
+
),
|
|
2213
|
+
)
|
|
2214
|
+
unmet_constraints: Optional[list[str]] = Field(
|
|
2215
|
+
default=None,
|
|
2216
|
+
description="List of user constraints that the results failed to satisfy.",
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
|
|
2220
|
+
class VerifierModel(BaseModel):
|
|
2221
|
+
"""
|
|
2222
|
+
Validate results against user constraints with structured feedback.
|
|
2223
|
+
|
|
2224
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2225
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
2226
|
+
Returns structured feedback for intelligent retry, not blind retry.
|
|
2227
|
+
|
|
2228
|
+
Example:
|
|
2229
|
+
```yaml
|
|
2230
|
+
retriever:
|
|
2231
|
+
verifier:
|
|
2232
|
+
model: *fast_llm
|
|
2233
|
+
on_failure: warn_and_retry
|
|
2234
|
+
max_retries: 1
|
|
2235
|
+
```
|
|
2236
|
+
"""
|
|
2237
|
+
|
|
2238
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2239
|
+
|
|
2240
|
+
model: Optional["LLMModel"] = Field(
|
|
2241
|
+
default=None,
|
|
2242
|
+
description="LLM for verification (fast model recommended)",
|
|
2243
|
+
)
|
|
2244
|
+
on_failure: Literal["warn", "retry", "warn_and_retry"] = Field(
|
|
2245
|
+
default="warn",
|
|
2246
|
+
description="Behavior when verification fails",
|
|
2247
|
+
)
|
|
2248
|
+
max_retries: int = Field(
|
|
2249
|
+
default=1,
|
|
2250
|
+
description="Maximum retry attempts before returning with warning",
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
|
|
2254
|
+
class RankedDocument(BaseModel):
|
|
2255
|
+
"""Single ranked document."""
|
|
2256
|
+
|
|
2257
|
+
index: int = Field(description="Document index from input list")
|
|
2258
|
+
score: float = Field(description="0.0-1.0 relevance score")
|
|
2259
|
+
reason: str = Field(default="", description="Why this score")
|
|
2260
|
+
|
|
2261
|
+
|
|
2262
|
+
class RankingResult(BaseModel):
|
|
2263
|
+
"""Reranking output."""
|
|
2264
|
+
|
|
2265
|
+
rankings: list[RankedDocument] = Field(
|
|
2266
|
+
default_factory=list,
|
|
2267
|
+
description="Ranked documents, highest score first",
|
|
2268
|
+
)
|
|
1685
2269
|
|
|
1686
2270
|
|
|
1687
2271
|
class RetrieverModel(BaseModel):
|
|
@@ -1691,10 +2275,22 @@ class RetrieverModel(BaseModel):
|
|
|
1691
2275
|
search_parameters: SearchParametersModel = Field(
|
|
1692
2276
|
default_factory=SearchParametersModel
|
|
1693
2277
|
)
|
|
2278
|
+
router: Optional[RouterModel] = Field(
|
|
2279
|
+
default=None,
|
|
2280
|
+
description="Optional query router for selecting execution mode (standard vs instructed).",
|
|
2281
|
+
)
|
|
1694
2282
|
rerank: Optional[RerankParametersModel | bool] = Field(
|
|
1695
2283
|
default=None,
|
|
1696
2284
|
description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
|
|
1697
2285
|
)
|
|
2286
|
+
instructed: Optional[InstructedRetrieverModel] = Field(
|
|
2287
|
+
default=None,
|
|
2288
|
+
description="Optional instructed retrieval with query decomposition and RRF merging.",
|
|
2289
|
+
)
|
|
2290
|
+
verifier: Optional[VerifierModel] = Field(
|
|
2291
|
+
default=None,
|
|
2292
|
+
description="Optional result verification with structured feedback for retry.",
|
|
2293
|
+
)
|
|
1698
2294
|
|
|
1699
2295
|
@model_validator(mode="after")
|
|
1700
2296
|
def set_default_columns(self) -> Self:
|
|
@@ -1705,9 +2301,13 @@ class RetrieverModel(BaseModel):
|
|
|
1705
2301
|
|
|
1706
2302
|
@model_validator(mode="after")
|
|
1707
2303
|
def set_default_reranker(self) -> Self:
|
|
1708
|
-
"""Convert bool to ReRankParametersModel with defaults.
|
|
2304
|
+
"""Convert bool to ReRankParametersModel with defaults.
|
|
2305
|
+
|
|
2306
|
+
When rerank: true is used, sets the default FlashRank model
|
|
2307
|
+
(ms-marco-MiniLM-L-12-v2) to enable reranking.
|
|
2308
|
+
"""
|
|
1709
2309
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
1710
|
-
self.rerank = RerankParametersModel()
|
|
2310
|
+
self.rerank = RerankParametersModel(model="ms-marco-MiniLM-L-12-v2")
|
|
1711
2311
|
return self
|
|
1712
2312
|
|
|
1713
2313
|
|
|
@@ -1840,11 +2440,32 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1840
2440
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1841
2441
|
args: list[str] = Field(default_factory=list)
|
|
1842
2442
|
# MCP-specific fields
|
|
2443
|
+
app: Optional[DatabricksAppModel] = None
|
|
1843
2444
|
connection: Optional[ConnectionModel] = None
|
|
1844
2445
|
functions: Optional[SchemaModel] = None
|
|
1845
2446
|
genie_room: Optional[GenieRoomModel] = None
|
|
1846
2447
|
sql: Optional[bool] = None
|
|
1847
2448
|
vector_search: Optional[VectorStoreModel] = None
|
|
2449
|
+
# Tool filtering
|
|
2450
|
+
include_tools: Optional[list[str]] = Field(
|
|
2451
|
+
default=None,
|
|
2452
|
+
description=(
|
|
2453
|
+
"Optional list of tool names or glob patterns to include from the MCP server. "
|
|
2454
|
+
"If specified, only tools matching these patterns will be loaded. "
|
|
2455
|
+
"Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
|
|
2456
|
+
"Examples: ['execute_query', 'list_*', 'get_?_data']"
|
|
2457
|
+
),
|
|
2458
|
+
)
|
|
2459
|
+
exclude_tools: Optional[list[str]] = Field(
|
|
2460
|
+
default=None,
|
|
2461
|
+
description=(
|
|
2462
|
+
"Optional list of tool names or glob patterns to exclude from the MCP server. "
|
|
2463
|
+
"Tools matching these patterns will not be loaded. "
|
|
2464
|
+
"Takes precedence over include_tools. "
|
|
2465
|
+
"Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
|
|
2466
|
+
"Examples: ['drop_*', 'delete_*', 'execute_ddl']"
|
|
2467
|
+
),
|
|
2468
|
+
)
|
|
1848
2469
|
|
|
1849
2470
|
@property
|
|
1850
2471
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -1907,6 +2528,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1907
2528
|
|
|
1908
2529
|
Returns the URL based on the configured source:
|
|
1909
2530
|
- If url is set, returns it directly
|
|
2531
|
+
- If app is set, retrieves URL from Databricks App via workspace client
|
|
1910
2532
|
- If connection is set, constructs URL from connection
|
|
1911
2533
|
- If genie_room is set, constructs Genie MCP URL
|
|
1912
2534
|
- If sql is set, constructs DBSQL MCP URL (serverless)
|
|
@@ -1919,6 +2541,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1919
2541
|
- Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
|
|
1920
2542
|
- UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
|
|
1921
2543
|
- Connection: https://{host}/api/2.0/mcp/external/{connection_name}
|
|
2544
|
+
- Databricks App: Retrieved dynamically from workspace
|
|
1922
2545
|
"""
|
|
1923
2546
|
# Direct URL provided
|
|
1924
2547
|
if self.url:
|
|
@@ -1941,6 +2564,49 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1941
2564
|
if self.sql:
|
|
1942
2565
|
return f"{workspace_host}/api/2.0/mcp/sql"
|
|
1943
2566
|
|
|
2567
|
+
# Databricks App - MCP endpoint is at {app_url}/mcp
|
|
2568
|
+
# Try McpFunctionModel's workspace_client first (which may have credentials),
|
|
2569
|
+
# then fall back to DatabricksAppModel.url property (which uses its own workspace_client)
|
|
2570
|
+
if self.app:
|
|
2571
|
+
from databricks.sdk.service.apps import App
|
|
2572
|
+
|
|
2573
|
+
app_url: str | None = None
|
|
2574
|
+
|
|
2575
|
+
# First, try using McpFunctionModel's workspace_client
|
|
2576
|
+
try:
|
|
2577
|
+
app: App = self.workspace_client.apps.get(self.app.name)
|
|
2578
|
+
app_url = app.url
|
|
2579
|
+
logger.trace(
|
|
2580
|
+
"Got app URL using McpFunctionModel workspace_client",
|
|
2581
|
+
app_name=self.app.name,
|
|
2582
|
+
url=app_url,
|
|
2583
|
+
)
|
|
2584
|
+
except Exception as e:
|
|
2585
|
+
logger.debug(
|
|
2586
|
+
"Failed to get app URL using McpFunctionModel workspace_client, "
|
|
2587
|
+
"trying DatabricksAppModel.url property",
|
|
2588
|
+
app_name=self.app.name,
|
|
2589
|
+
error=str(e),
|
|
2590
|
+
)
|
|
2591
|
+
|
|
2592
|
+
# Fall back to DatabricksAppModel.url property
|
|
2593
|
+
if not app_url:
|
|
2594
|
+
try:
|
|
2595
|
+
app_url = self.app.url
|
|
2596
|
+
logger.trace(
|
|
2597
|
+
"Got app URL using DatabricksAppModel.url property",
|
|
2598
|
+
app_name=self.app.name,
|
|
2599
|
+
url=app_url,
|
|
2600
|
+
)
|
|
2601
|
+
except Exception as e:
|
|
2602
|
+
raise RuntimeError(
|
|
2603
|
+
f"Databricks App '{self.app.name}' does not have a URL. "
|
|
2604
|
+
"The app may not be deployed yet, or credentials may be invalid. "
|
|
2605
|
+
f"Error: {e}"
|
|
2606
|
+
) from e
|
|
2607
|
+
|
|
2608
|
+
return f"{app_url.rstrip('/')}/mcp"
|
|
2609
|
+
|
|
1944
2610
|
# Vector Search
|
|
1945
2611
|
if self.vector_search:
|
|
1946
2612
|
if (
|
|
@@ -1950,33 +2616,35 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1950
2616
|
raise ValueError(
|
|
1951
2617
|
"vector_search must have an index with a schema (catalog/schema) configured"
|
|
1952
2618
|
)
|
|
1953
|
-
catalog: str = self.vector_search.index.schema_model.catalog_name
|
|
1954
|
-
schema: str = self.vector_search.index.schema_model.schema_name
|
|
2619
|
+
catalog: str = value_of(self.vector_search.index.schema_model.catalog_name)
|
|
2620
|
+
schema: str = value_of(self.vector_search.index.schema_model.schema_name)
|
|
1955
2621
|
return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
|
|
1956
2622
|
|
|
1957
2623
|
# UC Functions MCP server
|
|
1958
2624
|
if self.functions:
|
|
1959
|
-
catalog: str = self.functions.catalog_name
|
|
1960
|
-
schema: str = self.functions.schema_name
|
|
2625
|
+
catalog: str = value_of(self.functions.catalog_name)
|
|
2626
|
+
schema: str = value_of(self.functions.schema_name)
|
|
1961
2627
|
return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
|
|
1962
2628
|
|
|
1963
2629
|
raise ValueError(
|
|
1964
|
-
"No URL source configured. Provide one of: url, connection, genie_room, "
|
|
2630
|
+
"No URL source configured. Provide one of: url, app, connection, genie_room, "
|
|
1965
2631
|
"sql, vector_search, or functions"
|
|
1966
2632
|
)
|
|
1967
2633
|
|
|
1968
2634
|
@field_serializer("transport")
|
|
1969
|
-
def serialize_transport(self, value) -> str:
|
|
2635
|
+
def serialize_transport(self, value: TransportType) -> str:
|
|
2636
|
+
"""Serialize transport enum to string."""
|
|
1970
2637
|
if isinstance(value, TransportType):
|
|
1971
2638
|
return value.value
|
|
1972
2639
|
return str(value)
|
|
1973
2640
|
|
|
1974
2641
|
@model_validator(mode="after")
|
|
1975
|
-
def validate_mutually_exclusive(self) ->
|
|
2642
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1976
2643
|
"""Validate that exactly one URL source is provided."""
|
|
1977
2644
|
# Count how many URL sources are provided
|
|
1978
2645
|
url_sources: list[tuple[str, Any]] = [
|
|
1979
2646
|
("url", self.url),
|
|
2647
|
+
("app", self.app),
|
|
1980
2648
|
("connection", self.connection),
|
|
1981
2649
|
("genie_room", self.genie_room),
|
|
1982
2650
|
("sql", self.sql),
|
|
@@ -1992,13 +2660,13 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1992
2660
|
if len(provided_sources) == 0:
|
|
1993
2661
|
raise ValueError(
|
|
1994
2662
|
"For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
|
|
1995
|
-
"url, connection, genie_room, sql, vector_search, or functions"
|
|
2663
|
+
"url, app, connection, genie_room, sql, vector_search, or functions"
|
|
1996
2664
|
)
|
|
1997
2665
|
if len(provided_sources) > 1:
|
|
1998
2666
|
raise ValueError(
|
|
1999
2667
|
f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
|
|
2000
2668
|
f"Found: {', '.join(provided_sources)}. "
|
|
2001
|
-
f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
|
|
2669
|
+
f"Please provide only one of: url, app, connection, genie_room, sql, vector_search, or functions"
|
|
2002
2670
|
)
|
|
2003
2671
|
|
|
2004
2672
|
if self.transport == TransportType.STDIO:
|
|
@@ -2010,14 +2678,41 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
2010
2678
|
return self
|
|
2011
2679
|
|
|
2012
2680
|
@model_validator(mode="after")
|
|
2013
|
-
def update_url(self) ->
|
|
2014
|
-
|
|
2681
|
+
def update_url(self) -> Self:
|
|
2682
|
+
"""Resolve AnyVariable to concrete value for URL."""
|
|
2683
|
+
if self.url is not None:
|
|
2684
|
+
resolved_value: Any = value_of(self.url)
|
|
2685
|
+
# Cast to string since URL must be a string
|
|
2686
|
+
self.url = str(resolved_value) if resolved_value else None
|
|
2015
2687
|
return self
|
|
2016
2688
|
|
|
2017
2689
|
@model_validator(mode="after")
|
|
2018
|
-
def update_headers(self) ->
|
|
2690
|
+
def update_headers(self) -> Self:
|
|
2691
|
+
"""Resolve AnyVariable to concrete values for headers."""
|
|
2019
2692
|
for key, value in self.headers.items():
|
|
2020
|
-
|
|
2693
|
+
resolved_value: Any = value_of(value)
|
|
2694
|
+
# Headers must be strings
|
|
2695
|
+
self.headers[key] = str(resolved_value) if resolved_value else ""
|
|
2696
|
+
return self
|
|
2697
|
+
|
|
2698
|
+
@model_validator(mode="after")
|
|
2699
|
+
def validate_tool_filters(self) -> Self:
|
|
2700
|
+
"""Validate tool filter configuration."""
|
|
2701
|
+
from loguru import logger
|
|
2702
|
+
|
|
2703
|
+
# Warn if both are empty lists (explicit but pointless)
|
|
2704
|
+
if self.include_tools is not None and len(self.include_tools) == 0:
|
|
2705
|
+
logger.warning(
|
|
2706
|
+
"include_tools is empty list - no tools will be loaded. "
|
|
2707
|
+
"Remove field to load all tools."
|
|
2708
|
+
)
|
|
2709
|
+
|
|
2710
|
+
if self.exclude_tools is not None and len(self.exclude_tools) == 0:
|
|
2711
|
+
logger.warning(
|
|
2712
|
+
"exclude_tools is empty list - has no effect. "
|
|
2713
|
+
"Remove field or add patterns."
|
|
2714
|
+
)
|
|
2715
|
+
|
|
2021
2716
|
return self
|
|
2022
2717
|
|
|
2023
2718
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
@@ -2425,7 +3120,6 @@ class SupervisorModel(BaseModel):
|
|
|
2425
3120
|
|
|
2426
3121
|
class SwarmModel(BaseModel):
|
|
2427
3122
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2428
|
-
model: LLMModel
|
|
2429
3123
|
default_agent: Optional[AgentModel | str] = None
|
|
2430
3124
|
middleware: list[MiddlewareModel] = Field(
|
|
2431
3125
|
default_factory=list,
|
|
@@ -2439,11 +3133,17 @@ class SwarmModel(BaseModel):
|
|
|
2439
3133
|
class OrchestrationModel(BaseModel):
|
|
2440
3134
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2441
3135
|
supervisor: Optional[SupervisorModel] = None
|
|
2442
|
-
swarm: Optional[SwarmModel] = None
|
|
3136
|
+
swarm: Optional[SwarmModel | Literal[True]] = None
|
|
2443
3137
|
memory: Optional[MemoryModel] = None
|
|
2444
3138
|
|
|
2445
3139
|
@model_validator(mode="after")
|
|
2446
|
-
def
|
|
3140
|
+
def validate_and_normalize(self) -> Self:
|
|
3141
|
+
"""Validate orchestration and normalize swarm shorthand."""
|
|
3142
|
+
# Convert swarm: true to SwarmModel()
|
|
3143
|
+
if self.swarm is True:
|
|
3144
|
+
self.swarm = SwarmModel()
|
|
3145
|
+
|
|
3146
|
+
# Validate mutually exclusive
|
|
2447
3147
|
if self.supervisor is not None and self.swarm is not None:
|
|
2448
3148
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
2449
3149
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -2653,6 +3353,11 @@ class AppModel(BaseModel):
|
|
|
2653
3353
|
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
2654
3354
|
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
2655
3355
|
)
|
|
3356
|
+
deployment_target: Optional[DeploymentTarget] = Field(
|
|
3357
|
+
default=None,
|
|
3358
|
+
description="Default deployment target. If not specified, defaults to MODEL_SERVING. "
|
|
3359
|
+
"Can be overridden via CLI --target flag. Options: 'model_serving' or 'apps'.",
|
|
3360
|
+
)
|
|
2656
3361
|
|
|
2657
3362
|
@model_validator(mode="after")
|
|
2658
3363
|
def set_databricks_env_vars(self) -> Self:
|
|
@@ -2710,9 +3415,7 @@ class AppModel(BaseModel):
|
|
|
2710
3415
|
elif len(self.agents) == 1:
|
|
2711
3416
|
default_agent: AgentModel = self.agents[0]
|
|
2712
3417
|
self.orchestration = OrchestrationModel(
|
|
2713
|
-
swarm=SwarmModel(
|
|
2714
|
-
model=default_agent.model, default_agent=default_agent
|
|
2715
|
-
)
|
|
3418
|
+
swarm=SwarmModel(default_agent=default_agent)
|
|
2716
3419
|
)
|
|
2717
3420
|
else:
|
|
2718
3421
|
raise ValueError("At least one agent must be specified")
|
|
@@ -2752,8 +3455,24 @@ class GuidelineModel(BaseModel):
|
|
|
2752
3455
|
|
|
2753
3456
|
|
|
2754
3457
|
class EvaluationModel(BaseModel):
|
|
3458
|
+
"""
|
|
3459
|
+
Configuration for MLflow GenAI evaluation.
|
|
3460
|
+
|
|
3461
|
+
Attributes:
|
|
3462
|
+
model: LLM model used as the judge for LLM-based scorers (e.g., Guidelines, Safety).
|
|
3463
|
+
This model evaluates agent responses during evaluation.
|
|
3464
|
+
table: Table to store evaluation results.
|
|
3465
|
+
num_evals: Number of evaluation samples to generate.
|
|
3466
|
+
agent_description: Description of the agent for evaluation data generation.
|
|
3467
|
+
question_guidelines: Guidelines for generating evaluation questions.
|
|
3468
|
+
custom_inputs: Custom inputs to pass to the agent during evaluation.
|
|
3469
|
+
guidelines: List of guideline configurations for Guidelines scorers.
|
|
3470
|
+
"""
|
|
3471
|
+
|
|
2755
3472
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2756
|
-
model: LLMModel
|
|
3473
|
+
model: LLMModel = Field(
|
|
3474
|
+
..., description="LLM model used as the judge for LLM-based evaluation scorers"
|
|
3475
|
+
)
|
|
2757
3476
|
table: TableModel
|
|
2758
3477
|
num_evals: int
|
|
2759
3478
|
agent_description: Optional[str] = None
|
|
@@ -2761,6 +3480,16 @@ class EvaluationModel(BaseModel):
|
|
|
2761
3480
|
custom_inputs: dict[str, Any] = Field(default_factory=dict)
|
|
2762
3481
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
2763
3482
|
|
|
3483
|
+
@property
|
|
3484
|
+
def judge_model_endpoint(self) -> str:
|
|
3485
|
+
"""
|
|
3486
|
+
Get the judge model endpoint string for MLflow scorers.
|
|
3487
|
+
|
|
3488
|
+
Returns:
|
|
3489
|
+
Endpoint string in format 'databricks:/model-name'
|
|
3490
|
+
"""
|
|
3491
|
+
return f"databricks:/{self.model.name}"
|
|
3492
|
+
|
|
2764
3493
|
|
|
2765
3494
|
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
2766
3495
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -2958,6 +3687,165 @@ class OptimizationsModel(BaseModel):
|
|
|
2958
3687
|
return results
|
|
2959
3688
|
|
|
2960
3689
|
|
|
3690
|
+
class SemanticCacheEvalEntryModel(BaseModel):
|
|
3691
|
+
"""Single evaluation entry for semantic cache threshold optimization.
|
|
3692
|
+
|
|
3693
|
+
Represents a pair of question/context combinations to evaluate
|
|
3694
|
+
whether the cache should return a hit or miss.
|
|
3695
|
+
|
|
3696
|
+
Example:
|
|
3697
|
+
entry:
|
|
3698
|
+
question: "What are total sales?"
|
|
3699
|
+
question_embedding: [0.1, 0.2, ...] # Pre-computed
|
|
3700
|
+
context: "Previous: Show me revenue"
|
|
3701
|
+
context_embedding: [0.1, 0.2, ...]
|
|
3702
|
+
cached_question: "Show total sales"
|
|
3703
|
+
cached_question_embedding: [0.1, 0.2, ...]
|
|
3704
|
+
cached_context: "Previous: Show me revenue"
|
|
3705
|
+
cached_context_embedding: [0.1, 0.2, ...]
|
|
3706
|
+
expected_match: true
|
|
3707
|
+
"""
|
|
3708
|
+
|
|
3709
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3710
|
+
question: str
|
|
3711
|
+
question_embedding: list[float]
|
|
3712
|
+
context: str = ""
|
|
3713
|
+
context_embedding: list[float] = Field(default_factory=list)
|
|
3714
|
+
cached_question: str
|
|
3715
|
+
cached_question_embedding: list[float]
|
|
3716
|
+
cached_context: str = ""
|
|
3717
|
+
cached_context_embedding: list[float] = Field(default_factory=list)
|
|
3718
|
+
expected_match: Optional[bool] = None # None = use LLM judge
|
|
3719
|
+
|
|
3720
|
+
|
|
3721
|
+
class SemanticCacheEvalDatasetModel(BaseModel):
|
|
3722
|
+
"""Dataset for semantic cache threshold optimization.
|
|
3723
|
+
|
|
3724
|
+
Contains pairs of questions/contexts to evaluate whether thresholds
|
|
3725
|
+
correctly identify semantic matches.
|
|
3726
|
+
|
|
3727
|
+
Example:
|
|
3728
|
+
dataset:
|
|
3729
|
+
name: my_cache_eval_dataset
|
|
3730
|
+
description: "Evaluation data for cache tuning"
|
|
3731
|
+
entries:
|
|
3732
|
+
- question: "What are total sales?"
|
|
3733
|
+
# ... entry fields
|
|
3734
|
+
"""
|
|
3735
|
+
|
|
3736
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3737
|
+
name: str
|
|
3738
|
+
description: str = ""
|
|
3739
|
+
entries: list[SemanticCacheEvalEntryModel] = Field(default_factory=list)
|
|
3740
|
+
|
|
3741
|
+
def as_eval_dataset(self) -> "SemanticCacheEvalDataset":
|
|
3742
|
+
"""Convert to internal evaluation dataset format."""
|
|
3743
|
+
from dao_ai.genie.cache.optimization import (
|
|
3744
|
+
SemanticCacheEvalDataset,
|
|
3745
|
+
SemanticCacheEvalEntry,
|
|
3746
|
+
)
|
|
3747
|
+
|
|
3748
|
+
entries = [
|
|
3749
|
+
SemanticCacheEvalEntry(
|
|
3750
|
+
question=e.question,
|
|
3751
|
+
question_embedding=e.question_embedding,
|
|
3752
|
+
context=e.context,
|
|
3753
|
+
context_embedding=e.context_embedding,
|
|
3754
|
+
cached_question=e.cached_question,
|
|
3755
|
+
cached_question_embedding=e.cached_question_embedding,
|
|
3756
|
+
cached_context=e.cached_context,
|
|
3757
|
+
cached_context_embedding=e.cached_context_embedding,
|
|
3758
|
+
expected_match=e.expected_match,
|
|
3759
|
+
)
|
|
3760
|
+
for e in self.entries
|
|
3761
|
+
]
|
|
3762
|
+
|
|
3763
|
+
return SemanticCacheEvalDataset(
|
|
3764
|
+
name=self.name,
|
|
3765
|
+
entries=entries,
|
|
3766
|
+
description=self.description,
|
|
3767
|
+
)
|
|
3768
|
+
|
|
3769
|
+
|
|
3770
|
+
class SemanticCacheThresholdOptimizationModel(BaseModel):
|
|
3771
|
+
"""Configuration for semantic cache threshold optimization.
|
|
3772
|
+
|
|
3773
|
+
Uses Optuna Bayesian optimization to find optimal threshold values
|
|
3774
|
+
that maximize cache hit accuracy (F1 score by default).
|
|
3775
|
+
|
|
3776
|
+
Example:
|
|
3777
|
+
threshold_optimization:
|
|
3778
|
+
name: optimize_cache_thresholds
|
|
3779
|
+
cache_parameters: *my_cache_params
|
|
3780
|
+
dataset: *my_eval_dataset
|
|
3781
|
+
judge_model: databricks-meta-llama-3-3-70b-instruct
|
|
3782
|
+
n_trials: 50
|
|
3783
|
+
metric: f1
|
|
3784
|
+
"""
|
|
3785
|
+
|
|
3786
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3787
|
+
name: str
|
|
3788
|
+
cache_parameters: Optional[GenieContextAwareCacheParametersModel] = None
|
|
3789
|
+
dataset: SemanticCacheEvalDatasetModel
|
|
3790
|
+
judge_model: Optional[LLMModel | str] = "databricks-meta-llama-3-3-70b-instruct"
|
|
3791
|
+
n_trials: int = 50
|
|
3792
|
+
metric: Literal["f1", "precision", "recall", "fbeta"] = "f1"
|
|
3793
|
+
beta: float = 1.0 # For fbeta metric
|
|
3794
|
+
seed: Optional[int] = None
|
|
3795
|
+
|
|
3796
|
+
def optimize(
|
|
3797
|
+
self, w: WorkspaceClient | None = None
|
|
3798
|
+
) -> "ThresholdOptimizationResult":
|
|
3799
|
+
"""
|
|
3800
|
+
Optimize semantic cache thresholds.
|
|
3801
|
+
|
|
3802
|
+
Args:
|
|
3803
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
3804
|
+
|
|
3805
|
+
Returns:
|
|
3806
|
+
ThresholdOptimizationResult with optimized thresholds
|
|
3807
|
+
"""
|
|
3808
|
+
from dao_ai.genie.cache.optimization import (
|
|
3809
|
+
ThresholdOptimizationResult,
|
|
3810
|
+
optimize_semantic_cache_thresholds,
|
|
3811
|
+
)
|
|
3812
|
+
|
|
3813
|
+
# Convert dataset
|
|
3814
|
+
eval_dataset = self.dataset.as_eval_dataset()
|
|
3815
|
+
|
|
3816
|
+
# Get original thresholds from cache_parameters
|
|
3817
|
+
original_thresholds: dict[str, float] | None = None
|
|
3818
|
+
if self.cache_parameters:
|
|
3819
|
+
original_thresholds = {
|
|
3820
|
+
"similarity_threshold": self.cache_parameters.similarity_threshold,
|
|
3821
|
+
"context_similarity_threshold": self.cache_parameters.context_similarity_threshold,
|
|
3822
|
+
"question_weight": self.cache_parameters.question_weight or 0.6,
|
|
3823
|
+
}
|
|
3824
|
+
|
|
3825
|
+
# Get judge model
|
|
3826
|
+
judge_model_name: str
|
|
3827
|
+
if isinstance(self.judge_model, str):
|
|
3828
|
+
judge_model_name = self.judge_model
|
|
3829
|
+
elif self.judge_model:
|
|
3830
|
+
judge_model_name = self.judge_model.uri
|
|
3831
|
+
else:
|
|
3832
|
+
judge_model_name = "databricks-meta-llama-3-3-70b-instruct"
|
|
3833
|
+
|
|
3834
|
+
result: ThresholdOptimizationResult = optimize_semantic_cache_thresholds(
|
|
3835
|
+
dataset=eval_dataset,
|
|
3836
|
+
original_thresholds=original_thresholds,
|
|
3837
|
+
judge_model=judge_model_name,
|
|
3838
|
+
n_trials=self.n_trials,
|
|
3839
|
+
metric=self.metric,
|
|
3840
|
+
beta=self.beta,
|
|
3841
|
+
register_if_improved=True,
|
|
3842
|
+
study_name=self.name,
|
|
3843
|
+
seed=self.seed,
|
|
3844
|
+
)
|
|
3845
|
+
|
|
3846
|
+
return result
|
|
3847
|
+
|
|
3848
|
+
|
|
2961
3849
|
class DatasetFormat(str, Enum):
|
|
2962
3850
|
CSV = "csv"
|
|
2963
3851
|
DELTA = "delta"
|
|
@@ -3133,6 +4021,7 @@ class ResourcesModel(BaseModel):
|
|
|
3133
4021
|
|
|
3134
4022
|
class AppConfig(BaseModel):
|
|
3135
4023
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
4024
|
+
version: Optional[str] = None
|
|
3136
4025
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
3137
4026
|
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
3138
4027
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
@@ -3153,6 +4042,9 @@ class AppConfig(BaseModel):
|
|
|
3153
4042
|
)
|
|
3154
4043
|
providers: Optional[dict[type | str, Any]] = None
|
|
3155
4044
|
|
|
4045
|
+
# Private attribute to track the source config file path (set by from_file)
|
|
4046
|
+
_source_config_path: str | None = None
|
|
4047
|
+
|
|
3156
4048
|
@classmethod
|
|
3157
4049
|
def from_file(cls, path: PathLike) -> "AppConfig":
|
|
3158
4050
|
path = Path(path).as_posix()
|
|
@@ -3160,12 +4052,20 @@ class AppConfig(BaseModel):
|
|
|
3160
4052
|
model_config: ModelConfig = ModelConfig(development_config=path)
|
|
3161
4053
|
config: AppConfig = AppConfig(**model_config.to_dict())
|
|
3162
4054
|
|
|
4055
|
+
# Store the source config path for later use (e.g., Apps deployment)
|
|
4056
|
+
config._source_config_path = path
|
|
4057
|
+
|
|
3163
4058
|
config.initialize()
|
|
3164
4059
|
|
|
3165
4060
|
atexit.register(config.shutdown)
|
|
3166
4061
|
|
|
3167
4062
|
return config
|
|
3168
4063
|
|
|
4064
|
+
@property
|
|
4065
|
+
def source_config_path(self) -> str | None:
|
|
4066
|
+
"""Get the source config file path if loaded via from_file."""
|
|
4067
|
+
return self._source_config_path
|
|
4068
|
+
|
|
3169
4069
|
def initialize(self) -> None:
|
|
3170
4070
|
from dao_ai.hooks.core import create_hooks
|
|
3171
4071
|
from dao_ai.logging import configure_logging
|
|
@@ -3236,6 +4136,7 @@ class AppConfig(BaseModel):
|
|
|
3236
4136
|
|
|
3237
4137
|
def deploy_agent(
|
|
3238
4138
|
self,
|
|
4139
|
+
target: DeploymentTarget | None = None,
|
|
3239
4140
|
w: WorkspaceClient | None = None,
|
|
3240
4141
|
vsc: "VectorSearchClient | None" = None,
|
|
3241
4142
|
pat: str | None = None,
|
|
@@ -3243,9 +4144,39 @@ class AppConfig(BaseModel):
|
|
|
3243
4144
|
client_secret: str | None = None,
|
|
3244
4145
|
workspace_host: str | None = None,
|
|
3245
4146
|
) -> None:
|
|
4147
|
+
"""
|
|
4148
|
+
Deploy the agent to the specified target.
|
|
4149
|
+
|
|
4150
|
+
Target resolution follows this priority:
|
|
4151
|
+
1. Explicit `target` parameter (if provided)
|
|
4152
|
+
2. `app.deployment_target` from config file (if set)
|
|
4153
|
+
3. Default: MODEL_SERVING
|
|
4154
|
+
|
|
4155
|
+
Args:
|
|
4156
|
+
target: The deployment target (MODEL_SERVING or APPS). If None, uses
|
|
4157
|
+
config.app.deployment_target or defaults to MODEL_SERVING.
|
|
4158
|
+
w: Optional WorkspaceClient instance
|
|
4159
|
+
vsc: Optional VectorSearchClient instance
|
|
4160
|
+
pat: Optional personal access token for authentication
|
|
4161
|
+
client_id: Optional client ID for service principal authentication
|
|
4162
|
+
client_secret: Optional client secret for service principal authentication
|
|
4163
|
+
workspace_host: Optional workspace host URL
|
|
4164
|
+
"""
|
|
3246
4165
|
from dao_ai.providers.base import ServiceProvider
|
|
3247
4166
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
3248
4167
|
|
|
4168
|
+
# Resolve target using hybrid logic:
|
|
4169
|
+
# 1. Explicit parameter takes precedence
|
|
4170
|
+
# 2. Fall back to config.app.deployment_target
|
|
4171
|
+
# 3. Default to MODEL_SERVING
|
|
4172
|
+
resolved_target: DeploymentTarget
|
|
4173
|
+
if target is not None:
|
|
4174
|
+
resolved_target = target
|
|
4175
|
+
elif self.app is not None and self.app.deployment_target is not None:
|
|
4176
|
+
resolved_target = self.app.deployment_target
|
|
4177
|
+
else:
|
|
4178
|
+
resolved_target = DeploymentTarget.MODEL_SERVING
|
|
4179
|
+
|
|
3249
4180
|
provider: ServiceProvider = DatabricksProvider(
|
|
3250
4181
|
w=w,
|
|
3251
4182
|
vsc=vsc,
|
|
@@ -3254,7 +4185,7 @@ class AppConfig(BaseModel):
|
|
|
3254
4185
|
client_secret=client_secret,
|
|
3255
4186
|
workspace_host=workspace_host,
|
|
3256
4187
|
)
|
|
3257
|
-
provider.deploy_agent(self)
|
|
4188
|
+
provider.deploy_agent(self, target=resolved_target)
|
|
3258
4189
|
|
|
3259
4190
|
def find_agents(
|
|
3260
4191
|
self, predicate: Callable[[AgentModel], bool] | None = None
|