dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/__init__.py CHANGED
@@ -0,0 +1,29 @@
1
+ """
2
+ DAO AI - A framework for building AI agents with Databricks.
3
+
4
+ This module configures package-level settings including warning filters
5
+ for expected runtime warnings that don't indicate actual problems.
6
+ """
7
+
8
+ import warnings
9
+
10
+ # Suppress Pydantic serialization warnings for Context objects during checkpointing.
11
+ # This warning occurs because LangGraph's checkpointer serializes the context_schema
12
+ # and Pydantic reports that serialization may not be as expected. This is benign
13
+ # since Context is only used at runtime and doesn't need to be persisted.
14
+ #
15
+ # The warning looks like:
16
+ # PydanticSerializationUnexpectedValue(Expected `none` - serialized value may not
17
+ # be as expected [field_name='context', input_value=Context(...), input_type=Context])
18
+ warnings.filterwarnings(
19
+ "ignore",
20
+ message=r".*Pydantic serializer warnings.*",
21
+ category=UserWarning,
22
+ )
23
+
24
+ # Also filter the specific PydanticSerializationUnexpectedValue warning
25
+ warnings.filterwarnings(
26
+ "ignore",
27
+ message=r".*PydanticSerializationUnexpectedValue.*",
28
+ category=UserWarning,
29
+ )
dao_ai/agent_as_code.py CHANGED
@@ -1,11 +1,9 @@
1
- import sys
2
-
3
1
  import mlflow
4
- from loguru import logger
5
2
  from mlflow.models import ModelConfig
6
3
  from mlflow.pyfunc import ResponsesAgent
7
4
 
8
5
  from dao_ai.config import AppConfig
6
+ from dao_ai.logging import configure_logging
9
7
 
10
8
  mlflow.set_registry_uri("databricks-uc")
11
9
  mlflow.set_tracking_uri("databricks")
@@ -17,8 +15,7 @@ config: AppConfig = AppConfig(**model_config.to_dict())
17
15
 
18
16
  log_level: str = config.app.log_level
19
17
 
20
- logger.remove()
21
- logger.add(sys.stderr, level=log_level)
18
+ configure_logging(level=log_level)
22
19
 
23
20
  app: ResponsesAgent = config.as_responses_agent()
24
21
 
dao_ai/cli.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import getpass
2
3
  import json
3
4
  import os
4
5
  import subprocess
@@ -13,11 +14,37 @@ from loguru import logger
13
14
 
14
15
  from dao_ai.config import AppConfig
15
16
  from dao_ai.graph import create_dao_ai_graph
17
+ from dao_ai.logging import configure_logging
16
18
  from dao_ai.models import save_image
17
19
  from dao_ai.utils import normalize_name
18
20
 
19
- logger.remove()
20
- logger.add(sys.stderr, level="ERROR")
21
+ configure_logging(level="ERROR")
22
+
23
+
24
+ def get_default_user_id() -> str:
25
+ """
26
+ Get the default user ID for the CLI session.
27
+
28
+ Tries to get the current user from Databricks, falls back to local user.
29
+
30
+ Returns:
31
+ User ID string (Databricks username or local username)
32
+ """
33
+ try:
34
+ # Try to get current user from Databricks SDK
35
+ from databricks.sdk import WorkspaceClient
36
+
37
+ w = WorkspaceClient()
38
+ current_user = w.current_user.me()
39
+ user_id = current_user.user_name
40
+ logger.debug(f"Using Databricks user: {user_id}")
41
+ return user_id
42
+ except Exception as e:
43
+ # Fall back to local system user
44
+ logger.debug(f"Could not get Databricks user, using local user: {e}")
45
+ local_user = getpass.getuser()
46
+ logger.debug(f"Using local user: {local_user}")
47
+ return local_user
21
48
 
22
49
 
23
50
  env_path: str = find_dotenv()
@@ -240,9 +267,9 @@ Use Ctrl-C to interrupt and exit immediately.
240
267
  """,
241
268
  epilog="""
242
269
  Examples:
243
- dao-ai chat -c config/model_config.yaml # Start chat with default settings
270
+ dao-ai chat -c config/model_config.yaml # Start chat (auto-detects user)
244
271
  dao-ai chat -c config/retail.yaml --custom-input store_num=87887 # Chat with custom store number
245
- dao-ai chat -c config/prod.yaml --user-id john123 # Chat with specific user ID
272
+ dao-ai chat -c config/prod.yaml --user-id john.doe@company.com # Chat with specific user ID
246
273
  dao-ai chat -c config/retail.yaml --custom-input store_num=123 --custom-input region=west # Multiple custom inputs
247
274
  """,
248
275
  formatter_class=argparse.RawDescriptionHelpFormatter,
@@ -264,28 +291,38 @@ Examples:
264
291
  chat_parser.add_argument(
265
292
  "--user-id",
266
293
  type=str,
267
- default="my_user_id",
294
+ default=None, # Will be set to actual user in handle_chat_command
268
295
  metavar="ID",
269
- help="User ID for the chat session (default: my_user_id)",
296
+ help="User ID for the chat session (default: current Databricks user or local username)",
270
297
  )
271
298
  chat_parser.add_argument(
272
299
  "--thread-id",
273
300
  type=str,
274
- default="1",
301
+ default=None,
275
302
  metavar="ID",
276
- help="Thread ID for the chat session (default: 1)",
303
+ help="Thread ID for the chat session (default: auto-generated UUID)",
277
304
  )
278
305
 
279
306
  options = parser.parse_args(args)
280
307
 
308
+ # Generate a new thread_id UUID if not provided (only for chat command)
309
+ if hasattr(options, "thread_id") and options.thread_id is None:
310
+ import uuid
311
+
312
+ options.thread_id = str(uuid.uuid4())
313
+
281
314
  return options
282
315
 
283
316
 
284
317
  def handle_chat_command(options: Namespace) -> None:
285
- """Interactive chat REPL with the DAO AI system."""
318
+ """Interactive chat REPL with the DAO AI system with Human-in-the-Loop support."""
286
319
  logger.debug("Starting chat session with DAO AI system...")
287
320
 
288
321
  try:
322
+ # Set default user_id if not provided
323
+ if options.user_id is None:
324
+ options.user_id = get_default_user_id()
325
+
289
326
  config: AppConfig = AppConfig.from_file(options.config)
290
327
  app = create_dao_ai_graph(config)
291
328
 
@@ -305,9 +342,7 @@ def handle_chat_command(options: Namespace) -> None:
305
342
  print("-" * 50)
306
343
 
307
344
  # Import streaming function and interrupt handling
308
- from langchain_core.messages import HumanMessage
309
-
310
- from dao_ai.models import process_messages_stream
345
+ from langchain_core.messages import AIMessage, HumanMessage
311
346
 
312
347
  # Conversation history
313
348
  messages = []
@@ -350,47 +385,217 @@ def handle_chat_command(options: Namespace) -> None:
350
385
  )
351
386
  continue
352
387
 
353
- # Prepare custom inputs for the agent
354
- custom_inputs = {"configurable": configurable}
388
+ # Create Context object from configurable dict
389
+ from dao_ai.state import Context
390
+
391
+ context = Context(**configurable)
392
+
393
+ # Prepare config with thread_id for checkpointer
394
+ # Note: thread_id is needed in config for checkpointer/memory
395
+ config = {"configurable": {"thread_id": options.thread_id}}
396
+
397
+ # Invoke the graph and handle interrupts (HITL)
398
+ # Wrap in async function to maintain connection pool throughout
399
+ logger.debug(f"Invoking graph with {len(messages)} messages")
400
+ logger.debug(f"Context: {context}")
401
+ logger.debug(f"Config: {config}")
402
+
403
+ import asyncio
404
+
405
+ from langgraph.types import Command
406
+
407
+ async def _invoke_with_hitl():
408
+ """Invoke graph and handle HITL interrupts in single async context."""
409
+ result = await app.ainvoke(
410
+ {"messages": messages},
411
+ config=config,
412
+ context=context, # Pass context as separate parameter
413
+ )
414
+
415
+ # Check for interrupts (Human-in-the-Loop) using __interrupt__
416
+ # This is the modern LangChain pattern
417
+ while "__interrupt__" in result:
418
+ interrupts = result["__interrupt__"]
419
+ logger.info(f"HITL: {len(interrupts)} interrupt(s) detected")
420
+
421
+ # Collect decisions for all interrupts
422
+ decisions = []
423
+
424
+ for interrupt in interrupts:
425
+ interrupt_value = interrupt.value
426
+ action_requests = interrupt_value.get("action_requests", [])
427
+
428
+ for action_request in action_requests:
429
+ # Display interrupt information
430
+ print("\n⚠️ Human in the Loop - Tool Approval Required")
431
+ print(f"{'=' * 60}")
432
+
433
+ tool_name = action_request.get("name", "unknown")
434
+ tool_args = action_request.get("arguments", {})
435
+ description = action_request.get("description", "")
436
+
437
+ print(f"Tool: {tool_name}")
438
+ if description:
439
+ print(f"\n{description}\n")
440
+
441
+ print("Arguments:")
442
+ for arg_name, arg_value in tool_args.items():
443
+ # Truncate long values
444
+ arg_str = str(arg_value)
445
+ if len(arg_str) > 100:
446
+ arg_str = arg_str[:97] + "..."
447
+ print(f" - {arg_name}: {arg_str}")
448
+
449
+ print(f"{'=' * 60}")
450
+
451
+ # Prompt user for decision
452
+ while True:
453
+ decision_input = (
454
+ input(
455
+ "\nAction? (a)pprove / (r)eject / (e)dit / (h)elp: "
456
+ )
457
+ .strip()
458
+ .lower()
459
+ )
460
+
461
+ if decision_input in ["a", "approve"]:
462
+ logger.info("User approved tool call")
463
+ print("✅ Approved - continuing execution...")
464
+ decisions.append({"type": "approve"})
465
+ break
466
+ elif decision_input in ["r", "reject"]:
467
+ logger.info("User rejected tool call")
468
+ feedback = input(
469
+ " Feedback for agent (optional): "
470
+ ).strip()
471
+ if feedback:
472
+ decisions.append(
473
+ {"type": "reject", "message": feedback}
474
+ )
475
+ else:
476
+ decisions.append(
477
+ {
478
+ "type": "reject",
479
+ "message": "Tool call rejected by user",
480
+ }
481
+ )
482
+ print(
483
+ "❌ Rejected - agent will receive feedback..."
484
+ )
485
+ break
486
+ elif decision_input in ["e", "edit"]:
487
+ print(
488
+ "ℹ️ Edit functionality not yet implemented in CLI"
489
+ )
490
+ print(" Please approve or reject.")
491
+ continue
492
+ elif decision_input in ["h", "help"]:
493
+ print("\nAvailable actions:")
494
+ print(
495
+ " (a)pprove - Execute the tool call as shown"
496
+ )
497
+ print(
498
+ " (r)eject - Cancel the tool call with optional feedback"
499
+ )
500
+ print(
501
+ " (e)dit - Modify arguments (not yet implemented)"
502
+ )
503
+ print(" (h)elp - Show this help message")
504
+ continue
505
+ else:
506
+ print("Invalid option. Type 'h' for help.")
507
+ continue
508
+
509
+ # Resume execution with decisions using Command
510
+ # This is the modern LangChain pattern
511
+ logger.debug(f"Resuming with {len(decisions)} decision(s)")
512
+ result = await app.ainvoke(
513
+ Command(resume={"decisions": decisions}),
514
+ config=config,
515
+ context=context,
516
+ )
517
+
518
+ return result
519
+
520
+ try:
521
+ # Use async invoke - keep connection pool alive throughout HITL
522
+ loop = asyncio.get_event_loop()
523
+ except RuntimeError:
524
+ loop = asyncio.new_event_loop()
525
+ asyncio.set_event_loop(loop)
526
+
527
+ try:
528
+ result = loop.run_until_complete(_invoke_with_hitl())
529
+ except Exception as e:
530
+ logger.error(f"Error invoking graph: {e}")
531
+ print(f"\n❌ Error: {e}")
532
+ continue
355
533
 
534
+ # After all interrupts handled, display the final response
356
535
  print("\n🤖 Assistant: ", end="", flush=True)
357
536
 
358
- # Stream the response
359
537
  response_content = ""
538
+ structured_response = None
360
539
  try:
361
- for chunk in process_messages_stream(app, messages, custom_inputs):
362
- # Handle different chunk types
363
- if hasattr(chunk, "content") and chunk.content:
364
- content = chunk.content
365
- print(content, end="", flush=True)
366
- response_content += content
367
- elif hasattr(chunk, "choices") and chunk.choices:
368
- # Handle ChatCompletionChunk format
369
- for choice in chunk.choices:
370
- if (
371
- hasattr(choice, "delta")
372
- and choice.delta
373
- and choice.delta.content
374
- ):
375
- content = choice.delta.content
376
- print(content, end="", flush=True)
377
- response_content += content
378
-
379
- print() # New line after streaming
540
+ # Debug: Log what's in the result
541
+ logger.debug(f"Result keys: {result.keys() if result else 'None'}")
542
+ if result:
543
+ for key in result.keys():
544
+ logger.debug(f"Result['{key}'] type: {type(result[key])}")
545
+
546
+ # Get the latest messages from the result
547
+ if result and "messages" in result:
548
+ latest_messages = result["messages"]
549
+ # Find the last AI message
550
+ for msg in reversed(latest_messages):
551
+ if isinstance(msg, AIMessage):
552
+ logger.debug(f"AI message content: {msg.content}")
553
+ logger.debug(
554
+ f"AI message has tool_calls: {hasattr(msg, 'tool_calls')}"
555
+ )
556
+ if hasattr(msg, "tool_calls"):
557
+ logger.debug(f"Tool calls: {msg.tool_calls}")
558
+
559
+ if hasattr(msg, "content") and msg.content:
560
+ response_content = msg.content
561
+ print(response_content, end="", flush=True)
562
+ break
563
+
564
+ # Check for structured output and display it separately
565
+ if result and "structured_response" in result:
566
+ structured_response = result["structured_response"]
567
+ import json
568
+
569
+ structured_json = json.dumps(
570
+ structured_response.model_dump()
571
+ if hasattr(structured_response, "model_dump")
572
+ else structured_response,
573
+ indent=2,
574
+ )
575
+
576
+ # If there was message content, add separator
577
+ if response_content.strip():
578
+ print("\n\n📊 Structured Output:")
579
+ print(structured_json)
580
+ else:
581
+ # No message content, just show structured output
582
+ print(structured_json, end="", flush=True)
583
+
584
+ response_content = response_content or structured_json
585
+
586
+ print() # New line after response
380
587
 
381
588
  # Add assistant response to history if we got content
382
589
  if response_content.strip():
383
- from langchain_core.messages import AIMessage
384
-
385
590
  assistant_message = AIMessage(content=response_content)
386
591
  messages.append(assistant_message)
387
592
  else:
388
593
  print("(No response content generated)")
389
594
 
390
595
  except Exception as e:
391
- print(f"\n❌ Error during streaming: {e}")
596
+ print(f"\n❌ Error processing response: {e}")
392
597
  print(f"Stack trace:\n{traceback.format_exc()}")
393
- logger.error(f"Streaming error: {e}")
598
+ logger.error(f"Response processing error: {e}")
394
599
  logger.error(f"Stack trace: {traceback.format_exc()}")
395
600
 
396
601
  except EOFError:
@@ -404,6 +609,7 @@ def handle_chat_command(options: Namespace) -> None:
404
609
  except Exception as e:
405
610
  print(f"\n❌ Error: {e}")
406
611
  logger.error(f"Chat error: {e}")
612
+ traceback.print_exc()
407
613
 
408
614
  except Exception as e:
409
615
  logger.error(f"Failed to initialize chat session: {e}")
@@ -448,7 +654,6 @@ def handle_validate_command(options: Namespace) -> None:
448
654
 
449
655
 
450
656
  def setup_logging(verbosity: int) -> None:
451
- logger.remove()
452
657
  levels: dict[int, str] = {
453
658
  0: "ERROR",
454
659
  1: "WARNING",
@@ -457,7 +662,7 @@ def setup_logging(verbosity: int) -> None:
457
662
  4: "TRACE",
458
663
  }
459
664
  level: str = levels.get(verbosity, "TRACE")
460
- logger.add(sys.stderr, level=level)
665
+ configure_logging(level=level)
461
666
 
462
667
 
463
668
  def generate_bundle_from_template(config_path: Path, app_name: str) -> Path: