dao-ai 0.0.36__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.0.dist-info/METADATA +1878 -0
- dao_ai-0.1.0.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/models.py
CHANGED
|
@@ -1,11 +1,34 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from os import PathLike
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generator, Optional, Sequence, Union
|
|
5
|
-
|
|
6
|
-
from
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Sequence, Union
|
|
5
|
+
|
|
6
|
+
from databricks_langchain import ChatDatabricks
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
# Import official LangChain HITL TypedDict definitions
|
|
12
|
+
# Reference: https://docs.langchain.com/oss/python/langchain/human-in-the-loop
|
|
13
|
+
from langchain.agents.middleware.human_in_the_loop import (
|
|
14
|
+
ActionRequest,
|
|
15
|
+
Decision,
|
|
16
|
+
EditDecision,
|
|
17
|
+
HITLRequest,
|
|
18
|
+
RejectDecision,
|
|
19
|
+
ReviewConfig,
|
|
20
|
+
)
|
|
21
|
+
from langchain_community.adapters.openai import convert_openai_messages
|
|
22
|
+
from langchain_core.language_models import LanguageModelLike
|
|
23
|
+
from langchain_core.messages import (
|
|
24
|
+
AIMessage,
|
|
25
|
+
AIMessageChunk,
|
|
26
|
+
BaseMessage,
|
|
27
|
+
HumanMessage,
|
|
28
|
+
SystemMessage,
|
|
29
|
+
)
|
|
7
30
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
|
-
from langgraph.types import StateSnapshot
|
|
31
|
+
from langgraph.types import Interrupt, StateSnapshot
|
|
9
32
|
from loguru import logger
|
|
10
33
|
from mlflow import MlflowClient
|
|
11
34
|
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
@@ -28,11 +51,13 @@ from mlflow.types.responses_helpers import (
|
|
|
28
51
|
Message,
|
|
29
52
|
ResponseInputTextParam,
|
|
30
53
|
)
|
|
54
|
+
from pydantic import BaseModel, Field, create_model
|
|
31
55
|
|
|
32
56
|
from dao_ai.messages import (
|
|
33
57
|
has_langchain_messages,
|
|
34
58
|
has_mlflow_messages,
|
|
35
59
|
has_mlflow_responses_messages,
|
|
60
|
+
last_human_message,
|
|
36
61
|
)
|
|
37
62
|
from dao_ai.state import Context
|
|
38
63
|
|
|
@@ -54,12 +79,37 @@ def get_latest_model_version(model_name: str) -> int:
|
|
|
54
79
|
mlflow_client: MlflowClient = MlflowClient()
|
|
55
80
|
latest_version: int = 1
|
|
56
81
|
for mv in mlflow_client.search_model_versions(f"name='{model_name}'"):
|
|
57
|
-
version_int = int(mv.version)
|
|
82
|
+
version_int: int = int(mv.version)
|
|
58
83
|
if version_int > latest_version:
|
|
59
84
|
latest_version = version_int
|
|
60
85
|
return latest_version
|
|
61
86
|
|
|
62
87
|
|
|
88
|
+
def is_interrupted(snapshot: StateSnapshot) -> bool:
|
|
89
|
+
"""
|
|
90
|
+
Check if the graph state is currently interrupted (paused for human-in-the-loop).
|
|
91
|
+
|
|
92
|
+
Based on LangChain documentation:
|
|
93
|
+
- StateSnapshot has an `interrupts` attribute which is a tuple
|
|
94
|
+
- When interrupted, the tuple contains Interrupt objects
|
|
95
|
+
- When not interrupted, it's an empty tuple ()
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
snapshot: The StateSnapshot to check
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
True if the graph is interrupted (has pending HITL actions), False otherwise
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
>>> snapshot = await graph.aget_state(config)
|
|
105
|
+
>>> if is_interrupted(snapshot):
|
|
106
|
+
... print("Graph is waiting for human input")
|
|
107
|
+
"""
|
|
108
|
+
# Check if snapshot has any interrupts
|
|
109
|
+
# According to LangChain docs, interrupts is a tuple that's empty () when no interrupts
|
|
110
|
+
return bool(snapshot.interrupts)
|
|
111
|
+
|
|
112
|
+
|
|
63
113
|
async def get_state_snapshot_async(
|
|
64
114
|
graph: CompiledStateGraph, thread_id: str
|
|
65
115
|
) -> Optional[StateSnapshot]:
|
|
@@ -167,6 +217,111 @@ def get_genie_conversation_ids_from_state(
|
|
|
167
217
|
return {}
|
|
168
218
|
|
|
169
219
|
|
|
220
|
+
def _extract_interrupt_value(interrupt: Interrupt) -> HITLRequest:
|
|
221
|
+
"""
|
|
222
|
+
Extract the HITL request from a LangGraph Interrupt object.
|
|
223
|
+
|
|
224
|
+
Following LangChain patterns, the Interrupt object has a .value attribute
|
|
225
|
+
containing the HITLRequest structure with action_requests and review_configs.
|
|
226
|
+
|
|
227
|
+
Args:
|
|
228
|
+
interrupt: Interrupt object from LangGraph with .value and .id attributes
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
HITLRequest with action_requests and review_configs
|
|
232
|
+
"""
|
|
233
|
+
# Interrupt.value is typed as Any, but for HITL it should be a HITLRequest dict
|
|
234
|
+
if isinstance(interrupt.value, dict):
|
|
235
|
+
# Return as HITLRequest TypedDict
|
|
236
|
+
return interrupt.value # type: ignore[return-value]
|
|
237
|
+
|
|
238
|
+
# Fallback: return empty structure if value is not a dict
|
|
239
|
+
return {"action_requests": [], "review_configs": []}
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _format_action_requests_message(interrupt_data: list[HITLRequest]) -> str:
|
|
243
|
+
"""
|
|
244
|
+
Format action requests from interrupts into a simple, user-friendly message.
|
|
245
|
+
|
|
246
|
+
Since we now use LLM-based parsing, users can respond in natural language.
|
|
247
|
+
This function just shows WHAT actions are pending, not HOW to respond.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
interrupt_data: List of HITLRequest structures containing action_requests and review_configs
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Simple formatted message describing the pending actions
|
|
254
|
+
"""
|
|
255
|
+
if not interrupt_data:
|
|
256
|
+
return ""
|
|
257
|
+
|
|
258
|
+
# Collect all action requests and review configs from all interrupts
|
|
259
|
+
all_actions: list[ActionRequest] = []
|
|
260
|
+
review_configs_map: dict[str, ReviewConfig] = {}
|
|
261
|
+
|
|
262
|
+
for hitl_request in interrupt_data:
|
|
263
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
264
|
+
for review_config in hitl_request.get("review_configs", []):
|
|
265
|
+
action_name = review_config.get("action_name", "")
|
|
266
|
+
if action_name:
|
|
267
|
+
review_configs_map[action_name] = review_config
|
|
268
|
+
|
|
269
|
+
if not all_actions:
|
|
270
|
+
return ""
|
|
271
|
+
|
|
272
|
+
# Build simple, clean message
|
|
273
|
+
lines = ["⏸️ **Action Approval Required**", ""]
|
|
274
|
+
lines.append(
|
|
275
|
+
f"The assistant wants to perform {len(all_actions)} action(s) that require your approval:"
|
|
276
|
+
)
|
|
277
|
+
lines.append("")
|
|
278
|
+
|
|
279
|
+
for i, action in enumerate(all_actions, 1):
|
|
280
|
+
tool_name = action.get("name", "unknown")
|
|
281
|
+
args = action.get("args", {})
|
|
282
|
+
description = action.get("description")
|
|
283
|
+
|
|
284
|
+
lines.append(f"**{i}. {tool_name}**")
|
|
285
|
+
|
|
286
|
+
# Show review prompt/description if available
|
|
287
|
+
if description:
|
|
288
|
+
lines.append(f" • **Review:** {description}")
|
|
289
|
+
|
|
290
|
+
if args:
|
|
291
|
+
# Format args nicely, truncating long values
|
|
292
|
+
for key, value in args.items():
|
|
293
|
+
value_str = str(value)
|
|
294
|
+
if len(value_str) > 100:
|
|
295
|
+
value_str = value_str[:100] + "..."
|
|
296
|
+
lines.append(f" • {key}: `{value_str}`")
|
|
297
|
+
else:
|
|
298
|
+
lines.append(" • (no arguments)")
|
|
299
|
+
|
|
300
|
+
# Show allowed decisions
|
|
301
|
+
review_config = review_configs_map.get(tool_name)
|
|
302
|
+
if review_config:
|
|
303
|
+
allowed_decisions = review_config.get("allowed_decisions", [])
|
|
304
|
+
if allowed_decisions:
|
|
305
|
+
decisions_str = ", ".join(allowed_decisions)
|
|
306
|
+
lines.append(f" • **Options:** {decisions_str}")
|
|
307
|
+
|
|
308
|
+
lines.append("")
|
|
309
|
+
|
|
310
|
+
lines.append("---")
|
|
311
|
+
lines.append("")
|
|
312
|
+
lines.append(
|
|
313
|
+
"**You can respond in natural language** (e.g., 'approve both', 'reject the first one', "
|
|
314
|
+
"'change the email to new@example.com')"
|
|
315
|
+
)
|
|
316
|
+
lines.append("")
|
|
317
|
+
lines.append(
|
|
318
|
+
"Or provide structured decisions in `custom_inputs` with key `decisions`: "
|
|
319
|
+
'`[{"type": "approve"}, {"type": "reject", "message": "reason"}]`'
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return "\n".join(lines)
|
|
323
|
+
|
|
324
|
+
|
|
170
325
|
class LanggraphChatModel(ChatModel):
|
|
171
326
|
"""
|
|
172
327
|
ChatModel that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -216,22 +371,36 @@ class LanggraphChatModel(ChatModel):
|
|
|
216
371
|
|
|
217
372
|
configurable: dict[str, Any] = {}
|
|
218
373
|
if "configurable" in input_data:
|
|
219
|
-
configurable
|
|
374
|
+
configurable = input_data.pop("configurable")
|
|
220
375
|
if "custom_inputs" in input_data:
|
|
221
376
|
custom_inputs: dict[str, Any] = input_data.pop("custom_inputs")
|
|
222
377
|
if "configurable" in custom_inputs:
|
|
223
|
-
configurable
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
378
|
+
configurable = custom_inputs.pop("configurable")
|
|
379
|
+
|
|
380
|
+
# Extract known Context fields
|
|
381
|
+
user_id: str | None = configurable.pop("user_id", None)
|
|
382
|
+
if user_id:
|
|
383
|
+
user_id = user_id.replace(".", "_")
|
|
384
|
+
|
|
385
|
+
# Accept either thread_id or conversation_id (interchangeable)
|
|
386
|
+
# conversation_id takes precedence (Databricks vocabulary)
|
|
387
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
388
|
+
conversation_id: str | None = configurable.pop("conversation_id", None)
|
|
389
|
+
|
|
390
|
+
# conversation_id takes precedence if both provided
|
|
391
|
+
if conversation_id:
|
|
392
|
+
thread_id = conversation_id
|
|
393
|
+
if not thread_id:
|
|
394
|
+
thread_id = str(uuid.uuid4())
|
|
395
|
+
|
|
396
|
+
# All remaining configurable values go into custom dict
|
|
397
|
+
custom: dict[str, Any] = configurable
|
|
398
|
+
|
|
399
|
+
context: Context = Context(
|
|
400
|
+
user_id=user_id,
|
|
401
|
+
thread_id=thread_id,
|
|
402
|
+
custom=custom,
|
|
403
|
+
)
|
|
235
404
|
return context
|
|
236
405
|
|
|
237
406
|
def predict_stream(
|
|
@@ -307,6 +476,322 @@ class LanggraphChatModel(ChatModel):
|
|
|
307
476
|
return [m.to_dict() for m in messages]
|
|
308
477
|
|
|
309
478
|
|
|
479
|
+
def _create_decision_schema(interrupt_data: list[HITLRequest]) -> type[BaseModel]:
|
|
480
|
+
"""
|
|
481
|
+
Dynamically create a Pydantic model for structured output based on interrupt actions.
|
|
482
|
+
|
|
483
|
+
This creates a schema that matches the expected decision format for the interrupted actions.
|
|
484
|
+
Each action gets a corresponding decision field that can be approve, edit, or reject.
|
|
485
|
+
Includes validation fields to ensure the response is complete and valid.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
interrupt_data: List of HITL interrupt requests containing action_requests and review_configs
|
|
489
|
+
|
|
490
|
+
Returns:
|
|
491
|
+
A dynamically created Pydantic BaseModel class for structured output
|
|
492
|
+
|
|
493
|
+
Example:
|
|
494
|
+
For two actions (send_email, execute_sql), creates a model like:
|
|
495
|
+
class Decisions(BaseModel):
|
|
496
|
+
is_valid: bool
|
|
497
|
+
validation_message: Optional[str]
|
|
498
|
+
decision_1: Literal["approve", "edit", "reject"]
|
|
499
|
+
decision_1_message: Optional[str] # For reject
|
|
500
|
+
decision_1_edited_args: Optional[dict] # For edit
|
|
501
|
+
decision_2: Literal["approve", "edit", "reject"]
|
|
502
|
+
...
|
|
503
|
+
"""
|
|
504
|
+
# Collect all actions
|
|
505
|
+
all_actions: list[ActionRequest] = []
|
|
506
|
+
review_configs_map: dict[str, ReviewConfig] = {}
|
|
507
|
+
|
|
508
|
+
for hitl_request in interrupt_data:
|
|
509
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
510
|
+
review_config: ReviewConfig
|
|
511
|
+
for review_config in hitl_request.get("review_configs", []):
|
|
512
|
+
action_name: str = review_config.get("action_name", "")
|
|
513
|
+
if action_name:
|
|
514
|
+
review_configs_map[action_name] = review_config
|
|
515
|
+
|
|
516
|
+
# Build fields for the dynamic model
|
|
517
|
+
# Start with validation fields
|
|
518
|
+
fields: dict[str, Any] = {
|
|
519
|
+
"is_valid": (
|
|
520
|
+
bool,
|
|
521
|
+
Field(
|
|
522
|
+
description="Whether the user's response provides valid decisions for ALL actions. "
|
|
523
|
+
"Set to False if the user's message is unclear, ambiguous, or doesn't provide decisions for all actions."
|
|
524
|
+
),
|
|
525
|
+
),
|
|
526
|
+
"validation_message": (
|
|
527
|
+
Optional[str],
|
|
528
|
+
Field(
|
|
529
|
+
None,
|
|
530
|
+
description="If is_valid is False, explain what is missing or unclear. "
|
|
531
|
+
"Be specific about which action(s) need clarification.",
|
|
532
|
+
),
|
|
533
|
+
),
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
i: int
|
|
537
|
+
action: ActionRequest
|
|
538
|
+
for i, action in enumerate(all_actions, 1):
|
|
539
|
+
tool_name: str = action.get("name", "unknown")
|
|
540
|
+
review_config: Optional[ReviewConfig] = review_configs_map.get(tool_name)
|
|
541
|
+
allowed_decisions: list[str] = (
|
|
542
|
+
review_config.get("allowed_decisions", ["approve", "reject"])
|
|
543
|
+
if review_config
|
|
544
|
+
else ["approve", "reject"]
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
# Create a Literal type for allowed decisions
|
|
548
|
+
decision_literal: type = Literal[tuple(allowed_decisions)] # type: ignore
|
|
549
|
+
|
|
550
|
+
# Add decision field
|
|
551
|
+
fields[f"decision_{i}"] = (
|
|
552
|
+
decision_literal,
|
|
553
|
+
Field(
|
|
554
|
+
description=f"Decision for action {i} ({tool_name}): {', '.join(allowed_decisions)}"
|
|
555
|
+
),
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
# Add optional message field for reject
|
|
559
|
+
if "reject" in allowed_decisions:
|
|
560
|
+
fields[f"decision_{i}_message"] = (
|
|
561
|
+
Optional[str],
|
|
562
|
+
Field(
|
|
563
|
+
None,
|
|
564
|
+
description=f"Optional message if rejecting action {i}",
|
|
565
|
+
),
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Add optional edited_args field for edit
|
|
569
|
+
if "edit" in allowed_decisions:
|
|
570
|
+
fields[f"decision_{i}_edited_args"] = (
|
|
571
|
+
Optional[dict[str, Any]],
|
|
572
|
+
Field(
|
|
573
|
+
None,
|
|
574
|
+
description=f"Modified arguments if editing action {i}. Only provide fields that need to change.",
|
|
575
|
+
),
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
# Create the dynamic model
|
|
579
|
+
DecisionsModel = create_model(
|
|
580
|
+
"InterruptDecisions",
|
|
581
|
+
__doc__="Decisions for each interrupted action, in order.",
|
|
582
|
+
**fields,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
return DecisionsModel
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def _convert_schema_to_decisions(
|
|
589
|
+
parsed_output: BaseModel,
|
|
590
|
+
interrupt_data: list[HITLRequest],
|
|
591
|
+
) -> list[Decision]:
|
|
592
|
+
"""
|
|
593
|
+
Convert the parsed structured output into LangChain Decision objects.
|
|
594
|
+
|
|
595
|
+
Args:
|
|
596
|
+
parsed_output: The Pydantic model instance from structured output
|
|
597
|
+
interrupt_data: Original interrupt data for context
|
|
598
|
+
|
|
599
|
+
Returns:
|
|
600
|
+
List of Decision dictionaries compatible with Command(resume={"decisions": ...})
|
|
601
|
+
"""
|
|
602
|
+
# Collect all actions to know how many decisions we need
|
|
603
|
+
all_actions: list[ActionRequest] = []
|
|
604
|
+
hitl_request: HITLRequest
|
|
605
|
+
for hitl_request in interrupt_data:
|
|
606
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
607
|
+
|
|
608
|
+
decisions: list[Decision] = []
|
|
609
|
+
|
|
610
|
+
i: int
|
|
611
|
+
for i in range(1, len(all_actions) + 1):
|
|
612
|
+
decision_type: str = getattr(parsed_output, f"decision_{i}")
|
|
613
|
+
|
|
614
|
+
if decision_type == "approve":
|
|
615
|
+
decisions.append({"type": "approve"}) # type: ignore
|
|
616
|
+
elif decision_type == "reject":
|
|
617
|
+
message: Optional[str] = getattr(
|
|
618
|
+
parsed_output, f"decision_{i}_message", None
|
|
619
|
+
)
|
|
620
|
+
reject_decision: RejectDecision = {"type": "reject"}
|
|
621
|
+
if message:
|
|
622
|
+
reject_decision["message"] = message
|
|
623
|
+
decisions.append(reject_decision) # type: ignore
|
|
624
|
+
elif decision_type == "edit":
|
|
625
|
+
edited_args: Optional[dict[str, Any]] = getattr(
|
|
626
|
+
parsed_output, f"decision_{i}_edited_args", None
|
|
627
|
+
)
|
|
628
|
+
action: ActionRequest = all_actions[i - 1]
|
|
629
|
+
tool_name: str = action.get("name", "")
|
|
630
|
+
original_args: dict[str, Any] = action.get("args", {})
|
|
631
|
+
|
|
632
|
+
# Merge original args with edited args
|
|
633
|
+
final_args: dict[str, Any] = {**original_args, **(edited_args or {})}
|
|
634
|
+
|
|
635
|
+
edit_decision: EditDecision = {
|
|
636
|
+
"type": "edit",
|
|
637
|
+
"edited_action": {
|
|
638
|
+
"name": tool_name,
|
|
639
|
+
"args": final_args,
|
|
640
|
+
},
|
|
641
|
+
}
|
|
642
|
+
decisions.append(edit_decision) # type: ignore
|
|
643
|
+
|
|
644
|
+
return decisions
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def handle_interrupt_response(
|
|
648
|
+
snapshot: StateSnapshot,
|
|
649
|
+
messages: list[BaseMessage],
|
|
650
|
+
model: Optional[LanguageModelLike] = None,
|
|
651
|
+
) -> dict[str, Any]:
|
|
652
|
+
"""
|
|
653
|
+
Parse user's natural language response to interrupts using LLM with structured output.
|
|
654
|
+
|
|
655
|
+
This function uses an LLM to understand the user's intent and extract structured decisions
|
|
656
|
+
for each pending action. The schema is dynamically created based on the pending actions.
|
|
657
|
+
Includes validation to ensure the response is complete and valid.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
snapshot: The current state snapshot containing interrupts
|
|
661
|
+
messages: List of messages, from which the last human message will be extracted
|
|
662
|
+
model: Optional LLM to use for parsing. Defaults to Llama 3.1 70B
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
Dictionary with:
|
|
666
|
+
- "is_valid": bool indicating if the response is valid
|
|
667
|
+
- "validation_message": Optional message if invalid, explaining what's missing
|
|
668
|
+
- "decisions": list of Decision objects (empty if invalid)
|
|
669
|
+
|
|
670
|
+
Example:
|
|
671
|
+
Valid: {"is_valid": True, "validation_message": None, "decisions": [{"type": "approve"}]}
|
|
672
|
+
Invalid: {"is_valid": False, "validation_message": "Please specify...", "decisions": []}
|
|
673
|
+
"""
|
|
674
|
+
# Extract the last human message
|
|
675
|
+
user_message_obj: Optional[HumanMessage] = last_human_message(messages)
|
|
676
|
+
|
|
677
|
+
if not user_message_obj:
|
|
678
|
+
logger.warning("handle_interrupt_response called but no human message found")
|
|
679
|
+
return {
|
|
680
|
+
"is_valid": False,
|
|
681
|
+
"validation_message": "No user message found. Please provide a response to the pending action(s).",
|
|
682
|
+
"decisions": [],
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
user_message: str = str(user_message_obj.content)
|
|
686
|
+
logger.info(f"HITL: Parsing user message with LLM: {user_message[:100]}")
|
|
687
|
+
|
|
688
|
+
if not model:
|
|
689
|
+
model = ChatDatabricks(
|
|
690
|
+
endpoint="databricks-claude-sonnet-4",
|
|
691
|
+
temperature=0,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# Extract interrupt data
|
|
695
|
+
if not snapshot.interrupts:
|
|
696
|
+
logger.warning("handle_interrupt_response called but no interrupts in snapshot")
|
|
697
|
+
return {"decisions": []}
|
|
698
|
+
|
|
699
|
+
interrupt_data: list[HITLRequest] = [
|
|
700
|
+
_extract_interrupt_value(interrupt) for interrupt in snapshot.interrupts
|
|
701
|
+
]
|
|
702
|
+
|
|
703
|
+
# Collect all actions for context
|
|
704
|
+
all_actions: list[ActionRequest] = []
|
|
705
|
+
hitl_request: HITLRequest
|
|
706
|
+
for hitl_request in interrupt_data:
|
|
707
|
+
all_actions.extend(hitl_request.get("action_requests", []))
|
|
708
|
+
|
|
709
|
+
if not all_actions:
|
|
710
|
+
logger.warning("handle_interrupt_response called but no actions in interrupts")
|
|
711
|
+
return {"decisions": []}
|
|
712
|
+
|
|
713
|
+
# Create dynamic schema
|
|
714
|
+
DecisionsModel: type[BaseModel] = _create_decision_schema(interrupt_data)
|
|
715
|
+
|
|
716
|
+
# Create structured LLM
|
|
717
|
+
structured_llm: LanguageModelLike = model.with_structured_output(DecisionsModel)
|
|
718
|
+
|
|
719
|
+
# Format action context for the LLM
|
|
720
|
+
action_descriptions: list[str] = []
|
|
721
|
+
i: int
|
|
722
|
+
action: ActionRequest
|
|
723
|
+
for i, action in enumerate(all_actions, 1):
|
|
724
|
+
tool_name: str = action.get("name", "unknown")
|
|
725
|
+
args: dict[str, Any] = action.get("args", {})
|
|
726
|
+
args_str: str = (
|
|
727
|
+
", ".join(f"{k}={v}" for k, v in args.items()) if args else "(no args)"
|
|
728
|
+
)
|
|
729
|
+
action_descriptions.append(f"Action {i}: {tool_name}({args_str})")
|
|
730
|
+
|
|
731
|
+
system_prompt: str = f"""You are parsing a user's response to interrupted agent actions.
|
|
732
|
+
|
|
733
|
+
The following actions are pending approval:
|
|
734
|
+
{chr(10).join(action_descriptions)}
|
|
735
|
+
|
|
736
|
+
Your task is to extract the user's decision for EACH action in order. The user may:
|
|
737
|
+
- Approve: Accept the action as-is
|
|
738
|
+
- Reject: Cancel the action (optionally with a reason/message)
|
|
739
|
+
- Edit: Modify the arguments before executing
|
|
740
|
+
|
|
741
|
+
VALIDATION:
|
|
742
|
+
- Set is_valid=True only if you can confidently extract decisions for ALL actions
|
|
743
|
+
- Set is_valid=False if the user's message is:
|
|
744
|
+
* Unclear or ambiguous
|
|
745
|
+
* Missing decisions for some actions
|
|
746
|
+
* Asking a question instead of providing decisions
|
|
747
|
+
* Not addressing the actions at all
|
|
748
|
+
- If is_valid=False, provide a clear validation_message explaining what is needed
|
|
749
|
+
|
|
750
|
+
FLEXIBILITY:
|
|
751
|
+
- Be flexible in parsing informal language like "yes", "no", "ok", "change X to Y"
|
|
752
|
+
- If the user doesn't explicitly mention an action, assume they want to approve it
|
|
753
|
+
- Only mark as invalid if the message is genuinely unclear or incomplete"""
|
|
754
|
+
|
|
755
|
+
try:
|
|
756
|
+
# Invoke LLM with structured output
|
|
757
|
+
parsed: BaseModel = structured_llm.invoke(
|
|
758
|
+
[
|
|
759
|
+
SystemMessage(content=system_prompt),
|
|
760
|
+
HumanMessage(content=user_message),
|
|
761
|
+
]
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
# Check validation first
|
|
765
|
+
is_valid: bool = getattr(parsed, "is_valid", True)
|
|
766
|
+
validation_message: Optional[str] = getattr(parsed, "validation_message", None)
|
|
767
|
+
|
|
768
|
+
if not is_valid:
|
|
769
|
+
logger.warning(
|
|
770
|
+
f"HITL: Invalid user response. Reason: {validation_message or 'Unknown'}"
|
|
771
|
+
)
|
|
772
|
+
return {
|
|
773
|
+
"is_valid": False,
|
|
774
|
+
"validation_message": validation_message
|
|
775
|
+
or "Your response was unclear. Please provide a clear decision for each action.",
|
|
776
|
+
"decisions": [],
|
|
777
|
+
}
|
|
778
|
+
|
|
779
|
+
# Convert to Decision format
|
|
780
|
+
decisions: list[Decision] = _convert_schema_to_decisions(parsed, interrupt_data)
|
|
781
|
+
|
|
782
|
+
logger.info(f"Parsed {len(decisions)} decisions from user message")
|
|
783
|
+
return {"is_valid": True, "validation_message": None, "decisions": decisions}
|
|
784
|
+
|
|
785
|
+
except Exception as e:
|
|
786
|
+
logger.error(f"Failed to parse interrupt response: {e}")
|
|
787
|
+
# Return invalid response on parsing failure
|
|
788
|
+
return {
|
|
789
|
+
"is_valid": False,
|
|
790
|
+
"validation_message": f"Failed to parse your response: {str(e)}. Please provide a clear decision for each action.",
|
|
791
|
+
"decisions": [],
|
|
792
|
+
}
|
|
793
|
+
|
|
794
|
+
|
|
310
795
|
class LanggraphResponsesAgent(ResponsesAgent):
|
|
311
796
|
"""
|
|
312
797
|
ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -315,37 +800,151 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
315
800
|
support for streaming, tool calling, and async execution.
|
|
316
801
|
"""
|
|
317
802
|
|
|
318
|
-
def __init__(
|
|
803
|
+
def __init__(
|
|
804
|
+
self,
|
|
805
|
+
graph: CompiledStateGraph,
|
|
806
|
+
) -> None:
|
|
319
807
|
self.graph = graph
|
|
320
808
|
|
|
321
809
|
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
322
810
|
"""
|
|
323
811
|
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
812
|
+
|
|
813
|
+
Input structure (custom_inputs):
|
|
814
|
+
configurable:
|
|
815
|
+
thread_id: "abc-123" # Or conversation_id (aliases, conversation_id takes precedence)
|
|
816
|
+
user_id: "nate.fleming"
|
|
817
|
+
store_num: "87887"
|
|
818
|
+
session: # Paste from previous output
|
|
819
|
+
conversation_id: "abc-123" # Alias of thread_id
|
|
820
|
+
genie:
|
|
821
|
+
spaces:
|
|
822
|
+
space_123: {conversation_id: "conv_456", ...}
|
|
823
|
+
decisions: # For resuming interrupted graphs (HITL)
|
|
824
|
+
- type: "approve"
|
|
825
|
+
- type: "reject"
|
|
826
|
+
message: "Not authorized"
|
|
827
|
+
|
|
828
|
+
Output structure (custom_outputs):
|
|
829
|
+
configurable:
|
|
830
|
+
thread_id: "abc-123" # Only thread_id in configurable
|
|
831
|
+
user_id: "nate.fleming"
|
|
832
|
+
store_num: "87887"
|
|
833
|
+
session:
|
|
834
|
+
conversation_id: "abc-123" # conversation_id in session
|
|
835
|
+
genie:
|
|
836
|
+
spaces:
|
|
837
|
+
space_123: {conversation_id: "conv_456", ...}
|
|
838
|
+
pending_actions: # If HITL interrupt occurred
|
|
839
|
+
- name: "send_email"
|
|
840
|
+
arguments: {...}
|
|
841
|
+
description: "..."
|
|
324
842
|
"""
|
|
325
843
|
logger.debug(f"ResponsesAgent request: {request}")
|
|
326
844
|
|
|
327
845
|
# Convert ResponsesAgent input to LangChain messages
|
|
328
|
-
messages = self._convert_request_to_langchain_messages(
|
|
846
|
+
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
847
|
+
request
|
|
848
|
+
)
|
|
329
849
|
|
|
330
|
-
# Prepare context
|
|
850
|
+
# Prepare context (conversation_id -> thread_id mapping happens here)
|
|
331
851
|
context: Context = self._convert_request_to_context(request)
|
|
332
852
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
333
853
|
|
|
334
|
-
#
|
|
335
|
-
|
|
336
|
-
if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
|
|
337
|
-
graph_input["genie_conversation_ids"] = request.custom_inputs[
|
|
338
|
-
"genie_conversation_ids"
|
|
339
|
-
]
|
|
340
|
-
logger.debug(
|
|
341
|
-
f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
|
|
342
|
-
)
|
|
854
|
+
# Extract session state from request
|
|
855
|
+
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
343
856
|
|
|
344
857
|
# Use async ainvoke internally for parallel execution
|
|
345
858
|
import asyncio
|
|
346
859
|
|
|
860
|
+
from langgraph.types import Command
|
|
861
|
+
|
|
347
862
|
async def _async_invoke():
|
|
348
863
|
try:
|
|
864
|
+
# Check if this is a resume request (HITL)
|
|
865
|
+
# Two ways to resume:
|
|
866
|
+
# 1. Explicit decisions in custom_inputs (structured)
|
|
867
|
+
# 2. Natural language message when graph is interrupted (LLM-parsed)
|
|
868
|
+
|
|
869
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
870
|
+
# Explicit structured decisions
|
|
871
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
872
|
+
logger.info(
|
|
873
|
+
f"HITL: Resuming with {len(decisions)} explicit decision(s)"
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# Resume interrupted graph with decisions
|
|
877
|
+
return await self.graph.ainvoke(
|
|
878
|
+
Command(resume={"decisions": decisions}),
|
|
879
|
+
context=context,
|
|
880
|
+
config=custom_inputs,
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
# Check if graph is currently interrupted (only if checkpointer is configured)
|
|
884
|
+
# aget_state requires a checkpointer
|
|
885
|
+
if self.graph.checkpointer:
|
|
886
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
887
|
+
config=custom_inputs
|
|
888
|
+
)
|
|
889
|
+
if is_interrupted(snapshot):
|
|
890
|
+
logger.info(
|
|
891
|
+
"HITL: Graph is interrupted, checking for user response"
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
# Convert message dicts to BaseMessage objects
|
|
895
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
896
|
+
messages
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
# Parse user's message with LLM to extract decisions
|
|
900
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
901
|
+
snapshot=snapshot,
|
|
902
|
+
messages=message_objects,
|
|
903
|
+
model=None, # Uses default model
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
# Check if the response was valid
|
|
907
|
+
if not parsed_result.get("is_valid", False):
|
|
908
|
+
validation_message: str = parsed_result.get(
|
|
909
|
+
"validation_message",
|
|
910
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
911
|
+
)
|
|
912
|
+
logger.warning(
|
|
913
|
+
f"HITL: Invalid response - {validation_message}"
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
# Return error message to user instead of resuming
|
|
917
|
+
# Don't resume the graph - stay interrupted so user can try again
|
|
918
|
+
return {
|
|
919
|
+
"messages": [
|
|
920
|
+
AIMessage(
|
|
921
|
+
content=f"❌ **Invalid Response**\n\n{validation_message}"
|
|
922
|
+
)
|
|
923
|
+
]
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
927
|
+
logger.info(
|
|
928
|
+
f"HITL: LLM parsed {len(decisions)} valid decision(s) from user message"
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# Resume interrupted graph with parsed decisions
|
|
932
|
+
return await self.graph.ainvoke(
|
|
933
|
+
Command(resume={"decisions": decisions}),
|
|
934
|
+
context=context,
|
|
935
|
+
config=custom_inputs,
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
# Normal invocation - build the graph input state
|
|
939
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
940
|
+
if "genie_conversation_ids" in session_input:
|
|
941
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
942
|
+
"genie_conversation_ids"
|
|
943
|
+
]
|
|
944
|
+
logger.debug(
|
|
945
|
+
f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
|
|
946
|
+
)
|
|
947
|
+
|
|
349
948
|
return await self.graph.ainvoke(
|
|
350
949
|
graph_input, context=context, config=custom_inputs
|
|
351
950
|
)
|
|
@@ -356,7 +955,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
356
955
|
try:
|
|
357
956
|
loop = asyncio.get_event_loop()
|
|
358
957
|
except RuntimeError:
|
|
359
|
-
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
360
958
|
loop = asyncio.new_event_loop()
|
|
361
959
|
asyncio.set_event_loop(loop)
|
|
362
960
|
|
|
@@ -371,22 +969,81 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
371
969
|
# Convert response to ResponsesAgent format
|
|
372
970
|
last_message: BaseMessage = response["messages"][-1]
|
|
373
971
|
|
|
374
|
-
|
|
375
|
-
|
|
972
|
+
# Build custom_outputs that can be copy-pasted as next request's custom_inputs
|
|
973
|
+
custom_outputs: dict[str, Any] = self._build_custom_outputs(
|
|
974
|
+
context=context,
|
|
975
|
+
thread_id=context.thread_id,
|
|
976
|
+
loop=loop,
|
|
376
977
|
)
|
|
377
978
|
|
|
378
|
-
#
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
979
|
+
# Handle structured_response if present
|
|
980
|
+
if "structured_response" in response:
|
|
981
|
+
from dataclasses import asdict, is_dataclass
|
|
982
|
+
|
|
983
|
+
from pydantic import BaseModel
|
|
984
|
+
|
|
985
|
+
structured_response = response["structured_response"]
|
|
986
|
+
logger.debug(f"Processing structured_response: {type(structured_response)}")
|
|
987
|
+
|
|
988
|
+
# Serialize to dict for JSON compatibility using type hints
|
|
989
|
+
if isinstance(structured_response, BaseModel):
|
|
990
|
+
# Pydantic model
|
|
991
|
+
serialized: dict[str, Any] = structured_response.model_dump()
|
|
992
|
+
elif is_dataclass(structured_response):
|
|
993
|
+
# Dataclass
|
|
994
|
+
serialized = asdict(structured_response)
|
|
995
|
+
elif isinstance(structured_response, dict):
|
|
996
|
+
# Already a dict
|
|
997
|
+
serialized = structured_response
|
|
998
|
+
else:
|
|
999
|
+
# Unknown type, convert to dict if possible
|
|
1000
|
+
serialized = (
|
|
1001
|
+
dict(structured_response)
|
|
1002
|
+
if hasattr(structured_response, "__dict__")
|
|
1003
|
+
else structured_response
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
# Place structured output in message content as JSON
|
|
1007
|
+
import json
|
|
1008
|
+
|
|
1009
|
+
structured_text: str = json.dumps(serialized, indent=2)
|
|
1010
|
+
output_item = self.create_text_output_item(
|
|
1011
|
+
text=structured_text, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
384
1012
|
)
|
|
385
|
-
|
|
386
|
-
|
|
1013
|
+
logger.debug("Placed structured_response in message content")
|
|
1014
|
+
else:
|
|
1015
|
+
# No structured response, use text content
|
|
1016
|
+
output_item = self.create_text_output_item(
|
|
1017
|
+
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
387
1018
|
)
|
|
388
|
-
|
|
389
|
-
|
|
1019
|
+
|
|
1020
|
+
# Include interrupt structure if HITL occurred (following LangChain pattern)
|
|
1021
|
+
if "__interrupt__" in response:
|
|
1022
|
+
interrupts: list[Interrupt] = response["__interrupt__"]
|
|
1023
|
+
logger.info(f"HITL: {len(interrupts)} interrupt(s) detected")
|
|
1024
|
+
|
|
1025
|
+
# Extract HITLRequest structures from interrupts (deduplicate by ID)
|
|
1026
|
+
seen_interrupt_ids: set[str] = set()
|
|
1027
|
+
interrupt_data: list[HITLRequest] = []
|
|
1028
|
+
interrupt: Interrupt
|
|
1029
|
+
for interrupt in interrupts:
|
|
1030
|
+
# Only process each unique interrupt once
|
|
1031
|
+
if interrupt.id not in seen_interrupt_ids:
|
|
1032
|
+
seen_interrupt_ids.add(interrupt.id)
|
|
1033
|
+
interrupt_data.append(_extract_interrupt_value(interrupt))
|
|
1034
|
+
logger.debug(f"HITL: Added interrupt {interrupt.id} to response")
|
|
1035
|
+
|
|
1036
|
+
custom_outputs["interrupts"] = interrupt_data
|
|
1037
|
+
logger.debug(
|
|
1038
|
+
f"HITL: Included {len(interrupt_data)} interrupt(s) in response"
|
|
1039
|
+
)
|
|
1040
|
+
|
|
1041
|
+
# Add user-facing message about the pending actions
|
|
1042
|
+
action_message: str = _format_action_requests_message(interrupt_data)
|
|
1043
|
+
if action_message:
|
|
1044
|
+
output_item = self.create_text_output_item(
|
|
1045
|
+
text=action_message, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
1046
|
+
)
|
|
390
1047
|
|
|
391
1048
|
return ResponsesAgentResponse(
|
|
392
1049
|
output=[output_item], custom_outputs=custom_outputs
|
|
@@ -397,89 +1054,310 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
397
1054
|
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
398
1055
|
"""
|
|
399
1056
|
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
1057
|
+
|
|
1058
|
+
Uses same input/output structure as predict() for consistency.
|
|
1059
|
+
Supports Human-in-the-Loop (HITL) interrupts.
|
|
400
1060
|
"""
|
|
401
1061
|
logger.debug(f"ResponsesAgent stream request: {request}")
|
|
402
1062
|
|
|
403
1063
|
# Convert ResponsesAgent input to LangChain messages
|
|
404
|
-
messages: list[
|
|
1064
|
+
messages: list[dict[str, Any]] = self._convert_request_to_langchain_messages(
|
|
405
1065
|
request
|
|
406
1066
|
)
|
|
407
1067
|
|
|
408
|
-
# Prepare context
|
|
1068
|
+
# Prepare context (conversation_id -> thread_id mapping happens here)
|
|
409
1069
|
context: Context = self._convert_request_to_context(request)
|
|
410
1070
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
411
1071
|
|
|
412
|
-
#
|
|
413
|
-
|
|
414
|
-
if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
|
|
415
|
-
graph_input["genie_conversation_ids"] = request.custom_inputs[
|
|
416
|
-
"genie_conversation_ids"
|
|
417
|
-
]
|
|
418
|
-
logger.debug(
|
|
419
|
-
f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
|
|
420
|
-
)
|
|
1072
|
+
# Extract session state from request
|
|
1073
|
+
session_input: dict[str, Any] = self._extract_session_from_request(request)
|
|
421
1074
|
|
|
422
1075
|
# Use async astream internally for parallel execution
|
|
423
1076
|
import asyncio
|
|
424
1077
|
|
|
1078
|
+
from langgraph.types import Command
|
|
1079
|
+
|
|
425
1080
|
async def _async_stream():
|
|
426
|
-
item_id = f"msg_{uuid.uuid4().hex[:8]}"
|
|
427
|
-
accumulated_content = ""
|
|
1081
|
+
item_id: str = f"msg_{uuid.uuid4().hex[:8]}"
|
|
1082
|
+
accumulated_content: str = ""
|
|
1083
|
+
interrupt_data: list[HITLRequest] = []
|
|
1084
|
+
seen_interrupt_ids: set[str] = set() # Track processed interrupt IDs
|
|
1085
|
+
structured_response: Any = None # Track structured output from stream
|
|
428
1086
|
|
|
429
1087
|
try:
|
|
430
|
-
|
|
431
|
-
|
|
1088
|
+
# Check if this is a resume request (HITL)
|
|
1089
|
+
# Two ways to resume:
|
|
1090
|
+
# 1. Explicit decisions in custom_inputs (structured)
|
|
1091
|
+
# 2. Natural language message when graph is interrupted (LLM-parsed)
|
|
1092
|
+
|
|
1093
|
+
if request.custom_inputs and "decisions" in request.custom_inputs:
|
|
1094
|
+
# Explicit structured decisions
|
|
1095
|
+
decisions: list[Decision] = request.custom_inputs["decisions"]
|
|
1096
|
+
logger.info(
|
|
1097
|
+
f"HITL: Resuming with {len(decisions)} explicit decision(s)"
|
|
1098
|
+
)
|
|
1099
|
+
stream_input: Command | dict[str, Any] = Command(
|
|
1100
|
+
resume={"decisions": decisions}
|
|
1101
|
+
)
|
|
1102
|
+
elif self.graph.checkpointer:
|
|
1103
|
+
# Check if graph is currently interrupted (only if checkpointer is configured)
|
|
1104
|
+
# aget_state requires a checkpointer
|
|
1105
|
+
snapshot: StateSnapshot = await self.graph.aget_state(
|
|
1106
|
+
config=custom_inputs
|
|
1107
|
+
)
|
|
1108
|
+
if is_interrupted(snapshot):
|
|
1109
|
+
logger.info(
|
|
1110
|
+
"HITL: Graph is interrupted, checking for user response"
|
|
1111
|
+
)
|
|
1112
|
+
|
|
1113
|
+
# Convert message dicts to BaseMessage objects
|
|
1114
|
+
message_objects: list[BaseMessage] = convert_openai_messages(
|
|
1115
|
+
messages
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
# Parse user's message with LLM to extract decisions
|
|
1119
|
+
parsed_result: dict[str, Any] = handle_interrupt_response(
|
|
1120
|
+
snapshot=snapshot,
|
|
1121
|
+
messages=message_objects,
|
|
1122
|
+
model=None, # Uses default model
|
|
1123
|
+
)
|
|
1124
|
+
|
|
1125
|
+
# Check if the response was valid
|
|
1126
|
+
if not parsed_result.get("is_valid", False):
|
|
1127
|
+
validation_message: str = parsed_result.get(
|
|
1128
|
+
"validation_message",
|
|
1129
|
+
"Your response was unclear. Please provide a clear decision for each action.",
|
|
1130
|
+
)
|
|
1131
|
+
logger.warning(
|
|
1132
|
+
f"HITL: Invalid response - {validation_message}"
|
|
1133
|
+
)
|
|
1134
|
+
|
|
1135
|
+
# Build custom_outputs before returning
|
|
1136
|
+
custom_outputs: dict[
|
|
1137
|
+
str, Any
|
|
1138
|
+
] = await self._build_custom_outputs_async(
|
|
1139
|
+
context=context,
|
|
1140
|
+
thread_id=context.thread_id,
|
|
1141
|
+
)
|
|
1142
|
+
|
|
1143
|
+
# Yield error message to user - don't resume graph
|
|
1144
|
+
error_message: str = (
|
|
1145
|
+
f"❌ **Invalid Response**\n\n{validation_message}"
|
|
1146
|
+
)
|
|
1147
|
+
accumulated_content = error_message
|
|
1148
|
+
yield ResponsesAgentStreamEvent(
|
|
1149
|
+
type="response.output_item.done",
|
|
1150
|
+
item=self.create_text_output_item(
|
|
1151
|
+
text=error_message, id=item_id
|
|
1152
|
+
),
|
|
1153
|
+
custom_outputs=custom_outputs,
|
|
1154
|
+
)
|
|
1155
|
+
return # Don't resume - stay interrupted
|
|
1156
|
+
|
|
1157
|
+
decisions: list[Decision] = parsed_result.get("decisions", [])
|
|
1158
|
+
logger.info(
|
|
1159
|
+
f"HITL: LLM parsed {len(decisions)} valid decision(s) from user message"
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
# Resume interrupted graph with parsed decisions
|
|
1163
|
+
stream_input: Command | dict[str, Any] = Command(
|
|
1164
|
+
resume={"decisions": decisions}
|
|
1165
|
+
)
|
|
1166
|
+
else:
|
|
1167
|
+
# Graph not interrupted, use normal invocation
|
|
1168
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
1169
|
+
if "genie_conversation_ids" in session_input:
|
|
1170
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
1171
|
+
"genie_conversation_ids"
|
|
1172
|
+
]
|
|
1173
|
+
stream_input: Command | dict[str, Any] = graph_input
|
|
1174
|
+
else:
|
|
1175
|
+
# No checkpointer, use normal invocation
|
|
1176
|
+
graph_input: dict[str, Any] = {"messages": messages}
|
|
1177
|
+
if "genie_conversation_ids" in session_input:
|
|
1178
|
+
graph_input["genie_conversation_ids"] = session_input[
|
|
1179
|
+
"genie_conversation_ids"
|
|
1180
|
+
]
|
|
1181
|
+
stream_input: Command | dict[str, Any] = graph_input
|
|
1182
|
+
|
|
1183
|
+
# Stream the graph execution with both messages and updates modes to capture interrupts
|
|
1184
|
+
async for nodes, stream_mode, data in self.graph.astream(
|
|
1185
|
+
stream_input,
|
|
432
1186
|
context=context,
|
|
433
1187
|
config=custom_inputs,
|
|
434
|
-
stream_mode=["messages", "
|
|
1188
|
+
stream_mode=["messages", "updates"],
|
|
435
1189
|
subgraphs=True,
|
|
436
1190
|
):
|
|
437
1191
|
nodes: tuple[str, ...]
|
|
438
1192
|
stream_mode: str
|
|
439
|
-
messages_batch: Sequence[BaseMessage]
|
|
440
1193
|
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
1194
|
+
# Handle message streaming
|
|
1195
|
+
if stream_mode == "messages":
|
|
1196
|
+
messages_batch: Sequence[BaseMessage] = data
|
|
1197
|
+
message: BaseMessage
|
|
1198
|
+
for message in messages_batch:
|
|
1199
|
+
if (
|
|
1200
|
+
isinstance(
|
|
1201
|
+
message,
|
|
1202
|
+
(
|
|
1203
|
+
AIMessageChunk,
|
|
1204
|
+
AIMessage,
|
|
1205
|
+
),
|
|
1206
|
+
)
|
|
1207
|
+
and message.content
|
|
1208
|
+
and "summarization" not in nodes
|
|
1209
|
+
):
|
|
1210
|
+
content: str = message.content
|
|
1211
|
+
accumulated_content += content
|
|
1212
|
+
|
|
1213
|
+
# Yield streaming delta
|
|
1214
|
+
yield ResponsesAgentStreamEvent(
|
|
1215
|
+
**self.create_text_delta(
|
|
1216
|
+
delta=content, item_id=item_id
|
|
1217
|
+
)
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1220
|
+
# Handle interrupts (HITL) and state updates
|
|
1221
|
+
elif stream_mode == "updates":
|
|
1222
|
+
updates: dict[str, Any] = data
|
|
1223
|
+
source: str
|
|
1224
|
+
update: Any
|
|
1225
|
+
for source, update in updates.items():
|
|
1226
|
+
if source == "__interrupt__":
|
|
1227
|
+
interrupts: list[Interrupt] = update
|
|
1228
|
+
logger.info(
|
|
1229
|
+
f"HITL: {len(interrupts)} interrupt(s) detected during streaming"
|
|
1230
|
+
)
|
|
1231
|
+
|
|
1232
|
+
# Extract interrupt values (deduplicate by ID)
|
|
1233
|
+
interrupt: Interrupt
|
|
1234
|
+
for interrupt in interrupts:
|
|
1235
|
+
# Only process each unique interrupt once
|
|
1236
|
+
if interrupt.id not in seen_interrupt_ids:
|
|
1237
|
+
seen_interrupt_ids.add(interrupt.id)
|
|
1238
|
+
interrupt_data.append(
|
|
1239
|
+
_extract_interrupt_value(interrupt)
|
|
1240
|
+
)
|
|
1241
|
+
logger.debug(
|
|
1242
|
+
f"HITL: Added interrupt {interrupt.id} to response"
|
|
1243
|
+
)
|
|
1244
|
+
elif (
|
|
1245
|
+
isinstance(update, dict)
|
|
1246
|
+
and "structured_response" in update
|
|
1247
|
+
):
|
|
1248
|
+
# Capture structured_response from stream updates
|
|
1249
|
+
structured_response = update["structured_response"]
|
|
1250
|
+
logger.debug(
|
|
1251
|
+
f"Captured structured_response from stream: {type(structured_response)}"
|
|
1252
|
+
)
|
|
1253
|
+
|
|
1254
|
+
# Get final state to extract structured_response (only if checkpointer available)
|
|
1255
|
+
if self.graph.checkpointer:
|
|
1256
|
+
final_state: StateSnapshot = await self.graph.aget_state(
|
|
1257
|
+
config=custom_inputs
|
|
1258
|
+
)
|
|
1259
|
+
# Extract structured_response from state if not already captured
|
|
1260
|
+
if (
|
|
1261
|
+
"structured_response" in final_state.values
|
|
1262
|
+
and not structured_response
|
|
1263
|
+
):
|
|
1264
|
+
structured_response = final_state.values["structured_response"]
|
|
455
1265
|
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
1266
|
+
# Build custom_outputs
|
|
1267
|
+
custom_outputs: dict[str, Any] = await self._build_custom_outputs_async(
|
|
1268
|
+
context=context,
|
|
1269
|
+
thread_id=context.thread_id,
|
|
1270
|
+
)
|
|
1271
|
+
|
|
1272
|
+
# Handle structured_response in streaming if present
|
|
1273
|
+
output_text: str = accumulated_content
|
|
1274
|
+
if structured_response:
|
|
1275
|
+
from dataclasses import asdict, is_dataclass
|
|
460
1276
|
|
|
461
|
-
|
|
462
|
-
custom_outputs: dict[str, Any] = custom_inputs.copy()
|
|
463
|
-
thread_id: Optional[str] = context.thread_id
|
|
1277
|
+
from pydantic import BaseModel
|
|
464
1278
|
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
StateSnapshot
|
|
468
|
-
] = await get_state_snapshot_async(self.graph, thread_id)
|
|
469
|
-
genie_conversation_ids: dict[str, str] = (
|
|
470
|
-
get_genie_conversation_ids_from_state(state_snapshot)
|
|
1279
|
+
logger.debug(
|
|
1280
|
+
f"Processing structured_response in streaming: {type(structured_response)}"
|
|
471
1281
|
)
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
1282
|
+
|
|
1283
|
+
# Serialize to dict for JSON compatibility using type hints
|
|
1284
|
+
if isinstance(structured_response, BaseModel):
|
|
1285
|
+
serialized: dict[str, Any] = structured_response.model_dump()
|
|
1286
|
+
elif is_dataclass(structured_response):
|
|
1287
|
+
serialized = asdict(structured_response)
|
|
1288
|
+
elif isinstance(structured_response, dict):
|
|
1289
|
+
serialized = structured_response
|
|
1290
|
+
else:
|
|
1291
|
+
serialized = (
|
|
1292
|
+
dict(structured_response)
|
|
1293
|
+
if hasattr(structured_response, "__dict__")
|
|
1294
|
+
else structured_response
|
|
1295
|
+
)
|
|
1296
|
+
|
|
1297
|
+
# Place structured output in message content - stream as JSON
|
|
1298
|
+
import json
|
|
1299
|
+
|
|
1300
|
+
structured_text: str = json.dumps(serialized, indent=2)
|
|
1301
|
+
|
|
1302
|
+
# If we streamed text, append structured; if no text, use structured only
|
|
1303
|
+
if accumulated_content.strip():
|
|
1304
|
+
# Stream separator and structured output
|
|
1305
|
+
yield ResponsesAgentStreamEvent(
|
|
1306
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
475
1307
|
)
|
|
1308
|
+
yield ResponsesAgentStreamEvent(
|
|
1309
|
+
**self.create_text_delta(
|
|
1310
|
+
delta=structured_text, item_id=item_id
|
|
1311
|
+
)
|
|
1312
|
+
)
|
|
1313
|
+
output_text = f"{accumulated_content}\n\n{structured_text}"
|
|
1314
|
+
else:
|
|
1315
|
+
# No text content, stream structured output
|
|
1316
|
+
yield ResponsesAgentStreamEvent(
|
|
1317
|
+
**self.create_text_delta(
|
|
1318
|
+
delta=structured_text, item_id=item_id
|
|
1319
|
+
)
|
|
1320
|
+
)
|
|
1321
|
+
output_text = structured_text
|
|
1322
|
+
|
|
1323
|
+
logger.debug("Streamed structured_response in message content")
|
|
1324
|
+
|
|
1325
|
+
# Include interrupt structure if HITL occurred
|
|
1326
|
+
if interrupt_data:
|
|
1327
|
+
custom_outputs["interrupts"] = interrupt_data
|
|
1328
|
+
logger.info(
|
|
1329
|
+
f"HITL: Included {len(interrupt_data)} interrupt(s) in streaming response"
|
|
1330
|
+
)
|
|
1331
|
+
|
|
1332
|
+
# Add user-facing message about the pending actions
|
|
1333
|
+
action_message = _format_action_requests_message(interrupt_data)
|
|
1334
|
+
if action_message:
|
|
1335
|
+
# If we haven't streamed any content yet, stream the action message
|
|
1336
|
+
if not accumulated_content:
|
|
1337
|
+
output_text = action_message
|
|
1338
|
+
# Stream the action message
|
|
1339
|
+
yield ResponsesAgentStreamEvent(
|
|
1340
|
+
**self.create_text_delta(
|
|
1341
|
+
delta=action_message, item_id=item_id
|
|
1342
|
+
)
|
|
1343
|
+
)
|
|
1344
|
+
else:
|
|
1345
|
+
# Append action message after accumulated content
|
|
1346
|
+
output_text = f"{accumulated_content}\n\n{action_message}"
|
|
1347
|
+
# Stream the separator and action message
|
|
1348
|
+
yield ResponsesAgentStreamEvent(
|
|
1349
|
+
**self.create_text_delta(delta="\n\n", item_id=item_id)
|
|
1350
|
+
)
|
|
1351
|
+
yield ResponsesAgentStreamEvent(
|
|
1352
|
+
**self.create_text_delta(
|
|
1353
|
+
delta=action_message, item_id=item_id
|
|
1354
|
+
)
|
|
1355
|
+
)
|
|
476
1356
|
|
|
477
1357
|
# Yield final output item
|
|
478
1358
|
yield ResponsesAgentStreamEvent(
|
|
479
1359
|
type="response.output_item.done",
|
|
480
|
-
item=self.create_text_output_item(
|
|
481
|
-
text=accumulated_content, id=item_id
|
|
482
|
-
),
|
|
1360
|
+
item=self.create_text_output_item(text=output_text, id=item_id),
|
|
483
1361
|
custom_outputs=custom_outputs,
|
|
484
1362
|
)
|
|
485
1363
|
except Exception as e:
|
|
@@ -490,7 +1368,6 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
490
1368
|
try:
|
|
491
1369
|
loop = asyncio.get_event_loop()
|
|
492
1370
|
except RuntimeError:
|
|
493
|
-
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
494
1371
|
loop = asyncio.new_event_loop()
|
|
495
1372
|
asyncio.set_event_loop(loop)
|
|
496
1373
|
|
|
@@ -575,15 +1452,24 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
575
1452
|
return messages
|
|
576
1453
|
|
|
577
1454
|
def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
|
|
578
|
-
"""Convert ResponsesAgent context to internal Context.
|
|
1455
|
+
"""Convert ResponsesAgent context to internal Context.
|
|
1456
|
+
|
|
1457
|
+
Handles the input structure:
|
|
1458
|
+
- custom_inputs.configurable: Configuration (thread_id, user_id, store_num, etc.)
|
|
1459
|
+
- custom_inputs.session: Accumulated state (conversation_id, genie conversations, etc.)
|
|
579
1460
|
|
|
1461
|
+
Maps conversation_id -> thread_id for LangGraph compatibility.
|
|
1462
|
+
conversation_id can be provided in either configurable or session.
|
|
1463
|
+
Normalizes user_id (replaces . with _) for memory namespace compatibility.
|
|
1464
|
+
"""
|
|
580
1465
|
logger.debug(f"request.context: {request.context}")
|
|
581
1466
|
logger.debug(f"request.custom_inputs: {request.custom_inputs}")
|
|
582
1467
|
|
|
583
1468
|
configurable: dict[str, Any] = {}
|
|
1469
|
+
session: dict[str, Any] = {}
|
|
584
1470
|
|
|
585
1471
|
# Process context values first (lower priority)
|
|
586
|
-
#
|
|
1472
|
+
# These come from Databricks ResponsesAgent ChatContext
|
|
587
1473
|
chat_context: Optional[ChatContext] = request.context
|
|
588
1474
|
if chat_context is not None:
|
|
589
1475
|
conversation_id: Optional[str] = chat_context.conversation_id
|
|
@@ -591,27 +1477,185 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
591
1477
|
|
|
592
1478
|
if conversation_id is not None:
|
|
593
1479
|
configurable["conversation_id"] = conversation_id
|
|
594
|
-
configurable["thread_id"] = conversation_id
|
|
595
1480
|
|
|
596
1481
|
if user_id is not None:
|
|
597
1482
|
configurable["user_id"] = user_id
|
|
598
1483
|
|
|
599
1484
|
# Process custom_inputs after context so they can override context values (higher priority)
|
|
600
1485
|
if request.custom_inputs:
|
|
1486
|
+
# Extract configurable section (user config)
|
|
601
1487
|
if "configurable" in request.custom_inputs:
|
|
602
|
-
configurable.update(request.custom_inputs
|
|
1488
|
+
configurable.update(request.custom_inputs["configurable"])
|
|
1489
|
+
|
|
1490
|
+
# Extract session section
|
|
1491
|
+
if "session" in request.custom_inputs:
|
|
1492
|
+
session_input = request.custom_inputs["session"]
|
|
1493
|
+
if isinstance(session_input, dict):
|
|
1494
|
+
session = session_input
|
|
1495
|
+
|
|
1496
|
+
# Handle legacy flat structure (backwards compatibility)
|
|
1497
|
+
# If user passes keys directly in custom_inputs, merge them
|
|
1498
|
+
for key in list(request.custom_inputs.keys()):
|
|
1499
|
+
if key not in ("configurable", "session"):
|
|
1500
|
+
configurable[key] = request.custom_inputs[key]
|
|
1501
|
+
|
|
1502
|
+
# Extract known Context fields
|
|
1503
|
+
user_id_value: str | None = configurable.pop("user_id", None)
|
|
1504
|
+
if user_id_value:
|
|
1505
|
+
# Normalize user_id for memory namespace (replace . with _)
|
|
1506
|
+
user_id_value = user_id_value.replace(".", "_")
|
|
1507
|
+
|
|
1508
|
+
# Accept thread_id from configurable, or conversation_id from configurable or session
|
|
1509
|
+
# Priority: configurable.conversation_id > session.conversation_id > configurable.thread_id
|
|
1510
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
1511
|
+
conversation_id: str | None = configurable.pop("conversation_id", None)
|
|
1512
|
+
|
|
1513
|
+
# Also check session for conversation_id (output puts it there)
|
|
1514
|
+
if conversation_id is None and "conversation_id" in session:
|
|
1515
|
+
conversation_id = session.get("conversation_id")
|
|
1516
|
+
|
|
1517
|
+
# conversation_id takes precedence if provided
|
|
1518
|
+
if conversation_id:
|
|
1519
|
+
thread_id = conversation_id
|
|
1520
|
+
if not thread_id:
|
|
1521
|
+
# Generate new thread_id if neither provided
|
|
1522
|
+
thread_id = str(uuid.uuid4())
|
|
1523
|
+
|
|
1524
|
+
# All remaining configurable values go into custom dict
|
|
1525
|
+
custom: dict[str, Any] = configurable
|
|
1526
|
+
|
|
1527
|
+
logger.debug(
|
|
1528
|
+
f"Creating context with user_id={user_id_value}, thread_id={thread_id}, custom={custom}"
|
|
1529
|
+
)
|
|
603
1530
|
|
|
604
|
-
|
|
1531
|
+
return Context(
|
|
1532
|
+
user_id=user_id_value,
|
|
1533
|
+
thread_id=thread_id,
|
|
1534
|
+
custom=custom,
|
|
1535
|
+
)
|
|
605
1536
|
|
|
606
|
-
|
|
607
|
-
|
|
1537
|
+
def _extract_session_from_request(
|
|
1538
|
+
self, request: ResponsesAgentRequest
|
|
1539
|
+
) -> dict[str, Any]:
|
|
1540
|
+
"""Extract session state from request for passing to graph.
|
|
608
1541
|
|
|
609
|
-
|
|
610
|
-
|
|
1542
|
+
Handles:
|
|
1543
|
+
- New structure: custom_inputs.session.genie
|
|
1544
|
+
- Legacy structure: custom_inputs.genie_conversation_ids
|
|
1545
|
+
"""
|
|
1546
|
+
session: dict[str, Any] = {}
|
|
1547
|
+
|
|
1548
|
+
if not request.custom_inputs:
|
|
1549
|
+
return session
|
|
1550
|
+
|
|
1551
|
+
# New structure: session.genie
|
|
1552
|
+
if "session" in request.custom_inputs:
|
|
1553
|
+
session_input = request.custom_inputs["session"]
|
|
1554
|
+
if isinstance(session_input, dict) and "genie" in session_input:
|
|
1555
|
+
genie_state = session_input["genie"]
|
|
1556
|
+
# Extract conversation IDs from the new structure
|
|
1557
|
+
if isinstance(genie_state, dict) and "spaces" in genie_state:
|
|
1558
|
+
genie_conversation_ids = {}
|
|
1559
|
+
for space_id, space_state in genie_state["spaces"].items():
|
|
1560
|
+
if (
|
|
1561
|
+
isinstance(space_state, dict)
|
|
1562
|
+
and "conversation_id" in space_state
|
|
1563
|
+
):
|
|
1564
|
+
genie_conversation_ids[space_id] = space_state[
|
|
1565
|
+
"conversation_id"
|
|
1566
|
+
]
|
|
1567
|
+
if genie_conversation_ids:
|
|
1568
|
+
session["genie_conversation_ids"] = genie_conversation_ids
|
|
1569
|
+
|
|
1570
|
+
# Legacy structure: genie_conversation_ids at top level
|
|
1571
|
+
if "genie_conversation_ids" in request.custom_inputs:
|
|
1572
|
+
session["genie_conversation_ids"] = request.custom_inputs[
|
|
1573
|
+
"genie_conversation_ids"
|
|
1574
|
+
]
|
|
1575
|
+
|
|
1576
|
+
# Also check inside configurable for legacy support
|
|
1577
|
+
if "configurable" in request.custom_inputs:
|
|
1578
|
+
cfg = request.custom_inputs["configurable"]
|
|
1579
|
+
if isinstance(cfg, dict) and "genie_conversation_ids" in cfg:
|
|
1580
|
+
session["genie_conversation_ids"] = cfg["genie_conversation_ids"]
|
|
1581
|
+
|
|
1582
|
+
return session
|
|
1583
|
+
|
|
1584
|
+
def _build_custom_outputs(
|
|
1585
|
+
self,
|
|
1586
|
+
context: Context,
|
|
1587
|
+
thread_id: Optional[str],
|
|
1588
|
+
loop: Any, # asyncio.AbstractEventLoop
|
|
1589
|
+
) -> dict[str, Any]:
|
|
1590
|
+
"""Build custom_outputs that can be copy-pasted as next request's custom_inputs.
|
|
1591
|
+
|
|
1592
|
+
Output structure:
|
|
1593
|
+
configurable:
|
|
1594
|
+
thread_id: "abc-123" # Thread identifier (conversation_id is alias)
|
|
1595
|
+
user_id: "nate.fleming" # De-normalized (no underscore replacement)
|
|
1596
|
+
store_num: "87887" # Any custom fields
|
|
1597
|
+
session:
|
|
1598
|
+
conversation_id: "abc-123" # Alias of thread_id for Databricks compatibility
|
|
1599
|
+
genie:
|
|
1600
|
+
spaces:
|
|
1601
|
+
space_123: {conversation_id: "conv_456", cache_hit: false, ...}
|
|
1602
|
+
"""
|
|
1603
|
+
return loop.run_until_complete(
|
|
1604
|
+
self._build_custom_outputs_async(context=context, thread_id=thread_id)
|
|
1605
|
+
)
|
|
1606
|
+
|
|
1607
|
+
async def _build_custom_outputs_async(
|
|
1608
|
+
self,
|
|
1609
|
+
context: Context,
|
|
1610
|
+
thread_id: Optional[str],
|
|
1611
|
+
) -> dict[str, Any]:
|
|
1612
|
+
"""Async version of _build_custom_outputs."""
|
|
1613
|
+
# Build configurable section
|
|
1614
|
+
# Note: only thread_id is included here (conversation_id goes in session)
|
|
1615
|
+
configurable: dict[str, Any] = {}
|
|
1616
|
+
|
|
1617
|
+
if thread_id:
|
|
1618
|
+
configurable["thread_id"] = thread_id
|
|
611
1619
|
|
|
612
|
-
|
|
1620
|
+
# Include user_id (keep normalized form for consistency)
|
|
1621
|
+
if context.user_id:
|
|
1622
|
+
configurable["user_id"] = context.user_id
|
|
613
1623
|
|
|
614
|
-
|
|
1624
|
+
# Include all custom fields from context
|
|
1625
|
+
configurable.update(context.custom)
|
|
1626
|
+
|
|
1627
|
+
# Build session section with accumulated state
|
|
1628
|
+
# Note: conversation_id is included here as an alias of thread_id
|
|
1629
|
+
session: dict[str, Any] = {}
|
|
1630
|
+
|
|
1631
|
+
if thread_id:
|
|
1632
|
+
# Include conversation_id in session (alias of thread_id)
|
|
1633
|
+
session["conversation_id"] = thread_id
|
|
1634
|
+
|
|
1635
|
+
state_snapshot: Optional[StateSnapshot] = await get_state_snapshot_async(
|
|
1636
|
+
self.graph, thread_id
|
|
1637
|
+
)
|
|
1638
|
+
genie_conversation_ids: dict[str, str] = (
|
|
1639
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
1640
|
+
)
|
|
1641
|
+
if genie_conversation_ids:
|
|
1642
|
+
# Convert flat genie_conversation_ids to new session.genie.spaces structure
|
|
1643
|
+
session["genie"] = {
|
|
1644
|
+
"spaces": {
|
|
1645
|
+
space_id: {
|
|
1646
|
+
"conversation_id": conv_id,
|
|
1647
|
+
# Note: cache_hit, follow_up_questions populated by Genie tool
|
|
1648
|
+
"cache_hit": False,
|
|
1649
|
+
"follow_up_questions": [],
|
|
1650
|
+
}
|
|
1651
|
+
for space_id, conv_id in genie_conversation_ids.items()
|
|
1652
|
+
}
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
return {
|
|
1656
|
+
"configurable": configurable,
|
|
1657
|
+
"session": session,
|
|
1658
|
+
}
|
|
615
1659
|
|
|
616
1660
|
|
|
617
1661
|
def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
@@ -630,7 +1674,9 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
|
630
1674
|
return LanggraphChatModel(graph)
|
|
631
1675
|
|
|
632
1676
|
|
|
633
|
-
def create_responses_agent(
|
|
1677
|
+
def create_responses_agent(
|
|
1678
|
+
graph: CompiledStateGraph,
|
|
1679
|
+
) -> ResponsesAgent:
|
|
634
1680
|
"""
|
|
635
1681
|
Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
|
|
636
1682
|
|
|
@@ -665,6 +1711,29 @@ def _process_langchain_messages(
|
|
|
665
1711
|
return loop.run_until_complete(_async_invoke())
|
|
666
1712
|
|
|
667
1713
|
|
|
1714
|
+
def _configurable_to_context(configurable: dict[str, Any]) -> Context:
|
|
1715
|
+
"""Convert a configurable dict to a Context object."""
|
|
1716
|
+
configurable = configurable.copy()
|
|
1717
|
+
|
|
1718
|
+
# Extract known Context fields
|
|
1719
|
+
user_id: str | None = configurable.pop("user_id", None)
|
|
1720
|
+
if user_id:
|
|
1721
|
+
user_id = user_id.replace(".", "_")
|
|
1722
|
+
|
|
1723
|
+
thread_id: str | None = configurable.pop("thread_id", None)
|
|
1724
|
+
if "conversation_id" in configurable and not thread_id:
|
|
1725
|
+
thread_id = configurable.pop("conversation_id")
|
|
1726
|
+
if not thread_id:
|
|
1727
|
+
thread_id = str(uuid.uuid4())
|
|
1728
|
+
|
|
1729
|
+
# All remaining values go into custom dict
|
|
1730
|
+
return Context(
|
|
1731
|
+
user_id=user_id,
|
|
1732
|
+
thread_id=thread_id,
|
|
1733
|
+
custom=configurable,
|
|
1734
|
+
)
|
|
1735
|
+
|
|
1736
|
+
|
|
668
1737
|
def _process_langchain_messages_stream(
|
|
669
1738
|
app: LanggraphChatModel | CompiledStateGraph,
|
|
670
1739
|
messages: Sequence[BaseMessage],
|
|
@@ -678,8 +1747,8 @@ def _process_langchain_messages_stream(
|
|
|
678
1747
|
|
|
679
1748
|
logger.debug(f"Processing messages: {messages}, custom_inputs: {custom_inputs}")
|
|
680
1749
|
|
|
681
|
-
|
|
682
|
-
context: Context =
|
|
1750
|
+
configurable = (custom_inputs or {}).get("configurable", custom_inputs or {})
|
|
1751
|
+
context: Context = _configurable_to_context(configurable)
|
|
683
1752
|
|
|
684
1753
|
# Use async astream internally for parallel execution
|
|
685
1754
|
async def _async_stream():
|