dao-ai 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/apps/server.py ADDED
@@ -0,0 +1,39 @@
1
+ """
2
+ App server module for running dao-ai agents as Databricks Apps.
3
+
4
+ This module provides the entry point for deploying dao-ai agents as Databricks Apps
5
+ using MLflow's AgentServer. It follows the same pattern as model_serving.py but
6
+ uses the AgentServer for the Databricks Apps runtime.
7
+
8
+ Configuration Loading:
9
+ The config path is specified via the DAO_AI_CONFIG_PATH environment variable,
10
+ or defaults to model_config.yaml in the current directory.
11
+
12
+ Usage:
13
+ # With environment variable
14
+ DAO_AI_CONFIG_PATH=/path/to/config.yaml python -m dao_ai.apps.server
15
+
16
+ # With default model_config.yaml in current directory
17
+ python -m dao_ai.apps.server
18
+ """
19
+
20
+ from mlflow.genai.agent_server import AgentServer
21
+
22
+ # Import the agent handlers to register the invoke and stream decorators
23
+ # This MUST happen before creating the AgentServer instance
24
+ import dao_ai.apps.handlers # noqa: E402, F401
25
+
26
+ # Create the AgentServer instance
27
+ agent_server = AgentServer("ResponsesAgent", enable_chat_proxy=True)
28
+
29
+ # Define the app as a module level variable to enable multiple workers
30
+ app = agent_server.app
31
+
32
+
33
+ def main() -> None:
34
+ """Entry point for running the agent server."""
35
+ agent_server.run(app_import_string="dao_ai.apps.server:app")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
dao_ai/cli.py CHANGED
@@ -285,6 +285,15 @@ Examples:
285
285
  action="store_true",
286
286
  help="Perform a dry run without executing the deployment or run commands",
287
287
  )
288
+ bundle_parser.add_argument(
289
+ "--deployment-target",
290
+ type=str,
291
+ choices=["model_serving", "apps"],
292
+ default=None,
293
+ help="Agent deployment target: 'model_serving' or 'apps'. "
294
+ "If not specified, uses app.deployment_target from config file, "
295
+ "or defaults to 'model_serving'. Passed to the deploy notebook.",
296
+ )
288
297
 
289
298
  # Deploy command
290
299
  deploy_parser: ArgumentParser = subparsers.add_parser(
@@ -309,6 +318,16 @@ Examples:
309
318
  metavar="FILE",
310
319
  help="Path to the model configuration file to validate",
311
320
  )
321
+ deploy_parser.add_argument(
322
+ "-t",
323
+ "--target",
324
+ type=str,
325
+ choices=["model_serving", "apps"],
326
+ default=None,
327
+ help="Deployment target: 'model_serving' or 'apps'. "
328
+ "If not specified, uses app.deployment_target from config file, "
329
+ "or defaults to 'model_serving'.",
330
+ )
312
331
 
313
332
  # List MCP tools command
314
333
  list_mcp_parser: ArgumentParser = subparsers.add_parser(
@@ -729,11 +748,28 @@ def handle_graph_command(options: Namespace) -> None:
729
748
 
730
749
 
731
750
  def handle_deploy_command(options: Namespace) -> None:
751
+ from dao_ai.config import DeploymentTarget
752
+
732
753
  logger.debug(f"Validating configuration from {options.config}...")
733
754
  try:
734
755
  config: AppConfig = AppConfig.from_file(options.config)
756
+
757
+ # Hybrid target resolution:
758
+ # 1. CLI --target takes precedence
759
+ # 2. Fall back to config.app.deployment_target
760
+ # 3. Default to MODEL_SERVING (handled in deploy_agent)
761
+ target: DeploymentTarget | None = None
762
+ if options.target is not None:
763
+ target = DeploymentTarget(options.target)
764
+ logger.info(f"Using CLI-specified deployment target: {target.value}")
765
+ elif config.app is not None and config.app.deployment_target is not None:
766
+ target = config.app.deployment_target
767
+ logger.info(f"Using config file deployment target: {target.value}")
768
+ else:
769
+ logger.info("No deployment target specified, defaulting to model_serving")
770
+
735
771
  config.create_agent()
736
- config.deploy_agent()
772
+ config.deploy_agent(target=target)
737
773
  sys.exit(0)
738
774
  except Exception as e:
739
775
  logger.error(f"Deployment failed: {e}")
@@ -1083,6 +1119,7 @@ def run_databricks_command(
1083
1119
  target: Optional[str] = None,
1084
1120
  cloud: Optional[str] = None,
1085
1121
  dry_run: bool = False,
1122
+ deployment_target: Optional[str] = None,
1086
1123
  ) -> None:
1087
1124
  """Execute a databricks CLI command with optional profile, target, and cloud.
1088
1125
 
@@ -1093,6 +1130,8 @@ def run_databricks_command(
1093
1130
  target: Optional bundle target name (if not provided, auto-generated from app name and cloud)
1094
1131
  cloud: Optional cloud provider ('azure', 'aws', 'gcp'). Auto-detected if not specified.
1095
1132
  dry_run: If True, print the command without executing
1133
+ deployment_target: Optional agent deployment target ('model_serving' or 'apps').
1134
+ Passed to the deploy notebook via bundle variable.
1096
1135
  """
1097
1136
  config_path = Path(config) if config else None
1098
1137
 
@@ -1148,6 +1187,24 @@ def run_databricks_command(
1148
1187
 
1149
1188
  cmd.append(f'--var="config_path={relative_config}"')
1150
1189
 
1190
+ # Add deployment_target variable for notebooks (hybrid resolution)
1191
+ # Priority: CLI arg > config file > default (model_serving)
1192
+ resolved_deployment_target: str = "model_serving"
1193
+ if deployment_target is not None:
1194
+ resolved_deployment_target = deployment_target
1195
+ logger.debug(
1196
+ f"Using CLI-specified deployment target: {resolved_deployment_target}"
1197
+ )
1198
+ elif app_config and app_config.app and app_config.app.deployment_target:
1199
+ resolved_deployment_target = app_config.app.deployment_target.value
1200
+ logger.debug(
1201
+ f"Using config file deployment target: {resolved_deployment_target}"
1202
+ )
1203
+ else:
1204
+ logger.debug("Using default deployment target: model_serving")
1205
+
1206
+ cmd.append(f'--var="deployment_target={resolved_deployment_target}"')
1207
+
1151
1208
  logger.debug(f"Executing command: {' '.join(cmd)}")
1152
1209
 
1153
1210
  if dry_run:
@@ -1190,6 +1247,7 @@ def handle_bundle_command(options: Namespace) -> None:
1190
1247
  target: Optional[str] = options.target
1191
1248
  cloud: Optional[str] = options.cloud
1192
1249
  dry_run: bool = options.dry_run
1250
+ deployment_target: Optional[str] = options.deployment_target
1193
1251
 
1194
1252
  if options.deploy:
1195
1253
  logger.info("Deploying DAO AI asset bundle...")
@@ -1200,6 +1258,7 @@ def handle_bundle_command(options: Namespace) -> None:
1200
1258
  target=target,
1201
1259
  cloud=cloud,
1202
1260
  dry_run=dry_run,
1261
+ deployment_target=deployment_target,
1203
1262
  )
1204
1263
  if options.run:
1205
1264
  logger.info("Running DAO AI system with current configuration...")
@@ -1211,6 +1270,7 @@ def handle_bundle_command(options: Namespace) -> None:
1211
1270
  target=target,
1212
1271
  cloud=cloud,
1213
1272
  dry_run=dry_run,
1273
+ deployment_target=deployment_target,
1214
1274
  )
1215
1275
  if options.destroy:
1216
1276
  logger.info("Destroying DAO AI system with current configuration...")
@@ -1221,6 +1281,7 @@ def handle_bundle_command(options: Namespace) -> None:
1221
1281
  target=target,
1222
1282
  cloud=cloud,
1223
1283
  dry_run=dry_run,
1284
+ deployment_target=deployment_target,
1224
1285
  )
1225
1286
  else:
1226
1287
  logger.warning("No action specified. Use --deploy, --run or --destroy flags.")
dao_ai/config.py CHANGED
@@ -208,7 +208,9 @@ class IsDatabricksResource(ABC, BaseModel):
208
208
  Authentication Options:
209
209
  ----------------------
210
210
  1. **On-Behalf-Of User (OBO)**: Set on_behalf_of_user=True to use the
211
- calling user's identity via ModelServingUserCredentials.
211
+ calling user's identity. Implementation varies by deployment:
212
+ - Databricks Apps: Uses X-Forwarded-Access-Token from request headers
213
+ - Model Serving: Uses ModelServingUserCredentials
212
214
 
213
215
  2. **Service Principal (OAuth M2M)**: Provide service_principal or
214
216
  (client_id + client_secret + workspace_host) for service principal auth.
@@ -221,9 +223,17 @@ class IsDatabricksResource(ABC, BaseModel):
221
223
 
222
224
  Authentication Priority:
223
225
  1. OBO (on_behalf_of_user=True)
226
+ - Checks for forwarded headers (Databricks Apps)
227
+ - Falls back to ModelServingUserCredentials (Model Serving)
224
228
  2. Service Principal (client_id + client_secret + workspace_host)
225
229
  3. PAT (pat + workspace_host)
226
230
  4. Ambient/default authentication
231
+
232
+ Note: When on_behalf_of_user=True, the agent acts as the calling user regardless
233
+ of deployment target. In Databricks Apps, this uses X-Forwarded-Access-Token
234
+ automatically captured by MLflow AgentServer. In Model Serving, this uses
235
+ ModelServingUserCredentials. Forwarded headers are ONLY used when
236
+ on_behalf_of_user=True.
227
237
  """
228
238
 
229
239
  model_config = ConfigDict(use_enum_values=True)
@@ -235,9 +245,6 @@ class IsDatabricksResource(ABC, BaseModel):
235
245
  workspace_host: Optional[AnyVariable] = None
236
246
  pat: Optional[AnyVariable] = None
237
247
 
238
- # Private attribute to cache the workspace client (lazy instantiation)
239
- _workspace_client: Optional[WorkspaceClient] = PrivateAttr(default=None)
240
-
241
248
  @abstractmethod
242
249
  def as_resources(self) -> Sequence[DatabricksResource]: ...
243
250
 
@@ -273,32 +280,56 @@ class IsDatabricksResource(ABC, BaseModel):
273
280
  """
274
281
  Get a WorkspaceClient configured with the appropriate authentication.
275
282
 
276
- The client is lazily instantiated on first access and cached for subsequent calls.
283
+ A new client is created on each access.
277
284
 
278
285
  Authentication priority:
279
- 1. If on_behalf_of_user is True, uses ModelServingUserCredentials (OBO)
280
- 2. If service principal credentials are configured (client_id, client_secret,
281
- workspace_host), uses OAuth M2M
282
- 3. If PAT is configured, uses token authentication
283
- 4. Otherwise, uses default/ambient authentication
286
+ 1. On-Behalf-Of User (on_behalf_of_user=True):
287
+ - Forwarded headers (Databricks Apps)
288
+ - ModelServingUserCredentials (Model Serving)
289
+ 2. Service Principal (client_id + client_secret + workspace_host)
290
+ 3. PAT (pat + workspace_host)
291
+ 4. Ambient/default authentication
284
292
  """
285
- # Return cached client if already instantiated
286
- if self._workspace_client is not None:
287
- return self._workspace_client
288
-
289
293
  from dao_ai.utils import normalize_host
290
294
 
291
295
  # Check for OBO first (highest priority)
292
296
  if self.on_behalf_of_user:
297
+ # NEW: In Databricks Apps, use forwarded headers for per-user auth
298
+ try:
299
+ from mlflow.genai.agent_server import get_request_headers
300
+
301
+ headers = get_request_headers()
302
+ forwarded_token = headers.get("x-forwarded-access-token")
303
+
304
+ if forwarded_token:
305
+ forwarded_user = headers.get("x-forwarded-user", "unknown")
306
+ logger.debug(
307
+ f"Creating WorkspaceClient for {self.__class__.__name__} "
308
+ f"with OBO using forwarded token from Databricks Apps",
309
+ forwarded_user=forwarded_user,
310
+ )
311
+ # Use workspace_host if configured, otherwise SDK will auto-detect
312
+ workspace_host_value: str | None = (
313
+ normalize_host(value_of(self.workspace_host))
314
+ if self.workspace_host
315
+ else None
316
+ )
317
+ return WorkspaceClient(
318
+ host=workspace_host_value,
319
+ token=forwarded_token,
320
+ auth_type="pat",
321
+ )
322
+ except (ImportError, LookupError):
323
+ # mlflow not available or headers not set - fall through to Model Serving
324
+ pass
325
+
326
+ # Fall back to Model Serving OBO (existing behavior)
293
327
  credentials_strategy: CredentialsStrategy = ModelServingUserCredentials()
294
328
  logger.debug(
295
329
  f"Creating WorkspaceClient for {self.__class__.__name__} "
296
- f"with OBO credentials strategy"
297
- )
298
- self._workspace_client = WorkspaceClient(
299
- credentials_strategy=credentials_strategy
330
+ f"with OBO credentials strategy (Model Serving)"
300
331
  )
301
- return self._workspace_client
332
+ return WorkspaceClient(credentials_strategy=credentials_strategy)
302
333
 
303
334
  # Check for service principal credentials
304
335
  client_id_value: str | None = (
@@ -313,18 +344,24 @@ class IsDatabricksResource(ABC, BaseModel):
313
344
  else None
314
345
  )
315
346
 
316
- if client_id_value and client_secret_value and workspace_host_value:
347
+ if client_id_value and client_secret_value:
348
+ # If workspace_host is not provided, check DATABRICKS_HOST env var first,
349
+ # then fall back to WorkspaceClient().config.host
350
+ if not workspace_host_value:
351
+ workspace_host_value = os.getenv("DATABRICKS_HOST")
352
+ if not workspace_host_value:
353
+ workspace_host_value = WorkspaceClient().config.host
354
+
317
355
  logger.debug(
318
356
  f"Creating WorkspaceClient for {self.__class__.__name__} with service principal: "
319
357
  f"client_id={client_id_value}, host={workspace_host_value}"
320
358
  )
321
- self._workspace_client = WorkspaceClient(
359
+ return WorkspaceClient(
322
360
  host=workspace_host_value,
323
361
  client_id=client_id_value,
324
362
  client_secret=client_secret_value,
325
363
  auth_type="oauth-m2m",
326
364
  )
327
- return self._workspace_client
328
365
 
329
366
  # Check for PAT authentication
330
367
  pat_value: str | None = value_of(self.pat) if self.pat else None
@@ -332,20 +369,28 @@ class IsDatabricksResource(ABC, BaseModel):
332
369
  logger.debug(
333
370
  f"Creating WorkspaceClient for {self.__class__.__name__} with PAT"
334
371
  )
335
- self._workspace_client = WorkspaceClient(
372
+ return WorkspaceClient(
336
373
  host=workspace_host_value,
337
374
  token=pat_value,
338
375
  auth_type="pat",
339
376
  )
340
- return self._workspace_client
341
377
 
342
378
  # Default: use ambient authentication
343
379
  logger.debug(
344
380
  f"Creating WorkspaceClient for {self.__class__.__name__} "
345
381
  "with default/ambient authentication"
346
382
  )
347
- self._workspace_client = WorkspaceClient()
348
- return self._workspace_client
383
+ return WorkspaceClient()
384
+
385
+
386
+ class DeploymentTarget(str, Enum):
387
+ """Target platform for agent deployment."""
388
+
389
+ MODEL_SERVING = "model_serving"
390
+ """Deploy to Databricks Model Serving endpoint."""
391
+
392
+ APPS = "apps"
393
+ """Deploy as a Databricks App."""
349
394
 
350
395
 
351
396
  class Privilege(str, Enum):
@@ -865,10 +910,6 @@ class GenieRoomModel(IsDatabricksResource):
865
910
  pat=self.pat,
866
911
  )
867
912
 
868
- # Share the cached workspace client if available
869
- if self._workspace_client is not None:
870
- warehouse_model._workspace_client = self._workspace_client
871
-
872
913
  return warehouse_model
873
914
  except Exception as e:
874
915
  logger.warning(
@@ -912,9 +953,6 @@ class GenieRoomModel(IsDatabricksResource):
912
953
  workspace_host=self.workspace_host,
913
954
  pat=self.pat,
914
955
  )
915
- # Share the cached workspace client if available
916
- if self._workspace_client is not None:
917
- table_model._workspace_client = self._workspace_client
918
956
 
919
957
  # Verify the table exists before adding
920
958
  if not table_model.exists():
@@ -952,9 +990,6 @@ class GenieRoomModel(IsDatabricksResource):
952
990
  workspace_host=self.workspace_host,
953
991
  pat=self.pat,
954
992
  )
955
- # Share the cached workspace client if available
956
- if self._workspace_client is not None:
957
- function_model._workspace_client = self._workspace_client
958
993
 
959
994
  # Verify the function exists before adding
960
995
  if not function_model.exists():
@@ -2775,6 +2810,11 @@ class AppModel(BaseModel):
2775
2810
  "which is supported by Databricks Model Serving. This allows deploying from "
2776
2811
  "environments with different Python versions (e.g., Databricks Apps with 3.11).",
2777
2812
  )
2813
+ deployment_target: Optional[DeploymentTarget] = Field(
2814
+ default=None,
2815
+ description="Default deployment target. If not specified, defaults to MODEL_SERVING. "
2816
+ "Can be overridden via CLI --target flag. Options: 'model_serving' or 'apps'.",
2817
+ )
2778
2818
 
2779
2819
  @model_validator(mode="after")
2780
2820
  def set_databricks_env_vars(self) -> Self:
@@ -3255,6 +3295,7 @@ class ResourcesModel(BaseModel):
3255
3295
 
3256
3296
  class AppConfig(BaseModel):
3257
3297
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
3298
+ version: Optional[str] = None
3258
3299
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
3259
3300
  service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
3260
3301
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
@@ -3275,6 +3316,9 @@ class AppConfig(BaseModel):
3275
3316
  )
3276
3317
  providers: Optional[dict[type | str, Any]] = None
3277
3318
 
3319
+ # Private attribute to track the source config file path (set by from_file)
3320
+ _source_config_path: str | None = None
3321
+
3278
3322
  @classmethod
3279
3323
  def from_file(cls, path: PathLike) -> "AppConfig":
3280
3324
  path = Path(path).as_posix()
@@ -3282,12 +3326,20 @@ class AppConfig(BaseModel):
3282
3326
  model_config: ModelConfig = ModelConfig(development_config=path)
3283
3327
  config: AppConfig = AppConfig(**model_config.to_dict())
3284
3328
 
3329
+ # Store the source config path for later use (e.g., Apps deployment)
3330
+ config._source_config_path = path
3331
+
3285
3332
  config.initialize()
3286
3333
 
3287
3334
  atexit.register(config.shutdown)
3288
3335
 
3289
3336
  return config
3290
3337
 
3338
+ @property
3339
+ def source_config_path(self) -> str | None:
3340
+ """Get the source config file path if loaded via from_file."""
3341
+ return self._source_config_path
3342
+
3291
3343
  def initialize(self) -> None:
3292
3344
  from dao_ai.hooks.core import create_hooks
3293
3345
  from dao_ai.logging import configure_logging
@@ -3358,6 +3410,7 @@ class AppConfig(BaseModel):
3358
3410
 
3359
3411
  def deploy_agent(
3360
3412
  self,
3413
+ target: DeploymentTarget | None = None,
3361
3414
  w: WorkspaceClient | None = None,
3362
3415
  vsc: "VectorSearchClient | None" = None,
3363
3416
  pat: str | None = None,
@@ -3365,9 +3418,39 @@ class AppConfig(BaseModel):
3365
3418
  client_secret: str | None = None,
3366
3419
  workspace_host: str | None = None,
3367
3420
  ) -> None:
3421
+ """
3422
+ Deploy the agent to the specified target.
3423
+
3424
+ Target resolution follows this priority:
3425
+ 1. Explicit `target` parameter (if provided)
3426
+ 2. `app.deployment_target` from config file (if set)
3427
+ 3. Default: MODEL_SERVING
3428
+
3429
+ Args:
3430
+ target: The deployment target (MODEL_SERVING or APPS). If None, uses
3431
+ config.app.deployment_target or defaults to MODEL_SERVING.
3432
+ w: Optional WorkspaceClient instance
3433
+ vsc: Optional VectorSearchClient instance
3434
+ pat: Optional personal access token for authentication
3435
+ client_id: Optional client ID for service principal authentication
3436
+ client_secret: Optional client secret for service principal authentication
3437
+ workspace_host: Optional workspace host URL
3438
+ """
3368
3439
  from dao_ai.providers.base import ServiceProvider
3369
3440
  from dao_ai.providers.databricks import DatabricksProvider
3370
3441
 
3442
+ # Resolve target using hybrid logic:
3443
+ # 1. Explicit parameter takes precedence
3444
+ # 2. Fall back to config.app.deployment_target
3445
+ # 3. Default to MODEL_SERVING
3446
+ resolved_target: DeploymentTarget
3447
+ if target is not None:
3448
+ resolved_target = target
3449
+ elif self.app is not None and self.app.deployment_target is not None:
3450
+ resolved_target = self.app.deployment_target
3451
+ else:
3452
+ resolved_target = DeploymentTarget.MODEL_SERVING
3453
+
3371
3454
  provider: ServiceProvider = DatabricksProvider(
3372
3455
  w=w,
3373
3456
  vsc=vsc,
@@ -3376,7 +3459,7 @@ class AppConfig(BaseModel):
3376
3459
  client_secret=client_secret,
3377
3460
  workspace_host=workspace_host,
3378
3461
  )
3379
- provider.deploy_agent(self)
3462
+ provider.deploy_agent(self, target=resolved_target)
3380
3463
 
3381
3464
  def find_agents(
3382
3465
  self, predicate: Callable[[AgentModel], bool] | None = None
dao_ai/memory/postgres.py CHANGED
@@ -178,7 +178,20 @@ class AsyncPostgresStoreManager(StoreManagerBase):
178
178
  def _setup(self):
179
179
  if self._setup_complete:
180
180
  return
181
- asyncio.run(self._async_setup())
181
+ try:
182
+ # Check if we're already in an async context
183
+ asyncio.get_running_loop()
184
+ # If we get here, we're in an async context - raise to caller
185
+ raise RuntimeError(
186
+ "Cannot call sync _setup() from async context. "
187
+ "Use await _async_setup() instead."
188
+ )
189
+ except RuntimeError as e:
190
+ if "no running event loop" in str(e).lower():
191
+ # No event loop running - safe to use asyncio.run()
192
+ asyncio.run(self._async_setup())
193
+ else:
194
+ raise
182
195
 
183
196
  async def _async_setup(self):
184
197
  if self._setup_complete:
@@ -237,13 +250,25 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
237
250
 
238
251
  def _setup(self):
239
252
  """
240
- Run the async setup. Works in both sync and async contexts when nest_asyncio is applied.
253
+ Run the async setup. For async contexts, use await _async_setup() directly.
241
254
  """
242
255
  if self._setup_complete:
243
256
  return
244
257
 
245
- # With nest_asyncio applied in notebooks, asyncio.run() works everywhere
246
- asyncio.run(self._async_setup())
258
+ try:
259
+ # Check if we're already in an async context
260
+ asyncio.get_running_loop()
261
+ # If we get here, we're in an async context - raise to caller
262
+ raise RuntimeError(
263
+ "Cannot call sync _setup() from async context. "
264
+ "Use await _async_setup() instead."
265
+ )
266
+ except RuntimeError as e:
267
+ if "no running event loop" in str(e).lower():
268
+ # No event loop running - safe to use asyncio.run()
269
+ asyncio.run(self._async_setup())
270
+ else:
271
+ raise
247
272
 
248
273
  async def _async_setup(self):
249
274
  """