fast-agent-mcp 0.2.33__py3-none-any.whl → 0.2.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ """
2
+ Utility module for displaying usage statistics in a consistent format.
3
+ Consolidates the usage display logic that was duplicated between fastagent.py and interactive_prompt.py.
4
+ """
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ from rich.console import Console
9
+
10
+
11
+ def display_usage_report(
12
+ agents: Dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False
13
+ ) -> None:
14
+ """
15
+ Display a formatted table of token usage for all agents.
16
+
17
+ Args:
18
+ agents: Dictionary of agent name -> agent object
19
+ show_if_progress_disabled: If True, show even when progress display is disabled
20
+ subdued_colors: If True, use dim styling for a more subdued appearance
21
+ """
22
+ # Check if progress display is enabled (only relevant for fastagent context)
23
+ if not show_if_progress_disabled:
24
+ try:
25
+ from mcp_agent import config
26
+
27
+ settings = config.get_settings()
28
+ if not settings.logger.progress_display:
29
+ return
30
+ except (ImportError, AttributeError):
31
+ # If we can't check settings, assume we should display
32
+ pass
33
+
34
+ # Collect usage data from all agents
35
+ usage_data = []
36
+ total_input = 0
37
+ total_output = 0
38
+ total_tokens = 0
39
+
40
+ for agent_name, agent in agents.items():
41
+ if agent.usage_accumulator:
42
+ summary = agent.usage_accumulator.get_summary()
43
+ if summary["turn_count"] > 0:
44
+ input_tokens = summary["cumulative_input_tokens"]
45
+ output_tokens = summary["cumulative_output_tokens"]
46
+ billing_tokens = summary["cumulative_billing_tokens"]
47
+ turns = summary["turn_count"]
48
+
49
+ # Get context percentage for this agent
50
+ context_percentage = agent.usage_accumulator.context_usage_percentage
51
+
52
+ # Get model name from LLM's default_request_params
53
+ model = "unknown"
54
+ if hasattr(agent, "_llm") and agent._llm:
55
+ llm = agent._llm
56
+ if (
57
+ hasattr(llm, "default_request_params")
58
+ and llm.default_request_params
59
+ and hasattr(llm.default_request_params, "model")
60
+ ):
61
+ model = llm.default_request_params.model or "unknown"
62
+
63
+ # Standardize model name truncation - use consistent 25 char width with 22+... truncation
64
+ if len(model) > 25:
65
+ model = model[:22] + "..."
66
+
67
+ usage_data.append(
68
+ {
69
+ "name": agent_name,
70
+ "model": model,
71
+ "input": input_tokens,
72
+ "output": output_tokens,
73
+ "total": billing_tokens,
74
+ "turns": turns,
75
+ "context": context_percentage,
76
+ }
77
+ )
78
+
79
+ total_input += input_tokens
80
+ total_output += output_tokens
81
+ total_tokens += billing_tokens
82
+
83
+ if not usage_data:
84
+ return
85
+
86
+ # Calculate dynamic agent column width (max 15)
87
+ max_agent_width = min(15, max(len(data["name"]) for data in usage_data) if usage_data else 8)
88
+ agent_width = max(max_agent_width, 5) # Minimum of 5 for "Agent" header
89
+
90
+ # Display the table
91
+ console = Console()
92
+ console.print()
93
+ console.print("[dim]Usage Summary (Cumulative)[/dim]")
94
+
95
+ # Print header with proper spacing
96
+ console.print(
97
+ f"[dim]{'Agent':<{agent_width}} {'Input':>9} {'Output':>9} {'Total':>9} {'Turns':>6} {'Context%':>9} {'Model':<25}[/dim]"
98
+ )
99
+
100
+ # Print agent rows - use styling based on subdued_colors flag
101
+ for data in usage_data:
102
+ input_str = f"{data['input']:,}"
103
+ output_str = f"{data['output']:,}"
104
+ total_str = f"{data['total']:,}"
105
+ turns_str = str(data["turns"])
106
+ context_str = f"{data['context']:.1f}%" if data["context"] is not None else "-"
107
+
108
+ # Truncate agent name if needed
109
+ agent_name = data["name"]
110
+ if len(agent_name) > agent_width:
111
+ agent_name = agent_name[: agent_width - 3] + "..."
112
+
113
+ if subdued_colors:
114
+ # Original fastagent.py style with dim wrapper
115
+ console.print(
116
+ f"[dim]{agent_name:<{agent_width}} "
117
+ f"{input_str:>9} "
118
+ f"{output_str:>9} "
119
+ f"[bold]{total_str:>9}[/bold] "
120
+ f"{turns_str:>6} "
121
+ f"{context_str:>9} "
122
+ f"{data['model']:<25}[/dim]"
123
+ )
124
+ else:
125
+ # Original interactive_prompt.py style
126
+ console.print(
127
+ f"{agent_name:<{agent_width}} "
128
+ f"{input_str:>9} "
129
+ f"{output_str:>9} "
130
+ f"[bold]{total_str:>9}[/bold] "
131
+ f"{turns_str:>6} "
132
+ f"{context_str:>9} "
133
+ f"[dim]{data['model']:<25}[/dim]"
134
+ )
135
+
136
+ # Add total row if multiple agents
137
+ if len(usage_data) > 1:
138
+ console.print()
139
+ total_input_str = f"{total_input:,}"
140
+ total_output_str = f"{total_output:,}"
141
+ total_tokens_str = f"{total_tokens:,}"
142
+
143
+ if subdued_colors:
144
+ # Original fastagent.py style with dim wrapper on bold
145
+ console.print(
146
+ f"[bold dim]{'TOTAL':<{agent_width}} "
147
+ f"{total_input_str:>9} "
148
+ f"{total_output_str:>9} "
149
+ f"[bold]{total_tokens_str:>9}[/bold] "
150
+ f"{'':<6} "
151
+ f"{'':<9} "
152
+ f"{'':<25}[/bold dim]"
153
+ )
154
+ else:
155
+ # Original interactive_prompt.py style
156
+ console.print(
157
+ f"[bold]{'TOTAL':<{agent_width}}[/bold] "
158
+ f"[bold]{total_input_str:>9}[/bold] "
159
+ f"[bold]{total_output_str:>9}[/bold] "
160
+ f"[bold]{total_tokens_str:>9}[/bold] "
161
+ f"{'':<6} "
162
+ f"{'':<9} "
163
+ f"{'':<25}"
164
+ )
165
+
166
+ console.print()
167
+
168
+
169
+ def collect_agents_from_provider(
170
+ prompt_provider: Any, agent_name: Optional[str] = None
171
+ ) -> Dict[str, Any]:
172
+ """
173
+ Collect agents from a prompt provider for usage display.
174
+
175
+ Args:
176
+ prompt_provider: Provider that has access to agents
177
+ agent_name: Name of the current agent (for context)
178
+
179
+ Returns:
180
+ Dictionary of agent name -> agent object
181
+ """
182
+ agents_to_show = {}
183
+
184
+ if hasattr(prompt_provider, "_agents"):
185
+ # Multi-agent app - show all agents
186
+ agents_to_show = prompt_provider._agents
187
+ elif hasattr(prompt_provider, "agent"):
188
+ # Single agent
189
+ agent = prompt_provider.agent
190
+ if hasattr(agent, "name"):
191
+ agents_to_show = {agent.name: agent}
192
+
193
+ return agents_to_show
@@ -15,6 +15,7 @@ class ProgressAction(str, Enum):
15
15
  LOADED = "Loaded"
16
16
  INITIALIZED = "Initialized"
17
17
  CHATTING = "Chatting"
18
+ STREAMING = "Streaming" # Special action for real-time streaming updates
18
19
  ROUTING = "Routing"
19
20
  PLANNING = "Planning"
20
21
  READY = "Ready"
@@ -33,12 +34,22 @@ class ProgressEvent(BaseModel):
33
34
  target: str
34
35
  details: Optional[str] = None
35
36
  agent_name: Optional[str] = None
37
+ streaming_tokens: Optional[str] = None # Special field for streaming token count
36
38
 
37
39
  def __str__(self) -> str:
38
40
  """Format the progress event for display."""
39
- base = f"{self.action.ljust(11)}. {self.target}"
40
- if self.details:
41
- base += f" - {self.details}"
41
+ # Special handling for streaming - show token count in action position
42
+ if self.action == ProgressAction.STREAMING and self.streaming_tokens:
43
+ # For streaming, show just the token count instead of "Streaming"
44
+ action_display = self.streaming_tokens.ljust(11)
45
+ base = f"{action_display}. {self.target}"
46
+ if self.details:
47
+ base += f" - {self.details}"
48
+ else:
49
+ base = f"{self.action.ljust(11)}. {self.target}"
50
+ if self.details:
51
+ base += f" - {self.details}"
52
+
42
53
  if self.agent_name:
43
54
  base = f"[{self.agent_name}] {base}"
44
55
  return base
@@ -78,7 +89,8 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
78
89
 
79
90
  elif "augmented_llm" in namespace:
80
91
  model = event_data.get("model", "")
81
-
92
+
93
+ # For all augmented_llm events, put model info in details column
82
94
  details = f"{model}"
83
95
  chat_turn = event_data.get("chat_turn")
84
96
  if chat_turn is not None:
@@ -87,9 +99,15 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
87
99
  if not target:
88
100
  target = event_data.get("target", "unknown")
89
101
 
102
+ # Extract streaming token count for STREAMING actions
103
+ streaming_tokens = None
104
+ if progress_action == ProgressAction.STREAMING:
105
+ streaming_tokens = event_data.get("details", "")
106
+
90
107
  return ProgressEvent(
91
108
  action=ProgressAction(progress_action),
92
109
  target=target or "unknown",
93
110
  details=details,
94
111
  agent_name=event_data.get("agent_name"),
112
+ streaming_tokens=streaming_tokens,
95
113
  )
@@ -30,11 +30,13 @@ from mcp_agent.core.prompt import Prompt
30
30
  from mcp_agent.core.request_params import RequestParams
31
31
  from mcp_agent.event_progress import ProgressAction
32
32
  from mcp_agent.llm.memory import Memory, SimpleMemory
33
+ from mcp_agent.llm.model_database import ModelDatabase
33
34
  from mcp_agent.llm.provider_types import Provider
34
35
  from mcp_agent.llm.sampling_format_converter import (
35
36
  BasicFormatConverter,
36
37
  ProviderFormatConverter,
37
38
  )
39
+ from mcp_agent.llm.usage_tracking import UsageAccumulator
38
40
  from mcp_agent.logging.logger import get_logger
39
41
  from mcp_agent.mcp.helpers.content_helpers import get_text
40
42
  from mcp_agent.mcp.interfaces import (
@@ -95,6 +97,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
95
97
  PARAM_USE_HISTORY = "use_history"
96
98
  PARAM_MAX_ITERATIONS = "max_iterations"
97
99
  PARAM_TEMPLATE_VARS = "template_vars"
100
+
98
101
  # Base set of fields that should always be excluded
99
102
  BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
100
103
 
@@ -155,12 +158,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
155
158
  # Initialize the display component
156
159
  self.display = ConsoleDisplay(config=self.context.config)
157
160
 
158
- # Initialize default parameters
159
- self.default_request_params = self._initialize_default_params(kwargs)
160
-
161
- # Apply model override if provided
161
+ # Initialize default parameters, passing model info
162
+ model_kwargs = kwargs.copy()
162
163
  if model:
163
- self.default_request_params.model = model
164
+ model_kwargs["model"] = model
165
+ self.default_request_params = self._initialize_default_params(model_kwargs)
164
166
 
165
167
  # Merge with provided params if any
166
168
  if self._init_request_params:
@@ -171,13 +173,22 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
171
173
  self.type_converter = type_converter
172
174
  self.verb = kwargs.get("verb")
173
175
 
176
+ # Initialize usage tracking
177
+ self.usage_accumulator = UsageAccumulator()
178
+
174
179
  def _initialize_default_params(self, kwargs: dict) -> RequestParams:
175
180
  """Initialize default parameters for the LLM.
176
181
  Should be overridden by provider implementations to set provider-specific defaults."""
182
+ # Get model-aware default max tokens
183
+ model = kwargs.get("model")
184
+ max_tokens = ModelDatabase.get_default_max_tokens(model)
185
+
177
186
  return RequestParams(
187
+ model=model,
188
+ maxTokens=max_tokens,
178
189
  systemPrompt=self.instruction,
179
190
  parallel_tool_calls=True,
180
- max_iterations=10,
191
+ max_iterations=20,
181
192
  use_history=True,
182
193
  )
183
194
 
@@ -361,16 +372,28 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
361
372
  # Start with base arguments
362
373
  arguments = base_args.copy()
363
374
 
364
- # Use provided exclude_fields or fall back to base exclusions
365
- exclude_fields = exclude_fields or self.BASE_EXCLUDE_FIELDS.copy()
375
+ # Combine base exclusions with provider-specific exclusions
376
+ final_exclude_fields = self.BASE_EXCLUDE_FIELDS.copy()
377
+ if exclude_fields:
378
+ final_exclude_fields.update(exclude_fields)
366
379
 
367
380
  # Add all fields from params that aren't explicitly excluded
368
- params_dict = request_params.model_dump(exclude=exclude_fields)
381
+ # Ensure model_dump only includes set fields if that's the desired behavior,
382
+ # or adjust exclude_unset=True/False as needed.
383
+ # Default Pydantic v2 model_dump is exclude_unset=False
384
+ params_dict = request_params.model_dump(exclude=final_exclude_fields)
385
+
369
386
  for key, value in params_dict.items():
387
+ # Only add if not None and not already in base_args (base_args take precedence)
388
+ # or if None is a valid value for the provider, this logic might need adjustment.
370
389
  if value is not None and key not in arguments:
371
390
  arguments[key] = value
391
+ elif value is not None and key in arguments and arguments[key] is None:
392
+ # Allow overriding a None in base_args with a set value from params
393
+ arguments[key] = value
372
394
 
373
395
  # Finally, add any metadata fields as a last layer of overrides
396
+ # This ensures metadata can override anything previously set if keys conflict.
374
397
  if request_params.metadata:
375
398
  arguments.update(request_params.metadata)
376
399
 
@@ -642,3 +665,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
642
665
 
643
666
  assert self.provider
644
667
  return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
668
+
669
+ def get_usage_summary(self) -> dict:
670
+ """
671
+ Get a summary of usage statistics for this LLM instance.
672
+
673
+ Returns:
674
+ Dictionary containing usage statistics including tokens, cache metrics,
675
+ and context window utilization.
676
+ """
677
+ return self.usage_accumulator.get_summary()
@@ -10,6 +10,7 @@ from mcp_agent.llm.augmented_llm import (
10
10
  RequestParams,
11
11
  )
12
12
  from mcp_agent.llm.provider_types import Provider
13
+ from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
13
14
  from mcp_agent.logging.logger import get_logger
14
15
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
15
16
 
@@ -48,13 +49,34 @@ class PassthroughLLM(AugmentedLLM):
48
49
  await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
49
50
 
50
51
  # Handle PromptMessage by concatenating all parts
52
+ result = ""
51
53
  if isinstance(message, PromptMessage):
52
54
  parts_text = []
53
55
  for part in message.content:
54
56
  parts_text.append(str(part))
55
- return "\n".join(parts_text)
57
+ result = "\n".join(parts_text)
58
+ else:
59
+ result = str(message)
56
60
 
57
- return str(message)
61
+ # Track usage for this passthrough "turn"
62
+ try:
63
+ input_content = str(message)
64
+ output_content = result
65
+ tool_calls = 1 if input_content.startswith("***CALL_TOOL") else 0
66
+
67
+ turn_usage = create_turn_usage_from_messages(
68
+ input_content=input_content,
69
+ output_content=output_content,
70
+ model="passthrough",
71
+ model_type="passthrough",
72
+ tool_calls=tool_calls,
73
+ delay_seconds=0.0,
74
+ )
75
+ self.usage_accumulator.add_turn(turn_usage)
76
+ except Exception as e:
77
+ self.logger.warning(f"Failed to track usage: {e}")
78
+
79
+ return result
58
80
 
59
81
  async def initialize(self) -> None:
60
82
  pass
@@ -146,6 +168,25 @@ class PassthroughLLM(AugmentedLLM):
146
168
  if self.is_tool_call(last_message):
147
169
  result = Prompt.assistant(await self.generate_str(last_message.first_text()))
148
170
  await self.show_assistant_message(result.first_text())
171
+
172
+ # Track usage for this tool call "turn"
173
+ try:
174
+ input_content = "\n".join(message.all_text() for message in multipart_messages)
175
+ output_content = result.first_text()
176
+
177
+ turn_usage = create_turn_usage_from_messages(
178
+ input_content=input_content,
179
+ output_content=output_content,
180
+ model="passthrough",
181
+ model_type="passthrough",
182
+ tool_calls=1, # This is definitely a tool call
183
+ delay_seconds=0.0,
184
+ )
185
+ self.usage_accumulator.add_turn(turn_usage)
186
+
187
+ except Exception as e:
188
+ self.logger.warning(f"Failed to track usage: {e}")
189
+
149
190
  return result
150
191
 
151
192
  if last_message.first_text().startswith(FIXED_RESPONSE_INDICATOR):
@@ -155,12 +196,33 @@ class PassthroughLLM(AugmentedLLM):
155
196
 
156
197
  if self._fixed_response:
157
198
  await self.show_assistant_message(self._fixed_response)
158
- return Prompt.assistant(self._fixed_response)
199
+ result = Prompt.assistant(self._fixed_response)
159
200
  else:
160
201
  # TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
161
202
  concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
162
203
  await self.show_assistant_message(concatenated)
163
- return Prompt.assistant(concatenated)
204
+ result = Prompt.assistant(concatenated)
205
+
206
+ # Track usage for this passthrough "turn"
207
+ try:
208
+ input_content = "\n".join(message.all_text() for message in multipart_messages)
209
+ output_content = result.first_text()
210
+ tool_calls = 1 if self.is_tool_call(last_message) else 0
211
+
212
+ turn_usage = create_turn_usage_from_messages(
213
+ input_content=input_content,
214
+ output_content=output_content,
215
+ model="passthrough",
216
+ model_type="passthrough",
217
+ tool_calls=tool_calls,
218
+ delay_seconds=0.0,
219
+ )
220
+ self.usage_accumulator.add_turn(turn_usage)
221
+
222
+ except Exception as e:
223
+ self.logger.warning(f"Failed to track usage: {e}")
224
+
225
+ return result
164
226
 
165
227
  def is_tool_call(self, message: PromptMessageMultipart) -> bool:
166
228
  return message.first_text().startswith(CALL_TOOL_INDICATOR)
@@ -5,6 +5,7 @@ from mcp_agent.core.prompt import Prompt
5
5
  from mcp_agent.llm.augmented_llm import RequestParams
6
6
  from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
7
7
  from mcp_agent.llm.provider_types import Provider
8
+ from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
8
9
  from mcp_agent.mcp.interfaces import ModelT
9
10
  from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
10
11
  from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
@@ -83,6 +84,24 @@ class PlaybackLLM(PassthroughLLM):
83
84
  message_text=MessageContent.get_first_text(response), title="ASSISTANT/PLAYBACK"
84
85
  )
85
86
 
87
+ # Track usage for this playback "turn"
88
+ try:
89
+ input_content = str(multipart_messages) if multipart_messages else ""
90
+ output_content = MessageContent.get_first_text(response)
91
+
92
+ turn_usage = create_turn_usage_from_messages(
93
+ input_content=input_content,
94
+ output_content=output_content,
95
+ model="playback",
96
+ model_type="playback",
97
+ tool_calls=0,
98
+ delay_seconds=0.0,
99
+ )
100
+ self.usage_accumulator.add_turn(turn_usage)
101
+
102
+ except Exception as e:
103
+ self.logger.warning(f"Failed to track usage: {e}")
104
+
86
105
  return response
87
106
 
88
107
  async def structured(
@@ -30,7 +30,18 @@ class SlowLLM(PassthroughLLM):
30
30
  ) -> str:
31
31
  """Sleep for 3 seconds then return the input message as a string."""
32
32
  await asyncio.sleep(3)
33
- return await super().generate_str(message, request_params)
33
+ result = await super().generate_str(message, request_params)
34
+
35
+ # Override the last turn to include the 3-second delay
36
+ if self.usage_accumulator.turns:
37
+ last_turn = self.usage_accumulator.turns[-1]
38
+ # Update the raw usage to include delay
39
+ if hasattr(last_turn.raw_usage, 'delay_seconds'):
40
+ last_turn.raw_usage.delay_seconds = 3.0
41
+ # Print updated debug info
42
+ print("SlowLLM: Added 3.0s delay to turn usage")
43
+
44
+ return result
34
45
 
35
46
  async def _apply_prompt_provider_specific(
36
47
  self,
mcp_agent/llm/memory.py CHANGED
@@ -35,6 +35,9 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
35
35
  def __init__(self) -> None:
36
36
  self.history: List[MessageParamT] = []
37
37
  self.prompt_messages: List[MessageParamT] = [] # Always included
38
+ self.conversation_cache_positions: List[int] = [] # Track active conversation cache positions
39
+ self.cache_walk_distance: int = 6 # Messages between cache blocks
40
+ self.max_conversation_cache_blocks: int = 2 # Maximum conversation cache blocks
38
41
 
39
42
  def extend(self, messages: List[MessageParamT], is_prompt: bool = False) -> None:
40
43
  """
@@ -99,5 +102,122 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
99
102
  clear_prompts: If True, also clear prompt messages
100
103
  """
101
104
  self.history = []
105
+ self.conversation_cache_positions = [] # Reset cache positions
102
106
  if clear_prompts:
103
107
  self.prompt_messages = []
108
+
109
+ def should_apply_conversation_cache(self) -> bool:
110
+ """
111
+ Determine if conversation caching should be applied based on walking algorithm.
112
+
113
+ Returns:
114
+ True if we should add or update cache blocks
115
+ """
116
+ total_messages = len(self.history)
117
+
118
+ # Need at least cache_walk_distance messages to start caching
119
+ if total_messages < self.cache_walk_distance:
120
+ return False
121
+
122
+ # Check if we need to add a new cache block
123
+ return len(self._calculate_cache_positions(total_messages)) != len(self.conversation_cache_positions)
124
+
125
+ def _calculate_cache_positions(self, total_conversation_messages: int) -> List[int]:
126
+ """
127
+ Calculate where cache blocks should be placed using walking algorithm.
128
+
129
+ Args:
130
+ total_conversation_messages: Number of conversation messages (not including prompts)
131
+
132
+ Returns:
133
+ List of positions (relative to conversation start) where cache should be placed
134
+ """
135
+ positions = []
136
+
137
+ # Place cache blocks every cache_walk_distance messages
138
+ for i in range(self.cache_walk_distance - 1, total_conversation_messages, self.cache_walk_distance):
139
+ positions.append(i)
140
+ if len(positions) >= self.max_conversation_cache_blocks:
141
+ break
142
+
143
+ # Keep only the most recent cache blocks (walking behavior)
144
+ if len(positions) > self.max_conversation_cache_blocks:
145
+ positions = positions[-self.max_conversation_cache_blocks:]
146
+
147
+ return positions
148
+
149
+ def get_conversation_cache_updates(self) -> dict:
150
+ """
151
+ Get cache position updates needed for the walking algorithm.
152
+
153
+ Returns:
154
+ Dict with 'add', 'remove', and 'active' position lists (relative to full message array)
155
+ """
156
+ total_conversation_messages = len(self.history)
157
+ new_positions = self._calculate_cache_positions(total_conversation_messages)
158
+
159
+ # Convert to absolute positions (including prompt messages)
160
+ prompt_offset = len(self.prompt_messages)
161
+ new_absolute_positions = [pos + prompt_offset for pos in new_positions]
162
+
163
+ old_positions_set = set(self.conversation_cache_positions)
164
+ new_positions_set = set(new_absolute_positions)
165
+
166
+ return {
167
+ 'add': sorted(new_positions_set - old_positions_set),
168
+ 'remove': sorted(old_positions_set - new_positions_set),
169
+ 'active': sorted(new_absolute_positions)
170
+ }
171
+
172
+ def apply_conversation_cache_updates(self, updates: dict) -> None:
173
+ """
174
+ Apply cache position updates.
175
+
176
+ Args:
177
+ updates: Dict from get_conversation_cache_updates()
178
+ """
179
+ self.conversation_cache_positions = updates['active'].copy()
180
+
181
+ def remove_cache_control_from_messages(self, messages: List[MessageParamT], positions: List[int]) -> None:
182
+ """
183
+ Remove cache control from specified message positions.
184
+
185
+ Args:
186
+ messages: The message array to modify
187
+ positions: List of positions to remove cache control from
188
+ """
189
+ for pos in positions:
190
+ if pos < len(messages):
191
+ message = messages[pos]
192
+ if isinstance(message, dict) and "content" in message:
193
+ content_list = message["content"]
194
+ if isinstance(content_list, list):
195
+ for content_block in content_list:
196
+ if isinstance(content_block, dict) and "cache_control" in content_block:
197
+ del content_block["cache_control"]
198
+
199
+ def add_cache_control_to_messages(self, messages: List[MessageParamT], positions: List[int]) -> int:
200
+ """
201
+ Add cache control to specified message positions.
202
+
203
+ Args:
204
+ messages: The message array to modify
205
+ positions: List of positions to add cache control to
206
+
207
+ Returns:
208
+ Number of cache blocks successfully applied
209
+ """
210
+ applied_count = 0
211
+ for pos in positions:
212
+ if pos < len(messages):
213
+ message = messages[pos]
214
+ if isinstance(message, dict) and "content" in message:
215
+ content_list = message["content"]
216
+ if isinstance(content_list, list) and content_list:
217
+ # Apply cache control to the last content block
218
+ for content_block in reversed(content_list):
219
+ if isinstance(content_block, dict):
220
+ content_block["cache_control"] = {"type": "ephemeral"}
221
+ applied_count += 1
222
+ break
223
+ return applied_count