dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +446 -16
  7. dao_ai/config.py +1034 -103
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +5 -0
  23. dao_ai/middleware/tool_selector.py +129 -0
  24. dao_ai/models.py +327 -370
  25. dao_ai/nodes.py +4 -4
  26. dao_ai/orchestration/core.py +33 -9
  27. dao_ai/orchestration/supervisor.py +23 -8
  28. dao_ai/orchestration/swarm.py +6 -1
  29. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  30. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  31. dao_ai/prompts/instruction_reranker.yaml +14 -0
  32. dao_ai/prompts/router.yaml +37 -0
  33. dao_ai/prompts/verifier.yaml +46 -0
  34. dao_ai/providers/base.py +28 -2
  35. dao_ai/providers/databricks.py +352 -33
  36. dao_ai/state.py +1 -0
  37. dao_ai/tools/__init__.py +5 -3
  38. dao_ai/tools/genie.py +103 -26
  39. dao_ai/tools/instructed_retriever.py +366 -0
  40. dao_ai/tools/instruction_reranker.py +202 -0
  41. dao_ai/tools/mcp.py +539 -97
  42. dao_ai/tools/router.py +89 -0
  43. dao_ai/tools/slack.py +13 -2
  44. dao_ai/tools/sql.py +7 -3
  45. dao_ai/tools/unity_catalog.py +32 -10
  46. dao_ai/tools/vector_search.py +493 -160
  47. dao_ai/tools/verifier.py +159 -0
  48. dao_ai/utils.py +182 -2
  49. dao_ai/vector_search.py +9 -1
  50. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
  51. dao_ai-0.1.20.dist-info/RECORD +89 -0
  52. dao_ai/agent_as_code.py +0 -22
  53. dao_ai/genie/cache/semantic.py +0 -970
  54. dao_ai-0.1.5.dist-info/RECORD +0 -70
  55. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  56. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  57. {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,7 @@ from databricks.sdk.service.catalog import (
23
23
  )
24
24
  from databricks.sdk.service.database import DatabaseCredential
25
25
  from databricks.sdk.service.iam import User
26
- from databricks.sdk.service.workspace import GetSecretResponse
26
+ from databricks.sdk.service.workspace import GetSecretResponse, ImportFormat
27
27
  from databricks.vector_search.client import VectorSearchClient
28
28
  from databricks.vector_search.index import VectorSearchIndex
29
29
  from loguru import logger
@@ -48,6 +48,7 @@ from dao_ai.config import (
48
48
  DatabaseModel,
49
49
  DatabricksAppModel,
50
50
  DatasetModel,
51
+ DeploymentTarget,
51
52
  FunctionModel,
52
53
  GenieRoomModel,
53
54
  HasFullName,
@@ -151,25 +152,77 @@ class DatabricksProvider(ServiceProvider):
151
152
  client_secret: str | None = None,
152
153
  workspace_host: str | None = None,
153
154
  ) -> None:
154
- if w is None:
155
- w = _workspace_client(
156
- pat=pat,
157
- client_id=client_id,
158
- client_secret=client_secret,
159
- workspace_host=workspace_host,
155
+ # Store credentials for lazy initialization
156
+ self._pat = pat
157
+ self._client_id = client_id
158
+ self._client_secret = client_secret
159
+ self._workspace_host = workspace_host
160
+
161
+ # Lazy initialization for WorkspaceClient
162
+ self._w: WorkspaceClient | None = w
163
+ self._w_initialized = w is not None
164
+
165
+ # Lazy initialization for VectorSearchClient - only create when needed
166
+ # This avoids authentication errors in Databricks Apps where VSC
167
+ # requires explicit credentials but the platform uses ambient auth
168
+ self._vsc: VectorSearchClient | None = vsc
169
+ self._vsc_initialized = vsc is not None
170
+
171
+ # Lazy initialization for DatabricksFunctionClient
172
+ self._dfs: DatabricksFunctionClient | None = dfs
173
+ self._dfs_initialized = dfs is not None
174
+
175
+ @property
176
+ def w(self) -> WorkspaceClient:
177
+ """Lazy initialization of WorkspaceClient."""
178
+ if not self._w_initialized:
179
+ self._w = _workspace_client(
180
+ pat=self._pat,
181
+ client_id=self._client_id,
182
+ client_secret=self._client_secret,
183
+ workspace_host=self._workspace_host,
160
184
  )
161
- if vsc is None:
162
- vsc = _vector_search_client(
163
- pat=pat,
164
- client_id=client_id,
165
- client_secret=client_secret,
166
- workspace_host=workspace_host,
185
+ self._w_initialized = True
186
+ return self._w # type: ignore[return-value]
187
+
188
+ @w.setter
189
+ def w(self, value: WorkspaceClient) -> None:
190
+ """Set WorkspaceClient and mark as initialized."""
191
+ self._w = value
192
+ self._w_initialized = True
193
+
194
+ @property
195
+ def vsc(self) -> VectorSearchClient:
196
+ """Lazy initialization of VectorSearchClient."""
197
+ if not self._vsc_initialized:
198
+ self._vsc = _vector_search_client(
199
+ pat=self._pat,
200
+ client_id=self._client_id,
201
+ client_secret=self._client_secret,
202
+ workspace_host=self._workspace_host,
167
203
  )
168
- if dfs is None:
169
- dfs = _function_client(w=w)
170
- self.w = w
171
- self.vsc = vsc
172
- self.dfs = dfs
204
+ self._vsc_initialized = True
205
+ return self._vsc # type: ignore[return-value]
206
+
207
+ @vsc.setter
208
+ def vsc(self, value: VectorSearchClient) -> None:
209
+ """Set VectorSearchClient and mark as initialized."""
210
+ self._vsc = value
211
+ self._vsc_initialized = True
212
+
213
+ @property
214
+ def dfs(self) -> DatabricksFunctionClient:
215
+ """Lazy initialization of DatabricksFunctionClient."""
216
+ if not self._dfs_initialized:
217
+ self._dfs = _function_client(w=self.w)
218
+ self._dfs_initialized = True
219
+ return self._dfs # type: ignore[return-value]
220
+
221
+ @dfs.setter
222
+ def dfs(self, value: DatabricksFunctionClient) -> None:
223
+ """Set DatabricksFunctionClient and mark as initialized."""
224
+ self._dfs = value
225
+ self._dfs_initialized = True
173
226
 
174
227
  def experiment_name(self, config: AppConfig) -> str:
175
228
  current_user: User = self.w.current_user.me()
@@ -326,7 +379,7 @@ class DatabricksProvider(ServiceProvider):
326
379
  raise FileNotFoundError(f"Code path does not exist: {path}")
327
380
 
328
381
  model_root_path: Path = Path(dao_ai.__file__).parent
329
- model_path: Path = model_root_path / "agent_as_code.py"
382
+ model_path: Path = model_root_path / "apps" / "model_serving.py"
330
383
 
331
384
  pip_requirements: Sequence[str] = config.app.pip_requirements
332
385
 
@@ -344,6 +397,8 @@ class DatabricksProvider(ServiceProvider):
344
397
 
345
398
  pip_requirements += get_installed_packages()
346
399
 
400
+ code_paths = list(dict.fromkeys(code_paths))
401
+
347
402
  logger.trace("Pip requirements prepared", count=len(pip_requirements))
348
403
  logger.trace("Code paths prepared", count=len(code_paths))
349
404
 
@@ -381,19 +436,38 @@ class DatabricksProvider(ServiceProvider):
381
436
  pip_packages_count=len(pip_requirements),
382
437
  )
383
438
 
384
- with mlflow.start_run(run_name=run_name):
385
- mlflow.set_tag("type", "agent")
386
- mlflow.set_tag("dao_ai", dao_ai_version())
387
- logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
388
- python_model=model_path.as_posix(),
389
- code_paths=code_paths,
390
- model_config=config.model_dump(mode="json", by_alias=True),
391
- name="agent",
392
- conda_env=conda_env,
393
- input_example=input_example,
394
- # resources=all_resources,
395
- auth_policy=auth_policy,
439
+ # End any stale runs before starting to ensure clean state on retry
440
+ if mlflow.active_run():
441
+ logger.warning(
442
+ "Ending stale MLflow run before creating new agent",
443
+ run_id=mlflow.active_run().info.run_id,
444
+ )
445
+ mlflow.end_run()
446
+
447
+ try:
448
+ with mlflow.start_run(run_name=run_name):
449
+ mlflow.set_tag("type", "agent")
450
+ mlflow.set_tag("dao_ai", dao_ai_version())
451
+ logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
452
+ python_model=model_path.as_posix(),
453
+ code_paths=code_paths,
454
+ model_config=config.model_dump(mode="json", by_alias=True),
455
+ name="agent",
456
+ conda_env=conda_env,
457
+ input_example=input_example,
458
+ # resources=all_resources,
459
+ auth_policy=auth_policy,
460
+ )
461
+ except Exception as e:
462
+ # Ensure run is ended on failure to prevent stale state on retry
463
+ if mlflow.active_run():
464
+ mlflow.end_run(status="FAILED")
465
+ logger.error(
466
+ "Failed to log model",
467
+ run_name=run_name,
468
+ error=str(e),
396
469
  )
470
+ raise
397
471
 
398
472
  registered_model_name: str = config.app.registered_model.full_name
399
473
 
@@ -439,8 +513,19 @@ class DatabricksProvider(ServiceProvider):
439
513
  version=aliased_model.version,
440
514
  )
441
515
 
442
- def deploy_agent(self, config: AppConfig) -> None:
443
- logger.info("Deploying agent", endpoint_name=config.app.endpoint_name)
516
+ def deploy_model_serving_agent(self, config: AppConfig) -> None:
517
+ """
518
+ Deploy agent to Databricks Model Serving endpoint.
519
+
520
+ This is the original deployment method that creates/updates a Model Serving
521
+ endpoint with the registered model.
522
+
523
+ Args:
524
+ config: The AppConfig containing deployment configuration
525
+ """
526
+ logger.info(
527
+ "Deploying agent to Model Serving", endpoint_name=config.app.endpoint_name
528
+ )
444
529
  mlflow.set_registry_uri("databricks-uc")
445
530
 
446
531
  endpoint_name: str = config.app.endpoint_name
@@ -499,6 +584,240 @@ class DatabricksProvider(ServiceProvider):
499
584
  permission_level=PermissionLevel[entitlement],
500
585
  )
501
586
 
587
+ def deploy_apps_agent(self, config: AppConfig) -> None:
588
+ """
589
+ Deploy agent as a Databricks App.
590
+
591
+ This method creates or updates a Databricks App that serves the agent
592
+ using the app_server module.
593
+
594
+ The deployment process:
595
+ 1. Determine the workspace source path for the app
596
+ 2. Upload the configuration file to the workspace
597
+ 3. Create the app if it doesn't exist
598
+ 4. Deploy the app
599
+
600
+ Args:
601
+ config: The AppConfig containing deployment configuration
602
+
603
+ Note:
604
+ The config file must be loaded via AppConfig.from_file() so that
605
+ the source_config_path is available for upload.
606
+ """
607
+ import io
608
+
609
+ from databricks.sdk.service.apps import (
610
+ App,
611
+ AppDeployment,
612
+ AppDeploymentMode,
613
+ AppDeploymentState,
614
+ )
615
+
616
+ # Normalize app name: lowercase, replace underscores with dashes
617
+ raw_name: str = config.app.name
618
+ app_name: str = raw_name.lower().replace("_", "-")
619
+ if app_name != raw_name:
620
+ logger.info(
621
+ "Normalized app name for Databricks Apps",
622
+ original=raw_name,
623
+ normalized=app_name,
624
+ )
625
+ logger.info("Deploying agent to Databricks Apps", app_name=app_name)
626
+
627
+ # Use convention-based workspace path: /Workspace/Users/{user}/apps/{app_name}
628
+ current_user: User = self.w.current_user.me()
629
+ user_name: str = current_user.user_name or "default"
630
+ source_path: str = f"/Workspace/Users/{user_name}/apps/{app_name}"
631
+
632
+ logger.info("Using workspace source path", source_path=source_path)
633
+
634
+ # Get or create experiment for this app (for tracing and tracking)
635
+ from mlflow.entities import Experiment
636
+
637
+ experiment: Experiment = self.get_or_create_experiment(config)
638
+ logger.info(
639
+ "Using MLflow experiment for app",
640
+ experiment_name=experiment.name,
641
+ experiment_id=experiment.experiment_id,
642
+ )
643
+
644
+ # Upload the configuration file to the workspace
645
+ source_config_path: str | None = config.source_config_path
646
+ if source_config_path:
647
+ # Read the config file and upload to workspace
648
+ config_file_name: str = "dao_ai.yaml"
649
+ workspace_config_path: str = f"{source_path}/{config_file_name}"
650
+
651
+ logger.info(
652
+ "Uploading config file to workspace",
653
+ source=source_config_path,
654
+ destination=workspace_config_path,
655
+ )
656
+
657
+ # Read the source config file
658
+ with open(source_config_path, "rb") as f:
659
+ config_content: bytes = f.read()
660
+
661
+ # Create the directory if it doesn't exist and upload the file
662
+ try:
663
+ self.w.workspace.mkdirs(source_path)
664
+ except Exception as e:
665
+ logger.debug(f"Directory may already exist: {e}")
666
+
667
+ # Upload the config file
668
+ self.w.workspace.upload(
669
+ path=workspace_config_path,
670
+ content=io.BytesIO(config_content),
671
+ format=ImportFormat.AUTO,
672
+ overwrite=True,
673
+ )
674
+ logger.info("Config file uploaded", path=workspace_config_path)
675
+ else:
676
+ logger.warning(
677
+ "No source config path available. "
678
+ "Ensure DAO_AI_CONFIG_PATH is set in the app environment or "
679
+ "dao_ai.yaml exists in the app source directory."
680
+ )
681
+
682
+ # Generate and upload app.yaml with dynamically discovered resources
683
+ from dao_ai.apps.resources import generate_app_yaml
684
+
685
+ app_yaml_content: str = generate_app_yaml(
686
+ config,
687
+ command=[
688
+ "/bin/bash",
689
+ "-c",
690
+ "pip install dao-ai && python -m dao_ai.apps.server",
691
+ ],
692
+ include_resources=True,
693
+ )
694
+
695
+ app_yaml_path: str = f"{source_path}/app.yaml"
696
+ self.w.workspace.upload(
697
+ path=app_yaml_path,
698
+ content=io.BytesIO(app_yaml_content.encode("utf-8")),
699
+ format=ImportFormat.AUTO,
700
+ overwrite=True,
701
+ )
702
+ logger.info("app.yaml with resources uploaded", path=app_yaml_path)
703
+
704
+ # Generate SDK resources from the config (including experiment)
705
+ from dao_ai.apps.resources import (
706
+ generate_sdk_resources,
707
+ generate_user_api_scopes,
708
+ )
709
+
710
+ sdk_resources = generate_sdk_resources(
711
+ config, experiment_id=experiment.experiment_id
712
+ )
713
+ if sdk_resources:
714
+ logger.info(
715
+ "Discovered app resources from config",
716
+ resource_count=len(sdk_resources),
717
+ resources=[r.name for r in sdk_resources],
718
+ )
719
+
720
+ # Generate user API scopes for on-behalf-of-user resources
721
+ user_api_scopes = generate_user_api_scopes(config)
722
+ if user_api_scopes:
723
+ logger.info(
724
+ "Discovered user API scopes for OBO resources",
725
+ scopes=user_api_scopes,
726
+ )
727
+
728
+ # Check if app exists
729
+ app_exists: bool = False
730
+ try:
731
+ existing_app: App = self.w.apps.get(name=app_name)
732
+ app_exists = True
733
+ logger.debug("App already exists, updating", app_name=app_name)
734
+ except NotFound:
735
+ logger.debug("Creating new app", app_name=app_name)
736
+
737
+ # Create or update the app with resources and user_api_scopes
738
+ if not app_exists:
739
+ logger.info("Creating Databricks App", app_name=app_name)
740
+ app_spec = App(
741
+ name=app_name,
742
+ description=config.app.description or f"DAO AI Agent: {app_name}",
743
+ resources=sdk_resources if sdk_resources else None,
744
+ user_api_scopes=user_api_scopes if user_api_scopes else None,
745
+ )
746
+ app: App = self.w.apps.create_and_wait(app=app_spec)
747
+ logger.info("App created", app_name=app.name, app_url=app.url)
748
+ else:
749
+ app = existing_app
750
+ # Update resources and scopes on existing app
751
+ if sdk_resources or user_api_scopes:
752
+ logger.info("Updating app resources and scopes", app_name=app_name)
753
+ updated_app = App(
754
+ name=app_name,
755
+ description=config.app.description or app.description,
756
+ resources=sdk_resources if sdk_resources else None,
757
+ user_api_scopes=user_api_scopes if user_api_scopes else None,
758
+ )
759
+ app = self.w.apps.update(name=app_name, app=updated_app)
760
+ logger.info("App resources and scopes updated", app_name=app_name)
761
+
762
+ # Deploy the app with source code
763
+ # The app will use the dao_ai.apps.server module as the entry point
764
+ logger.info("Deploying app", app_name=app_name)
765
+
766
+ # Create deployment configuration
767
+ app_deployment = AppDeployment(
768
+ mode=AppDeploymentMode.SNAPSHOT,
769
+ source_code_path=source_path,
770
+ )
771
+
772
+ # Deploy the app
773
+ deployment: AppDeployment = self.w.apps.deploy_and_wait(
774
+ app_name=app_name,
775
+ app_deployment=app_deployment,
776
+ )
777
+
778
+ if (
779
+ deployment.status
780
+ and deployment.status.state == AppDeploymentState.SUCCEEDED
781
+ ):
782
+ logger.info(
783
+ "App deployed successfully",
784
+ app_name=app_name,
785
+ deployment_id=deployment.deployment_id,
786
+ app_url=app.url if app else None,
787
+ )
788
+ else:
789
+ status_message: str = (
790
+ deployment.status.message if deployment.status else "Unknown error"
791
+ )
792
+ logger.error(
793
+ "App deployment failed",
794
+ app_name=app_name,
795
+ status=status_message,
796
+ )
797
+ raise RuntimeError(f"App deployment failed: {status_message}")
798
+
799
+ def deploy_agent(
800
+ self,
801
+ config: AppConfig,
802
+ target: DeploymentTarget = DeploymentTarget.MODEL_SERVING,
803
+ ) -> None:
804
+ """
805
+ Deploy agent to the specified target.
806
+
807
+ This is the main deployment method that routes to the appropriate
808
+ deployment implementation based on the target.
809
+
810
+ Args:
811
+ config: The AppConfig containing deployment configuration
812
+ target: The deployment target (MODEL_SERVING or APPS)
813
+ """
814
+ if target == DeploymentTarget.MODEL_SERVING:
815
+ self.deploy_model_serving_agent(config)
816
+ elif target == DeploymentTarget.APPS:
817
+ self.deploy_apps_agent(config)
818
+ else:
819
+ raise ValueError(f"Unknown deployment target: {target}")
820
+
502
821
  def create_catalog(self, schema: SchemaModel) -> CatalogInfo:
503
822
  catalog_info: CatalogInfo
504
823
  try:
dao_ai/state.py CHANGED
@@ -164,6 +164,7 @@ class Context(BaseModel):
164
164
 
165
165
  user_id: str | None = None
166
166
  thread_id: str | None = None
167
+ headers: dict[str, Any] | None = None
167
168
 
168
169
  @classmethod
169
170
  def from_runnable_config(cls, config: dict[str, Any]) -> "Context":
dao_ai/tools/__init__.py CHANGED
@@ -1,10 +1,10 @@
1
- from dao_ai.genie.cache import LRUCacheService, SemanticCacheService
1
+ from dao_ai.genie.cache import LRUCacheService, PostgresContextAwareGenieService
2
2
  from dao_ai.hooks.core import create_hooks
3
3
  from dao_ai.tools.agent import create_agent_endpoint_tool
4
4
  from dao_ai.tools.core import create_tools, say_hello_tool
5
5
  from dao_ai.tools.email import create_send_email_tool
6
6
  from dao_ai.tools.genie import create_genie_tool
7
- from dao_ai.tools.mcp import create_mcp_tools
7
+ from dao_ai.tools.mcp import MCPToolInfo, create_mcp_tools, list_mcp_tools
8
8
  from dao_ai.tools.memory import create_search_memory_tool
9
9
  from dao_ai.tools.python import create_factory_tool, create_python_tool
10
10
  from dao_ai.tools.search import create_search_tool
@@ -30,6 +30,8 @@ __all__ = [
30
30
  "create_genie_tool",
31
31
  "create_hooks",
32
32
  "create_mcp_tools",
33
+ "list_mcp_tools",
34
+ "MCPToolInfo",
33
35
  "create_python_tool",
34
36
  "create_search_memory_tool",
35
37
  "create_search_tool",
@@ -42,8 +44,8 @@ __all__ = [
42
44
  "format_time_tool",
43
45
  "is_business_hours_tool",
44
46
  "LRUCacheService",
47
+ "PostgresContextAwareGenieService",
45
48
  "say_hello_tool",
46
- "SemanticCacheService",
47
49
  "time_difference_tool",
48
50
  "time_in_timezone_tool",
49
51
  "time_until_tool",
dao_ai/tools/genie.py CHANGED
@@ -6,7 +6,7 @@ interact with Databricks Genie.
6
6
 
7
7
  For the core Genie service and cache implementations, see:
8
8
  - dao_ai.genie: GenieService, GenieServiceBase
9
- - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
9
+ - dao_ai.genie.cache: LRUCacheService, PostgresContextAwareGenieService, InMemoryContextAwareGenieService
10
10
  """
11
11
 
12
12
  import json
@@ -25,13 +25,19 @@ from pydantic import BaseModel
25
25
  from dao_ai.config import (
26
26
  AnyVariable,
27
27
  CompositeVariableModel,
28
+ GenieContextAwareCacheParametersModel,
29
+ GenieInMemorySemanticCacheParametersModel,
28
30
  GenieLRUCacheParametersModel,
29
31
  GenieRoomModel,
30
- GenieSemanticCacheParametersModel,
31
32
  value_of,
32
33
  )
33
34
  from dao_ai.genie import GenieService, GenieServiceBase
34
- from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
35
+ from dao_ai.genie.cache import (
36
+ CacheResult,
37
+ InMemoryContextAwareGenieService,
38
+ LRUCacheService,
39
+ PostgresContextAwareGenieService,
40
+ )
35
41
  from dao_ai.state import AgentState, Context, SessionState
36
42
 
37
43
 
@@ -64,7 +70,10 @@ def create_genie_tool(
64
70
  persist_conversation: bool = True,
65
71
  truncate_results: bool = False,
66
72
  lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
67
- semantic_cache_parameters: GenieSemanticCacheParametersModel
73
+ semantic_cache_parameters: GenieContextAwareCacheParametersModel
74
+ | dict[str, Any]
75
+ | None = None,
76
+ in_memory_semantic_cache_parameters: GenieInMemorySemanticCacheParametersModel
68
77
  | dict[str, Any]
69
78
  | None = None,
70
79
  ) -> Callable[..., Command]:
@@ -84,7 +93,9 @@ def create_genie_tool(
84
93
  truncate_results: Whether to truncate large query results to fit token limits
85
94
  lru_cache_parameters: Optional LRU cache configuration for SQL query caching
86
95
  semantic_cache_parameters: Optional semantic cache configuration using pg_vector
87
- for similarity-based query matching
96
+ for similarity-based query matching (requires PostgreSQL/Lakebase)
97
+ in_memory_semantic_cache_parameters: Optional in-memory semantic cache configuration
98
+ for similarity-based query matching (no database required)
88
99
 
89
100
  Returns:
90
101
  A LangGraph tool that processes natural language queries through Genie
@@ -97,6 +108,7 @@ def create_genie_tool(
97
108
  name=name,
98
109
  has_lru_cache=lru_cache_parameters is not None,
99
110
  has_semantic_cache=semantic_cache_parameters is not None,
111
+ has_in_memory_semantic_cache=in_memory_semantic_cache_parameters is not None,
100
112
  )
101
113
 
102
114
  if isinstance(genie_room, dict):
@@ -106,10 +118,15 @@ def create_genie_tool(
106
118
  lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
107
119
 
108
120
  if isinstance(semantic_cache_parameters, dict):
109
- semantic_cache_parameters = GenieSemanticCacheParametersModel(
121
+ semantic_cache_parameters = GenieContextAwareCacheParametersModel(
110
122
  **semantic_cache_parameters
111
123
  )
112
124
 
125
+ if isinstance(in_memory_semantic_cache_parameters, dict):
126
+ in_memory_semantic_cache_parameters = GenieInMemorySemanticCacheParametersModel(
127
+ **in_memory_semantic_cache_parameters
128
+ )
129
+
113
130
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
114
131
  "DATABRICKS_GENIE_SPACE_ID"
115
132
  )
@@ -139,29 +156,61 @@ Returns:
139
156
  GenieResponse: A response object containing the conversation ID and result from Genie."""
140
157
  tool_description = tool_description + function_docs
141
158
 
142
- genie: Genie = Genie(
143
- space_id=space_id,
144
- client=genie_room.workspace_client,
145
- truncate_results=truncate_results,
146
- )
159
+ # Cache for genie service - created lazily on first call
160
+ # This allows us to use workspace_client_from with runtime context for OBO
161
+ _cached_genie_service: GenieServiceBase | None = None
162
+
163
+ def _get_genie_service(context: Context | None) -> GenieServiceBase:
164
+ """Get or create the Genie service, using context for OBO auth if available."""
165
+ nonlocal _cached_genie_service
166
+
167
+ # Use cached service if available (for non-OBO or after first call)
168
+ # For OBO, we need fresh workspace client each time to use the user's token
169
+ if _cached_genie_service is not None and not genie_room.on_behalf_of_user:
170
+ return _cached_genie_service
147
171
 
148
- genie_service: GenieServiceBase = GenieService(genie)
149
-
150
- # Wrap with semantic cache first (checked second due to decorator pattern)
151
- if semantic_cache_parameters is not None:
152
- genie_service = SemanticCacheService(
153
- impl=genie_service,
154
- parameters=semantic_cache_parameters,
155
- workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
- ).initialize() # Eagerly initialize to fail fast and create table
157
-
158
- # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
- if lru_cache_parameters is not None:
160
- genie_service = LRUCacheService(
161
- impl=genie_service,
162
- parameters=lru_cache_parameters,
172
+ # Get workspace client using context for OBO support
173
+ from databricks.sdk import WorkspaceClient
174
+
175
+ workspace_client: WorkspaceClient = genie_room.workspace_client_from(context)
176
+
177
+ genie: Genie = Genie(
178
+ space_id=space_id,
179
+ client=workspace_client,
180
+ truncate_results=truncate_results,
163
181
  )
164
182
 
183
+ genie_service: GenieServiceBase = GenieService(genie)
184
+
185
+ # Wrap with context-aware cache first (checked second/third due to decorator pattern)
186
+ if semantic_cache_parameters is not None:
187
+ genie_service = PostgresContextAwareGenieService(
188
+ impl=genie_service,
189
+ parameters=semantic_cache_parameters,
190
+ workspace_client=workspace_client,
191
+ ).initialize()
192
+
193
+ # Wrap with in-memory context-aware cache (alternative to PostgreSQL context-aware cache)
194
+ if in_memory_semantic_cache_parameters is not None:
195
+ genie_service = InMemoryContextAwareGenieService(
196
+ impl=genie_service,
197
+ parameters=in_memory_semantic_cache_parameters,
198
+ workspace_client=workspace_client,
199
+ ).initialize()
200
+
201
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
202
+ if lru_cache_parameters is not None:
203
+ genie_service = LRUCacheService(
204
+ impl=genie_service,
205
+ parameters=lru_cache_parameters,
206
+ )
207
+
208
+ # Cache for non-OBO scenarios
209
+ if not genie_room.on_behalf_of_user:
210
+ _cached_genie_service = genie_service
211
+
212
+ return genie_service
213
+
165
214
  @tool(
166
215
  name_or_callable=tool_name,
167
216
  description=tool_description,
@@ -177,6 +226,10 @@ GenieResponse: A response object containing the conversation ID and result from
177
226
  # Access state through runtime
178
227
  state: AgentState = runtime.state
179
228
  tool_call_id: str = runtime.tool_call_id
229
+ context: Context | None = runtime.context
230
+
231
+ # Get genie service with OBO support via context
232
+ genie_service: GenieServiceBase = _get_genie_service(context)
180
233
 
181
234
  # Ensure space_id is a string for state keys
182
235
  space_id_str: str = str(space_id)
@@ -194,6 +247,14 @@ GenieResponse: A response object containing the conversation ID and result from
194
247
  conversation_id=existing_conversation_id,
195
248
  )
196
249
 
250
+ # Log the prompt being sent to Genie
251
+ logger.trace(
252
+ "Sending prompt to Genie",
253
+ space_id=space_id_str,
254
+ conversation_id=existing_conversation_id,
255
+ prompt=question[:500] + "..." if len(question) > 500 else question,
256
+ )
257
+
197
258
  # Call ask_question which always returns CacheResult with cache metadata
198
259
  cache_result: CacheResult = genie_service.ask_question(
199
260
  question, conversation_id=existing_conversation_id
@@ -211,6 +272,22 @@ GenieResponse: A response object containing the conversation ID and result from
211
272
  cache_key=cache_key,
212
273
  )
213
274
 
275
+ # Log truncated response for debugging
276
+ result_preview: str = str(genie_response.result)
277
+ if len(result_preview) > 500:
278
+ result_preview = result_preview[:500] + "..."
279
+ logger.trace(
280
+ "Genie response content",
281
+ question=question[:100] + "..." if len(question) > 100 else question,
282
+ query=genie_response.query,
283
+ description=(
284
+ genie_response.description[:200] + "..."
285
+ if genie_response.description and len(genie_response.description) > 200
286
+ else genie_response.description
287
+ ),
288
+ result_preview=result_preview,
289
+ )
290
+
214
291
  # Update session state with cache information
215
292
  if persist_conversation:
216
293
  session.genie.update_space(