dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py
CHANGED
|
@@ -1,7 +1,16 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from os import PathLike
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import (
|
|
5
|
+
TYPE_CHECKING,
|
|
6
|
+
Any,
|
|
7
|
+
AsyncGenerator,
|
|
8
|
+
Generator,
|
|
9
|
+
Literal,
|
|
10
|
+
Optional,
|
|
11
|
+
Sequence,
|
|
12
|
+
Union,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
15
|
from databricks_langchain import ChatDatabricks
|
|
7
16
|
|
|
@@ -825,13 +834,16 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
825
834
|
) -> None:
|
|
826
835
|
self.graph = graph
|
|
827
836
|
|
|
828
|
-
def
|
|
837
|
+
async def apredict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
829
838
|
"""
|
|
839
|
+
Async version of predict - primary implementation for Databricks Apps.
|
|
840
|
+
|
|
830
841
|
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
842
|
+
This method can be awaited directly in async contexts (e.g., MLflow AgentServer).
|
|
831
843
|
|
|
832
844
|
Input structure (custom_inputs):
|
|
833
845
|
configurable:
|
|
834
|
-
thread_id: "abc-123" # Or conversation_id (aliases
|
|
846
|
+
thread_id: "abc-123" # Or conversation_id (aliases)
|
|
835
847
|
user_id: "nate.fleming"
|
|
836
848
|
store_num: "87887"
|
|
837
849
|
session: # Paste from previous output
|
|
@@ -846,11 +858,11 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
846
858
|
|
|
847
859
|
Output structure (custom_outputs):
|
|
848
860
|
configurable:
|
|
849
|
-
thread_id: "abc-123"
|
|
861
|
+
thread_id: "abc-123"
|
|
850
862
|
user_id: "nate.fleming"
|
|
851
863
|
store_num: "87887"
|
|
852
864
|
session:
|
|
853
|
-
conversation_id: "abc-123"
|
|
865
|
+
conversation_id: "abc-123"
|
|
854
866
|
genie:
|
|
855
867
|
spaces:
|
|
856
868
|
space_123: {conversation_id: "conv_456", ...}
|
|
@@ -859,12 +871,13 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
859
871
|
arguments: {...}
|
|
860
872
|
description: "..."
|
|
861
873
|
"""
|
|
862
|
-
|
|
874
|
+
from langgraph.types import Command
|
|
875
|
+
|
|
876
|
+
# Extract conversation_id for logging
|
|
863
877
|
conversation_id_for_log: str | None = None
|
|
864
878
|
if request.context and hasattr(request.context, "conversation_id"):
|
|
865
879
|
conversation_id_for_log = request.context.conversation_id
|
|
866
880
|
elif request.custom_inputs:
|
|
867
|
-
# Check configurable or session for conversation_id
|
|
868
881
|
if "configurable" in request.custom_inputs and isinstance(
|
|
869
882
|
request.custom_inputs["configurable"], dict
|
|
870
883
|
):
|
|
@@ -881,7 +894,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
881
894
|
)
|
|
882
895
|
|
|
883
896
|
logger.debug(
|
|
884
|
-
"ResponsesAgent
|
|
897
|
+
"ResponsesAgent apredict called",
|
|
885
898
|
conversation_id=conversation_id_for_log
|
|
886
899
|
if conversation_id_for_log
|
|
887
900
|
else "new",
|
|
@@ -899,130 +912,106 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
899
912
|
# Extract session state from request
|
|
900
913
|
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
901
914
|
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
915
|
+
try:
|
|
916
|
+
# Check if this is a resume request (HITL)
|
|
917
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
918
|
+
# Explicit structured decisions
|
|
919
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
920
|
+
logger.info(
|
|
921
|
+
"HITL: Resuming with explicit decisions",
|
|
922
|
+
decisions_count=len(decisions),
|
|
923
|
+
)
|
|
906
924
|
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
if
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
)
|
|
925
|
+
# Resume interrupted graph with decisions
|
|
926
|
+
response = await self.graph.ainvoke(
|
|
927
|
+
Command(resume={"decisions": decisions}),
|
|
928
|
+
context=context,
|
|
929
|
+
config=custom_inputs,
|
|
930
|
+
)
|
|
931
|
+
elif self.graph.checkpointer:
|
|
932
|
+
# Check if graph is currently interrupted
|
|
933
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
934
|
+
config=custom_inputs
|
|
935
|
+
)
|
|
936
|
+
if is_interrupted(snapshot):
|
|
937
|
+
logger.info("HITL: Graph interrupted, checking for user response")
|
|
921
938
|
|
|
922
|
-
#
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
context=context,
|
|
926
|
-
config=custom_inputs,
|
|
939
|
+
# Convert message dicts to BaseMessage objects
|
|
940
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
941
|
+
messages
|
|
927
942
|
)
|
|
928
943
|
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
944
|
+
# Parse user's message with LLM to extract decisions
|
|
945
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
946
|
+
snapshot=snapshot,
|
|
947
|
+
messages=message_objects,
|
|
948
|
+
model=None,
|
|
934
949
|
)
|
|
935
|
-
if is_interrupted(snapshot):
|
|
936
|
-
logger.info(
|
|
937
|
-
"HITL: Graph interrupted, checking for user response"
|
|
938
|
-
)
|
|
939
950
|
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
951
|
+
if not parsed_result.get("is_valid", False):
|
|
952
|
+
validation_message: str = parsed_result.get(
|
|
953
|
+
"validation_message",
|
|
954
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
943
955
|
)
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
snapshot=snapshot,
|
|
948
|
-
messages=message_objects,
|
|
949
|
-
model=None, # Uses default model
|
|
956
|
+
logger.warning(
|
|
957
|
+
"HITL: Invalid response from user",
|
|
958
|
+
validation_message=validation_message,
|
|
950
959
|
)
|
|
951
960
|
|
|
952
|
-
#
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
# Return error message to user instead of resuming
|
|
964
|
-
# Don't resume the graph - stay interrupted so user can try again
|
|
965
|
-
return {
|
|
966
|
-
"messages": [
|
|
967
|
-
AIMessage(
|
|
968
|
-
content=f"❌ **Invalid Response**\n\n{validation_message}"
|
|
969
|
-
)
|
|
970
|
-
]
|
|
971
|
-
}
|
|
972
|
-
|
|
973
|
-
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
961
|
+
# Return error message without resuming
|
|
962
|
+
response = {
|
|
963
|
+
"messages": [
|
|
964
|
+
AIMessage(
|
|
965
|
+
content=f"❌ **Invalid Response**\n\n{validation_message}"
|
|
966
|
+
)
|
|
967
|
+
]
|
|
968
|
+
}
|
|
969
|
+
else:
|
|
970
|
+
decisions = parsed_result.get("decisions", [])
|
|
974
971
|
logger.info(
|
|
975
972
|
"HITL: LLM parsed valid decisions from user message",
|
|
976
973
|
decisions_count=len(decisions),
|
|
977
974
|
)
|
|
978
975
|
|
|
979
976
|
# Resume interrupted graph with parsed decisions
|
|
980
|
-
|
|
977
|
+
response = await self.graph.ainvoke(
|
|
981
978
|
Command(resume={"decisions": decisions}),
|
|
982
979
|
context=context,
|
|
983
980
|
config=custom_inputs,
|
|
984
981
|
)
|
|
982
|
+
else:
|
|
983
|
+
# Normal invocation
|
|
984
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
985
|
+
if "genie_conversation_ids" in session_input:
|
|
986
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
987
|
+
"genie_conversation_ids"
|
|
988
|
+
]
|
|
985
989
|
|
|
986
|
-
|
|
987
|
-
|
|
990
|
+
response = await self.graph.ainvoke(
|
|
991
|
+
graph_input, context=context, config=custom_inputs
|
|
992
|
+
)
|
|
993
|
+
else:
|
|
994
|
+
# No checkpointer, use normal invocation
|
|
995
|
+
graph_input = {"messages": messages}
|
|
988
996
|
if "genie_conversation_ids" in session_input:
|
|
989
997
|
graph_input["genie_conversation_ids"] = session_input[
|
|
990
998
|
"genie_conversation_ids"
|
|
991
999
|
]
|
|
992
|
-
logger.trace(
|
|
993
|
-
"Including genie conversation IDs in graph input",
|
|
994
|
-
count=len(graph_input["genie_conversation_ids"]),
|
|
995
|
-
)
|
|
996
1000
|
|
|
997
|
-
|
|
1001
|
+
response = await self.graph.ainvoke(
|
|
998
1002
|
graph_input, context=context, config=custom_inputs
|
|
999
1003
|
)
|
|
1000
|
-
except Exception as e:
|
|
1001
|
-
logger.error("Error in graph invocation", error=str(e))
|
|
1002
|
-
raise
|
|
1003
|
-
|
|
1004
|
-
try:
|
|
1005
|
-
loop = asyncio.get_event_loop()
|
|
1006
|
-
except RuntimeError:
|
|
1007
|
-
loop = asyncio.new_event_loop()
|
|
1008
|
-
asyncio.set_event_loop(loop)
|
|
1009
|
-
|
|
1010
|
-
try:
|
|
1011
|
-
response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
|
|
1012
|
-
_async_invoke()
|
|
1013
|
-
)
|
|
1014
1004
|
except Exception as e:
|
|
1015
|
-
logger.error("Error in
|
|
1005
|
+
logger.error("Error in graph invocation", error=str(e))
|
|
1016
1006
|
raise
|
|
1017
1007
|
|
|
1018
1008
|
# Convert response to ResponsesAgent format
|
|
1019
1009
|
last_message: BaseMessage = response["messages"][-1]
|
|
1020
1010
|
|
|
1021
|
-
# Build custom_outputs
|
|
1022
|
-
custom_outputs: dict[str, Any] = self.
|
|
1011
|
+
# Build custom_outputs
|
|
1012
|
+
custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
|
|
1023
1013
|
context=context,
|
|
1024
1014
|
thread_id=context.thread_id,
|
|
1025
|
-
loop=loop,
|
|
1026
1015
|
)
|
|
1027
1016
|
|
|
1028
1017
|
# Handle structured_response if present
|
|
@@ -1037,25 +1026,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1037
1026
|
response_type=type(structured_response).__name__,
|
|
1038
1027
|
)
|
|
1039
1028
|
|
|
1040
|
-
# Serialize to dict for JSON compatibility using type hints
|
|
1041
1029
|
if isinstance(structured_response, BaseModel):
|
|
1042
|
-
# Pydantic model
|
|
1043
1030
|
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1044
1031
|
elif is_dataclass(structured_response):
|
|
1045
|
-
# Dataclass
|
|
1046
1032
|
serialized = asdict(structured_response)
|
|
1047
1033
|
elif isinstance(structured_response, dict):
|
|
1048
|
-
# Already a dict
|
|
1049
1034
|
serialized = structured_response
|
|
1050
1035
|
else:
|
|
1051
|
-
# Unknown type, convert to dict if possible
|
|
1052
1036
|
serialized = (
|
|
1053
1037
|
dict(structured_response)
|
|
1054
1038
|
if hasattr(structured_response, "__dict__")
|
|
1055
1039
|
else structured_response
|
|
1056
1040
|
)
|
|
1057
1041
|
|
|
1058
|
-
# Place structured output in message content as JSON
|
|
1059
1042
|
import json
|
|
1060
1043
|
|
|
1061
1044
|
structured_text: str = json.dumps(serialized, indent=2)
|
|
@@ -1064,22 +1047,18 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1064
1047
|
)
|
|
1065
1048
|
logger.trace("Structured response placed in message content")
|
|
1066
1049
|
else:
|
|
1067
|
-
# No structured response, use text content
|
|
1068
1050
|
output_item = self.create_text_output_item(
|
|
1069
1051
|
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
1070
1052
|
)
|
|
1071
1053
|
|
|
1072
|
-
# Include interrupt structure if HITL occurred
|
|
1054
|
+
# Include interrupt structure if HITL occurred
|
|
1073
1055
|
if "__interrupt__" in response:
|
|
1074
1056
|
interrupts: list[Interrupt] = response["__interrupt__"]
|
|
1075
1057
|
logger.info("HITL: Interrupts detected", interrupts_count=len(interrupts))
|
|
1076
1058
|
|
|
1077
|
-
# Extract HITLRequest structures from interrupts (deduplicate by ID)
|
|
1078
1059
|
seen_interrupt_ids: set[str] = set()
|
|
1079
1060
|
interrupt_data: list[HITLRequest] = []
|
|
1080
|
-
interrupt: Interrupt
|
|
1081
1061
|
for interrupt in interrupts:
|
|
1082
|
-
# Only process each unique interrupt once
|
|
1083
1062
|
if interrupt.id not in seen_interrupt_ids:
|
|
1084
1063
|
seen_interrupt_ids.add(interrupt.id)
|
|
1085
1064
|
interrupt_data.append(_extract_interrupt_value(interrupt))
|
|
@@ -1093,7 +1072,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1093
1072
|
interrupts_count=len(interrupt_data),
|
|
1094
1073
|
)
|
|
1095
1074
|
|
|
1096
|
-
# Add user-facing message about the pending actions
|
|
1097
1075
|
action_message: str = _format_action_requests_message(interrupt_data)
|
|
1098
1076
|
if action_message:
|
|
1099
1077
|
output_item = self.create_text_output_item(
|
|
@@ -1104,21 +1082,25 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1104
1082
|
output=[output_item], custom_outputs=custom_outputs
|
|
1105
1083
|
)
|
|
1106
1084
|
|
|
1107
|
-
def
|
|
1085
|
+
async def apredict_stream(
|
|
1108
1086
|
self, request: ResponsesAgentRequest
|
|
1109
|
-
) ->
|
|
1087
|
+
) -> AsyncGenerator[ResponsesAgentStreamEvent, None]:
|
|
1110
1088
|
"""
|
|
1089
|
+
Async version of predict_stream - primary implementation for Databricks Apps.
|
|
1090
|
+
|
|
1111
1091
|
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
1092
|
+
This method can be used directly with async for loops in async contexts.
|
|
1112
1093
|
|
|
1113
|
-
Uses same input/output structure as
|
|
1094
|
+
Uses same input/output structure as apredict() for consistency.
|
|
1114
1095
|
Supports Human-in-the-Loop (HITL) interrupts.
|
|
1115
1096
|
"""
|
|
1116
|
-
|
|
1097
|
+
from langgraph.types import Command
|
|
1098
|
+
|
|
1099
|
+
# Extract conversation_id for logging
|
|
1117
1100
|
conversation_id_for_log: str | None = None
|
|
1118
1101
|
if request.context and hasattr(request.context, "conversation_id"):
|
|
1119
1102
|
conversation_id_for_log = request.context.conversation_id
|
|
1120
1103
|
elif request.custom_inputs:
|
|
1121
|
-
# Check configurable or session for conversation_id
|
|
1122
1104
|
if "configurable" in request.custom_inputs and isinstance(
|
|
1123
1105
|
request.custom_inputs["configurable"], dict
|
|
1124
1106
|
):
|
|
@@ -1135,7 +1117,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1135
1117
|
)
|
|
1136
1118
|
|
|
1137
1119
|
logger.debug(
|
|
1138
|
-
"ResponsesAgent
|
|
1120
|
+
"ResponsesAgent apredict_stream called",
|
|
1139
1121
|
conversation_id=conversation_id_for_log
|
|
1140
1122
|
if conversation_id_for_log
|
|
1141
1123
|
else "new",
|
|
@@ -1153,305 +1135,280 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1153
1135
|
# Extract session state from request
|
|
1154
1136
|
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
1155
1137
|
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
async def _async_stream():
|
|
1162
|
-
item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
|
|
1163
|
-
accumulated_content: str = ""
|
|
1164
|
-
interrupt_data: list[HITLRequest] = []
|
|
1165
|
-
seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
|
|
1166
|
-
structured_response: Any = None # Track structured output from stream
|
|
1138
|
+
item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
|
|
1139
|
+
accumulated_content: str = ""
|
|
1140
|
+
interrupt_data: list[HITLRequest] = []
|
|
1141
|
+
seen_interrupt_ids: set[str] = set()
|
|
1142
|
+
structured_response: Any = None
|
|
1167
1143
|
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
1144
|
+
try:
|
|
1145
|
+
# Check if this is a resume request (HITL)
|
|
1146
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
1147
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
1148
|
+
logger.info(
|
|
1149
|
+
"HITL: Resuming stream with explicit decisions",
|
|
1150
|
+
decisions_count=len(decisions),
|
|
1151
|
+
)
|
|
1152
|
+
stream_input: Command | dict[str, Any] = Command(
|
|
1153
|
+
resume={"decisions": decisions}
|
|
1154
|
+
)
|
|
1155
|
+
elif self.graph.checkpointer:
|
|
1156
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
1157
|
+
config=custom_inputs
|
|
1158
|
+
)
|
|
1159
|
+
if is_interrupted(snapshot):
|
|
1177
1160
|
logger.info(
|
|
1178
|
-
"HITL:
|
|
1179
|
-
decisions_count=len(decisions),
|
|
1161
|
+
"HITL: Graph interrupted, checking for user response in stream"
|
|
1180
1162
|
)
|
|
1181
|
-
|
|
1182
|
-
|
|
1163
|
+
|
|
1164
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
1165
|
+
messages
|
|
1183
1166
|
)
|
|
1184
|
-
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1167
|
+
|
|
1168
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
1169
|
+
snapshot=snapshot,
|
|
1170
|
+
messages=message_objects,
|
|
1171
|
+
model=None,
|
|
1189
1172
|
)
|
|
1190
|
-
if is_interrupted(snapshot):
|
|
1191
|
-
logger.info(
|
|
1192
|
-
"HITL: Graph interrupted, checking for user response in stream"
|
|
1193
|
-
)
|
|
1194
1173
|
|
|
1195
|
-
|
|
1196
|
-
|
|
1197
|
-
|
|
1174
|
+
if not parsed_result.get("is_valid", False):
|
|
1175
|
+
validation_message: str = parsed_result.get(
|
|
1176
|
+
"validation_message",
|
|
1177
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
1198
1178
|
)
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
snapshot=snapshot,
|
|
1203
|
-
messages=message_objects,
|
|
1204
|
-
model=None, # Uses default model
|
|
1179
|
+
logger.warning(
|
|
1180
|
+
"HITL: Invalid response from user in stream",
|
|
1181
|
+
validation_message=validation_message,
|
|
1205
1182
|
)
|
|
1206
1183
|
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
)
|
|
1213
|
-
logger.warning(
|
|
1214
|
-
"HITL: Invalid response from user in stream",
|
|
1215
|
-
validation_message=validation_message,
|
|
1216
|
-
)
|
|
1217
|
-
|
|
1218
|
-
# Build custom_outputs before returning
|
|
1219
|
-
custom_outputs: dict[
|
|
1220
|
-
str, Any
|
|
1221
|
-
] = await self._build_custom_outputs_async(
|
|
1222
|
-
context=context,
|
|
1223
|
-
thread_id=context.thread_id,
|
|
1224
|
-
)
|
|
1225
|
-
|
|
1226
|
-
# Yield error message to user - don't resume graph
|
|
1227
|
-
error_message: str = (
|
|
1228
|
-
f"❌ **Invalid Response**\n\n{validation_message}"
|
|
1229
|
-
)
|
|
1230
|
-
accumulated_content = error_message
|
|
1231
|
-
yield ResponsesAgentStreamEvent(
|
|
1232
|
-
type="response.output_item.done",
|
|
1233
|
-
item=self.create_text_output_item(
|
|
1234
|
-
text=error_message, id=item_id
|
|
1235
|
-
),
|
|
1236
|
-
custom_outputs=custom_outputs,
|
|
1237
|
-
)
|
|
1238
|
-
return # Don't resume - stay interrupted
|
|
1239
|
-
|
|
1240
|
-
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
1241
|
-
logger.info(
|
|
1242
|
-
"HITL: LLM parsed valid decisions from user message in stream",
|
|
1243
|
-
decisions_count=len(decisions),
|
|
1184
|
+
custom_outputs: dict[
|
|
1185
|
+
str, Any
|
|
1186
|
+
] = await self._build_custom_outputs_async(
|
|
1187
|
+
context=context,
|
|
1188
|
+
thread_id=context.thread_id,
|
|
1244
1189
|
)
|
|
1245
1190
|
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
resume={"decisions": decisions}
|
|
1191
|
+
error_message: str = (
|
|
1192
|
+
f"❌ **Invalid Response**\n\n{validation_message}"
|
|
1249
1193
|
)
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1194
|
+
yield ResponsesAgentStreamEvent(
|
|
1195
|
+
type="response.output_item.done",
|
|
1196
|
+
item=self.create_text_output_item(
|
|
1197
|
+
text=error_message, id=item_id
|
|
1198
|
+
),
|
|
1199
|
+
custom_outputs=custom_outputs,
|
|
1200
|
+
)
|
|
1201
|
+
return
|
|
1202
|
+
|
|
1203
|
+
decisions = parsed_result.get("decisions", [])
|
|
1204
|
+
logger.info(
|
|
1205
|
+
"HITL: LLM parsed valid decisions from user message in stream",
|
|
1206
|
+
decisions_count=len(decisions),
|
|
1207
|
+
)
|
|
1208
|
+
|
|
1209
|
+
stream_input = Command(resume={"decisions": decisions})
|
|
1258
1210
|
else:
|
|
1259
|
-
# No checkpointer, use normal invocation
|
|
1260
1211
|
graph_input: dict[str, Any] = {"messages": messages}
|
|
1261
1212
|
if "genie_conversation_ids" in session_input:
|
|
1262
1213
|
graph_input["genie_conversation_ids"] = session_input[
|
|
1263
1214
|
"genie_conversation_ids"
|
|
1264
1215
|
]
|
|
1265
|
-
stream_input
|
|
1216
|
+
stream_input = graph_input
|
|
1217
|
+
else:
|
|
1218
|
+
graph_input = {"messages": messages}
|
|
1219
|
+
if "genie_conversation_ids" in session_input:
|
|
1220
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
1221
|
+
"genie_conversation_ids"
|
|
1222
|
+
]
|
|
1223
|
+
stream_input = graph_input
|
|
1266
1224
|
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
# Handle message streaming
|
|
1279
|
-
if stream_mode == "messages":
|
|
1280
|
-
messages_batch: Sequence[BaseMessage] = data
|
|
1281
|
-
message: BaseMessage
|
|
1282
|
-
for message in messages_batch:
|
|
1283
|
-
if (
|
|
1284
|
-
isinstance(
|
|
1285
|
-
message,
|
|
1286
|
-
(
|
|
1287
|
-
AIMessageChunk,
|
|
1288
|
-
AIMessage,
|
|
1289
|
-
),
|
|
1290
|
-
)
|
|
1291
|
-
and message.content
|
|
1292
|
-
and "summarization" not in nodes
|
|
1293
|
-
):
|
|
1294
|
-
content: str = message.content
|
|
1295
|
-
accumulated_content += content
|
|
1296
|
-
|
|
1297
|
-
# Yield streaming delta
|
|
1298
|
-
yield ResponsesAgentStreamEvent(
|
|
1299
|
-
**self.create_text_delta(
|
|
1300
|
-
delta=content, item_id=item_id
|
|
1301
|
-
)
|
|
1302
|
-
)
|
|
1225
|
+
# Stream the graph execution
|
|
1226
|
+
async for nodes, stream_mode, data in self.graph.astream(
|
|
1227
|
+
stream_input,
|
|
1228
|
+
context=context,
|
|
1229
|
+
config=custom_inputs,
|
|
1230
|
+
stream_mode=["messages", "updates"],
|
|
1231
|
+
subgraphs=True,
|
|
1232
|
+
):
|
|
1233
|
+
nodes: tuple[str, ...]
|
|
1234
|
+
stream_mode: str
|
|
1303
1235
|
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
interrupts_count=len(interrupts),
|
|
1315
|
-
)
|
|
1236
|
+
if stream_mode == "messages":
|
|
1237
|
+
messages_batch: Sequence[BaseMessage] = data
|
|
1238
|
+
for message in messages_batch:
|
|
1239
|
+
if (
|
|
1240
|
+
isinstance(message, (AIMessageChunk, AIMessage))
|
|
1241
|
+
and message.content
|
|
1242
|
+
and "summarization" not in nodes
|
|
1243
|
+
):
|
|
1244
|
+
content: str = message.content
|
|
1245
|
+
accumulated_content += content
|
|
1316
1246
|
|
|
1317
|
-
|
|
1318
|
-
|
|
1319
|
-
|
|
1320
|
-
# Only process each unique interrupt once
|
|
1321
|
-
if interrupt.id not in seen_interrupt_ids:
|
|
1322
|
-
seen_interrupt_ids.add(interrupt.id)
|
|
1323
|
-
interrupt_data.append(
|
|
1324
|
-
_extract_interrupt_value(interrupt)
|
|
1325
|
-
)
|
|
1326
|
-
logger.trace(
|
|
1327
|
-
"HITL: Added interrupt to response",
|
|
1328
|
-
interrupt_id=interrupt.id,
|
|
1329
|
-
)
|
|
1330
|
-
elif (
|
|
1331
|
-
isinstance(update, dict)
|
|
1332
|
-
and "structured_response" in update
|
|
1333
|
-
):
|
|
1334
|
-
# Capture structured_response from stream updates
|
|
1335
|
-
structured_response = update["structured_response"]
|
|
1336
|
-
logger.trace(
|
|
1337
|
-
"Captured structured response from stream",
|
|
1338
|
-
response_type=type(structured_response).__name__,
|
|
1339
|
-
)
|
|
1247
|
+
yield ResponsesAgentStreamEvent(
|
|
1248
|
+
**self.create_text_delta(delta=content, item_id=item_id)
|
|
1249
|
+
)
|
|
1340
1250
|
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
):
|
|
1351
|
-
structured_response = final_state.values["structured_response"]
|
|
1251
|
+
elif stream_mode == "updates":
|
|
1252
|
+
updates: dict[str, Any] = data
|
|
1253
|
+
for source, update in updates.items():
|
|
1254
|
+
if source == "__interrupt__":
|
|
1255
|
+
interrupts: list[Interrupt] = update
|
|
1256
|
+
logger.info(
|
|
1257
|
+
"HITL: Interrupts detected during streaming",
|
|
1258
|
+
interrupts_count=len(interrupts),
|
|
1259
|
+
)
|
|
1352
1260
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1261
|
+
for interrupt in interrupts:
|
|
1262
|
+
if interrupt.id not in seen_interrupt_ids:
|
|
1263
|
+
seen_interrupt_ids.add(interrupt.id)
|
|
1264
|
+
interrupt_data.append(
|
|
1265
|
+
_extract_interrupt_value(interrupt)
|
|
1266
|
+
)
|
|
1267
|
+
logger.trace(
|
|
1268
|
+
"HITL: Added interrupt to response",
|
|
1269
|
+
interrupt_id=interrupt.id,
|
|
1270
|
+
)
|
|
1271
|
+
elif (
|
|
1272
|
+
isinstance(update, dict) and "structured_response" in update
|
|
1273
|
+
):
|
|
1274
|
+
structured_response = update["structured_response"]
|
|
1275
|
+
logger.trace(
|
|
1276
|
+
"Captured structured response from stream",
|
|
1277
|
+
response_type=type(structured_response).__name__,
|
|
1278
|
+
)
|
|
1279
|
+
|
|
1280
|
+
# Get final state if checkpointer available
|
|
1281
|
+
if self.graph.checkpointer:
|
|
1282
|
+
final_state: StateSnapshot = await self.graph.aget_state(
|
|
1283
|
+
config=custom_inputs
|
|
1357
1284
|
)
|
|
1285
|
+
if (
|
|
1286
|
+
"structured_response" in final_state.values
|
|
1287
|
+
and not structured_response
|
|
1288
|
+
):
|
|
1289
|
+
structured_response = final_state.values["structured_response"]
|
|
1290
|
+
|
|
1291
|
+
# Build custom_outputs
|
|
1292
|
+
custom_outputs = await self._build_custom_outputs_async(
|
|
1293
|
+
context=context,
|
|
1294
|
+
thread_id=context.thread_id,
|
|
1295
|
+
)
|
|
1358
1296
|
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1297
|
+
# Handle structured_response in streaming
|
|
1298
|
+
output_text: str = accumulated_content
|
|
1299
|
+
if structured_response:
|
|
1300
|
+
from dataclasses import asdict, is_dataclass
|
|
1363
1301
|
|
|
1364
|
-
|
|
1302
|
+
from pydantic import BaseModel
|
|
1365
1303
|
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1304
|
+
logger.trace(
|
|
1305
|
+
"Processing structured response in streaming",
|
|
1306
|
+
response_type=type(structured_response).__name__,
|
|
1307
|
+
)
|
|
1308
|
+
|
|
1309
|
+
if isinstance(structured_response, BaseModel):
|
|
1310
|
+
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1311
|
+
elif is_dataclass(structured_response):
|
|
1312
|
+
serialized = asdict(structured_response)
|
|
1313
|
+
elif isinstance(structured_response, dict):
|
|
1314
|
+
serialized = structured_response
|
|
1315
|
+
else:
|
|
1316
|
+
serialized = (
|
|
1317
|
+
dict(structured_response)
|
|
1318
|
+
if hasattr(structured_response, "__dict__")
|
|
1319
|
+
else structured_response
|
|
1369
1320
|
)
|
|
1370
1321
|
|
|
1371
|
-
|
|
1372
|
-
if isinstance(structured_response, BaseModel):
|
|
1373
|
-
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1374
|
-
elif is_dataclass(structured_response):
|
|
1375
|
-
serialized = asdict(structured_response)
|
|
1376
|
-
elif isinstance(structured_response, dict):
|
|
1377
|
-
serialized = structured_response
|
|
1378
|
-
else:
|
|
1379
|
-
serialized = (
|
|
1380
|
-
dict(structured_response)
|
|
1381
|
-
if hasattr(structured_response, "__dict__")
|
|
1382
|
-
else structured_response
|
|
1383
|
-
)
|
|
1322
|
+
import json
|
|
1384
1323
|
|
|
1385
|
-
|
|
1386
|
-
import json
|
|
1324
|
+
structured_text: str = json.dumps(serialized, indent=2)
|
|
1387
1325
|
|
|
1388
|
-
|
|
1326
|
+
if accumulated_content.strip():
|
|
1327
|
+
yield ResponsesAgentStreamEvent(
|
|
1328
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1329
|
+
)
|
|
1330
|
+
yield ResponsesAgentStreamEvent(
|
|
1331
|
+
**self.create_text_delta(delta=structured_text, item_id=item_id)
|
|
1332
|
+
)
|
|
1333
|
+
output_text = f"{accumulated_content}\n\n{structured_text}"
|
|
1334
|
+
else:
|
|
1335
|
+
yield ResponsesAgentStreamEvent(
|
|
1336
|
+
**self.create_text_delta(delta=structured_text, item_id=item_id)
|
|
1337
|
+
)
|
|
1338
|
+
output_text = structured_text
|
|
1389
1339
|
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1340
|
+
logger.trace("Streamed structured response in message content")
|
|
1341
|
+
|
|
1342
|
+
# Include interrupt structure if HITL occurred
|
|
1343
|
+
if interrupt_data:
|
|
1344
|
+
custom_outputs["interrupts"] = interrupt_data
|
|
1345
|
+
logger.info(
|
|
1346
|
+
"HITL: Included interrupts in streaming response",
|
|
1347
|
+
interrupts_count=len(interrupt_data),
|
|
1348
|
+
)
|
|
1349
|
+
|
|
1350
|
+
action_message = _format_action_requests_message(interrupt_data)
|
|
1351
|
+
if action_message:
|
|
1352
|
+
if not accumulated_content:
|
|
1353
|
+
output_text = action_message
|
|
1396
1354
|
yield ResponsesAgentStreamEvent(
|
|
1397
1355
|
**self.create_text_delta(
|
|
1398
|
-
delta=
|
|
1356
|
+
delta=action_message, item_id=item_id
|
|
1399
1357
|
)
|
|
1400
1358
|
)
|
|
1401
|
-
output_text = f"{accumulated_content}\n\n{structured_text}"
|
|
1402
1359
|
else:
|
|
1403
|
-
|
|
1360
|
+
output_text = f"{accumulated_content}\n\n{action_message}"
|
|
1361
|
+
yield ResponsesAgentStreamEvent(
|
|
1362
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1363
|
+
)
|
|
1404
1364
|
yield ResponsesAgentStreamEvent(
|
|
1405
1365
|
**self.create_text_delta(
|
|
1406
|
-
delta=
|
|
1366
|
+
delta=action_message, item_id=item_id
|
|
1407
1367
|
)
|
|
1408
1368
|
)
|
|
1409
|
-
output_text = structured_text
|
|
1410
1369
|
|
|
1411
|
-
|
|
1370
|
+
# Yield final output item
|
|
1371
|
+
yield ResponsesAgentStreamEvent(
|
|
1372
|
+
type="response.output_item.done",
|
|
1373
|
+
item=self.create_text_output_item(text=output_text, id=item_id),
|
|
1374
|
+
custom_outputs=custom_outputs,
|
|
1375
|
+
)
|
|
1376
|
+
except Exception as e:
|
|
1377
|
+
logger.error("Error in graph streaming", error=str(e))
|
|
1378
|
+
raise
|
|
1412
1379
|
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
logger.info(
|
|
1417
|
-
"HITL: Included interrupts in streaming response",
|
|
1418
|
-
interrupts_count=len(interrupt_data),
|
|
1419
|
-
)
|
|
1380
|
+
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
1381
|
+
"""
|
|
1382
|
+
Synchronous wrapper for apredict().
|
|
1420
1383
|
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
if action_message:
|
|
1424
|
-
# If we haven't streamed any content yet, stream the action message
|
|
1425
|
-
if not accumulated_content:
|
|
1426
|
-
output_text = action_message
|
|
1427
|
-
# Stream the action message
|
|
1428
|
-
yield ResponsesAgentStreamEvent(
|
|
1429
|
-
**self.create_text_delta(
|
|
1430
|
-
delta=action_message, item_id=item_id
|
|
1431
|
-
)
|
|
1432
|
-
)
|
|
1433
|
-
else:
|
|
1434
|
-
# Append action message after accumulated content
|
|
1435
|
-
output_text = f"{accumulated_content}\n\n{action_message}"
|
|
1436
|
-
# Stream the separator and action message
|
|
1437
|
-
yield ResponsesAgentStreamEvent(
|
|
1438
|
-
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1439
|
-
)
|
|
1440
|
-
yield ResponsesAgentStreamEvent(
|
|
1441
|
-
**self.create_text_delta(
|
|
1442
|
-
delta=action_message, item_id=item_id
|
|
1443
|
-
)
|
|
1444
|
-
)
|
|
1384
|
+
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
1385
|
+
For async contexts (e.g., Databricks Apps), use apredict() directly.
|
|
1445
1386
|
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1387
|
+
Note: This method uses asyncio.run() internally, which will fail in contexts
|
|
1388
|
+
where an event loop is already running (e.g., uvloop). For those cases,
|
|
1389
|
+
use apredict() instead.
|
|
1390
|
+
"""
|
|
1391
|
+
import asyncio
|
|
1392
|
+
|
|
1393
|
+
logger.debug("ResponsesAgent predict called (sync wrapper)")
|
|
1394
|
+
return asyncio.run(self.apredict(request))
|
|
1395
|
+
|
|
1396
|
+
def predict_stream(
|
|
1397
|
+
self, request: ResponsesAgentRequest
|
|
1398
|
+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
1399
|
+
"""
|
|
1400
|
+
Synchronous wrapper for apredict_stream().
|
|
1401
|
+
|
|
1402
|
+
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
1403
|
+
For async contexts (e.g., Databricks Apps), use apredict_stream() directly.
|
|
1404
|
+
|
|
1405
|
+
Note: This method converts the async generator to a sync generator using
|
|
1406
|
+
event loop manipulation. For contexts where an event loop is already running
|
|
1407
|
+
(e.g., uvloop), use apredict_stream() instead.
|
|
1408
|
+
"""
|
|
1409
|
+
import asyncio
|
|
1410
|
+
|
|
1411
|
+
logger.debug("ResponsesAgent predict_stream called (sync wrapper)")
|
|
1455
1412
|
|
|
1456
1413
|
# Convert async generator to sync generator
|
|
1457
1414
|
try:
|
|
@@ -1460,7 +1417,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
1460
1417
|
loop = asyncio.new_event_loop()
|
|
1461
1418
|
asyncio.set_event_loop(loop)
|
|
1462
1419
|
|
|
1463
|
-
async_gen =
|
|
1420
|
+
async_gen = self.apredict_stream(request)
|
|
1464
1421
|
|
|
1465
1422
|
try:
|
|
1466
1423
|
while True:
|