dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/__init__.py
CHANGED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DAO AI - A framework for building AI agents with Databricks.
|
|
3
|
+
|
|
4
|
+
This module configures package-level settings including warning filters
|
|
5
|
+
for expected runtime warnings that don't indicate actual problems.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import warnings
|
|
9
|
+
|
|
10
|
+
# Suppress Pydantic serialization warnings for Context objects during checkpointing.
|
|
11
|
+
# This warning occurs because LangGraph's checkpointer serializes the context_schema
|
|
12
|
+
# and Pydantic reports that serialization may not be as expected. This is benign
|
|
13
|
+
# since Context is only used at runtime and doesn't need to be persisted.
|
|
14
|
+
#
|
|
15
|
+
# The warning looks like:
|
|
16
|
+
# PydanticSerializationUnexpectedValue(Expected `none` - serialized value may not
|
|
17
|
+
# be as expected [field_name='context', input_value=Context(...), input_type=Context])
|
|
18
|
+
warnings.filterwarnings(
|
|
19
|
+
"ignore",
|
|
20
|
+
message=r".*Pydantic serializer warnings.*",
|
|
21
|
+
category=UserWarning,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Also filter the specific PydanticSerializationUnexpectedValue warning
|
|
25
|
+
warnings.filterwarnings(
|
|
26
|
+
"ignore",
|
|
27
|
+
message=r".*PydanticSerializationUnexpectedValue.*",
|
|
28
|
+
category=UserWarning,
|
|
29
|
+
)
|
dao_ai/cli.py
CHANGED
|
@@ -271,18 +271,24 @@ Examples:
|
|
|
271
271
|
chat_parser.add_argument(
|
|
272
272
|
"--thread-id",
|
|
273
273
|
type=str,
|
|
274
|
-
default=
|
|
274
|
+
default=None,
|
|
275
275
|
metavar="ID",
|
|
276
|
-
help="Thread ID for the chat session (default:
|
|
276
|
+
help="Thread ID for the chat session (default: auto-generated UUID)",
|
|
277
277
|
)
|
|
278
278
|
|
|
279
279
|
options = parser.parse_args(args)
|
|
280
280
|
|
|
281
|
+
# Generate a new thread_id UUID if not provided (only for chat command)
|
|
282
|
+
if hasattr(options, "thread_id") and options.thread_id is None:
|
|
283
|
+
import uuid
|
|
284
|
+
|
|
285
|
+
options.thread_id = str(uuid.uuid4())
|
|
286
|
+
|
|
281
287
|
return options
|
|
282
288
|
|
|
283
289
|
|
|
284
290
|
def handle_chat_command(options: Namespace) -> None:
|
|
285
|
-
"""Interactive chat REPL with the DAO AI system."""
|
|
291
|
+
"""Interactive chat REPL with the DAO AI system with Human-in-the-Loop support."""
|
|
286
292
|
logger.debug("Starting chat session with DAO AI system...")
|
|
287
293
|
|
|
288
294
|
try:
|
|
@@ -305,9 +311,7 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
305
311
|
print("-" * 50)
|
|
306
312
|
|
|
307
313
|
# Import streaming function and interrupt handling
|
|
308
|
-
from langchain_core.messages import HumanMessage
|
|
309
|
-
|
|
310
|
-
from dao_ai.models import process_messages_stream
|
|
314
|
+
from langchain_core.messages import AIMessage, HumanMessage
|
|
311
315
|
|
|
312
316
|
# Conversation history
|
|
313
317
|
messages = []
|
|
@@ -353,44 +357,204 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
353
357
|
# Prepare custom inputs for the agent
|
|
354
358
|
custom_inputs = {"configurable": configurable}
|
|
355
359
|
|
|
360
|
+
# Invoke the graph and handle interrupts (HITL)
|
|
361
|
+
# Wrap in async function to maintain connection pool throughout
|
|
362
|
+
logger.debug(f"Invoking graph with {len(messages)} messages")
|
|
363
|
+
|
|
364
|
+
import asyncio
|
|
365
|
+
|
|
366
|
+
from langgraph.types import Command
|
|
367
|
+
|
|
368
|
+
async def _invoke_with_hitl():
|
|
369
|
+
"""Invoke graph and handle HITL interrupts in single async context."""
|
|
370
|
+
result = await app.ainvoke(
|
|
371
|
+
{"messages": messages},
|
|
372
|
+
config=custom_inputs,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Check for interrupts (Human-in-the-Loop) using __interrupt__
|
|
376
|
+
# This is the modern LangChain pattern
|
|
377
|
+
while "__interrupt__" in result:
|
|
378
|
+
interrupts = result["__interrupt__"]
|
|
379
|
+
logger.info(f"HITL: {len(interrupts)} interrupt(s) detected")
|
|
380
|
+
|
|
381
|
+
# Collect decisions for all interrupts
|
|
382
|
+
decisions = []
|
|
383
|
+
|
|
384
|
+
for interrupt in interrupts:
|
|
385
|
+
interrupt_value = interrupt.value
|
|
386
|
+
action_requests = interrupt_value.get("action_requests", [])
|
|
387
|
+
|
|
388
|
+
for action_request in action_requests:
|
|
389
|
+
# Display interrupt information
|
|
390
|
+
print("\n⚠️ Human in the Loop - Tool Approval Required")
|
|
391
|
+
print(f"{'=' * 60}")
|
|
392
|
+
|
|
393
|
+
tool_name = action_request.get("name", "unknown")
|
|
394
|
+
tool_args = action_request.get("arguments", {})
|
|
395
|
+
description = action_request.get("description", "")
|
|
396
|
+
|
|
397
|
+
print(f"Tool: {tool_name}")
|
|
398
|
+
if description:
|
|
399
|
+
print(f"\n{description}\n")
|
|
400
|
+
|
|
401
|
+
print("Arguments:")
|
|
402
|
+
for arg_name, arg_value in tool_args.items():
|
|
403
|
+
# Truncate long values
|
|
404
|
+
arg_str = str(arg_value)
|
|
405
|
+
if len(arg_str) > 100:
|
|
406
|
+
arg_str = arg_str[:97] + "..."
|
|
407
|
+
print(f" - {arg_name}: {arg_str}")
|
|
408
|
+
|
|
409
|
+
print(f"{'=' * 60}")
|
|
410
|
+
|
|
411
|
+
# Prompt user for decision
|
|
412
|
+
while True:
|
|
413
|
+
decision_input = (
|
|
414
|
+
input(
|
|
415
|
+
"\nAction? (a)pprove / (r)eject / (e)dit / (h)elp: "
|
|
416
|
+
)
|
|
417
|
+
.strip()
|
|
418
|
+
.lower()
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
if decision_input in ["a", "approve"]:
|
|
422
|
+
logger.info("User approved tool call")
|
|
423
|
+
print("✅ Approved - continuing execution...")
|
|
424
|
+
decisions.append({"type": "approve"})
|
|
425
|
+
break
|
|
426
|
+
elif decision_input in ["r", "reject"]:
|
|
427
|
+
logger.info("User rejected tool call")
|
|
428
|
+
feedback = input(
|
|
429
|
+
" Feedback for agent (optional): "
|
|
430
|
+
).strip()
|
|
431
|
+
if feedback:
|
|
432
|
+
decisions.append(
|
|
433
|
+
{"type": "reject", "message": feedback}
|
|
434
|
+
)
|
|
435
|
+
else:
|
|
436
|
+
decisions.append(
|
|
437
|
+
{
|
|
438
|
+
"type": "reject",
|
|
439
|
+
"message": "Tool call rejected by user",
|
|
440
|
+
}
|
|
441
|
+
)
|
|
442
|
+
print(
|
|
443
|
+
"❌ Rejected - agent will receive feedback..."
|
|
444
|
+
)
|
|
445
|
+
break
|
|
446
|
+
elif decision_input in ["e", "edit"]:
|
|
447
|
+
print(
|
|
448
|
+
"ℹ️ Edit functionality not yet implemented in CLI"
|
|
449
|
+
)
|
|
450
|
+
print(" Please approve or reject.")
|
|
451
|
+
continue
|
|
452
|
+
elif decision_input in ["h", "help"]:
|
|
453
|
+
print("\nAvailable actions:")
|
|
454
|
+
print(
|
|
455
|
+
" (a)pprove - Execute the tool call as shown"
|
|
456
|
+
)
|
|
457
|
+
print(
|
|
458
|
+
" (r)eject - Cancel the tool call with optional feedback"
|
|
459
|
+
)
|
|
460
|
+
print(
|
|
461
|
+
" (e)dit - Modify arguments (not yet implemented)"
|
|
462
|
+
)
|
|
463
|
+
print(" (h)elp - Show this help message")
|
|
464
|
+
continue
|
|
465
|
+
else:
|
|
466
|
+
print("Invalid option. Type 'h' for help.")
|
|
467
|
+
continue
|
|
468
|
+
|
|
469
|
+
# Resume execution with decisions using Command
|
|
470
|
+
# This is the modern LangChain pattern
|
|
471
|
+
logger.debug(f"Resuming with {len(decisions)} decision(s)")
|
|
472
|
+
result = await app.ainvoke(
|
|
473
|
+
Command(resume={"decisions": decisions}),
|
|
474
|
+
config=custom_inputs,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
return result
|
|
478
|
+
|
|
479
|
+
try:
|
|
480
|
+
# Use async invoke - keep connection pool alive throughout HITL
|
|
481
|
+
loop = asyncio.get_event_loop()
|
|
482
|
+
except RuntimeError:
|
|
483
|
+
loop = asyncio.new_event_loop()
|
|
484
|
+
asyncio.set_event_loop(loop)
|
|
485
|
+
|
|
486
|
+
try:
|
|
487
|
+
result = loop.run_until_complete(_invoke_with_hitl())
|
|
488
|
+
except Exception as e:
|
|
489
|
+
logger.error(f"Error invoking graph: {e}")
|
|
490
|
+
print(f"\n❌ Error: {e}")
|
|
491
|
+
continue
|
|
492
|
+
|
|
493
|
+
# After all interrupts handled, display the final response
|
|
356
494
|
print("\n🤖 Assistant: ", end="", flush=True)
|
|
357
495
|
|
|
358
|
-
# Stream the response
|
|
359
496
|
response_content = ""
|
|
497
|
+
structured_response = None
|
|
360
498
|
try:
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
499
|
+
# Debug: Log what's in the result
|
|
500
|
+
logger.debug(f"Result keys: {result.keys() if result else 'None'}")
|
|
501
|
+
if result:
|
|
502
|
+
for key in result.keys():
|
|
503
|
+
logger.debug(f"Result['{key}'] type: {type(result[key])}")
|
|
504
|
+
|
|
505
|
+
# Get the latest messages from the result
|
|
506
|
+
if result and "messages" in result:
|
|
507
|
+
latest_messages = result["messages"]
|
|
508
|
+
# Find the last AI message
|
|
509
|
+
for msg in reversed(latest_messages):
|
|
510
|
+
if isinstance(msg, AIMessage):
|
|
511
|
+
logger.debug(f"AI message content: {msg.content}")
|
|
512
|
+
logger.debug(
|
|
513
|
+
f"AI message has tool_calls: {hasattr(msg, 'tool_calls')}"
|
|
514
|
+
)
|
|
515
|
+
if hasattr(msg, "tool_calls"):
|
|
516
|
+
logger.debug(f"Tool calls: {msg.tool_calls}")
|
|
517
|
+
|
|
518
|
+
if hasattr(msg, "content") and msg.content:
|
|
519
|
+
response_content = msg.content
|
|
520
|
+
print(response_content, end="", flush=True)
|
|
521
|
+
break
|
|
522
|
+
|
|
523
|
+
# Check for structured output and display it separately
|
|
524
|
+
if result and "structured_response" in result:
|
|
525
|
+
structured_response = result["structured_response"]
|
|
526
|
+
import json
|
|
527
|
+
|
|
528
|
+
structured_json = json.dumps(
|
|
529
|
+
structured_response.model_dump()
|
|
530
|
+
if hasattr(structured_response, "model_dump")
|
|
531
|
+
else structured_response,
|
|
532
|
+
indent=2,
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
# If there was message content, add separator
|
|
536
|
+
if response_content.strip():
|
|
537
|
+
print("\n\n📊 Structured Output:")
|
|
538
|
+
print(structured_json)
|
|
539
|
+
else:
|
|
540
|
+
# No message content, just show structured output
|
|
541
|
+
print(structured_json, end="", flush=True)
|
|
542
|
+
|
|
543
|
+
response_content = response_content or structured_json
|
|
544
|
+
|
|
545
|
+
print() # New line after response
|
|
380
546
|
|
|
381
547
|
# Add assistant response to history if we got content
|
|
382
548
|
if response_content.strip():
|
|
383
|
-
from langchain_core.messages import AIMessage
|
|
384
|
-
|
|
385
549
|
assistant_message = AIMessage(content=response_content)
|
|
386
550
|
messages.append(assistant_message)
|
|
387
551
|
else:
|
|
388
552
|
print("(No response content generated)")
|
|
389
553
|
|
|
390
554
|
except Exception as e:
|
|
391
|
-
print(f"\n❌ Error
|
|
555
|
+
print(f"\n❌ Error processing response: {e}")
|
|
392
556
|
print(f"Stack trace:\n{traceback.format_exc()}")
|
|
393
|
-
logger.error(f"
|
|
557
|
+
logger.error(f"Response processing error: {e}")
|
|
394
558
|
logger.error(f"Stack trace: {traceback.format_exc()}")
|
|
395
559
|
|
|
396
560
|
except EOFError:
|
|
@@ -404,6 +568,7 @@ def handle_chat_command(options: Namespace) -> None:
|
|
|
404
568
|
except Exception as e:
|
|
405
569
|
print(f"\n❌ Error: {e}")
|
|
406
570
|
logger.error(f"Chat error: {e}")
|
|
571
|
+
traceback.print_exc()
|
|
407
572
|
|
|
408
573
|
except Exception as e:
|
|
409
574
|
logger.error(f"Failed to initialize chat session: {e}")
|