dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py
CHANGED
|
@@ -126,11 +126,11 @@ async def get_state_snapshot_async(
|
|
|
126
126
|
Returns:
|
|
127
127
|
StateSnapshot if found, None otherwise
|
|
128
128
|
"""
|
|
129
|
-
logger.
|
|
129
|
+
logger.trace("Retrieving state snapshot", thread_id=thread_id)
|
|
130
130
|
try:
|
|
131
131
|
# Check if graph has a checkpointer
|
|
132
132
|
if graph.checkpointer is None:
|
|
133
|
-
logger.
|
|
133
|
+
logger.trace("No checkpointer available in graph")
|
|
134
134
|
return None
|
|
135
135
|
|
|
136
136
|
# Get the current state from the checkpointer (use async version)
|
|
@@ -138,13 +138,15 @@ async def get_state_snapshot_async(
|
|
|
138
138
|
state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
|
|
139
139
|
|
|
140
140
|
if state_snapshot is None:
|
|
141
|
-
logger.
|
|
141
|
+
logger.trace("No state found for thread", thread_id=thread_id)
|
|
142
142
|
return None
|
|
143
143
|
|
|
144
144
|
return state_snapshot
|
|
145
145
|
|
|
146
146
|
except Exception as e:
|
|
147
|
-
logger.warning(
|
|
147
|
+
logger.warning(
|
|
148
|
+
"Error retrieving state snapshot", thread_id=thread_id, error=str(e)
|
|
149
|
+
)
|
|
148
150
|
return None
|
|
149
151
|
|
|
150
152
|
|
|
@@ -175,7 +177,7 @@ def get_state_snapshot(
|
|
|
175
177
|
try:
|
|
176
178
|
return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
|
|
177
179
|
except Exception as e:
|
|
178
|
-
logger.warning(
|
|
180
|
+
logger.warning("Error in synchronous state snapshot retrieval", error=str(e))
|
|
179
181
|
return None
|
|
180
182
|
|
|
181
183
|
|
|
@@ -207,13 +209,17 @@ def get_genie_conversation_ids_from_state(
|
|
|
207
209
|
)
|
|
208
210
|
|
|
209
211
|
if genie_conversation_ids:
|
|
210
|
-
logger.
|
|
212
|
+
logger.trace(
|
|
213
|
+
"Retrieved genie conversation IDs", count=len(genie_conversation_ids)
|
|
214
|
+
)
|
|
211
215
|
return genie_conversation_ids
|
|
212
216
|
|
|
213
217
|
return {}
|
|
214
218
|
|
|
215
219
|
except Exception as e:
|
|
216
|
-
logger.warning(
|
|
220
|
+
logger.warning(
|
|
221
|
+
"Error extracting genie conversation IDs from state", error=str(e)
|
|
222
|
+
)
|
|
217
223
|
return {}
|
|
218
224
|
|
|
219
225
|
|
|
@@ -333,7 +339,11 @@ class LanggraphChatModel(ChatModel):
|
|
|
333
339
|
def predict(
|
|
334
340
|
self, context, messages: list[ChatMessage], params: Optional[ChatParams] = None
|
|
335
341
|
) -> ChatCompletionResponse:
|
|
336
|
-
logger.
|
|
342
|
+
logger.trace(
|
|
343
|
+
"Predict called",
|
|
344
|
+
messages_count=len(messages),
|
|
345
|
+
has_params=params is not None,
|
|
346
|
+
)
|
|
337
347
|
if not messages:
|
|
338
348
|
raise ValueError("Message list is empty.")
|
|
339
349
|
|
|
@@ -355,7 +365,10 @@ class LanggraphChatModel(ChatModel):
|
|
|
355
365
|
_async_invoke()
|
|
356
366
|
)
|
|
357
367
|
|
|
358
|
-
logger.trace(
|
|
368
|
+
logger.trace(
|
|
369
|
+
"Predict response received",
|
|
370
|
+
messages_count=len(response.get("messages", [])),
|
|
371
|
+
)
|
|
359
372
|
|
|
360
373
|
last_message: BaseMessage = response["messages"][-1]
|
|
361
374
|
|
|
@@ -393,20 +406,21 @@ class LanggraphChatModel(ChatModel):
|
|
|
393
406
|
if not thread_id:
|
|
394
407
|
thread_id = str(uuid.uuid4())
|
|
395
408
|
|
|
396
|
-
# All remaining configurable values
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
context: Context = Context(
|
|
409
|
+
# All remaining configurable values become top-level context attributes
|
|
410
|
+
return Context(
|
|
400
411
|
user_id=user_id,
|
|
401
412
|
thread_id=thread_id,
|
|
402
|
-
|
|
413
|
+
**configurable, # Extra fields become top-level attributes
|
|
403
414
|
)
|
|
404
|
-
return context
|
|
405
415
|
|
|
406
416
|
def predict_stream(
|
|
407
417
|
self, context, messages: list[ChatMessage], params: ChatParams
|
|
408
418
|
) -> Generator[ChatCompletionChunk, None, None]:
|
|
409
|
-
logger.
|
|
419
|
+
logger.trace(
|
|
420
|
+
"Predict stream called",
|
|
421
|
+
messages_count=len(messages),
|
|
422
|
+
has_params=params is not None,
|
|
423
|
+
)
|
|
410
424
|
if not messages:
|
|
411
425
|
raise ValueError("Message list is empty.")
|
|
412
426
|
|
|
@@ -430,7 +444,10 @@ class LanggraphChatModel(ChatModel):
|
|
|
430
444
|
stream_mode: str
|
|
431
445
|
messages_batch: Sequence[BaseMessage]
|
|
432
446
|
logger.trace(
|
|
433
|
-
|
|
447
|
+
"Stream batch received",
|
|
448
|
+
nodes=nodes,
|
|
449
|
+
stream_mode=stream_mode,
|
|
450
|
+
messages_count=len(messages_batch),
|
|
434
451
|
)
|
|
435
452
|
for message in messages_batch:
|
|
436
453
|
if (
|
|
@@ -675,7 +692,7 @@ def handle_interrupt_response(
|
|
|
675
692
|
user_message_obj: Optional[HumanMessage] = last_human_message(messages)
|
|
676
693
|
|
|
677
694
|
if not user_message_obj:
|
|
678
|
-
logger.warning("
|
|
695
|
+
logger.warning("HITL: No human message found in interrupt response")
|
|
679
696
|
return {
|
|
680
697
|
"is_valid": False,
|
|
681
698
|
"validation_message": "No user message found. Please provide a response to the pending action(s).",
|
|
@@ -683,7 +700,9 @@ def handle_interrupt_response(
|
|
|
683
700
|
}
|
|
684
701
|
|
|
685
702
|
user_message: str = str(user_message_obj.content)
|
|
686
|
-
logger.info(
|
|
703
|
+
logger.info(
|
|
704
|
+
"HITL: Parsing user interrupt response", message_preview=user_message[:100]
|
|
705
|
+
)
|
|
687
706
|
|
|
688
707
|
if not model:
|
|
689
708
|
model = ChatDatabricks(
|
|
@@ -693,7 +712,7 @@ def handle_interrupt_response(
|
|
|
693
712
|
|
|
694
713
|
# Extract interrupt data
|
|
695
714
|
if not snapshot.interrupts:
|
|
696
|
-
logger.warning("
|
|
715
|
+
logger.warning("HITL: No interrupts found in snapshot")
|
|
697
716
|
return {"decisions": []}
|
|
698
717
|
|
|
699
718
|
interrupt_data: list[HITLRequest] = [
|
|
@@ -707,7 +726,7 @@ def handle_interrupt_response(
|
|
|
707
726
|
all_actions.extend(hitl_request.get("action_requests", []))
|
|
708
727
|
|
|
709
728
|
if not all_actions:
|
|
710
|
-
logger.warning("
|
|
729
|
+
logger.warning("HITL: No actions found in interrupts")
|
|
711
730
|
return {"decisions": []}
|
|
712
731
|
|
|
713
732
|
# Create dynamic schema
|
|
@@ -767,7 +786,7 @@ FLEXIBILITY:
|
|
|
767
786
|
|
|
768
787
|
if not is_valid:
|
|
769
788
|
logger.warning(
|
|
770
|
-
|
|
789
|
+
"HITL: Invalid user response", reason=validation_message or "Unknown"
|
|
771
790
|
)
|
|
772
791
|
return {
|
|
773
792
|
"is_valid": False,
|
|
@@ -779,11 +798,11 @@ FLEXIBILITY:
|
|
|
779
798
|
# Convert to Decision format
|
|
780
799
|
decisions: list[Decision] = _convert_schema_to_decisions(parsed, interrupt_data)
|
|
781
800
|
|
|
782
|
-
logger.info(
|
|
801
|
+
logger.info("HITL: Parsed interrupt decisions", decisions_count=len(decisions))
|
|
783
802
|
return {"is_valid": True, "validation_message": None, "decisions": decisions}
|
|
784
803
|
|
|
785
804
|
except Exception as e:
|
|
786
|
-
logger.error(
|
|
805
|
+
logger.error("HITL: Failed to parse interrupt response", error=str(e))
|
|
787
806
|
# Return invalid response on parsing failure
|
|
788
807
|
return {
|
|
789
808
|
"is_valid": False,
|
|
@@ -840,7 +859,33 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
840
859
|
arguments: {...}
|
|
841
860
|
description: "..."
|
|
842
861
|
"""
|
|
843
|
-
|
|
862
|
+
# Extract conversation_id for logging (from context or custom_inputs)
|
|
863
|
+
conversation_id_for_log: str | None = None
|
|
864
|
+
if request.context and hasattr(request.context, "conversation_id"):
|
|
865
|
+
conversation_id_for_log = request.context.conversation_id
|
|
866
|
+
elif request.custom_inputs:
|
|
867
|
+
# Check configurable or session for conversation_id
|
|
868
|
+
if "configurable" in request.custom_inputs and isinstance(
|
|
869
|
+
request.custom_inputs["configurable"], dict
|
|
870
|
+
):
|
|
871
|
+
conversation_id_for_log = request.custom_inputs["configurable"].get(
|
|
872
|
+
"conversation_id"
|
|
873
|
+
)
|
|
874
|
+
if (
|
|
875
|
+
conversation_id_for_log is None
|
|
876
|
+
and "session" in request.custom_inputs
|
|
877
|
+
and isinstance(request.custom_inputs["session"], dict)
|
|
878
|
+
):
|
|
879
|
+
conversation_id_for_log = request.custom_inputs["session"].get(
|
|
880
|
+
"conversation_id"
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
logger.debug(
|
|
884
|
+
"ResponsesAgent predict called",
|
|
885
|
+
conversation_id=conversation_id_for_log
|
|
886
|
+
if conversation_id_for_log
|
|
887
|
+
else "new",
|
|
888
|
+
)
|
|
844
889
|
|
|
845
890
|
# Convert ResponsesAgent input to LangChain messages
|
|
846
891
|
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
@@ -870,7 +915,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
870
915
|
# Explicit structured decisions
|
|
871
916
|
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
872
917
|
logger.info(
|
|
873
|
-
|
|
918
|
+
"HITL: Resuming with explicit decisions",
|
|
919
|
+
decisions_count=len(decisions),
|
|
874
920
|
)
|
|
875
921
|
|
|
876
922
|
# Resume interrupted graph with decisions
|
|
@@ -888,7 +934,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
888
934
|
)
|
|
889
935
|
if is_interrupted(snapshot):
|
|
890
936
|
logger.info(
|
|
891
|
-
"HITL: Graph
|
|
937
|
+
"HITL: Graph interrupted, checking for user response"
|
|
892
938
|
)
|
|
893
939
|
|
|
894
940
|
# Convert message dicts to BaseMessage objects
|
|
@@ -910,7 +956,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
910
956
|
"Your response was unclear. Please provide a clear decision for each action.",
|
|
911
957
|
)
|
|
912
958
|
logger.warning(
|
|
913
|
-
|
|
959
|
+
"HITL: Invalid response from user",
|
|
960
|
+
validation_message=validation_message,
|
|
914
961
|
)
|
|
915
962
|
|
|
916
963
|
# Return error message to user instead of resuming
|
|
@@ -925,7 +972,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
925
972
|
|
|
926
973
|
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
927
974
|
logger.info(
|
|
928
|
-
|
|
975
|
+
"HITL: LLM parsed valid decisions from user message",
|
|
976
|
+
decisions_count=len(decisions),
|
|
929
977
|
)
|
|
930
978
|
|
|
931
979
|
# Resume interrupted graph with parsed decisions
|
|
@@ -941,15 +989,16 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
941
989
|
graph_input["genie_conversation_ids"] = session_input[
|
|
942
990
|
"genie_conversation_ids"
|
|
943
991
|
]
|
|
944
|
-
logger.
|
|
945
|
-
|
|
992
|
+
logger.trace(
|
|
993
|
+
"Including genie conversation IDs in graph input",
|
|
994
|
+
count=len(graph_input["genie_conversation_ids"]),
|
|
946
995
|
)
|
|
947
996
|
|
|
948
997
|
return await self.graph.ainvoke(
|
|
949
998
|
graph_input, context=context, config=custom_inputs
|
|
950
999
|
)
|
|
951
1000
|
except Exception as e:
|
|
952
|
-
logger.error(
|
|
1001
|
+
logger.error("Error in graph invocation", error=str(e))
|
|
953
1002
|
raise
|
|
954
1003
|
|
|
955
1004
|
try:
|
|
@@ -963,7 +1012,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
963
1012
|
_async_invoke()
|
|
964
1013
|
)
|
|
965
1014
|
except Exception as e:
|
|
966
|
-
logger.error(
|
|
1015
|
+
logger.error("Error in async execution", error=str(e))
|
|
967
1016
|
raise
|
|
968
1017
|
|
|
969
1018
|
# Convert response to ResponsesAgent format
|
|
@@ -983,7 +1032,10 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
983
1032
|
from pydantic import BaseModel
|
|
984
1033
|
|
|
985
1034
|
structured_response = response["structured_response"]
|
|
986
|
-
logger.
|
|
1035
|
+
logger.trace(
|
|
1036
|
+
"Processing structured response",
|
|
1037
|
+
response_type=type(structured_response).__name__,
|
|
1038
|
+
)
|
|
987
1039
|
|
|
988
1040
|
# Serialize to dict for JSON compatibility using type hints
|
|
989
1041
|
if isinstance(structured_response, BaseModel):
|
|
@@ -1010,7 +1062,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1010
1062
|
output_item = self.create_text_output_item(
|
|
1011
1063
|
text=structured_text, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
1012
1064
|
)
|
|
1013
|
-
logger.
|
|
1065
|
+
logger.trace("Structured response placed in message content")
|
|
1014
1066
|
else:
|
|
1015
1067
|
# No structured response, use text content
|
|
1016
1068
|
output_item = self.create_text_output_item(
|
|
@@ -1020,7 +1072,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1020
1072
|
# Include interrupt structure if HITL occurred (following LangChain pattern)
|
|
1021
1073
|
if "__interrupt__" in response:
|
|
1022
1074
|
interrupts: list[Interrupt] = response["__interrupt__"]
|
|
1023
|
-
logger.info(
|
|
1075
|
+
logger.info("HITL: Interrupts detected", interrupts_count=len(interrupts))
|
|
1024
1076
|
|
|
1025
1077
|
# Extract HITLRequest structures from interrupts (deduplicate by ID)
|
|
1026
1078
|
seen_interrupt_ids: set[str] = set()
|
|
@@ -1031,11 +1083,14 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1031
1083
|
if interrupt.id not in seen_interrupt_ids:
|
|
1032
1084
|
seen_interrupt_ids.add(interrupt.id)
|
|
1033
1085
|
interrupt_data.append(_extract_interrupt_value(interrupt))
|
|
1034
|
-
logger.
|
|
1086
|
+
logger.trace(
|
|
1087
|
+
"HITL: Added interrupt to response", interrupt_id=interrupt.id
|
|
1088
|
+
)
|
|
1035
1089
|
|
|
1036
1090
|
custom_outputs["interrupts"] = interrupt_data
|
|
1037
1091
|
logger.debug(
|
|
1038
|
-
|
|
1092
|
+
"HITL: Included interrupts in response",
|
|
1093
|
+
interrupts_count=len(interrupt_data),
|
|
1039
1094
|
)
|
|
1040
1095
|
|
|
1041
1096
|
# Add user-facing message about the pending actions
|
|
@@ -1058,7 +1113,33 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1058
1113
|
Uses same input/output structure as predict() for consistency.
|
|
1059
1114
|
Supports Human-in-the-Loop (HITL) interrupts.
|
|
1060
1115
|
"""
|
|
1061
|
-
|
|
1116
|
+
# Extract conversation_id for logging (from context or custom_inputs)
|
|
1117
|
+
conversation_id_for_log: str | None = None
|
|
1118
|
+
if request.context and hasattr(request.context, "conversation_id"):
|
|
1119
|
+
conversation_id_for_log = request.context.conversation_id
|
|
1120
|
+
elif request.custom_inputs:
|
|
1121
|
+
# Check configurable or session for conversation_id
|
|
1122
|
+
if "configurable" in request.custom_inputs and isinstance(
|
|
1123
|
+
request.custom_inputs["configurable"], dict
|
|
1124
|
+
):
|
|
1125
|
+
conversation_id_for_log = request.custom_inputs["configurable"].get(
|
|
1126
|
+
"conversation_id"
|
|
1127
|
+
)
|
|
1128
|
+
if (
|
|
1129
|
+
conversation_id_for_log is None
|
|
1130
|
+
and "session" in request.custom_inputs
|
|
1131
|
+
and isinstance(request.custom_inputs["session"], dict)
|
|
1132
|
+
):
|
|
1133
|
+
conversation_id_for_log = request.custom_inputs["session"].get(
|
|
1134
|
+
"conversation_id"
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
logger.debug(
|
|
1138
|
+
"ResponsesAgent predict_stream called",
|
|
1139
|
+
conversation_id=conversation_id_for_log
|
|
1140
|
+
if conversation_id_for_log
|
|
1141
|
+
else "new",
|
|
1142
|
+
)
|
|
1062
1143
|
|
|
1063
1144
|
# Convert ResponsesAgent input to LangChain messages
|
|
1064
1145
|
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
@@ -1094,7 +1175,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1094
1175
|
# Explicit structured decisions
|
|
1095
1176
|
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
1096
1177
|
logger.info(
|
|
1097
|
-
|
|
1178
|
+
"HITL: Resuming stream with explicit decisions",
|
|
1179
|
+
decisions_count=len(decisions),
|
|
1098
1180
|
)
|
|
1099
1181
|
stream_input: Command | dict[str, Any] = Command(
|
|
1100
1182
|
resume={"decisions": decisions}
|
|
@@ -1107,7 +1189,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1107
1189
|
)
|
|
1108
1190
|
if is_interrupted(snapshot):
|
|
1109
1191
|
logger.info(
|
|
1110
|
-
"HITL: Graph
|
|
1192
|
+
"HITL: Graph interrupted, checking for user response in stream"
|
|
1111
1193
|
)
|
|
1112
1194
|
|
|
1113
1195
|
# Convert message dicts to BaseMessage objects
|
|
@@ -1129,7 +1211,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1129
1211
|
"Your response was unclear. Please provide a clear decision for each action.",
|
|
1130
1212
|
)
|
|
1131
1213
|
logger.warning(
|
|
1132
|
-
|
|
1214
|
+
"HITL: Invalid response from user in stream",
|
|
1215
|
+
validation_message=validation_message,
|
|
1133
1216
|
)
|
|
1134
1217
|
|
|
1135
1218
|
# Build custom_outputs before returning
|
|
@@ -1156,7 +1239,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1156
1239
|
|
|
1157
1240
|
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
1158
1241
|
logger.info(
|
|
1159
|
-
|
|
1242
|
+
"HITL: LLM parsed valid decisions from user message in stream",
|
|
1243
|
+
decisions_count=len(decisions),
|
|
1160
1244
|
)
|
|
1161
1245
|
|
|
1162
1246
|
# Resume interrupted graph with parsed decisions
|
|
@@ -1226,7 +1310,8 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1226
1310
|
if source == "__interrupt__":
|
|
1227
1311
|
interrupts: list[Interrupt] = update
|
|
1228
1312
|
logger.info(
|
|
1229
|
-
|
|
1313
|
+
"HITL: Interrupts detected during streaming",
|
|
1314
|
+
interrupts_count=len(interrupts),
|
|
1230
1315
|
)
|
|
1231
1316
|
|
|
1232
1317
|
# Extract interrupt values (deduplicate by ID)
|
|
@@ -1238,8 +1323,9 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1238
1323
|
interrupt_data.append(
|
|
1239
1324
|
_extract_interrupt_value(interrupt)
|
|
1240
1325
|
)
|
|
1241
|
-
logger.
|
|
1242
|
-
|
|
1326
|
+
logger.trace(
|
|
1327
|
+
"HITL: Added interrupt to response",
|
|
1328
|
+
interrupt_id=interrupt.id,
|
|
1243
1329
|
)
|
|
1244
1330
|
elif (
|
|
1245
1331
|
isinstance(update, dict)
|
|
@@ -1247,8 +1333,9 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1247
1333
|
):
|
|
1248
1334
|
# Capture structured_response from stream updates
|
|
1249
1335
|
structured_response = update["structured_response"]
|
|
1250
|
-
logger.
|
|
1251
|
-
|
|
1336
|
+
logger.trace(
|
|
1337
|
+
"Captured structured response from stream",
|
|
1338
|
+
response_type=type(structured_response).__name__,
|
|
1252
1339
|
)
|
|
1253
1340
|
|
|
1254
1341
|
# Get final state to extract structured_response (only if checkpointer available)
|
|
@@ -1276,8 +1363,9 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1276
1363
|
|
|
1277
1364
|
from pydantic import BaseModel
|
|
1278
1365
|
|
|
1279
|
-
logger.
|
|
1280
|
-
|
|
1366
|
+
logger.trace(
|
|
1367
|
+
"Processing structured response in streaming",
|
|
1368
|
+
response_type=type(structured_response).__name__,
|
|
1281
1369
|
)
|
|
1282
1370
|
|
|
1283
1371
|
# Serialize to dict for JSON compatibility using type hints
|
|
@@ -1320,13 +1408,14 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1320
1408
|
)
|
|
1321
1409
|
output_text = structured_text
|
|
1322
1410
|
|
|
1323
|
-
logger.
|
|
1411
|
+
logger.trace("Streamed structured response in message content")
|
|
1324
1412
|
|
|
1325
1413
|
# Include interrupt structure if HITL occurred
|
|
1326
1414
|
if interrupt_data:
|
|
1327
1415
|
custom_outputs["interrupts"] = interrupt_data
|
|
1328
1416
|
logger.info(
|
|
1329
|
-
|
|
1417
|
+
"HITL: Included interrupts in streaming response",
|
|
1418
|
+
interrupts_count=len(interrupt_data),
|
|
1330
1419
|
)
|
|
1331
1420
|
|
|
1332
1421
|
# Add user-facing message about the pending actions
|
|
@@ -1361,7 +1450,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1361
1450
|
custom_outputs=custom_outputs,
|
|
1362
1451
|
)
|
|
1363
1452
|
except Exception as e:
|
|
1364
|
-
logger.error(
|
|
1453
|
+
logger.error("Error in graph streaming", error=str(e))
|
|
1365
1454
|
raise
|
|
1366
1455
|
|
|
1367
1456
|
# Convert async generator to sync generator
|
|
@@ -1381,13 +1470,13 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1381
1470
|
except StopAsyncIteration:
|
|
1382
1471
|
break
|
|
1383
1472
|
except Exception as e:
|
|
1384
|
-
logger.error(
|
|
1473
|
+
logger.error("Error in streaming", error=str(e))
|
|
1385
1474
|
raise
|
|
1386
1475
|
finally:
|
|
1387
1476
|
try:
|
|
1388
1477
|
loop.run_until_complete(async_gen.aclose())
|
|
1389
1478
|
except Exception as e:
|
|
1390
|
-
logger.warning(
|
|
1479
|
+
logger.warning("Error closing async generator", error=str(e))
|
|
1391
1480
|
|
|
1392
1481
|
def _extract_text_from_content(
|
|
1393
1482
|
self,
|
|
@@ -1462,8 +1551,11 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1462
1551
|
conversation_id can be provided in either configurable or session.
|
|
1463
1552
|
Normalizes user_id (replaces . with _) for memory namespace compatibility.
|
|
1464
1553
|
"""
|
|
1465
|
-
logger.
|
|
1466
|
-
|
|
1554
|
+
logger.trace(
|
|
1555
|
+
"Converting request to context",
|
|
1556
|
+
has_context=request.context is not None,
|
|
1557
|
+
has_custom_inputs=request.custom_inputs is not None,
|
|
1558
|
+
)
|
|
1467
1559
|
|
|
1468
1560
|
configurable: dict[str, Any] = {}
|
|
1469
1561
|
session: dict[str, Any] = {}
|
|
@@ -1521,17 +1613,18 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1521
1613
|
# Generate new thread_id if neither provided
|
|
1522
1614
|
thread_id = str(uuid.uuid4())
|
|
1523
1615
|
|
|
1524
|
-
# All remaining configurable values
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1616
|
+
# All remaining configurable values become top-level context attributes
|
|
1617
|
+
logger.trace(
|
|
1618
|
+
"Creating context",
|
|
1619
|
+
user_id=user_id_value,
|
|
1620
|
+
thread_id=thread_id,
|
|
1621
|
+
extra_keys=list(configurable.keys()) if configurable else [],
|
|
1529
1622
|
)
|
|
1530
1623
|
|
|
1531
1624
|
return Context(
|
|
1532
1625
|
user_id=user_id_value,
|
|
1533
1626
|
thread_id=thread_id,
|
|
1534
|
-
|
|
1627
|
+
**configurable, # Pass remaining configurable values as context attributes
|
|
1535
1628
|
)
|
|
1536
1629
|
|
|
1537
1630
|
def _extract_session_from_request(
|
|
@@ -1621,8 +1714,11 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1621
1714
|
if context.user_id:
|
|
1622
1715
|
configurable["user_id"] = context.user_id
|
|
1623
1716
|
|
|
1624
|
-
# Include all
|
|
1625
|
-
|
|
1717
|
+
# Include all extra fields from context (beyond user_id and thread_id)
|
|
1718
|
+
context_dict = context.model_dump()
|
|
1719
|
+
for key, value in context_dict.items():
|
|
1720
|
+
if key not in {"user_id", "thread_id"} and value is not None:
|
|
1721
|
+
configurable[key] = value
|
|
1626
1722
|
|
|
1627
1723
|
# Build session section with accumulated state
|
|
1628
1724
|
# Note: conversation_id is included here as an alias of thread_id
|
|
@@ -1726,11 +1822,11 @@ def _configurable_to_context(configurable: dict[str, Any]) -> Context:
|
|
|
1726
1822
|
if not thread_id:
|
|
1727
1823
|
thread_id = str(uuid.uuid4())
|
|
1728
1824
|
|
|
1729
|
-
# All remaining values
|
|
1825
|
+
# All remaining values become top-level context attributes
|
|
1730
1826
|
return Context(
|
|
1731
1827
|
user_id=user_id,
|
|
1732
1828
|
thread_id=thread_id,
|
|
1733
|
-
|
|
1829
|
+
**configurable, # Extra fields become top-level attributes
|
|
1734
1830
|
)
|
|
1735
1831
|
|
|
1736
1832
|
|
|
@@ -1745,7 +1841,11 @@ def _process_langchain_messages_stream(
|
|
|
1745
1841
|
if isinstance(app, LanggraphChatModel):
|
|
1746
1842
|
app = app.graph
|
|
1747
1843
|
|
|
1748
|
-
logger.
|
|
1844
|
+
logger.trace(
|
|
1845
|
+
"Processing messages for stream",
|
|
1846
|
+
messages_count=len(messages),
|
|
1847
|
+
has_custom_inputs=custom_inputs is not None,
|
|
1848
|
+
)
|
|
1749
1849
|
|
|
1750
1850
|
configurable = (custom_inputs or {}).get("configurable", custom_inputs or {})
|
|
1751
1851
|
context: Context = _configurable_to_context(configurable)
|
|
@@ -1763,7 +1863,10 @@ def _process_langchain_messages_stream(
|
|
|
1763
1863
|
stream_mode: str
|
|
1764
1864
|
stream_messages: Sequence[BaseMessage]
|
|
1765
1865
|
logger.trace(
|
|
1766
|
-
|
|
1866
|
+
"Stream batch received",
|
|
1867
|
+
nodes=nodes,
|
|
1868
|
+
stream_mode=stream_mode,
|
|
1869
|
+
messages_count=len(stream_messages),
|
|
1767
1870
|
)
|
|
1768
1871
|
for message in stream_messages:
|
|
1769
1872
|
if (
|