dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py CHANGED
@@ -1,11 +1,34 @@
1
1
  import uuid
2
2
  from os import PathLike
3
3
  from pathlib import Path
4
- from typing import Any, Generator, Optional, Sequence, Union
5
-
6
- from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
4
+ from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Sequence, Union
5
+
6
+ from databricks_langchain import ChatDatabricks
7
+
8
+ if TYPE_CHECKING:
9
+ pass
10
+
11
+ # Import official LangChain HITL TypedDict definitions
12
+ # Reference: https://docs.langchain.com/oss/python/langchain/human-in-the-loop
13
+ from langchain.agents.middleware.human_in_the_loop import (
14
+ ActionRequest,
15
+ Decision,
16
+ EditDecision,
17
+ HITLRequest,
18
+ RejectDecision,
19
+ ReviewConfig,
20
+ )
21
+ from langchain_community.adapters.openai import convert_openai_messages
22
+ from langchain_core.language_models import LanguageModelLike
23
+ from langchain_core.messages import (
24
+ AIMessage,
25
+ AIMessageChunk,
26
+ BaseMessage,
27
+ HumanMessage,
28
+ SystemMessage,
29
+ )
7
30
  from langgraph.graph.state import CompiledStateGraph
8
- from langgraph.types import StateSnapshot
31
+ from langgraph.types import Interrupt, StateSnapshot
9
32
  from loguru import logger
10
33
  from mlflow import MlflowClient
11
34
  from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
@@ -28,11 +51,13 @@ from mlflow.types.responses_helpers import (
28
51
  Message,
29
52
  ResponseInputTextParam,
30
53
  )
54
+ from pydantic import BaseModel, Field, create_model
31
55
 
32
56
  from dao_ai.messages import (
33
57
  has_langchain_messages,
34
58
  has_mlflow_messages,
35
59
  has_mlflow_responses_messages,
60
+ last_human_message,
36
61
  )
37
62
  from dao_ai.state import Context
38
63
 
@@ -54,12 +79,37 @@ def get_latest_model_version(model_name: str) -> int:
54
79
  mlflow_client: MlflowClient = MlflowClient()
55
80
  latest_version: int = 1
56
81
  for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
57
- version_int = int(mv.version)
82
+ version_int: int = int(mv.version)
58
83
  if version_int > latest_version:
59
84
  latest_version = version_int
60
85
  return latest_version
61
86
 
62
87
 
88
+ def is_interrupted(snapshot: StateSnapshot) -> bool:
89
+ """
90
+ Check if the graph state is currently interrupted (paused for human-in-the-loop).
91
+
92
+ Based on LangChain documentation:
93
+ - StateSnapshot has an `interrupts` attribute which is a tuple
94
+ - When interrupted, the tuple contains Interrupt objects
95
+ - When not interrupted, it's an empty tuple ()
96
+
97
+ Args:
98
+ snapshot: The StateSnapshot to check
99
+
100
+ Returns:
101
+ True if the graph is interrupted (has pending HITL actions), False otherwise
102
+
103
+ Example:
104
+ >>> snapshot = await graph.aget_state(config)
105
+ >>> if is_interrupted(snapshot):
106
+ ... print("Graph is waiting for human input")
107
+ """
108
+ # Check if snapshot has any interrupts
109
+ # According to LangChain docs, interrupts is a tuple that's empty () when no interrupts
110
+ return bool(snapshot.interrupts)
111
+
112
+
63
113
  async def get_state_snapshot_async(
64
114
  graph: CompiledStateGraph, thread_id: str
65
115
  ) -> Optional[StateSnapshot]:
@@ -76,11 +126,11 @@ async def get_state_snapshot_async(
76
126
  Returns:
77
127
  StateSnapshot if found, None otherwise
78
128
  """
79
- logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
129
+ logger.trace("Retrieving state snapshot", thread_id=thread_id)
80
130
  try:
81
131
  # Check if graph has a checkpointer
82
132
  if graph.checkpointer is None:
83
- logger.debug("No checkpointer available in graph")
133
+ logger.trace("No checkpointer available in graph")
84
134
  return None
85
135
 
86
136
  # Get the current state from the checkpointer (use async version)
@@ -88,13 +138,15 @@ async def get_state_snapshot_async(
88
138
  state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
89
139
 
90
140
  if state_snapshot is None:
91
- logger.debug(f"No state found for thread_id: {thread_id}")
141
+ logger.trace("No state found for thread", thread_id=thread_id)
92
142
  return None
93
143
 
94
144
  return state_snapshot
95
145
 
96
146
  except Exception as e:
97
- logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
147
+ logger.warning(
148
+ "Error retrieving state snapshot", thread_id=thread_id, error=str(e)
149
+ )
98
150
  return None
99
151
 
100
152
 
@@ -125,7 +177,7 @@ def get_state_snapshot(
125
177
  try:
126
178
  return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
127
179
  except Exception as e:
128
- logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
180
+ logger.warning("Error in synchronous state snapshot retrieval", error=str(e))
129
181
  return None
130
182
 
131
183
 
@@ -157,16 +209,125 @@ def get_genie_conversation_ids_from_state(
157
209
  )
158
210
 
159
211
  if genie_conversation_ids:
160
- logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
212
+ logger.trace(
213
+ "Retrieved genie conversation IDs", count=len(genie_conversation_ids)
214
+ )
161
215
  return genie_conversation_ids
162
216
 
163
217
  return {}
164
218
 
165
219
  except Exception as e:
166
- logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
220
+ logger.warning(
221
+ "Error extracting genie conversation IDs from state", error=str(e)
222
+ )
167
223
  return {}
168
224
 
169
225
 
226
+ def _extract_interrupt_value(interrupt: Interrupt) -> HITLRequest:
227
+ """
228
+ Extract the HITL request from a LangGraph Interrupt object.
229
+
230
+ Following LangChain patterns, the Interrupt object has a .value attribute
231
+ containing the HITLRequest structure with action_requests and review_configs.
232
+
233
+ Args:
234
+ interrupt: Interrupt object from LangGraph with .value and .id attributes
235
+
236
+ Returns:
237
+ HITLRequest with action_requests and review_configs
238
+ """
239
+ # Interrupt.value is typed as Any, but for HITL it should be a HITLRequest dict
240
+ if isinstance(interrupt.value, dict):
241
+ # Return as HITLRequest TypedDict
242
+ return interrupt.value # type: ignore[return-value]
243
+
244
+ # Fallback: return empty structure if value is not a dict
245
+ return {"action_requests": [], "review_configs": []}
246
+
247
+
248
+ def _format_action_requests_message(interrupt_data: list[HITLRequest]) -> str:
249
+ """
250
+ Format action requests from interrupts into a simple, user-friendly message.
251
+
252
+ Since we now use LLM-based parsing, users can respond in natural language.
253
+ This function just shows WHAT actions are pending, not HOW to respond.
254
+
255
+ Args:
256
+ interrupt_data: List of HITLRequest structures containing action_requests and review_configs
257
+
258
+ Returns:
259
+ Simple formatted message describing the pending actions
260
+ """
261
+ if not interrupt_data:
262
+ return ""
263
+
264
+ # Collect all action requests and review configs from all interrupts
265
+ all_actions: list[ActionRequest] = []
266
+ review_configs_map: dict[str, ReviewConfig] = {}
267
+
268
+ for hitl_request in interrupt_data:
269
+ all_actions.extend(hitl_request.get("action_requests", []))
270
+ for review_config in hitl_request.get("review_configs", []):
271
+ action_name = review_config.get("action_name", "")
272
+ if action_name:
273
+ review_configs_map[action_name] = review_config
274
+
275
+ if not all_actions:
276
+ return ""
277
+
278
+ # Build simple, clean message
279
+ lines = ["⏸️ **Action Approval Required**", ""]
280
+ lines.append(
281
+ f"The assistant wants to perform {len(all_actions)} action(s) that require your approval:"
282
+ )
283
+ lines.append("")
284
+
285
+ for i, action in enumerate(all_actions, 1):
286
+ tool_name = action.get("name", "unknown")
287
+ args = action.get("args", {})
288
+ description = action.get("description")
289
+
290
+ lines.append(f"**{i}. {tool_name}**")
291
+
292
+ # Show review prompt/description if available
293
+ if description:
294
+ lines.append(f" • **Review:** {description}")
295
+
296
+ if args:
297
+ # Format args nicely, truncating long values
298
+ for key, value in args.items():
299
+ value_str = str(value)
300
+ if len(value_str) > 100:
301
+ value_str = value_str[:100] + "..."
302
+ lines.append(f" • {key}: `{value_str}`")
303
+ else:
304
+ lines.append(" • (no arguments)")
305
+
306
+ # Show allowed decisions
307
+ review_config = review_configs_map.get(tool_name)
308
+ if review_config:
309
+ allowed_decisions = review_config.get("allowed_decisions", [])
310
+ if allowed_decisions:
311
+ decisions_str = ", ".join(allowed_decisions)
312
+ lines.append(f" • **Options:** {decisions_str}")
313
+
314
+ lines.append("")
315
+
316
+ lines.append("---")
317
+ lines.append("")
318
+ lines.append(
319
+ "**You can respond in natural language** (e.g., 'approve both', 'reject the first one', "
320
+ "'change the email to new@example.com')"
321
+ )
322
+ lines.append("")
323
+ lines.append(
324
+ "Or provide structured decisions in `custom_inputs` with key `decisions`: "
325
+ '`[{"type": "approve"}, {"type": "reject", "message": "reason"}]`'
326
+ )
327
+
328
+ return "\n".join(lines)
329
+
330
+
170
331
  class LanggraphChatModel(ChatModel):
171
332
  """
172
333
  ChatModel that delegates requests to a LangGraph CompiledStateGraph.
@@ -178,7 +339,11 @@ class LanggraphChatModel(ChatModel):
178
339
  def predict(
179
340
  self, context, messages: list[ChatMessage], params: Optional[ChatParams] = None
180
341
  ) -> ChatCompletionResponse:
181
- logger.debug(f"messages: {messages}, params: {params}")
342
+ logger.trace(
343
+ "Predict called",
344
+ messages_count=len(messages),
345
+ has_params=params is not None,
346
+ )
182
347
  if not messages:
183
348
  raise ValueError("Message list is empty.")
184
349
 
@@ -200,7 +365,10 @@ class LanggraphChatModel(ChatModel):
200
365
  _async_invoke()
201
366
  )
202
367
 
203
- logger.trace(f"response: {response}")
368
+ logger.trace(
369
+ "Predict response received",
370
+ messages_count=len(response.get("messages", [])),
371
+ )
204
372
 
205
373
  last_message: BaseMessage = response["messages"][-1]
206
374
 
@@ -216,28 +384,43 @@ class LanggraphChatModel(ChatModel):
216
384
 
217
385
  configurable: dict[str, Any] = {}
218
386
  if "configurable" in input_data:
219
- configurable: dict[str, Any] = input_data.pop("configurable")
387
+ configurable = input_data.pop("configurable")
220
388
  if "custom_inputs" in input_data:
221
389
  custom_inputs: dict[str, Any] = input_data.pop("custom_inputs")
222
390
  if "configurable" in custom_inputs:
223
- configurable: dict[str, Any] = custom_inputs.pop("configurable")
224
-
225
- if "user_id" in configurable:
226
- configurable["user_id"] = configurable["user_id"].replace(".", "_")
227
-
228
- if "conversation_id" in configurable and "thread_id" not in configurable:
229
- configurable["thread_id"] = configurable["conversation_id"]
230
-
231
- if "thread_id" not in configurable:
232
- configurable["thread_id"] = str(uuid.uuid4())
233
-
234
- context: Context = Context(**configurable)
235
- return context
391
+ configurable = custom_inputs.pop("configurable")
392
+
393
+ # Extract known Context fields
394
+ user_id: str | None = configurable.pop("user_id", None)
395
+ if user_id:
396
+ user_id = user_id.replace(".", "_")
397
+
398
+ # Accept either thread_id or conversation_id (interchangeable)
399
+ # conversation_id takes precedence (Databricks vocabulary)
400
+ thread_id: str | None = configurable.pop("thread_id", None)
401
+ conversation_id: str | None = configurable.pop("conversation_id", None)
402
+
403
+ # conversation_id takes precedence if both provided
404
+ if conversation_id:
405
+ thread_id = conversation_id
406
+ if not thread_id:
407
+ thread_id = str(uuid.uuid4())
408
+
409
+ # All remaining configurable values become top-level context attributes
410
+ return Context(
411
+ user_id=user_id,
412
+ thread_id=thread_id,
413
+ **configurable, # Extra fields become top-level attributes
414
+ )
236
415
 
237
416
  def predict_stream(
238
417
  self, context, messages: list[ChatMessage], params: ChatParams
239
418
  ) -> Generator[ChatCompletionChunk, None, None]:
240
- logger.debug(f"messages: {messages}, params: {params}")
419
+ logger.trace(
420
+ "Predict stream called",
421
+ messages_count=len(messages),
422
+ has_params=params is not None,
423
+ )
241
424
  if not messages:
242
425
  raise ValueError("Message list is empty.")
243
426
 
@@ -261,7 +444,10 @@ class LanggraphChatModel(ChatModel):
261
444
  stream_mode: str
262
445
  messages_batch: Sequence[BaseMessage]
263
446
  logger.trace(
264
- f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages_batch}"
447
+ "Stream batch received",
448
+ nodes=nodes,
449
+ stream_mode=stream_mode,
450
+ messages_count=len(messages_batch),
265
451
  )
266
452
  for message in messages_batch:
267
453
  if (
@@ -307,6 +493,324 @@ class LanggraphChatModel(ChatModel):
307
493
  return [m.to_dict() for m in messages]
308
494
 
309
495
 
496
+ def _create_decision_schema(interrupt_data: list[HITLRequest]) -> type[BaseModel]:
497
+ """
498
+ Dynamically create a Pydantic model for structured output based on interrupt actions.
499
+
500
+ This creates a schema that matches the expected decision format for the interrupted actions.
501
+ Each action gets a corresponding decision field that can be approve, edit, or reject.
502
+ Includes validation fields to ensure the response is complete and valid.
503
+
504
+ Args:
505
+ interrupt_data: List of HITL interrupt requests containing action_requests and review_configs
506
+
507
+ Returns:
508
+ A dynamically created Pydantic BaseModel class for structured output
509
+
510
+ Example:
511
+ For two actions (send_email, execute_sql), creates a model like:
512
+ class Decisions(BaseModel):
513
+ is_valid: bool
514
+ validation_message: Optional[str]
515
+ decision_1: Literal["approve", "edit", "reject"]
516
+ decision_1_message: Optional[str] # For reject
517
+ decision_1_edited_args: Optional[dict] # For edit
518
+ decision_2: Literal["approve", "edit", "reject"]
519
+ ...
520
+ """
521
+ # Collect all actions
522
+ all_actions: list[ActionRequest] = []
523
+ review_configs_map: dict[str, ReviewConfig] = {}
524
+
525
+ for hitl_request in interrupt_data:
526
+ all_actions.extend(hitl_request.get("action_requests", []))
527
+ review_config: ReviewConfig
528
+ for review_config in hitl_request.get("review_configs", []):
529
+ action_name: str = review_config.get("action_name", "")
530
+ if action_name:
531
+ review_configs_map[action_name] = review_config
532
+
533
+ # Build fields for the dynamic model
534
+ # Start with validation fields
535
+ fields: dict[str, Any] = {
536
+ "is_valid": (
537
+ bool,
538
+ Field(
539
+ description="Whether the user's response provides valid decisions for ALL actions. "
540
+ "Set to False if the user's message is unclear, ambiguous, or doesn't provide decisions for all actions."
541
+ ),
542
+ ),
543
+ "validation_message": (
544
+ Optional[str],
545
+ Field(
546
+ None,
547
+ description="If is_valid is False, explain what is missing or unclear. "
548
+ "Be specific about which action(s) need clarification.",
549
+ ),
550
+ ),
551
+ }
552
+
553
+ i: int
554
+ action: ActionRequest
555
+ for i, action in enumerate(all_actions, 1):
556
+ tool_name: str = action.get("name", "unknown")
557
+ review_config: Optional[ReviewConfig] = review_configs_map.get(tool_name)
558
+ allowed_decisions: list[str] = (
559
+ review_config.get("allowed_decisions", ["approve", "reject"])
560
+ if review_config
561
+ else ["approve", "reject"]
562
+ )
563
+
564
+ # Create a Literal type for allowed decisions
565
+ decision_literal: type = Literal[tuple(allowed_decisions)] # type: ignore
566
+
567
+ # Add decision field
568
+ fields[f"decision_{i}"] = (
569
+ decision_literal,
570
+ Field(
571
+ description=f"Decision for action {i} ({tool_name}): {', '.join(allowed_decisions)}"
572
+ ),
573
+ )
574
+
575
+ # Add optional message field for reject
576
+ if "reject" in allowed_decisions:
577
+ fields[f"decision_{i}_message"] = (
578
+ Optional[str],
579
+ Field(
580
+ None,
581
+ description=f"Optional message if rejecting action {i}",
582
+ ),
583
+ )
584
+
585
+ # Add optional edited_args field for edit
586
+ if "edit" in allowed_decisions:
587
+ fields[f"decision_{i}_edited_args"] = (
588
+ Optional[dict[str, Any]],
589
+ Field(
590
+ None,
591
+ description=f"Modified arguments if editing action {i}. Only provide fields that need to change.",
592
+ ),
593
+ )
594
+
595
+ # Create the dynamic model
596
+ DecisionsModel = create_model(
597
+ "InterruptDecisions",
598
+ __doc__="Decisions for each interrupted action, in order.",
599
+ **fields,
600
+ )
601
+
602
+ return DecisionsModel
603
+
604
+
605
+ def _convert_schema_to_decisions(
606
+ parsed_output: BaseModel,
607
+ interrupt_data: list[HITLRequest],
608
+ ) -> list[Decision]:
609
+ """
610
+ Convert the parsed structured output into LangChain Decision objects.
611
+
612
+ Args:
613
+ parsed_output: The Pydantic model instance from structured output
614
+ interrupt_data: Original interrupt data for context
615
+
616
+ Returns:
617
+ List of Decision dictionaries compatible with Command(resume={"decisions": ...})
618
+ """
619
+ # Collect all actions to know how many decisions we need
620
+ all_actions: list[ActionRequest] = []
621
+ hitl_request: HITLRequest
622
+ for hitl_request in interrupt_data:
623
+ all_actions.extend(hitl_request.get("action_requests", []))
624
+
625
+ decisions: list[Decision] = []
626
+
627
+ i: int
628
+ for i in range(1, len(all_actions) + 1):
629
+ decision_type: str = getattr(parsed_output, f"decision_{i}")
630
+
631
+ if decision_type == "approve":
632
+ decisions.append({"type": "approve"}) # type: ignore
633
+ elif decision_type == "reject":
634
+ message: Optional[str] = getattr(
635
+ parsed_output, f"decision_{i}_message", None
636
+ )
637
+ reject_decision: RejectDecision = {"type": "reject"}
638
+ if message:
639
+ reject_decision["message"] = message
640
+ decisions.append(reject_decision) # type: ignore
641
+ elif decision_type == "edit":
642
+ edited_args: Optional[dict[str, Any]] = getattr(
643
+ parsed_output, f"decision_{i}_edited_args", None
644
+ )
645
+ action: ActionRequest = all_actions[i - 1]
646
+ tool_name: str = action.get("name", "")
647
+ original_args: dict[str, Any] = action.get("args", {})
648
+
649
+ # Merge original args with edited args
650
+ final_args: dict[str, Any] = {**original_args, **(edited_args or {})}
651
+
652
+ edit_decision: EditDecision = {
653
+ "type": "edit",
654
+ "edited_action": {
655
+ "name": tool_name,
656
+ "args": final_args,
657
+ },
658
+ }
659
+ decisions.append(edit_decision) # type: ignore
660
+
661
+ return decisions
662
+
663
+
664
+ def handle_interrupt_response(
665
+ snapshot: StateSnapshot,
666
+ messages: list[BaseMessage],
667
+ model: Optional[LanguageModelLike] = None,
668
+ ) -> dict[str, Any]:
669
+ """
670
+ Parse user's natural language response to interrupts using LLM with structured output.
671
+
672
+ This function uses an LLM to understand the user's intent and extract structured decisions
673
+ for each pending action. The schema is dynamically created based on the pending actions.
674
+ Includes validation to ensure the response is complete and valid.
675
+
676
+ Args:
677
+ snapshot: The current state snapshot containing interrupts
678
+ messages: List of messages, from which the last human message will be extracted
679
+ model: Optional LLM to use for parsing. Defaults to Llama 3.1 70B
680
+
681
+ Returns:
682
+ Dictionary with:
683
+ - "is_valid": bool indicating if the response is valid
684
+ - "validation_message": Optional message if invalid, explaining what's missing
685
+ - "decisions": list of Decision objects (empty if invalid)
686
+
687
+ Example:
688
+ Valid: {"is_valid": True, "validation_message": None, "decisions": [{"type": "approve"}]}
689
+ Invalid: {"is_valid": False, "validation_message": "Please specify...", "decisions": []}
690
+ """
691
+ # Extract the last human message
692
+ user_message_obj: Optional[HumanMessage] = last_human_message(messages)
693
+
694
+ if not user_message_obj:
695
+ logger.warning("HITL: No human message found in interrupt response")
696
+ return {
697
+ "is_valid": False,
698
+ "validation_message": "No user message found. Please provide a response to the pending action(s).",
699
+ "decisions": [],
700
+ }
701
+
702
+ user_message: str = str(user_message_obj.content)
703
+ logger.info(
704
+ "HITL: Parsing user interrupt response", message_preview=user_message[:100]
705
+ )
706
+
707
+ if not model:
708
+ model = ChatDatabricks(
709
+ endpoint="databricks-claude-sonnet-4",
710
+ temperature=0,
711
+ )
712
+
713
+ # Extract interrupt data
714
+ if not snapshot.interrupts:
715
+ logger.warning("HITL: No interrupts found in snapshot")
716
+ return {"decisions": []}
717
+
718
+ interrupt_data: list[HITLRequest] = [
719
+ _extract_interrupt_value(interrupt) for interrupt in snapshot.interrupts
720
+ ]
721
+
722
+ # Collect all actions for context
723
+ all_actions: list[ActionRequest] = []
724
+ hitl_request: HITLRequest
725
+ for hitl_request in interrupt_data:
726
+ all_actions.extend(hitl_request.get("action_requests", []))
727
+
728
+ if not all_actions:
729
+ logger.warning("HITL: No actions found in interrupts")
730
+ return {"decisions": []}
731
+
732
+ # Create dynamic schema
733
+ DecisionsModel: type[BaseModel] = _create_decision_schema(interrupt_data)
734
+
735
+ # Create structured LLM
736
+ structured_llm: LanguageModelLike = model.with_structured_output(DecisionsModel)
737
+
738
+ # Format action context for the LLM
739
+ action_descriptions: list[str] = []
740
+ i: int
741
+ action: ActionRequest
742
+ for i, action in enumerate(all_actions, 1):
743
+ tool_name: str = action.get("name", "unknown")
744
+ args: dict[str, Any] = action.get("args", {})
745
+ args_str: str = (
746
+ ", ".join(f"{k}={v}" for k, v in args.items()) if args else "(no args)"
747
+ )
748
+ action_descriptions.append(f"Action {i}: {tool_name}({args_str})")
749
+
750
+ system_prompt: str = f"""You are parsing a user's response to interrupted agent actions.
751
+
752
+ The following actions are pending approval:
753
+ {chr(10).join(action_descriptions)}
754
+
755
+ Your task is to extract the user's decision for EACH action in order. The user may:
756
+ - Approve: Accept the action as-is
757
+ - Reject: Cancel the action (optionally with a reason/message)
758
+ - Edit: Modify the arguments before executing
759
+
760
+ VALIDATION:
761
+ - Set is_valid=True only if you can confidently extract decisions for ALL actions
762
+ - Set is_valid=False if the user's message is:
763
+ * Unclear or ambiguous
764
+ * Missing decisions for some actions
765
+ * Asking a question instead of providing decisions
766
+ * Not addressing the actions at all
767
+ - If is_valid=False, provide a clear validation_message explaining what is needed
768
+
769
+ FLEXIBILITY:
770
+ - Be flexible in parsing informal language like "yes", "no", "ok", "change X to Y"
771
+ - If the user doesn't explicitly mention an action, assume they want to approve it
772
+ - Only mark as invalid if the message is genuinely unclear or incomplete"""
773
+
774
+ try:
775
+ # Invoke LLM with structured output
776
+ parsed: BaseModel = structured_llm.invoke(
777
+ [
778
+ SystemMessage(content=system_prompt),
779
+ HumanMessage(content=user_message),
780
+ ]
781
+ )
782
+
783
+ # Check validation first
784
+ is_valid: bool = getattr(parsed, "is_valid", True)
785
+ validation_message: Optional[str] = getattr(parsed, "validation_message", None)
786
+
787
+ if not is_valid:
788
+ logger.warning(
789
+ "HITL: Invalid user response", reason=validation_message or "Unknown"
790
+ )
791
+ return {
792
+ "is_valid": False,
793
+ "validation_message": validation_message
794
+ or "Your response was unclear. Please provide a clear decision for each action.",
795
+ "decisions": [],
796
+ }
797
+
798
+ # Convert to Decision format
799
+ decisions: list[Decision] = _convert_schema_to_decisions(parsed, interrupt_data)
800
+
801
+ logger.info("HITL: Parsed interrupt decisions", decisions_count=len(decisions))
802
+ return {"is_valid": True, "validation_message": None, "decisions": decisions}
803
+
804
+ except Exception as e:
805
+ logger.error("HITL: Failed to parse interrupt response", error=str(e))
806
+ # Return invalid response on parsing failure
807
+ return {
808
+ "is_valid": False,
809
+ "validation_message": f"Failed to parse your response: {str(e)}. Please provide a clear decision for each action.",
810
+ "decisions": [],
811
+ }
812
+
813
+
310
814
  class LanggraphResponsesAgent(ResponsesAgent):
311
815
  """
312
816
  ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
@@ -315,38 +819,191 @@ class LanggraphResponsesAgent(ResponsesAgent):
315
819
  support for streaming, tool calling, and async execution.
316
820
  """
317
821
 
318
- def __init__(self, graph: CompiledStateGraph) -> None:
822
+ def __init__(
823
+ self,
824
+ graph: CompiledStateGraph,
825
+ ) -> None:
319
826
  self.graph = graph
320
827
 
321
828
  def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
322
829
  """
323
830
  Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
831
+
832
+ Input structure (custom_inputs):
833
+ configurable:
834
+ thread_id: "abc-123" # Or conversation_id (aliases, conversation_id takes precedence)
835
+ user_id: "nate.fleming"
836
+ store_num: "87887"
837
+ session: # Paste from previous output
838
+ conversation_id: "abc-123" # Alias of thread_id
839
+ genie:
840
+ spaces:
841
+ space_123: {conversation_id: "conv_456", ...}
842
+ decisions: # For resuming interrupted graphs (HITL)
843
+ - type: "approve"
844
+ - type: "reject"
845
+ message: "Not authorized"
846
+
847
+ Output structure (custom_outputs):
848
+ configurable:
849
+ thread_id: "abc-123" # Only thread_id in configurable
850
+ user_id: "nate.fleming"
851
+ store_num: "87887"
852
+ session:
853
+ conversation_id: "abc-123" # conversation_id in session
854
+ genie:
855
+ spaces:
856
+ space_123: {conversation_id: "conv_456", ...}
857
+ pending_actions: # If HITL interrupt occurred
858
+ - name: "send_email"
859
+ arguments: {...}
860
+ description: "..."
324
861
  """
325
- logger.debug(f"ResponsesAgent request: {request}")
862
+ # Extract conversation_id for logging (from context or custom_inputs)
863
+ conversation_id_for_log: str | None = None
864
+ if request.context and hasattr(request.context, "conversation_id"):
865
+ conversation_id_for_log = request.context.conversation_id
866
+ elif request.custom_inputs:
867
+ # Check configurable or session for conversation_id
868
+ if "configurable" in request.custom_inputs and isinstance(
869
+ request.custom_inputs["configurable"], dict
870
+ ):
871
+ conversation_id_for_log = request.custom_inputs["configurable"].get(
872
+ "conversation_id"
873
+ )
874
+ if (
875
+ conversation_id_for_log is None
876
+ and "session" in request.custom_inputs
877
+ and isinstance(request.custom_inputs["session"], dict)
878
+ ):
879
+ conversation_id_for_log = request.custom_inputs["session"].get(
880
+ "conversation_id"
881
+ )
882
+
883
+ logger.debug(
884
+ "ResponsesAgent predict called",
885
+ conversation_id=conversation_id_for_log
886
+ if conversation_id_for_log
887
+ else "new",
888
+ )
326
889
 
327
890
  # Convert ResponsesAgent input to LangChain messages
328
- messages = self._convert_request_to_langchain_messages(request)
891
+ messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
892
+ request
893
+ )
329
894
 
330
- # Prepare context
895
+ # Prepare context (conversation_id -> thread_id mapping happens here)
331
896
  context: Context = self._convert_request_to_context(request)
332
897
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
333
898
 
899
+ # Extract session state from request
900
+ session_input: dict[str, Any] = self._extract_session_from_request(request)
901
+
334
902
  # Use async ainvoke internally for parallel execution
335
903
  import asyncio
336
904
 
905
+ from langgraph.types import Command
906
+
337
907
  async def _async_invoke():
338
908
  try:
909
+ # Check if this is a resume request (HITL)
910
+ # Two ways to resume:
911
+ # 1. Explicit decisions in custom_inputs (structured)
912
+ # 2. Natural language message when graph is interrupted (LLM-parsed)
913
+
914
+ if request.custom_inputs and "decisions" in request.custom_inputs:
915
+ # Explicit structured decisions
916
+ decisions: list[Decision] = request.custom_inputs["decisions"]
917
+ logger.info(
918
+ "HITL: Resuming with explicit decisions",
919
+ decisions_count=len(decisions),
920
+ )
921
+
922
+ # Resume interrupted graph with decisions
923
+ return await self.graph.ainvoke(
924
+ Command(resume={"decisions": decisions}),
925
+ context=context,
926
+ config=custom_inputs,
927
+ )
928
+
929
+ # Check if graph is currently interrupted (only if checkpointer is configured)
930
+ # aget_state requires a checkpointer
931
+ if self.graph.checkpointer:
932
+ snapshot: StateSnapshot = await self.graph.aget_state(
933
+ config=custom_inputs
934
+ )
935
+ if is_interrupted(snapshot):
936
+ logger.info(
937
+ "HITL: Graph interrupted, checking for user response"
938
+ )
939
+
940
+ # Convert message dicts to BaseMessage objects
941
+ message_objects: list[BaseMessage] = convert_openai_messages(
942
+ messages
943
+ )
944
+
945
+ # Parse user's message with LLM to extract decisions
946
+ parsed_result: dict[str, Any] = handle_interrupt_response(
947
+ snapshot=snapshot,
948
+ messages=message_objects,
949
+ model=None, # Uses default model
950
+ )
951
+
952
+ # Check if the response was valid
953
+ if not parsed_result.get("is_valid", False):
954
+ validation_message: str = parsed_result.get(
955
+ "validation_message",
956
+ "Your response was unclear. Please provide a clear decision for each action.",
957
+ )
958
+ logger.warning(
959
+ "HITL: Invalid response from user",
960
+ validation_message=validation_message,
961
+ )
962
+
963
+ # Return error message to user instead of resuming
964
+ # Don't resume the graph - stay interrupted so user can try again
965
+ return {
966
+ "messages": [
967
+ AIMessage(
968
+ content=f"❌ **Invalid Response**\n\n{validation_message}"
969
+ )
970
+ ]
971
+ }
972
+
973
+ decisions: list[Decision] = parsed_result.get("decisions", [])
974
+ logger.info(
975
+ "HITL: LLM parsed valid decisions from user message",
976
+ decisions_count=len(decisions),
977
+ )
978
+
979
+ # Resume interrupted graph with parsed decisions
980
+ return await self.graph.ainvoke(
981
+ Command(resume={"decisions": decisions}),
982
+ context=context,
983
+ config=custom_inputs,
984
+ )
985
+
986
+ # Normal invocation - build the graph input state
987
+ graph_input: dict[str, Any] = {"messages": messages}
988
+ if "genie_conversation_ids" in session_input:
989
+ graph_input["genie_conversation_ids"] = session_input[
990
+ "genie_conversation_ids"
991
+ ]
992
+ logger.trace(
993
+ "Including genie conversation IDs in graph input",
994
+ count=len(graph_input["genie_conversation_ids"]),
995
+ )
996
+
339
997
  return await self.graph.ainvoke(
340
- {"messages": messages}, context=context, config=custom_inputs
998
+ graph_input, context=context, config=custom_inputs
341
999
  )
342
1000
  except Exception as e:
343
- logger.error(f"Error in graph.ainvoke: {e}")
1001
+ logger.error("Error in graph invocation", error=str(e))
344
1002
  raise
345
1003
 
346
1004
  try:
347
1005
  loop = asyncio.get_event_loop()
348
1006
  except RuntimeError:
349
- # Handle case where no event loop exists (common in some deployment scenarios)
350
1007
  loop = asyncio.new_event_loop()
351
1008
  asyncio.set_event_loop(loop)
352
1009
 
@@ -355,28 +1012,93 @@ class LanggraphResponsesAgent(ResponsesAgent):
355
1012
  _async_invoke()
356
1013
  )
357
1014
  except Exception as e:
358
- logger.error(f"Error in async execution: {e}")
1015
+ logger.error("Error in async execution", error=str(e))
359
1016
  raise
360
1017
 
361
1018
  # Convert response to ResponsesAgent format
362
1019
  last_message: BaseMessage = response["messages"][-1]
363
1020
 
364
- output_item = self.create_text_output_item(
365
- text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
1021
+ # Build custom_outputs that can be copy-pasted as next request's custom_inputs
1022
+ custom_outputs: dict[str, Any] = self._build_custom_outputs(
1023
+ context=context,
1024
+ thread_id=context.thread_id,
1025
+ loop=loop,
366
1026
  )
367
1027
 
368
- # Retrieve genie_conversation_ids from state if available
369
- custom_outputs: dict[str, Any] = custom_inputs.copy()
370
- thread_id: Optional[str] = context.thread_id
371
- if thread_id:
372
- state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
373
- get_state_snapshot_async(self.graph, thread_id)
1028
+ # Handle structured_response if present
1029
+ if "structured_response" in response:
1030
+ from dataclasses import asdict, is_dataclass
1031
+
1032
+ from pydantic import BaseModel
1033
+
1034
+ structured_response = response["structured_response"]
1035
+ logger.trace(
1036
+ "Processing structured response",
1037
+ response_type=type(structured_response).__name__,
374
1038
  )
375
- genie_conversation_ids: dict[str, str] = (
376
- get_genie_conversation_ids_from_state(state_snapshot)
1039
+
1040
+ # Serialize to dict for JSON compatibility using type hints
1041
+ if isinstance(structured_response, BaseModel):
1042
+ # Pydantic model
1043
+ serialized: dict[str, Any] = structured_response.model_dump()
1044
+ elif is_dataclass(structured_response):
1045
+ # Dataclass
1046
+ serialized = asdict(structured_response)
1047
+ elif isinstance(structured_response, dict):
1048
+ # Already a dict
1049
+ serialized = structured_response
1050
+ else:
1051
+ # Unknown type, convert to dict if possible
1052
+ serialized = (
1053
+ dict(structured_response)
1054
+ if hasattr(structured_response, "__dict__")
1055
+ else structured_response
1056
+ )
1057
+
1058
+ # Place structured output in message content as JSON
1059
+ import json
1060
+
1061
+ structured_text: str = json.dumps(serialized, indent=2)
1062
+ output_item = self.create_text_output_item(
1063
+ text=structured_text, id=f"msg_{uuid.uuid4().hex[:8]}"
377
1064
  )
378
- if genie_conversation_ids:
379
- custom_outputs["genie_conversation_ids"] = genie_conversation_ids
1065
+ logger.trace("Structured response placed in message content")
1066
+ else:
1067
+ # No structured response, use text content
1068
+ output_item = self.create_text_output_item(
1069
+ text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
1070
+ )
1071
+
1072
+ # Include interrupt structure if HITL occurred (following LangChain pattern)
1073
+ if "__interrupt__" in response:
1074
+ interrupts: list[Interrupt] = response["__interrupt__"]
1075
+ logger.info("HITL: Interrupts detected", interrupts_count=len(interrupts))
1076
+
1077
+ # Extract HITLRequest structures from interrupts (deduplicate by ID)
1078
+ seen_interrupt_ids: set[str] = set()
1079
+ interrupt_data: list[HITLRequest] = []
1080
+ interrupt: Interrupt
1081
+ for interrupt in interrupts:
1082
+ # Only process each unique interrupt once
1083
+ if interrupt.id not in seen_interrupt_ids:
1084
+ seen_interrupt_ids.add(interrupt.id)
1085
+ interrupt_data.append(_extract_interrupt_value(interrupt))
1086
+ logger.trace(
1087
+ "HITL: Added interrupt to response", interrupt_id=interrupt.id
1088
+ )
1089
+
1090
+ custom_outputs["interrupts"] = interrupt_data
1091
+ logger.debug(
1092
+ "HITL: Included interrupts in response",
1093
+ interrupts_count=len(interrupt_data),
1094
+ )
1095
+
1096
+ # Add user-facing message about the pending actions
1097
+ action_message: str = _format_action_requests_message(interrupt_data)
1098
+ if action_message:
1099
+ output_item = self.create_text_output_item(
1100
+ text=action_message, id=f"msg_{uuid.uuid4().hex[:8]}"
1101
+ )
380
1102
 
381
1103
  return ResponsesAgentResponse(
382
1104
  output=[output_item], custom_outputs=custom_outputs
@@ -387,90 +1109,354 @@ class LanggraphResponsesAgent(ResponsesAgent):
387
1109
  ) -> Generator[ResponsesAgentStreamEvent, None, None]:
388
1110
  """
389
1111
  Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
1112
+
1113
+ Uses same input/output structure as predict() for consistency.
1114
+ Supports Human-in-the-Loop (HITL) interrupts.
390
1115
  """
391
- logger.debug(f"ResponsesAgent stream request: {request}")
1116
+ # Extract conversation_id for logging (from context or custom_inputs)
1117
+ conversation_id_for_log: str | None = None
1118
+ if request.context and hasattr(request.context, "conversation_id"):
1119
+ conversation_id_for_log = request.context.conversation_id
1120
+ elif request.custom_inputs:
1121
+ # Check configurable or session for conversation_id
1122
+ if "configurable" in request.custom_inputs and isinstance(
1123
+ request.custom_inputs["configurable"], dict
1124
+ ):
1125
+ conversation_id_for_log = request.custom_inputs["configurable"].get(
1126
+ "conversation_id"
1127
+ )
1128
+ if (
1129
+ conversation_id_for_log is None
1130
+ and "session" in request.custom_inputs
1131
+ and isinstance(request.custom_inputs["session"], dict)
1132
+ ):
1133
+ conversation_id_for_log = request.custom_inputs["session"].get(
1134
+ "conversation_id"
1135
+ )
1136
+
1137
+ logger.debug(
1138
+ "ResponsesAgent predict_stream called",
1139
+ conversation_id=conversation_id_for_log
1140
+ if conversation_id_for_log
1141
+ else "new",
1142
+ )
392
1143
 
393
1144
  # Convert ResponsesAgent input to LangChain messages
394
- messages: list[BaseMessage] = self._convert_request_to_langchain_messages(
1145
+ messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
395
1146
  request
396
1147
  )
397
1148
 
398
- # Prepare context
1149
+ # Prepare context (conversation_id -> thread_id mapping happens here)
399
1150
  context: Context = self._convert_request_to_context(request)
400
1151
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
401
1152
 
1153
+ # Extract session state from request
1154
+ session_input: dict[str, Any] = self._extract_session_from_request(request)
1155
+
402
1156
  # Use async astream internally for parallel execution
403
1157
  import asyncio
404
1158
 
1159
+ from langgraph.types import Command
1160
+
405
1161
  async def _async_stream():
406
- item_id = f"msg_{uuid.uuid4().hex[:8]}"
407
- accumulated_content = ""
1162
+ item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
1163
+ accumulated_content: str = ""
1164
+ interrupt_data: list[HITLRequest] = []
1165
+ seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
1166
+ structured_response: Any = None # Track structured output from stream
408
1167
 
409
1168
  try:
410
- async for nodes, stream_mode, messages_batch in self.graph.astream(
411
- {"messages": messages},
1169
+ # Check if this is a resume request (HITL)
1170
+ # Two ways to resume:
1171
+ # 1. Explicit decisions in custom_inputs (structured)
1172
+ # 2. Natural language message when graph is interrupted (LLM-parsed)
1173
+
1174
+ if request.custom_inputs and "decisions" in request.custom_inputs:
1175
+ # Explicit structured decisions
1176
+ decisions: list[Decision] = request.custom_inputs["decisions"]
1177
+ logger.info(
1178
+ "HITL: Resuming stream with explicit decisions",
1179
+ decisions_count=len(decisions),
1180
+ )
1181
+ stream_input: Command | dict[str, Any] = Command(
1182
+ resume={"decisions": decisions}
1183
+ )
1184
+ elif self.graph.checkpointer:
1185
+ # Check if graph is currently interrupted (only if checkpointer is configured)
1186
+ # aget_state requires a checkpointer
1187
+ snapshot: StateSnapshot = await self.graph.aget_state(
1188
+ config=custom_inputs
1189
+ )
1190
+ if is_interrupted(snapshot):
1191
+ logger.info(
1192
+ "HITL: Graph interrupted, checking for user response in stream"
1193
+ )
1194
+
1195
+ # Convert message dicts to BaseMessage objects
1196
+ message_objects: list[BaseMessage] = convert_openai_messages(
1197
+ messages
1198
+ )
1199
+
1200
+ # Parse user's message with LLM to extract decisions
1201
+ parsed_result: dict[str, Any] = handle_interrupt_response(
1202
+ snapshot=snapshot,
1203
+ messages=message_objects,
1204
+ model=None, # Uses default model
1205
+ )
1206
+
1207
+ # Check if the response was valid
1208
+ if not parsed_result.get("is_valid", False):
1209
+ validation_message: str = parsed_result.get(
1210
+ "validation_message",
1211
+ "Your response was unclear. Please provide a clear decision for each action.",
1212
+ )
1213
+ logger.warning(
1214
+ "HITL: Invalid response from user in stream",
1215
+ validation_message=validation_message,
1216
+ )
1217
+
1218
+ # Build custom_outputs before returning
1219
+ custom_outputs: dict[
1220
+ str, Any
1221
+ ] = await self._build_custom_outputs_async(
1222
+ context=context,
1223
+ thread_id=context.thread_id,
1224
+ )
1225
+
1226
+ # Yield error message to user - don't resume graph
1227
+ error_message: str = (
1228
+ f"❌ **Invalid Response**\n\n{validation_message}"
1229
+ )
1230
+ accumulated_content = error_message
1231
+ yield ResponsesAgentStreamEvent(
1232
+ type="response.output_item.done",
1233
+ item=self.create_text_output_item(
1234
+ text=error_message, id=item_id
1235
+ ),
1236
+ custom_outputs=custom_outputs,
1237
+ )
1238
+ return # Don't resume - stay interrupted
1239
+
1240
+ decisions: list[Decision] = parsed_result.get("decisions", [])
1241
+ logger.info(
1242
+ "HITL: LLM parsed valid decisions from user message in stream",
1243
+ decisions_count=len(decisions),
1244
+ )
1245
+
1246
+ # Resume interrupted graph with parsed decisions
1247
+ stream_input: Command | dict[str, Any] = Command(
1248
+ resume={"decisions": decisions}
1249
+ )
1250
+ else:
1251
+ # Graph not interrupted, use normal invocation
1252
+ graph_input: dict[str, Any] = {"messages": messages}
1253
+ if "genie_conversation_ids" in session_input:
1254
+ graph_input["genie_conversation_ids"] = session_input[
1255
+ "genie_conversation_ids"
1256
+ ]
1257
+ stream_input: Command | dict[str, Any] = graph_input
1258
+ else:
1259
+ # No checkpointer, use normal invocation
1260
+ graph_input: dict[str, Any] = {"messages": messages}
1261
+ if "genie_conversation_ids" in session_input:
1262
+ graph_input["genie_conversation_ids"] = session_input[
1263
+ "genie_conversation_ids"
1264
+ ]
1265
+ stream_input: Command | dict[str, Any] = graph_input
1266
+
1267
+ # Stream the graph execution with both messages and updates modes to capture interrupts
1268
+ async for nodes, stream_mode, data in self.graph.astream(
1269
+ stream_input,
412
1270
  context=context,
413
1271
  config=custom_inputs,
414
- stream_mode=["messages", "custom"],
1272
+ stream_mode=["messages", "updates"],
415
1273
  subgraphs=True,
416
1274
  ):
417
1275
  nodes: tuple[str, ...]
418
1276
  stream_mode: str
419
- messages_batch: Sequence[BaseMessage]
420
1277
 
421
- for message in messages_batch:
422
- if (
423
- isinstance(
424
- message,
425
- (
426
- AIMessageChunk,
427
- AIMessage,
428
- ),
429
- )
430
- and message.content
431
- and "summarization" not in nodes
432
- ):
433
- content = message.content
434
- accumulated_content += content
1278
+ # Handle message streaming
1279
+ if stream_mode == "messages":
1280
+ messages_batch: Sequence[BaseMessage] = data
1281
+ message: BaseMessage
1282
+ for message in messages_batch:
1283
+ if (
1284
+ isinstance(
1285
+ message,
1286
+ (
1287
+ AIMessageChunk,
1288
+ AIMessage,
1289
+ ),
1290
+ )
1291
+ and message.content
1292
+ and "summarization" not in nodes
1293
+ ):
1294
+ content: str = message.content
1295
+ accumulated_content += content
1296
+
1297
+ # Yield streaming delta
1298
+ yield ResponsesAgentStreamEvent(
1299
+ **self.create_text_delta(
1300
+ delta=content, item_id=item_id
1301
+ )
1302
+ )
1303
+
1304
+ # Handle interrupts (HITL) and state updates
1305
+ elif stream_mode == "updates":
1306
+ updates: dict[str, Any] = data
1307
+ source: str
1308
+ update: Any
1309
+ for source, update in updates.items():
1310
+ if source == "__interrupt__":
1311
+ interrupts: list[Interrupt] = update
1312
+ logger.info(
1313
+ "HITL: Interrupts detected during streaming",
1314
+ interrupts_count=len(interrupts),
1315
+ )
1316
+
1317
+ # Extract interrupt values (deduplicate by ID)
1318
+ interrupt: Interrupt
1319
+ for interrupt in interrupts:
1320
+ # Only process each unique interrupt once
1321
+ if interrupt.id not in seen_interrupt_ids:
1322
+ seen_interrupt_ids.add(interrupt.id)
1323
+ interrupt_data.append(
1324
+ _extract_interrupt_value(interrupt)
1325
+ )
1326
+ logger.trace(
1327
+ "HITL: Added interrupt to response",
1328
+ interrupt_id=interrupt.id,
1329
+ )
1330
+ elif (
1331
+ isinstance(update, dict)
1332
+ and "structured_response" in update
1333
+ ):
1334
+ # Capture structured_response from stream updates
1335
+ structured_response = update["structured_response"]
1336
+ logger.trace(
1337
+ "Captured structured response from stream",
1338
+ response_type=type(structured_response).__name__,
1339
+ )
1340
+
1341
+ # Get final state to extract structured_response (only if checkpointer available)
1342
+ if self.graph.checkpointer:
1343
+ final_state: StateSnapshot = await self.graph.aget_state(
1344
+ config=custom_inputs
1345
+ )
1346
+ # Extract structured_response from state if not already captured
1347
+ if (
1348
+ "structured_response" in final_state.values
1349
+ and not structured_response
1350
+ ):
1351
+ structured_response = final_state.values["structured_response"]
435
1352
 
436
- # Yield streaming delta
437
- yield ResponsesAgentStreamEvent(
438
- **self.create_text_delta(delta=content, item_id=item_id)
439
- )
1353
+ # Build custom_outputs
1354
+ custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
1355
+ context=context,
1356
+ thread_id=context.thread_id,
1357
+ )
440
1358
 
441
- # Retrieve genie_conversation_ids from state if available
442
- custom_outputs: dict[str, Any] = custom_inputs.copy()
443
- thread_id: Optional[str] = context.thread_id
1359
+ # Handle structured_response in streaming if present
1360
+ output_text: str = accumulated_content
1361
+ if structured_response:
1362
+ from dataclasses import asdict, is_dataclass
444
1363
 
445
- if thread_id:
446
- state_snapshot: Optional[
447
- StateSnapshot
448
- ] = await get_state_snapshot_async(self.graph, thread_id)
449
- genie_conversation_ids: dict[str, str] = (
450
- get_genie_conversation_ids_from_state(state_snapshot)
1364
+ from pydantic import BaseModel
1365
+
1366
+ logger.trace(
1367
+ "Processing structured response in streaming",
1368
+ response_type=type(structured_response).__name__,
451
1369
  )
452
- if genie_conversation_ids:
453
- custom_outputs["genie_conversation_ids"] = (
454
- genie_conversation_ids
1370
+
1371
+ # Serialize to dict for JSON compatibility using type hints
1372
+ if isinstance(structured_response, BaseModel):
1373
+ serialized: dict[str, Any] = structured_response.model_dump()
1374
+ elif is_dataclass(structured_response):
1375
+ serialized = asdict(structured_response)
1376
+ elif isinstance(structured_response, dict):
1377
+ serialized = structured_response
1378
+ else:
1379
+ serialized = (
1380
+ dict(structured_response)
1381
+ if hasattr(structured_response, "__dict__")
1382
+ else structured_response
1383
+ )
1384
+
1385
+ # Place structured output in message content - stream as JSON
1386
+ import json
1387
+
1388
+ structured_text: str = json.dumps(serialized, indent=2)
1389
+
1390
+ # If we streamed text, append structured; if no text, use structured only
1391
+ if accumulated_content.strip():
1392
+ # Stream separator and structured output
1393
+ yield ResponsesAgentStreamEvent(
1394
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
1395
+ )
1396
+ yield ResponsesAgentStreamEvent(
1397
+ **self.create_text_delta(
1398
+ delta=structured_text, item_id=item_id
1399
+ )
1400
+ )
1401
+ output_text = f"{accumulated_content}\n\n{structured_text}"
1402
+ else:
1403
+ # No text content, stream structured output
1404
+ yield ResponsesAgentStreamEvent(
1405
+ **self.create_text_delta(
1406
+ delta=structured_text, item_id=item_id
1407
+ )
455
1408
  )
1409
+ output_text = structured_text
1410
+
1411
+ logger.trace("Streamed structured response in message content")
1412
+
1413
+ # Include interrupt structure if HITL occurred
1414
+ if interrupt_data:
1415
+ custom_outputs["interrupts"] = interrupt_data
1416
+ logger.info(
1417
+ "HITL: Included interrupts in streaming response",
1418
+ interrupts_count=len(interrupt_data),
1419
+ )
1420
+
1421
+ # Add user-facing message about the pending actions
1422
+ action_message = _format_action_requests_message(interrupt_data)
1423
+ if action_message:
1424
+ # If we haven't streamed any content yet, stream the action message
1425
+ if not accumulated_content:
1426
+ output_text = action_message
1427
+ # Stream the action message
1428
+ yield ResponsesAgentStreamEvent(
1429
+ **self.create_text_delta(
1430
+ delta=action_message, item_id=item_id
1431
+ )
1432
+ )
1433
+ else:
1434
+ # Append action message after accumulated content
1435
+ output_text = f"{accumulated_content}\n\n{action_message}"
1436
+ # Stream the separator and action message
1437
+ yield ResponsesAgentStreamEvent(
1438
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
1439
+ )
1440
+ yield ResponsesAgentStreamEvent(
1441
+ **self.create_text_delta(
1442
+ delta=action_message, item_id=item_id
1443
+ )
1444
+ )
456
1445
 
457
1446
  # Yield final output item
458
1447
  yield ResponsesAgentStreamEvent(
459
1448
  type="response.output_item.done",
460
- item=self.create_text_output_item(
461
- text=accumulated_content, id=item_id
462
- ),
1449
+ item=self.create_text_output_item(text=output_text, id=item_id),
463
1450
  custom_outputs=custom_outputs,
464
1451
  )
465
1452
  except Exception as e:
466
- logger.error(f"Error in graph.astream: {e}")
1453
+ logger.error("Error in graph streaming", error=str(e))
467
1454
  raise
468
1455
 
469
1456
  # Convert async generator to sync generator
470
1457
  try:
471
1458
  loop = asyncio.get_event_loop()
472
1459
  except RuntimeError:
473
- # Handle case where no event loop exists (common in some deployment scenarios)
474
1460
  loop = asyncio.new_event_loop()
475
1461
  asyncio.set_event_loop(loop)
476
1462
 
@@ -484,13 +1470,13 @@ class LanggraphResponsesAgent(ResponsesAgent):
484
1470
  except StopAsyncIteration:
485
1471
  break
486
1472
  except Exception as e:
487
- logger.error(f"Error in streaming: {e}")
1473
+ logger.error("Error in streaming", error=str(e))
488
1474
  raise
489
1475
  finally:
490
1476
  try:
491
1477
  loop.run_until_complete(async_gen.aclose())
492
1478
  except Exception as e:
493
- logger.warning(f"Error closing async generator: {e}")
1479
+ logger.warning("Error closing async generator", error=str(e))
494
1480
 
495
1481
  def _extract_text_from_content(
496
1482
  self,
@@ -555,15 +1541,27 @@ class LanggraphResponsesAgent(ResponsesAgent):
555
1541
  return messages
556
1542
 
557
1543
  def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
558
- """Convert ResponsesAgent context to internal Context."""
1544
+ """Convert ResponsesAgent context to internal Context.
1545
+
1546
+ Handles the input structure:
1547
+ - custom_inputs.configurable: Configuration (thread_id, user_id, store_num, etc.)
1548
+ - custom_inputs.session: Accumulated state (conversation_id, genie conversations, etc.)
559
1549
 
560
- logger.debug(f"request.context: {request.context}")
561
- logger.debug(f"request.custom_inputs: {request.custom_inputs}")
1550
+ Maps conversation_id -> thread_id for LangGraph compatibility.
1551
+ conversation_id can be provided in either configurable or session.
1552
+ Normalizes user_id (replaces . with _) for memory namespace compatibility.
1553
+ """
1554
+ logger.trace(
1555
+ "Converting request to context",
1556
+ has_context=request.context is not None,
1557
+ has_custom_inputs=request.custom_inputs is not None,
1558
+ )
562
1559
 
563
1560
  configurable: dict[str, Any] = {}
1561
+ session: dict[str, Any] = {}
564
1562
 
565
1563
  # Process context values first (lower priority)
566
- # Use strong typing with forward-declared type hints instead of hasattr checks
1564
+ # These come from Databricks ResponsesAgent ChatContext
567
1565
  chat_context: Optional[ChatContext] = request.context
568
1566
  if chat_context is not None:
569
1567
  conversation_id: Optional[str] = chat_context.conversation_id
@@ -571,27 +1569,189 @@ class LanggraphResponsesAgent(ResponsesAgent):
571
1569
 
572
1570
  if conversation_id is not None:
573
1571
  configurable["conversation_id"] = conversation_id
574
- configurable["thread_id"] = conversation_id
575
1572
 
576
1573
  if user_id is not None:
577
1574
  configurable["user_id"] = user_id
578
1575
 
579
1576
  # Process custom_inputs after context so they can override context values (higher priority)
580
1577
  if request.custom_inputs:
1578
+ # Extract configurable section (user config)
581
1579
  if "configurable" in request.custom_inputs:
582
- configurable.update(request.custom_inputs.pop("configurable"))
1580
+ configurable.update(request.custom_inputs["configurable"])
1581
+
1582
+ # Extract session section
1583
+ if "session" in request.custom_inputs:
1584
+ session_input = request.custom_inputs["session"]
1585
+ if isinstance(session_input, dict):
1586
+ session = session_input
1587
+
1588
+ # Handle legacy flat structure (backwards compatibility)
1589
+ # If user passes keys directly in custom_inputs, merge them
1590
+ for key in list(request.custom_inputs.keys()):
1591
+ if key not in ("configurable", "session"):
1592
+ configurable[key] = request.custom_inputs[key]
1593
+
1594
+ # Extract known Context fields
1595
+ user_id_value: str | None = configurable.pop("user_id", None)
1596
+ if user_id_value:
1597
+ # Normalize user_id for memory namespace (replace . with _)
1598
+ user_id_value = user_id_value.replace(".", "_")
1599
+
1600
+ # Accept thread_id from configurable, or conversation_id from configurable or session
1601
+ # Priority: configurable.conversation_id > session.conversation_id > configurable.thread_id
1602
+ thread_id: str | None = configurable.pop("thread_id", None)
1603
+ conversation_id: str | None = configurable.pop("conversation_id", None)
1604
+
1605
+ # Also check session for conversation_id (output puts it there)
1606
+ if conversation_id is None and "conversation_id" in session:
1607
+ conversation_id = session.get("conversation_id")
1608
+
1609
+ # conversation_id takes precedence if provided
1610
+ if conversation_id:
1611
+ thread_id = conversation_id
1612
+ if not thread_id:
1613
+ # Generate new thread_id if neither provided
1614
+ thread_id = str(uuid.uuid4())
1615
+
1616
+ # All remaining configurable values become top-level context attributes
1617
+ logger.trace(
1618
+ "Creating context",
1619
+ user_id=user_id_value,
1620
+ thread_id=thread_id,
1621
+ extra_keys=list(configurable.keys()) if configurable else [],
1622
+ )
1623
+
1624
+ return Context(
1625
+ user_id=user_id_value,
1626
+ thread_id=thread_id,
1627
+ **configurable, # Pass remaining configurable values as context attributes
1628
+ )
1629
+
1630
+ def _extract_session_from_request(
1631
+ self, request: ResponsesAgentRequest
1632
+ ) -> dict[str, Any]:
1633
+ """Extract session state from request for passing to graph.
1634
+
1635
+ Handles:
1636
+ - New structure: custom_inputs.session.genie
1637
+ - Legacy structure: custom_inputs.genie_conversation_ids
1638
+ """
1639
+ session: dict[str, Any] = {}
1640
+
1641
+ if not request.custom_inputs:
1642
+ return session
1643
+
1644
+ # New structure: session.genie
1645
+ if "session" in request.custom_inputs:
1646
+ session_input = request.custom_inputs["session"]
1647
+ if isinstance(session_input, dict) and "genie" in session_input:
1648
+ genie_state = session_input["genie"]
1649
+ # Extract conversation IDs from the new structure
1650
+ if isinstance(genie_state, dict) and "spaces" in genie_state:
1651
+ genie_conversation_ids = {}
1652
+ for space_id, space_state in genie_state["spaces"].items():
1653
+ if (
1654
+ isinstance(space_state, dict)
1655
+ and "conversation_id" in space_state
1656
+ ):
1657
+ genie_conversation_ids[space_id] = space_state[
1658
+ "conversation_id"
1659
+ ]
1660
+ if genie_conversation_ids:
1661
+ session["genie_conversation_ids"] = genie_conversation_ids
1662
+
1663
+ # Legacy structure: genie_conversation_ids at top level
1664
+ if "genie_conversation_ids" in request.custom_inputs:
1665
+ session["genie_conversation_ids"] = request.custom_inputs[
1666
+ "genie_conversation_ids"
1667
+ ]
1668
+
1669
+ # Also check inside configurable for legacy support
1670
+ if "configurable" in request.custom_inputs:
1671
+ cfg = request.custom_inputs["configurable"]
1672
+ if isinstance(cfg, dict) and "genie_conversation_ids" in cfg:
1673
+ session["genie_conversation_ids"] = cfg["genie_conversation_ids"]
1674
+
1675
+ return session
1676
+
1677
+ def _build_custom_outputs(
1678
+ self,
1679
+ context: Context,
1680
+ thread_id: Optional[str],
1681
+ loop: Any, # asyncio.AbstractEventLoop
1682
+ ) -> dict[str, Any]:
1683
+ """Build custom_outputs that can be copy-pasted as next request's custom_inputs.
1684
+
1685
+ Output structure:
1686
+ configurable:
1687
+ thread_id: "abc-123" # Thread identifier (conversation_id is alias)
1688
+ user_id: "nate.fleming" # De-normalized (no underscore replacement)
1689
+ store_num: "87887" # Any custom fields
1690
+ session:
1691
+ conversation_id: "abc-123" # Alias of thread_id for Databricks compatibility
1692
+ genie:
1693
+ spaces:
1694
+ space_123: {conversation_id: "conv_456", cache_hit: false, ...}
1695
+ """
1696
+ return loop.run_until_complete(
1697
+ self._build_custom_outputs_async(context=context, thread_id=thread_id)
1698
+ )
583
1699
 
584
- configurable.update(request.custom_inputs)
1700
+ async def _build_custom_outputs_async(
1701
+ self,
1702
+ context: Context,
1703
+ thread_id: Optional[str],
1704
+ ) -> dict[str, Any]:
1705
+ """Async version of _build_custom_outputs."""
1706
+ # Build configurable section
1707
+ # Note: only thread_id is included here (conversation_id goes in session)
1708
+ configurable: dict[str, Any] = {}
1709
+
1710
+ if thread_id:
1711
+ configurable["thread_id"] = thread_id
585
1712
 
586
- if "user_id" in configurable:
587
- configurable["user_id"] = configurable["user_id"].replace(".", "_")
1713
+ # Include user_id (keep normalized form for consistency)
1714
+ if context.user_id:
1715
+ configurable["user_id"] = context.user_id
588
1716
 
589
- if "thread_id" not in configurable:
590
- configurable["thread_id"] = str(uuid.uuid4())
1717
+ # Include all extra fields from context (beyond user_id and thread_id)
1718
+ context_dict = context.model_dump()
1719
+ for key, value in context_dict.items():
1720
+ if key not in {"user_id", "thread_id"} and value is not None:
1721
+ configurable[key] = value
591
1722
 
592
- logger.debug(f"Creating context from: {configurable}")
1723
+ # Build session section with accumulated state
1724
+ # Note: conversation_id is included here as an alias of thread_id
1725
+ session: dict[str, Any] = {}
593
1726
 
594
- return Context(**configurable)
1727
+ if thread_id:
1728
+ # Include conversation_id in session (alias of thread_id)
1729
+ session["conversation_id"] = thread_id
1730
+
1731
+ state_snapshot: Optional[StateSnapshot] = await get_state_snapshot_async(
1732
+ self.graph, thread_id
1733
+ )
1734
+ genie_conversation_ids: dict[str, str] = (
1735
+ get_genie_conversation_ids_from_state(state_snapshot)
1736
+ )
1737
+ if genie_conversation_ids:
1738
+ # Convert flat genie_conversation_ids to new session.genie.spaces structure
1739
+ session["genie"] = {
1740
+ "spaces": {
1741
+ space_id: {
1742
+ "conversation_id": conv_id,
1743
+ # Note: cache_hit, follow_up_questions populated by Genie tool
1744
+ "cache_hit": False,
1745
+ "follow_up_questions": [],
1746
+ }
1747
+ for space_id, conv_id in genie_conversation_ids.items()
1748
+ }
1749
+ }
1750
+
1751
+ return {
1752
+ "configurable": configurable,
1753
+ "session": session,
1754
+ }
595
1755
 
596
1756
 
597
1757
  def create_agent(graph: CompiledStateGraph) -> ChatAgent:
@@ -610,7 +1770,9 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
610
1770
  return LanggraphChatModel(graph)
611
1771
 
612
1772
 
613
- def create_responses_agent(graph: CompiledStateGraph) -> ResponsesAgent:
1773
+ def create_responses_agent(
1774
+ graph: CompiledStateGraph,
1775
+ ) -> ResponsesAgent:
614
1776
  """
615
1777
  Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
616
1778
 
@@ -645,6 +1807,29 @@ def _process_langchain_messages(
645
1807
  return loop.run_until_complete(_async_invoke())
646
1808
 
647
1809
 
1810
+ def _configurable_to_context(configurable: dict[str, Any]) -> Context:
1811
+ """Convert a configurable dict to a Context object."""
1812
+ configurable = configurable.copy()
1813
+
1814
+ # Extract known Context fields
1815
+ user_id: str | None = configurable.pop("user_id", None)
1816
+ if user_id:
1817
+ user_id = user_id.replace(".", "_")
1818
+
1819
+ thread_id: str | None = configurable.pop("thread_id", None)
1820
+ if "conversation_id" in configurable and not thread_id:
1821
+ thread_id = configurable.pop("conversation_id")
1822
+ if not thread_id:
1823
+ thread_id = str(uuid.uuid4())
1824
+
1825
+ # All remaining values become top-level context attributes
1826
+ return Context(
1827
+ user_id=user_id,
1828
+ thread_id=thread_id,
1829
+ **configurable, # Extra fields become top-level attributes
1830
+ )
1831
+
1832
+
648
1833
  def _process_langchain_messages_stream(
649
1834
  app: LanggraphChatModel | CompiledStateGraph,
650
1835
  messages: Sequence[BaseMessage],
@@ -656,10 +1841,14 @@ def _process_langchain_messages_stream(
656
1841
  if isinstance(app, LanggraphChatModel):
657
1842
  app = app.graph
658
1843
 
659
- logger.debug(f"Processing messages: {messages}, custom_inputs: {custom_inputs}")
1844
+ logger.trace(
1845
+ "Processing messages for stream",
1846
+ messages_count=len(messages),
1847
+ has_custom_inputs=custom_inputs is not None,
1848
+ )
660
1849
 
661
- custom_inputs = custom_inputs.get("configurable", custom_inputs or {})
662
- context: Context = Context(**custom_inputs)
1850
+ configurable = (custom_inputs or {}).get("configurable", custom_inputs or {})
1851
+ context: Context = _configurable_to_context(configurable)
663
1852
 
664
1853
  # Use async astream internally for parallel execution
665
1854
  async def _async_stream():
@@ -674,7 +1863,10 @@ def _process_langchain_messages_stream(
674
1863
  stream_mode: str
675
1864
  stream_messages: Sequence[BaseMessage]
676
1865
  logger.trace(
677
- f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {stream_messages}"
1866
+ "Stream batch received",
1867
+ nodes=nodes,
1868
+ stream_mode=stream_mode,
1869
+ messages_count=len(stream_messages),
678
1870
  )
679
1871
  for message in stream_messages:
680
1872
  if (