dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -7,6 +7,7 @@ from enum import Enum
|
|
|
7
7
|
from os import PathLike
|
|
8
8
|
from pathlib import Path
|
|
9
9
|
from typing import (
|
|
10
|
+
TYPE_CHECKING,
|
|
10
11
|
Any,
|
|
11
12
|
Callable,
|
|
12
13
|
Iterator,
|
|
@@ -18,12 +19,20 @@ from typing import (
|
|
|
18
19
|
Union,
|
|
19
20
|
)
|
|
20
21
|
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from dao_ai.genie.cache.optimization import (
|
|
24
|
+
SemanticCacheEvalDataset,
|
|
25
|
+
ThresholdOptimizationResult,
|
|
26
|
+
)
|
|
27
|
+
from dao_ai.state import Context
|
|
28
|
+
|
|
21
29
|
from databricks.sdk import WorkspaceClient
|
|
22
30
|
from databricks.sdk.credentials_provider import (
|
|
23
31
|
CredentialsStrategy,
|
|
24
32
|
ModelServingUserCredentials,
|
|
25
33
|
)
|
|
26
34
|
from databricks.sdk.errors.platform import NotFound
|
|
35
|
+
from databricks.sdk.service.apps import App
|
|
27
36
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
28
37
|
from databricks.sdk.service.dashboards import GenieSpace
|
|
29
38
|
from databricks.sdk.service.database import DatabaseInstance
|
|
@@ -147,7 +156,7 @@ class PrimitiveVariableModel(BaseModel, HasValue):
|
|
|
147
156
|
return str(value)
|
|
148
157
|
|
|
149
158
|
@model_validator(mode="after")
|
|
150
|
-
def validate_value(self) ->
|
|
159
|
+
def validate_value(self) -> Self:
|
|
151
160
|
if not isinstance(self.as_value(), (str, int, float, bool)):
|
|
152
161
|
raise ValueError("Value must be a primitive type (str, int, float, bool)")
|
|
153
162
|
return self
|
|
@@ -207,7 +216,9 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
207
216
|
Authentication Options:
|
|
208
217
|
----------------------
|
|
209
218
|
1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
|
|
210
|
-
calling user's identity
|
|
219
|
+
calling user's identity. Implementation varies by deployment:
|
|
220
|
+
- Databricks Apps: Uses X-Forwarded-Access-Token from request headers
|
|
221
|
+
- Model Serving: Uses ModelServingUserCredentials
|
|
211
222
|
|
|
212
223
|
2. **Service Principal (OAuth M2M)**: Provide service_principal or
|
|
213
224
|
(client_id + client_secret + workspace_host) for service principal auth.
|
|
@@ -220,9 +231,17 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
220
231
|
|
|
221
232
|
Authentication Priority:
|
|
222
233
|
1. OBO (on_behalf_of_user=True)
|
|
234
|
+
- Checks for forwarded headers (Databricks Apps)
|
|
235
|
+
- Falls back to ModelServingUserCredentials (Model Serving)
|
|
223
236
|
2. Service Principal (client_id + client_secret + workspace_host)
|
|
224
237
|
3. PAT (pat + workspace_host)
|
|
225
238
|
4. Ambient/default authentication
|
|
239
|
+
|
|
240
|
+
Note: When on_behalf_of_user=True, the agent acts as the calling user regardless
|
|
241
|
+
of deployment target. In Databricks Apps, this uses X-Forwarded-Access-Token
|
|
242
|
+
automatically captured by MLflow AgentServer. In Model Serving, this uses
|
|
243
|
+
ModelServingUserCredentials. Forwarded headers are ONLY used when
|
|
244
|
+
on_behalf_of_user=True.
|
|
226
245
|
"""
|
|
227
246
|
|
|
228
247
|
model_config = ConfigDict(use_enum_values=True)
|
|
@@ -234,9 +253,6 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
234
253
|
workspace_host: Optional[AnyVariable] = None
|
|
235
254
|
pat: Optional[AnyVariable] = None
|
|
236
255
|
|
|
237
|
-
# Private attribute to cache the workspace client (lazy instantiation)
|
|
238
|
-
_workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
|
|
239
|
-
|
|
240
256
|
@abstractmethod
|
|
241
257
|
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
242
258
|
|
|
@@ -272,19 +288,16 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
272
288
|
"""
|
|
273
289
|
Get a WorkspaceClient configured with the appropriate authentication.
|
|
274
290
|
|
|
275
|
-
|
|
291
|
+
A new client is created on each access.
|
|
276
292
|
|
|
277
293
|
Authentication priority:
|
|
278
|
-
1.
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
294
|
+
1. On-Behalf-Of User (on_behalf_of_user=True):
|
|
295
|
+
- Uses ModelServingUserCredentials (Model Serving)
|
|
296
|
+
- For Databricks Apps with headers, use workspace_client_from(context)
|
|
297
|
+
2. Service Principal (client_id + client_secret + workspace_host)
|
|
298
|
+
3. PAT (pat + workspace_host)
|
|
299
|
+
4. Ambient/default authentication
|
|
283
300
|
"""
|
|
284
|
-
# Return cached client if already instantiated
|
|
285
|
-
if self._workspace_client is not None:
|
|
286
|
-
return self._workspace_client
|
|
287
|
-
|
|
288
301
|
from dao_ai.utils import normalize_host
|
|
289
302
|
|
|
290
303
|
# Check for OBO first (highest priority)
|
|
@@ -292,12 +305,9 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
292
305
|
credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
|
|
293
306
|
logger.debug(
|
|
294
307
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
295
|
-
f"with OBO credentials strategy"
|
|
296
|
-
)
|
|
297
|
-
self._workspace_client = WorkspaceClient(
|
|
298
|
-
credentials_strategy=credentials_strategy
|
|
308
|
+
f"with OBO credentials strategy (Model Serving)"
|
|
299
309
|
)
|
|
300
|
-
return
|
|
310
|
+
return WorkspaceClient(credentials_strategy=credentials_strategy)
|
|
301
311
|
|
|
302
312
|
# Check for service principal credentials
|
|
303
313
|
client_id_value: str | None = (
|
|
@@ -312,18 +322,24 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
312
322
|
else None
|
|
313
323
|
)
|
|
314
324
|
|
|
315
|
-
if client_id_value and client_secret_value
|
|
325
|
+
if client_id_value and client_secret_value:
|
|
326
|
+
# If workspace_host is not provided, check DATABRICKS_HOST env var first,
|
|
327
|
+
# then fall back to WorkspaceClient().config.host
|
|
328
|
+
if not workspace_host_value:
|
|
329
|
+
workspace_host_value = os.getenv("DATABRICKS_HOST")
|
|
330
|
+
if not workspace_host_value:
|
|
331
|
+
workspace_host_value = WorkspaceClient().config.host
|
|
332
|
+
|
|
316
333
|
logger.debug(
|
|
317
334
|
f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
|
|
318
335
|
f"client_id={client_id_value}, host={workspace_host_value}"
|
|
319
336
|
)
|
|
320
|
-
|
|
337
|
+
return WorkspaceClient(
|
|
321
338
|
host=workspace_host_value,
|
|
322
339
|
client_id=client_id_value,
|
|
323
340
|
client_secret=client_secret_value,
|
|
324
341
|
auth_type="oauth-m2m",
|
|
325
342
|
)
|
|
326
|
-
return self._workspace_client
|
|
327
343
|
|
|
328
344
|
# Check for PAT authentication
|
|
329
345
|
pat_value: str | None = value_of(self.pat) if self.pat else None
|
|
@@ -331,20 +347,83 @@ class IsDatabricksResource(ABC, BaseModel):
|
|
|
331
347
|
logger.debug(
|
|
332
348
|
f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
|
|
333
349
|
)
|
|
334
|
-
|
|
350
|
+
return WorkspaceClient(
|
|
335
351
|
host=workspace_host_value,
|
|
336
352
|
token=pat_value,
|
|
337
353
|
auth_type="pat",
|
|
338
354
|
)
|
|
339
|
-
return self._workspace_client
|
|
340
355
|
|
|
341
356
|
# Default: use ambient authentication
|
|
342
357
|
logger.debug(
|
|
343
358
|
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
344
359
|
"with default/ambient authentication"
|
|
345
360
|
)
|
|
346
|
-
|
|
347
|
-
|
|
361
|
+
return WorkspaceClient()
|
|
362
|
+
|
|
363
|
+
def workspace_client_from(self, context: "Context | None") -> WorkspaceClient:
|
|
364
|
+
"""
|
|
365
|
+
Get a WorkspaceClient using headers from the provided Context.
|
|
366
|
+
|
|
367
|
+
Use this method from tools that have access to ToolRuntime[Context].
|
|
368
|
+
This allows OBO authentication to work in Databricks Apps where headers
|
|
369
|
+
are captured at request entry and passed through the Context.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
context: Runtime context containing headers for OBO auth.
|
|
373
|
+
If None or no headers, falls back to workspace_client property.
|
|
374
|
+
|
|
375
|
+
Returns:
|
|
376
|
+
WorkspaceClient configured with appropriate authentication.
|
|
377
|
+
"""
|
|
378
|
+
from dao_ai.utils import normalize_host
|
|
379
|
+
|
|
380
|
+
logger.trace(
|
|
381
|
+
"workspace_client_from called",
|
|
382
|
+
context=context,
|
|
383
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
# Check if we have headers in context for OBO
|
|
387
|
+
if context and context.headers and self.on_behalf_of_user:
|
|
388
|
+
headers = context.headers
|
|
389
|
+
# Try both lowercase and title-case header names (HTTP headers are case-insensitive)
|
|
390
|
+
forwarded_token: str = headers.get(
|
|
391
|
+
"x-forwarded-access-token"
|
|
392
|
+
) or headers.get("X-Forwarded-Access-Token")
|
|
393
|
+
|
|
394
|
+
if forwarded_token:
|
|
395
|
+
forwarded_user = headers.get("x-forwarded-user") or headers.get(
|
|
396
|
+
"X-Forwarded-User", "unknown"
|
|
397
|
+
)
|
|
398
|
+
logger.debug(
|
|
399
|
+
f"Creating WorkspaceClient for {self.__class__.__name__} "
|
|
400
|
+
f"with OBO using forwarded token from Context",
|
|
401
|
+
forwarded_user=forwarded_user,
|
|
402
|
+
)
|
|
403
|
+
# Use workspace_host if configured, otherwise SDK will auto-detect
|
|
404
|
+
workspace_host_value: str | None = (
|
|
405
|
+
normalize_host(value_of(self.workspace_host))
|
|
406
|
+
if self.workspace_host
|
|
407
|
+
else None
|
|
408
|
+
)
|
|
409
|
+
return WorkspaceClient(
|
|
410
|
+
host=workspace_host_value,
|
|
411
|
+
token=forwarded_token,
|
|
412
|
+
auth_type="pat",
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Fall back to existing workspace_client property
|
|
416
|
+
return self.workspace_client
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class DeploymentTarget(str, Enum):
|
|
420
|
+
"""Target platform for agent deployment."""
|
|
421
|
+
|
|
422
|
+
MODEL_SERVING = "model_serving"
|
|
423
|
+
"""Deploy to Databricks Model Serving endpoint."""
|
|
424
|
+
|
|
425
|
+
APPS = "apps"
|
|
426
|
+
"""Deploy as a Databricks App."""
|
|
348
427
|
|
|
349
428
|
|
|
350
429
|
class Privilege(str, Enum):
|
|
@@ -391,10 +470,17 @@ class PermissionModel(BaseModel):
|
|
|
391
470
|
|
|
392
471
|
class SchemaModel(BaseModel, HasFullName):
|
|
393
472
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
394
|
-
catalog_name:
|
|
395
|
-
schema_name:
|
|
473
|
+
catalog_name: AnyVariable
|
|
474
|
+
schema_name: AnyVariable
|
|
396
475
|
permissions: Optional[list[PermissionModel]] = Field(default_factory=list)
|
|
397
476
|
|
|
477
|
+
@model_validator(mode="after")
|
|
478
|
+
def resolve_variables(self) -> Self:
|
|
479
|
+
"""Resolve AnyVariable fields to their actual string values."""
|
|
480
|
+
self.catalog_name = value_of(self.catalog_name)
|
|
481
|
+
self.schema_name = value_of(self.schema_name)
|
|
482
|
+
return self
|
|
483
|
+
|
|
398
484
|
@property
|
|
399
485
|
def full_name(self) -> str:
|
|
400
486
|
return f"{self.catalog_name}.{self.schema_name}"
|
|
@@ -408,9 +494,44 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
408
494
|
|
|
409
495
|
|
|
410
496
|
class DatabricksAppModel(IsDatabricksResource, HasFullName):
|
|
497
|
+
"""
|
|
498
|
+
Configuration for a Databricks App resource.
|
|
499
|
+
|
|
500
|
+
The `name` is the unique instance name of the Databricks App within the workspace.
|
|
501
|
+
The `url` is dynamically retrieved from the workspace client by calling
|
|
502
|
+
`apps.get(name)` and returning the app's URL.
|
|
503
|
+
|
|
504
|
+
Example:
|
|
505
|
+
```yaml
|
|
506
|
+
resources:
|
|
507
|
+
apps:
|
|
508
|
+
my_app:
|
|
509
|
+
name: my-databricks-app
|
|
510
|
+
```
|
|
511
|
+
"""
|
|
512
|
+
|
|
411
513
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
412
514
|
name: str
|
|
413
|
-
|
|
515
|
+
"""The unique instance name of the Databricks App in the workspace."""
|
|
516
|
+
|
|
517
|
+
@property
|
|
518
|
+
def url(self) -> str:
|
|
519
|
+
"""
|
|
520
|
+
Retrieve the URL of the Databricks App from the workspace.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
The URL of the deployed Databricks App.
|
|
524
|
+
|
|
525
|
+
Raises:
|
|
526
|
+
RuntimeError: If the app is not found or URL is not available.
|
|
527
|
+
"""
|
|
528
|
+
app: App = self.workspace_client.apps.get(self.name)
|
|
529
|
+
if app.url is None:
|
|
530
|
+
raise RuntimeError(
|
|
531
|
+
f"Databricks App '{self.name}' does not have a URL. "
|
|
532
|
+
"The app may not be deployed yet."
|
|
533
|
+
)
|
|
534
|
+
return app.url
|
|
414
535
|
|
|
415
536
|
@property
|
|
416
537
|
def full_name(self) -> str:
|
|
@@ -432,7 +553,7 @@ class TableModel(IsDatabricksResource, HasFullName):
|
|
|
432
553
|
name: Optional[str] = None
|
|
433
554
|
|
|
434
555
|
@model_validator(mode="after")
|
|
435
|
-
def validate_name_or_schema_required(self) ->
|
|
556
|
+
def validate_name_or_schema_required(self) -> Self:
|
|
436
557
|
if not self.name and not self.schema_model:
|
|
437
558
|
raise ValueError(
|
|
438
559
|
"Either 'name' or 'schema_model' must be provided for TableModel"
|
|
@@ -601,6 +722,8 @@ class VectorSearchEndpoint(BaseModel):
|
|
|
601
722
|
|
|
602
723
|
|
|
603
724
|
class IndexModel(IsDatabricksResource, HasFullName):
|
|
725
|
+
"""Model representing a Databricks Vector Search index."""
|
|
726
|
+
|
|
604
727
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
605
728
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
606
729
|
name: str
|
|
@@ -624,6 +747,22 @@ class IndexModel(IsDatabricksResource, HasFullName):
|
|
|
624
747
|
)
|
|
625
748
|
]
|
|
626
749
|
|
|
750
|
+
def exists(self) -> bool:
|
|
751
|
+
"""Check if this vector search index exists.
|
|
752
|
+
|
|
753
|
+
Returns:
|
|
754
|
+
True if the index exists, False otherwise.
|
|
755
|
+
"""
|
|
756
|
+
try:
|
|
757
|
+
self.workspace_client.vector_search_indexes.get_index(self.full_name)
|
|
758
|
+
return True
|
|
759
|
+
except NotFound:
|
|
760
|
+
logger.debug(f"Index not found: {self.full_name}")
|
|
761
|
+
return False
|
|
762
|
+
except Exception as e:
|
|
763
|
+
logger.warning(f"Error checking index existence for {self.full_name}: {e}")
|
|
764
|
+
return False
|
|
765
|
+
|
|
627
766
|
|
|
628
767
|
class FunctionModel(IsDatabricksResource, HasFullName):
|
|
629
768
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -699,11 +838,20 @@ class FunctionModel(IsDatabricksResource, HasFullName):
|
|
|
699
838
|
|
|
700
839
|
|
|
701
840
|
class WarehouseModel(IsDatabricksResource):
|
|
702
|
-
model_config = ConfigDict()
|
|
703
|
-
name: str
|
|
841
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
842
|
+
name: Optional[str] = None
|
|
704
843
|
description: Optional[str] = None
|
|
705
844
|
warehouse_id: AnyVariable
|
|
706
845
|
|
|
846
|
+
_warehouse_details: Optional[GetWarehouseResponse] = PrivateAttr(default=None)
|
|
847
|
+
|
|
848
|
+
def _get_warehouse_details(self) -> GetWarehouseResponse:
|
|
849
|
+
if self._warehouse_details is None:
|
|
850
|
+
self._warehouse_details = self.workspace_client.warehouses.get(
|
|
851
|
+
id=value_of(self.warehouse_id)
|
|
852
|
+
)
|
|
853
|
+
return self._warehouse_details
|
|
854
|
+
|
|
707
855
|
@property
|
|
708
856
|
def api_scopes(self) -> Sequence[str]:
|
|
709
857
|
return [
|
|
@@ -724,10 +872,22 @@ class WarehouseModel(IsDatabricksResource):
|
|
|
724
872
|
self.warehouse_id = value_of(self.warehouse_id)
|
|
725
873
|
return self
|
|
726
874
|
|
|
875
|
+
@model_validator(mode="after")
|
|
876
|
+
def populate_name(self) -> Self:
|
|
877
|
+
"""Populate name from warehouse details if not provided."""
|
|
878
|
+
if self.warehouse_id and not self.name:
|
|
879
|
+
try:
|
|
880
|
+
warehouse_details = self._get_warehouse_details()
|
|
881
|
+
if warehouse_details.name:
|
|
882
|
+
self.name = warehouse_details.name
|
|
883
|
+
except Exception as e:
|
|
884
|
+
logger.debug(f"Could not fetch details from warehouse: {e}")
|
|
885
|
+
return self
|
|
886
|
+
|
|
727
887
|
|
|
728
888
|
class GenieRoomModel(IsDatabricksResource):
|
|
729
889
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
730
|
-
name: str
|
|
890
|
+
name: Optional[str] = None
|
|
731
891
|
description: Optional[str] = None
|
|
732
892
|
space_id: AnyVariable
|
|
733
893
|
|
|
@@ -783,10 +943,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
783
943
|
pat=self.pat,
|
|
784
944
|
)
|
|
785
945
|
|
|
786
|
-
# Share the cached workspace client if available
|
|
787
|
-
if self._workspace_client is not None:
|
|
788
|
-
warehouse_model._workspace_client = self._workspace_client
|
|
789
|
-
|
|
790
946
|
return warehouse_model
|
|
791
947
|
except Exception as e:
|
|
792
948
|
logger.warning(
|
|
@@ -830,9 +986,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
830
986
|
workspace_host=self.workspace_host,
|
|
831
987
|
pat=self.pat,
|
|
832
988
|
)
|
|
833
|
-
# Share the cached workspace client if available
|
|
834
|
-
if self._workspace_client is not None:
|
|
835
|
-
table_model._workspace_client = self._workspace_client
|
|
836
989
|
|
|
837
990
|
# Verify the table exists before adding
|
|
838
991
|
if not table_model.exists():
|
|
@@ -870,9 +1023,6 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
870
1023
|
workspace_host=self.workspace_host,
|
|
871
1024
|
pat=self.pat,
|
|
872
1025
|
)
|
|
873
|
-
# Share the cached workspace client if available
|
|
874
|
-
if self._workspace_client is not None:
|
|
875
|
-
function_model._workspace_client = self._workspace_client
|
|
876
1026
|
|
|
877
1027
|
# Verify the function exists before adding
|
|
878
1028
|
if not function_model.exists():
|
|
@@ -936,15 +1086,17 @@ class GenieRoomModel(IsDatabricksResource):
|
|
|
936
1086
|
return self
|
|
937
1087
|
|
|
938
1088
|
@model_validator(mode="after")
|
|
939
|
-
def
|
|
940
|
-
"""Populate description from GenieSpace if not provided."""
|
|
941
|
-
if not self.description:
|
|
1089
|
+
def populate_name_and_description(self) -> Self:
|
|
1090
|
+
"""Populate name and description from GenieSpace if not provided."""
|
|
1091
|
+
if self.space_id and (not self.name or not self.description):
|
|
942
1092
|
try:
|
|
943
1093
|
space_details = self._get_space_details()
|
|
944
|
-
if space_details.
|
|
1094
|
+
if not self.name and space_details.title:
|
|
1095
|
+
self.name = space_details.title
|
|
1096
|
+
if not self.description and space_details.description:
|
|
945
1097
|
self.description = space_details.description
|
|
946
1098
|
except Exception as e:
|
|
947
|
-
logger.debug(f"Could not fetch
|
|
1099
|
+
logger.debug(f"Could not fetch details from Genie space: {e}")
|
|
948
1100
|
return self
|
|
949
1101
|
|
|
950
1102
|
|
|
@@ -980,7 +1132,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
980
1132
|
path: Optional[str] = None
|
|
981
1133
|
|
|
982
1134
|
@model_validator(mode="after")
|
|
983
|
-
def validate_path_or_volume(self) ->
|
|
1135
|
+
def validate_path_or_volume(self) -> Self:
|
|
984
1136
|
if not self.volume and not self.path:
|
|
985
1137
|
raise ValueError("Either 'volume' or 'path' must be provided")
|
|
986
1138
|
return self
|
|
@@ -1009,27 +1161,92 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
1009
1161
|
|
|
1010
1162
|
|
|
1011
1163
|
class VectorStoreModel(IsDatabricksResource):
|
|
1164
|
+
"""
|
|
1165
|
+
Configuration model for a Databricks Vector Search store.
|
|
1166
|
+
|
|
1167
|
+
Supports two modes:
|
|
1168
|
+
1. **Use Existing Index**: Provide only `index` (fully qualified name).
|
|
1169
|
+
Used for querying an existing vector search index at runtime.
|
|
1170
|
+
2. **Provisioning Mode**: Provide `source_table` + `embedding_source_column`.
|
|
1171
|
+
Used for creating a new vector search index.
|
|
1172
|
+
|
|
1173
|
+
Examples:
|
|
1174
|
+
Minimal configuration (use existing index):
|
|
1175
|
+
```yaml
|
|
1176
|
+
vector_stores:
|
|
1177
|
+
products_search:
|
|
1178
|
+
index:
|
|
1179
|
+
name: catalog.schema.my_index
|
|
1180
|
+
```
|
|
1181
|
+
|
|
1182
|
+
Full provisioning configuration:
|
|
1183
|
+
```yaml
|
|
1184
|
+
vector_stores:
|
|
1185
|
+
products_search:
|
|
1186
|
+
source_table:
|
|
1187
|
+
schema: *my_schema
|
|
1188
|
+
name: products
|
|
1189
|
+
embedding_source_column: description
|
|
1190
|
+
endpoint:
|
|
1191
|
+
name: my_endpoint
|
|
1192
|
+
```
|
|
1193
|
+
"""
|
|
1194
|
+
|
|
1012
1195
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1013
|
-
|
|
1196
|
+
|
|
1197
|
+
# RUNTIME: Only index is truly required for querying existing indexes
|
|
1014
1198
|
index: Optional[IndexModel] = None
|
|
1199
|
+
|
|
1200
|
+
# PROVISIONING ONLY: Required when creating a new index
|
|
1201
|
+
source_table: Optional[TableModel] = None
|
|
1202
|
+
embedding_source_column: Optional[str] = None
|
|
1203
|
+
embedding_model: Optional[LLMModel] = None
|
|
1015
1204
|
endpoint: Optional[VectorSearchEndpoint] = None
|
|
1016
|
-
|
|
1205
|
+
|
|
1206
|
+
# OPTIONAL: For both modes
|
|
1017
1207
|
source_path: Optional[VolumePathModel] = None
|
|
1018
1208
|
checkpoint_path: Optional[VolumePathModel] = None
|
|
1019
1209
|
primary_key: Optional[str] = None
|
|
1020
1210
|
columns: Optional[list[str]] = Field(default_factory=list)
|
|
1021
1211
|
doc_uri: Optional[str] = None
|
|
1022
|
-
|
|
1212
|
+
|
|
1213
|
+
@model_validator(mode="after")
|
|
1214
|
+
def validate_configuration_mode(self) -> Self:
|
|
1215
|
+
"""
|
|
1216
|
+
Validate that configuration is valid for either:
|
|
1217
|
+
- Use existing mode: index is provided
|
|
1218
|
+
- Provisioning mode: source_table + embedding_source_column provided
|
|
1219
|
+
"""
|
|
1220
|
+
has_index = self.index is not None
|
|
1221
|
+
has_source_table = self.source_table is not None
|
|
1222
|
+
has_embedding_col = self.embedding_source_column is not None
|
|
1223
|
+
|
|
1224
|
+
# Must have at least index OR source_table
|
|
1225
|
+
if not has_index and not has_source_table:
|
|
1226
|
+
raise ValueError(
|
|
1227
|
+
"Either 'index' (for existing indexes) or 'source_table' "
|
|
1228
|
+
"(for provisioning) must be provided"
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
# If provisioning mode, need embedding_source_column
|
|
1232
|
+
if has_source_table and not has_embedding_col:
|
|
1233
|
+
raise ValueError(
|
|
1234
|
+
"embedding_source_column is required when source_table is provided (provisioning mode)"
|
|
1235
|
+
)
|
|
1236
|
+
|
|
1237
|
+
return self
|
|
1023
1238
|
|
|
1024
1239
|
@model_validator(mode="after")
|
|
1025
1240
|
def set_default_embedding_model(self) -> Self:
|
|
1026
|
-
|
|
1241
|
+
# Only set default embedding model in provisioning mode
|
|
1242
|
+
if self.source_table is not None and not self.embedding_model:
|
|
1027
1243
|
self.embedding_model = LLMModel(name="databricks-gte-large-en")
|
|
1028
1244
|
return self
|
|
1029
1245
|
|
|
1030
1246
|
@model_validator(mode="after")
|
|
1031
1247
|
def set_default_primary_key(self) -> Self:
|
|
1032
|
-
|
|
1248
|
+
# Only auto-discover primary key in provisioning mode
|
|
1249
|
+
if self.primary_key is None and self.source_table is not None:
|
|
1033
1250
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1034
1251
|
|
|
1035
1252
|
provider: DatabricksProvider = DatabricksProvider()
|
|
@@ -1050,14 +1267,16 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
1050
1267
|
|
|
1051
1268
|
@model_validator(mode="after")
|
|
1052
1269
|
def set_default_index(self) -> Self:
|
|
1053
|
-
|
|
1270
|
+
# Only generate index from source_table in provisioning mode
|
|
1271
|
+
if self.index is None and self.source_table is not None:
|
|
1054
1272
|
name: str = f"{self.source_table.name}_index"
|
|
1055
1273
|
self.index = IndexModel(schema=self.source_table.schema_model, name=name)
|
|
1056
1274
|
return self
|
|
1057
1275
|
|
|
1058
1276
|
@model_validator(mode="after")
|
|
1059
1277
|
def set_default_endpoint(self) -> Self:
|
|
1060
|
-
|
|
1278
|
+
# Only find/create endpoint in provisioning mode
|
|
1279
|
+
if self.endpoint is None and self.source_table is not None:
|
|
1061
1280
|
from dao_ai.providers.databricks import (
|
|
1062
1281
|
DatabricksProvider,
|
|
1063
1282
|
with_available_indexes,
|
|
@@ -1092,18 +1311,60 @@ class VectorStoreModel(IsDatabricksResource):
|
|
|
1092
1311
|
return self.index.as_resources()
|
|
1093
1312
|
|
|
1094
1313
|
def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
|
|
1095
|
-
from dao_ai.providers.base import ServiceProvider
|
|
1096
1314
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1097
1315
|
|
|
1098
|
-
provider:
|
|
1316
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
1099
1317
|
index: VectorSearchIndex = provider.get_vector_index(self)
|
|
1100
1318
|
return index
|
|
1101
1319
|
|
|
1102
1320
|
def create(self, vsc: VectorSearchClient | None = None) -> None:
|
|
1103
|
-
|
|
1321
|
+
"""
|
|
1322
|
+
Create or validate the vector search index.
|
|
1323
|
+
|
|
1324
|
+
Behavior depends on configuration mode:
|
|
1325
|
+
- **Provisioning Mode** (source_table provided): Creates the index
|
|
1326
|
+
- **Use Existing Mode** (only index provided): Validates the index exists
|
|
1327
|
+
|
|
1328
|
+
Args:
|
|
1329
|
+
vsc: Optional VectorSearchClient instance
|
|
1330
|
+
|
|
1331
|
+
Raises:
|
|
1332
|
+
ValueError: If configuration is invalid or index doesn't exist
|
|
1333
|
+
"""
|
|
1104
1334
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1105
1335
|
|
|
1106
|
-
provider:
|
|
1336
|
+
provider: DatabricksProvider = DatabricksProvider(vsc=vsc)
|
|
1337
|
+
|
|
1338
|
+
if self.source_table is not None:
|
|
1339
|
+
self._create_new_index(provider)
|
|
1340
|
+
else:
|
|
1341
|
+
self._validate_existing_index(provider)
|
|
1342
|
+
|
|
1343
|
+
def _validate_existing_index(self, provider: Any) -> None:
|
|
1344
|
+
"""Validate that an existing index is accessible."""
|
|
1345
|
+
if self.index is None:
|
|
1346
|
+
raise ValueError("index is required for 'use existing' mode")
|
|
1347
|
+
|
|
1348
|
+
if self.index.exists():
|
|
1349
|
+
logger.info(
|
|
1350
|
+
"Vector search index exists and ready",
|
|
1351
|
+
index_name=self.index.full_name,
|
|
1352
|
+
)
|
|
1353
|
+
else:
|
|
1354
|
+
raise ValueError(
|
|
1355
|
+
f"Index '{self.index.full_name}' does not exist. "
|
|
1356
|
+
"Provide 'source_table' to provision it."
|
|
1357
|
+
)
|
|
1358
|
+
|
|
1359
|
+
def _create_new_index(self, provider: Any) -> None:
|
|
1360
|
+
"""Create a new vector search index from source table."""
|
|
1361
|
+
if self.embedding_source_column is None:
|
|
1362
|
+
raise ValueError("embedding_source_column is required for provisioning")
|
|
1363
|
+
if self.endpoint is None:
|
|
1364
|
+
raise ValueError("endpoint is required for provisioning")
|
|
1365
|
+
if self.index is None:
|
|
1366
|
+
raise ValueError("index is required for provisioning")
|
|
1367
|
+
|
|
1107
1368
|
provider.create_vector_store(self)
|
|
1108
1369
|
|
|
1109
1370
|
|
|
@@ -1145,13 +1406,20 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1145
1406
|
- Databricks Lakebase: Provide `instance_name` (authentication optional, supports ambient auth)
|
|
1146
1407
|
- Standard PostgreSQL: Provide `host` (authentication required via user/password)
|
|
1147
1408
|
|
|
1148
|
-
Note:
|
|
1409
|
+
Note: For Lakebase connections, `name` is optional and defaults to `instance_name`.
|
|
1410
|
+
For PostgreSQL connections, `name` is required.
|
|
1411
|
+
|
|
1412
|
+
Example Databricks Lakebase (minimal):
|
|
1413
|
+
```yaml
|
|
1414
|
+
databases:
|
|
1415
|
+
my_lakebase:
|
|
1416
|
+
instance_name: my-lakebase-instance # name defaults to instance_name
|
|
1417
|
+
```
|
|
1149
1418
|
|
|
1150
1419
|
Example Databricks Lakebase with Service Principal:
|
|
1151
1420
|
```yaml
|
|
1152
1421
|
databases:
|
|
1153
1422
|
my_lakebase:
|
|
1154
|
-
name: my-database
|
|
1155
1423
|
instance_name: my-lakebase-instance
|
|
1156
1424
|
service_principal:
|
|
1157
1425
|
client_id:
|
|
@@ -1167,7 +1435,6 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1167
1435
|
```yaml
|
|
1168
1436
|
databases:
|
|
1169
1437
|
my_lakebase:
|
|
1170
|
-
name: my-database
|
|
1171
1438
|
instance_name: my-lakebase-instance
|
|
1172
1439
|
on_behalf_of_user: true
|
|
1173
1440
|
```
|
|
@@ -1187,7 +1454,7 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1187
1454
|
"""
|
|
1188
1455
|
|
|
1189
1456
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1190
|
-
name: str
|
|
1457
|
+
name: Optional[str] = None
|
|
1191
1458
|
instance_name: Optional[str] = None
|
|
1192
1459
|
description: Optional[str] = None
|
|
1193
1460
|
host: Optional[AnyVariable] = None
|
|
@@ -1236,6 +1503,17 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1236
1503
|
)
|
|
1237
1504
|
return self
|
|
1238
1505
|
|
|
1506
|
+
@model_validator(mode="after")
|
|
1507
|
+
def populate_name_from_instance_name(self) -> Self:
|
|
1508
|
+
"""Populate name from instance_name if not provided for Lakebase connections."""
|
|
1509
|
+
if self.name is None and self.instance_name:
|
|
1510
|
+
self.name = self.instance_name
|
|
1511
|
+
elif self.name is None:
|
|
1512
|
+
raise ValueError(
|
|
1513
|
+
"Either 'name' or 'instance_name' must be provided for DatabaseModel."
|
|
1514
|
+
)
|
|
1515
|
+
return self
|
|
1516
|
+
|
|
1239
1517
|
@model_validator(mode="after")
|
|
1240
1518
|
def update_user(self) -> Self:
|
|
1241
1519
|
# Skip if using OBO (passive auth), explicit credentials, or explicit user
|
|
@@ -1266,32 +1544,12 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1266
1544
|
|
|
1267
1545
|
@model_validator(mode="after")
|
|
1268
1546
|
def update_host(self) -> Self:
|
|
1269
|
-
|
|
1547
|
+
# Lakebase uses instance_name directly via databricks_langchain - host not needed
|
|
1548
|
+
if self.is_lakebase:
|
|
1270
1549
|
return self
|
|
1271
1550
|
|
|
1272
|
-
#
|
|
1273
|
-
#
|
|
1274
|
-
if self.is_lakebase:
|
|
1275
|
-
try:
|
|
1276
|
-
existing_instance: DatabaseInstance = (
|
|
1277
|
-
self.workspace_client.database.get_database_instance(
|
|
1278
|
-
name=self.instance_name
|
|
1279
|
-
)
|
|
1280
|
-
)
|
|
1281
|
-
self.host = existing_instance.read_write_dns
|
|
1282
|
-
except Exception as e:
|
|
1283
|
-
# For Lakebase with OBO/ambient auth, we can't fetch at config time
|
|
1284
|
-
# The host will need to be provided explicitly or fetched at runtime
|
|
1285
|
-
if self.on_behalf_of_user:
|
|
1286
|
-
logger.debug(
|
|
1287
|
-
f"Could not fetch host for database {self.instance_name} "
|
|
1288
|
-
f"(Lakebase with OBO mode - will be resolved at runtime): {e}"
|
|
1289
|
-
)
|
|
1290
|
-
else:
|
|
1291
|
-
raise ValueError(
|
|
1292
|
-
f"Could not fetch host for database {self.instance_name}. "
|
|
1293
|
-
f"Please provide the 'host' explicitly or ensure the instance exists: {e}"
|
|
1294
|
-
)
|
|
1551
|
+
# For standard PostgreSQL, host must be provided by the user
|
|
1552
|
+
# (enforced by validate_connection_type)
|
|
1295
1553
|
return self
|
|
1296
1554
|
|
|
1297
1555
|
@model_validator(mode="after")
|
|
@@ -1353,10 +1611,10 @@ class DatabaseModel(IsDatabricksResource):
|
|
|
1353
1611
|
username: str | None = None
|
|
1354
1612
|
password_value: str | None = None
|
|
1355
1613
|
|
|
1356
|
-
# Resolve host -
|
|
1614
|
+
# Resolve host - fetch from API at runtime for Lakebase if not provided
|
|
1357
1615
|
host_value: Any = self.host
|
|
1358
|
-
if host_value is None and self.is_lakebase
|
|
1359
|
-
# Fetch host
|
|
1616
|
+
if host_value is None and self.is_lakebase:
|
|
1617
|
+
# Fetch host from Lakebase instance API
|
|
1360
1618
|
existing_instance: DatabaseInstance = (
|
|
1361
1619
|
self.workspace_client.database.get_database_instance(
|
|
1362
1620
|
name=self.instance_name
|
|
@@ -1456,7 +1714,7 @@ class GenieLRUCacheParametersModel(BaseModel):
|
|
|
1456
1714
|
warehouse: WarehouseModel
|
|
1457
1715
|
|
|
1458
1716
|
|
|
1459
|
-
class
|
|
1717
|
+
class GenieContextAwareCacheParametersModel(BaseModel):
|
|
1460
1718
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1461
1719
|
time_to_live_seconds: int | None = (
|
|
1462
1720
|
60 * 60 * 24
|
|
@@ -1474,6 +1732,116 @@ class GenieSemanticCacheParametersModel(BaseModel):
|
|
|
1474
1732
|
database: DatabaseModel
|
|
1475
1733
|
warehouse: WarehouseModel
|
|
1476
1734
|
table_name: str = "genie_semantic_cache"
|
|
1735
|
+
context_window_size: int = 2 # Number of previous turns to include for context
|
|
1736
|
+
max_context_tokens: int = (
|
|
1737
|
+
2000 # Maximum context length to prevent extremely long embeddings
|
|
1738
|
+
)
|
|
1739
|
+
# Prompt history configuration
|
|
1740
|
+
# Prompt history is always enabled - it stores all user prompts to maintain
|
|
1741
|
+
# conversation context for accurate semantic matching even when cache hits occur
|
|
1742
|
+
prompt_history_table: str = "genie_prompt_history" # Table name for prompt history
|
|
1743
|
+
max_prompt_history_length: int = 50 # Maximum prompts to keep per conversation
|
|
1744
|
+
use_genie_api_for_history: bool = (
|
|
1745
|
+
False # Fallback to Genie API if local history empty
|
|
1746
|
+
)
|
|
1747
|
+
prompt_history_ttl_seconds: int | None = (
|
|
1748
|
+
None # TTL for prompts (None = use cache TTL)
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1751
|
+
@model_validator(mode="after")
|
|
1752
|
+
def compute_and_validate_weights(self) -> Self:
|
|
1753
|
+
"""
|
|
1754
|
+
Compute missing weight and validate that question_weight + context_weight = 1.0.
|
|
1755
|
+
|
|
1756
|
+
Either question_weight or context_weight (or both) can be provided.
|
|
1757
|
+
The missing one will be computed as 1.0 - provided_weight.
|
|
1758
|
+
If both are provided, they must sum to 1.0.
|
|
1759
|
+
"""
|
|
1760
|
+
if self.question_weight is None and self.context_weight is None:
|
|
1761
|
+
# Both missing - use defaults
|
|
1762
|
+
self.question_weight = 0.6
|
|
1763
|
+
self.context_weight = 0.4
|
|
1764
|
+
elif self.question_weight is None:
|
|
1765
|
+
# Compute question_weight from context_weight
|
|
1766
|
+
if not (0.0 <= self.context_weight <= 1.0):
|
|
1767
|
+
raise ValueError(
|
|
1768
|
+
f"context_weight must be between 0.0 and 1.0, got {self.context_weight}"
|
|
1769
|
+
)
|
|
1770
|
+
self.question_weight = 1.0 - self.context_weight
|
|
1771
|
+
elif self.context_weight is None:
|
|
1772
|
+
# Compute context_weight from question_weight
|
|
1773
|
+
if not (0.0 <= self.question_weight <= 1.0):
|
|
1774
|
+
raise ValueError(
|
|
1775
|
+
f"question_weight must be between 0.0 and 1.0, got {self.question_weight}"
|
|
1776
|
+
)
|
|
1777
|
+
self.context_weight = 1.0 - self.question_weight
|
|
1778
|
+
else:
|
|
1779
|
+
# Both provided - validate they sum to 1.0
|
|
1780
|
+
total_weight = self.question_weight + self.context_weight
|
|
1781
|
+
if not abs(total_weight - 1.0) < 0.0001: # Allow small floating point error
|
|
1782
|
+
raise ValueError(
|
|
1783
|
+
f"question_weight ({self.question_weight}) + context_weight ({self.context_weight}) "
|
|
1784
|
+
f"must equal 1.0 (got {total_weight}). These weights determine the relative importance "
|
|
1785
|
+
f"of question vs context similarity in the combined score."
|
|
1786
|
+
)
|
|
1787
|
+
|
|
1788
|
+
return self
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
# Memory estimation for capacity planning:
|
|
1792
|
+
# - Each entry: ~20KB (8KB question embedding + 8KB context embedding + 4KB strings/overhead)
|
|
1793
|
+
# - 1,000 entries: ~20MB (0.4% of 8GB)
|
|
1794
|
+
# - 5,000 entries: ~100MB (2% of 8GB)
|
|
1795
|
+
# - 10,000 entries: ~200MB (4-5% of 8GB) - default for ~30 users
|
|
1796
|
+
# - 20,000 entries: ~400MB (8-10% of 8GB)
|
|
1797
|
+
# Default 10,000 entries provides ~330 queries per user for 30 users.
|
|
1798
|
+
class GenieInMemorySemanticCacheParametersModel(BaseModel):
|
|
1799
|
+
"""
|
|
1800
|
+
Configuration for in-memory semantic cache (no database required).
|
|
1801
|
+
|
|
1802
|
+
This cache stores embeddings and cache entries entirely in memory, providing
|
|
1803
|
+
semantic similarity matching without requiring external database dependencies
|
|
1804
|
+
like PostgreSQL or Databricks Lakebase.
|
|
1805
|
+
|
|
1806
|
+
Default settings are tuned for ~30 users on an 8GB machine:
|
|
1807
|
+
- Capacity: 10,000 entries (~200MB memory, ~330 queries per user)
|
|
1808
|
+
- Eviction: LRU (Least Recently Used) - keeps frequently accessed queries
|
|
1809
|
+
- TTL: 1 week (accommodates weekly work patterns and batch jobs)
|
|
1810
|
+
- Memory overhead: ~4-5% of 8GB system
|
|
1811
|
+
|
|
1812
|
+
The LRU eviction strategy ensures hot queries stay cached while cold queries
|
|
1813
|
+
are evicted, providing better hit rates than FIFO eviction.
|
|
1814
|
+
|
|
1815
|
+
For larger deployments or memory-constrained environments, adjust capacity and TTL accordingly.
|
|
1816
|
+
|
|
1817
|
+
Use this when:
|
|
1818
|
+
- No external database access is available
|
|
1819
|
+
- Single-instance deployments (cache not shared across instances)
|
|
1820
|
+
- Cache persistence across restarts is not required
|
|
1821
|
+
- Cache sizes are moderate (hundreds to low thousands of entries)
|
|
1822
|
+
|
|
1823
|
+
For multi-instance deployments or large cache sizes, use GenieContextAwareCacheParametersModel
|
|
1824
|
+
with PostgreSQL backend instead.
|
|
1825
|
+
"""
|
|
1826
|
+
|
|
1827
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1828
|
+
time_to_live_seconds: int | None = (
|
|
1829
|
+
60 * 60 * 24 * 7
|
|
1830
|
+
) # 1 week default (604800 seconds), None or negative = never expires
|
|
1831
|
+
similarity_threshold: float = 0.85 # Minimum similarity for question matching (L2 distance converted to 0-1 scale)
|
|
1832
|
+
context_similarity_threshold: float = 0.80 # Minimum similarity for context matching (L2 distance converted to 0-1 scale)
|
|
1833
|
+
question_weight: Optional[float] = (
|
|
1834
|
+
0.6 # Weight for question similarity in combined score (0-1). If not provided, computed as 1 - context_weight
|
|
1835
|
+
)
|
|
1836
|
+
context_weight: Optional[float] = (
|
|
1837
|
+
None # Weight for context similarity in combined score (0-1). If not provided, computed as 1 - question_weight
|
|
1838
|
+
)
|
|
1839
|
+
embedding_model: str | LLMModel = "databricks-gte-large-en"
|
|
1840
|
+
embedding_dims: int | None = None # Auto-detected if None
|
|
1841
|
+
warehouse: WarehouseModel
|
|
1842
|
+
capacity: int | None = (
|
|
1843
|
+
10000 # Maximum cache entries. ~200MB for 10000 entries (1024-dim embeddings). LRU eviction when full. None = unlimited (not recommended for production).
|
|
1844
|
+
)
|
|
1477
1845
|
context_window_size: int = 3 # Number of previous turns to include for context
|
|
1478
1846
|
max_context_tokens: int = (
|
|
1479
1847
|
2000 # Maximum context length to prevent extremely long embeddings
|
|
@@ -1526,41 +1894,83 @@ class SearchParametersModel(BaseModel):
|
|
|
1526
1894
|
query_type: Optional[str] = "ANN"
|
|
1527
1895
|
|
|
1528
1896
|
|
|
1897
|
+
class InstructionAwareRerankModel(BaseModel):
|
|
1898
|
+
"""
|
|
1899
|
+
LLM-based reranking considering user instructions and constraints.
|
|
1900
|
+
|
|
1901
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~100ms).
|
|
1902
|
+
Runs AFTER FlashRank as an additional constraint-aware reranking stage.
|
|
1903
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
1904
|
+
|
|
1905
|
+
Example:
|
|
1906
|
+
```yaml
|
|
1907
|
+
rerank:
|
|
1908
|
+
model: ms-marco-MiniLM-L-12-v2
|
|
1909
|
+
top_n: 20
|
|
1910
|
+
instruction_aware:
|
|
1911
|
+
model: *fast_llm
|
|
1912
|
+
instructions: |
|
|
1913
|
+
Prioritize results matching price and brand constraints.
|
|
1914
|
+
top_n: 10
|
|
1915
|
+
```
|
|
1916
|
+
"""
|
|
1917
|
+
|
|
1918
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1919
|
+
|
|
1920
|
+
model: Optional["LLMModel"] = Field(
|
|
1921
|
+
default=None,
|
|
1922
|
+
description="LLM for instruction reranking (fast model recommended)",
|
|
1923
|
+
)
|
|
1924
|
+
instructions: Optional[str] = Field(
|
|
1925
|
+
default=None,
|
|
1926
|
+
description="Custom reranking instructions for constraint prioritization",
|
|
1927
|
+
)
|
|
1928
|
+
top_n: Optional[int] = Field(
|
|
1929
|
+
default=None,
|
|
1930
|
+
description="Number of documents to return after instruction reranking",
|
|
1931
|
+
)
|
|
1932
|
+
|
|
1933
|
+
|
|
1529
1934
|
class RerankParametersModel(BaseModel):
|
|
1530
1935
|
"""
|
|
1531
|
-
Configuration for reranking retrieved documents
|
|
1936
|
+
Configuration for reranking retrieved documents.
|
|
1532
1937
|
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1938
|
+
Supports three reranking options that can be combined:
|
|
1939
|
+
1. FlashRank (local cross-encoder) - set `model`
|
|
1940
|
+
2. Databricks server-side reranking - set `columns`
|
|
1941
|
+
3. LLM instruction-aware reranking - set `instruction_aware`
|
|
1536
1942
|
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1943
|
+
Example with Databricks columns + instruction-aware (no FlashRank):
|
|
1944
|
+
```yaml
|
|
1945
|
+
rerank:
|
|
1946
|
+
columns: # Databricks server-side reranking
|
|
1947
|
+
- product_name
|
|
1948
|
+
- brand_name
|
|
1949
|
+
instruction_aware: # LLM-based constraint reranking
|
|
1950
|
+
model: *fast_llm
|
|
1951
|
+
instructions: "Prioritize by brand preferences"
|
|
1952
|
+
top_n: 10
|
|
1953
|
+
```
|
|
1541
1954
|
|
|
1542
|
-
Example:
|
|
1955
|
+
Example with FlashRank:
|
|
1543
1956
|
```yaml
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
rerank:
|
|
1548
|
-
model: ms-marco-MiniLM-L-12-v2
|
|
1549
|
-
top_n: 5 # Return top 5 after reranking
|
|
1957
|
+
rerank:
|
|
1958
|
+
model: ms-marco-MiniLM-L-12-v2 # FlashRank model
|
|
1959
|
+
top_n: 10
|
|
1550
1960
|
```
|
|
1551
1961
|
|
|
1552
|
-
Available models (
|
|
1553
|
-
- "ms-marco-TinyBERT-L-2-v2" (
|
|
1554
|
-
- "ms-marco-MiniLM-L-
|
|
1555
|
-
- "
|
|
1556
|
-
- "
|
|
1962
|
+
Available FlashRank models (see https://github.com/PrithivirajDamodaran/FlashRank):
|
|
1963
|
+
- "ms-marco-TinyBERT-L-2-v2" (~4MB, fastest)
|
|
1964
|
+
- "ms-marco-MiniLM-L-12-v2" (~34MB, best cross-encoder)
|
|
1965
|
+
- "rank-T5-flan" (~110MB, best non cross-encoder)
|
|
1966
|
+
- "ms-marco-MultiBERT-L-12" (~150MB, multilingual 100+ languages)
|
|
1557
1967
|
"""
|
|
1558
1968
|
|
|
1559
1969
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1560
1970
|
|
|
1561
|
-
model: str = Field(
|
|
1562
|
-
default=
|
|
1563
|
-
description="FlashRank model name.
|
|
1971
|
+
model: Optional[str] = Field(
|
|
1972
|
+
default=None,
|
|
1973
|
+
description="FlashRank model name. If None, FlashRank is not used (use columns for Databricks reranking).",
|
|
1564
1974
|
)
|
|
1565
1975
|
top_n: Optional[int] = Field(
|
|
1566
1976
|
default=None,
|
|
@@ -1573,6 +1983,289 @@ class RerankParametersModel(BaseModel):
|
|
|
1573
1983
|
columns: Optional[list[str]] = Field(
|
|
1574
1984
|
default_factory=list, description="Columns to rerank using DatabricksReranker"
|
|
1575
1985
|
)
|
|
1986
|
+
instruction_aware: Optional[InstructionAwareRerankModel] = Field(
|
|
1987
|
+
default=None,
|
|
1988
|
+
description="Optional LLM-based reranking stage after FlashRank",
|
|
1989
|
+
)
|
|
1990
|
+
|
|
1991
|
+
|
|
1992
|
+
class FilterItem(BaseModel):
|
|
1993
|
+
"""A metadata filter for vector search.
|
|
1994
|
+
|
|
1995
|
+
Filters constrain search results by matching column values.
|
|
1996
|
+
Use column names from the provided schema description.
|
|
1997
|
+
"""
|
|
1998
|
+
|
|
1999
|
+
model_config = ConfigDict(extra="forbid")
|
|
2000
|
+
key: str = Field(
|
|
2001
|
+
description=(
|
|
2002
|
+
"Column name with optional operator suffix. "
|
|
2003
|
+
"Operators: (none) for equality, NOT for exclusion, "
|
|
2004
|
+
"< <= > >= for numeric comparison, "
|
|
2005
|
+
"LIKE for token match, NOT LIKE to exclude tokens."
|
|
2006
|
+
)
|
|
2007
|
+
)
|
|
2008
|
+
value: Union[str, int, float, bool, list[Union[str, int, float, bool]]] = Field(
|
|
2009
|
+
description=(
|
|
2010
|
+
"The filter value matching the column type. "
|
|
2011
|
+
"Use an array for IN-style matching multiple values."
|
|
2012
|
+
)
|
|
2013
|
+
)
|
|
2014
|
+
|
|
2015
|
+
|
|
2016
|
+
class SearchQuery(BaseModel):
|
|
2017
|
+
"""A single search query with optional metadata filters.
|
|
2018
|
+
|
|
2019
|
+
Represents one focused search intent extracted from the user's request.
|
|
2020
|
+
The text should be a natural language query optimized for semantic search.
|
|
2021
|
+
Filters constrain results to match specific metadata values.
|
|
2022
|
+
"""
|
|
2023
|
+
|
|
2024
|
+
model_config = ConfigDict(extra="forbid")
|
|
2025
|
+
text: str = Field(
|
|
2026
|
+
description=(
|
|
2027
|
+
"Natural language search query text optimized for semantic similarity. "
|
|
2028
|
+
"Should be focused on a single search intent. "
|
|
2029
|
+
"Do NOT include filter criteria in the text; use the filters field instead."
|
|
2030
|
+
)
|
|
2031
|
+
)
|
|
2032
|
+
filters: Optional[list[FilterItem]] = Field(
|
|
2033
|
+
default=None,
|
|
2034
|
+
description=(
|
|
2035
|
+
"Metadata filters to constrain search results. "
|
|
2036
|
+
"Set to null if no filters apply. "
|
|
2037
|
+
"Extract filter values from explicit constraints in the user query."
|
|
2038
|
+
),
|
|
2039
|
+
)
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
class DecomposedQueries(BaseModel):
|
|
2043
|
+
"""Decomposed search queries extracted from a user request.
|
|
2044
|
+
|
|
2045
|
+
Break down complex user queries into multiple focused search queries.
|
|
2046
|
+
Each query targets a distinct search intent with appropriate filters.
|
|
2047
|
+
Generate 1-3 queries depending on the complexity of the user request.
|
|
2048
|
+
"""
|
|
2049
|
+
|
|
2050
|
+
model_config = ConfigDict(extra="forbid")
|
|
2051
|
+
queries: list[SearchQuery] = Field(
|
|
2052
|
+
description=(
|
|
2053
|
+
"List of search queries extracted from the user request. "
|
|
2054
|
+
"Each query should target a distinct search intent. "
|
|
2055
|
+
"Order queries by importance, with the most relevant first."
|
|
2056
|
+
)
|
|
2057
|
+
)
|
|
2058
|
+
|
|
2059
|
+
|
|
2060
|
+
class ColumnInfo(BaseModel):
|
|
2061
|
+
"""Column metadata for dynamic schema generation in structured output.
|
|
2062
|
+
|
|
2063
|
+
When provided, column information is embedded directly into the JSON schema
|
|
2064
|
+
that with_structured_output sends to the LLM, improving filter accuracy.
|
|
2065
|
+
"""
|
|
2066
|
+
|
|
2067
|
+
model_config = ConfigDict(extra="forbid")
|
|
2068
|
+
|
|
2069
|
+
name: str = Field(description="Column name as it appears in the database")
|
|
2070
|
+
type: Literal["string", "number", "boolean", "datetime"] = Field(
|
|
2071
|
+
default="string",
|
|
2072
|
+
description="Column data type for value validation",
|
|
2073
|
+
)
|
|
2074
|
+
operators: list[str] = Field(
|
|
2075
|
+
default=["", "NOT", "<", "<=", ">", ">=", "LIKE", "NOT LIKE"],
|
|
2076
|
+
description="Valid filter operators for this column",
|
|
2077
|
+
)
|
|
2078
|
+
|
|
2079
|
+
|
|
2080
|
+
class InstructedRetrieverModel(BaseModel):
|
|
2081
|
+
"""
|
|
2082
|
+
Configuration for instructed retrieval with query decomposition and RRF merging.
|
|
2083
|
+
|
|
2084
|
+
Instructed retrieval decomposes user queries into multiple subqueries with
|
|
2085
|
+
metadata filters, executes them in parallel, and merges results using
|
|
2086
|
+
Reciprocal Rank Fusion (RRF) before reranking.
|
|
2087
|
+
|
|
2088
|
+
Example:
|
|
2089
|
+
```yaml
|
|
2090
|
+
retriever:
|
|
2091
|
+
vector_store: *products_vector_store
|
|
2092
|
+
instructed:
|
|
2093
|
+
decomposition_model: *fast_llm
|
|
2094
|
+
schema_description: |
|
|
2095
|
+
Products table: product_id, brand_name, category, price, updated_at
|
|
2096
|
+
Filter operators: {"col": val}, {"col >": val}, {"col NOT": val}
|
|
2097
|
+
columns:
|
|
2098
|
+
- name: brand_name
|
|
2099
|
+
type: string
|
|
2100
|
+
- name: price
|
|
2101
|
+
type: number
|
|
2102
|
+
operators: ["", "<", "<=", ">", ">="]
|
|
2103
|
+
constraints:
|
|
2104
|
+
- "Prefer recent products"
|
|
2105
|
+
max_subqueries: 3
|
|
2106
|
+
examples:
|
|
2107
|
+
- query: "cheap drills"
|
|
2108
|
+
filters: {"price <": 100}
|
|
2109
|
+
```
|
|
2110
|
+
"""
|
|
2111
|
+
|
|
2112
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2113
|
+
|
|
2114
|
+
decomposition_model: Optional["LLMModel"] = Field(
|
|
2115
|
+
default=None,
|
|
2116
|
+
description="LLM for query decomposition (smaller/faster model recommended)",
|
|
2117
|
+
)
|
|
2118
|
+
schema_description: str = Field(
|
|
2119
|
+
description="Column names, types, and valid filter syntax for the LLM"
|
|
2120
|
+
)
|
|
2121
|
+
columns: Optional[list[ColumnInfo]] = Field(
|
|
2122
|
+
default=None,
|
|
2123
|
+
description=(
|
|
2124
|
+
"Structured column info for dynamic schema generation. "
|
|
2125
|
+
"When provided, column names are embedded in the JSON schema for better LLM accuracy."
|
|
2126
|
+
),
|
|
2127
|
+
)
|
|
2128
|
+
constraints: Optional[list[str]] = Field(
|
|
2129
|
+
default=None, description="Default constraints to always apply"
|
|
2130
|
+
)
|
|
2131
|
+
max_subqueries: int = Field(
|
|
2132
|
+
default=3, description="Maximum number of parallel subqueries"
|
|
2133
|
+
)
|
|
2134
|
+
rrf_k: int = Field(
|
|
2135
|
+
default=60,
|
|
2136
|
+
description="RRF constant (lower values weight top ranks more heavily)",
|
|
2137
|
+
)
|
|
2138
|
+
examples: Optional[list[dict[str, Any]]] = Field(
|
|
2139
|
+
default=None,
|
|
2140
|
+
description="Few-shot examples for domain-specific filter translation",
|
|
2141
|
+
)
|
|
2142
|
+
normalize_filter_case: Optional[Literal["uppercase", "lowercase"]] = Field(
|
|
2143
|
+
default=None,
|
|
2144
|
+
description="Auto-normalize filter string values to uppercase or lowercase",
|
|
2145
|
+
)
|
|
2146
|
+
|
|
2147
|
+
|
|
2148
|
+
class RouterModel(BaseModel):
|
|
2149
|
+
"""
|
|
2150
|
+
Select internal execution mode based on query characteristics.
|
|
2151
|
+
|
|
2152
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2153
|
+
Routes to internal modes within the same retriever, not external retrievers.
|
|
2154
|
+
Cross-index routing belongs at the agent/tool-selection level.
|
|
2155
|
+
|
|
2156
|
+
Execution Modes:
|
|
2157
|
+
- "standard": Single similarity_search() for simple keyword/product searches
|
|
2158
|
+
- "instructed": Decompose -> Parallel Search -> RRF for constrained queries
|
|
2159
|
+
|
|
2160
|
+
Example:
|
|
2161
|
+
```yaml
|
|
2162
|
+
retriever:
|
|
2163
|
+
router:
|
|
2164
|
+
model: *fast_llm
|
|
2165
|
+
default_mode: standard
|
|
2166
|
+
auto_bypass: true
|
|
2167
|
+
```
|
|
2168
|
+
"""
|
|
2169
|
+
|
|
2170
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2171
|
+
|
|
2172
|
+
model: Optional["LLMModel"] = Field(
|
|
2173
|
+
default=None,
|
|
2174
|
+
description="LLM for routing decision (fast model recommended)",
|
|
2175
|
+
)
|
|
2176
|
+
default_mode: Literal["standard", "instructed"] = Field(
|
|
2177
|
+
default="standard",
|
|
2178
|
+
description="Fallback mode if routing fails",
|
|
2179
|
+
)
|
|
2180
|
+
auto_bypass: bool = Field(
|
|
2181
|
+
default=True,
|
|
2182
|
+
description="Skip Instruction Reranker and Verifier for standard mode",
|
|
2183
|
+
)
|
|
2184
|
+
|
|
2185
|
+
|
|
2186
|
+
class VerificationResult(BaseModel):
|
|
2187
|
+
"""Verification of whether search results satisfy the user's constraints.
|
|
2188
|
+
|
|
2189
|
+
Analyze the retrieved results against the original query and any explicit
|
|
2190
|
+
constraints to determine if a retry with modified filters is needed.
|
|
2191
|
+
"""
|
|
2192
|
+
|
|
2193
|
+
model_config = ConfigDict(extra="forbid")
|
|
2194
|
+
|
|
2195
|
+
passed: bool = Field(
|
|
2196
|
+
description="True if results satisfy the user's query intent and constraints."
|
|
2197
|
+
)
|
|
2198
|
+
confidence: float = Field(
|
|
2199
|
+
ge=0.0,
|
|
2200
|
+
le=1.0,
|
|
2201
|
+
description="Confidence in the verification decision, from 0.0 (uncertain) to 1.0 (certain).",
|
|
2202
|
+
)
|
|
2203
|
+
feedback: Optional[str] = Field(
|
|
2204
|
+
default=None,
|
|
2205
|
+
description="Explanation of why verification passed or failed. Include specific issues found.",
|
|
2206
|
+
)
|
|
2207
|
+
suggested_filter_relaxation: Optional[dict[str, Any]] = Field(
|
|
2208
|
+
default=None,
|
|
2209
|
+
description=(
|
|
2210
|
+
"Suggested filter modifications for retry. "
|
|
2211
|
+
"Keys are column names, values indicate changes (e.g., 'REMOVE', 'WIDEN', or new values)."
|
|
2212
|
+
),
|
|
2213
|
+
)
|
|
2214
|
+
unmet_constraints: Optional[list[str]] = Field(
|
|
2215
|
+
default=None,
|
|
2216
|
+
description="List of user constraints that the results failed to satisfy.",
|
|
2217
|
+
)
|
|
2218
|
+
|
|
2219
|
+
|
|
2220
|
+
class VerifierModel(BaseModel):
|
|
2221
|
+
"""
|
|
2222
|
+
Validate results against user constraints with structured feedback.
|
|
2223
|
+
|
|
2224
|
+
Use fast models (GPT-3.5, Haiku, Llama 3 8B) to minimize latency (~50-100ms).
|
|
2225
|
+
Skipped for 'standard' mode when auto_bypass=true in router config.
|
|
2226
|
+
Returns structured feedback for intelligent retry, not blind retry.
|
|
2227
|
+
|
|
2228
|
+
Example:
|
|
2229
|
+
```yaml
|
|
2230
|
+
retriever:
|
|
2231
|
+
verifier:
|
|
2232
|
+
model: *fast_llm
|
|
2233
|
+
on_failure: warn_and_retry
|
|
2234
|
+
max_retries: 1
|
|
2235
|
+
```
|
|
2236
|
+
"""
|
|
2237
|
+
|
|
2238
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2239
|
+
|
|
2240
|
+
model: Optional["LLMModel"] = Field(
|
|
2241
|
+
default=None,
|
|
2242
|
+
description="LLM for verification (fast model recommended)",
|
|
2243
|
+
)
|
|
2244
|
+
on_failure: Literal["warn", "retry", "warn_and_retry"] = Field(
|
|
2245
|
+
default="warn",
|
|
2246
|
+
description="Behavior when verification fails",
|
|
2247
|
+
)
|
|
2248
|
+
max_retries: int = Field(
|
|
2249
|
+
default=1,
|
|
2250
|
+
description="Maximum retry attempts before returning with warning",
|
|
2251
|
+
)
|
|
2252
|
+
|
|
2253
|
+
|
|
2254
|
+
class RankedDocument(BaseModel):
|
|
2255
|
+
"""Single ranked document."""
|
|
2256
|
+
|
|
2257
|
+
index: int = Field(description="Document index from input list")
|
|
2258
|
+
score: float = Field(description="0.0-1.0 relevance score")
|
|
2259
|
+
reason: str = Field(default="", description="Why this score")
|
|
2260
|
+
|
|
2261
|
+
|
|
2262
|
+
class RankingResult(BaseModel):
|
|
2263
|
+
"""Reranking output."""
|
|
2264
|
+
|
|
2265
|
+
rankings: list[RankedDocument] = Field(
|
|
2266
|
+
default_factory=list,
|
|
2267
|
+
description="Ranked documents, highest score first",
|
|
2268
|
+
)
|
|
1576
2269
|
|
|
1577
2270
|
|
|
1578
2271
|
class RetrieverModel(BaseModel):
|
|
@@ -1582,10 +2275,22 @@ class RetrieverModel(BaseModel):
|
|
|
1582
2275
|
search_parameters: SearchParametersModel = Field(
|
|
1583
2276
|
default_factory=SearchParametersModel
|
|
1584
2277
|
)
|
|
2278
|
+
router: Optional[RouterModel] = Field(
|
|
2279
|
+
default=None,
|
|
2280
|
+
description="Optional query router for selecting execution mode (standard vs instructed).",
|
|
2281
|
+
)
|
|
1585
2282
|
rerank: Optional[RerankParametersModel | bool] = Field(
|
|
1586
2283
|
default=None,
|
|
1587
2284
|
description="Optional reranking configuration. Set to true for defaults, or provide ReRankParametersModel for custom settings.",
|
|
1588
2285
|
)
|
|
2286
|
+
instructed: Optional[InstructedRetrieverModel] = Field(
|
|
2287
|
+
default=None,
|
|
2288
|
+
description="Optional instructed retrieval with query decomposition and RRF merging.",
|
|
2289
|
+
)
|
|
2290
|
+
verifier: Optional[VerifierModel] = Field(
|
|
2291
|
+
default=None,
|
|
2292
|
+
description="Optional result verification with structured feedback for retry.",
|
|
2293
|
+
)
|
|
1589
2294
|
|
|
1590
2295
|
@model_validator(mode="after")
|
|
1591
2296
|
def set_default_columns(self) -> Self:
|
|
@@ -1596,9 +2301,13 @@ class RetrieverModel(BaseModel):
|
|
|
1596
2301
|
|
|
1597
2302
|
@model_validator(mode="after")
|
|
1598
2303
|
def set_default_reranker(self) -> Self:
|
|
1599
|
-
"""Convert bool to ReRankParametersModel with defaults.
|
|
2304
|
+
"""Convert bool to ReRankParametersModel with defaults.
|
|
2305
|
+
|
|
2306
|
+
When rerank: true is used, sets the default FlashRank model
|
|
2307
|
+
(ms-marco-MiniLM-L-12-v2) to enable reranking.
|
|
2308
|
+
"""
|
|
1600
2309
|
if isinstance(self.rerank, bool) and self.rerank:
|
|
1601
|
-
self.rerank = RerankParametersModel()
|
|
2310
|
+
self.rerank = RerankParametersModel(model="ms-marco-MiniLM-L-12-v2")
|
|
1602
2311
|
return self
|
|
1603
2312
|
|
|
1604
2313
|
|
|
@@ -1731,11 +2440,32 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1731
2440
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
1732
2441
|
args: list[str] = Field(default_factory=list)
|
|
1733
2442
|
# MCP-specific fields
|
|
2443
|
+
app: Optional[DatabricksAppModel] = None
|
|
1734
2444
|
connection: Optional[ConnectionModel] = None
|
|
1735
2445
|
functions: Optional[SchemaModel] = None
|
|
1736
2446
|
genie_room: Optional[GenieRoomModel] = None
|
|
1737
2447
|
sql: Optional[bool] = None
|
|
1738
2448
|
vector_search: Optional[VectorStoreModel] = None
|
|
2449
|
+
# Tool filtering
|
|
2450
|
+
include_tools: Optional[list[str]] = Field(
|
|
2451
|
+
default=None,
|
|
2452
|
+
description=(
|
|
2453
|
+
"Optional list of tool names or glob patterns to include from the MCP server. "
|
|
2454
|
+
"If specified, only tools matching these patterns will be loaded. "
|
|
2455
|
+
"Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
|
|
2456
|
+
"Examples: ['execute_query', 'list_*', 'get_?_data']"
|
|
2457
|
+
),
|
|
2458
|
+
)
|
|
2459
|
+
exclude_tools: Optional[list[str]] = Field(
|
|
2460
|
+
default=None,
|
|
2461
|
+
description=(
|
|
2462
|
+
"Optional list of tool names or glob patterns to exclude from the MCP server. "
|
|
2463
|
+
"Tools matching these patterns will not be loaded. "
|
|
2464
|
+
"Takes precedence over include_tools. "
|
|
2465
|
+
"Supports glob patterns: * (any chars), ? (single char), [abc] (char set). "
|
|
2466
|
+
"Examples: ['drop_*', 'delete_*', 'execute_ddl']"
|
|
2467
|
+
),
|
|
2468
|
+
)
|
|
1739
2469
|
|
|
1740
2470
|
@property
|
|
1741
2471
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -1798,6 +2528,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1798
2528
|
|
|
1799
2529
|
Returns the URL based on the configured source:
|
|
1800
2530
|
- If url is set, returns it directly
|
|
2531
|
+
- If app is set, retrieves URL from Databricks App via workspace client
|
|
1801
2532
|
- If connection is set, constructs URL from connection
|
|
1802
2533
|
- If genie_room is set, constructs Genie MCP URL
|
|
1803
2534
|
- If sql is set, constructs DBSQL MCP URL (serverless)
|
|
@@ -1810,6 +2541,7 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1810
2541
|
- Vector Search: https://{host}/api/2.0/mcp/vector-search/{catalog}/{schema}
|
|
1811
2542
|
- UC Functions: https://{host}/api/2.0/mcp/functions/{catalog}/{schema}
|
|
1812
2543
|
- Connection: https://{host}/api/2.0/mcp/external/{connection_name}
|
|
2544
|
+
- Databricks App: Retrieved dynamically from workspace
|
|
1813
2545
|
"""
|
|
1814
2546
|
# Direct URL provided
|
|
1815
2547
|
if self.url:
|
|
@@ -1832,6 +2564,49 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1832
2564
|
if self.sql:
|
|
1833
2565
|
return f"{workspace_host}/api/2.0/mcp/sql"
|
|
1834
2566
|
|
|
2567
|
+
# Databricks App - MCP endpoint is at {app_url}/mcp
|
|
2568
|
+
# Try McpFunctionModel's workspace_client first (which may have credentials),
|
|
2569
|
+
# then fall back to DatabricksAppModel.url property (which uses its own workspace_client)
|
|
2570
|
+
if self.app:
|
|
2571
|
+
from databricks.sdk.service.apps import App
|
|
2572
|
+
|
|
2573
|
+
app_url: str | None = None
|
|
2574
|
+
|
|
2575
|
+
# First, try using McpFunctionModel's workspace_client
|
|
2576
|
+
try:
|
|
2577
|
+
app: App = self.workspace_client.apps.get(self.app.name)
|
|
2578
|
+
app_url = app.url
|
|
2579
|
+
logger.trace(
|
|
2580
|
+
"Got app URL using McpFunctionModel workspace_client",
|
|
2581
|
+
app_name=self.app.name,
|
|
2582
|
+
url=app_url,
|
|
2583
|
+
)
|
|
2584
|
+
except Exception as e:
|
|
2585
|
+
logger.debug(
|
|
2586
|
+
"Failed to get app URL using McpFunctionModel workspace_client, "
|
|
2587
|
+
"trying DatabricksAppModel.url property",
|
|
2588
|
+
app_name=self.app.name,
|
|
2589
|
+
error=str(e),
|
|
2590
|
+
)
|
|
2591
|
+
|
|
2592
|
+
# Fall back to DatabricksAppModel.url property
|
|
2593
|
+
if not app_url:
|
|
2594
|
+
try:
|
|
2595
|
+
app_url = self.app.url
|
|
2596
|
+
logger.trace(
|
|
2597
|
+
"Got app URL using DatabricksAppModel.url property",
|
|
2598
|
+
app_name=self.app.name,
|
|
2599
|
+
url=app_url,
|
|
2600
|
+
)
|
|
2601
|
+
except Exception as e:
|
|
2602
|
+
raise RuntimeError(
|
|
2603
|
+
f"Databricks App '{self.app.name}' does not have a URL. "
|
|
2604
|
+
"The app may not be deployed yet, or credentials may be invalid. "
|
|
2605
|
+
f"Error: {e}"
|
|
2606
|
+
) from e
|
|
2607
|
+
|
|
2608
|
+
return f"{app_url.rstrip('/')}/mcp"
|
|
2609
|
+
|
|
1835
2610
|
# Vector Search
|
|
1836
2611
|
if self.vector_search:
|
|
1837
2612
|
if (
|
|
@@ -1841,33 +2616,35 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1841
2616
|
raise ValueError(
|
|
1842
2617
|
"vector_search must have an index with a schema (catalog/schema) configured"
|
|
1843
2618
|
)
|
|
1844
|
-
catalog: str = self.vector_search.index.schema_model.catalog_name
|
|
1845
|
-
schema: str = self.vector_search.index.schema_model.schema_name
|
|
2619
|
+
catalog: str = value_of(self.vector_search.index.schema_model.catalog_name)
|
|
2620
|
+
schema: str = value_of(self.vector_search.index.schema_model.schema_name)
|
|
1846
2621
|
return f"{workspace_host}/api/2.0/mcp/vector-search/{catalog}/{schema}"
|
|
1847
2622
|
|
|
1848
2623
|
# UC Functions MCP server
|
|
1849
2624
|
if self.functions:
|
|
1850
|
-
catalog: str = self.functions.catalog_name
|
|
1851
|
-
schema: str = self.functions.schema_name
|
|
2625
|
+
catalog: str = value_of(self.functions.catalog_name)
|
|
2626
|
+
schema: str = value_of(self.functions.schema_name)
|
|
1852
2627
|
return f"{workspace_host}/api/2.0/mcp/functions/{catalog}/{schema}"
|
|
1853
2628
|
|
|
1854
2629
|
raise ValueError(
|
|
1855
|
-
"No URL source configured. Provide one of: url, connection, genie_room, "
|
|
2630
|
+
"No URL source configured. Provide one of: url, app, connection, genie_room, "
|
|
1856
2631
|
"sql, vector_search, or functions"
|
|
1857
2632
|
)
|
|
1858
2633
|
|
|
1859
2634
|
@field_serializer("transport")
|
|
1860
|
-
def serialize_transport(self, value) -> str:
|
|
2635
|
+
def serialize_transport(self, value: TransportType) -> str:
|
|
2636
|
+
"""Serialize transport enum to string."""
|
|
1861
2637
|
if isinstance(value, TransportType):
|
|
1862
2638
|
return value.value
|
|
1863
2639
|
return str(value)
|
|
1864
2640
|
|
|
1865
2641
|
@model_validator(mode="after")
|
|
1866
|
-
def validate_mutually_exclusive(self) ->
|
|
2642
|
+
def validate_mutually_exclusive(self) -> Self:
|
|
1867
2643
|
"""Validate that exactly one URL source is provided."""
|
|
1868
2644
|
# Count how many URL sources are provided
|
|
1869
2645
|
url_sources: list[tuple[str, Any]] = [
|
|
1870
2646
|
("url", self.url),
|
|
2647
|
+
("app", self.app),
|
|
1871
2648
|
("connection", self.connection),
|
|
1872
2649
|
("genie_room", self.genie_room),
|
|
1873
2650
|
("sql", self.sql),
|
|
@@ -1883,13 +2660,13 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1883
2660
|
if len(provided_sources) == 0:
|
|
1884
2661
|
raise ValueError(
|
|
1885
2662
|
"For STREAMABLE_HTTP transport, exactly one of the following must be provided: "
|
|
1886
|
-
"url, connection, genie_room, sql, vector_search, or functions"
|
|
2663
|
+
"url, app, connection, genie_room, sql, vector_search, or functions"
|
|
1887
2664
|
)
|
|
1888
2665
|
if len(provided_sources) > 1:
|
|
1889
2666
|
raise ValueError(
|
|
1890
2667
|
f"For STREAMABLE_HTTP transport, only one URL source can be provided. "
|
|
1891
2668
|
f"Found: {', '.join(provided_sources)}. "
|
|
1892
|
-
f"Please provide only one of: url, connection, genie_room, sql, vector_search, or functions"
|
|
2669
|
+
f"Please provide only one of: url, app, connection, genie_room, sql, vector_search, or functions"
|
|
1893
2670
|
)
|
|
1894
2671
|
|
|
1895
2672
|
if self.transport == TransportType.STDIO:
|
|
@@ -1901,14 +2678,41 @@ class McpFunctionModel(BaseFunctionModel, IsDatabricksResource):
|
|
|
1901
2678
|
return self
|
|
1902
2679
|
|
|
1903
2680
|
@model_validator(mode="after")
|
|
1904
|
-
def update_url(self) ->
|
|
1905
|
-
|
|
2681
|
+
def update_url(self) -> Self:
|
|
2682
|
+
"""Resolve AnyVariable to concrete value for URL."""
|
|
2683
|
+
if self.url is not None:
|
|
2684
|
+
resolved_value: Any = value_of(self.url)
|
|
2685
|
+
# Cast to string since URL must be a string
|
|
2686
|
+
self.url = str(resolved_value) if resolved_value else None
|
|
1906
2687
|
return self
|
|
1907
2688
|
|
|
1908
2689
|
@model_validator(mode="after")
|
|
1909
|
-
def update_headers(self) ->
|
|
2690
|
+
def update_headers(self) -> Self:
|
|
2691
|
+
"""Resolve AnyVariable to concrete values for headers."""
|
|
1910
2692
|
for key, value in self.headers.items():
|
|
1911
|
-
|
|
2693
|
+
resolved_value: Any = value_of(value)
|
|
2694
|
+
# Headers must be strings
|
|
2695
|
+
self.headers[key] = str(resolved_value) if resolved_value else ""
|
|
2696
|
+
return self
|
|
2697
|
+
|
|
2698
|
+
@model_validator(mode="after")
|
|
2699
|
+
def validate_tool_filters(self) -> Self:
|
|
2700
|
+
"""Validate tool filter configuration."""
|
|
2701
|
+
from loguru import logger
|
|
2702
|
+
|
|
2703
|
+
# Warn if both are empty lists (explicit but pointless)
|
|
2704
|
+
if self.include_tools is not None and len(self.include_tools) == 0:
|
|
2705
|
+
logger.warning(
|
|
2706
|
+
"include_tools is empty list - no tools will be loaded. "
|
|
2707
|
+
"Remove field to load all tools."
|
|
2708
|
+
)
|
|
2709
|
+
|
|
2710
|
+
if self.exclude_tools is not None and len(self.exclude_tools) == 0:
|
|
2711
|
+
logger.warning(
|
|
2712
|
+
"exclude_tools is empty list - has no effect. "
|
|
2713
|
+
"Remove field or add patterns."
|
|
2714
|
+
)
|
|
2715
|
+
|
|
1912
2716
|
return self
|
|
1913
2717
|
|
|
1914
2718
|
def as_tools(self, **kwargs: Any) -> Sequence[RunnableLike]:
|
|
@@ -2316,7 +3120,6 @@ class SupervisorModel(BaseModel):
|
|
|
2316
3120
|
|
|
2317
3121
|
class SwarmModel(BaseModel):
|
|
2318
3122
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2319
|
-
model: LLMModel
|
|
2320
3123
|
default_agent: Optional[AgentModel | str] = None
|
|
2321
3124
|
middleware: list[MiddlewareModel] = Field(
|
|
2322
3125
|
default_factory=list,
|
|
@@ -2330,11 +3133,17 @@ class SwarmModel(BaseModel):
|
|
|
2330
3133
|
class OrchestrationModel(BaseModel):
|
|
2331
3134
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2332
3135
|
supervisor: Optional[SupervisorModel] = None
|
|
2333
|
-
swarm: Optional[SwarmModel] = None
|
|
3136
|
+
swarm: Optional[SwarmModel | Literal[True]] = None
|
|
2334
3137
|
memory: Optional[MemoryModel] = None
|
|
2335
3138
|
|
|
2336
3139
|
@model_validator(mode="after")
|
|
2337
|
-
def
|
|
3140
|
+
def validate_and_normalize(self) -> Self:
|
|
3141
|
+
"""Validate orchestration and normalize swarm shorthand."""
|
|
3142
|
+
# Convert swarm: true to SwarmModel()
|
|
3143
|
+
if self.swarm is True:
|
|
3144
|
+
self.swarm = SwarmModel()
|
|
3145
|
+
|
|
3146
|
+
# Validate mutually exclusive
|
|
2338
3147
|
if self.supervisor is not None and self.swarm is not None:
|
|
2339
3148
|
raise ValueError("Cannot specify both supervisor and swarm")
|
|
2340
3149
|
if self.supervisor is None and self.swarm is None:
|
|
@@ -2544,6 +3353,11 @@ class AppModel(BaseModel):
|
|
|
2544
3353
|
"which is supported by Databricks Model Serving. This allows deploying from "
|
|
2545
3354
|
"environments with different Python versions (e.g., Databricks Apps with 3.11).",
|
|
2546
3355
|
)
|
|
3356
|
+
deployment_target: Optional[DeploymentTarget] = Field(
|
|
3357
|
+
default=None,
|
|
3358
|
+
description="Default deployment target. If not specified, defaults to MODEL_SERVING. "
|
|
3359
|
+
"Can be overridden via CLI --target flag. Options: 'model_serving' or 'apps'.",
|
|
3360
|
+
)
|
|
2547
3361
|
|
|
2548
3362
|
@model_validator(mode="after")
|
|
2549
3363
|
def set_databricks_env_vars(self) -> Self:
|
|
@@ -2601,9 +3415,7 @@ class AppModel(BaseModel):
|
|
|
2601
3415
|
elif len(self.agents) == 1:
|
|
2602
3416
|
default_agent: AgentModel = self.agents[0]
|
|
2603
3417
|
self.orchestration = OrchestrationModel(
|
|
2604
|
-
swarm=SwarmModel(
|
|
2605
|
-
model=default_agent.model, default_agent=default_agent
|
|
2606
|
-
)
|
|
3418
|
+
swarm=SwarmModel(default_agent=default_agent)
|
|
2607
3419
|
)
|
|
2608
3420
|
else:
|
|
2609
3421
|
raise ValueError("At least one agent must be specified")
|
|
@@ -2643,8 +3455,24 @@ class GuidelineModel(BaseModel):
|
|
|
2643
3455
|
|
|
2644
3456
|
|
|
2645
3457
|
class EvaluationModel(BaseModel):
|
|
3458
|
+
"""
|
|
3459
|
+
Configuration for MLflow GenAI evaluation.
|
|
3460
|
+
|
|
3461
|
+
Attributes:
|
|
3462
|
+
model: LLM model used as the judge for LLM-based scorers (e.g., Guidelines, Safety).
|
|
3463
|
+
This model evaluates agent responses during evaluation.
|
|
3464
|
+
table: Table to store evaluation results.
|
|
3465
|
+
num_evals: Number of evaluation samples to generate.
|
|
3466
|
+
agent_description: Description of the agent for evaluation data generation.
|
|
3467
|
+
question_guidelines: Guidelines for generating evaluation questions.
|
|
3468
|
+
custom_inputs: Custom inputs to pass to the agent during evaluation.
|
|
3469
|
+
guidelines: List of guideline configurations for Guidelines scorers.
|
|
3470
|
+
"""
|
|
3471
|
+
|
|
2646
3472
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
2647
|
-
model: LLMModel
|
|
3473
|
+
model: LLMModel = Field(
|
|
3474
|
+
..., description="LLM model used as the judge for LLM-based evaluation scorers"
|
|
3475
|
+
)
|
|
2648
3476
|
table: TableModel
|
|
2649
3477
|
num_evals: int
|
|
2650
3478
|
agent_description: Optional[str] = None
|
|
@@ -2652,6 +3480,16 @@ class EvaluationModel(BaseModel):
|
|
|
2652
3480
|
custom_inputs: dict[str, Any] = Field(default_factory=dict)
|
|
2653
3481
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
2654
3482
|
|
|
3483
|
+
@property
|
|
3484
|
+
def judge_model_endpoint(self) -> str:
|
|
3485
|
+
"""
|
|
3486
|
+
Get the judge model endpoint string for MLflow scorers.
|
|
3487
|
+
|
|
3488
|
+
Returns:
|
|
3489
|
+
Endpoint string in format 'databricks:/model-name'
|
|
3490
|
+
"""
|
|
3491
|
+
return f"databricks:/{self.model.name}"
|
|
3492
|
+
|
|
2655
3493
|
|
|
2656
3494
|
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
2657
3495
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -2849,6 +3687,165 @@ class OptimizationsModel(BaseModel):
|
|
|
2849
3687
|
return results
|
|
2850
3688
|
|
|
2851
3689
|
|
|
3690
|
+
class SemanticCacheEvalEntryModel(BaseModel):
|
|
3691
|
+
"""Single evaluation entry for semantic cache threshold optimization.
|
|
3692
|
+
|
|
3693
|
+
Represents a pair of question/context combinations to evaluate
|
|
3694
|
+
whether the cache should return a hit or miss.
|
|
3695
|
+
|
|
3696
|
+
Example:
|
|
3697
|
+
entry:
|
|
3698
|
+
question: "What are total sales?"
|
|
3699
|
+
question_embedding: [0.1, 0.2, ...] # Pre-computed
|
|
3700
|
+
context: "Previous: Show me revenue"
|
|
3701
|
+
context_embedding: [0.1, 0.2, ...]
|
|
3702
|
+
cached_question: "Show total sales"
|
|
3703
|
+
cached_question_embedding: [0.1, 0.2, ...]
|
|
3704
|
+
cached_context: "Previous: Show me revenue"
|
|
3705
|
+
cached_context_embedding: [0.1, 0.2, ...]
|
|
3706
|
+
expected_match: true
|
|
3707
|
+
"""
|
|
3708
|
+
|
|
3709
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3710
|
+
question: str
|
|
3711
|
+
question_embedding: list[float]
|
|
3712
|
+
context: str = ""
|
|
3713
|
+
context_embedding: list[float] = Field(default_factory=list)
|
|
3714
|
+
cached_question: str
|
|
3715
|
+
cached_question_embedding: list[float]
|
|
3716
|
+
cached_context: str = ""
|
|
3717
|
+
cached_context_embedding: list[float] = Field(default_factory=list)
|
|
3718
|
+
expected_match: Optional[bool] = None # None = use LLM judge
|
|
3719
|
+
|
|
3720
|
+
|
|
3721
|
+
class SemanticCacheEvalDatasetModel(BaseModel):
|
|
3722
|
+
"""Dataset for semantic cache threshold optimization.
|
|
3723
|
+
|
|
3724
|
+
Contains pairs of questions/contexts to evaluate whether thresholds
|
|
3725
|
+
correctly identify semantic matches.
|
|
3726
|
+
|
|
3727
|
+
Example:
|
|
3728
|
+
dataset:
|
|
3729
|
+
name: my_cache_eval_dataset
|
|
3730
|
+
description: "Evaluation data for cache tuning"
|
|
3731
|
+
entries:
|
|
3732
|
+
- question: "What are total sales?"
|
|
3733
|
+
# ... entry fields
|
|
3734
|
+
"""
|
|
3735
|
+
|
|
3736
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3737
|
+
name: str
|
|
3738
|
+
description: str = ""
|
|
3739
|
+
entries: list[SemanticCacheEvalEntryModel] = Field(default_factory=list)
|
|
3740
|
+
|
|
3741
|
+
def as_eval_dataset(self) -> "SemanticCacheEvalDataset":
|
|
3742
|
+
"""Convert to internal evaluation dataset format."""
|
|
3743
|
+
from dao_ai.genie.cache.optimization import (
|
|
3744
|
+
SemanticCacheEvalDataset,
|
|
3745
|
+
SemanticCacheEvalEntry,
|
|
3746
|
+
)
|
|
3747
|
+
|
|
3748
|
+
entries = [
|
|
3749
|
+
SemanticCacheEvalEntry(
|
|
3750
|
+
question=e.question,
|
|
3751
|
+
question_embedding=e.question_embedding,
|
|
3752
|
+
context=e.context,
|
|
3753
|
+
context_embedding=e.context_embedding,
|
|
3754
|
+
cached_question=e.cached_question,
|
|
3755
|
+
cached_question_embedding=e.cached_question_embedding,
|
|
3756
|
+
cached_context=e.cached_context,
|
|
3757
|
+
cached_context_embedding=e.cached_context_embedding,
|
|
3758
|
+
expected_match=e.expected_match,
|
|
3759
|
+
)
|
|
3760
|
+
for e in self.entries
|
|
3761
|
+
]
|
|
3762
|
+
|
|
3763
|
+
return SemanticCacheEvalDataset(
|
|
3764
|
+
name=self.name,
|
|
3765
|
+
entries=entries,
|
|
3766
|
+
description=self.description,
|
|
3767
|
+
)
|
|
3768
|
+
|
|
3769
|
+
|
|
3770
|
+
class SemanticCacheThresholdOptimizationModel(BaseModel):
|
|
3771
|
+
"""Configuration for semantic cache threshold optimization.
|
|
3772
|
+
|
|
3773
|
+
Uses Optuna Bayesian optimization to find optimal threshold values
|
|
3774
|
+
that maximize cache hit accuracy (F1 score by default).
|
|
3775
|
+
|
|
3776
|
+
Example:
|
|
3777
|
+
threshold_optimization:
|
|
3778
|
+
name: optimize_cache_thresholds
|
|
3779
|
+
cache_parameters: *my_cache_params
|
|
3780
|
+
dataset: *my_eval_dataset
|
|
3781
|
+
judge_model: databricks-meta-llama-3-3-70b-instruct
|
|
3782
|
+
n_trials: 50
|
|
3783
|
+
metric: f1
|
|
3784
|
+
"""
|
|
3785
|
+
|
|
3786
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
3787
|
+
name: str
|
|
3788
|
+
cache_parameters: Optional[GenieContextAwareCacheParametersModel] = None
|
|
3789
|
+
dataset: SemanticCacheEvalDatasetModel
|
|
3790
|
+
judge_model: Optional[LLMModel | str] = "databricks-meta-llama-3-3-70b-instruct"
|
|
3791
|
+
n_trials: int = 50
|
|
3792
|
+
metric: Literal["f1", "precision", "recall", "fbeta"] = "f1"
|
|
3793
|
+
beta: float = 1.0 # For fbeta metric
|
|
3794
|
+
seed: Optional[int] = None
|
|
3795
|
+
|
|
3796
|
+
def optimize(
|
|
3797
|
+
self, w: WorkspaceClient | None = None
|
|
3798
|
+
) -> "ThresholdOptimizationResult":
|
|
3799
|
+
"""
|
|
3800
|
+
Optimize semantic cache thresholds.
|
|
3801
|
+
|
|
3802
|
+
Args:
|
|
3803
|
+
w: Optional WorkspaceClient (not used, kept for API compatibility)
|
|
3804
|
+
|
|
3805
|
+
Returns:
|
|
3806
|
+
ThresholdOptimizationResult with optimized thresholds
|
|
3807
|
+
"""
|
|
3808
|
+
from dao_ai.genie.cache.optimization import (
|
|
3809
|
+
ThresholdOptimizationResult,
|
|
3810
|
+
optimize_semantic_cache_thresholds,
|
|
3811
|
+
)
|
|
3812
|
+
|
|
3813
|
+
# Convert dataset
|
|
3814
|
+
eval_dataset = self.dataset.as_eval_dataset()
|
|
3815
|
+
|
|
3816
|
+
# Get original thresholds from cache_parameters
|
|
3817
|
+
original_thresholds: dict[str, float] | None = None
|
|
3818
|
+
if self.cache_parameters:
|
|
3819
|
+
original_thresholds = {
|
|
3820
|
+
"similarity_threshold": self.cache_parameters.similarity_threshold,
|
|
3821
|
+
"context_similarity_threshold": self.cache_parameters.context_similarity_threshold,
|
|
3822
|
+
"question_weight": self.cache_parameters.question_weight or 0.6,
|
|
3823
|
+
}
|
|
3824
|
+
|
|
3825
|
+
# Get judge model
|
|
3826
|
+
judge_model_name: str
|
|
3827
|
+
if isinstance(self.judge_model, str):
|
|
3828
|
+
judge_model_name = self.judge_model
|
|
3829
|
+
elif self.judge_model:
|
|
3830
|
+
judge_model_name = self.judge_model.uri
|
|
3831
|
+
else:
|
|
3832
|
+
judge_model_name = "databricks-meta-llama-3-3-70b-instruct"
|
|
3833
|
+
|
|
3834
|
+
result: ThresholdOptimizationResult = optimize_semantic_cache_thresholds(
|
|
3835
|
+
dataset=eval_dataset,
|
|
3836
|
+
original_thresholds=original_thresholds,
|
|
3837
|
+
judge_model=judge_model_name,
|
|
3838
|
+
n_trials=self.n_trials,
|
|
3839
|
+
metric=self.metric,
|
|
3840
|
+
beta=self.beta,
|
|
3841
|
+
register_if_improved=True,
|
|
3842
|
+
study_name=self.name,
|
|
3843
|
+
seed=self.seed,
|
|
3844
|
+
)
|
|
3845
|
+
|
|
3846
|
+
return result
|
|
3847
|
+
|
|
3848
|
+
|
|
2852
3849
|
class DatasetFormat(str, Enum):
|
|
2853
3850
|
CSV = "csv"
|
|
2854
3851
|
DELTA = "delta"
|
|
@@ -3024,6 +4021,7 @@ class ResourcesModel(BaseModel):
|
|
|
3024
4021
|
|
|
3025
4022
|
class AppConfig(BaseModel):
|
|
3026
4023
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
4024
|
+
version: Optional[str] = None
|
|
3027
4025
|
variables: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
3028
4026
|
service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
|
|
3029
4027
|
schemas: dict[str, SchemaModel] = Field(default_factory=dict)
|
|
@@ -3044,6 +4042,9 @@ class AppConfig(BaseModel):
|
|
|
3044
4042
|
)
|
|
3045
4043
|
providers: Optional[dict[type | str, Any]] = None
|
|
3046
4044
|
|
|
4045
|
+
# Private attribute to track the source config file path (set by from_file)
|
|
4046
|
+
_source_config_path: str | None = None
|
|
4047
|
+
|
|
3047
4048
|
@classmethod
|
|
3048
4049
|
def from_file(cls, path: PathLike) -> "AppConfig":
|
|
3049
4050
|
path = Path(path).as_posix()
|
|
@@ -3051,12 +4052,20 @@ class AppConfig(BaseModel):
|
|
|
3051
4052
|
model_config: ModelConfig = ModelConfig(development_config=path)
|
|
3052
4053
|
config: AppConfig = AppConfig(**model_config.to_dict())
|
|
3053
4054
|
|
|
4055
|
+
# Store the source config path for later use (e.g., Apps deployment)
|
|
4056
|
+
config._source_config_path = path
|
|
4057
|
+
|
|
3054
4058
|
config.initialize()
|
|
3055
4059
|
|
|
3056
4060
|
atexit.register(config.shutdown)
|
|
3057
4061
|
|
|
3058
4062
|
return config
|
|
3059
4063
|
|
|
4064
|
+
@property
|
|
4065
|
+
def source_config_path(self) -> str | None:
|
|
4066
|
+
"""Get the source config file path if loaded via from_file."""
|
|
4067
|
+
return self._source_config_path
|
|
4068
|
+
|
|
3060
4069
|
def initialize(self) -> None:
|
|
3061
4070
|
from dao_ai.hooks.core import create_hooks
|
|
3062
4071
|
from dao_ai.logging import configure_logging
|
|
@@ -3127,6 +4136,7 @@ class AppConfig(BaseModel):
|
|
|
3127
4136
|
|
|
3128
4137
|
def deploy_agent(
|
|
3129
4138
|
self,
|
|
4139
|
+
target: DeploymentTarget | None = None,
|
|
3130
4140
|
w: WorkspaceClient | None = None,
|
|
3131
4141
|
vsc: "VectorSearchClient | None" = None,
|
|
3132
4142
|
pat: str | None = None,
|
|
@@ -3134,9 +4144,39 @@ class AppConfig(BaseModel):
|
|
|
3134
4144
|
client_secret: str | None = None,
|
|
3135
4145
|
workspace_host: str | None = None,
|
|
3136
4146
|
) -> None:
|
|
4147
|
+
"""
|
|
4148
|
+
Deploy the agent to the specified target.
|
|
4149
|
+
|
|
4150
|
+
Target resolution follows this priority:
|
|
4151
|
+
1. Explicit `target` parameter (if provided)
|
|
4152
|
+
2. `app.deployment_target` from config file (if set)
|
|
4153
|
+
3. Default: MODEL_SERVING
|
|
4154
|
+
|
|
4155
|
+
Args:
|
|
4156
|
+
target: The deployment target (MODEL_SERVING or APPS). If None, uses
|
|
4157
|
+
config.app.deployment_target or defaults to MODEL_SERVING.
|
|
4158
|
+
w: Optional WorkspaceClient instance
|
|
4159
|
+
vsc: Optional VectorSearchClient instance
|
|
4160
|
+
pat: Optional personal access token for authentication
|
|
4161
|
+
client_id: Optional client ID for service principal authentication
|
|
4162
|
+
client_secret: Optional client secret for service principal authentication
|
|
4163
|
+
workspace_host: Optional workspace host URL
|
|
4164
|
+
"""
|
|
3137
4165
|
from dao_ai.providers.base import ServiceProvider
|
|
3138
4166
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
3139
4167
|
|
|
4168
|
+
# Resolve target using hybrid logic:
|
|
4169
|
+
# 1. Explicit parameter takes precedence
|
|
4170
|
+
# 2. Fall back to config.app.deployment_target
|
|
4171
|
+
# 3. Default to MODEL_SERVING
|
|
4172
|
+
resolved_target: DeploymentTarget
|
|
4173
|
+
if target is not None:
|
|
4174
|
+
resolved_target = target
|
|
4175
|
+
elif self.app is not None and self.app.deployment_target is not None:
|
|
4176
|
+
resolved_target = self.app.deployment_target
|
|
4177
|
+
else:
|
|
4178
|
+
resolved_target = DeploymentTarget.MODEL_SERVING
|
|
4179
|
+
|
|
3140
4180
|
provider: ServiceProvider = DatabricksProvider(
|
|
3141
4181
|
w=w,
|
|
3142
4182
|
vsc=vsc,
|
|
@@ -3145,7 +4185,7 @@ class AppConfig(BaseModel):
|
|
|
3145
4185
|
client_secret=client_secret,
|
|
3146
4186
|
workspace_host=workspace_host,
|
|
3147
4187
|
)
|
|
3148
|
-
provider.deploy_agent(self)
|
|
4188
|
+
provider.deploy_agent(self, target=resolved_target)
|
|
3149
4189
|
|
|
3150
4190
|
def find_agents(
|
|
3151
4191
|
self, predicate: Callable[[AgentModel], bool] | None = None
|