dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +770 -244
  4. dao_ai/genie/__init__.py +1 -22
  5. dao_ai/genie/cache/__init__.py +1 -2
  6. dao_ai/genie/cache/base.py +20 -70
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +44 -21
  9. dao_ai/genie/cache/semantic.py +390 -109
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +8 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +47 -24
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/genie/__init__.py +0 -236
  54. dao_ai/tools/human_in_the_loop.py +0 -100
  55. dao_ai-0.0.36.dist-info/METADATA +0 -951
  56. dao_ai-0.0.36.dist-info/RECORD +0 -47
  57. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  58. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  59. {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py CHANGED
@@ -1,11 +1,34 @@
1
1
  import uuid
2
2
  from os import PathLike
3
3
  from pathlib import Path
4
- from typing import Any, Generator, Optional, Sequence, Union
5
-
6
- from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
4
+ from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Sequence, Union
5
+
6
+ from databricks_langchain import ChatDatabricks
7
+
8
+ if TYPE_CHECKING:
9
+ pass
10
+
11
+ # Import official LangChain HITL TypedDict definitions
12
+ # Reference: https://docs.langchain.com/oss/python/langchain/human-in-the-loop
13
+ from langchain.agents.middleware.human_in_the_loop import (
14
+ ActionRequest,
15
+ Decision,
16
+ EditDecision,
17
+ HITLRequest,
18
+ RejectDecision,
19
+ ReviewConfig,
20
+ )
21
+ from langchain_community.adapters.openai import convert_openai_messages
22
+ from langchain_core.language_models import LanguageModelLike
23
+ from langchain_core.messages import (
24
+ AIMessage,
25
+ AIMessageChunk,
26
+ BaseMessage,
27
+ HumanMessage,
28
+ SystemMessage,
29
+ )
7
30
  from langgraph.graph.state import CompiledStateGraph
8
- from langgraph.types import StateSnapshot
31
+ from langgraph.types import Interrupt, StateSnapshot
9
32
  from loguru import logger
10
33
  from mlflow import MlflowClient
11
34
  from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
@@ -28,11 +51,13 @@ from mlflow.types.responses_helpers import (
28
51
  Message,
29
52
  ResponseInputTextParam,
30
53
  )
54
+ from pydantic import BaseModel, Field, create_model
31
55
 
32
56
  from dao_ai.messages import (
33
57
  has_langchain_messages,
34
58
  has_mlflow_messages,
35
59
  has_mlflow_responses_messages,
60
+ last_human_message,
36
61
  )
37
62
  from dao_ai.state import Context
38
63
 
@@ -54,12 +79,37 @@ def get_latest_model_version(model_name: str) -> int:
54
79
  mlflow_client: MlflowClient = MlflowClient()
55
80
  latest_version: int = 1
56
81
  for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
57
- version_int = int(mv.version)
82
+ version_int: int = int(mv.version)
58
83
  if version_int > latest_version:
59
84
  latest_version = version_int
60
85
  return latest_version
61
86
 
62
87
 
88
+ def is_interrupted(snapshot: StateSnapshot) -> bool:
89
+ """
90
+ Check if the graph state is currently interrupted (paused for human-in-the-loop).
91
+
92
+ Based on LangChain documentation:
93
+ - StateSnapshot has an `interrupts` attribute which is a tuple
94
+ - When interrupted, the tuple contains Interrupt objects
95
+ - When not interrupted, it's an empty tuple ()
96
+
97
+ Args:
98
+ snapshot: The StateSnapshot to check
99
+
100
+ Returns:
101
+ True if the graph is interrupted (has pending HITL actions), False otherwise
102
+
103
+ Example:
104
+ >>> snapshot = await graph.aget_state(config)
105
+ >>> if is_interrupted(snapshot):
106
+ ... print("Graph is waiting for human input")
107
+ """
108
+ # Check if snapshot has any interrupts
109
+ # According to LangChain docs, interrupts is a tuple that's empty () when no interrupts
110
+ return bool(snapshot.interrupts)
111
+
112
+
63
113
  async def get_state_snapshot_async(
64
114
  graph: CompiledStateGraph, thread_id: str
65
115
  ) -> Optional[StateSnapshot]:
@@ -167,6 +217,111 @@ def get_genie_conversation_ids_from_state(
167
217
  return {}
168
218
 
169
219
 
220
+ def _extract_interrupt_value(interrupt: Interrupt) -> HITLRequest:
221
+ """
222
+ Extract the HITL request from a LangGraph Interrupt object.
223
+
224
+ Following LangChain patterns, the Interrupt object has a .value attribute
225
+ containing the HITLRequest structure with action_requests and review_configs.
226
+
227
+ Args:
228
+ interrupt: Interrupt object from LangGraph with .value and .id attributes
229
+
230
+ Returns:
231
+ HITLRequest with action_requests and review_configs
232
+ """
233
+ # Interrupt.value is typed as Any, but for HITL it should be a HITLRequest dict
234
+ if isinstance(interrupt.value, dict):
235
+ # Return as HITLRequest TypedDict
236
+ return interrupt.value # type: ignore[return-value]
237
+
238
+ # Fallback: return empty structure if value is not a dict
239
+ return {"action_requests": [], "review_configs": []}
240
+
241
+
242
+ def _format_action_requests_message(interrupt_data: list[HITLRequest]) -> str:
243
+ """
244
+ Format action requests from interrupts into a simple, user-friendly message.
245
+
246
+ Since we now use LLM-based parsing, users can respond in natural language.
247
+ This function just shows WHAT actions are pending, not HOW to respond.
248
+
249
+ Args:
250
+ interrupt_data: List of HITLRequest structures containing action_requests and review_configs
251
+
252
+ Returns:
253
+ Simple formatted message describing the pending actions
254
+ """
255
+ if not interrupt_data:
256
+ return ""
257
+
258
+ # Collect all action requests and review configs from all interrupts
259
+ all_actions: list[ActionRequest] = []
260
+ review_configs_map: dict[str, ReviewConfig] = {}
261
+
262
+ for hitl_request in interrupt_data:
263
+ all_actions.extend(hitl_request.get("action_requests", []))
264
+ for review_config in hitl_request.get("review_configs", []):
265
+ action_name = review_config.get("action_name", "")
266
+ if action_name:
267
+ review_configs_map[action_name] = review_config
268
+
269
+ if not all_actions:
270
+ return ""
271
+
272
+ # Build simple, clean message
273
+ lines = ["⏸️ **Action Approval Required**", ""]
274
+ lines.append(
275
+ f"The assistant wants to perform {len(all_actions)} action(s) that require your approval:"
276
+ )
277
+ lines.append("")
278
+
279
+ for i, action in enumerate(all_actions, 1):
280
+ tool_name = action.get("name", "unknown")
281
+ args = action.get("args", {})
282
+ description = action.get("description")
283
+
284
+ lines.append(f"**{i}. {tool_name}**")
285
+
286
+ # Show review prompt/description if available
287
+ if description:
288
+ lines.append(f" • **Review:** {description}")
289
+
290
+ if args:
291
+ # Format args nicely, truncating long values
292
+ for key, value in args.items():
293
+ value_str = str(value)
294
+ if len(value_str) > 100:
295
+ value_str = value_str[:100] + "..."
296
+ lines.append(f" • {key}: `{value_str}`")
297
+ else:
298
+ lines.append(" • (no arguments)")
299
+
300
+ # Show allowed decisions
301
+ review_config = review_configs_map.get(tool_name)
302
+ if review_config:
303
+ allowed_decisions = review_config.get("allowed_decisions", [])
304
+ if allowed_decisions:
305
+ decisions_str = ", ".join(allowed_decisions)
306
+ lines.append(f" • **Options:** {decisions_str}")
307
+
308
+ lines.append("")
309
+
310
+ lines.append("---")
311
+ lines.append("")
312
+ lines.append(
313
+ "**You can respond in natural language** (e.g., 'approve both', 'reject the first one', "
314
+ "'change the email to new@example.com')"
315
+ )
316
+ lines.append("")
317
+ lines.append(
318
+ "Or provide structured decisions in `custom_inputs` with key `decisions`: "
319
+ '`[{"type": "approve"}, {"type": "reject", "message": "reason"}]`'
320
+ )
321
+
322
+ return "\n".join(lines)
323
+
324
+
170
325
  class LanggraphChatModel(ChatModel):
171
326
  """
172
327
  ChatModel that delegates requests to a LangGraph CompiledStateGraph.
@@ -216,22 +371,36 @@ class LanggraphChatModel(ChatModel):
216
371
 
217
372
  configurable: dict[str, Any] = {}
218
373
  if "configurable" in input_data:
219
- configurable: dict[str, Any] = input_data.pop("configurable")
374
+ configurable = input_data.pop("configurable")
220
375
  if "custom_inputs" in input_data:
221
376
  custom_inputs: dict[str, Any] = input_data.pop("custom_inputs")
222
377
  if "configurable" in custom_inputs:
223
- configurable: dict[str, Any] = custom_inputs.pop("configurable")
224
-
225
- if "user_id" in configurable:
226
- configurable["user_id"] = configurable["user_id"].replace(".", "_")
227
-
228
- if "conversation_id" in configurable and "thread_id" not in configurable:
229
- configurable["thread_id"] = configurable["conversation_id"]
230
-
231
- if "thread_id" not in configurable:
232
- configurable["thread_id"] = str(uuid.uuid4())
233
-
234
- context: Context = Context(**configurable)
378
+ configurable = custom_inputs.pop("configurable")
379
+
380
+ # Extract known Context fields
381
+ user_id: str | None = configurable.pop("user_id", None)
382
+ if user_id:
383
+ user_id = user_id.replace(".", "_")
384
+
385
+ # Accept either thread_id or conversation_id (interchangeable)
386
+ # conversation_id takes precedence (Databricks vocabulary)
387
+ thread_id: str | None = configurable.pop("thread_id", None)
388
+ conversation_id: str | None = configurable.pop("conversation_id", None)
389
+
390
+ # conversation_id takes precedence if both provided
391
+ if conversation_id:
392
+ thread_id = conversation_id
393
+ if not thread_id:
394
+ thread_id = str(uuid.uuid4())
395
+
396
+ # All remaining configurable values go into custom dict
397
+ custom: dict[str, Any] = configurable
398
+
399
+ context: Context = Context(
400
+ user_id=user_id,
401
+ thread_id=thread_id,
402
+ custom=custom,
403
+ )
235
404
  return context
236
405
 
237
406
  def predict_stream(
@@ -307,6 +476,322 @@ class LanggraphChatModel(ChatModel):
307
476
  return [m.to_dict() for m in messages]
308
477
 
309
478
 
479
+ def _create_decision_schema(interrupt_data: list[HITLRequest]) -> type[BaseModel]:
480
+ """
481
+ Dynamically create a Pydantic model for structured output based on interrupt actions.
482
+
483
+ This creates a schema that matches the expected decision format for the interrupted actions.
484
+ Each action gets a corresponding decision field that can be approve, edit, or reject.
485
+ Includes validation fields to ensure the response is complete and valid.
486
+
487
+ Args:
488
+ interrupt_data: List of HITL interrupt requests containing action_requests and review_configs
489
+
490
+ Returns:
491
+ A dynamically created Pydantic BaseModel class for structured output
492
+
493
+ Example:
494
+ For two actions (send_email, execute_sql), creates a model like:
495
+ class Decisions(BaseModel):
496
+ is_valid: bool
497
+ validation_message: Optional[str]
498
+ decision_1: Literal["approve", "edit", "reject"]
499
+ decision_1_message: Optional[str] # For reject
500
+ decision_1_edited_args: Optional[dict] # For edit
501
+ decision_2: Literal["approve", "edit", "reject"]
502
+ ...
503
+ """
504
+ # Collect all actions
505
+ all_actions: list[ActionRequest] = []
506
+ review_configs_map: dict[str, ReviewConfig] = {}
507
+
508
+ for hitl_request in interrupt_data:
509
+ all_actions.extend(hitl_request.get("action_requests", []))
510
+ review_config: ReviewConfig
511
+ for review_config in hitl_request.get("review_configs", []):
512
+ action_name: str = review_config.get("action_name", "")
513
+ if action_name:
514
+ review_configs_map[action_name] = review_config
515
+
516
+ # Build fields for the dynamic model
517
+ # Start with validation fields
518
+ fields: dict[str, Any] = {
519
+ "is_valid": (
520
+ bool,
521
+ Field(
522
+ description="Whether the user's response provides valid decisions for ALL actions. "
523
+ "Set to False if the user's message is unclear, ambiguous, or doesn't provide decisions for all actions."
524
+ ),
525
+ ),
526
+ "validation_message": (
527
+ Optional[str],
528
+ Field(
529
+ None,
530
+ description="If is_valid is False, explain what is missing or unclear. "
531
+ "Be specific about which action(s) need clarification.",
532
+ ),
533
+ ),
534
+ }
535
+
536
+ i: int
537
+ action: ActionRequest
538
+ for i, action in enumerate(all_actions, 1):
539
+ tool_name: str = action.get("name", "unknown")
540
+ review_config: Optional[ReviewConfig] = review_configs_map.get(tool_name)
541
+ allowed_decisions: list[str] = (
542
+ review_config.get("allowed_decisions", ["approve", "reject"])
543
+ if review_config
544
+ else ["approve", "reject"]
545
+ )
546
+
547
+ # Create a Literal type for allowed decisions
548
+ decision_literal: type = Literal[tuple(allowed_decisions)] # type: ignore
549
+
550
+ # Add decision field
551
+ fields[f"decision_{i}"] = (
552
+ decision_literal,
553
+ Field(
554
+ description=f"Decision for action {i} ({tool_name}): {', '.join(allowed_decisions)}"
555
+ ),
556
+ )
557
+
558
+ # Add optional message field for reject
559
+ if "reject" in allowed_decisions:
560
+ fields[f"decision_{i}_message"] = (
561
+ Optional[str],
562
+ Field(
563
+ None,
564
+ description=f"Optional message if rejecting action {i}",
565
+ ),
566
+ )
567
+
568
+ # Add optional edited_args field for edit
569
+ if "edit" in allowed_decisions:
570
+ fields[f"decision_{i}_edited_args"] = (
571
+ Optional[dict[str, Any]],
572
+ Field(
573
+ None,
574
+ description=f"Modified arguments if editing action {i}. Only provide fields that need to change.",
575
+ ),
576
+ )
577
+
578
+ # Create the dynamic model
579
+ DecisionsModel = create_model(
580
+ "InterruptDecisions",
581
+ __doc__="Decisions for each interrupted action, in order.",
582
+ **fields,
583
+ )
584
+
585
+ return DecisionsModel
586
+
587
+
588
+ def _convert_schema_to_decisions(
589
+ parsed_output: BaseModel,
590
+ interrupt_data: list[HITLRequest],
591
+ ) -> list[Decision]:
592
+ """
593
+ Convert the parsed structured output into LangChain Decision objects.
594
+
595
+ Args:
596
+ parsed_output: The Pydantic model instance from structured output
597
+ interrupt_data: Original interrupt data for context
598
+
599
+ Returns:
600
+ List of Decision dictionaries compatible with Command(resume={"decisions": ...})
601
+ """
602
+ # Collect all actions to know how many decisions we need
603
+ all_actions: list[ActionRequest] = []
604
+ hitl_request: HITLRequest
605
+ for hitl_request in interrupt_data:
606
+ all_actions.extend(hitl_request.get("action_requests", []))
607
+
608
+ decisions: list[Decision] = []
609
+
610
+ i: int
611
+ for i in range(1, len(all_actions) + 1):
612
+ decision_type: str = getattr(parsed_output, f"decision_{i}")
613
+
614
+ if decision_type == "approve":
615
+ decisions.append({"type": "approve"}) # type: ignore
616
+ elif decision_type == "reject":
617
+ message: Optional[str] = getattr(
618
+ parsed_output, f"decision_{i}_message", None
619
+ )
620
+ reject_decision: RejectDecision = {"type": "reject"}
621
+ if message:
622
+ reject_decision["message"] = message
623
+ decisions.append(reject_decision) # type: ignore
624
+ elif decision_type == "edit":
625
+ edited_args: Optional[dict[str, Any]] = getattr(
626
+ parsed_output, f"decision_{i}_edited_args", None
627
+ )
628
+ action: ActionRequest = all_actions[i - 1]
629
+ tool_name: str = action.get("name", "")
630
+ original_args: dict[str, Any] = action.get("args", {})
631
+
632
+ # Merge original args with edited args
633
+ final_args: dict[str, Any] = {**original_args, **(edited_args or {})}
634
+
635
+ edit_decision: EditDecision = {
636
+ "type": "edit",
637
+ "edited_action": {
638
+ "name": tool_name,
639
+ "args": final_args,
640
+ },
641
+ }
642
+ decisions.append(edit_decision) # type: ignore
643
+
644
+ return decisions
645
+
646
+
647
+ def handle_interrupt_response(
648
+ snapshot: StateSnapshot,
649
+ messages: list[BaseMessage],
650
+ model: Optional[LanguageModelLike] = None,
651
+ ) -> dict[str, Any]:
652
+ """
653
+ Parse user's natural language response to interrupts using LLM with structured output.
654
+
655
+ This function uses an LLM to understand the user's intent and extract structured decisions
656
+ for each pending action. The schema is dynamically created based on the pending actions.
657
+ Includes validation to ensure the response is complete and valid.
658
+
659
+ Args:
660
+ snapshot: The current state snapshot containing interrupts
661
+ messages: List of messages, from which the last human message will be extracted
662
+ model: Optional LLM to use for parsing. Defaults to Llama 3.1 70B
663
+
664
+ Returns:
665
+ Dictionary with:
666
+ - "is_valid": bool indicating if the response is valid
667
+ - "validation_message": Optional message if invalid, explaining what's missing
668
+ - "decisions": list of Decision objects (empty if invalid)
669
+
670
+ Example:
671
+ Valid: {"is_valid": True, "validation_message": None, "decisions": [{"type": "approve"}]}
672
+ Invalid: {"is_valid": False, "validation_message": "Please specify...", "decisions": []}
673
+ """
674
+ # Extract the last human message
675
+ user_message_obj: Optional[HumanMessage] = last_human_message(messages)
676
+
677
+ if not user_message_obj:
678
+ logger.warning("handle_interrupt_response called but no human message found")
679
+ return {
680
+ "is_valid": False,
681
+ "validation_message": "No user message found. Please provide a response to the pending action(s).",
682
+ "decisions": [],
683
+ }
684
+
685
+ user_message: str = str(user_message_obj.content)
686
+ logger.info(f"HITL: Parsing user message with LLM: {user_message[:100]}")
687
+
688
+ if not model:
689
+ model = ChatDatabricks(
690
+ endpoint="databricks-claude-sonnet-4",
691
+ temperature=0,
692
+ )
693
+
694
+ # Extract interrupt data
695
+ if not snapshot.interrupts:
696
+ logger.warning("handle_interrupt_response called but no interrupts in snapshot")
697
+ return {"decisions": []}
698
+
699
+ interrupt_data: list[HITLRequest] = [
700
+ _extract_interrupt_value(interrupt) for interrupt in snapshot.interrupts
701
+ ]
702
+
703
+ # Collect all actions for context
704
+ all_actions: list[ActionRequest] = []
705
+ hitl_request: HITLRequest
706
+ for hitl_request in interrupt_data:
707
+ all_actions.extend(hitl_request.get("action_requests", []))
708
+
709
+ if not all_actions:
710
+ logger.warning("handle_interrupt_response called but no actions in interrupts")
711
+ return {"decisions": []}
712
+
713
+ # Create dynamic schema
714
+ DecisionsModel: type[BaseModel] = _create_decision_schema(interrupt_data)
715
+
716
+ # Create structured LLM
717
+ structured_llm: LanguageModelLike = model.with_structured_output(DecisionsModel)
718
+
719
+ # Format action context for the LLM
720
+ action_descriptions: list[str] = []
721
+ i: int
722
+ action: ActionRequest
723
+ for i, action in enumerate(all_actions, 1):
724
+ tool_name: str = action.get("name", "unknown")
725
+ args: dict[str, Any] = action.get("args", {})
726
+ args_str: str = (
727
+ ", ".join(f"{k}={v}" for k, v in args.items()) if args else "(no args)"
728
+ )
729
+ action_descriptions.append(f"Action {i}: {tool_name}({args_str})")
730
+
731
+ system_prompt: str = f"""You are parsing a user's response to interrupted agent actions.
732
+
733
+ The following actions are pending approval:
734
+ {chr(10).join(action_descriptions)}
735
+
736
+ Your task is to extract the user's decision for EACH action in order. The user may:
737
+ - Approve: Accept the action as-is
738
+ - Reject: Cancel the action (optionally with a reason/message)
739
+ - Edit: Modify the arguments before executing
740
+
741
+ VALIDATION:
742
+ - Set is_valid=True only if you can confidently extract decisions for ALL actions
743
+ - Set is_valid=False if the user's message is:
744
+ * Unclear or ambiguous
745
+ * Missing decisions for some actions
746
+ * Asking a question instead of providing decisions
747
+ * Not addressing the actions at all
748
+ - If is_valid=False, provide a clear validation_message explaining what is needed
749
+
750
+ FLEXIBILITY:
751
+ - Be flexible in parsing informal language like "yes", "no", "ok", "change X to Y"
752
+ - If the user doesn't explicitly mention an action, assume they want to approve it
753
+ - Only mark as invalid if the message is genuinely unclear or incomplete"""
754
+
755
+ try:
756
+ # Invoke LLM with structured output
757
+ parsed: BaseModel = structured_llm.invoke(
758
+ [
759
+ SystemMessage(content=system_prompt),
760
+ HumanMessage(content=user_message),
761
+ ]
762
+ )
763
+
764
+ # Check validation first
765
+ is_valid: bool = getattr(parsed, "is_valid", True)
766
+ validation_message: Optional[str] = getattr(parsed, "validation_message", None)
767
+
768
+ if not is_valid:
769
+ logger.warning(
770
+ f"HITL: Invalid user response. Reason: {validation_message or 'Unknown'}"
771
+ )
772
+ return {
773
+ "is_valid": False,
774
+ "validation_message": validation_message
775
+ or "Your response was unclear. Please provide a clear decision for each action.",
776
+ "decisions": [],
777
+ }
778
+
779
+ # Convert to Decision format
780
+ decisions: list[Decision] = _convert_schema_to_decisions(parsed, interrupt_data)
781
+
782
+ logger.info(f"Parsed {len(decisions)} decisions from user message")
783
+ return {"is_valid": True, "validation_message": None, "decisions": decisions}
784
+
785
+ except Exception as e:
786
+ logger.error(f"Failed to parse interrupt response: {e}")
787
+ # Return invalid response on parsing failure
788
+ return {
789
+ "is_valid": False,
790
+ "validation_message": f"Failed to parse your response: {str(e)}. Please provide a clear decision for each action.",
791
+ "decisions": [],
792
+ }
793
+
794
+
310
795
  class LanggraphResponsesAgent(ResponsesAgent):
311
796
  """
312
797
  ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
@@ -315,37 +800,151 @@ class LanggraphResponsesAgent(ResponsesAgent):
315
800
  support for streaming, tool calling, and async execution.
316
801
  """
317
802
 
318
- def __init__(self, graph: CompiledStateGraph) -> None:
803
+ def __init__(
804
+ self,
805
+ graph: CompiledStateGraph,
806
+ ) -> None:
319
807
  self.graph = graph
320
808
 
321
809
  def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
322
810
  """
323
811
  Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
812
+
813
+ Input structure (custom_inputs):
814
+ configurable:
815
+ thread_id: "abc-123" # Or conversation_id (aliases, conversation_id takes precedence)
816
+ user_id: "nate.fleming"
817
+ store_num: "87887"
818
+ session: # Paste from previous output
819
+ conversation_id: "abc-123" # Alias of thread_id
820
+ genie:
821
+ spaces:
822
+ space_123: {conversation_id: "conv_456", ...}
823
+ decisions: # For resuming interrupted graphs (HITL)
824
+ - type: "approve"
825
+ - type: "reject"
826
+ message: "Not authorized"
827
+
828
+ Output structure (custom_outputs):
829
+ configurable:
830
+ thread_id: "abc-123" # Only thread_id in configurable
831
+ user_id: "nate.fleming"
832
+ store_num: "87887"
833
+ session:
834
+ conversation_id: "abc-123" # conversation_id in session
835
+ genie:
836
+ spaces:
837
+ space_123: {conversation_id: "conv_456", ...}
838
+ pending_actions: # If HITL interrupt occurred
839
+ - name: "send_email"
840
+ arguments: {...}
841
+ description: "..."
324
842
  """
325
843
  logger.debug(f"ResponsesAgent request: {request}")
326
844
 
327
845
  # Convert ResponsesAgent input to LangChain messages
328
- messages = self._convert_request_to_langchain_messages(request)
846
+ messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
847
+ request
848
+ )
329
849
 
330
- # Prepare context
850
+ # Prepare context (conversation_id -> thread_id mapping happens here)
331
851
  context: Context = self._convert_request_to_context(request)
332
852
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
333
853
 
334
- # Build the graph input state, including genie_conversation_ids if provided
335
- graph_input: dict[str, Any] = {"messages": messages}
336
- if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
337
- graph_input["genie_conversation_ids"] = request.custom_inputs[
338
- "genie_conversation_ids"
339
- ]
340
- logger.debug(
341
- f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
342
- )
854
+ # Extract session state from request
855
+ session_input: dict[str, Any] = self._extract_session_from_request(request)
343
856
 
344
857
  # Use async ainvoke internally for parallel execution
345
858
  import asyncio
346
859
 
860
+ from langgraph.types import Command
861
+
347
862
  async def _async_invoke():
348
863
  try:
864
+ # Check if this is a resume request (HITL)
865
+ # Two ways to resume:
866
+ # 1. Explicit decisions in custom_inputs (structured)
867
+ # 2. Natural language message when graph is interrupted (LLM-parsed)
868
+
869
+ if request.custom_inputs and "decisions" in request.custom_inputs:
870
+ # Explicit structured decisions
871
+ decisions: list[Decision] = request.custom_inputs["decisions"]
872
+ logger.info(
873
+ f"HITL: Resuming with {len(decisions)} explicit decision(s)"
874
+ )
875
+
876
+ # Resume interrupted graph with decisions
877
+ return await self.graph.ainvoke(
878
+ Command(resume={"decisions": decisions}),
879
+ context=context,
880
+ config=custom_inputs,
881
+ )
882
+
883
+ # Check if graph is currently interrupted (only if checkpointer is configured)
884
+ # aget_state requires a checkpointer
885
+ if self.graph.checkpointer:
886
+ snapshot: StateSnapshot = await self.graph.aget_state(
887
+ config=custom_inputs
888
+ )
889
+ if is_interrupted(snapshot):
890
+ logger.info(
891
+ "HITL: Graph is interrupted, checking for user response"
892
+ )
893
+
894
+ # Convert message dicts to BaseMessage objects
895
+ message_objects: list[BaseMessage] = convert_openai_messages(
896
+ messages
897
+ )
898
+
899
+ # Parse user's message with LLM to extract decisions
900
+ parsed_result: dict[str, Any] = handle_interrupt_response(
901
+ snapshot=snapshot,
902
+ messages=message_objects,
903
+ model=None, # Uses default model
904
+ )
905
+
906
+ # Check if the response was valid
907
+ if not parsed_result.get("is_valid", False):
908
+ validation_message: str = parsed_result.get(
909
+ "validation_message",
910
+ "Your response was unclear. Please provide a clear decision for each action.",
911
+ )
912
+ logger.warning(
913
+ f"HITL: Invalid response - {validation_message}"
914
+ )
915
+
916
+ # Return error message to user instead of resuming
917
+ # Don't resume the graph - stay interrupted so user can try again
918
+ return {
919
+ "messages": [
920
+ AIMessage(
921
+ content=f"❌ **Invalid Response**\n\n{validation_message}"
922
+ )
923
+ ]
924
+ }
925
+
926
+ decisions: list[Decision] = parsed_result.get("decisions", [])
927
+ logger.info(
928
+ f"HITL: LLM parsed {len(decisions)} valid decision(s) from user message"
929
+ )
930
+
931
+ # Resume interrupted graph with parsed decisions
932
+ return await self.graph.ainvoke(
933
+ Command(resume={"decisions": decisions}),
934
+ context=context,
935
+ config=custom_inputs,
936
+ )
937
+
938
+ # Normal invocation - build the graph input state
939
+ graph_input: dict[str, Any] = {"messages": messages}
940
+ if "genie_conversation_ids" in session_input:
941
+ graph_input["genie_conversation_ids"] = session_input[
942
+ "genie_conversation_ids"
943
+ ]
944
+ logger.debug(
945
+ f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
946
+ )
947
+
349
948
  return await self.graph.ainvoke(
350
949
  graph_input, context=context, config=custom_inputs
351
950
  )
@@ -356,7 +955,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
356
955
  try:
357
956
  loop = asyncio.get_event_loop()
358
957
  except RuntimeError:
359
- # Handle case where no event loop exists (common in some deployment scenarios)
360
958
  loop = asyncio.new_event_loop()
361
959
  asyncio.set_event_loop(loop)
362
960
 
@@ -371,22 +969,81 @@ class LanggraphResponsesAgent(ResponsesAgent):
371
969
  # Convert response to ResponsesAgent format
372
970
  last_message: BaseMessage = response["messages"][-1]
373
971
 
374
- output_item = self.create_text_output_item(
375
- text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
972
+ # Build custom_outputs that can be copy-pasted as next request's custom_inputs
973
+ custom_outputs: dict[str, Any] = self._build_custom_outputs(
974
+ context=context,
975
+ thread_id=context.thread_id,
976
+ loop=loop,
376
977
  )
377
978
 
378
- # Retrieve genie_conversation_ids from state if available
379
- custom_outputs: dict[str, Any] = custom_inputs.copy()
380
- thread_id: Optional[str] = context.thread_id
381
- if thread_id:
382
- state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
383
- get_state_snapshot_async(self.graph, thread_id)
979
+ # Handle structured_response if present
980
+ if "structured_response" in response:
981
+ from dataclasses import asdict, is_dataclass
982
+
983
+ from pydantic import BaseModel
984
+
985
+ structured_response = response["structured_response"]
986
+ logger.debug(f"Processing structured_response: {type(structured_response)}")
987
+
988
+ # Serialize to dict for JSON compatibility using type hints
989
+ if isinstance(structured_response, BaseModel):
990
+ # Pydantic model
991
+ serialized: dict[str, Any] = structured_response.model_dump()
992
+ elif is_dataclass(structured_response):
993
+ # Dataclass
994
+ serialized = asdict(structured_response)
995
+ elif isinstance(structured_response, dict):
996
+ # Already a dict
997
+ serialized = structured_response
998
+ else:
999
+ # Unknown type, convert to dict if possible
1000
+ serialized = (
1001
+ dict(structured_response)
1002
+ if hasattr(structured_response, "__dict__")
1003
+ else structured_response
1004
+ )
1005
+
1006
+ # Place structured output in message content as JSON
1007
+ import json
1008
+
1009
+ structured_text: str = json.dumps(serialized, indent=2)
1010
+ output_item = self.create_text_output_item(
1011
+ text=structured_text, id=f"msg_{uuid.uuid4().hex[:8]}"
384
1012
  )
385
- genie_conversation_ids: dict[str, str] = (
386
- get_genie_conversation_ids_from_state(state_snapshot)
1013
+ logger.debug("Placed structured_response in message content")
1014
+ else:
1015
+ # No structured response, use text content
1016
+ output_item = self.create_text_output_item(
1017
+ text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
387
1018
  )
388
- if genie_conversation_ids:
389
- custom_outputs["genie_conversation_ids"] = genie_conversation_ids
1019
+
1020
+ # Include interrupt structure if HITL occurred (following LangChain pattern)
1021
+ if "__interrupt__" in response:
1022
+ interrupts: list[Interrupt] = response["__interrupt__"]
1023
+ logger.info(f"HITL: {len(interrupts)} interrupt(s) detected")
1024
+
1025
+ # Extract HITLRequest structures from interrupts (deduplicate by ID)
1026
+ seen_interrupt_ids: set[str] = set()
1027
+ interrupt_data: list[HITLRequest] = []
1028
+ interrupt: Interrupt
1029
+ for interrupt in interrupts:
1030
+ # Only process each unique interrupt once
1031
+ if interrupt.id not in seen_interrupt_ids:
1032
+ seen_interrupt_ids.add(interrupt.id)
1033
+ interrupt_data.append(_extract_interrupt_value(interrupt))
1034
+ logger.debug(f"HITL: Added interrupt {interrupt.id} to response")
1035
+
1036
+ custom_outputs["interrupts"] = interrupt_data
1037
+ logger.debug(
1038
+ f"HITL: Included {len(interrupt_data)} interrupt(s) in response"
1039
+ )
1040
+
1041
+ # Add user-facing message about the pending actions
1042
+ action_message: str = _format_action_requests_message(interrupt_data)
1043
+ if action_message:
1044
+ output_item = self.create_text_output_item(
1045
+ text=action_message, id=f"msg_{uuid.uuid4().hex[:8]}"
1046
+ )
390
1047
 
391
1048
  return ResponsesAgentResponse(
392
1049
  output=[output_item], custom_outputs=custom_outputs
@@ -397,89 +1054,310 @@ class LanggraphResponsesAgent(ResponsesAgent):
397
1054
  ) -> Generator[ResponsesAgentStreamEvent, None, None]:
398
1055
  """
399
1056
  Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
1057
+
1058
+ Uses same input/output structure as predict() for consistency.
1059
+ Supports Human-in-the-Loop (HITL) interrupts.
400
1060
  """
401
1061
  logger.debug(f"ResponsesAgent stream request: {request}")
402
1062
 
403
1063
  # Convert ResponsesAgent input to LangChain messages
404
- messages: list[BaseMessage] = self._convert_request_to_langchain_messages(
1064
+ messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
405
1065
  request
406
1066
  )
407
1067
 
408
- # Prepare context
1068
+ # Prepare context (conversation_id -> thread_id mapping happens here)
409
1069
  context: Context = self._convert_request_to_context(request)
410
1070
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
411
1071
 
412
- # Build the graph input state, including genie_conversation_ids if provided
413
- graph_input: dict[str, Any] = {"messages": messages}
414
- if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
415
- graph_input["genie_conversation_ids"] = request.custom_inputs[
416
- "genie_conversation_ids"
417
- ]
418
- logger.debug(
419
- f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
420
- )
1072
+ # Extract session state from request
1073
+ session_input: dict[str, Any] = self._extract_session_from_request(request)
421
1074
 
422
1075
  # Use async astream internally for parallel execution
423
1076
  import asyncio
424
1077
 
1078
+ from langgraph.types import Command
1079
+
425
1080
  async def _async_stream():
426
- item_id = f"msg_{uuid.uuid4().hex[:8]}"
427
- accumulated_content = ""
1081
+ item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
1082
+ accumulated_content: str = ""
1083
+ interrupt_data: list[HITLRequest] = []
1084
+ seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
1085
+ structured_response: Any = None # Track structured output from stream
428
1086
 
429
1087
  try:
430
- async for nodes, stream_mode, messages_batch in self.graph.astream(
431
- graph_input,
1088
+ # Check if this is a resume request (HITL)
1089
+ # Two ways to resume:
1090
+ # 1. Explicit decisions in custom_inputs (structured)
1091
+ # 2. Natural language message when graph is interrupted (LLM-parsed)
1092
+
1093
+ if request.custom_inputs and "decisions" in request.custom_inputs:
1094
+ # Explicit structured decisions
1095
+ decisions: list[Decision] = request.custom_inputs["decisions"]
1096
+ logger.info(
1097
+ f"HITL: Resuming with {len(decisions)} explicit decision(s)"
1098
+ )
1099
+ stream_input: Command | dict[str, Any] = Command(
1100
+ resume={"decisions": decisions}
1101
+ )
1102
+ elif self.graph.checkpointer:
1103
+ # Check if graph is currently interrupted (only if checkpointer is configured)
1104
+ # aget_state requires a checkpointer
1105
+ snapshot: StateSnapshot = await self.graph.aget_state(
1106
+ config=custom_inputs
1107
+ )
1108
+ if is_interrupted(snapshot):
1109
+ logger.info(
1110
+ "HITL: Graph is interrupted, checking for user response"
1111
+ )
1112
+
1113
+ # Convert message dicts to BaseMessage objects
1114
+ message_objects: list[BaseMessage] = convert_openai_messages(
1115
+ messages
1116
+ )
1117
+
1118
+ # Parse user's message with LLM to extract decisions
1119
+ parsed_result: dict[str, Any] = handle_interrupt_response(
1120
+ snapshot=snapshot,
1121
+ messages=message_objects,
1122
+ model=None, # Uses default model
1123
+ )
1124
+
1125
+ # Check if the response was valid
1126
+ if not parsed_result.get("is_valid", False):
1127
+ validation_message: str = parsed_result.get(
1128
+ "validation_message",
1129
+ "Your response was unclear. Please provide a clear decision for each action.",
1130
+ )
1131
+ logger.warning(
1132
+ f"HITL: Invalid response - {validation_message}"
1133
+ )
1134
+
1135
+ # Build custom_outputs before returning
1136
+ custom_outputs: dict[
1137
+ str, Any
1138
+ ] = await self._build_custom_outputs_async(
1139
+ context=context,
1140
+ thread_id=context.thread_id,
1141
+ )
1142
+
1143
+ # Yield error message to user - don't resume graph
1144
+ error_message: str = (
1145
+ f"❌ **Invalid Response**\n\n{validation_message}"
1146
+ )
1147
+ accumulated_content = error_message
1148
+ yield ResponsesAgentStreamEvent(
1149
+ type="response.output_item.done",
1150
+ item=self.create_text_output_item(
1151
+ text=error_message, id=item_id
1152
+ ),
1153
+ custom_outputs=custom_outputs,
1154
+ )
1155
+ return # Don't resume - stay interrupted
1156
+
1157
+ decisions: list[Decision] = parsed_result.get("decisions", [])
1158
+ logger.info(
1159
+ f"HITL: LLM parsed {len(decisions)} valid decision(s) from user message"
1160
+ )
1161
+
1162
+ # Resume interrupted graph with parsed decisions
1163
+ stream_input: Command | dict[str, Any] = Command(
1164
+ resume={"decisions": decisions}
1165
+ )
1166
+ else:
1167
+ # Graph not interrupted, use normal invocation
1168
+ graph_input: dict[str, Any] = {"messages": messages}
1169
+ if "genie_conversation_ids" in session_input:
1170
+ graph_input["genie_conversation_ids"] = session_input[
1171
+ "genie_conversation_ids"
1172
+ ]
1173
+ stream_input: Command | dict[str, Any] = graph_input
1174
+ else:
1175
+ # No checkpointer, use normal invocation
1176
+ graph_input: dict[str, Any] = {"messages": messages}
1177
+ if "genie_conversation_ids" in session_input:
1178
+ graph_input["genie_conversation_ids"] = session_input[
1179
+ "genie_conversation_ids"
1180
+ ]
1181
+ stream_input: Command | dict[str, Any] = graph_input
1182
+
1183
+ # Stream the graph execution with both messages and updates modes to capture interrupts
1184
+ async for nodes, stream_mode, data in self.graph.astream(
1185
+ stream_input,
432
1186
  context=context,
433
1187
  config=custom_inputs,
434
- stream_mode=["messages", "custom"],
1188
+ stream_mode=["messages", "updates"],
435
1189
  subgraphs=True,
436
1190
  ):
437
1191
  nodes: tuple[str, ...]
438
1192
  stream_mode: str
439
- messages_batch: Sequence[BaseMessage]
440
1193
 
441
- for message in messages_batch:
442
- if (
443
- isinstance(
444
- message,
445
- (
446
- AIMessageChunk,
447
- AIMessage,
448
- ),
449
- )
450
- and message.content
451
- and "summarization" not in nodes
452
- ):
453
- content = message.content
454
- accumulated_content += content
1194
+ # Handle message streaming
1195
+ if stream_mode == "messages":
1196
+ messages_batch: Sequence[BaseMessage] = data
1197
+ message: BaseMessage
1198
+ for message in messages_batch:
1199
+ if (
1200
+ isinstance(
1201
+ message,
1202
+ (
1203
+ AIMessageChunk,
1204
+ AIMessage,
1205
+ ),
1206
+ )
1207
+ and message.content
1208
+ and "summarization" not in nodes
1209
+ ):
1210
+ content: str = message.content
1211
+ accumulated_content += content
1212
+
1213
+ # Yield streaming delta
1214
+ yield ResponsesAgentStreamEvent(
1215
+ **self.create_text_delta(
1216
+ delta=content, item_id=item_id
1217
+ )
1218
+ )
1219
+
1220
+ # Handle interrupts (HITL) and state updates
1221
+ elif stream_mode == "updates":
1222
+ updates: dict[str, Any] = data
1223
+ source: str
1224
+ update: Any
1225
+ for source, update in updates.items():
1226
+ if source == "__interrupt__":
1227
+ interrupts: list[Interrupt] = update
1228
+ logger.info(
1229
+ f"HITL: {len(interrupts)} interrupt(s) detected during streaming"
1230
+ )
1231
+
1232
+ # Extract interrupt values (deduplicate by ID)
1233
+ interrupt: Interrupt
1234
+ for interrupt in interrupts:
1235
+ # Only process each unique interrupt once
1236
+ if interrupt.id not in seen_interrupt_ids:
1237
+ seen_interrupt_ids.add(interrupt.id)
1238
+ interrupt_data.append(
1239
+ _extract_interrupt_value(interrupt)
1240
+ )
1241
+ logger.debug(
1242
+ f"HITL: Added interrupt {interrupt.id} to response"
1243
+ )
1244
+ elif (
1245
+ isinstance(update, dict)
1246
+ and "structured_response" in update
1247
+ ):
1248
+ # Capture structured_response from stream updates
1249
+ structured_response = update["structured_response"]
1250
+ logger.debug(
1251
+ f"Captured structured_response from stream: {type(structured_response)}"
1252
+ )
1253
+
1254
+ # Get final state to extract structured_response (only if checkpointer available)
1255
+ if self.graph.checkpointer:
1256
+ final_state: StateSnapshot = await self.graph.aget_state(
1257
+ config=custom_inputs
1258
+ )
1259
+ # Extract structured_response from state if not already captured
1260
+ if (
1261
+ "structured_response" in final_state.values
1262
+ and not structured_response
1263
+ ):
1264
+ structured_response = final_state.values["structured_response"]
455
1265
 
456
- # Yield streaming delta
457
- yield ResponsesAgentStreamEvent(
458
- **self.create_text_delta(delta=content, item_id=item_id)
459
- )
1266
+ # Build custom_outputs
1267
+ custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
1268
+ context=context,
1269
+ thread_id=context.thread_id,
1270
+ )
1271
+
1272
+ # Handle structured_response in streaming if present
1273
+ output_text: str = accumulated_content
1274
+ if structured_response:
1275
+ from dataclasses import asdict, is_dataclass
460
1276
 
461
- # Retrieve genie_conversation_ids from state if available
462
- custom_outputs: dict[str, Any] = custom_inputs.copy()
463
- thread_id: Optional[str] = context.thread_id
1277
+ from pydantic import BaseModel
464
1278
 
465
- if thread_id:
466
- state_snapshot: Optional[
467
- StateSnapshot
468
- ] = await get_state_snapshot_async(self.graph, thread_id)
469
- genie_conversation_ids: dict[str, str] = (
470
- get_genie_conversation_ids_from_state(state_snapshot)
1279
+ logger.debug(
1280
+ f"Processing structured_response in streaming: {type(structured_response)}"
471
1281
  )
472
- if genie_conversation_ids:
473
- custom_outputs["genie_conversation_ids"] = (
474
- genie_conversation_ids
1282
+
1283
+ # Serialize to dict for JSON compatibility using type hints
1284
+ if isinstance(structured_response, BaseModel):
1285
+ serialized: dict[str, Any] = structured_response.model_dump()
1286
+ elif is_dataclass(structured_response):
1287
+ serialized = asdict(structured_response)
1288
+ elif isinstance(structured_response, dict):
1289
+ serialized = structured_response
1290
+ else:
1291
+ serialized = (
1292
+ dict(structured_response)
1293
+ if hasattr(structured_response, "__dict__")
1294
+ else structured_response
1295
+ )
1296
+
1297
+ # Place structured output in message content - stream as JSON
1298
+ import json
1299
+
1300
+ structured_text: str = json.dumps(serialized, indent=2)
1301
+
1302
+ # If we streamed text, append structured; if no text, use structured only
1303
+ if accumulated_content.strip():
1304
+ # Stream separator and structured output
1305
+ yield ResponsesAgentStreamEvent(
1306
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
475
1307
  )
1308
+ yield ResponsesAgentStreamEvent(
1309
+ **self.create_text_delta(
1310
+ delta=structured_text, item_id=item_id
1311
+ )
1312
+ )
1313
+ output_text = f"{accumulated_content}\n\n{structured_text}"
1314
+ else:
1315
+ # No text content, stream structured output
1316
+ yield ResponsesAgentStreamEvent(
1317
+ **self.create_text_delta(
1318
+ delta=structured_text, item_id=item_id
1319
+ )
1320
+ )
1321
+ output_text = structured_text
1322
+
1323
+ logger.debug("Streamed structured_response in message content")
1324
+
1325
+ # Include interrupt structure if HITL occurred
1326
+ if interrupt_data:
1327
+ custom_outputs["interrupts"] = interrupt_data
1328
+ logger.info(
1329
+ f"HITL: Included {len(interrupt_data)} interrupt(s) in streaming response"
1330
+ )
1331
+
1332
+ # Add user-facing message about the pending actions
1333
+ action_message = _format_action_requests_message(interrupt_data)
1334
+ if action_message:
1335
+ # If we haven't streamed any content yet, stream the action message
1336
+ if not accumulated_content:
1337
+ output_text = action_message
1338
+ # Stream the action message
1339
+ yield ResponsesAgentStreamEvent(
1340
+ **self.create_text_delta(
1341
+ delta=action_message, item_id=item_id
1342
+ )
1343
+ )
1344
+ else:
1345
+ # Append action message after accumulated content
1346
+ output_text = f"{accumulated_content}\n\n{action_message}"
1347
+ # Stream the separator and action message
1348
+ yield ResponsesAgentStreamEvent(
1349
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
1350
+ )
1351
+ yield ResponsesAgentStreamEvent(
1352
+ **self.create_text_delta(
1353
+ delta=action_message, item_id=item_id
1354
+ )
1355
+ )
476
1356
 
477
1357
  # Yield final output item
478
1358
  yield ResponsesAgentStreamEvent(
479
1359
  type="response.output_item.done",
480
- item=self.create_text_output_item(
481
- text=accumulated_content, id=item_id
482
- ),
1360
+ item=self.create_text_output_item(text=output_text, id=item_id),
483
1361
  custom_outputs=custom_outputs,
484
1362
  )
485
1363
  except Exception as e:
@@ -490,7 +1368,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
490
1368
  try:
491
1369
  loop = asyncio.get_event_loop()
492
1370
  except RuntimeError:
493
- # Handle case where no event loop exists (common in some deployment scenarios)
494
1371
  loop = asyncio.new_event_loop()
495
1372
  asyncio.set_event_loop(loop)
496
1373
 
@@ -575,15 +1452,24 @@ class LanggraphResponsesAgent(ResponsesAgent):
575
1452
  return messages
576
1453
 
577
1454
  def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
578
- """Convert ResponsesAgent context to internal Context."""
1455
+ """Convert ResponsesAgent context to internal Context.
1456
+
1457
+ Handles the input structure:
1458
+ - custom_inputs.configurable: Configuration (thread_id, user_id, store_num, etc.)
1459
+ - custom_inputs.session: Accumulated state (conversation_id, genie conversations, etc.)
579
1460
 
1461
+ Maps conversation_id -> thread_id for LangGraph compatibility.
1462
+ conversation_id can be provided in either configurable or session.
1463
+ Normalizes user_id (replaces . with _) for memory namespace compatibility.
1464
+ """
580
1465
  logger.debug(f"request.context: {request.context}")
581
1466
  logger.debug(f"request.custom_inputs: {request.custom_inputs}")
582
1467
 
583
1468
  configurable: dict[str, Any] = {}
1469
+ session: dict[str, Any] = {}
584
1470
 
585
1471
  # Process context values first (lower priority)
586
- # Use strong typing with forward-declared type hints instead of hasattr checks
1472
+ # These come from Databricks ResponsesAgent ChatContext
587
1473
  chat_context: Optional[ChatContext] = request.context
588
1474
  if chat_context is not None:
589
1475
  conversation_id: Optional[str] = chat_context.conversation_id
@@ -591,27 +1477,185 @@ class LanggraphResponsesAgent(ResponsesAgent):
591
1477
 
592
1478
  if conversation_id is not None:
593
1479
  configurable["conversation_id"] = conversation_id
594
- configurable["thread_id"] = conversation_id
595
1480
 
596
1481
  if user_id is not None:
597
1482
  configurable["user_id"] = user_id
598
1483
 
599
1484
  # Process custom_inputs after context so they can override context values (higher priority)
600
1485
  if request.custom_inputs:
1486
+ # Extract configurable section (user config)
601
1487
  if "configurable" in request.custom_inputs:
602
- configurable.update(request.custom_inputs.pop("configurable"))
1488
+ configurable.update(request.custom_inputs["configurable"])
1489
+
1490
+ # Extract session section
1491
+ if "session" in request.custom_inputs:
1492
+ session_input = request.custom_inputs["session"]
1493
+ if isinstance(session_input, dict):
1494
+ session = session_input
1495
+
1496
+ # Handle legacy flat structure (backwards compatibility)
1497
+ # If user passes keys directly in custom_inputs, merge them
1498
+ for key in list(request.custom_inputs.keys()):
1499
+ if key not in ("configurable", "session"):
1500
+ configurable[key] = request.custom_inputs[key]
1501
+
1502
+ # Extract known Context fields
1503
+ user_id_value: str | None = configurable.pop("user_id", None)
1504
+ if user_id_value:
1505
+ # Normalize user_id for memory namespace (replace . with _)
1506
+ user_id_value = user_id_value.replace(".", "_")
1507
+
1508
+ # Accept thread_id from configurable, or conversation_id from configurable or session
1509
+ # Priority: configurable.conversation_id > session.conversation_id > configurable.thread_id
1510
+ thread_id: str | None = configurable.pop("thread_id", None)
1511
+ conversation_id: str | None = configurable.pop("conversation_id", None)
1512
+
1513
+ # Also check session for conversation_id (output puts it there)
1514
+ if conversation_id is None and "conversation_id" in session:
1515
+ conversation_id = session.get("conversation_id")
1516
+
1517
+ # conversation_id takes precedence if provided
1518
+ if conversation_id:
1519
+ thread_id = conversation_id
1520
+ if not thread_id:
1521
+ # Generate new thread_id if neither provided
1522
+ thread_id = str(uuid.uuid4())
1523
+
1524
+ # All remaining configurable values go into custom dict
1525
+ custom: dict[str, Any] = configurable
1526
+
1527
+ logger.debug(
1528
+ f"Creating context with user_id={user_id_value}, thread_id={thread_id}, custom={custom}"
1529
+ )
603
1530
 
604
- configurable.update(request.custom_inputs)
1531
+ return Context(
1532
+ user_id=user_id_value,
1533
+ thread_id=thread_id,
1534
+ custom=custom,
1535
+ )
605
1536
 
606
- if "user_id" in configurable:
607
- configurable["user_id"] = configurable["user_id"].replace(".", "_")
1537
+ def _extract_session_from_request(
1538
+ self, request: ResponsesAgentRequest
1539
+ ) -> dict[str, Any]:
1540
+ """Extract session state from request for passing to graph.
608
1541
 
609
- if "thread_id" not in configurable:
610
- configurable["thread_id"] = str(uuid.uuid4())
1542
+ Handles:
1543
+ - New structure: custom_inputs.session.genie
1544
+ - Legacy structure: custom_inputs.genie_conversation_ids
1545
+ """
1546
+ session: dict[str, Any] = {}
1547
+
1548
+ if not request.custom_inputs:
1549
+ return session
1550
+
1551
+ # New structure: session.genie
1552
+ if "session" in request.custom_inputs:
1553
+ session_input = request.custom_inputs["session"]
1554
+ if isinstance(session_input, dict) and "genie" in session_input:
1555
+ genie_state = session_input["genie"]
1556
+ # Extract conversation IDs from the new structure
1557
+ if isinstance(genie_state, dict) and "spaces" in genie_state:
1558
+ genie_conversation_ids = {}
1559
+ for space_id, space_state in genie_state["spaces"].items():
1560
+ if (
1561
+ isinstance(space_state, dict)
1562
+ and "conversation_id" in space_state
1563
+ ):
1564
+ genie_conversation_ids[space_id] = space_state[
1565
+ "conversation_id"
1566
+ ]
1567
+ if genie_conversation_ids:
1568
+ session["genie_conversation_ids"] = genie_conversation_ids
1569
+
1570
+ # Legacy structure: genie_conversation_ids at top level
1571
+ if "genie_conversation_ids" in request.custom_inputs:
1572
+ session["genie_conversation_ids"] = request.custom_inputs[
1573
+ "genie_conversation_ids"
1574
+ ]
1575
+
1576
+ # Also check inside configurable for legacy support
1577
+ if "configurable" in request.custom_inputs:
1578
+ cfg = request.custom_inputs["configurable"]
1579
+ if isinstance(cfg, dict) and "genie_conversation_ids" in cfg:
1580
+ session["genie_conversation_ids"] = cfg["genie_conversation_ids"]
1581
+
1582
+ return session
1583
+
1584
+ def _build_custom_outputs(
1585
+ self,
1586
+ context: Context,
1587
+ thread_id: Optional[str],
1588
+ loop: Any, # asyncio.AbstractEventLoop
1589
+ ) -> dict[str, Any]:
1590
+ """Build custom_outputs that can be copy-pasted as next request's custom_inputs.
1591
+
1592
+ Output structure:
1593
+ configurable:
1594
+ thread_id: "abc-123" # Thread identifier (conversation_id is alias)
1595
+ user_id: "nate.fleming" # De-normalized (no underscore replacement)
1596
+ store_num: "87887" # Any custom fields
1597
+ session:
1598
+ conversation_id: "abc-123" # Alias of thread_id for Databricks compatibility
1599
+ genie:
1600
+ spaces:
1601
+ space_123: {conversation_id: "conv_456", cache_hit: false, ...}
1602
+ """
1603
+ return loop.run_until_complete(
1604
+ self._build_custom_outputs_async(context=context, thread_id=thread_id)
1605
+ )
1606
+
1607
+ async def _build_custom_outputs_async(
1608
+ self,
1609
+ context: Context,
1610
+ thread_id: Optional[str],
1611
+ ) -> dict[str, Any]:
1612
+ """Async version of _build_custom_outputs."""
1613
+ # Build configurable section
1614
+ # Note: only thread_id is included here (conversation_id goes in session)
1615
+ configurable: dict[str, Any] = {}
1616
+
1617
+ if thread_id:
1618
+ configurable["thread_id"] = thread_id
611
1619
 
612
- logger.debug(f"Creating context from: {configurable}")
1620
+ # Include user_id (keep normalized form for consistency)
1621
+ if context.user_id:
1622
+ configurable["user_id"] = context.user_id
613
1623
 
614
- return Context(**configurable)
1624
+ # Include all custom fields from context
1625
+ configurable.update(context.custom)
1626
+
1627
+ # Build session section with accumulated state
1628
+ # Note: conversation_id is included here as an alias of thread_id
1629
+ session: dict[str, Any] = {}
1630
+
1631
+ if thread_id:
1632
+ # Include conversation_id in session (alias of thread_id)
1633
+ session["conversation_id"] = thread_id
1634
+
1635
+ state_snapshot: Optional[StateSnapshot] = await get_state_snapshot_async(
1636
+ self.graph, thread_id
1637
+ )
1638
+ genie_conversation_ids: dict[str, str] = (
1639
+ get_genie_conversation_ids_from_state(state_snapshot)
1640
+ )
1641
+ if genie_conversation_ids:
1642
+ # Convert flat genie_conversation_ids to new session.genie.spaces structure
1643
+ session["genie"] = {
1644
+ "spaces": {
1645
+ space_id: {
1646
+ "conversation_id": conv_id,
1647
+ # Note: cache_hit, follow_up_questions populated by Genie tool
1648
+ "cache_hit": False,
1649
+ "follow_up_questions": [],
1650
+ }
1651
+ for space_id, conv_id in genie_conversation_ids.items()
1652
+ }
1653
+ }
1654
+
1655
+ return {
1656
+ "configurable": configurable,
1657
+ "session": session,
1658
+ }
615
1659
 
616
1660
 
617
1661
  def create_agent(graph: CompiledStateGraph) -> ChatAgent:
@@ -630,7 +1674,9 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
630
1674
  return LanggraphChatModel(graph)
631
1675
 
632
1676
 
633
- def create_responses_agent(graph: CompiledStateGraph) -> ResponsesAgent:
1677
+ def create_responses_agent(
1678
+ graph: CompiledStateGraph,
1679
+ ) -> ResponsesAgent:
634
1680
  """
635
1681
  Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
636
1682
 
@@ -665,6 +1711,29 @@ def _process_langchain_messages(
665
1711
  return loop.run_until_complete(_async_invoke())
666
1712
 
667
1713
 
1714
+ def _configurable_to_context(configurable: dict[str, Any]) -> Context:
1715
+ """Convert a configurable dict to a Context object."""
1716
+ configurable = configurable.copy()
1717
+
1718
+ # Extract known Context fields
1719
+ user_id: str | None = configurable.pop("user_id", None)
1720
+ if user_id:
1721
+ user_id = user_id.replace(".", "_")
1722
+
1723
+ thread_id: str | None = configurable.pop("thread_id", None)
1724
+ if "conversation_id" in configurable and not thread_id:
1725
+ thread_id = configurable.pop("conversation_id")
1726
+ if not thread_id:
1727
+ thread_id = str(uuid.uuid4())
1728
+
1729
+ # All remaining values go into custom dict
1730
+ return Context(
1731
+ user_id=user_id,
1732
+ thread_id=thread_id,
1733
+ custom=configurable,
1734
+ )
1735
+
1736
+
668
1737
  def _process_langchain_messages_stream(
669
1738
  app: LanggraphChatModel | CompiledStateGraph,
670
1739
  messages: Sequence[BaseMessage],
@@ -678,8 +1747,8 @@ def _process_langchain_messages_stream(
678
1747
 
679
1748
  logger.debug(f"Processing messages: {messages}, custom_inputs: {custom_inputs}")
680
1749
 
681
- custom_inputs = custom_inputs.get("configurable", custom_inputs or {})
682
- context: Context = Context(**custom_inputs)
1750
+ configurable = (custom_inputs or {}).get("configurable", custom_inputs or {})
1751
+ context: Context = _configurable_to_context(configurable)
683
1752
 
684
1753
  # Use async astream internally for parallel execution
685
1754
  async def _async_stream():