shotgun-sh 0.2.6.dev5__py3-none-any.whl → 0.2.7.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of shotgun-sh might be problematic. Click here for more details.

@@ -1,10 +1,21 @@
1
1
  """Agent manager for coordinating multiple AI agents with shared message history."""
2
2
 
3
+ import json
3
4
  import logging
4
5
  from collections.abc import AsyncIterable, Sequence
5
6
  from dataclasses import dataclass, field, is_dataclass, replace
7
+ from pathlib import Path
6
8
  from typing import TYPE_CHECKING, Any, cast
7
9
 
10
+ import logfire
11
+ from tenacity import (
12
+ before_sleep_log,
13
+ retry,
14
+ retry_if_exception,
15
+ stop_after_attempt,
16
+ wait_exponential,
17
+ )
18
+
8
19
  if TYPE_CHECKING:
9
20
  from shotgun.agents.conversation_history import ConversationState
10
21
 
@@ -52,6 +63,35 @@ from .tasks import create_tasks_agent
52
63
  logger = logging.getLogger(__name__)
53
64
 
54
65
 
66
+ def _is_retryable_error(exception: BaseException) -> bool:
67
+ """Check if exception should trigger a retry.
68
+
69
+ Args:
70
+ exception: The exception to check.
71
+
72
+ Returns:
73
+ True if the exception is a transient error that should be retried.
74
+ """
75
+ # ValueError for truncated/incomplete JSON
76
+ if isinstance(exception, ValueError):
77
+ error_str = str(exception)
78
+ return "EOF while parsing" in error_str or (
79
+ "JSON" in error_str and "parsing" in error_str
80
+ )
81
+
82
+ # API errors (overload, rate limits)
83
+ exception_name = type(exception).__name__
84
+ if "APIStatusError" in exception_name:
85
+ error_str = str(exception)
86
+ return "overload" in error_str.lower() or "rate" in error_str.lower()
87
+
88
+ # Network errors
89
+ if "ConnectionError" in exception_name or "TimeoutError" in exception_name:
90
+ return True
91
+
92
+ return False
93
+
94
+
55
95
  class MessageHistoryUpdated(Message):
56
96
  """Event posted when the message history is updated."""
57
97
 
@@ -265,6 +305,49 @@ class AgentManager(Widget):
265
305
  f"Invalid agent type: {agent_type}. Must be one of: {', '.join(e.value for e in AgentType)}"
266
306
  ) from None
267
307
 
308
+ @retry(
309
+ stop=stop_after_attempt(3),
310
+ wait=wait_exponential(multiplier=1, min=1, max=8),
311
+ retry=retry_if_exception(_is_retryable_error),
312
+ before_sleep=before_sleep_log(logger, logging.WARNING),
313
+ reraise=True,
314
+ )
315
+ async def _run_agent_with_retry(
316
+ self,
317
+ agent: Agent[AgentDeps, AgentResponse],
318
+ prompt: str | None,
319
+ deps: AgentDeps,
320
+ usage_limits: UsageLimits | None,
321
+ message_history: list[ModelMessage],
322
+ event_stream_handler: Any,
323
+ **kwargs: Any,
324
+ ) -> AgentRunResult[AgentResponse]:
325
+ """Run agent with automatic retry on transient errors.
326
+
327
+ Args:
328
+ agent: The agent to run.
329
+ prompt: Optional prompt to send to the agent.
330
+ deps: Agent dependencies.
331
+ usage_limits: Optional usage limits.
332
+ message_history: Message history to provide to agent.
333
+ event_stream_handler: Event handler for streaming.
334
+ **kwargs: Additional keyword arguments.
335
+
336
+ Returns:
337
+ The agent run result.
338
+
339
+ Raises:
340
+ Various exceptions if all retries fail.
341
+ """
342
+ return await agent.run(
343
+ prompt,
344
+ deps=deps,
345
+ usage_limits=usage_limits,
346
+ message_history=message_history,
347
+ event_stream_handler=event_stream_handler,
348
+ **kwargs,
349
+ )
350
+
268
351
  async def run(
269
352
  self,
270
353
  prompt: str | None = None,
@@ -391,8 +474,9 @@ class AgentManager(Widget):
391
474
  )
392
475
 
393
476
  try:
394
- result: AgentRunResult[AgentResponse] = await self.current_agent.run(
395
- prompt,
477
+ result: AgentRunResult[AgentResponse] = await self._run_agent_with_retry(
478
+ agent=self.current_agent,
479
+ prompt=prompt,
396
480
  deps=deps,
397
481
  usage_limits=usage_limits,
398
482
  message_history=message_history,
@@ -401,18 +485,93 @@ class AgentManager(Widget):
401
485
  else None,
402
486
  **kwargs,
403
487
  )
488
+ except ValueError as e:
489
+ # Handle truncated/incomplete JSON in tool calls specifically
490
+ error_str = str(e)
491
+ if "EOF while parsing" in error_str or (
492
+ "JSON" in error_str and "parsing" in error_str
493
+ ):
494
+ logger.error(
495
+ "Tool call with truncated/incomplete JSON arguments detected",
496
+ extra={
497
+ "agent_mode": self._current_agent_type.value,
498
+ "model_name": model_name,
499
+ "error": error_str,
500
+ },
501
+ )
502
+ logfire.error(
503
+ "Tool call with truncated JSON arguments",
504
+ agent_mode=self._current_agent_type.value,
505
+ model_name=model_name,
506
+ error=error_str,
507
+ )
508
+ # Add helpful hint message for the user
509
+ self.ui_message_history.append(
510
+ HintMessage(
511
+ message="⚠️ The agent attempted an operation with arguments that were too large (truncated JSON). "
512
+ "Try breaking your request into smaller steps or more focused contracts."
513
+ )
514
+ )
515
+ self._post_messages_updated()
516
+ # Re-raise to maintain error visibility
517
+ raise
518
+ except Exception as e:
519
+ # Log the error with full stack trace to shotgun.log and Logfire
520
+ logger.exception(
521
+ "Agent execution failed",
522
+ extra={
523
+ "agent_mode": self._current_agent_type.value,
524
+ "model_name": model_name,
525
+ "error_type": type(e).__name__,
526
+ },
527
+ )
528
+ logfire.exception(
529
+ "Agent execution failed",
530
+ agent_mode=self._current_agent_type.value,
531
+ model_name=model_name,
532
+ error_type=type(e).__name__,
533
+ )
534
+ # Re-raise to let TUI handle user messaging
535
+ raise
404
536
  finally:
405
537
  self._stream_state = None
406
538
 
407
539
  # Agent ALWAYS returns AgentResponse with structured output
408
540
  agent_response = result.output
409
- logger.debug("Agent returned structured AgentResponse")
541
+ logger.debug(
542
+ "Agent returned structured AgentResponse",
543
+ extra={
544
+ "has_response": agent_response.response is not None,
545
+ "response_length": len(agent_response.response)
546
+ if agent_response.response
547
+ else 0,
548
+ "response_preview": agent_response.response[:100] + "..."
549
+ if agent_response.response and len(agent_response.response) > 100
550
+ else agent_response.response or "(empty)",
551
+ "has_clarifying_questions": bool(agent_response.clarifying_questions),
552
+ "num_clarifying_questions": len(agent_response.clarifying_questions)
553
+ if agent_response.clarifying_questions
554
+ else 0,
555
+ },
556
+ )
410
557
 
411
558
  # Always add the agent's response messages to maintain conversation history
412
559
  self.ui_message_history = original_messages + cast(
413
560
  list[ModelRequest | ModelResponse | HintMessage], result.new_messages()
414
561
  )
415
562
 
563
+ # Get file operations early so we can use them for contextual messages
564
+ file_operations = deps.file_tracker.operations.copy()
565
+ self.recently_change_files = file_operations
566
+
567
+ logger.debug(
568
+ "File operations tracked",
569
+ extra={
570
+ "num_file_operations": len(file_operations),
571
+ "operation_files": [Path(op.file_path).name for op in file_operations],
572
+ },
573
+ )
574
+
416
575
  # Check if there are clarifying questions
417
576
  if agent_response.clarifying_questions:
418
577
  logger.info(
@@ -459,27 +618,93 @@ class AgentManager(Widget):
459
618
  response_text=agent_response.response,
460
619
  )
461
620
  )
621
+
622
+ # Post UI update with hint messages and file operations
623
+ logger.debug(
624
+ "Posting UI update for Q&A mode with hint messages and file operations"
625
+ )
626
+ self._post_messages_updated(file_operations)
462
627
  else:
463
- # No clarifying questions - just show the response if present
628
+ # No clarifying questions - show the response or a default success message
464
629
  if agent_response.response and agent_response.response.strip():
630
+ logger.debug(
631
+ "Adding agent response as hint",
632
+ extra={
633
+ "response_preview": agent_response.response[:100] + "..."
634
+ if len(agent_response.response) > 100
635
+ else agent_response.response,
636
+ "has_file_operations": len(file_operations) > 0,
637
+ },
638
+ )
465
639
  self.ui_message_history.append(
466
640
  HintMessage(message=agent_response.response)
467
641
  )
642
+ else:
643
+ # Fallback: response is empty or whitespace
644
+ logger.debug(
645
+ "Agent response was empty, using fallback completion message",
646
+ extra={"has_file_operations": len(file_operations) > 0},
647
+ )
648
+ # Show contextual message based on whether files were modified
649
+ if file_operations:
650
+ self.ui_message_history.append(
651
+ HintMessage(
652
+ message="✅ Task completed - files have been modified"
653
+ )
654
+ )
655
+ else:
656
+ self.ui_message_history.append(
657
+ HintMessage(message="✅ Task completed")
658
+ )
659
+
660
+ # Post UI update immediately so user sees the response without delay
661
+ logger.debug(
662
+ "Posting immediate UI update with hint message and file operations"
663
+ )
664
+ self._post_messages_updated(file_operations)
468
665
 
469
666
  # Apply compaction to persistent message history to prevent cascading growth
470
667
  all_messages = result.all_messages()
471
- self.message_history = await apply_persistent_compaction(all_messages, deps)
472
- usage = result.usage()
473
- deps.usage_manager.add_usage(
474
- usage, model_name=deps.llm_model.name, provider=deps.llm_model.provider
475
- )
668
+ try:
669
+ logger.debug(
670
+ "Starting message history compaction",
671
+ extra={"message_count": len(all_messages)},
672
+ )
673
+ self.message_history = await apply_persistent_compaction(all_messages, deps)
674
+ logger.debug(
675
+ "Completed message history compaction",
676
+ extra={
677
+ "original_count": len(all_messages),
678
+ "compacted_count": len(self.message_history),
679
+ },
680
+ )
681
+ except Exception as e:
682
+ # If compaction fails, log full error with stack trace and use uncompacted messages
683
+ logger.error(
684
+ "Failed to compact message history - using uncompacted messages",
685
+ exc_info=True,
686
+ extra={
687
+ "error": str(e),
688
+ "message_count": len(all_messages),
689
+ "agent_mode": self._current_agent_type.value,
690
+ },
691
+ )
692
+ # Fallback: use uncompacted messages to prevent data loss
693
+ self.message_history = all_messages
476
694
 
477
- # Log file operations summary if any files were modified
478
- file_operations = deps.file_tracker.operations.copy()
479
- self.recently_change_files = file_operations
695
+ usage = result.usage()
696
+ if hasattr(deps, "llm_model") and deps.llm_model is not None:
697
+ deps.usage_manager.add_usage(
698
+ usage, model_name=deps.llm_model.name, provider=deps.llm_model.provider
699
+ )
700
+ else:
701
+ logger.warning(
702
+ "llm_model is None, skipping usage tracking",
703
+ extra={"agent_mode": self._current_agent_type.value},
704
+ )
480
705
 
481
- # Post message history update (hints are now added synchronously above)
482
- self._post_messages_updated(file_operations)
706
+ # UI updates are now posted immediately in each branch (Q&A or non-Q&A)
707
+ # before compaction, so no duplicate posting needed here
483
708
 
484
709
  return result
485
710
 
@@ -554,6 +779,39 @@ class AgentManager(Widget):
554
779
  # Detect source from call stack
555
780
  source = detect_source()
556
781
 
782
+ # Log if tool call has incomplete args (for debugging truncated JSON)
783
+ if isinstance(event.part.args, str):
784
+ try:
785
+ json.loads(event.part.args)
786
+ except (json.JSONDecodeError, ValueError):
787
+ args_preview = (
788
+ event.part.args[:100] + "..."
789
+ if len(event.part.args) > 100
790
+ else event.part.args
791
+ )
792
+ logger.warning(
793
+ "FunctionToolCallEvent received with incomplete JSON args",
794
+ extra={
795
+ "tool_name": event.part.tool_name,
796
+ "tool_call_id": event.part.tool_call_id,
797
+ "args_preview": args_preview,
798
+ "args_length": len(event.part.args)
799
+ if event.part.args
800
+ else 0,
801
+ "agent_mode": self._current_agent_type.value,
802
+ },
803
+ )
804
+ logfire.warn(
805
+ "FunctionToolCallEvent received with incomplete JSON args",
806
+ tool_name=event.part.tool_name,
807
+ tool_call_id=event.part.tool_call_id,
808
+ args_preview=args_preview,
809
+ args_length=len(event.part.args)
810
+ if event.part.args
811
+ else 0,
812
+ agent_mode=self._current_agent_type.value,
813
+ )
814
+
557
815
  track_event(
558
816
  "tool_called",
559
817
  {
shotgun/agents/common.py CHANGED
@@ -384,23 +384,48 @@ def get_agent_existing_files(agent_mode: AgentType | None = None) -> list[str]:
384
384
  relative_path = file_path.relative_to(base_path)
385
385
  existing_files.append(str(relative_path))
386
386
  else:
387
- # For other agents, check both .md file and directory with same name
388
- allowed_file = AGENT_DIRECTORIES[agent_mode]
389
-
390
- # Check for the .md file
391
- md_file_path = base_path / allowed_file
392
- if md_file_path.exists():
393
- existing_files.append(allowed_file)
394
-
395
- # Check for directory with same base name (e.g., research/ for research.md)
396
- base_name = allowed_file.replace(".md", "")
397
- dir_path = base_path / base_name
398
- if dir_path.exists() and dir_path.is_dir():
399
- # List all files in the directory
400
- for file_path in dir_path.rglob("*"):
401
- if file_path.is_file():
402
- relative_path = file_path.relative_to(base_path)
403
- existing_files.append(str(relative_path))
387
+ # For other agents, check files/directories they have access to
388
+ allowed_paths_raw = AGENT_DIRECTORIES[agent_mode]
389
+
390
+ # Convert single Path/string to list of Paths for uniform handling
391
+ if isinstance(allowed_paths_raw, str):
392
+ # Special case: "*" means export agent (shouldn't reach here but handle it)
393
+ allowed_paths = (
394
+ [Path(allowed_paths_raw)] if allowed_paths_raw != "*" else []
395
+ )
396
+ elif isinstance(allowed_paths_raw, Path):
397
+ allowed_paths = [allowed_paths_raw]
398
+ else:
399
+ # Already a list
400
+ allowed_paths = allowed_paths_raw
401
+
402
+ # Check each allowed path
403
+ for allowed_path in allowed_paths:
404
+ allowed_str = str(allowed_path)
405
+
406
+ # Check if it's a directory (no .md suffix)
407
+ if not allowed_path.suffix or not allowed_str.endswith(".md"):
408
+ # It's a directory - list all files within it
409
+ dir_path = base_path / allowed_str
410
+ if dir_path.exists() and dir_path.is_dir():
411
+ for file_path in dir_path.rglob("*"):
412
+ if file_path.is_file():
413
+ relative_path = file_path.relative_to(base_path)
414
+ existing_files.append(str(relative_path))
415
+ else:
416
+ # It's a file - check if it exists
417
+ file_path = base_path / allowed_str
418
+ if file_path.exists():
419
+ existing_files.append(allowed_str)
420
+
421
+ # Also check for associated directory (e.g., research/ for research.md)
422
+ base_name = allowed_str.replace(".md", "")
423
+ dir_path = base_path / base_name
424
+ if dir_path.exists() and dir_path.is_dir():
425
+ for file_path in dir_path.rglob("*"):
426
+ if file_path.is_file():
427
+ relative_path = file_path.relative_to(base_path)
428
+ existing_files.append(str(relative_path))
404
429
 
405
430
  return existing_files
406
431
 
@@ -1,5 +1,7 @@
1
1
  """Models and utilities for persisting TUI conversation history."""
2
2
 
3
+ import json
4
+ import logging
3
5
  from datetime import datetime
4
6
  from typing import Any, cast
5
7
 
@@ -7,14 +9,106 @@ from pydantic import BaseModel, ConfigDict, Field
7
9
  from pydantic_ai.messages import (
8
10
  ModelMessage,
9
11
  ModelMessagesTypeAdapter,
12
+ ModelResponse,
13
+ ToolCallPart,
10
14
  )
11
15
  from pydantic_core import to_jsonable_python
12
16
 
13
17
  from shotgun.tui.screens.chat_screen.hint_message import HintMessage
14
18
 
19
+ logger = logging.getLogger(__name__)
20
+
15
21
  SerializedMessage = dict[str, Any]
16
22
 
17
23
 
24
+ def is_tool_call_complete(tool_call: ToolCallPart) -> bool:
25
+ """Check if a tool call has valid, complete JSON arguments.
26
+
27
+ Args:
28
+ tool_call: The tool call part to validate
29
+
30
+ Returns:
31
+ True if the tool call args are valid JSON, False otherwise
32
+ """
33
+ if tool_call.args is None:
34
+ return True # No args is valid
35
+
36
+ if isinstance(tool_call.args, dict):
37
+ return True # Already parsed dict is valid
38
+
39
+ if not isinstance(tool_call.args, str):
40
+ return False
41
+
42
+ # Try to parse the JSON string
43
+ try:
44
+ json.loads(tool_call.args)
45
+ return True
46
+ except (json.JSONDecodeError, ValueError) as e:
47
+ # Log incomplete tool call detection
48
+ args_preview = (
49
+ tool_call.args[:100] + "..."
50
+ if len(tool_call.args) > 100
51
+ else tool_call.args
52
+ )
53
+ logger.info(
54
+ "Detected incomplete tool call in validation",
55
+ extra={
56
+ "tool_name": tool_call.tool_name,
57
+ "tool_call_id": tool_call.tool_call_id,
58
+ "args_preview": args_preview,
59
+ "error": str(e),
60
+ },
61
+ )
62
+ return False
63
+
64
+
65
+ def filter_incomplete_messages(messages: list[ModelMessage]) -> list[ModelMessage]:
66
+ """Filter out messages with incomplete tool calls.
67
+
68
+ Args:
69
+ messages: List of messages to filter
70
+
71
+ Returns:
72
+ List of messages with only complete tool calls
73
+ """
74
+ filtered: list[ModelMessage] = []
75
+ filtered_count = 0
76
+ filtered_tool_names: list[str] = []
77
+
78
+ for message in messages:
79
+ # Only check ModelResponse messages for tool calls
80
+ if not isinstance(message, ModelResponse):
81
+ filtered.append(message)
82
+ continue
83
+
84
+ # Check if any tool calls are incomplete
85
+ has_incomplete_tool_call = False
86
+ for part in message.parts:
87
+ if isinstance(part, ToolCallPart) and not is_tool_call_complete(part):
88
+ has_incomplete_tool_call = True
89
+ filtered_tool_names.append(part.tool_name)
90
+ break
91
+
92
+ # Only include messages without incomplete tool calls
93
+ if not has_incomplete_tool_call:
94
+ filtered.append(message)
95
+ else:
96
+ filtered_count += 1
97
+
98
+ # Log if any messages were filtered
99
+ if filtered_count > 0:
100
+ logger.info(
101
+ "Filtered incomplete messages before saving",
102
+ extra={
103
+ "filtered_count": filtered_count,
104
+ "total_messages": len(messages),
105
+ "filtered_tool_names": filtered_tool_names,
106
+ },
107
+ )
108
+
109
+ return filtered
110
+
111
+
18
112
  class ConversationState(BaseModel):
19
113
  """Represents the complete state of a conversation in memory."""
20
114
 
@@ -46,14 +140,41 @@ class ConversationHistory(BaseModel):
46
140
  Args:
47
141
  messages: List of ModelMessage objects to serialize and store
48
142
  """
143
+ # Filter out messages with incomplete tool calls to prevent corruption
144
+ filtered_messages = filter_incomplete_messages(messages)
145
+
49
146
  # Serialize ModelMessage list to JSON-serializable format
50
147
  self.agent_history = to_jsonable_python(
51
- messages, fallback=lambda x: str(x), exclude_none=True
148
+ filtered_messages, fallback=lambda x: str(x), exclude_none=True
52
149
  )
53
150
 
54
151
  def set_ui_messages(self, messages: list[ModelMessage | HintMessage]) -> None:
55
152
  """Set ui_history from a list of UI messages."""
56
153
 
154
+ # Filter out ModelMessages with incomplete tool calls (keep all HintMessages)
155
+ # We need to maintain message order, so we'll check each message individually
156
+ filtered_messages: list[ModelMessage | HintMessage] = []
157
+
158
+ for msg in messages:
159
+ if isinstance(msg, HintMessage):
160
+ # Always keep hint messages
161
+ filtered_messages.append(msg)
162
+ elif isinstance(msg, ModelResponse):
163
+ # Check if this ModelResponse has incomplete tool calls
164
+ has_incomplete = False
165
+ for part in msg.parts:
166
+ if isinstance(part, ToolCallPart) and not is_tool_call_complete(
167
+ part
168
+ ):
169
+ has_incomplete = True
170
+ break
171
+
172
+ if not has_incomplete:
173
+ filtered_messages.append(msg)
174
+ else:
175
+ # Keep all other ModelMessage types (ModelRequest, etc.)
176
+ filtered_messages.append(msg)
177
+
57
178
  def _serialize_message(
58
179
  message: ModelMessage | HintMessage,
59
180
  ) -> Any:
@@ -68,7 +189,7 @@ class ConversationHistory(BaseModel):
68
189
  payload.setdefault("message_type", "model")
69
190
  return payload
70
191
 
71
- self.ui_history = [_serialize_message(msg) for msg in messages]
192
+ self.ui_history = [_serialize_message(msg) for msg in filtered_messages]
72
193
 
73
194
  def get_agent_messages(self) -> list[ModelMessage]:
74
195
  """Get agent_history as a list of ModelMessage objects.
@@ -1,6 +1,7 @@
1
1
  """Manager for handling conversation persistence operations."""
2
2
 
3
3
  import json
4
+ import shutil
4
5
  from pathlib import Path
5
6
 
6
7
  from shotgun.logging_config import get_logger
@@ -77,9 +78,30 @@ class ConversationManager:
77
78
  )
78
79
  return conversation
79
80
 
80
- except Exception as e:
81
+ except (json.JSONDecodeError, ValueError) as e:
82
+ # Handle corrupted JSON or validation errors
83
+ logger.error(
84
+ "Corrupted conversation file at %s: %s. Creating backup and starting fresh.",
85
+ self.conversation_path,
86
+ e,
87
+ )
88
+
89
+ # Create a backup of the corrupted file for debugging
90
+ backup_path = self.conversation_path.with_suffix(".json.backup")
91
+ try:
92
+ shutil.copy2(self.conversation_path, backup_path)
93
+ logger.info("Backed up corrupted conversation to %s", backup_path)
94
+ except Exception as backup_error: # pragma: no cover
95
+ logger.warning("Failed to backup corrupted file: %s", backup_error)
96
+
97
+ return None
98
+
99
+ except Exception as e: # pragma: no cover
100
+ # Catch-all for unexpected errors
81
101
  logger.error(
82
- "Failed to load conversation from %s: %s", self.conversation_path, e
102
+ "Unexpected error loading conversation from %s: %s",
103
+ self.conversation_path,
104
+ e,
83
105
  )
84
106
  return None
85
107