dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py CHANGED
@@ -1,7 +1,16 @@
1
1
  import uuid
2
2
  from os import PathLike
3
3
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Sequence, Union
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ AsyncGenerator,
8
+ Generator,
9
+ Literal,
10
+ Optional,
11
+ Sequence,
12
+ Union,
13
+ )
5
14
 
6
15
  from databricks_langchain import ChatDatabricks
7
16
 
@@ -825,13 +834,16 @@ class LanggraphResponsesAgent(ResponsesAgent):
825
834
  ) -> None:
826
835
  self.graph = graph
827
836
 
828
- def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
837
+ async def apredict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
829
838
  """
839
+ Async version of predict - primary implementation for Databricks Apps.
840
+
830
841
  Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
842
+ This method can be awaited directly in async contexts (e.g., MLflow AgentServer).
831
843
 
832
844
  Input structure (custom_inputs):
833
845
  configurable:
834
- thread_id: "abc-123" # Or conversation_id (aliases, conversation_id takes precedence)
846
+ thread_id: "abc-123" # Or conversation_id (aliases)
835
847
  user_id: "nate.fleming"
836
848
  store_num: "87887"
837
849
  session: # Paste from previous output
@@ -846,11 +858,11 @@ class LanggraphResponsesAgent(ResponsesAgent):
846
858
 
847
859
  Output structure (custom_outputs):
848
860
  configurable:
849
- thread_id: "abc-123" # Only thread_id in configurable
861
+ thread_id: "abc-123"
850
862
  user_id: "nate.fleming"
851
863
  store_num: "87887"
852
864
  session:
853
- conversation_id: "abc-123" # conversation_id in session
865
+ conversation_id: "abc-123"
854
866
  genie:
855
867
  spaces:
856
868
  space_123: {conversation_id: "conv_456", ...}
@@ -859,12 +871,13 @@ class LanggraphResponsesAgent(ResponsesAgent):
859
871
  arguments: {...}
860
872
  description: "..."
861
873
  """
862
- # Extract conversation_id for logging (from context or custom_inputs)
874
+ from langgraph.types import Command
875
+
876
+ # Extract conversation_id for logging
863
877
  conversation_id_for_log: str | None = None
864
878
  if request.context and hasattr(request.context, "conversation_id"):
865
879
  conversation_id_for_log = request.context.conversation_id
866
880
  elif request.custom_inputs:
867
- # Check configurable or session for conversation_id
868
881
  if "configurable" in request.custom_inputs and isinstance(
869
882
  request.custom_inputs["configurable"], dict
870
883
  ):
@@ -881,7 +894,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
881
894
  )
882
895
 
883
896
  logger.debug(
884
- "ResponsesAgent predict called",
897
+ "ResponsesAgent apredict called",
885
898
  conversation_id=conversation_id_for_log
886
899
  if conversation_id_for_log
887
900
  else "new",
@@ -899,130 +912,106 @@ class LanggraphResponsesAgent(ResponsesAgent):
899
912
  # Extract session state from request
900
913
  session_input: dict[str, Any] = self._extract_session_from_request(request)
901
914
 
902
- # Use async ainvoke internally for parallel execution
903
- import asyncio
904
-
905
- from langgraph.types import Command
915
+ try:
916
+ # Check if this is a resume request (HITL)
917
+ if request.custom_inputs and "decisions" in request.custom_inputs:
918
+ # Explicit structured decisions
919
+ decisions: list[Decision] = request.custom_inputs["decisions"]
920
+ logger.info(
921
+ "HITL: Resuming with explicit decisions",
922
+ decisions_count=len(decisions),
923
+ )
906
924
 
907
- async def _async_invoke():
908
- try:
909
- # Check if this is a resume request (HITL)
910
- # Two ways to resume:
911
- # 1. Explicit decisions in custom_inputs (structured)
912
- # 2. Natural language message when graph is interrupted (LLM-parsed)
913
-
914
- if request.custom_inputs and "decisions" in request.custom_inputs:
915
- # Explicit structured decisions
916
- decisions: list[Decision] = request.custom_inputs["decisions"]
917
- logger.info(
918
- "HITL: Resuming with explicit decisions",
919
- decisions_count=len(decisions),
920
- )
925
+ # Resume interrupted graph with decisions
926
+ response = await self.graph.ainvoke(
927
+ Command(resume={"decisions": decisions}),
928
+ context=context,
929
+ config=custom_inputs,
930
+ )
931
+ elif self.graph.checkpointer:
932
+ # Check if graph is currently interrupted
933
+ snapshot: StateSnapshot = await self.graph.aget_state(
934
+ config=custom_inputs
935
+ )
936
+ if is_interrupted(snapshot):
937
+ logger.info("HITL: Graph interrupted, checking for user response")
921
938
 
922
- # Resume interrupted graph with decisions
923
- return await self.graph.ainvoke(
924
- Command(resume={"decisions": decisions}),
925
- context=context,
926
- config=custom_inputs,
939
+ # Convert message dicts to BaseMessage objects
940
+ message_objects: list[BaseMessage] = convert_openai_messages(
941
+ messages
927
942
  )
928
943
 
929
- # Check if graph is currently interrupted (only if checkpointer is configured)
930
- # aget_state requires a checkpointer
931
- if self.graph.checkpointer:
932
- snapshot: StateSnapshot = await self.graph.aget_state(
933
- config=custom_inputs
944
+ # Parse user's message with LLM to extract decisions
945
+ parsed_result: dict[str, Any] = handle_interrupt_response(
946
+ snapshot=snapshot,
947
+ messages=message_objects,
948
+ model=None,
934
949
  )
935
- if is_interrupted(snapshot):
936
- logger.info(
937
- "HITL: Graph interrupted, checking for user response"
938
- )
939
950
 
940
- # Convert message dicts to BaseMessage objects
941
- message_objects: list[BaseMessage] = convert_openai_messages(
942
- messages
951
+ if not parsed_result.get("is_valid", False):
952
+ validation_message: str = parsed_result.get(
953
+ "validation_message",
954
+ "Your response was unclear. Please provide a clear decision for each action.",
943
955
  )
944
-
945
- # Parse user's message with LLM to extract decisions
946
- parsed_result: dict[str, Any] = handle_interrupt_response(
947
- snapshot=snapshot,
948
- messages=message_objects,
949
- model=None, # Uses default model
956
+ logger.warning(
957
+ "HITL: Invalid response from user",
958
+ validation_message=validation_message,
950
959
  )
951
960
 
952
- # Check if the response was valid
953
- if not parsed_result.get("is_valid", False):
954
- validation_message: str = parsed_result.get(
955
- "validation_message",
956
- "Your response was unclear. Please provide a clear decision for each action.",
957
- )
958
- logger.warning(
959
- "HITL: Invalid response from user",
960
- validation_message=validation_message,
961
- )
962
-
963
- # Return error message to user instead of resuming
964
- # Don't resume the graph - stay interrupted so user can try again
965
- return {
966
- "messages": [
967
- AIMessage(
968
- content=f"❌ **Invalid Response**\n\n{validation_message}"
969
- )
970
- ]
971
- }
972
-
973
- decisions: list[Decision] = parsed_result.get("decisions", [])
961
+ # Return error message without resuming
962
+ response = {
963
+ "messages": [
964
+ AIMessage(
965
+ content=f" **Invalid Response**\n\n{validation_message}"
966
+ )
967
+ ]
968
+ }
969
+ else:
970
+ decisions = parsed_result.get("decisions", [])
974
971
  logger.info(
975
972
  "HITL: LLM parsed valid decisions from user message",
976
973
  decisions_count=len(decisions),
977
974
  )
978
975
 
979
976
  # Resume interrupted graph with parsed decisions
980
- return await self.graph.ainvoke(
977
+ response = await self.graph.ainvoke(
981
978
  Command(resume={"decisions": decisions}),
982
979
  context=context,
983
980
  config=custom_inputs,
984
981
  )
982
+ else:
983
+ # Normal invocation
984
+ graph_input: dict[str, Any] = {"messages": messages}
985
+ if "genie_conversation_ids" in session_input:
986
+ graph_input["genie_conversation_ids"] = session_input[
987
+ "genie_conversation_ids"
988
+ ]
985
989
 
986
- # Normal invocation - build the graph input state
987
- graph_input: dict[str, Any] = {"messages": messages}
990
+ response = await self.graph.ainvoke(
991
+ graph_input, context=context, config=custom_inputs
992
+ )
993
+ else:
994
+ # No checkpointer, use normal invocation
995
+ graph_input = {"messages": messages}
988
996
  if "genie_conversation_ids" in session_input:
989
997
  graph_input["genie_conversation_ids"] = session_input[
990
998
  "genie_conversation_ids"
991
999
  ]
992
- logger.trace(
993
- "Including genie conversation IDs in graph input",
994
- count=len(graph_input["genie_conversation_ids"]),
995
- )
996
1000
 
997
- return await self.graph.ainvoke(
1001
+ response = await self.graph.ainvoke(
998
1002
  graph_input, context=context, config=custom_inputs
999
1003
  )
1000
- except Exception as e:
1001
- logger.error("Error in graph invocation", error=str(e))
1002
- raise
1003
-
1004
- try:
1005
- loop = asyncio.get_event_loop()
1006
- except RuntimeError:
1007
- loop = asyncio.new_event_loop()
1008
- asyncio.set_event_loop(loop)
1009
-
1010
- try:
1011
- response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
1012
- _async_invoke()
1013
- )
1014
1004
  except Exception as e:
1015
- logger.error("Error in async execution", error=str(e))
1005
+ logger.error("Error in graph invocation", error=str(e))
1016
1006
  raise
1017
1007
 
1018
1008
  # Convert response to ResponsesAgent format
1019
1009
  last_message: BaseMessage = response["messages"][-1]
1020
1010
 
1021
- # Build custom_outputs that can be copy-pasted as next request's custom_inputs
1022
- custom_outputs: dict[str, Any] = self._build_custom_outputs(
1011
+ # Build custom_outputs
1012
+ custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
1023
1013
  context=context,
1024
1014
  thread_id=context.thread_id,
1025
- loop=loop,
1026
1015
  )
1027
1016
 
1028
1017
  # Handle structured_response if present
@@ -1037,25 +1026,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
1037
1026
  response_type=type(structured_response).__name__,
1038
1027
  )
1039
1028
 
1040
- # Serialize to dict for JSON compatibility using type hints
1041
1029
  if isinstance(structured_response, BaseModel):
1042
- # Pydantic model
1043
1030
  serialized: dict[str, Any] = structured_response.model_dump()
1044
1031
  elif is_dataclass(structured_response):
1045
- # Dataclass
1046
1032
  serialized = asdict(structured_response)
1047
1033
  elif isinstance(structured_response, dict):
1048
- # Already a dict
1049
1034
  serialized = structured_response
1050
1035
  else:
1051
- # Unknown type, convert to dict if possible
1052
1036
  serialized = (
1053
1037
  dict(structured_response)
1054
1038
  if hasattr(structured_response, "__dict__")
1055
1039
  else structured_response
1056
1040
  )
1057
1041
 
1058
- # Place structured output in message content as JSON
1059
1042
  import json
1060
1043
 
1061
1044
  structured_text: str = json.dumps(serialized, indent=2)
@@ -1064,22 +1047,18 @@ class LanggraphResponsesAgent(ResponsesAgent):
1064
1047
  )
1065
1048
  logger.trace("Structured response placed in message content")
1066
1049
  else:
1067
- # No structured response, use text content
1068
1050
  output_item = self.create_text_output_item(
1069
1051
  text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
1070
1052
  )
1071
1053
 
1072
- # Include interrupt structure if HITL occurred (following LangChain pattern)
1054
+ # Include interrupt structure if HITL occurred
1073
1055
  if "__interrupt__" in response:
1074
1056
  interrupts: list[Interrupt] = response["__interrupt__"]
1075
1057
  logger.info("HITL: Interrupts detected", interrupts_count=len(interrupts))
1076
1058
 
1077
- # Extract HITLRequest structures from interrupts (deduplicate by ID)
1078
1059
  seen_interrupt_ids: set[str] = set()
1079
1060
  interrupt_data: list[HITLRequest] = []
1080
- interrupt: Interrupt
1081
1061
  for interrupt in interrupts:
1082
- # Only process each unique interrupt once
1083
1062
  if interrupt.id not in seen_interrupt_ids:
1084
1063
  seen_interrupt_ids.add(interrupt.id)
1085
1064
  interrupt_data.append(_extract_interrupt_value(interrupt))
@@ -1093,7 +1072,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
1093
1072
  interrupts_count=len(interrupt_data),
1094
1073
  )
1095
1074
 
1096
- # Add user-facing message about the pending actions
1097
1075
  action_message: str = _format_action_requests_message(interrupt_data)
1098
1076
  if action_message:
1099
1077
  output_item = self.create_text_output_item(
@@ -1104,21 +1082,25 @@ class LanggraphResponsesAgent(ResponsesAgent):
1104
1082
  output=[output_item], custom_outputs=custom_outputs
1105
1083
  )
1106
1084
 
1107
- def predict_stream(
1085
+ async def apredict_stream(
1108
1086
  self, request: ResponsesAgentRequest
1109
- ) -> Generator[ResponsesAgentStreamEvent, None, None]:
1087
+ ) -> AsyncGenerator[ResponsesAgentStreamEvent, None]:
1110
1088
  """
1089
+ Async version of predict_stream - primary implementation for Databricks Apps.
1090
+
1111
1091
  Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
1092
+ This method can be used directly with async for loops in async contexts.
1112
1093
 
1113
- Uses same input/output structure as predict() for consistency.
1094
+ Uses same input/output structure as apredict() for consistency.
1114
1095
  Supports Human-in-the-Loop (HITL) interrupts.
1115
1096
  """
1116
- # Extract conversation_id for logging (from context or custom_inputs)
1097
+ from langgraph.types import Command
1098
+
1099
+ # Extract conversation_id for logging
1117
1100
  conversation_id_for_log: str | None = None
1118
1101
  if request.context and hasattr(request.context, "conversation_id"):
1119
1102
  conversation_id_for_log = request.context.conversation_id
1120
1103
  elif request.custom_inputs:
1121
- # Check configurable or session for conversation_id
1122
1104
  if "configurable" in request.custom_inputs and isinstance(
1123
1105
  request.custom_inputs["configurable"], dict
1124
1106
  ):
@@ -1135,7 +1117,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
1135
1117
  )
1136
1118
 
1137
1119
  logger.debug(
1138
- "ResponsesAgent predict_stream called",
1120
+ "ResponsesAgent apredict_stream called",
1139
1121
  conversation_id=conversation_id_for_log
1140
1122
  if conversation_id_for_log
1141
1123
  else "new",
@@ -1153,305 +1135,280 @@ class LanggraphResponsesAgent(ResponsesAgent):
1153
1135
  # Extract session state from request
1154
1136
  session_input: dict[str, Any] = self._extract_session_from_request(request)
1155
1137
 
1156
- # Use async astream internally for parallel execution
1157
- import asyncio
1158
-
1159
- from langgraph.types import Command
1160
-
1161
- async def _async_stream():
1162
- item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
1163
- accumulated_content: str = ""
1164
- interrupt_data: list[HITLRequest] = []
1165
- seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
1166
- structured_response: Any = None # Track structured output from stream
1138
+ item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
1139
+ accumulated_content: str = ""
1140
+ interrupt_data: list[HITLRequest] = []
1141
+ seen_interrupt_ids: set[str] = set()
1142
+ structured_response: Any = None
1167
1143
 
1168
- try:
1169
- # Check if this is a resume request (HITL)
1170
- # Two ways to resume:
1171
- # 1. Explicit decisions in custom_inputs (structured)
1172
- # 2. Natural language message when graph is interrupted (LLM-parsed)
1173
-
1174
- if request.custom_inputs and "decisions" in request.custom_inputs:
1175
- # Explicit structured decisions
1176
- decisions: list[Decision] = request.custom_inputs["decisions"]
1144
+ try:
1145
+ # Check if this is a resume request (HITL)
1146
+ if request.custom_inputs and "decisions" in request.custom_inputs:
1147
+ decisions: list[Decision] = request.custom_inputs["decisions"]
1148
+ logger.info(
1149
+ "HITL: Resuming stream with explicit decisions",
1150
+ decisions_count=len(decisions),
1151
+ )
1152
+ stream_input: Command | dict[str, Any] = Command(
1153
+ resume={"decisions": decisions}
1154
+ )
1155
+ elif self.graph.checkpointer:
1156
+ snapshot: StateSnapshot = await self.graph.aget_state(
1157
+ config=custom_inputs
1158
+ )
1159
+ if is_interrupted(snapshot):
1177
1160
  logger.info(
1178
- "HITL: Resuming stream with explicit decisions",
1179
- decisions_count=len(decisions),
1161
+ "HITL: Graph interrupted, checking for user response in stream"
1180
1162
  )
1181
- stream_input: Command | dict[str, Any] = Command(
1182
- resume={"decisions": decisions}
1163
+
1164
+ message_objects: list[BaseMessage] = convert_openai_messages(
1165
+ messages
1183
1166
  )
1184
- elif self.graph.checkpointer:
1185
- # Check if graph is currently interrupted (only if checkpointer is configured)
1186
- # aget_state requires a checkpointer
1187
- snapshot: StateSnapshot = await self.graph.aget_state(
1188
- config=custom_inputs
1167
+
1168
+ parsed_result: dict[str, Any] = handle_interrupt_response(
1169
+ snapshot=snapshot,
1170
+ messages=message_objects,
1171
+ model=None,
1189
1172
  )
1190
- if is_interrupted(snapshot):
1191
- logger.info(
1192
- "HITL: Graph interrupted, checking for user response in stream"
1193
- )
1194
1173
 
1195
- # Convert message dicts to BaseMessage objects
1196
- message_objects: list[BaseMessage] = convert_openai_messages(
1197
- messages
1174
+ if not parsed_result.get("is_valid", False):
1175
+ validation_message: str = parsed_result.get(
1176
+ "validation_message",
1177
+ "Your response was unclear. Please provide a clear decision for each action.",
1198
1178
  )
1199
-
1200
- # Parse user's message with LLM to extract decisions
1201
- parsed_result: dict[str, Any] = handle_interrupt_response(
1202
- snapshot=snapshot,
1203
- messages=message_objects,
1204
- model=None, # Uses default model
1179
+ logger.warning(
1180
+ "HITL: Invalid response from user in stream",
1181
+ validation_message=validation_message,
1205
1182
  )
1206
1183
 
1207
- # Check if the response was valid
1208
- if not parsed_result.get("is_valid", False):
1209
- validation_message: str = parsed_result.get(
1210
- "validation_message",
1211
- "Your response was unclear. Please provide a clear decision for each action.",
1212
- )
1213
- logger.warning(
1214
- "HITL: Invalid response from user in stream",
1215
- validation_message=validation_message,
1216
- )
1217
-
1218
- # Build custom_outputs before returning
1219
- custom_outputs: dict[
1220
- str, Any
1221
- ] = await self._build_custom_outputs_async(
1222
- context=context,
1223
- thread_id=context.thread_id,
1224
- )
1225
-
1226
- # Yield error message to user - don't resume graph
1227
- error_message: str = (
1228
- f"❌ **Invalid Response**\n\n{validation_message}"
1229
- )
1230
- accumulated_content = error_message
1231
- yield ResponsesAgentStreamEvent(
1232
- type="response.output_item.done",
1233
- item=self.create_text_output_item(
1234
- text=error_message, id=item_id
1235
- ),
1236
- custom_outputs=custom_outputs,
1237
- )
1238
- return # Don't resume - stay interrupted
1239
-
1240
- decisions: list[Decision] = parsed_result.get("decisions", [])
1241
- logger.info(
1242
- "HITL: LLM parsed valid decisions from user message in stream",
1243
- decisions_count=len(decisions),
1184
+ custom_outputs: dict[
1185
+ str, Any
1186
+ ] = await self._build_custom_outputs_async(
1187
+ context=context,
1188
+ thread_id=context.thread_id,
1244
1189
  )
1245
1190
 
1246
- # Resume interrupted graph with parsed decisions
1247
- stream_input: Command | dict[str, Any] = Command(
1248
- resume={"decisions": decisions}
1191
+ error_message: str = (
1192
+ f"❌ **Invalid Response**\n\n{validation_message}"
1249
1193
  )
1250
- else:
1251
- # Graph not interrupted, use normal invocation
1252
- graph_input: dict[str, Any] = {"messages": messages}
1253
- if "genie_conversation_ids" in session_input:
1254
- graph_input["genie_conversation_ids"] = session_input[
1255
- "genie_conversation_ids"
1256
- ]
1257
- stream_input: Command | dict[str, Any] = graph_input
1194
+ yield ResponsesAgentStreamEvent(
1195
+ type="response.output_item.done",
1196
+ item=self.create_text_output_item(
1197
+ text=error_message, id=item_id
1198
+ ),
1199
+ custom_outputs=custom_outputs,
1200
+ )
1201
+ return
1202
+
1203
+ decisions = parsed_result.get("decisions", [])
1204
+ logger.info(
1205
+ "HITL: LLM parsed valid decisions from user message in stream",
1206
+ decisions_count=len(decisions),
1207
+ )
1208
+
1209
+ stream_input = Command(resume={"decisions": decisions})
1258
1210
  else:
1259
- # No checkpointer, use normal invocation
1260
1211
  graph_input: dict[str, Any] = {"messages": messages}
1261
1212
  if "genie_conversation_ids" in session_input:
1262
1213
  graph_input["genie_conversation_ids"] = session_input[
1263
1214
  "genie_conversation_ids"
1264
1215
  ]
1265
- stream_input: Command | dict[str, Any] = graph_input
1216
+ stream_input = graph_input
1217
+ else:
1218
+ graph_input = {"messages": messages}
1219
+ if "genie_conversation_ids" in session_input:
1220
+ graph_input["genie_conversation_ids"] = session_input[
1221
+ "genie_conversation_ids"
1222
+ ]
1223
+ stream_input = graph_input
1266
1224
 
1267
- # Stream the graph execution with both messages and updates modes to capture interrupts
1268
- async for nodes, stream_mode, data in self.graph.astream(
1269
- stream_input,
1270
- context=context,
1271
- config=custom_inputs,
1272
- stream_mode=["messages", "updates"],
1273
- subgraphs=True,
1274
- ):
1275
- nodes: tuple[str, ...]
1276
- stream_mode: str
1277
-
1278
- # Handle message streaming
1279
- if stream_mode == "messages":
1280
- messages_batch: Sequence[BaseMessage] = data
1281
- message: BaseMessage
1282
- for message in messages_batch:
1283
- if (
1284
- isinstance(
1285
- message,
1286
- (
1287
- AIMessageChunk,
1288
- AIMessage,
1289
- ),
1290
- )
1291
- and message.content
1292
- and "summarization" not in nodes
1293
- ):
1294
- content: str = message.content
1295
- accumulated_content += content
1296
-
1297
- # Yield streaming delta
1298
- yield ResponsesAgentStreamEvent(
1299
- **self.create_text_delta(
1300
- delta=content, item_id=item_id
1301
- )
1302
- )
1225
+ # Stream the graph execution
1226
+ async for nodes, stream_mode, data in self.graph.astream(
1227
+ stream_input,
1228
+ context=context,
1229
+ config=custom_inputs,
1230
+ stream_mode=["messages", "updates"],
1231
+ subgraphs=True,
1232
+ ):
1233
+ nodes: tuple[str, ...]
1234
+ stream_mode: str
1303
1235
 
1304
- # Handle interrupts (HITL) and state updates
1305
- elif stream_mode == "updates":
1306
- updates: dict[str, Any] = data
1307
- source: str
1308
- update: Any
1309
- for source, update in updates.items():
1310
- if source == "__interrupt__":
1311
- interrupts: list[Interrupt] = update
1312
- logger.info(
1313
- "HITL: Interrupts detected during streaming",
1314
- interrupts_count=len(interrupts),
1315
- )
1236
+ if stream_mode == "messages":
1237
+ messages_batch: Sequence[BaseMessage] = data
1238
+ for message in messages_batch:
1239
+ if (
1240
+ isinstance(message, (AIMessageChunk, AIMessage))
1241
+ and message.content
1242
+ and "summarization" not in nodes
1243
+ ):
1244
+ content: str = message.content
1245
+ accumulated_content += content
1316
1246
 
1317
- # Extract interrupt values (deduplicate by ID)
1318
- interrupt: Interrupt
1319
- for interrupt in interrupts:
1320
- # Only process each unique interrupt once
1321
- if interrupt.id not in seen_interrupt_ids:
1322
- seen_interrupt_ids.add(interrupt.id)
1323
- interrupt_data.append(
1324
- _extract_interrupt_value(interrupt)
1325
- )
1326
- logger.trace(
1327
- "HITL: Added interrupt to response",
1328
- interrupt_id=interrupt.id,
1329
- )
1330
- elif (
1331
- isinstance(update, dict)
1332
- and "structured_response" in update
1333
- ):
1334
- # Capture structured_response from stream updates
1335
- structured_response = update["structured_response"]
1336
- logger.trace(
1337
- "Captured structured response from stream",
1338
- response_type=type(structured_response).__name__,
1339
- )
1247
+ yield ResponsesAgentStreamEvent(
1248
+ **self.create_text_delta(delta=content, item_id=item_id)
1249
+ )
1340
1250
 
1341
- # Get final state to extract structured_response (only if checkpointer available)
1342
- if self.graph.checkpointer:
1343
- final_state: StateSnapshot = await self.graph.aget_state(
1344
- config=custom_inputs
1345
- )
1346
- # Extract structured_response from state if not already captured
1347
- if (
1348
- "structured_response" in final_state.values
1349
- and not structured_response
1350
- ):
1351
- structured_response = final_state.values["structured_response"]
1251
+ elif stream_mode == "updates":
1252
+ updates: dict[str, Any] = data
1253
+ for source, update in updates.items():
1254
+ if source == "__interrupt__":
1255
+ interrupts: list[Interrupt] = update
1256
+ logger.info(
1257
+ "HITL: Interrupts detected during streaming",
1258
+ interrupts_count=len(interrupts),
1259
+ )
1352
1260
 
1353
- # Build custom_outputs
1354
- custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
1355
- context=context,
1356
- thread_id=context.thread_id,
1261
+ for interrupt in interrupts:
1262
+ if interrupt.id not in seen_interrupt_ids:
1263
+ seen_interrupt_ids.add(interrupt.id)
1264
+ interrupt_data.append(
1265
+ _extract_interrupt_value(interrupt)
1266
+ )
1267
+ logger.trace(
1268
+ "HITL: Added interrupt to response",
1269
+ interrupt_id=interrupt.id,
1270
+ )
1271
+ elif (
1272
+ isinstance(update, dict) and "structured_response" in update
1273
+ ):
1274
+ structured_response = update["structured_response"]
1275
+ logger.trace(
1276
+ "Captured structured response from stream",
1277
+ response_type=type(structured_response).__name__,
1278
+ )
1279
+
1280
+ # Get final state if checkpointer available
1281
+ if self.graph.checkpointer:
1282
+ final_state: StateSnapshot = await self.graph.aget_state(
1283
+ config=custom_inputs
1357
1284
  )
1285
+ if (
1286
+ "structured_response" in final_state.values
1287
+ and not structured_response
1288
+ ):
1289
+ structured_response = final_state.values["structured_response"]
1290
+
1291
+ # Build custom_outputs
1292
+ custom_outputs = await self._build_custom_outputs_async(
1293
+ context=context,
1294
+ thread_id=context.thread_id,
1295
+ )
1358
1296
 
1359
- # Handle structured_response in streaming if present
1360
- output_text: str = accumulated_content
1361
- if structured_response:
1362
- from dataclasses import asdict, is_dataclass
1297
+ # Handle structured_response in streaming
1298
+ output_text: str = accumulated_content
1299
+ if structured_response:
1300
+ from dataclasses import asdict, is_dataclass
1363
1301
 
1364
- from pydantic import BaseModel
1302
+ from pydantic import BaseModel
1365
1303
 
1366
- logger.trace(
1367
- "Processing structured response in streaming",
1368
- response_type=type(structured_response).__name__,
1304
+ logger.trace(
1305
+ "Processing structured response in streaming",
1306
+ response_type=type(structured_response).__name__,
1307
+ )
1308
+
1309
+ if isinstance(structured_response, BaseModel):
1310
+ serialized: dict[str, Any] = structured_response.model_dump()
1311
+ elif is_dataclass(structured_response):
1312
+ serialized = asdict(structured_response)
1313
+ elif isinstance(structured_response, dict):
1314
+ serialized = structured_response
1315
+ else:
1316
+ serialized = (
1317
+ dict(structured_response)
1318
+ if hasattr(structured_response, "__dict__")
1319
+ else structured_response
1369
1320
  )
1370
1321
 
1371
- # Serialize to dict for JSON compatibility using type hints
1372
- if isinstance(structured_response, BaseModel):
1373
- serialized: dict[str, Any] = structured_response.model_dump()
1374
- elif is_dataclass(structured_response):
1375
- serialized = asdict(structured_response)
1376
- elif isinstance(structured_response, dict):
1377
- serialized = structured_response
1378
- else:
1379
- serialized = (
1380
- dict(structured_response)
1381
- if hasattr(structured_response, "__dict__")
1382
- else structured_response
1383
- )
1322
+ import json
1384
1323
 
1385
- # Place structured output in message content - stream as JSON
1386
- import json
1324
+ structured_text: str = json.dumps(serialized, indent=2)
1387
1325
 
1388
- structured_text: str = json.dumps(serialized, indent=2)
1326
+ if accumulated_content.strip():
1327
+ yield ResponsesAgentStreamEvent(
1328
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
1329
+ )
1330
+ yield ResponsesAgentStreamEvent(
1331
+ **self.create_text_delta(delta=structured_text, item_id=item_id)
1332
+ )
1333
+ output_text = f"{accumulated_content}\n\n{structured_text}"
1334
+ else:
1335
+ yield ResponsesAgentStreamEvent(
1336
+ **self.create_text_delta(delta=structured_text, item_id=item_id)
1337
+ )
1338
+ output_text = structured_text
1389
1339
 
1390
- # If we streamed text, append structured; if no text, use structured only
1391
- if accumulated_content.strip():
1392
- # Stream separator and structured output
1393
- yield ResponsesAgentStreamEvent(
1394
- **self.create_text_delta(delta="\n\n", item_id=item_id)
1395
- )
1340
+ logger.trace("Streamed structured response in message content")
1341
+
1342
+ # Include interrupt structure if HITL occurred
1343
+ if interrupt_data:
1344
+ custom_outputs["interrupts"] = interrupt_data
1345
+ logger.info(
1346
+ "HITL: Included interrupts in streaming response",
1347
+ interrupts_count=len(interrupt_data),
1348
+ )
1349
+
1350
+ action_message = _format_action_requests_message(interrupt_data)
1351
+ if action_message:
1352
+ if not accumulated_content:
1353
+ output_text = action_message
1396
1354
  yield ResponsesAgentStreamEvent(
1397
1355
  **self.create_text_delta(
1398
- delta=structured_text, item_id=item_id
1356
+ delta=action_message, item_id=item_id
1399
1357
  )
1400
1358
  )
1401
- output_text = f"{accumulated_content}\n\n{structured_text}"
1402
1359
  else:
1403
- # No text content, stream structured output
1360
+ output_text = f"{accumulated_content}\n\n{action_message}"
1361
+ yield ResponsesAgentStreamEvent(
1362
+ **self.create_text_delta(delta="\n\n", item_id=item_id)
1363
+ )
1404
1364
  yield ResponsesAgentStreamEvent(
1405
1365
  **self.create_text_delta(
1406
- delta=structured_text, item_id=item_id
1366
+ delta=action_message, item_id=item_id
1407
1367
  )
1408
1368
  )
1409
- output_text = structured_text
1410
1369
 
1411
- logger.trace("Streamed structured response in message content")
1370
+ # Yield final output item
1371
+ yield ResponsesAgentStreamEvent(
1372
+ type="response.output_item.done",
1373
+ item=self.create_text_output_item(text=output_text, id=item_id),
1374
+ custom_outputs=custom_outputs,
1375
+ )
1376
+ except Exception as e:
1377
+ logger.error("Error in graph streaming", error=str(e))
1378
+ raise
1412
1379
 
1413
- # Include interrupt structure if HITL occurred
1414
- if interrupt_data:
1415
- custom_outputs["interrupts"] = interrupt_data
1416
- logger.info(
1417
- "HITL: Included interrupts in streaming response",
1418
- interrupts_count=len(interrupt_data),
1419
- )
1380
+ def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
1381
+ """
1382
+ Synchronous wrapper for apredict().
1420
1383
 
1421
- # Add user-facing message about the pending actions
1422
- action_message = _format_action_requests_message(interrupt_data)
1423
- if action_message:
1424
- # If we haven't streamed any content yet, stream the action message
1425
- if not accumulated_content:
1426
- output_text = action_message
1427
- # Stream the action message
1428
- yield ResponsesAgentStreamEvent(
1429
- **self.create_text_delta(
1430
- delta=action_message, item_id=item_id
1431
- )
1432
- )
1433
- else:
1434
- # Append action message after accumulated content
1435
- output_text = f"{accumulated_content}\n\n{action_message}"
1436
- # Stream the separator and action message
1437
- yield ResponsesAgentStreamEvent(
1438
- **self.create_text_delta(delta="\n\n", item_id=item_id)
1439
- )
1440
- yield ResponsesAgentStreamEvent(
1441
- **self.create_text_delta(
1442
- delta=action_message, item_id=item_id
1443
- )
1444
- )
1384
+ Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
1385
+ For async contexts (e.g., Databricks Apps), use apredict() directly.
1445
1386
 
1446
- # Yield final output item
1447
- yield ResponsesAgentStreamEvent(
1448
- type="response.output_item.done",
1449
- item=self.create_text_output_item(text=output_text, id=item_id),
1450
- custom_outputs=custom_outputs,
1451
- )
1452
- except Exception as e:
1453
- logger.error("Error in graph streaming", error=str(e))
1454
- raise
1387
+ Note: This method uses asyncio.run() internally, which will fail in contexts
1388
+ where an event loop is already running (e.g., uvloop). For those cases,
1389
+ use apredict() instead.
1390
+ """
1391
+ import asyncio
1392
+
1393
+ logger.debug("ResponsesAgent predict called (sync wrapper)")
1394
+ return asyncio.run(self.apredict(request))
1395
+
1396
+ def predict_stream(
1397
+ self, request: ResponsesAgentRequest
1398
+ ) -> Generator[ResponsesAgentStreamEvent, None, None]:
1399
+ """
1400
+ Synchronous wrapper for apredict_stream().
1401
+
1402
+ Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
1403
+ For async contexts (e.g., Databricks Apps), use apredict_stream() directly.
1404
+
1405
+ Note: This method converts the async generator to a sync generator using
1406
+ event loop manipulation. For contexts where an event loop is already running
1407
+ (e.g., uvloop), use apredict_stream() instead.
1408
+ """
1409
+ import asyncio
1410
+
1411
+ logger.debug("ResponsesAgent predict_stream called (sync wrapper)")
1455
1412
 
1456
1413
  # Convert async generator to sync generator
1457
1414
  try:
@@ -1460,7 +1417,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
1460
1417
  loop = asyncio.new_event_loop()
1461
1418
  asyncio.set_event_loop(loop)
1462
1419
 
1463
- async_gen = _async_stream()
1420
+ async_gen = self.apredict_stream(request)
1464
1421
 
1465
1422
  try:
1466
1423
  while True: