dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +342 -58
  4. dao_ai/config.py +1610 -380
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +158 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +67 -0
  26. dao_ai/middleware/guardrails.py +420 -0
  27. dao_ai/middleware/human_in_the_loop.py +233 -0
  28. dao_ai/middleware/message_validation.py +586 -0
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +197 -0
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/models.py +1306 -114
  36. dao_ai/nodes.py +240 -161
  37. dao_ai/optimization.py +674 -0
  38. dao_ai/orchestration/__init__.py +52 -0
  39. dao_ai/orchestration/core.py +294 -0
  40. dao_ai/orchestration/supervisor.py +279 -0
  41. dao_ai/orchestration/swarm.py +271 -0
  42. dao_ai/prompts.py +128 -31
  43. dao_ai/providers/databricks.py +584 -601
  44. dao_ai/state.py +157 -21
  45. dao_ai/tools/__init__.py +13 -5
  46. dao_ai/tools/agent.py +1 -3
  47. dao_ai/tools/core.py +64 -11
  48. dao_ai/tools/email.py +232 -0
  49. dao_ai/tools/genie.py +144 -294
  50. dao_ai/tools/mcp.py +223 -155
  51. dao_ai/tools/memory.py +50 -0
  52. dao_ai/tools/python.py +9 -14
  53. dao_ai/tools/search.py +14 -0
  54. dao_ai/tools/slack.py +22 -10
  55. dao_ai/tools/sql.py +202 -0
  56. dao_ai/tools/time.py +30 -7
  57. dao_ai/tools/unity_catalog.py +165 -88
  58. dao_ai/tools/vector_search.py +331 -221
  59. dao_ai/utils.py +166 -20
  60. dao_ai/vector_search.py +37 -0
  61. dao_ai-0.1.5.dist-info/METADATA +489 -0
  62. dao_ai-0.1.5.dist-info/RECORD +70 -0
  63. dao_ai/chat_models.py +0 -204
  64. dao_ai/guardrails.py +0 -112
  65. dao_ai/tools/human_in_the_loop.py +0 -100
  66. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  67. dao_ai-0.0.28.dist-info/RECORD +0 -41
  68. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
  69. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
  70. {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
dao_ai/__init__.py CHANGED
@@ -0,0 +1,29 @@
1
+ """
2
+ DAO AI - A framework for building AI agents with Databricks.
3
+
4
+ This module configures package-level settings including warning filters
5
+ for expected runtime warnings that don't indicate actual problems.
6
+ """
7
+
8
+ import warnings
9
+
10
+ # Suppress Pydantic serialization warnings for Context objects during checkpointing.
11
+ # This warning occurs because LangGraph's checkpointer serializes the context_schema
12
+ # and Pydantic reports that serialization may not be as expected. This is benign
13
+ # since Context is only used at runtime and doesn't need to be persisted.
14
+ #
15
+ # The warning looks like:
16
+ # PydanticSerializationUnexpectedValue(Expected `none` - serialized value may not
17
+ # be as expected [field_name='context', input_value=Context(...), input_type=Context])
18
+ warnings.filterwarnings(
19
+ "ignore",
20
+ message=r".*Pydantic serializer warnings.*",
21
+ category=UserWarning,
22
+ )
23
+
24
+ # Also filter the specific PydanticSerializationUnexpectedValue warning
25
+ warnings.filterwarnings(
26
+ "ignore",
27
+ message=r".*PydanticSerializationUnexpectedValue.*",
28
+ category=UserWarning,
29
+ )
dao_ai/agent_as_code.py CHANGED
@@ -1,11 +1,9 @@
1
- import sys
2
-
3
1
  import mlflow
4
- from loguru import logger
5
2
  from mlflow.models import ModelConfig
6
3
  from mlflow.pyfunc import ResponsesAgent
7
4
 
8
5
  from dao_ai.config import AppConfig
6
+ from dao_ai.logging import configure_logging
9
7
 
10
8
  mlflow.set_registry_uri("databricks-uc")
11
9
  mlflow.set_tracking_uri("databricks")
@@ -17,8 +15,7 @@ config: AppConfig = AppConfig(**model_config.to_dict())
17
15
 
18
16
  log_level: str = config.app.log_level
19
17
 
20
- logger.remove()
21
- logger.add(sys.stderr, level=log_level)
18
+ configure_logging(level=log_level)
22
19
 
23
20
  app: ResponsesAgent = config.as_responses_agent()
24
21
 
dao_ai/cli.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import getpass
2
3
  import json
3
4
  import os
4
5
  import subprocess
@@ -13,11 +14,88 @@ from loguru import logger
13
14
 
14
15
  from dao_ai.config import AppConfig
15
16
  from dao_ai.graph import create_dao_ai_graph
17
+ from dao_ai.logging import configure_logging
16
18
  from dao_ai.models import save_image
17
19
  from dao_ai.utils import normalize_name
18
20
 
19
- logger.remove()
20
- logger.add(sys.stderr, level="ERROR")
21
+ configure_logging(level="ERROR")
22
+
23
+
24
+ def get_default_user_id() -> str:
25
+ """
26
+ Get the default user ID for the CLI session.
27
+
28
+ Tries to get the current user from Databricks, falls back to local user.
29
+
30
+ Returns:
31
+ User ID string (Databricks username or local username)
32
+ """
33
+ try:
34
+ # Try to get current user from Databricks SDK
35
+ from databricks.sdk import WorkspaceClient
36
+
37
+ w = WorkspaceClient()
38
+ current_user = w.current_user.me()
39
+ user_id = current_user.user_name
40
+ logger.debug(f"Using Databricks user: {user_id}")
41
+ return user_id
42
+ except Exception as e:
43
+ # Fall back to local system user
44
+ logger.debug(f"Could not get Databricks user, using local user: {e}")
45
+ local_user = getpass.getuser()
46
+ logger.debug(f"Using local user: {local_user}")
47
+ return local_user
48
+
49
+
50
+ def detect_cloud_provider(profile: Optional[str] = None) -> Optional[str]:
51
+ """
52
+ Detect the cloud provider from the Databricks workspace URL.
53
+
54
+ The cloud provider is determined by the workspace URL pattern:
55
+ - Azure: *.azuredatabricks.net
56
+ - AWS: *.cloud.databricks.com (without gcp subdomain)
57
+ - GCP: *.gcp.databricks.com
58
+
59
+ Args:
60
+ profile: Optional Databricks CLI profile name
61
+
62
+ Returns:
63
+ Cloud provider string ('azure', 'aws', 'gcp') or None if detection fails
64
+ """
65
+ try:
66
+ from databricks.sdk import WorkspaceClient
67
+
68
+ # Create workspace client with optional profile
69
+ if profile:
70
+ w = WorkspaceClient(profile=profile)
71
+ else:
72
+ w = WorkspaceClient()
73
+
74
+ # Get the workspace URL from config
75
+ host = w.config.host
76
+ if not host:
77
+ logger.warning("Could not determine workspace URL for cloud detection")
78
+ return None
79
+
80
+ host_lower = host.lower()
81
+
82
+ if "azuredatabricks.net" in host_lower:
83
+ logger.debug(f"Detected Azure cloud from workspace URL: {host}")
84
+ return "azure"
85
+ elif ".gcp.databricks.com" in host_lower:
86
+ logger.debug(f"Detected GCP cloud from workspace URL: {host}")
87
+ return "gcp"
88
+ elif ".cloud.databricks.com" in host_lower or "databricks.com" in host_lower:
89
+ # AWS uses *.cloud.databricks.com or regional patterns
90
+ logger.debug(f"Detected AWS cloud from workspace URL: {host}")
91
+ return "aws"
92
+ else:
93
+ logger.warning(f"Could not determine cloud provider from URL: {host}")
94
+ return None
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Could not detect cloud provider: {e}")
98
+ return None
21
99
 
22
100
 
23
101
  env_path: str = find_dotenv()
@@ -193,6 +271,13 @@ Examples:
193
271
  "-t",
194
272
  "--target",
195
273
  type=str,
274
+ help="Bundle target name (default: auto-generated from app name and cloud)",
275
+ )
276
+ bundle_parser.add_argument(
277
+ "--cloud",
278
+ type=str,
279
+ choices=["azure", "aws", "gcp"],
280
+ help="Cloud provider (auto-detected from workspace URL if not specified)",
196
281
  )
197
282
  bundle_parser.add_argument(
198
283
  "--dry-run",
@@ -240,9 +325,9 @@ Use Ctrl-C to interrupt and exit immediately.
240
325
  """,
241
326
  epilog="""
242
327
  Examples:
243
- dao-ai chat -c config/model_config.yaml # Start chat with default settings
328
+ dao-ai chat -c config/model_config.yaml # Start chat (auto-detects user)
244
329
  dao-ai chat -c config/retail.yaml --custom-input store_num=87887 # Chat with custom store number
245
- dao-ai chat -c config/prod.yaml --user-id john123 # Chat with specific user ID
330
+ dao-ai chat -c config/prod.yaml --user-id john.doe@company.com # Chat with specific user ID
246
331
  dao-ai chat -c config/retail.yaml --custom-input store_num=123 --custom-input region=west # Multiple custom inputs
247
332
  """,
248
333
  formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -264,28 +349,38 @@ Examples:
264
349
  chat_parser.add_argument(
265
350
  "--user-id",
266
351
  type=str,
267
- default="my_user_id",
352
+ default=None, # Will be set to actual user in handle_chat_command
268
353
  metavar="ID",
269
- help="User ID for the chat session (default: my_user_id)",
354
+ help="User ID for the chat session (default: current Databricks user or local username)",
270
355
  )
271
356
  chat_parser.add_argument(
272
357
  "--thread-id",
273
358
  type=str,
274
- default="1",
359
+ default=None,
275
360
  metavar="ID",
276
- help="Thread ID for the chat session (default: 1)",
361
+ help="Thread ID for the chat session (default: auto-generated UUID)",
277
362
  )
278
363
 
279
364
  options = parser.parse_args(args)
280
365
 
366
+ # Generate a new thread_id UUID if not provided (only for chat command)
367
+ if hasattr(options, "thread_id") and options.thread_id is None:
368
+ import uuid
369
+
370
+ options.thread_id = str(uuid.uuid4())
371
+
281
372
  return options
282
373
 
283
374
 
284
375
  def handle_chat_command(options: Namespace) -> None:
285
- """Interactive chat REPL with the DAO AI system."""
376
+ """Interactive chat REPL with the DAO AI system with Human-in-the-Loop support."""
286
377
  logger.debug("Starting chat session with DAO AI system...")
287
378
 
288
379
  try:
380
+ # Set default user_id if not provided
381
+ if options.user_id is None:
382
+ options.user_id = get_default_user_id()
383
+
289
384
  config: AppConfig = AppConfig.from_file(options.config)
290
385
  app = create_dao_ai_graph(config)
291
386
 
@@ -305,9 +400,7 @@ def handle_chat_command(options: Namespace) -> None:
305
400
  print("-" * 50)
306
401
 
307
402
  # Import streaming function and interrupt handling
308
- from langchain_core.messages import HumanMessage
309
-
310
- from dao_ai.models import process_messages_stream
403
+ from langchain_core.messages import AIMessage, HumanMessage
311
404
 
312
405
  # Conversation history
313
406
  messages = []
@@ -350,47 +443,210 @@ def handle_chat_command(options: Namespace) -> None:
350
443
  )
351
444
  continue
352
445
 
353
- # Prepare custom inputs for the agent
354
- custom_inputs = {"configurable": configurable}
446
+ # Create Context object from configurable dict
447
+ from dao_ai.state import Context
448
+
449
+ context = Context(**configurable)
450
+
451
+ # Prepare config with thread_id for checkpointer
452
+ # Note: thread_id is needed in config for checkpointer/memory
453
+ config = {"configurable": {"thread_id": options.thread_id}}
454
+
455
+ # Invoke the graph and handle interrupts (HITL)
456
+ # Wrap in async function to maintain connection pool throughout
457
+ logger.debug(f"Invoking graph with {len(messages)} messages")
458
+ logger.debug(f"Context: {context}")
459
+ logger.debug(f"Config: {config}")
460
+
461
+ import asyncio
462
+
463
+ from langgraph.types import Command
464
+
465
+ async def _invoke_with_hitl():
466
+ """Invoke graph and handle HITL interrupts in single async context."""
467
+ result = await app.ainvoke(
468
+ {"messages": messages},
469
+ config=config,
470
+ context=context, # Pass context as separate parameter
471
+ )
472
+
473
+ # Check for interrupts (Human-in-the-Loop) using __interrupt__
474
+ # This is the modern LangChain pattern
475
+ while "__interrupt__" in result:
476
+ interrupts = result["__interrupt__"]
477
+ logger.info(f"HITL: {len(interrupts)} interrupt(s) detected")
478
+
479
+ # Collect decisions for all interrupts
480
+ decisions = []
481
+
482
+ for interrupt in interrupts:
483
+ interrupt_value = interrupt.value
484
+ action_requests = interrupt_value.get("action_requests", [])
485
+
486
+ for action_request in action_requests:
487
+ # Display interrupt information
488
+ print("\n⚠️ Human in the Loop - Tool Approval Required")
489
+ print(f"{'=' * 60}")
490
+
491
+ tool_name = action_request.get("name", "unknown")
492
+ tool_args = action_request.get("arguments", {})
493
+ description = action_request.get("description", "")
494
+
495
+ print(f"Tool: {tool_name}")
496
+ if description:
497
+ print(f"\n{description}\n")
498
+
499
+ print("Arguments:")
500
+ for arg_name, arg_value in tool_args.items():
501
+ # Truncate long values
502
+ arg_str = str(arg_value)
503
+ if len(arg_str) > 100:
504
+ arg_str = arg_str[:97] + "..."
505
+ print(f" - {arg_name}: {arg_str}")
506
+
507
+ print(f"{'=' * 60}")
508
+
509
+ # Prompt user for decision
510
+ while True:
511
+ decision_input = (
512
+ input(
513
+ "\nAction? (a)pprove / (r)eject / (e)dit / (h)elp: "
514
+ )
515
+ .strip()
516
+ .lower()
517
+ )
518
+
519
+ if decision_input in ["a", "approve"]:
520
+ logger.info("User approved tool call")
521
+ print("✅ Approved - continuing execution...")
522
+ decisions.append({"type": "approve"})
523
+ break
524
+ elif decision_input in ["r", "reject"]:
525
+ logger.info("User rejected tool call")
526
+ feedback = input(
527
+ " Feedback for agent (optional): "
528
+ ).strip()
529
+ if feedback:
530
+ decisions.append(
531
+ {"type": "reject", "message": feedback}
532
+ )
533
+ else:
534
+ decisions.append(
535
+ {
536
+ "type": "reject",
537
+ "message": "Tool call rejected by user",
538
+ }
539
+ )
540
+ print(
541
+ "❌ Rejected - agent will receive feedback..."
542
+ )
543
+ break
544
+ elif decision_input in ["e", "edit"]:
545
+ print(
546
+ "ℹ️ Edit functionality not yet implemented in CLI"
547
+ )
548
+ print(" Please approve or reject.")
549
+ continue
550
+ elif decision_input in ["h", "help"]:
551
+ print("\nAvailable actions:")
552
+ print(
553
+ " (a)pprove - Execute the tool call as shown"
554
+ )
555
+ print(
556
+ " (r)eject - Cancel the tool call with optional feedback"
557
+ )
558
+ print(
559
+ " (e)dit - Modify arguments (not yet implemented)"
560
+ )
561
+ print(" (h)elp - Show this help message")
562
+ continue
563
+ else:
564
+ print("Invalid option. Type 'h' for help.")
565
+ continue
566
+
567
+ # Resume execution with decisions using Command
568
+ # This is the modern LangChain pattern
569
+ logger.debug(f"Resuming with {len(decisions)} decision(s)")
570
+ result = await app.ainvoke(
571
+ Command(resume={"decisions": decisions}),
572
+ config=config,
573
+ context=context,
574
+ )
575
+
576
+ return result
577
+
578
+ try:
579
+ # Use async invoke - keep connection pool alive throughout HITL
580
+ loop = asyncio.get_event_loop()
581
+ except RuntimeError:
582
+ loop = asyncio.new_event_loop()
583
+ asyncio.set_event_loop(loop)
584
+
585
+ try:
586
+ result = loop.run_until_complete(_invoke_with_hitl())
587
+ except Exception as e:
588
+ logger.error(f"Error invoking graph: {e}")
589
+ print(f"\n❌ Error: {e}")
590
+ continue
355
591
 
592
+ # After all interrupts handled, display the final response
356
593
  print("\n🤖 Assistant: ", end="", flush=True)
357
594
 
358
- # Stream the response
359
595
  response_content = ""
596
+ structured_response = None
360
597
  try:
361
- for chunk in process_messages_stream(app, messages, custom_inputs):
362
- # Handle different chunk types
363
- if hasattr(chunk, "content") and chunk.content:
364
- content = chunk.content
365
- print(content, end="", flush=True)
366
- response_content += content
367
- elif hasattr(chunk, "choices") and chunk.choices:
368
- # Handle ChatCompletionChunk format
369
- for choice in chunk.choices:
370
- if (
371
- hasattr(choice, "delta")
372
- and choice.delta
373
- and choice.delta.content
374
- ):
375
- content = choice.delta.content
376
- print(content, end="", flush=True)
377
- response_content += content
378
-
379
- print() # New line after streaming
598
+ # Debug: Log what's in the result
599
+ logger.debug(f"Result keys: {result.keys() if result else 'None'}")
600
+ if result:
601
+ for key in result.keys():
602
+ logger.debug(f"Result['{key}'] type: {type(result[key])}")
603
+
604
+ # Get the latest messages from the result
605
+ if result and "messages" in result:
606
+ latest_messages = result["messages"]
607
+ # Find the last AI message
608
+ for msg in reversed(latest_messages):
609
+ if isinstance(msg, AIMessage):
610
+ if hasattr(msg, "content") and msg.content:
611
+ response_content = msg.content
612
+ print(response_content, end="", flush=True)
613
+ break
614
+
615
+ # Check for structured output and display it separately
616
+ if result and "structured_response" in result:
617
+ structured_response = result["structured_response"]
618
+ import json
619
+
620
+ structured_json = json.dumps(
621
+ structured_response.model_dump()
622
+ if hasattr(structured_response, "model_dump")
623
+ else structured_response,
624
+ indent=2,
625
+ )
626
+
627
+ # If there was message content, add separator
628
+ if response_content.strip():
629
+ print("\n\n📊 Structured Output:")
630
+ print(structured_json)
631
+ else:
632
+ # No message content, just show structured output
633
+ print(structured_json, end="", flush=True)
634
+
635
+ response_content = response_content or structured_json
636
+
637
+ print() # New line after response
380
638
 
381
639
  # Add assistant response to history if we got content
382
640
  if response_content.strip():
383
- from langchain_core.messages import AIMessage
384
-
385
641
  assistant_message = AIMessage(content=response_content)
386
642
  messages.append(assistant_message)
387
643
  else:
388
644
  print("(No response content generated)")
389
645
 
390
646
  except Exception as e:
391
- print(f"\n❌ Error during streaming: {e}")
647
+ print(f"\n❌ Error processing response: {e}")
392
648
  print(f"Stack trace:\n{traceback.format_exc()}")
393
- logger.error(f"Streaming error: {e}")
649
+ logger.error(f"Response processing error: {e}")
394
650
  logger.error(f"Stack trace: {traceback.format_exc()}")
395
651
 
396
652
  except EOFError:
@@ -404,6 +660,7 @@ def handle_chat_command(options: Namespace) -> None:
404
660
  except Exception as e:
405
661
  print(f"\n❌ Error: {e}")
406
662
  logger.error(f"Chat error: {e}")
663
+ traceback.print_exc()
407
664
 
408
665
  except Exception as e:
409
666
  logger.error(f"Failed to initialize chat session: {e}")
@@ -448,7 +705,6 @@ def handle_validate_command(options: Namespace) -> None:
448
705
 
449
706
 
450
707
  def setup_logging(verbosity: int) -> None:
451
- logger.remove()
452
708
  levels: dict[int, str] = {
453
709
  0: "ERROR",
454
710
  1: "WARNING",
@@ -457,7 +713,7 @@ def setup_logging(verbosity: int) -> None:
457
713
  4: "TRACE",
458
714
  }
459
715
  level: str = levels.get(verbosity, "TRACE")
460
- logger.add(sys.stderr, level=level)
716
+ configure_logging(level=level)
461
717
 
462
718
 
463
719
  def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
@@ -471,7 +727,7 @@ def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
471
727
  4. Returns the path to the generated file
472
728
 
473
729
  The generated databricks.yaml is overwritten on each deployment and is not tracked in git.
474
- Schema reference remains pointing to ./schemas/bundle_config_schema.json.
730
+ The template contains cloud-specific targets (azure, aws, gcp) with appropriate node types.
475
731
 
476
732
  Args:
477
733
  config_path: Path to the app config file
@@ -508,39 +764,59 @@ def run_databricks_command(
508
764
  profile: Optional[str] = None,
509
765
  config: Optional[str] = None,
510
766
  target: Optional[str] = None,
767
+ cloud: Optional[str] = None,
511
768
  dry_run: bool = False,
512
769
  ) -> None:
513
- """Execute a databricks CLI command with optional profile and target."""
770
+ """Execute a databricks CLI command with optional profile, target, and cloud.
771
+
772
+ Args:
773
+ command: The databricks CLI command to execute (e.g., ["bundle", "deploy"])
774
+ profile: Optional Databricks CLI profile name
775
+ config: Optional path to the configuration file
776
+ target: Optional bundle target name (if not provided, auto-generated from app name and cloud)
777
+ cloud: Optional cloud provider ('azure', 'aws', 'gcp'). Auto-detected if not specified.
778
+ dry_run: If True, print the command without executing
779
+ """
514
780
  config_path = Path(config) if config else None
515
781
 
516
782
  if config_path and not config_path.exists():
517
783
  logger.error(f"Configuration file {config_path} does not exist.")
518
784
  sys.exit(1)
519
785
 
520
- # Load app config and generate bundle from template
786
+ # Load app config
521
787
  app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
522
788
  normalized_name: str = normalize_name(app_config.app.name) if app_config else None
523
789
 
790
+ # Auto-detect cloud provider if not specified
791
+ if not cloud:
792
+ cloud = detect_cloud_provider(profile)
793
+ if cloud:
794
+ logger.info(f"Auto-detected cloud provider: {cloud}")
795
+ else:
796
+ logger.warning("Could not detect cloud provider. Defaulting to 'azure'.")
797
+ cloud = "azure"
798
+
524
799
  # Generate app-specific bundle from template (overwrites databricks.yaml temporarily)
525
800
  if config_path and app_config:
526
801
  generate_bundle_from_template(config_path, normalized_name)
527
802
 
528
- # Use app name as target if not explicitly provided
529
- # This ensures each app gets its own Terraform state in .databricks/bundle/<app-name>/
530
- if not target and normalized_name:
531
- target = normalized_name
532
- logger.debug(f"Using app-specific target: {target}")
803
+ # Use cloud as target (azure, aws, gcp) - can be overridden with explicit --target
804
+ if not target:
805
+ target = cloud
806
+ logger.debug(f"Using cloud-based target: {target}")
533
807
 
534
- # Build databricks command (no -c flag needed, uses databricks.yaml in current dir)
808
+ # Build databricks command
809
+ # --profile is a global flag, --target is a subcommand flag for 'bundle'
535
810
  cmd = ["databricks"]
536
811
  if profile:
537
812
  cmd.extend(["--profile", profile])
538
813
 
814
+ cmd.extend(command)
815
+
816
+ # --target must come after the bundle subcommand (it's a subcommand-specific flag)
539
817
  if target:
540
818
  cmd.extend(["--target", target])
541
819
 
542
- cmd.extend(command)
543
-
544
820
  # Add config_path variable for notebooks
545
821
  if config_path and app_config:
546
822
  # Calculate relative path from notebooks directory to config file
@@ -595,30 +871,38 @@ def handle_bundle_command(options: Namespace) -> None:
595
871
  profile: Optional[str] = options.profile
596
872
  config: Optional[str] = options.config
597
873
  target: Optional[str] = options.target
874
+ cloud: Optional[str] = options.cloud
598
875
  dry_run: bool = options.dry_run
599
876
 
600
877
  if options.deploy:
601
878
  logger.info("Deploying DAO AI asset bundle...")
602
879
  run_databricks_command(
603
- ["bundle", "deploy"], profile, config, target, dry_run=dry_run
880
+ ["bundle", "deploy"],
881
+ profile=profile,
882
+ config=config,
883
+ target=target,
884
+ cloud=cloud,
885
+ dry_run=dry_run,
604
886
  )
605
887
  if options.run:
606
888
  logger.info("Running DAO AI system with current configuration...")
607
889
  # Use static job resource key that matches databricks.yaml (resources.jobs.deploy_job)
608
890
  run_databricks_command(
609
891
  ["bundle", "run", "deploy_job"],
610
- profile,
611
- config,
612
- target,
892
+ profile=profile,
893
+ config=config,
894
+ target=target,
895
+ cloud=cloud,
613
896
  dry_run=dry_run,
614
897
  )
615
898
  if options.destroy:
616
899
  logger.info("Destroying DAO AI system with current configuration...")
617
900
  run_databricks_command(
618
901
  ["bundle", "destroy", "--auto-approve"],
619
- profile,
620
- config,
621
- target,
902
+ profile=profile,
903
+ config=config,
904
+ target=target,
905
+ cloud=cloud,
622
906
  dry_run=dry_run,
623
907
  )
624
908
  else: