shotgun-sh 0.2.6.dev5__py3-none-any.whl → 0.2.7.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of shotgun-sh might be problematic. Click here for more details.
- shotgun/agents/agent_manager.py +272 -14
- shotgun/agents/common.py +42 -17
- shotgun/agents/conversation_history.py +123 -2
- shotgun/agents/conversation_manager.py +24 -2
- shotgun/agents/history/context_extraction.py +93 -6
- shotgun/agents/tools/file_management.py +55 -9
- shotgun/prompts/agents/specify.j2 +270 -3
- shotgun/tui/screens/chat.py +54 -13
- shotgun_sh-0.2.7.dev2.dist-info/METADATA +126 -0
- {shotgun_sh-0.2.6.dev5.dist-info → shotgun_sh-0.2.7.dev2.dist-info}/RECORD +13 -13
- shotgun_sh-0.2.6.dev5.dist-info/METADATA +0 -467
- {shotgun_sh-0.2.6.dev5.dist-info → shotgun_sh-0.2.7.dev2.dist-info}/WHEEL +0 -0
- {shotgun_sh-0.2.6.dev5.dist-info → shotgun_sh-0.2.7.dev2.dist-info}/entry_points.txt +0 -0
- {shotgun_sh-0.2.6.dev5.dist-info → shotgun_sh-0.2.7.dev2.dist-info}/licenses/LICENSE +0 -0
shotgun/agents/agent_manager.py
CHANGED
|
@@ -1,10 +1,21 @@
|
|
|
1
1
|
"""Agent manager for coordinating multiple AI agents with shared message history."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
import logging
|
|
4
5
|
from collections.abc import AsyncIterable, Sequence
|
|
5
6
|
from dataclasses import dataclass, field, is_dataclass, replace
|
|
7
|
+
from pathlib import Path
|
|
6
8
|
from typing import TYPE_CHECKING, Any, cast
|
|
7
9
|
|
|
10
|
+
import logfire
|
|
11
|
+
from tenacity import (
|
|
12
|
+
before_sleep_log,
|
|
13
|
+
retry,
|
|
14
|
+
retry_if_exception,
|
|
15
|
+
stop_after_attempt,
|
|
16
|
+
wait_exponential,
|
|
17
|
+
)
|
|
18
|
+
|
|
8
19
|
if TYPE_CHECKING:
|
|
9
20
|
from shotgun.agents.conversation_history import ConversationState
|
|
10
21
|
|
|
@@ -52,6 +63,35 @@ from .tasks import create_tasks_agent
|
|
|
52
63
|
logger = logging.getLogger(__name__)
|
|
53
64
|
|
|
54
65
|
|
|
66
|
+
def _is_retryable_error(exception: BaseException) -> bool:
|
|
67
|
+
"""Check if exception should trigger a retry.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
exception: The exception to check.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if the exception is a transient error that should be retried.
|
|
74
|
+
"""
|
|
75
|
+
# ValueError for truncated/incomplete JSON
|
|
76
|
+
if isinstance(exception, ValueError):
|
|
77
|
+
error_str = str(exception)
|
|
78
|
+
return "EOF while parsing" in error_str or (
|
|
79
|
+
"JSON" in error_str and "parsing" in error_str
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# API errors (overload, rate limits)
|
|
83
|
+
exception_name = type(exception).__name__
|
|
84
|
+
if "APIStatusError" in exception_name:
|
|
85
|
+
error_str = str(exception)
|
|
86
|
+
return "overload" in error_str.lower() or "rate" in error_str.lower()
|
|
87
|
+
|
|
88
|
+
# Network errors
|
|
89
|
+
if "ConnectionError" in exception_name or "TimeoutError" in exception_name:
|
|
90
|
+
return True
|
|
91
|
+
|
|
92
|
+
return False
|
|
93
|
+
|
|
94
|
+
|
|
55
95
|
class MessageHistoryUpdated(Message):
|
|
56
96
|
"""Event posted when the message history is updated."""
|
|
57
97
|
|
|
@@ -265,6 +305,49 @@ class AgentManager(Widget):
|
|
|
265
305
|
f"Invalid agent type: {agent_type}. Must be one of: {', '.join(e.value for e in AgentType)}"
|
|
266
306
|
) from None
|
|
267
307
|
|
|
308
|
+
@retry(
|
|
309
|
+
stop=stop_after_attempt(3),
|
|
310
|
+
wait=wait_exponential(multiplier=1, min=1, max=8),
|
|
311
|
+
retry=retry_if_exception(_is_retryable_error),
|
|
312
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
313
|
+
reraise=True,
|
|
314
|
+
)
|
|
315
|
+
async def _run_agent_with_retry(
|
|
316
|
+
self,
|
|
317
|
+
agent: Agent[AgentDeps, AgentResponse],
|
|
318
|
+
prompt: str | None,
|
|
319
|
+
deps: AgentDeps,
|
|
320
|
+
usage_limits: UsageLimits | None,
|
|
321
|
+
message_history: list[ModelMessage],
|
|
322
|
+
event_stream_handler: Any,
|
|
323
|
+
**kwargs: Any,
|
|
324
|
+
) -> AgentRunResult[AgentResponse]:
|
|
325
|
+
"""Run agent with automatic retry on transient errors.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
agent: The agent to run.
|
|
329
|
+
prompt: Optional prompt to send to the agent.
|
|
330
|
+
deps: Agent dependencies.
|
|
331
|
+
usage_limits: Optional usage limits.
|
|
332
|
+
message_history: Message history to provide to agent.
|
|
333
|
+
event_stream_handler: Event handler for streaming.
|
|
334
|
+
**kwargs: Additional keyword arguments.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
The agent run result.
|
|
338
|
+
|
|
339
|
+
Raises:
|
|
340
|
+
Various exceptions if all retries fail.
|
|
341
|
+
"""
|
|
342
|
+
return await agent.run(
|
|
343
|
+
prompt,
|
|
344
|
+
deps=deps,
|
|
345
|
+
usage_limits=usage_limits,
|
|
346
|
+
message_history=message_history,
|
|
347
|
+
event_stream_handler=event_stream_handler,
|
|
348
|
+
**kwargs,
|
|
349
|
+
)
|
|
350
|
+
|
|
268
351
|
async def run(
|
|
269
352
|
self,
|
|
270
353
|
prompt: str | None = None,
|
|
@@ -391,8 +474,9 @@ class AgentManager(Widget):
|
|
|
391
474
|
)
|
|
392
475
|
|
|
393
476
|
try:
|
|
394
|
-
result: AgentRunResult[AgentResponse] = await self.
|
|
395
|
-
|
|
477
|
+
result: AgentRunResult[AgentResponse] = await self._run_agent_with_retry(
|
|
478
|
+
agent=self.current_agent,
|
|
479
|
+
prompt=prompt,
|
|
396
480
|
deps=deps,
|
|
397
481
|
usage_limits=usage_limits,
|
|
398
482
|
message_history=message_history,
|
|
@@ -401,18 +485,93 @@ class AgentManager(Widget):
|
|
|
401
485
|
else None,
|
|
402
486
|
**kwargs,
|
|
403
487
|
)
|
|
488
|
+
except ValueError as e:
|
|
489
|
+
# Handle truncated/incomplete JSON in tool calls specifically
|
|
490
|
+
error_str = str(e)
|
|
491
|
+
if "EOF while parsing" in error_str or (
|
|
492
|
+
"JSON" in error_str and "parsing" in error_str
|
|
493
|
+
):
|
|
494
|
+
logger.error(
|
|
495
|
+
"Tool call with truncated/incomplete JSON arguments detected",
|
|
496
|
+
extra={
|
|
497
|
+
"agent_mode": self._current_agent_type.value,
|
|
498
|
+
"model_name": model_name,
|
|
499
|
+
"error": error_str,
|
|
500
|
+
},
|
|
501
|
+
)
|
|
502
|
+
logfire.error(
|
|
503
|
+
"Tool call with truncated JSON arguments",
|
|
504
|
+
agent_mode=self._current_agent_type.value,
|
|
505
|
+
model_name=model_name,
|
|
506
|
+
error=error_str,
|
|
507
|
+
)
|
|
508
|
+
# Add helpful hint message for the user
|
|
509
|
+
self.ui_message_history.append(
|
|
510
|
+
HintMessage(
|
|
511
|
+
message="⚠️ The agent attempted an operation with arguments that were too large (truncated JSON). "
|
|
512
|
+
"Try breaking your request into smaller steps or more focused contracts."
|
|
513
|
+
)
|
|
514
|
+
)
|
|
515
|
+
self._post_messages_updated()
|
|
516
|
+
# Re-raise to maintain error visibility
|
|
517
|
+
raise
|
|
518
|
+
except Exception as e:
|
|
519
|
+
# Log the error with full stack trace to shotgun.log and Logfire
|
|
520
|
+
logger.exception(
|
|
521
|
+
"Agent execution failed",
|
|
522
|
+
extra={
|
|
523
|
+
"agent_mode": self._current_agent_type.value,
|
|
524
|
+
"model_name": model_name,
|
|
525
|
+
"error_type": type(e).__name__,
|
|
526
|
+
},
|
|
527
|
+
)
|
|
528
|
+
logfire.exception(
|
|
529
|
+
"Agent execution failed",
|
|
530
|
+
agent_mode=self._current_agent_type.value,
|
|
531
|
+
model_name=model_name,
|
|
532
|
+
error_type=type(e).__name__,
|
|
533
|
+
)
|
|
534
|
+
# Re-raise to let TUI handle user messaging
|
|
535
|
+
raise
|
|
404
536
|
finally:
|
|
405
537
|
self._stream_state = None
|
|
406
538
|
|
|
407
539
|
# Agent ALWAYS returns AgentResponse with structured output
|
|
408
540
|
agent_response = result.output
|
|
409
|
-
logger.debug(
|
|
541
|
+
logger.debug(
|
|
542
|
+
"Agent returned structured AgentResponse",
|
|
543
|
+
extra={
|
|
544
|
+
"has_response": agent_response.response is not None,
|
|
545
|
+
"response_length": len(agent_response.response)
|
|
546
|
+
if agent_response.response
|
|
547
|
+
else 0,
|
|
548
|
+
"response_preview": agent_response.response[:100] + "..."
|
|
549
|
+
if agent_response.response and len(agent_response.response) > 100
|
|
550
|
+
else agent_response.response or "(empty)",
|
|
551
|
+
"has_clarifying_questions": bool(agent_response.clarifying_questions),
|
|
552
|
+
"num_clarifying_questions": len(agent_response.clarifying_questions)
|
|
553
|
+
if agent_response.clarifying_questions
|
|
554
|
+
else 0,
|
|
555
|
+
},
|
|
556
|
+
)
|
|
410
557
|
|
|
411
558
|
# Always add the agent's response messages to maintain conversation history
|
|
412
559
|
self.ui_message_history = original_messages + cast(
|
|
413
560
|
list[ModelRequest | ModelResponse | HintMessage], result.new_messages()
|
|
414
561
|
)
|
|
415
562
|
|
|
563
|
+
# Get file operations early so we can use them for contextual messages
|
|
564
|
+
file_operations = deps.file_tracker.operations.copy()
|
|
565
|
+
self.recently_change_files = file_operations
|
|
566
|
+
|
|
567
|
+
logger.debug(
|
|
568
|
+
"File operations tracked",
|
|
569
|
+
extra={
|
|
570
|
+
"num_file_operations": len(file_operations),
|
|
571
|
+
"operation_files": [Path(op.file_path).name for op in file_operations],
|
|
572
|
+
},
|
|
573
|
+
)
|
|
574
|
+
|
|
416
575
|
# Check if there are clarifying questions
|
|
417
576
|
if agent_response.clarifying_questions:
|
|
418
577
|
logger.info(
|
|
@@ -459,27 +618,93 @@ class AgentManager(Widget):
|
|
|
459
618
|
response_text=agent_response.response,
|
|
460
619
|
)
|
|
461
620
|
)
|
|
621
|
+
|
|
622
|
+
# Post UI update with hint messages and file operations
|
|
623
|
+
logger.debug(
|
|
624
|
+
"Posting UI update for Q&A mode with hint messages and file operations"
|
|
625
|
+
)
|
|
626
|
+
self._post_messages_updated(file_operations)
|
|
462
627
|
else:
|
|
463
|
-
# No clarifying questions -
|
|
628
|
+
# No clarifying questions - show the response or a default success message
|
|
464
629
|
if agent_response.response and agent_response.response.strip():
|
|
630
|
+
logger.debug(
|
|
631
|
+
"Adding agent response as hint",
|
|
632
|
+
extra={
|
|
633
|
+
"response_preview": agent_response.response[:100] + "..."
|
|
634
|
+
if len(agent_response.response) > 100
|
|
635
|
+
else agent_response.response,
|
|
636
|
+
"has_file_operations": len(file_operations) > 0,
|
|
637
|
+
},
|
|
638
|
+
)
|
|
465
639
|
self.ui_message_history.append(
|
|
466
640
|
HintMessage(message=agent_response.response)
|
|
467
641
|
)
|
|
642
|
+
else:
|
|
643
|
+
# Fallback: response is empty or whitespace
|
|
644
|
+
logger.debug(
|
|
645
|
+
"Agent response was empty, using fallback completion message",
|
|
646
|
+
extra={"has_file_operations": len(file_operations) > 0},
|
|
647
|
+
)
|
|
648
|
+
# Show contextual message based on whether files were modified
|
|
649
|
+
if file_operations:
|
|
650
|
+
self.ui_message_history.append(
|
|
651
|
+
HintMessage(
|
|
652
|
+
message="✅ Task completed - files have been modified"
|
|
653
|
+
)
|
|
654
|
+
)
|
|
655
|
+
else:
|
|
656
|
+
self.ui_message_history.append(
|
|
657
|
+
HintMessage(message="✅ Task completed")
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
# Post UI update immediately so user sees the response without delay
|
|
661
|
+
logger.debug(
|
|
662
|
+
"Posting immediate UI update with hint message and file operations"
|
|
663
|
+
)
|
|
664
|
+
self._post_messages_updated(file_operations)
|
|
468
665
|
|
|
469
666
|
# Apply compaction to persistent message history to prevent cascading growth
|
|
470
667
|
all_messages = result.all_messages()
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
668
|
+
try:
|
|
669
|
+
logger.debug(
|
|
670
|
+
"Starting message history compaction",
|
|
671
|
+
extra={"message_count": len(all_messages)},
|
|
672
|
+
)
|
|
673
|
+
self.message_history = await apply_persistent_compaction(all_messages, deps)
|
|
674
|
+
logger.debug(
|
|
675
|
+
"Completed message history compaction",
|
|
676
|
+
extra={
|
|
677
|
+
"original_count": len(all_messages),
|
|
678
|
+
"compacted_count": len(self.message_history),
|
|
679
|
+
},
|
|
680
|
+
)
|
|
681
|
+
except Exception as e:
|
|
682
|
+
# If compaction fails, log full error with stack trace and use uncompacted messages
|
|
683
|
+
logger.error(
|
|
684
|
+
"Failed to compact message history - using uncompacted messages",
|
|
685
|
+
exc_info=True,
|
|
686
|
+
extra={
|
|
687
|
+
"error": str(e),
|
|
688
|
+
"message_count": len(all_messages),
|
|
689
|
+
"agent_mode": self._current_agent_type.value,
|
|
690
|
+
},
|
|
691
|
+
)
|
|
692
|
+
# Fallback: use uncompacted messages to prevent data loss
|
|
693
|
+
self.message_history = all_messages
|
|
476
694
|
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
695
|
+
usage = result.usage()
|
|
696
|
+
if hasattr(deps, "llm_model") and deps.llm_model is not None:
|
|
697
|
+
deps.usage_manager.add_usage(
|
|
698
|
+
usage, model_name=deps.llm_model.name, provider=deps.llm_model.provider
|
|
699
|
+
)
|
|
700
|
+
else:
|
|
701
|
+
logger.warning(
|
|
702
|
+
"llm_model is None, skipping usage tracking",
|
|
703
|
+
extra={"agent_mode": self._current_agent_type.value},
|
|
704
|
+
)
|
|
480
705
|
|
|
481
|
-
#
|
|
482
|
-
|
|
706
|
+
# UI updates are now posted immediately in each branch (Q&A or non-Q&A)
|
|
707
|
+
# before compaction, so no duplicate posting needed here
|
|
483
708
|
|
|
484
709
|
return result
|
|
485
710
|
|
|
@@ -554,6 +779,39 @@ class AgentManager(Widget):
|
|
|
554
779
|
# Detect source from call stack
|
|
555
780
|
source = detect_source()
|
|
556
781
|
|
|
782
|
+
# Log if tool call has incomplete args (for debugging truncated JSON)
|
|
783
|
+
if isinstance(event.part.args, str):
|
|
784
|
+
try:
|
|
785
|
+
json.loads(event.part.args)
|
|
786
|
+
except (json.JSONDecodeError, ValueError):
|
|
787
|
+
args_preview = (
|
|
788
|
+
event.part.args[:100] + "..."
|
|
789
|
+
if len(event.part.args) > 100
|
|
790
|
+
else event.part.args
|
|
791
|
+
)
|
|
792
|
+
logger.warning(
|
|
793
|
+
"FunctionToolCallEvent received with incomplete JSON args",
|
|
794
|
+
extra={
|
|
795
|
+
"tool_name": event.part.tool_name,
|
|
796
|
+
"tool_call_id": event.part.tool_call_id,
|
|
797
|
+
"args_preview": args_preview,
|
|
798
|
+
"args_length": len(event.part.args)
|
|
799
|
+
if event.part.args
|
|
800
|
+
else 0,
|
|
801
|
+
"agent_mode": self._current_agent_type.value,
|
|
802
|
+
},
|
|
803
|
+
)
|
|
804
|
+
logfire.warn(
|
|
805
|
+
"FunctionToolCallEvent received with incomplete JSON args",
|
|
806
|
+
tool_name=event.part.tool_name,
|
|
807
|
+
tool_call_id=event.part.tool_call_id,
|
|
808
|
+
args_preview=args_preview,
|
|
809
|
+
args_length=len(event.part.args)
|
|
810
|
+
if event.part.args
|
|
811
|
+
else 0,
|
|
812
|
+
agent_mode=self._current_agent_type.value,
|
|
813
|
+
)
|
|
814
|
+
|
|
557
815
|
track_event(
|
|
558
816
|
"tool_called",
|
|
559
817
|
{
|
shotgun/agents/common.py
CHANGED
|
@@ -384,23 +384,48 @@ def get_agent_existing_files(agent_mode: AgentType | None = None) -> list[str]:
|
|
|
384
384
|
relative_path = file_path.relative_to(base_path)
|
|
385
385
|
existing_files.append(str(relative_path))
|
|
386
386
|
else:
|
|
387
|
-
# For other agents, check
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
#
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
#
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
387
|
+
# For other agents, check files/directories they have access to
|
|
388
|
+
allowed_paths_raw = AGENT_DIRECTORIES[agent_mode]
|
|
389
|
+
|
|
390
|
+
# Convert single Path/string to list of Paths for uniform handling
|
|
391
|
+
if isinstance(allowed_paths_raw, str):
|
|
392
|
+
# Special case: "*" means export agent (shouldn't reach here but handle it)
|
|
393
|
+
allowed_paths = (
|
|
394
|
+
[Path(allowed_paths_raw)] if allowed_paths_raw != "*" else []
|
|
395
|
+
)
|
|
396
|
+
elif isinstance(allowed_paths_raw, Path):
|
|
397
|
+
allowed_paths = [allowed_paths_raw]
|
|
398
|
+
else:
|
|
399
|
+
# Already a list
|
|
400
|
+
allowed_paths = allowed_paths_raw
|
|
401
|
+
|
|
402
|
+
# Check each allowed path
|
|
403
|
+
for allowed_path in allowed_paths:
|
|
404
|
+
allowed_str = str(allowed_path)
|
|
405
|
+
|
|
406
|
+
# Check if it's a directory (no .md suffix)
|
|
407
|
+
if not allowed_path.suffix or not allowed_str.endswith(".md"):
|
|
408
|
+
# It's a directory - list all files within it
|
|
409
|
+
dir_path = base_path / allowed_str
|
|
410
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
411
|
+
for file_path in dir_path.rglob("*"):
|
|
412
|
+
if file_path.is_file():
|
|
413
|
+
relative_path = file_path.relative_to(base_path)
|
|
414
|
+
existing_files.append(str(relative_path))
|
|
415
|
+
else:
|
|
416
|
+
# It's a file - check if it exists
|
|
417
|
+
file_path = base_path / allowed_str
|
|
418
|
+
if file_path.exists():
|
|
419
|
+
existing_files.append(allowed_str)
|
|
420
|
+
|
|
421
|
+
# Also check for associated directory (e.g., research/ for research.md)
|
|
422
|
+
base_name = allowed_str.replace(".md", "")
|
|
423
|
+
dir_path = base_path / base_name
|
|
424
|
+
if dir_path.exists() and dir_path.is_dir():
|
|
425
|
+
for file_path in dir_path.rglob("*"):
|
|
426
|
+
if file_path.is_file():
|
|
427
|
+
relative_path = file_path.relative_to(base_path)
|
|
428
|
+
existing_files.append(str(relative_path))
|
|
404
429
|
|
|
405
430
|
return existing_files
|
|
406
431
|
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
"""Models and utilities for persisting TUI conversation history."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
3
5
|
from datetime import datetime
|
|
4
6
|
from typing import Any, cast
|
|
5
7
|
|
|
@@ -7,14 +9,106 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
|
7
9
|
from pydantic_ai.messages import (
|
|
8
10
|
ModelMessage,
|
|
9
11
|
ModelMessagesTypeAdapter,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
ToolCallPart,
|
|
10
14
|
)
|
|
11
15
|
from pydantic_core import to_jsonable_python
|
|
12
16
|
|
|
13
17
|
from shotgun.tui.screens.chat_screen.hint_message import HintMessage
|
|
14
18
|
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
15
21
|
SerializedMessage = dict[str, Any]
|
|
16
22
|
|
|
17
23
|
|
|
24
|
+
def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
|
|
25
|
+
"""Check if a tool call has valid, complete JSON arguments.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
tool_call: The tool call part to validate
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
True if the tool call args are valid JSON, False otherwise
|
|
32
|
+
"""
|
|
33
|
+
if tool_call.args is None:
|
|
34
|
+
return True # No args is valid
|
|
35
|
+
|
|
36
|
+
if isinstance(tool_call.args, dict):
|
|
37
|
+
return True # Already parsed dict is valid
|
|
38
|
+
|
|
39
|
+
if not isinstance(tool_call.args, str):
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
# Try to parse the JSON string
|
|
43
|
+
try:
|
|
44
|
+
json.loads(tool_call.args)
|
|
45
|
+
return True
|
|
46
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
47
|
+
# Log incomplete tool call detection
|
|
48
|
+
args_preview = (
|
|
49
|
+
tool_call.args[:100] + "..."
|
|
50
|
+
if len(tool_call.args) > 100
|
|
51
|
+
else tool_call.args
|
|
52
|
+
)
|
|
53
|
+
logger.info(
|
|
54
|
+
"Detected incomplete tool call in validation",
|
|
55
|
+
extra={
|
|
56
|
+
"tool_name": tool_call.tool_name,
|
|
57
|
+
"tool_call_id": tool_call.tool_call_id,
|
|
58
|
+
"args_preview": args_preview,
|
|
59
|
+
"error": str(e),
|
|
60
|
+
},
|
|
61
|
+
)
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
|
|
66
|
+
"""Filter out messages with incomplete tool calls.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
messages: List of messages to filter
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
List of messages with only complete tool calls
|
|
73
|
+
"""
|
|
74
|
+
filtered: list[ModelMessage] = []
|
|
75
|
+
filtered_count = 0
|
|
76
|
+
filtered_tool_names: list[str] = []
|
|
77
|
+
|
|
78
|
+
for message in messages:
|
|
79
|
+
# Only check ModelResponse messages for tool calls
|
|
80
|
+
if not isinstance(message, ModelResponse):
|
|
81
|
+
filtered.append(message)
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
# Check if any tool calls are incomplete
|
|
85
|
+
has_incomplete_tool_call = False
|
|
86
|
+
for part in message.parts:
|
|
87
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
|
|
88
|
+
has_incomplete_tool_call = True
|
|
89
|
+
filtered_tool_names.append(part.tool_name)
|
|
90
|
+
break
|
|
91
|
+
|
|
92
|
+
# Only include messages without incomplete tool calls
|
|
93
|
+
if not has_incomplete_tool_call:
|
|
94
|
+
filtered.append(message)
|
|
95
|
+
else:
|
|
96
|
+
filtered_count += 1
|
|
97
|
+
|
|
98
|
+
# Log if any messages were filtered
|
|
99
|
+
if filtered_count > 0:
|
|
100
|
+
logger.info(
|
|
101
|
+
"Filtered incomplete messages before saving",
|
|
102
|
+
extra={
|
|
103
|
+
"filtered_count": filtered_count,
|
|
104
|
+
"total_messages": len(messages),
|
|
105
|
+
"filtered_tool_names": filtered_tool_names,
|
|
106
|
+
},
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return filtered
|
|
110
|
+
|
|
111
|
+
|
|
18
112
|
class ConversationState(BaseModel):
|
|
19
113
|
"""Represents the complete state of a conversation in memory."""
|
|
20
114
|
|
|
@@ -46,14 +140,41 @@ class ConversationHistory(BaseModel):
|
|
|
46
140
|
Args:
|
|
47
141
|
messages: List of ModelMessage objects to serialize and store
|
|
48
142
|
"""
|
|
143
|
+
# Filter out messages with incomplete tool calls to prevent corruption
|
|
144
|
+
filtered_messages = filter_incomplete_messages(messages)
|
|
145
|
+
|
|
49
146
|
# Serialize ModelMessage list to JSON-serializable format
|
|
50
147
|
self.agent_history = to_jsonable_python(
|
|
51
|
-
|
|
148
|
+
filtered_messages, fallback=lambda x: str(x), exclude_none=True
|
|
52
149
|
)
|
|
53
150
|
|
|
54
151
|
def set_ui_messages(self, messages: list[ModelMessage | HintMessage]) -> None:
|
|
55
152
|
"""Set ui_history from a list of UI messages."""
|
|
56
153
|
|
|
154
|
+
# Filter out ModelMessages with incomplete tool calls (keep all HintMessages)
|
|
155
|
+
# We need to maintain message order, so we'll check each message individually
|
|
156
|
+
filtered_messages: list[ModelMessage | HintMessage] = []
|
|
157
|
+
|
|
158
|
+
for msg in messages:
|
|
159
|
+
if isinstance(msg, HintMessage):
|
|
160
|
+
# Always keep hint messages
|
|
161
|
+
filtered_messages.append(msg)
|
|
162
|
+
elif isinstance(msg, ModelResponse):
|
|
163
|
+
# Check if this ModelResponse has incomplete tool calls
|
|
164
|
+
has_incomplete = False
|
|
165
|
+
for part in msg.parts:
|
|
166
|
+
if isinstance(part, ToolCallPart) and not is_tool_call_complete(
|
|
167
|
+
part
|
|
168
|
+
):
|
|
169
|
+
has_incomplete = True
|
|
170
|
+
break
|
|
171
|
+
|
|
172
|
+
if not has_incomplete:
|
|
173
|
+
filtered_messages.append(msg)
|
|
174
|
+
else:
|
|
175
|
+
# Keep all other ModelMessage types (ModelRequest, etc.)
|
|
176
|
+
filtered_messages.append(msg)
|
|
177
|
+
|
|
57
178
|
def _serialize_message(
|
|
58
179
|
message: ModelMessage | HintMessage,
|
|
59
180
|
) -> Any:
|
|
@@ -68,7 +189,7 @@ class ConversationHistory(BaseModel):
|
|
|
68
189
|
payload.setdefault("message_type", "model")
|
|
69
190
|
return payload
|
|
70
191
|
|
|
71
|
-
self.ui_history = [_serialize_message(msg) for msg in
|
|
192
|
+
self.ui_history = [_serialize_message(msg) for msg in filtered_messages]
|
|
72
193
|
|
|
73
194
|
def get_agent_messages(self) -> list[ModelMessage]:
|
|
74
195
|
"""Get agent_history as a list of ModelMessage objects.
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Manager for handling conversation persistence operations."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
import shutil
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
|
|
6
7
|
from shotgun.logging_config import get_logger
|
|
@@ -77,9 +78,30 @@ class ConversationManager:
|
|
|
77
78
|
)
|
|
78
79
|
return conversation
|
|
79
80
|
|
|
80
|
-
except
|
|
81
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
82
|
+
# Handle corrupted JSON or validation errors
|
|
83
|
+
logger.error(
|
|
84
|
+
"Corrupted conversation file at %s: %s. Creating backup and starting fresh.",
|
|
85
|
+
self.conversation_path,
|
|
86
|
+
e,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Create a backup of the corrupted file for debugging
|
|
90
|
+
backup_path = self.conversation_path.with_suffix(".json.backup")
|
|
91
|
+
try:
|
|
92
|
+
shutil.copy2(self.conversation_path, backup_path)
|
|
93
|
+
logger.info("Backed up corrupted conversation to %s", backup_path)
|
|
94
|
+
except Exception as backup_error: # pragma: no cover
|
|
95
|
+
logger.warning("Failed to backup corrupted file: %s", backup_error)
|
|
96
|
+
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
except Exception as e: # pragma: no cover
|
|
100
|
+
# Catch-all for unexpected errors
|
|
81
101
|
logger.error(
|
|
82
|
-
"
|
|
102
|
+
"Unexpected error loading conversation from %s: %s",
|
|
103
|
+
self.conversation_path,
|
|
104
|
+
e,
|
|
83
105
|
)
|
|
84
106
|
return None
|
|
85
107
|
|