dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py
CHANGED
|
@@ -1,11 +1,34 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from os import PathLike
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generator, Optional, Sequence, Union
|
|
5
|
-
|
|
6
|
-
from
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from databricks_langchain import ChatDatabricks
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
# Import official LangChain HITL TypedDict definitions
|
|
12
|
+
# Reference: https://docs.langchain.com/oss/python/langchain/human-in-the-loop
|
|
13
|
+
from langchain.agents.middleware.human_in_the_loop import (
|
|
14
|
+
ActionRequest,
|
|
15
|
+
Decision,
|
|
16
|
+
EditDecision,
|
|
17
|
+
HITLRequest,
|
|
18
|
+
RejectDecision,
|
|
19
|
+
ReviewConfig,
|
|
20
|
+
)
|
|
21
|
+
from langchain_community.adapters.openai import convert_openai_messages
|
|
22
|
+
from langchain_core.language_models import LanguageModelLike
|
|
23
|
+
from langchain_core.messages import (
|
|
24
|
+
AIMessage,
|
|
25
|
+
AIMessageChunk,
|
|
26
|
+
BaseMessage,
|
|
27
|
+
HumanMessage,
|
|
28
|
+
SystemMessage,
|
|
29
|
+
)
|
|
7
30
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
|
-
from langgraph.types import StateSnapshot
|
|
31
|
+
from langgraph.types import Interrupt, StateSnapshot
|
|
9
32
|
from loguru import logger
|
|
10
33
|
from mlflow import MlflowClient
|
|
11
34
|
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
@@ -28,11 +51,13 @@ from mlflow.types.responses_helpers import (
|
|
|
28
51
|
Message,
|
|
29
52
|
ResponseInputTextParam,
|
|
30
53
|
)
|
|
54
|
+
from pydantic import BaseModel, Field, create_model
|
|
31
55
|
|
|
32
56
|
from dao_ai.messages import (
|
|
33
57
|
has_langchain_messages,
|
|
34
58
|
has_mlflow_messages,
|
|
35
59
|
has_mlflow_responses_messages,
|
|
60
|
+
last_human_message,
|
|
36
61
|
)
|
|
37
62
|
from dao_ai.state import Context
|
|
38
63
|
|
|
@@ -54,12 +79,37 @@ def get_latest_model_version(model_name: str) -> int:
|
|
|
54
79
|
mlflow_client: MlflowClient = MlflowClient()
|
|
55
80
|
latest_version: int = 1
|
|
56
81
|
for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
|
|
57
|
-
version_int = int(mv.version)
|
|
82
|
+
version_int: int = int(mv.version)
|
|
58
83
|
if version_int > latest_version:
|
|
59
84
|
latest_version = version_int
|
|
60
85
|
return latest_version
|
|
61
86
|
|
|
62
87
|
|
|
88
|
+
def is_interrupted(snapshot: StateSnapshot) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Check if the graph state is currently interrupted (paused for human-in-the-loop).
|
|
91
|
+
|
|
92
|
+
Based on LangChain documentation:
|
|
93
|
+
- StateSnapshot has an `interrupts` attribute which is a tuple
|
|
94
|
+
- When interrupted, the tuple contains Interrupt objects
|
|
95
|
+
- When not interrupted, it's an empty tuple ()
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
snapshot: The StateSnapshot to check
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if the graph is interrupted (has pending HITL actions), False otherwise
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
>>> snapshot = await graph.aget_state(config)
|
|
105
|
+
>>> if is_interrupted(snapshot):
|
|
106
|
+
... print("Graph is waiting for human input")
|
|
107
|
+
"""
|
|
108
|
+
# Check if snapshot has any interrupts
|
|
109
|
+
# According to LangChain docs, interrupts is a tuple that's empty () when no interrupts
|
|
110
|
+
return bool(snapshot.interrupts)
|
|
111
|
+
|
|
112
|
+
|
|
63
113
|
async def get_state_snapshot_async(
|
|
64
114
|
graph: CompiledStateGraph, thread_id: str
|
|
65
115
|
) -> Optional[StateSnapshot]:
|
|
@@ -76,11 +126,11 @@ async def get_state_snapshot_async(
|
|
|
76
126
|
Returns:
|
|
77
127
|
StateSnapshot if found, None otherwise
|
|
78
128
|
"""
|
|
79
|
-
logger.
|
|
129
|
+
logger.trace("Retrieving state snapshot", thread_id=thread_id)
|
|
80
130
|
try:
|
|
81
131
|
# Check if graph has a checkpointer
|
|
82
132
|
if graph.checkpointer is None:
|
|
83
|
-
logger.
|
|
133
|
+
logger.trace("No checkpointer available in graph")
|
|
84
134
|
return None
|
|
85
135
|
|
|
86
136
|
# Get the current state from the checkpointer (use async version)
|
|
@@ -88,13 +138,15 @@ async def get_state_snapshot_async(
|
|
|
88
138
|
state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
|
|
89
139
|
|
|
90
140
|
if state_snapshot is None:
|
|
91
|
-
logger.
|
|
141
|
+
logger.trace("No state found for thread", thread_id=thread_id)
|
|
92
142
|
return None
|
|
93
143
|
|
|
94
144
|
return state_snapshot
|
|
95
145
|
|
|
96
146
|
except Exception as e:
|
|
97
|
-
logger.warning(
|
|
147
|
+
logger.warning(
|
|
148
|
+
"Error retrieving state snapshot", thread_id=thread_id, error=str(e)
|
|
149
|
+
)
|
|
98
150
|
return None
|
|
99
151
|
|
|
100
152
|
|
|
@@ -125,7 +177,7 @@ def get_state_snapshot(
|
|
|
125
177
|
try:
|
|
126
178
|
return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
|
|
127
179
|
except Exception as e:
|
|
128
|
-
logger.warning(
|
|
180
|
+
logger.warning("Error in synchronous state snapshot retrieval", error=str(e))
|
|
129
181
|
return None
|
|
130
182
|
|
|
131
183
|
|
|
@@ -157,16 +209,125 @@ def get_genie_conversation_ids_from_state(
|
|
|
157
209
|
)
|
|
158
210
|
|
|
159
211
|
if genie_conversation_ids:
|
|
160
|
-
logger.
|
|
212
|
+
logger.trace(
|
|
213
|
+
"Retrieved genie conversation IDs", count=len(genie_conversation_ids)
|
|
214
|
+
)
|
|
161
215
|
return genie_conversation_ids
|
|
162
216
|
|
|
163
217
|
return {}
|
|
164
218
|
|
|
165
219
|
except Exception as e:
|
|
166
|
-
logger.warning(
|
|
220
|
+
logger.warning(
|
|
221
|
+
"Error extracting genie conversation IDs from state", error=str(e)
|
|
222
|
+
)
|
|
167
223
|
return {}
|
|
168
224
|
|
|
169
225
|
|
|
226
|
+
def _extract_interrupt_value(interrupt: Interrupt) -> HITLRequest:
|
|
227
|
+
"""
|
|
228
|
+
Extract the HITL request from a LangGraph Interrupt object.
|
|
229
|
+
|
|
230
|
+
Following LangChain patterns, the Interrupt object has a .value attribute
|
|
231
|
+
containing the HITLRequest structure with action_requests and review_configs.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
interrupt: Interrupt object from LangGraph with .value and .id attributes
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
HITLRequest with action_requests and review_configs
|
|
238
|
+
"""
|
|
239
|
+
# Interrupt.value is typed as Any, but for HITL it should be a HITLRequest dict
|
|
240
|
+
if isinstance(interrupt.value, dict):
|
|
241
|
+
# Return as HITLRequest TypedDict
|
|
242
|
+
return interrupt.value # type: ignore[return-value]
|
|
243
|
+
|
|
244
|
+
# Fallback: return empty structure if value is not a dict
|
|
245
|
+
return {"action_requests": [], "review_configs": []}
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _format_action_requests_message(interrupt_data: list[HITLRequest]) -> str:
|
|
249
|
+
"""
|
|
250
|
+
Format action requests from interrupts into a simple, user-friendly message.
|
|
251
|
+
|
|
252
|
+
Since we now use LLM-based parsing, users can respond in natural language.
|
|
253
|
+
This function just shows WHAT actions are pending, not HOW to respond.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
interrupt_data: List of HITLRequest structures containing action_requests and review_configs
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
Simple formatted message describing the pending actions
|
|
260
|
+
"""
|
|
261
|
+
if not interrupt_data:
|
|
262
|
+
return ""
|
|
263
|
+
|
|
264
|
+
# Collect all action requests and review configs from all interrupts
|
|
265
|
+
all_actions: list[ActionRequest] = []
|
|
266
|
+
review_configs_map: dict[str, ReviewConfig] = {}
|
|
267
|
+
|
|
268
|
+
for hitl_request in interrupt_data:
|
|
269
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
270
|
+
for review_config in hitl_request.get("review_configs", []):
|
|
271
|
+
action_name = review_config.get("action_name", "")
|
|
272
|
+
if action_name:
|
|
273
|
+
review_configs_map[action_name] = review_config
|
|
274
|
+
|
|
275
|
+
if not all_actions:
|
|
276
|
+
return ""
|
|
277
|
+
|
|
278
|
+
# Build simple, clean message
|
|
279
|
+
lines = ["⏸️ **Action Approval Required**", ""]
|
|
280
|
+
lines.append(
|
|
281
|
+
f"The assistant wants to perform {len(all_actions)} action(s) that require your approval:"
|
|
282
|
+
)
|
|
283
|
+
lines.append("")
|
|
284
|
+
|
|
285
|
+
for i, action in enumerate(all_actions, 1):
|
|
286
|
+
tool_name = action.get("name", "unknown")
|
|
287
|
+
args = action.get("args", {})
|
|
288
|
+
description = action.get("description")
|
|
289
|
+
|
|
290
|
+
lines.append(f"**{i}. {tool_name}**")
|
|
291
|
+
|
|
292
|
+
# Show review prompt/description if available
|
|
293
|
+
if description:
|
|
294
|
+
lines.append(f" • **Review:** {description}")
|
|
295
|
+
|
|
296
|
+
if args:
|
|
297
|
+
# Format args nicely, truncating long values
|
|
298
|
+
for key, value in args.items():
|
|
299
|
+
value_str = str(value)
|
|
300
|
+
if len(value_str) > 100:
|
|
301
|
+
value_str = value_str[:100] + "..."
|
|
302
|
+
lines.append(f" • {key}: `{value_str}`")
|
|
303
|
+
else:
|
|
304
|
+
lines.append(" • (no arguments)")
|
|
305
|
+
|
|
306
|
+
# Show allowed decisions
|
|
307
|
+
review_config = review_configs_map.get(tool_name)
|
|
308
|
+
if review_config:
|
|
309
|
+
allowed_decisions = review_config.get("allowed_decisions", [])
|
|
310
|
+
if allowed_decisions:
|
|
311
|
+
decisions_str = ", ".join(allowed_decisions)
|
|
312
|
+
lines.append(f" • **Options:** {decisions_str}")
|
|
313
|
+
|
|
314
|
+
lines.append("")
|
|
315
|
+
|
|
316
|
+
lines.append("---")
|
|
317
|
+
lines.append("")
|
|
318
|
+
lines.append(
|
|
319
|
+
"**You can respond in natural language** (e.g., 'approve both', 'reject the first one', "
|
|
320
|
+
"'change the email to new@example.com')"
|
|
321
|
+
)
|
|
322
|
+
lines.append("")
|
|
323
|
+
lines.append(
|
|
324
|
+
"Or provide structured decisions in `custom_inputs` with key `decisions`: "
|
|
325
|
+
'`[{"type": "approve"}, {"type": "reject", "message": "reason"}]`'
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
return "\n".join(lines)
|
|
329
|
+
|
|
330
|
+
|
|
170
331
|
class LanggraphChatModel(ChatModel):
|
|
171
332
|
"""
|
|
172
333
|
ChatModel that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -178,7 +339,11 @@ class LanggraphChatModel(ChatModel):
|
|
|
178
339
|
def predict(
|
|
179
340
|
self, context, messages: list[ChatMessage], params: Optional[ChatParams] = None
|
|
180
341
|
) -> ChatCompletionResponse:
|
|
181
|
-
logger.
|
|
342
|
+
logger.trace(
|
|
343
|
+
"Predict called",
|
|
344
|
+
messages_count=len(messages),
|
|
345
|
+
has_params=params is not None,
|
|
346
|
+
)
|
|
182
347
|
if not messages:
|
|
183
348
|
raise ValueError("Message list is empty.")
|
|
184
349
|
|
|
@@ -200,7 +365,10 @@ class LanggraphChatModel(ChatModel):
|
|
|
200
365
|
_async_invoke()
|
|
201
366
|
)
|
|
202
367
|
|
|
203
|
-
logger.trace(
|
|
368
|
+
logger.trace(
|
|
369
|
+
"Predict response received",
|
|
370
|
+
messages_count=len(response.get("messages", [])),
|
|
371
|
+
)
|
|
204
372
|
|
|
205
373
|
last_message: BaseMessage = response["messages"][-1]
|
|
206
374
|
|
|
@@ -216,28 +384,43 @@ class LanggraphChatModel(ChatModel):
|
|
|
216
384
|
|
|
217
385
|
configurable: dict[str, Any] = {}
|
|
218
386
|
if "configurable" in input_data:
|
|
219
|
-
configurable
|
|
387
|
+
configurable = input_data.pop("configurable")
|
|
220
388
|
if "custom_inputs" in input_data:
|
|
221
389
|
custom_inputs: dict[str, Any] = input_data.pop("custom_inputs")
|
|
222
390
|
if "configurable" in custom_inputs:
|
|
223
|
-
configurable
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
391
|
+
configurable = custom_inputs.pop("configurable")
|
|
392
|
+
|
|
393
|
+
# Extract known Context fields
|
|
394
|
+
user_id: str | None = configurable.pop("user_id", None)
|
|
395
|
+
if user_id:
|
|
396
|
+
user_id = user_id.replace(".", "_")
|
|
397
|
+
|
|
398
|
+
# Accept either thread_id or conversation_id (interchangeable)
|
|
399
|
+
# conversation_id takes precedence (Databricks vocabulary)
|
|
400
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
401
|
+
conversation_id: str | None = configurable.pop("conversation_id", None)
|
|
402
|
+
|
|
403
|
+
# conversation_id takes precedence if both provided
|
|
404
|
+
if conversation_id:
|
|
405
|
+
thread_id = conversation_id
|
|
406
|
+
if not thread_id:
|
|
407
|
+
thread_id = str(uuid.uuid4())
|
|
408
|
+
|
|
409
|
+
# All remaining configurable values become top-level context attributes
|
|
410
|
+
return Context(
|
|
411
|
+
user_id=user_id,
|
|
412
|
+
thread_id=thread_id,
|
|
413
|
+
**configurable, # Extra fields become top-level attributes
|
|
414
|
+
)
|
|
236
415
|
|
|
237
416
|
def predict_stream(
|
|
238
417
|
self, context, messages: list[ChatMessage], params: ChatParams
|
|
239
418
|
) -> Generator[ChatCompletionChunk, None, None]:
|
|
240
|
-
logger.
|
|
419
|
+
logger.trace(
|
|
420
|
+
"Predict stream called",
|
|
421
|
+
messages_count=len(messages),
|
|
422
|
+
has_params=params is not None,
|
|
423
|
+
)
|
|
241
424
|
if not messages:
|
|
242
425
|
raise ValueError("Message list is empty.")
|
|
243
426
|
|
|
@@ -261,7 +444,10 @@ class LanggraphChatModel(ChatModel):
|
|
|
261
444
|
stream_mode: str
|
|
262
445
|
messages_batch: Sequence[BaseMessage]
|
|
263
446
|
logger.trace(
|
|
264
|
-
|
|
447
|
+
"Stream batch received",
|
|
448
|
+
nodes=nodes,
|
|
449
|
+
stream_mode=stream_mode,
|
|
450
|
+
messages_count=len(messages_batch),
|
|
265
451
|
)
|
|
266
452
|
for message in messages_batch:
|
|
267
453
|
if (
|
|
@@ -307,6 +493,324 @@ class LanggraphChatModel(ChatModel):
|
|
|
307
493
|
return [m.to_dict() for m in messages]
|
|
308
494
|
|
|
309
495
|
|
|
496
|
+
def _create_decision_schema(interrupt_data: list[HITLRequest]) -> type[BaseModel]:
|
|
497
|
+
"""
|
|
498
|
+
Dynamically create a Pydantic model for structured output based on interrupt actions.
|
|
499
|
+
|
|
500
|
+
This creates a schema that matches the expected decision format for the interrupted actions.
|
|
501
|
+
Each action gets a corresponding decision field that can be approve, edit, or reject.
|
|
502
|
+
Includes validation fields to ensure the response is complete and valid.
|
|
503
|
+
|
|
504
|
+
Args:
|
|
505
|
+
interrupt_data: List of HITL interrupt requests containing action_requests and review_configs
|
|
506
|
+
|
|
507
|
+
Returns:
|
|
508
|
+
A dynamically created Pydantic BaseModel class for structured output
|
|
509
|
+
|
|
510
|
+
Example:
|
|
511
|
+
For two actions (send_email, execute_sql), creates a model like:
|
|
512
|
+
class Decisions(BaseModel):
|
|
513
|
+
is_valid: bool
|
|
514
|
+
validation_message: Optional[str]
|
|
515
|
+
decision_1: Literal["approve", "edit", "reject"]
|
|
516
|
+
decision_1_message: Optional[str] # For reject
|
|
517
|
+
decision_1_edited_args: Optional[dict] # For edit
|
|
518
|
+
decision_2: Literal["approve", "edit", "reject"]
|
|
519
|
+
...
|
|
520
|
+
"""
|
|
521
|
+
# Collect all actions
|
|
522
|
+
all_actions: list[ActionRequest] = []
|
|
523
|
+
review_configs_map: dict[str, ReviewConfig] = {}
|
|
524
|
+
|
|
525
|
+
for hitl_request in interrupt_data:
|
|
526
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
527
|
+
review_config: ReviewConfig
|
|
528
|
+
for review_config in hitl_request.get("review_configs", []):
|
|
529
|
+
action_name: str = review_config.get("action_name", "")
|
|
530
|
+
if action_name:
|
|
531
|
+
review_configs_map[action_name] = review_config
|
|
532
|
+
|
|
533
|
+
# Build fields for the dynamic model
|
|
534
|
+
# Start with validation fields
|
|
535
|
+
fields: dict[str, Any] = {
|
|
536
|
+
"is_valid": (
|
|
537
|
+
bool,
|
|
538
|
+
Field(
|
|
539
|
+
description="Whether the user's response provides valid decisions for ALL actions. "
|
|
540
|
+
"Set to False if the user's message is unclear, ambiguous, or doesn't provide decisions for all actions."
|
|
541
|
+
),
|
|
542
|
+
),
|
|
543
|
+
"validation_message": (
|
|
544
|
+
Optional[str],
|
|
545
|
+
Field(
|
|
546
|
+
None,
|
|
547
|
+
description="If is_valid is False, explain what is missing or unclear. "
|
|
548
|
+
"Be specific about which action(s) need clarification.",
|
|
549
|
+
),
|
|
550
|
+
),
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
i: int
|
|
554
|
+
action: ActionRequest
|
|
555
|
+
for i, action in enumerate(all_actions, 1):
|
|
556
|
+
tool_name: str = action.get("name", "unknown")
|
|
557
|
+
review_config: Optional[ReviewConfig] = review_configs_map.get(tool_name)
|
|
558
|
+
allowed_decisions: list[str] = (
|
|
559
|
+
review_config.get("allowed_decisions", ["approve", "reject"])
|
|
560
|
+
if review_config
|
|
561
|
+
else ["approve", "reject"]
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# Create a Literal type for allowed decisions
|
|
565
|
+
decision_literal: type = Literal[tuple(allowed_decisions)] # type: ignore
|
|
566
|
+
|
|
567
|
+
# Add decision field
|
|
568
|
+
fields[f"decision_{i}"] = (
|
|
569
|
+
decision_literal,
|
|
570
|
+
Field(
|
|
571
|
+
description=f"Decision for action {i} ({tool_name}): {', '.join(allowed_decisions)}"
|
|
572
|
+
),
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
# Add optional message field for reject
|
|
576
|
+
if "reject" in allowed_decisions:
|
|
577
|
+
fields[f"decision_{i}_message"] = (
|
|
578
|
+
Optional[str],
|
|
579
|
+
Field(
|
|
580
|
+
None,
|
|
581
|
+
description=f"Optional message if rejecting action {i}",
|
|
582
|
+
),
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
# Add optional edited_args field for edit
|
|
586
|
+
if "edit" in allowed_decisions:
|
|
587
|
+
fields[f"decision_{i}_edited_args"] = (
|
|
588
|
+
Optional[dict[str, Any]],
|
|
589
|
+
Field(
|
|
590
|
+
None,
|
|
591
|
+
description=f"Modified arguments if editing action {i}. Only provide fields that need to change.",
|
|
592
|
+
),
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# Create the dynamic model
|
|
596
|
+
DecisionsModel = create_model(
|
|
597
|
+
"InterruptDecisions",
|
|
598
|
+
__doc__="Decisions for each interrupted action, in order.",
|
|
599
|
+
**fields,
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
return DecisionsModel
|
|
603
|
+
|
|
604
|
+
|
|
605
|
+
def _convert_schema_to_decisions(
|
|
606
|
+
parsed_output: BaseModel,
|
|
607
|
+
interrupt_data: list[HITLRequest],
|
|
608
|
+
) -> list[Decision]:
|
|
609
|
+
"""
|
|
610
|
+
Convert the parsed structured output into LangChain Decision objects.
|
|
611
|
+
|
|
612
|
+
Args:
|
|
613
|
+
parsed_output: The Pydantic model instance from structured output
|
|
614
|
+
interrupt_data: Original interrupt data for context
|
|
615
|
+
|
|
616
|
+
Returns:
|
|
617
|
+
List of Decision dictionaries compatible with Command(resume={"decisions": ...})
|
|
618
|
+
"""
|
|
619
|
+
# Collect all actions to know how many decisions we need
|
|
620
|
+
all_actions: list[ActionRequest] = []
|
|
621
|
+
hitl_request: HITLRequest
|
|
622
|
+
for hitl_request in interrupt_data:
|
|
623
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
624
|
+
|
|
625
|
+
decisions: list[Decision] = []
|
|
626
|
+
|
|
627
|
+
i: int
|
|
628
|
+
for i in range(1, len(all_actions) + 1):
|
|
629
|
+
decision_type: str = getattr(parsed_output, f"decision_{i}")
|
|
630
|
+
|
|
631
|
+
if decision_type == "approve":
|
|
632
|
+
decisions.append({"type": "approve"}) # type: ignore
|
|
633
|
+
elif decision_type == "reject":
|
|
634
|
+
message: Optional[str] = getattr(
|
|
635
|
+
parsed_output, f"decision_{i}_message", None
|
|
636
|
+
)
|
|
637
|
+
reject_decision: RejectDecision = {"type": "reject"}
|
|
638
|
+
if message:
|
|
639
|
+
reject_decision["message"] = message
|
|
640
|
+
decisions.append(reject_decision) # type: ignore
|
|
641
|
+
elif decision_type == "edit":
|
|
642
|
+
edited_args: Optional[dict[str, Any]] = getattr(
|
|
643
|
+
parsed_output, f"decision_{i}_edited_args", None
|
|
644
|
+
)
|
|
645
|
+
action: ActionRequest = all_actions[i - 1]
|
|
646
|
+
tool_name: str = action.get("name", "")
|
|
647
|
+
original_args: dict[str, Any] = action.get("args", {})
|
|
648
|
+
|
|
649
|
+
# Merge original args with edited args
|
|
650
|
+
final_args: dict[str, Any] = {**original_args, **(edited_args or {})}
|
|
651
|
+
|
|
652
|
+
edit_decision: EditDecision = {
|
|
653
|
+
"type": "edit",
|
|
654
|
+
"edited_action": {
|
|
655
|
+
"name": tool_name,
|
|
656
|
+
"args": final_args,
|
|
657
|
+
},
|
|
658
|
+
}
|
|
659
|
+
decisions.append(edit_decision) # type: ignore
|
|
660
|
+
|
|
661
|
+
return decisions
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def handle_interrupt_response(
|
|
665
|
+
snapshot: StateSnapshot,
|
|
666
|
+
messages: list[BaseMessage],
|
|
667
|
+
model: Optional[LanguageModelLike] = None,
|
|
668
|
+
) -> dict[str, Any]:
|
|
669
|
+
"""
|
|
670
|
+
Parse user's natural language response to interrupts using LLM with structured output.
|
|
671
|
+
|
|
672
|
+
This function uses an LLM to understand the user's intent and extract structured decisions
|
|
673
|
+
for each pending action. The schema is dynamically created based on the pending actions.
|
|
674
|
+
Includes validation to ensure the response is complete and valid.
|
|
675
|
+
|
|
676
|
+
Args:
|
|
677
|
+
snapshot: The current state snapshot containing interrupts
|
|
678
|
+
messages: List of messages, from which the last human message will be extracted
|
|
679
|
+
model: Optional LLM to use for parsing. Defaults to Llama 3.1 70B
|
|
680
|
+
|
|
681
|
+
Returns:
|
|
682
|
+
Dictionary with:
|
|
683
|
+
- "is_valid": bool indicating if the response is valid
|
|
684
|
+
- "validation_message": Optional message if invalid, explaining what's missing
|
|
685
|
+
- "decisions": list of Decision objects (empty if invalid)
|
|
686
|
+
|
|
687
|
+
Example:
|
|
688
|
+
Valid: {"is_valid": True, "validation_message": None, "decisions": [{"type": "approve"}]}
|
|
689
|
+
Invalid: {"is_valid": False, "validation_message": "Please specify...", "decisions": []}
|
|
690
|
+
"""
|
|
691
|
+
# Extract the last human message
|
|
692
|
+
user_message_obj: Optional[HumanMessage] = last_human_message(messages)
|
|
693
|
+
|
|
694
|
+
if not user_message_obj:
|
|
695
|
+
logger.warning("HITL: No human message found in interrupt response")
|
|
696
|
+
return {
|
|
697
|
+
"is_valid": False,
|
|
698
|
+
"validation_message": "No user message found. Please provide a response to the pending action(s).",
|
|
699
|
+
"decisions": [],
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
user_message: str = str(user_message_obj.content)
|
|
703
|
+
logger.info(
|
|
704
|
+
"HITL: Parsing user interrupt response", message_preview=user_message[:100]
|
|
705
|
+
)
|
|
706
|
+
|
|
707
|
+
if not model:
|
|
708
|
+
model = ChatDatabricks(
|
|
709
|
+
endpoint="databricks-claude-sonnet-4",
|
|
710
|
+
temperature=0,
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
# Extract interrupt data
|
|
714
|
+
if not snapshot.interrupts:
|
|
715
|
+
logger.warning("HITL: No interrupts found in snapshot")
|
|
716
|
+
return {"decisions": []}
|
|
717
|
+
|
|
718
|
+
interrupt_data: list[HITLRequest] = [
|
|
719
|
+
_extract_interrupt_value(interrupt) for interrupt in snapshot.interrupts
|
|
720
|
+
]
|
|
721
|
+
|
|
722
|
+
# Collect all actions for context
|
|
723
|
+
all_actions: list[ActionRequest] = []
|
|
724
|
+
hitl_request: HITLRequest
|
|
725
|
+
for hitl_request in interrupt_data:
|
|
726
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
727
|
+
|
|
728
|
+
if not all_actions:
|
|
729
|
+
logger.warning("HITL: No actions found in interrupts")
|
|
730
|
+
return {"decisions": []}
|
|
731
|
+
|
|
732
|
+
# Create dynamic schema
|
|
733
|
+
DecisionsModel: type[BaseModel] = _create_decision_schema(interrupt_data)
|
|
734
|
+
|
|
735
|
+
# Create structured LLM
|
|
736
|
+
structured_llm: LanguageModelLike = model.with_structured_output(DecisionsModel)
|
|
737
|
+
|
|
738
|
+
# Format action context for the LLM
|
|
739
|
+
action_descriptions: list[str] = []
|
|
740
|
+
i: int
|
|
741
|
+
action: ActionRequest
|
|
742
|
+
for i, action in enumerate(all_actions, 1):
|
|
743
|
+
tool_name: str = action.get("name", "unknown")
|
|
744
|
+
args: dict[str, Any] = action.get("args", {})
|
|
745
|
+
args_str: str = (
|
|
746
|
+
", ".join(f"{k}={v}" for k, v in args.items()) if args else "(no args)"
|
|
747
|
+
)
|
|
748
|
+
action_descriptions.append(f"Action {i}: {tool_name}({args_str})")
|
|
749
|
+
|
|
750
|
+
system_prompt: str = f"""You are parsing a user's response to interrupted agent actions.
|
|
751
|
+
|
|
752
|
+
The following actions are pending approval:
|
|
753
|
+
{chr(10).join(action_descriptions)}
|
|
754
|
+
|
|
755
|
+
Your task is to extract the user's decision for EACH action in order. The user may:
|
|
756
|
+
- Approve: Accept the action as-is
|
|
757
|
+
- Reject: Cancel the action (optionally with a reason/message)
|
|
758
|
+
- Edit: Modify the arguments before executing
|
|
759
|
+
|
|
760
|
+
VALIDATION:
|
|
761
|
+
- Set is_valid=True only if you can confidently extract decisions for ALL actions
|
|
762
|
+
- Set is_valid=False if the user's message is:
|
|
763
|
+
* Unclear or ambiguous
|
|
764
|
+
* Missing decisions for some actions
|
|
765
|
+
* Asking a question instead of providing decisions
|
|
766
|
+
* Not addressing the actions at all
|
|
767
|
+
- If is_valid=False, provide a clear validation_message explaining what is needed
|
|
768
|
+
|
|
769
|
+
FLEXIBILITY:
|
|
770
|
+
- Be flexible in parsing informal language like "yes", "no", "ok", "change X to Y"
|
|
771
|
+
- If the user doesn't explicitly mention an action, assume they want to approve it
|
|
772
|
+
- Only mark as invalid if the message is genuinely unclear or incomplete"""
|
|
773
|
+
|
|
774
|
+
try:
|
|
775
|
+
# Invoke LLM with structured output
|
|
776
|
+
parsed: BaseModel = structured_llm.invoke(
|
|
777
|
+
[
|
|
778
|
+
SystemMessage(content=system_prompt),
|
|
779
|
+
HumanMessage(content=user_message),
|
|
780
|
+
]
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Check validation first
|
|
784
|
+
is_valid: bool = getattr(parsed, "is_valid", True)
|
|
785
|
+
validation_message: Optional[str] = getattr(parsed, "validation_message", None)
|
|
786
|
+
|
|
787
|
+
if not is_valid:
|
|
788
|
+
logger.warning(
|
|
789
|
+
"HITL: Invalid user response", reason=validation_message or "Unknown"
|
|
790
|
+
)
|
|
791
|
+
return {
|
|
792
|
+
"is_valid": False,
|
|
793
|
+
"validation_message": validation_message
|
|
794
|
+
or "Your response was unclear. Please provide a clear decision for each action.",
|
|
795
|
+
"decisions": [],
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
# Convert to Decision format
|
|
799
|
+
decisions: list[Decision] = _convert_schema_to_decisions(parsed, interrupt_data)
|
|
800
|
+
|
|
801
|
+
logger.info("HITL: Parsed interrupt decisions", decisions_count=len(decisions))
|
|
802
|
+
return {"is_valid": True, "validation_message": None, "decisions": decisions}
|
|
803
|
+
|
|
804
|
+
except Exception as e:
|
|
805
|
+
logger.error("HITL: Failed to parse interrupt response", error=str(e))
|
|
806
|
+
# Return invalid response on parsing failure
|
|
807
|
+
return {
|
|
808
|
+
"is_valid": False,
|
|
809
|
+
"validation_message": f"Failed to parse your response: {str(e)}. Please provide a clear decision for each action.",
|
|
810
|
+
"decisions": [],
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
|
|
310
814
|
class LanggraphResponsesAgent(ResponsesAgent):
|
|
311
815
|
"""
|
|
312
816
|
ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -315,38 +819,191 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
315
819
|
support for streaming, tool calling, and async execution.
|
|
316
820
|
"""
|
|
317
821
|
|
|
318
|
-
def __init__(
|
|
822
|
+
def __init__(
|
|
823
|
+
self,
|
|
824
|
+
graph: CompiledStateGraph,
|
|
825
|
+
) -> None:
|
|
319
826
|
self.graph = graph
|
|
320
827
|
|
|
321
828
|
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
322
829
|
"""
|
|
323
830
|
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
831
|
+
|
|
832
|
+
Input structure (custom_inputs):
|
|
833
|
+
configurable:
|
|
834
|
+
thread_id: "abc-123" # Or conversation_id (aliases, conversation_id takes precedence)
|
|
835
|
+
user_id: "nate.fleming"
|
|
836
|
+
store_num: "87887"
|
|
837
|
+
session: # Paste from previous output
|
|
838
|
+
conversation_id: "abc-123" # Alias of thread_id
|
|
839
|
+
genie:
|
|
840
|
+
spaces:
|
|
841
|
+
space_123: {conversation_id: "conv_456", ...}
|
|
842
|
+
decisions: # For resuming interrupted graphs (HITL)
|
|
843
|
+
- type: "approve"
|
|
844
|
+
- type: "reject"
|
|
845
|
+
message: "Not authorized"
|
|
846
|
+
|
|
847
|
+
Output structure (custom_outputs):
|
|
848
|
+
configurable:
|
|
849
|
+
thread_id: "abc-123" # Only thread_id in configurable
|
|
850
|
+
user_id: "nate.fleming"
|
|
851
|
+
store_num: "87887"
|
|
852
|
+
session:
|
|
853
|
+
conversation_id: "abc-123" # conversation_id in session
|
|
854
|
+
genie:
|
|
855
|
+
spaces:
|
|
856
|
+
space_123: {conversation_id: "conv_456", ...}
|
|
857
|
+
pending_actions: # If HITL interrupt occurred
|
|
858
|
+
- name: "send_email"
|
|
859
|
+
arguments: {...}
|
|
860
|
+
description: "..."
|
|
324
861
|
"""
|
|
325
|
-
|
|
862
|
+
# Extract conversation_id for logging (from context or custom_inputs)
|
|
863
|
+
conversation_id_for_log: str | None = None
|
|
864
|
+
if request.context and hasattr(request.context, "conversation_id"):
|
|
865
|
+
conversation_id_for_log = request.context.conversation_id
|
|
866
|
+
elif request.custom_inputs:
|
|
867
|
+
# Check configurable or session for conversation_id
|
|
868
|
+
if "configurable" in request.custom_inputs and isinstance(
|
|
869
|
+
request.custom_inputs["configurable"], dict
|
|
870
|
+
):
|
|
871
|
+
conversation_id_for_log = request.custom_inputs["configurable"].get(
|
|
872
|
+
"conversation_id"
|
|
873
|
+
)
|
|
874
|
+
if (
|
|
875
|
+
conversation_id_for_log is None
|
|
876
|
+
and "session" in request.custom_inputs
|
|
877
|
+
and isinstance(request.custom_inputs["session"], dict)
|
|
878
|
+
):
|
|
879
|
+
conversation_id_for_log = request.custom_inputs["session"].get(
|
|
880
|
+
"conversation_id"
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
logger.debug(
|
|
884
|
+
"ResponsesAgent predict called",
|
|
885
|
+
conversation_id=conversation_id_for_log
|
|
886
|
+
if conversation_id_for_log
|
|
887
|
+
else "new",
|
|
888
|
+
)
|
|
326
889
|
|
|
327
890
|
# Convert ResponsesAgent input to LangChain messages
|
|
328
|
-
messages = self._convert_request_to_langchain_messages(
|
|
891
|
+
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
892
|
+
request
|
|
893
|
+
)
|
|
329
894
|
|
|
330
|
-
# Prepare context
|
|
895
|
+
# Prepare context (conversation_id -> thread_id mapping happens here)
|
|
331
896
|
context: Context = self._convert_request_to_context(request)
|
|
332
897
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
333
898
|
|
|
899
|
+
# Extract session state from request
|
|
900
|
+
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
901
|
+
|
|
334
902
|
# Use async ainvoke internally for parallel execution
|
|
335
903
|
import asyncio
|
|
336
904
|
|
|
905
|
+
from langgraph.types import Command
|
|
906
|
+
|
|
337
907
|
async def _async_invoke():
|
|
338
908
|
try:
|
|
909
|
+
# Check if this is a resume request (HITL)
|
|
910
|
+
# Two ways to resume:
|
|
911
|
+
# 1. Explicit decisions in custom_inputs (structured)
|
|
912
|
+
# 2. Natural language message when graph is interrupted (LLM-parsed)
|
|
913
|
+
|
|
914
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
915
|
+
# Explicit structured decisions
|
|
916
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
917
|
+
logger.info(
|
|
918
|
+
"HITL: Resuming with explicit decisions",
|
|
919
|
+
decisions_count=len(decisions),
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
# Resume interrupted graph with decisions
|
|
923
|
+
return await self.graph.ainvoke(
|
|
924
|
+
Command(resume={"decisions": decisions}),
|
|
925
|
+
context=context,
|
|
926
|
+
config=custom_inputs,
|
|
927
|
+
)
|
|
928
|
+
|
|
929
|
+
# Check if graph is currently interrupted (only if checkpointer is configured)
|
|
930
|
+
# aget_state requires a checkpointer
|
|
931
|
+
if self.graph.checkpointer:
|
|
932
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
933
|
+
config=custom_inputs
|
|
934
|
+
)
|
|
935
|
+
if is_interrupted(snapshot):
|
|
936
|
+
logger.info(
|
|
937
|
+
"HITL: Graph interrupted, checking for user response"
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# Convert message dicts to BaseMessage objects
|
|
941
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
942
|
+
messages
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
# Parse user's message with LLM to extract decisions
|
|
946
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
947
|
+
snapshot=snapshot,
|
|
948
|
+
messages=message_objects,
|
|
949
|
+
model=None, # Uses default model
|
|
950
|
+
)
|
|
951
|
+
|
|
952
|
+
# Check if the response was valid
|
|
953
|
+
if not parsed_result.get("is_valid", False):
|
|
954
|
+
validation_message: str = parsed_result.get(
|
|
955
|
+
"validation_message",
|
|
956
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
957
|
+
)
|
|
958
|
+
logger.warning(
|
|
959
|
+
"HITL: Invalid response from user",
|
|
960
|
+
validation_message=validation_message,
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
# Return error message to user instead of resuming
|
|
964
|
+
# Don't resume the graph - stay interrupted so user can try again
|
|
965
|
+
return {
|
|
966
|
+
"messages": [
|
|
967
|
+
AIMessage(
|
|
968
|
+
content=f"❌ **Invalid Response**\n\n{validation_message}"
|
|
969
|
+
)
|
|
970
|
+
]
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
974
|
+
logger.info(
|
|
975
|
+
"HITL: LLM parsed valid decisions from user message",
|
|
976
|
+
decisions_count=len(decisions),
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Resume interrupted graph with parsed decisions
|
|
980
|
+
return await self.graph.ainvoke(
|
|
981
|
+
Command(resume={"decisions": decisions}),
|
|
982
|
+
context=context,
|
|
983
|
+
config=custom_inputs,
|
|
984
|
+
)
|
|
985
|
+
|
|
986
|
+
# Normal invocation - build the graph input state
|
|
987
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
988
|
+
if "genie_conversation_ids" in session_input:
|
|
989
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
990
|
+
"genie_conversation_ids"
|
|
991
|
+
]
|
|
992
|
+
logger.trace(
|
|
993
|
+
"Including genie conversation IDs in graph input",
|
|
994
|
+
count=len(graph_input["genie_conversation_ids"]),
|
|
995
|
+
)
|
|
996
|
+
|
|
339
997
|
return await self.graph.ainvoke(
|
|
340
|
-
|
|
998
|
+
graph_input, context=context, config=custom_inputs
|
|
341
999
|
)
|
|
342
1000
|
except Exception as e:
|
|
343
|
-
logger.error(
|
|
1001
|
+
logger.error("Error in graph invocation", error=str(e))
|
|
344
1002
|
raise
|
|
345
1003
|
|
|
346
1004
|
try:
|
|
347
1005
|
loop = asyncio.get_event_loop()
|
|
348
1006
|
except RuntimeError:
|
|
349
|
-
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
350
1007
|
loop = asyncio.new_event_loop()
|
|
351
1008
|
asyncio.set_event_loop(loop)
|
|
352
1009
|
|
|
@@ -355,28 +1012,93 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
355
1012
|
_async_invoke()
|
|
356
1013
|
)
|
|
357
1014
|
except Exception as e:
|
|
358
|
-
logger.error(
|
|
1015
|
+
logger.error("Error in async execution", error=str(e))
|
|
359
1016
|
raise
|
|
360
1017
|
|
|
361
1018
|
# Convert response to ResponsesAgent format
|
|
362
1019
|
last_message: BaseMessage = response["messages"][-1]
|
|
363
1020
|
|
|
364
|
-
|
|
365
|
-
|
|
1021
|
+
# Build custom_outputs that can be copy-pasted as next request's custom_inputs
|
|
1022
|
+
custom_outputs: dict[str, Any] = self._build_custom_outputs(
|
|
1023
|
+
context=context,
|
|
1024
|
+
thread_id=context.thread_id,
|
|
1025
|
+
loop=loop,
|
|
366
1026
|
)
|
|
367
1027
|
|
|
368
|
-
#
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
1028
|
+
# Handle structured_response if present
|
|
1029
|
+
if "structured_response" in response:
|
|
1030
|
+
from dataclasses import asdict, is_dataclass
|
|
1031
|
+
|
|
1032
|
+
from pydantic import BaseModel
|
|
1033
|
+
|
|
1034
|
+
structured_response = response["structured_response"]
|
|
1035
|
+
logger.trace(
|
|
1036
|
+
"Processing structured response",
|
|
1037
|
+
response_type=type(structured_response).__name__,
|
|
374
1038
|
)
|
|
375
|
-
|
|
376
|
-
|
|
1039
|
+
|
|
1040
|
+
# Serialize to dict for JSON compatibility using type hints
|
|
1041
|
+
if isinstance(structured_response, BaseModel):
|
|
1042
|
+
# Pydantic model
|
|
1043
|
+
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1044
|
+
elif is_dataclass(structured_response):
|
|
1045
|
+
# Dataclass
|
|
1046
|
+
serialized = asdict(structured_response)
|
|
1047
|
+
elif isinstance(structured_response, dict):
|
|
1048
|
+
# Already a dict
|
|
1049
|
+
serialized = structured_response
|
|
1050
|
+
else:
|
|
1051
|
+
# Unknown type, convert to dict if possible
|
|
1052
|
+
serialized = (
|
|
1053
|
+
dict(structured_response)
|
|
1054
|
+
if hasattr(structured_response, "__dict__")
|
|
1055
|
+
else structured_response
|
|
1056
|
+
)
|
|
1057
|
+
|
|
1058
|
+
# Place structured output in message content as JSON
|
|
1059
|
+
import json
|
|
1060
|
+
|
|
1061
|
+
structured_text: str = json.dumps(serialized, indent=2)
|
|
1062
|
+
output_item = self.create_text_output_item(
|
|
1063
|
+
text=structured_text, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
377
1064
|
)
|
|
378
|
-
|
|
379
|
-
|
|
1065
|
+
logger.trace("Structured response placed in message content")
|
|
1066
|
+
else:
|
|
1067
|
+
# No structured response, use text content
|
|
1068
|
+
output_item = self.create_text_output_item(
|
|
1069
|
+
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
1070
|
+
)
|
|
1071
|
+
|
|
1072
|
+
# Include interrupt structure if HITL occurred (following LangChain pattern)
|
|
1073
|
+
if "__interrupt__" in response:
|
|
1074
|
+
interrupts: list[Interrupt] = response["__interrupt__"]
|
|
1075
|
+
logger.info("HITL: Interrupts detected", interrupts_count=len(interrupts))
|
|
1076
|
+
|
|
1077
|
+
# Extract HITLRequest structures from interrupts (deduplicate by ID)
|
|
1078
|
+
seen_interrupt_ids: set[str] = set()
|
|
1079
|
+
interrupt_data: list[HITLRequest] = []
|
|
1080
|
+
interrupt: Interrupt
|
|
1081
|
+
for interrupt in interrupts:
|
|
1082
|
+
# Only process each unique interrupt once
|
|
1083
|
+
if interrupt.id not in seen_interrupt_ids:
|
|
1084
|
+
seen_interrupt_ids.add(interrupt.id)
|
|
1085
|
+
interrupt_data.append(_extract_interrupt_value(interrupt))
|
|
1086
|
+
logger.trace(
|
|
1087
|
+
"HITL: Added interrupt to response", interrupt_id=interrupt.id
|
|
1088
|
+
)
|
|
1089
|
+
|
|
1090
|
+
custom_outputs["interrupts"] = interrupt_data
|
|
1091
|
+
logger.debug(
|
|
1092
|
+
"HITL: Included interrupts in response",
|
|
1093
|
+
interrupts_count=len(interrupt_data),
|
|
1094
|
+
)
|
|
1095
|
+
|
|
1096
|
+
# Add user-facing message about the pending actions
|
|
1097
|
+
action_message: str = _format_action_requests_message(interrupt_data)
|
|
1098
|
+
if action_message:
|
|
1099
|
+
output_item = self.create_text_output_item(
|
|
1100
|
+
text=action_message, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
1101
|
+
)
|
|
380
1102
|
|
|
381
1103
|
return ResponsesAgentResponse(
|
|
382
1104
|
output=[output_item], custom_outputs=custom_outputs
|
|
@@ -387,90 +1109,354 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
387
1109
|
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
388
1110
|
"""
|
|
389
1111
|
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
1112
|
+
|
|
1113
|
+
Uses same input/output structure as predict() for consistency.
|
|
1114
|
+
Supports Human-in-the-Loop (HITL) interrupts.
|
|
390
1115
|
"""
|
|
391
|
-
|
|
1116
|
+
# Extract conversation_id for logging (from context or custom_inputs)
|
|
1117
|
+
conversation_id_for_log: str | None = None
|
|
1118
|
+
if request.context and hasattr(request.context, "conversation_id"):
|
|
1119
|
+
conversation_id_for_log = request.context.conversation_id
|
|
1120
|
+
elif request.custom_inputs:
|
|
1121
|
+
# Check configurable or session for conversation_id
|
|
1122
|
+
if "configurable" in request.custom_inputs and isinstance(
|
|
1123
|
+
request.custom_inputs["configurable"], dict
|
|
1124
|
+
):
|
|
1125
|
+
conversation_id_for_log = request.custom_inputs["configurable"].get(
|
|
1126
|
+
"conversation_id"
|
|
1127
|
+
)
|
|
1128
|
+
if (
|
|
1129
|
+
conversation_id_for_log is None
|
|
1130
|
+
and "session" in request.custom_inputs
|
|
1131
|
+
and isinstance(request.custom_inputs["session"], dict)
|
|
1132
|
+
):
|
|
1133
|
+
conversation_id_for_log = request.custom_inputs["session"].get(
|
|
1134
|
+
"conversation_id"
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
logger.debug(
|
|
1138
|
+
"ResponsesAgent predict_stream called",
|
|
1139
|
+
conversation_id=conversation_id_for_log
|
|
1140
|
+
if conversation_id_for_log
|
|
1141
|
+
else "new",
|
|
1142
|
+
)
|
|
392
1143
|
|
|
393
1144
|
# Convert ResponsesAgent input to LangChain messages
|
|
394
|
-
messages: list[
|
|
1145
|
+
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
395
1146
|
request
|
|
396
1147
|
)
|
|
397
1148
|
|
|
398
|
-
# Prepare context
|
|
1149
|
+
# Prepare context (conversation_id -> thread_id mapping happens here)
|
|
399
1150
|
context: Context = self._convert_request_to_context(request)
|
|
400
1151
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
401
1152
|
|
|
1153
|
+
# Extract session state from request
|
|
1154
|
+
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
1155
|
+
|
|
402
1156
|
# Use async astream internally for parallel execution
|
|
403
1157
|
import asyncio
|
|
404
1158
|
|
|
1159
|
+
from langgraph.types import Command
|
|
1160
|
+
|
|
405
1161
|
async def _async_stream():
|
|
406
|
-
item_id = f"msg_{uuid.uuid4().hex[:8]}"
|
|
407
|
-
accumulated_content = ""
|
|
1162
|
+
item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
|
|
1163
|
+
accumulated_content: str = ""
|
|
1164
|
+
interrupt_data: list[HITLRequest] = []
|
|
1165
|
+
seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
|
|
1166
|
+
structured_response: Any = None # Track structured output from stream
|
|
408
1167
|
|
|
409
1168
|
try:
|
|
410
|
-
|
|
411
|
-
|
|
1169
|
+
# Check if this is a resume request (HITL)
|
|
1170
|
+
# Two ways to resume:
|
|
1171
|
+
# 1. Explicit decisions in custom_inputs (structured)
|
|
1172
|
+
# 2. Natural language message when graph is interrupted (LLM-parsed)
|
|
1173
|
+
|
|
1174
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
1175
|
+
# Explicit structured decisions
|
|
1176
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
1177
|
+
logger.info(
|
|
1178
|
+
"HITL: Resuming stream with explicit decisions",
|
|
1179
|
+
decisions_count=len(decisions),
|
|
1180
|
+
)
|
|
1181
|
+
stream_input: Command | dict[str, Any] = Command(
|
|
1182
|
+
resume={"decisions": decisions}
|
|
1183
|
+
)
|
|
1184
|
+
elif self.graph.checkpointer:
|
|
1185
|
+
# Check if graph is currently interrupted (only if checkpointer is configured)
|
|
1186
|
+
# aget_state requires a checkpointer
|
|
1187
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
1188
|
+
config=custom_inputs
|
|
1189
|
+
)
|
|
1190
|
+
if is_interrupted(snapshot):
|
|
1191
|
+
logger.info(
|
|
1192
|
+
"HITL: Graph interrupted, checking for user response in stream"
|
|
1193
|
+
)
|
|
1194
|
+
|
|
1195
|
+
# Convert message dicts to BaseMessage objects
|
|
1196
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
1197
|
+
messages
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
# Parse user's message with LLM to extract decisions
|
|
1201
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
1202
|
+
snapshot=snapshot,
|
|
1203
|
+
messages=message_objects,
|
|
1204
|
+
model=None, # Uses default model
|
|
1205
|
+
)
|
|
1206
|
+
|
|
1207
|
+
# Check if the response was valid
|
|
1208
|
+
if not parsed_result.get("is_valid", False):
|
|
1209
|
+
validation_message: str = parsed_result.get(
|
|
1210
|
+
"validation_message",
|
|
1211
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
1212
|
+
)
|
|
1213
|
+
logger.warning(
|
|
1214
|
+
"HITL: Invalid response from user in stream",
|
|
1215
|
+
validation_message=validation_message,
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# Build custom_outputs before returning
|
|
1219
|
+
custom_outputs: dict[
|
|
1220
|
+
str, Any
|
|
1221
|
+
] = await self._build_custom_outputs_async(
|
|
1222
|
+
context=context,
|
|
1223
|
+
thread_id=context.thread_id,
|
|
1224
|
+
)
|
|
1225
|
+
|
|
1226
|
+
# Yield error message to user - don't resume graph
|
|
1227
|
+
error_message: str = (
|
|
1228
|
+
f"❌ **Invalid Response**\n\n{validation_message}"
|
|
1229
|
+
)
|
|
1230
|
+
accumulated_content = error_message
|
|
1231
|
+
yield ResponsesAgentStreamEvent(
|
|
1232
|
+
type="response.output_item.done",
|
|
1233
|
+
item=self.create_text_output_item(
|
|
1234
|
+
text=error_message, id=item_id
|
|
1235
|
+
),
|
|
1236
|
+
custom_outputs=custom_outputs,
|
|
1237
|
+
)
|
|
1238
|
+
return # Don't resume - stay interrupted
|
|
1239
|
+
|
|
1240
|
+
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
1241
|
+
logger.info(
|
|
1242
|
+
"HITL: LLM parsed valid decisions from user message in stream",
|
|
1243
|
+
decisions_count=len(decisions),
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
# Resume interrupted graph with parsed decisions
|
|
1247
|
+
stream_input: Command | dict[str, Any] = Command(
|
|
1248
|
+
resume={"decisions": decisions}
|
|
1249
|
+
)
|
|
1250
|
+
else:
|
|
1251
|
+
# Graph not interrupted, use normal invocation
|
|
1252
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
1253
|
+
if "genie_conversation_ids" in session_input:
|
|
1254
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
1255
|
+
"genie_conversation_ids"
|
|
1256
|
+
]
|
|
1257
|
+
stream_input: Command | dict[str, Any] = graph_input
|
|
1258
|
+
else:
|
|
1259
|
+
# No checkpointer, use normal invocation
|
|
1260
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
1261
|
+
if "genie_conversation_ids" in session_input:
|
|
1262
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
1263
|
+
"genie_conversation_ids"
|
|
1264
|
+
]
|
|
1265
|
+
stream_input: Command | dict[str, Any] = graph_input
|
|
1266
|
+
|
|
1267
|
+
# Stream the graph execution with both messages and updates modes to capture interrupts
|
|
1268
|
+
async for nodes, stream_mode, data in self.graph.astream(
|
|
1269
|
+
stream_input,
|
|
412
1270
|
context=context,
|
|
413
1271
|
config=custom_inputs,
|
|
414
|
-
stream_mode=["messages", "
|
|
1272
|
+
stream_mode=["messages", "updates"],
|
|
415
1273
|
subgraphs=True,
|
|
416
1274
|
):
|
|
417
1275
|
nodes: tuple[str, ...]
|
|
418
1276
|
stream_mode: str
|
|
419
|
-
messages_batch: Sequence[BaseMessage]
|
|
420
1277
|
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
1278
|
+
# Handle message streaming
|
|
1279
|
+
if stream_mode == "messages":
|
|
1280
|
+
messages_batch: Sequence[BaseMessage] = data
|
|
1281
|
+
message: BaseMessage
|
|
1282
|
+
for message in messages_batch:
|
|
1283
|
+
if (
|
|
1284
|
+
isinstance(
|
|
1285
|
+
message,
|
|
1286
|
+
(
|
|
1287
|
+
AIMessageChunk,
|
|
1288
|
+
AIMessage,
|
|
1289
|
+
),
|
|
1290
|
+
)
|
|
1291
|
+
and message.content
|
|
1292
|
+
and "summarization" not in nodes
|
|
1293
|
+
):
|
|
1294
|
+
content: str = message.content
|
|
1295
|
+
accumulated_content += content
|
|
1296
|
+
|
|
1297
|
+
# Yield streaming delta
|
|
1298
|
+
yield ResponsesAgentStreamEvent(
|
|
1299
|
+
**self.create_text_delta(
|
|
1300
|
+
delta=content, item_id=item_id
|
|
1301
|
+
)
|
|
1302
|
+
)
|
|
1303
|
+
|
|
1304
|
+
# Handle interrupts (HITL) and state updates
|
|
1305
|
+
elif stream_mode == "updates":
|
|
1306
|
+
updates: dict[str, Any] = data
|
|
1307
|
+
source: str
|
|
1308
|
+
update: Any
|
|
1309
|
+
for source, update in updates.items():
|
|
1310
|
+
if source == "__interrupt__":
|
|
1311
|
+
interrupts: list[Interrupt] = update
|
|
1312
|
+
logger.info(
|
|
1313
|
+
"HITL: Interrupts detected during streaming",
|
|
1314
|
+
interrupts_count=len(interrupts),
|
|
1315
|
+
)
|
|
1316
|
+
|
|
1317
|
+
# Extract interrupt values (deduplicate by ID)
|
|
1318
|
+
interrupt: Interrupt
|
|
1319
|
+
for interrupt in interrupts:
|
|
1320
|
+
# Only process each unique interrupt once
|
|
1321
|
+
if interrupt.id not in seen_interrupt_ids:
|
|
1322
|
+
seen_interrupt_ids.add(interrupt.id)
|
|
1323
|
+
interrupt_data.append(
|
|
1324
|
+
_extract_interrupt_value(interrupt)
|
|
1325
|
+
)
|
|
1326
|
+
logger.trace(
|
|
1327
|
+
"HITL: Added interrupt to response",
|
|
1328
|
+
interrupt_id=interrupt.id,
|
|
1329
|
+
)
|
|
1330
|
+
elif (
|
|
1331
|
+
isinstance(update, dict)
|
|
1332
|
+
and "structured_response" in update
|
|
1333
|
+
):
|
|
1334
|
+
# Capture structured_response from stream updates
|
|
1335
|
+
structured_response = update["structured_response"]
|
|
1336
|
+
logger.trace(
|
|
1337
|
+
"Captured structured response from stream",
|
|
1338
|
+
response_type=type(structured_response).__name__,
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
# Get final state to extract structured_response (only if checkpointer available)
|
|
1342
|
+
if self.graph.checkpointer:
|
|
1343
|
+
final_state: StateSnapshot = await self.graph.aget_state(
|
|
1344
|
+
config=custom_inputs
|
|
1345
|
+
)
|
|
1346
|
+
# Extract structured_response from state if not already captured
|
|
1347
|
+
if (
|
|
1348
|
+
"structured_response" in final_state.values
|
|
1349
|
+
and not structured_response
|
|
1350
|
+
):
|
|
1351
|
+
structured_response = final_state.values["structured_response"]
|
|
435
1352
|
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
1353
|
+
# Build custom_outputs
|
|
1354
|
+
custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
|
|
1355
|
+
context=context,
|
|
1356
|
+
thread_id=context.thread_id,
|
|
1357
|
+
)
|
|
440
1358
|
|
|
441
|
-
#
|
|
442
|
-
|
|
443
|
-
|
|
1359
|
+
# Handle structured_response in streaming if present
|
|
1360
|
+
output_text: str = accumulated_content
|
|
1361
|
+
if structured_response:
|
|
1362
|
+
from dataclasses import asdict, is_dataclass
|
|
444
1363
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
get_genie_conversation_ids_from_state(state_snapshot)
|
|
1364
|
+
from pydantic import BaseModel
|
|
1365
|
+
|
|
1366
|
+
logger.trace(
|
|
1367
|
+
"Processing structured response in streaming",
|
|
1368
|
+
response_type=type(structured_response).__name__,
|
|
451
1369
|
)
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
1370
|
+
|
|
1371
|
+
# Serialize to dict for JSON compatibility using type hints
|
|
1372
|
+
if isinstance(structured_response, BaseModel):
|
|
1373
|
+
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1374
|
+
elif is_dataclass(structured_response):
|
|
1375
|
+
serialized = asdict(structured_response)
|
|
1376
|
+
elif isinstance(structured_response, dict):
|
|
1377
|
+
serialized = structured_response
|
|
1378
|
+
else:
|
|
1379
|
+
serialized = (
|
|
1380
|
+
dict(structured_response)
|
|
1381
|
+
if hasattr(structured_response, "__dict__")
|
|
1382
|
+
else structured_response
|
|
1383
|
+
)
|
|
1384
|
+
|
|
1385
|
+
# Place structured output in message content - stream as JSON
|
|
1386
|
+
import json
|
|
1387
|
+
|
|
1388
|
+
structured_text: str = json.dumps(serialized, indent=2)
|
|
1389
|
+
|
|
1390
|
+
# If we streamed text, append structured; if no text, use structured only
|
|
1391
|
+
if accumulated_content.strip():
|
|
1392
|
+
# Stream separator and structured output
|
|
1393
|
+
yield ResponsesAgentStreamEvent(
|
|
1394
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1395
|
+
)
|
|
1396
|
+
yield ResponsesAgentStreamEvent(
|
|
1397
|
+
**self.create_text_delta(
|
|
1398
|
+
delta=structured_text, item_id=item_id
|
|
1399
|
+
)
|
|
1400
|
+
)
|
|
1401
|
+
output_text = f"{accumulated_content}\n\n{structured_text}"
|
|
1402
|
+
else:
|
|
1403
|
+
# No text content, stream structured output
|
|
1404
|
+
yield ResponsesAgentStreamEvent(
|
|
1405
|
+
**self.create_text_delta(
|
|
1406
|
+
delta=structured_text, item_id=item_id
|
|
1407
|
+
)
|
|
455
1408
|
)
|
|
1409
|
+
output_text = structured_text
|
|
1410
|
+
|
|
1411
|
+
logger.trace("Streamed structured response in message content")
|
|
1412
|
+
|
|
1413
|
+
# Include interrupt structure if HITL occurred
|
|
1414
|
+
if interrupt_data:
|
|
1415
|
+
custom_outputs["interrupts"] = interrupt_data
|
|
1416
|
+
logger.info(
|
|
1417
|
+
"HITL: Included interrupts in streaming response",
|
|
1418
|
+
interrupts_count=len(interrupt_data),
|
|
1419
|
+
)
|
|
1420
|
+
|
|
1421
|
+
# Add user-facing message about the pending actions
|
|
1422
|
+
action_message = _format_action_requests_message(interrupt_data)
|
|
1423
|
+
if action_message:
|
|
1424
|
+
# If we haven't streamed any content yet, stream the action message
|
|
1425
|
+
if not accumulated_content:
|
|
1426
|
+
output_text = action_message
|
|
1427
|
+
# Stream the action message
|
|
1428
|
+
yield ResponsesAgentStreamEvent(
|
|
1429
|
+
**self.create_text_delta(
|
|
1430
|
+
delta=action_message, item_id=item_id
|
|
1431
|
+
)
|
|
1432
|
+
)
|
|
1433
|
+
else:
|
|
1434
|
+
# Append action message after accumulated content
|
|
1435
|
+
output_text = f"{accumulated_content}\n\n{action_message}"
|
|
1436
|
+
# Stream the separator and action message
|
|
1437
|
+
yield ResponsesAgentStreamEvent(
|
|
1438
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1439
|
+
)
|
|
1440
|
+
yield ResponsesAgentStreamEvent(
|
|
1441
|
+
**self.create_text_delta(
|
|
1442
|
+
delta=action_message, item_id=item_id
|
|
1443
|
+
)
|
|
1444
|
+
)
|
|
456
1445
|
|
|
457
1446
|
# Yield final output item
|
|
458
1447
|
yield ResponsesAgentStreamEvent(
|
|
459
1448
|
type="response.output_item.done",
|
|
460
|
-
item=self.create_text_output_item(
|
|
461
|
-
text=accumulated_content, id=item_id
|
|
462
|
-
),
|
|
1449
|
+
item=self.create_text_output_item(text=output_text, id=item_id),
|
|
463
1450
|
custom_outputs=custom_outputs,
|
|
464
1451
|
)
|
|
465
1452
|
except Exception as e:
|
|
466
|
-
logger.error(
|
|
1453
|
+
logger.error("Error in graph streaming", error=str(e))
|
|
467
1454
|
raise
|
|
468
1455
|
|
|
469
1456
|
# Convert async generator to sync generator
|
|
470
1457
|
try:
|
|
471
1458
|
loop = asyncio.get_event_loop()
|
|
472
1459
|
except RuntimeError:
|
|
473
|
-
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
474
1460
|
loop = asyncio.new_event_loop()
|
|
475
1461
|
asyncio.set_event_loop(loop)
|
|
476
1462
|
|
|
@@ -484,13 +1470,13 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
484
1470
|
except StopAsyncIteration:
|
|
485
1471
|
break
|
|
486
1472
|
except Exception as e:
|
|
487
|
-
logger.error(
|
|
1473
|
+
logger.error("Error in streaming", error=str(e))
|
|
488
1474
|
raise
|
|
489
1475
|
finally:
|
|
490
1476
|
try:
|
|
491
1477
|
loop.run_until_complete(async_gen.aclose())
|
|
492
1478
|
except Exception as e:
|
|
493
|
-
logger.warning(
|
|
1479
|
+
logger.warning("Error closing async generator", error=str(e))
|
|
494
1480
|
|
|
495
1481
|
def _extract_text_from_content(
|
|
496
1482
|
self,
|
|
@@ -555,15 +1541,27 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
555
1541
|
return messages
|
|
556
1542
|
|
|
557
1543
|
def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
|
|
558
|
-
"""Convert ResponsesAgent context to internal Context.
|
|
1544
|
+
"""Convert ResponsesAgent context to internal Context.
|
|
1545
|
+
|
|
1546
|
+
Handles the input structure:
|
|
1547
|
+
- custom_inputs.configurable: Configuration (thread_id, user_id, store_num, etc.)
|
|
1548
|
+
- custom_inputs.session: Accumulated state (conversation_id, genie conversations, etc.)
|
|
559
1549
|
|
|
560
|
-
|
|
561
|
-
|
|
1550
|
+
Maps conversation_id -> thread_id for LangGraph compatibility.
|
|
1551
|
+
conversation_id can be provided in either configurable or session.
|
|
1552
|
+
Normalizes user_id (replaces . with _) for memory namespace compatibility.
|
|
1553
|
+
"""
|
|
1554
|
+
logger.trace(
|
|
1555
|
+
"Converting request to context",
|
|
1556
|
+
has_context=request.context is not None,
|
|
1557
|
+
has_custom_inputs=request.custom_inputs is not None,
|
|
1558
|
+
)
|
|
562
1559
|
|
|
563
1560
|
configurable: dict[str, Any] = {}
|
|
1561
|
+
session: dict[str, Any] = {}
|
|
564
1562
|
|
|
565
1563
|
# Process context values first (lower priority)
|
|
566
|
-
#
|
|
1564
|
+
# These come from Databricks ResponsesAgent ChatContext
|
|
567
1565
|
chat_context: Optional[ChatContext] = request.context
|
|
568
1566
|
if chat_context is not None:
|
|
569
1567
|
conversation_id: Optional[str] = chat_context.conversation_id
|
|
@@ -571,27 +1569,189 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
571
1569
|
|
|
572
1570
|
if conversation_id is not None:
|
|
573
1571
|
configurable["conversation_id"] = conversation_id
|
|
574
|
-
configurable["thread_id"] = conversation_id
|
|
575
1572
|
|
|
576
1573
|
if user_id is not None:
|
|
577
1574
|
configurable["user_id"] = user_id
|
|
578
1575
|
|
|
579
1576
|
# Process custom_inputs after context so they can override context values (higher priority)
|
|
580
1577
|
if request.custom_inputs:
|
|
1578
|
+
# Extract configurable section (user config)
|
|
581
1579
|
if "configurable" in request.custom_inputs:
|
|
582
|
-
configurable.update(request.custom_inputs
|
|
1580
|
+
configurable.update(request.custom_inputs["configurable"])
|
|
1581
|
+
|
|
1582
|
+
# Extract session section
|
|
1583
|
+
if "session" in request.custom_inputs:
|
|
1584
|
+
session_input = request.custom_inputs["session"]
|
|
1585
|
+
if isinstance(session_input, dict):
|
|
1586
|
+
session = session_input
|
|
1587
|
+
|
|
1588
|
+
# Handle legacy flat structure (backwards compatibility)
|
|
1589
|
+
# If user passes keys directly in custom_inputs, merge them
|
|
1590
|
+
for key in list(request.custom_inputs.keys()):
|
|
1591
|
+
if key not in ("configurable", "session"):
|
|
1592
|
+
configurable[key] = request.custom_inputs[key]
|
|
1593
|
+
|
|
1594
|
+
# Extract known Context fields
|
|
1595
|
+
user_id_value: str | None = configurable.pop("user_id", None)
|
|
1596
|
+
if user_id_value:
|
|
1597
|
+
# Normalize user_id for memory namespace (replace . with _)
|
|
1598
|
+
user_id_value = user_id_value.replace(".", "_")
|
|
1599
|
+
|
|
1600
|
+
# Accept thread_id from configurable, or conversation_id from configurable or session
|
|
1601
|
+
# Priority: configurable.conversation_id > session.conversation_id > configurable.thread_id
|
|
1602
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
1603
|
+
conversation_id: str | None = configurable.pop("conversation_id", None)
|
|
1604
|
+
|
|
1605
|
+
# Also check session for conversation_id (output puts it there)
|
|
1606
|
+
if conversation_id is None and "conversation_id" in session:
|
|
1607
|
+
conversation_id = session.get("conversation_id")
|
|
1608
|
+
|
|
1609
|
+
# conversation_id takes precedence if provided
|
|
1610
|
+
if conversation_id:
|
|
1611
|
+
thread_id = conversation_id
|
|
1612
|
+
if not thread_id:
|
|
1613
|
+
# Generate new thread_id if neither provided
|
|
1614
|
+
thread_id = str(uuid.uuid4())
|
|
1615
|
+
|
|
1616
|
+
# All remaining configurable values become top-level context attributes
|
|
1617
|
+
logger.trace(
|
|
1618
|
+
"Creating context",
|
|
1619
|
+
user_id=user_id_value,
|
|
1620
|
+
thread_id=thread_id,
|
|
1621
|
+
extra_keys=list(configurable.keys()) if configurable else [],
|
|
1622
|
+
)
|
|
1623
|
+
|
|
1624
|
+
return Context(
|
|
1625
|
+
user_id=user_id_value,
|
|
1626
|
+
thread_id=thread_id,
|
|
1627
|
+
**configurable, # Pass remaining configurable values as context attributes
|
|
1628
|
+
)
|
|
1629
|
+
|
|
1630
|
+
def _extract_session_from_request(
|
|
1631
|
+
self, request: ResponsesAgentRequest
|
|
1632
|
+
) -> dict[str, Any]:
|
|
1633
|
+
"""Extract session state from request for passing to graph.
|
|
1634
|
+
|
|
1635
|
+
Handles:
|
|
1636
|
+
- New structure: custom_inputs.session.genie
|
|
1637
|
+
- Legacy structure: custom_inputs.genie_conversation_ids
|
|
1638
|
+
"""
|
|
1639
|
+
session: dict[str, Any] = {}
|
|
1640
|
+
|
|
1641
|
+
if not request.custom_inputs:
|
|
1642
|
+
return session
|
|
1643
|
+
|
|
1644
|
+
# New structure: session.genie
|
|
1645
|
+
if "session" in request.custom_inputs:
|
|
1646
|
+
session_input = request.custom_inputs["session"]
|
|
1647
|
+
if isinstance(session_input, dict) and "genie" in session_input:
|
|
1648
|
+
genie_state = session_input["genie"]
|
|
1649
|
+
# Extract conversation IDs from the new structure
|
|
1650
|
+
if isinstance(genie_state, dict) and "spaces" in genie_state:
|
|
1651
|
+
genie_conversation_ids = {}
|
|
1652
|
+
for space_id, space_state in genie_state["spaces"].items():
|
|
1653
|
+
if (
|
|
1654
|
+
isinstance(space_state, dict)
|
|
1655
|
+
and "conversation_id" in space_state
|
|
1656
|
+
):
|
|
1657
|
+
genie_conversation_ids[space_id] = space_state[
|
|
1658
|
+
"conversation_id"
|
|
1659
|
+
]
|
|
1660
|
+
if genie_conversation_ids:
|
|
1661
|
+
session["genie_conversation_ids"] = genie_conversation_ids
|
|
1662
|
+
|
|
1663
|
+
# Legacy structure: genie_conversation_ids at top level
|
|
1664
|
+
if "genie_conversation_ids" in request.custom_inputs:
|
|
1665
|
+
session["genie_conversation_ids"] = request.custom_inputs[
|
|
1666
|
+
"genie_conversation_ids"
|
|
1667
|
+
]
|
|
1668
|
+
|
|
1669
|
+
# Also check inside configurable for legacy support
|
|
1670
|
+
if "configurable" in request.custom_inputs:
|
|
1671
|
+
cfg = request.custom_inputs["configurable"]
|
|
1672
|
+
if isinstance(cfg, dict) and "genie_conversation_ids" in cfg:
|
|
1673
|
+
session["genie_conversation_ids"] = cfg["genie_conversation_ids"]
|
|
1674
|
+
|
|
1675
|
+
return session
|
|
1676
|
+
|
|
1677
|
+
def _build_custom_outputs(
|
|
1678
|
+
self,
|
|
1679
|
+
context: Context,
|
|
1680
|
+
thread_id: Optional[str],
|
|
1681
|
+
loop: Any, # asyncio.AbstractEventLoop
|
|
1682
|
+
) -> dict[str, Any]:
|
|
1683
|
+
"""Build custom_outputs that can be copy-pasted as next request's custom_inputs.
|
|
1684
|
+
|
|
1685
|
+
Output structure:
|
|
1686
|
+
configurable:
|
|
1687
|
+
thread_id: "abc-123" # Thread identifier (conversation_id is alias)
|
|
1688
|
+
user_id: "nate.fleming" # De-normalized (no underscore replacement)
|
|
1689
|
+
store_num: "87887" # Any custom fields
|
|
1690
|
+
session:
|
|
1691
|
+
conversation_id: "abc-123" # Alias of thread_id for Databricks compatibility
|
|
1692
|
+
genie:
|
|
1693
|
+
spaces:
|
|
1694
|
+
space_123: {conversation_id: "conv_456", cache_hit: false, ...}
|
|
1695
|
+
"""
|
|
1696
|
+
return loop.run_until_complete(
|
|
1697
|
+
self._build_custom_outputs_async(context=context, thread_id=thread_id)
|
|
1698
|
+
)
|
|
583
1699
|
|
|
584
|
-
|
|
1700
|
+
async def _build_custom_outputs_async(
|
|
1701
|
+
self,
|
|
1702
|
+
context: Context,
|
|
1703
|
+
thread_id: Optional[str],
|
|
1704
|
+
) -> dict[str, Any]:
|
|
1705
|
+
"""Async version of _build_custom_outputs."""
|
|
1706
|
+
# Build configurable section
|
|
1707
|
+
# Note: only thread_id is included here (conversation_id goes in session)
|
|
1708
|
+
configurable: dict[str, Any] = {}
|
|
1709
|
+
|
|
1710
|
+
if thread_id:
|
|
1711
|
+
configurable["thread_id"] = thread_id
|
|
585
1712
|
|
|
586
|
-
|
|
587
|
-
|
|
1713
|
+
# Include user_id (keep normalized form for consistency)
|
|
1714
|
+
if context.user_id:
|
|
1715
|
+
configurable["user_id"] = context.user_id
|
|
588
1716
|
|
|
589
|
-
|
|
590
|
-
|
|
1717
|
+
# Include all extra fields from context (beyond user_id and thread_id)
|
|
1718
|
+
context_dict = context.model_dump()
|
|
1719
|
+
for key, value in context_dict.items():
|
|
1720
|
+
if key not in {"user_id", "thread_id"} and value is not None:
|
|
1721
|
+
configurable[key] = value
|
|
591
1722
|
|
|
592
|
-
|
|
1723
|
+
# Build session section with accumulated state
|
|
1724
|
+
# Note: conversation_id is included here as an alias of thread_id
|
|
1725
|
+
session: dict[str, Any] = {}
|
|
593
1726
|
|
|
594
|
-
|
|
1727
|
+
if thread_id:
|
|
1728
|
+
# Include conversation_id in session (alias of thread_id)
|
|
1729
|
+
session["conversation_id"] = thread_id
|
|
1730
|
+
|
|
1731
|
+
state_snapshot: Optional[StateSnapshot] = await get_state_snapshot_async(
|
|
1732
|
+
self.graph, thread_id
|
|
1733
|
+
)
|
|
1734
|
+
genie_conversation_ids: dict[str, str] = (
|
|
1735
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
1736
|
+
)
|
|
1737
|
+
if genie_conversation_ids:
|
|
1738
|
+
# Convert flat genie_conversation_ids to new session.genie.spaces structure
|
|
1739
|
+
session["genie"] = {
|
|
1740
|
+
"spaces": {
|
|
1741
|
+
space_id: {
|
|
1742
|
+
"conversation_id": conv_id,
|
|
1743
|
+
# Note: cache_hit, follow_up_questions populated by Genie tool
|
|
1744
|
+
"cache_hit": False,
|
|
1745
|
+
"follow_up_questions": [],
|
|
1746
|
+
}
|
|
1747
|
+
for space_id, conv_id in genie_conversation_ids.items()
|
|
1748
|
+
}
|
|
1749
|
+
}
|
|
1750
|
+
|
|
1751
|
+
return {
|
|
1752
|
+
"configurable": configurable,
|
|
1753
|
+
"session": session,
|
|
1754
|
+
}
|
|
595
1755
|
|
|
596
1756
|
|
|
597
1757
|
def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
@@ -610,7 +1770,9 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
|
610
1770
|
return LanggraphChatModel(graph)
|
|
611
1771
|
|
|
612
1772
|
|
|
613
|
-
def create_responses_agent(
|
|
1773
|
+
def create_responses_agent(
|
|
1774
|
+
graph: CompiledStateGraph,
|
|
1775
|
+
) -> ResponsesAgent:
|
|
614
1776
|
"""
|
|
615
1777
|
Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
|
|
616
1778
|
|
|
@@ -645,6 +1807,29 @@ def _process_langchain_messages(
|
|
|
645
1807
|
return loop.run_until_complete(_async_invoke())
|
|
646
1808
|
|
|
647
1809
|
|
|
1810
|
+
def _configurable_to_context(configurable: dict[str, Any]) -> Context:
|
|
1811
|
+
"""Convert a configurable dict to a Context object."""
|
|
1812
|
+
configurable = configurable.copy()
|
|
1813
|
+
|
|
1814
|
+
# Extract known Context fields
|
|
1815
|
+
user_id: str | None = configurable.pop("user_id", None)
|
|
1816
|
+
if user_id:
|
|
1817
|
+
user_id = user_id.replace(".", "_")
|
|
1818
|
+
|
|
1819
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
1820
|
+
if "conversation_id" in configurable and not thread_id:
|
|
1821
|
+
thread_id = configurable.pop("conversation_id")
|
|
1822
|
+
if not thread_id:
|
|
1823
|
+
thread_id = str(uuid.uuid4())
|
|
1824
|
+
|
|
1825
|
+
# All remaining values become top-level context attributes
|
|
1826
|
+
return Context(
|
|
1827
|
+
user_id=user_id,
|
|
1828
|
+
thread_id=thread_id,
|
|
1829
|
+
**configurable, # Extra fields become top-level attributes
|
|
1830
|
+
)
|
|
1831
|
+
|
|
1832
|
+
|
|
648
1833
|
def _process_langchain_messages_stream(
|
|
649
1834
|
app: LanggraphChatModel | CompiledStateGraph,
|
|
650
1835
|
messages: Sequence[BaseMessage],
|
|
@@ -656,10 +1841,14 @@ def _process_langchain_messages_stream(
|
|
|
656
1841
|
if isinstance(app, LanggraphChatModel):
|
|
657
1842
|
app = app.graph
|
|
658
1843
|
|
|
659
|
-
logger.
|
|
1844
|
+
logger.trace(
|
|
1845
|
+
"Processing messages for stream",
|
|
1846
|
+
messages_count=len(messages),
|
|
1847
|
+
has_custom_inputs=custom_inputs is not None,
|
|
1848
|
+
)
|
|
660
1849
|
|
|
661
|
-
|
|
662
|
-
context: Context =
|
|
1850
|
+
configurable = (custom_inputs or {}).get("configurable", custom_inputs or {})
|
|
1851
|
+
context: Context = _configurable_to_context(configurable)
|
|
663
1852
|
|
|
664
1853
|
# Use async astream internally for parallel execution
|
|
665
1854
|
async def _async_stream():
|
|
@@ -674,7 +1863,10 @@ def _process_langchain_messages_stream(
|
|
|
674
1863
|
stream_mode: str
|
|
675
1864
|
stream_messages: Sequence[BaseMessage]
|
|
676
1865
|
logger.trace(
|
|
677
|
-
|
|
1866
|
+
"Stream batch received",
|
|
1867
|
+
nodes=nodes,
|
|
1868
|
+
stream_mode=stream_mode,
|
|
1869
|
+
messages_count=len(stream_messages),
|
|
678
1870
|
)
|
|
679
1871
|
for message in stream_messages:
|
|
680
1872
|
if (
|