fast-agent-mcp 0.2.33__py3-none-any.whl → 0.2.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {fast_agent_mcp-0.2.33.dist-info → fast_agent_mcp-0.2.35.dist-info}/METADATA +1 -1
- {fast_agent_mcp-0.2.33.dist-info → fast_agent_mcp-0.2.35.dist-info}/RECORD +28 -25
- mcp_agent/agents/base_agent.py +13 -0
- mcp_agent/config.py +8 -0
- mcp_agent/context.py +3 -2
- mcp_agent/core/agent_app.py +41 -1
- mcp_agent/core/enhanced_prompt.py +9 -0
- mcp_agent/core/fastagent.py +14 -2
- mcp_agent/core/interactive_prompt.py +59 -13
- mcp_agent/core/usage_display.py +193 -0
- mcp_agent/event_progress.py +22 -4
- mcp_agent/llm/augmented_llm.py +42 -9
- mcp_agent/llm/augmented_llm_passthrough.py +66 -4
- mcp_agent/llm/augmented_llm_playback.py +19 -0
- mcp_agent/llm/augmented_llm_slow.py +12 -1
- mcp_agent/llm/memory.py +120 -0
- mcp_agent/llm/model_database.py +236 -0
- mcp_agent/llm/model_factory.py +1 -0
- mcp_agent/llm/providers/augmented_llm_anthropic.py +211 -30
- mcp_agent/llm/providers/augmented_llm_google_native.py +18 -1
- mcp_agent/llm/providers/augmented_llm_openai.py +20 -7
- mcp_agent/llm/usage_tracking.py +402 -0
- mcp_agent/logging/events.py +24 -0
- mcp_agent/logging/rich_progress.py +9 -1
- mcp_agent/mcp/interfaces.py +6 -0
- {fast_agent_mcp-0.2.33.dist-info → fast_agent_mcp-0.2.35.dist-info}/WHEEL +0 -0
- {fast_agent_mcp-0.2.33.dist-info → fast_agent_mcp-0.2.35.dist-info}/entry_points.txt +0 -0
- {fast_agent_mcp-0.2.33.dist-info → fast_agent_mcp-0.2.35.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,193 @@
|
|
1
|
+
"""
|
2
|
+
Utility module for displaying usage statistics in a consistent format.
|
3
|
+
Consolidates the usage display logic that was duplicated between fastagent.py and interactive_prompt.py.
|
4
|
+
"""
|
5
|
+
|
6
|
+
from typing import Any, Dict, Optional
|
7
|
+
|
8
|
+
from rich.console import Console
|
9
|
+
|
10
|
+
|
11
|
+
def display_usage_report(
|
12
|
+
agents: Dict[str, Any], show_if_progress_disabled: bool = False, subdued_colors: bool = False
|
13
|
+
) -> None:
|
14
|
+
"""
|
15
|
+
Display a formatted table of token usage for all agents.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
agents: Dictionary of agent name -> agent object
|
19
|
+
show_if_progress_disabled: If True, show even when progress display is disabled
|
20
|
+
subdued_colors: If True, use dim styling for a more subdued appearance
|
21
|
+
"""
|
22
|
+
# Check if progress display is enabled (only relevant for fastagent context)
|
23
|
+
if not show_if_progress_disabled:
|
24
|
+
try:
|
25
|
+
from mcp_agent import config
|
26
|
+
|
27
|
+
settings = config.get_settings()
|
28
|
+
if not settings.logger.progress_display:
|
29
|
+
return
|
30
|
+
except (ImportError, AttributeError):
|
31
|
+
# If we can't check settings, assume we should display
|
32
|
+
pass
|
33
|
+
|
34
|
+
# Collect usage data from all agents
|
35
|
+
usage_data = []
|
36
|
+
total_input = 0
|
37
|
+
total_output = 0
|
38
|
+
total_tokens = 0
|
39
|
+
|
40
|
+
for agent_name, agent in agents.items():
|
41
|
+
if agent.usage_accumulator:
|
42
|
+
summary = agent.usage_accumulator.get_summary()
|
43
|
+
if summary["turn_count"] > 0:
|
44
|
+
input_tokens = summary["cumulative_input_tokens"]
|
45
|
+
output_tokens = summary["cumulative_output_tokens"]
|
46
|
+
billing_tokens = summary["cumulative_billing_tokens"]
|
47
|
+
turns = summary["turn_count"]
|
48
|
+
|
49
|
+
# Get context percentage for this agent
|
50
|
+
context_percentage = agent.usage_accumulator.context_usage_percentage
|
51
|
+
|
52
|
+
# Get model name from LLM's default_request_params
|
53
|
+
model = "unknown"
|
54
|
+
if hasattr(agent, "_llm") and agent._llm:
|
55
|
+
llm = agent._llm
|
56
|
+
if (
|
57
|
+
hasattr(llm, "default_request_params")
|
58
|
+
and llm.default_request_params
|
59
|
+
and hasattr(llm.default_request_params, "model")
|
60
|
+
):
|
61
|
+
model = llm.default_request_params.model or "unknown"
|
62
|
+
|
63
|
+
# Standardize model name truncation - use consistent 25 char width with 22+... truncation
|
64
|
+
if len(model) > 25:
|
65
|
+
model = model[:22] + "..."
|
66
|
+
|
67
|
+
usage_data.append(
|
68
|
+
{
|
69
|
+
"name": agent_name,
|
70
|
+
"model": model,
|
71
|
+
"input": input_tokens,
|
72
|
+
"output": output_tokens,
|
73
|
+
"total": billing_tokens,
|
74
|
+
"turns": turns,
|
75
|
+
"context": context_percentage,
|
76
|
+
}
|
77
|
+
)
|
78
|
+
|
79
|
+
total_input += input_tokens
|
80
|
+
total_output += output_tokens
|
81
|
+
total_tokens += billing_tokens
|
82
|
+
|
83
|
+
if not usage_data:
|
84
|
+
return
|
85
|
+
|
86
|
+
# Calculate dynamic agent column width (max 15)
|
87
|
+
max_agent_width = min(15, max(len(data["name"]) for data in usage_data) if usage_data else 8)
|
88
|
+
agent_width = max(max_agent_width, 5) # Minimum of 5 for "Agent" header
|
89
|
+
|
90
|
+
# Display the table
|
91
|
+
console = Console()
|
92
|
+
console.print()
|
93
|
+
console.print("[dim]Usage Summary (Cumulative)[/dim]")
|
94
|
+
|
95
|
+
# Print header with proper spacing
|
96
|
+
console.print(
|
97
|
+
f"[dim]{'Agent':<{agent_width}} {'Input':>9} {'Output':>9} {'Total':>9} {'Turns':>6} {'Context%':>9} {'Model':<25}[/dim]"
|
98
|
+
)
|
99
|
+
|
100
|
+
# Print agent rows - use styling based on subdued_colors flag
|
101
|
+
for data in usage_data:
|
102
|
+
input_str = f"{data['input']:,}"
|
103
|
+
output_str = f"{data['output']:,}"
|
104
|
+
total_str = f"{data['total']:,}"
|
105
|
+
turns_str = str(data["turns"])
|
106
|
+
context_str = f"{data['context']:.1f}%" if data["context"] is not None else "-"
|
107
|
+
|
108
|
+
# Truncate agent name if needed
|
109
|
+
agent_name = data["name"]
|
110
|
+
if len(agent_name) > agent_width:
|
111
|
+
agent_name = agent_name[: agent_width - 3] + "..."
|
112
|
+
|
113
|
+
if subdued_colors:
|
114
|
+
# Original fastagent.py style with dim wrapper
|
115
|
+
console.print(
|
116
|
+
f"[dim]{agent_name:<{agent_width}} "
|
117
|
+
f"{input_str:>9} "
|
118
|
+
f"{output_str:>9} "
|
119
|
+
f"[bold]{total_str:>9}[/bold] "
|
120
|
+
f"{turns_str:>6} "
|
121
|
+
f"{context_str:>9} "
|
122
|
+
f"{data['model']:<25}[/dim]"
|
123
|
+
)
|
124
|
+
else:
|
125
|
+
# Original interactive_prompt.py style
|
126
|
+
console.print(
|
127
|
+
f"{agent_name:<{agent_width}} "
|
128
|
+
f"{input_str:>9} "
|
129
|
+
f"{output_str:>9} "
|
130
|
+
f"[bold]{total_str:>9}[/bold] "
|
131
|
+
f"{turns_str:>6} "
|
132
|
+
f"{context_str:>9} "
|
133
|
+
f"[dim]{data['model']:<25}[/dim]"
|
134
|
+
)
|
135
|
+
|
136
|
+
# Add total row if multiple agents
|
137
|
+
if len(usage_data) > 1:
|
138
|
+
console.print()
|
139
|
+
total_input_str = f"{total_input:,}"
|
140
|
+
total_output_str = f"{total_output:,}"
|
141
|
+
total_tokens_str = f"{total_tokens:,}"
|
142
|
+
|
143
|
+
if subdued_colors:
|
144
|
+
# Original fastagent.py style with dim wrapper on bold
|
145
|
+
console.print(
|
146
|
+
f"[bold dim]{'TOTAL':<{agent_width}} "
|
147
|
+
f"{total_input_str:>9} "
|
148
|
+
f"{total_output_str:>9} "
|
149
|
+
f"[bold]{total_tokens_str:>9}[/bold] "
|
150
|
+
f"{'':<6} "
|
151
|
+
f"{'':<9} "
|
152
|
+
f"{'':<25}[/bold dim]"
|
153
|
+
)
|
154
|
+
else:
|
155
|
+
# Original interactive_prompt.py style
|
156
|
+
console.print(
|
157
|
+
f"[bold]{'TOTAL':<{agent_width}}[/bold] "
|
158
|
+
f"[bold]{total_input_str:>9}[/bold] "
|
159
|
+
f"[bold]{total_output_str:>9}[/bold] "
|
160
|
+
f"[bold]{total_tokens_str:>9}[/bold] "
|
161
|
+
f"{'':<6} "
|
162
|
+
f"{'':<9} "
|
163
|
+
f"{'':<25}"
|
164
|
+
)
|
165
|
+
|
166
|
+
console.print()
|
167
|
+
|
168
|
+
|
169
|
+
def collect_agents_from_provider(
|
170
|
+
prompt_provider: Any, agent_name: Optional[str] = None
|
171
|
+
) -> Dict[str, Any]:
|
172
|
+
"""
|
173
|
+
Collect agents from a prompt provider for usage display.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
prompt_provider: Provider that has access to agents
|
177
|
+
agent_name: Name of the current agent (for context)
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Dictionary of agent name -> agent object
|
181
|
+
"""
|
182
|
+
agents_to_show = {}
|
183
|
+
|
184
|
+
if hasattr(prompt_provider, "_agents"):
|
185
|
+
# Multi-agent app - show all agents
|
186
|
+
agents_to_show = prompt_provider._agents
|
187
|
+
elif hasattr(prompt_provider, "agent"):
|
188
|
+
# Single agent
|
189
|
+
agent = prompt_provider.agent
|
190
|
+
if hasattr(agent, "name"):
|
191
|
+
agents_to_show = {agent.name: agent}
|
192
|
+
|
193
|
+
return agents_to_show
|
mcp_agent/event_progress.py
CHANGED
@@ -15,6 +15,7 @@ class ProgressAction(str, Enum):
|
|
15
15
|
LOADED = "Loaded"
|
16
16
|
INITIALIZED = "Initialized"
|
17
17
|
CHATTING = "Chatting"
|
18
|
+
STREAMING = "Streaming" # Special action for real-time streaming updates
|
18
19
|
ROUTING = "Routing"
|
19
20
|
PLANNING = "Planning"
|
20
21
|
READY = "Ready"
|
@@ -33,12 +34,22 @@ class ProgressEvent(BaseModel):
|
|
33
34
|
target: str
|
34
35
|
details: Optional[str] = None
|
35
36
|
agent_name: Optional[str] = None
|
37
|
+
streaming_tokens: Optional[str] = None # Special field for streaming token count
|
36
38
|
|
37
39
|
def __str__(self) -> str:
|
38
40
|
"""Format the progress event for display."""
|
39
|
-
|
40
|
-
if self.
|
41
|
-
|
41
|
+
# Special handling for streaming - show token count in action position
|
42
|
+
if self.action == ProgressAction.STREAMING and self.streaming_tokens:
|
43
|
+
# For streaming, show just the token count instead of "Streaming"
|
44
|
+
action_display = self.streaming_tokens.ljust(11)
|
45
|
+
base = f"{action_display}. {self.target}"
|
46
|
+
if self.details:
|
47
|
+
base += f" - {self.details}"
|
48
|
+
else:
|
49
|
+
base = f"{self.action.ljust(11)}. {self.target}"
|
50
|
+
if self.details:
|
51
|
+
base += f" - {self.details}"
|
52
|
+
|
42
53
|
if self.agent_name:
|
43
54
|
base = f"[{self.agent_name}] {base}"
|
44
55
|
return base
|
@@ -78,7 +89,8 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
|
|
78
89
|
|
79
90
|
elif "augmented_llm" in namespace:
|
80
91
|
model = event_data.get("model", "")
|
81
|
-
|
92
|
+
|
93
|
+
# For all augmented_llm events, put model info in details column
|
82
94
|
details = f"{model}"
|
83
95
|
chat_turn = event_data.get("chat_turn")
|
84
96
|
if chat_turn is not None:
|
@@ -87,9 +99,15 @@ def convert_log_event(event: Event) -> Optional[ProgressEvent]:
|
|
87
99
|
if not target:
|
88
100
|
target = event_data.get("target", "unknown")
|
89
101
|
|
102
|
+
# Extract streaming token count for STREAMING actions
|
103
|
+
streaming_tokens = None
|
104
|
+
if progress_action == ProgressAction.STREAMING:
|
105
|
+
streaming_tokens = event_data.get("details", "")
|
106
|
+
|
90
107
|
return ProgressEvent(
|
91
108
|
action=ProgressAction(progress_action),
|
92
109
|
target=target or "unknown",
|
93
110
|
details=details,
|
94
111
|
agent_name=event_data.get("agent_name"),
|
112
|
+
streaming_tokens=streaming_tokens,
|
95
113
|
)
|
mcp_agent/llm/augmented_llm.py
CHANGED
@@ -30,11 +30,13 @@ from mcp_agent.core.prompt import Prompt
|
|
30
30
|
from mcp_agent.core.request_params import RequestParams
|
31
31
|
from mcp_agent.event_progress import ProgressAction
|
32
32
|
from mcp_agent.llm.memory import Memory, SimpleMemory
|
33
|
+
from mcp_agent.llm.model_database import ModelDatabase
|
33
34
|
from mcp_agent.llm.provider_types import Provider
|
34
35
|
from mcp_agent.llm.sampling_format_converter import (
|
35
36
|
BasicFormatConverter,
|
36
37
|
ProviderFormatConverter,
|
37
38
|
)
|
39
|
+
from mcp_agent.llm.usage_tracking import UsageAccumulator
|
38
40
|
from mcp_agent.logging.logger import get_logger
|
39
41
|
from mcp_agent.mcp.helpers.content_helpers import get_text
|
40
42
|
from mcp_agent.mcp.interfaces import (
|
@@ -95,6 +97,7 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
95
97
|
PARAM_USE_HISTORY = "use_history"
|
96
98
|
PARAM_MAX_ITERATIONS = "max_iterations"
|
97
99
|
PARAM_TEMPLATE_VARS = "template_vars"
|
100
|
+
|
98
101
|
# Base set of fields that should always be excluded
|
99
102
|
BASE_EXCLUDE_FIELDS = {PARAM_METADATA}
|
100
103
|
|
@@ -155,12 +158,11 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
155
158
|
# Initialize the display component
|
156
159
|
self.display = ConsoleDisplay(config=self.context.config)
|
157
160
|
|
158
|
-
# Initialize default parameters
|
159
|
-
|
160
|
-
|
161
|
-
# Apply model override if provided
|
161
|
+
# Initialize default parameters, passing model info
|
162
|
+
model_kwargs = kwargs.copy()
|
162
163
|
if model:
|
163
|
-
|
164
|
+
model_kwargs["model"] = model
|
165
|
+
self.default_request_params = self._initialize_default_params(model_kwargs)
|
164
166
|
|
165
167
|
# Merge with provided params if any
|
166
168
|
if self._init_request_params:
|
@@ -171,13 +173,22 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
171
173
|
self.type_converter = type_converter
|
172
174
|
self.verb = kwargs.get("verb")
|
173
175
|
|
176
|
+
# Initialize usage tracking
|
177
|
+
self.usage_accumulator = UsageAccumulator()
|
178
|
+
|
174
179
|
def _initialize_default_params(self, kwargs: dict) -> RequestParams:
|
175
180
|
"""Initialize default parameters for the LLM.
|
176
181
|
Should be overridden by provider implementations to set provider-specific defaults."""
|
182
|
+
# Get model-aware default max tokens
|
183
|
+
model = kwargs.get("model")
|
184
|
+
max_tokens = ModelDatabase.get_default_max_tokens(model)
|
185
|
+
|
177
186
|
return RequestParams(
|
187
|
+
model=model,
|
188
|
+
maxTokens=max_tokens,
|
178
189
|
systemPrompt=self.instruction,
|
179
190
|
parallel_tool_calls=True,
|
180
|
-
max_iterations=
|
191
|
+
max_iterations=20,
|
181
192
|
use_history=True,
|
182
193
|
)
|
183
194
|
|
@@ -361,16 +372,28 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
361
372
|
# Start with base arguments
|
362
373
|
arguments = base_args.copy()
|
363
374
|
|
364
|
-
#
|
365
|
-
|
375
|
+
# Combine base exclusions with provider-specific exclusions
|
376
|
+
final_exclude_fields = self.BASE_EXCLUDE_FIELDS.copy()
|
377
|
+
if exclude_fields:
|
378
|
+
final_exclude_fields.update(exclude_fields)
|
366
379
|
|
367
380
|
# Add all fields from params that aren't explicitly excluded
|
368
|
-
|
381
|
+
# Ensure model_dump only includes set fields if that's the desired behavior,
|
382
|
+
# or adjust exclude_unset=True/False as needed.
|
383
|
+
# Default Pydantic v2 model_dump is exclude_unset=False
|
384
|
+
params_dict = request_params.model_dump(exclude=final_exclude_fields)
|
385
|
+
|
369
386
|
for key, value in params_dict.items():
|
387
|
+
# Only add if not None and not already in base_args (base_args take precedence)
|
388
|
+
# or if None is a valid value for the provider, this logic might need adjustment.
|
370
389
|
if value is not None and key not in arguments:
|
371
390
|
arguments[key] = value
|
391
|
+
elif value is not None and key in arguments and arguments[key] is None:
|
392
|
+
# Allow overriding a None in base_args with a set value from params
|
393
|
+
arguments[key] = value
|
372
394
|
|
373
395
|
# Finally, add any metadata fields as a last layer of overrides
|
396
|
+
# This ensures metadata can override anything previously set if keys conflict.
|
374
397
|
if request_params.metadata:
|
375
398
|
arguments.update(request_params.metadata)
|
376
399
|
|
@@ -642,3 +665,13 @@ class AugmentedLLM(ContextDependent, AugmentedLLMProtocol, Generic[MessageParamT
|
|
642
665
|
|
643
666
|
assert self.provider
|
644
667
|
return ProviderKeyManager.get_api_key(self.provider.value, self.context.config)
|
668
|
+
|
669
|
+
def get_usage_summary(self) -> dict:
|
670
|
+
"""
|
671
|
+
Get a summary of usage statistics for this LLM instance.
|
672
|
+
|
673
|
+
Returns:
|
674
|
+
Dictionary containing usage statistics including tokens, cache metrics,
|
675
|
+
and context window utilization.
|
676
|
+
"""
|
677
|
+
return self.usage_accumulator.get_summary()
|
@@ -10,6 +10,7 @@ from mcp_agent.llm.augmented_llm import (
|
|
10
10
|
RequestParams,
|
11
11
|
)
|
12
12
|
from mcp_agent.llm.provider_types import Provider
|
13
|
+
from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
|
13
14
|
from mcp_agent.logging.logger import get_logger
|
14
15
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
15
16
|
|
@@ -48,13 +49,34 @@ class PassthroughLLM(AugmentedLLM):
|
|
48
49
|
await self.show_assistant_message(message, title="ASSISTANT/PASSTHROUGH")
|
49
50
|
|
50
51
|
# Handle PromptMessage by concatenating all parts
|
52
|
+
result = ""
|
51
53
|
if isinstance(message, PromptMessage):
|
52
54
|
parts_text = []
|
53
55
|
for part in message.content:
|
54
56
|
parts_text.append(str(part))
|
55
|
-
|
57
|
+
result = "\n".join(parts_text)
|
58
|
+
else:
|
59
|
+
result = str(message)
|
56
60
|
|
57
|
-
|
61
|
+
# Track usage for this passthrough "turn"
|
62
|
+
try:
|
63
|
+
input_content = str(message)
|
64
|
+
output_content = result
|
65
|
+
tool_calls = 1 if input_content.startswith("***CALL_TOOL") else 0
|
66
|
+
|
67
|
+
turn_usage = create_turn_usage_from_messages(
|
68
|
+
input_content=input_content,
|
69
|
+
output_content=output_content,
|
70
|
+
model="passthrough",
|
71
|
+
model_type="passthrough",
|
72
|
+
tool_calls=tool_calls,
|
73
|
+
delay_seconds=0.0,
|
74
|
+
)
|
75
|
+
self.usage_accumulator.add_turn(turn_usage)
|
76
|
+
except Exception as e:
|
77
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
78
|
+
|
79
|
+
return result
|
58
80
|
|
59
81
|
async def initialize(self) -> None:
|
60
82
|
pass
|
@@ -146,6 +168,25 @@ class PassthroughLLM(AugmentedLLM):
|
|
146
168
|
if self.is_tool_call(last_message):
|
147
169
|
result = Prompt.assistant(await self.generate_str(last_message.first_text()))
|
148
170
|
await self.show_assistant_message(result.first_text())
|
171
|
+
|
172
|
+
# Track usage for this tool call "turn"
|
173
|
+
try:
|
174
|
+
input_content = "\n".join(message.all_text() for message in multipart_messages)
|
175
|
+
output_content = result.first_text()
|
176
|
+
|
177
|
+
turn_usage = create_turn_usage_from_messages(
|
178
|
+
input_content=input_content,
|
179
|
+
output_content=output_content,
|
180
|
+
model="passthrough",
|
181
|
+
model_type="passthrough",
|
182
|
+
tool_calls=1, # This is definitely a tool call
|
183
|
+
delay_seconds=0.0,
|
184
|
+
)
|
185
|
+
self.usage_accumulator.add_turn(turn_usage)
|
186
|
+
|
187
|
+
except Exception as e:
|
188
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
189
|
+
|
149
190
|
return result
|
150
191
|
|
151
192
|
if last_message.first_text().startswith(FIXED_RESPONSE_INDICATOR):
|
@@ -155,12 +196,33 @@ class PassthroughLLM(AugmentedLLM):
|
|
155
196
|
|
156
197
|
if self._fixed_response:
|
157
198
|
await self.show_assistant_message(self._fixed_response)
|
158
|
-
|
199
|
+
result = Prompt.assistant(self._fixed_response)
|
159
200
|
else:
|
160
201
|
# TODO -- improve when we support Audio/Multimodal gen models e.g. gemini . This should really just return the input as "assistant"...
|
161
202
|
concatenated: str = "\n".join(message.all_text() for message in multipart_messages)
|
162
203
|
await self.show_assistant_message(concatenated)
|
163
|
-
|
204
|
+
result = Prompt.assistant(concatenated)
|
205
|
+
|
206
|
+
# Track usage for this passthrough "turn"
|
207
|
+
try:
|
208
|
+
input_content = "\n".join(message.all_text() for message in multipart_messages)
|
209
|
+
output_content = result.first_text()
|
210
|
+
tool_calls = 1 if self.is_tool_call(last_message) else 0
|
211
|
+
|
212
|
+
turn_usage = create_turn_usage_from_messages(
|
213
|
+
input_content=input_content,
|
214
|
+
output_content=output_content,
|
215
|
+
model="passthrough",
|
216
|
+
model_type="passthrough",
|
217
|
+
tool_calls=tool_calls,
|
218
|
+
delay_seconds=0.0,
|
219
|
+
)
|
220
|
+
self.usage_accumulator.add_turn(turn_usage)
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
224
|
+
|
225
|
+
return result
|
164
226
|
|
165
227
|
def is_tool_call(self, message: PromptMessageMultipart) -> bool:
|
166
228
|
return message.first_text().startswith(CALL_TOOL_INDICATOR)
|
@@ -5,6 +5,7 @@ from mcp_agent.core.prompt import Prompt
|
|
5
5
|
from mcp_agent.llm.augmented_llm import RequestParams
|
6
6
|
from mcp_agent.llm.augmented_llm_passthrough import PassthroughLLM
|
7
7
|
from mcp_agent.llm.provider_types import Provider
|
8
|
+
from mcp_agent.llm.usage_tracking import create_turn_usage_from_messages
|
8
9
|
from mcp_agent.mcp.interfaces import ModelT
|
9
10
|
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
|
10
11
|
from mcp_agent.mcp.prompts.prompt_helpers import MessageContent
|
@@ -83,6 +84,24 @@ class PlaybackLLM(PassthroughLLM):
|
|
83
84
|
message_text=MessageContent.get_first_text(response), title="ASSISTANT/PLAYBACK"
|
84
85
|
)
|
85
86
|
|
87
|
+
# Track usage for this playback "turn"
|
88
|
+
try:
|
89
|
+
input_content = str(multipart_messages) if multipart_messages else ""
|
90
|
+
output_content = MessageContent.get_first_text(response)
|
91
|
+
|
92
|
+
turn_usage = create_turn_usage_from_messages(
|
93
|
+
input_content=input_content,
|
94
|
+
output_content=output_content,
|
95
|
+
model="playback",
|
96
|
+
model_type="playback",
|
97
|
+
tool_calls=0,
|
98
|
+
delay_seconds=0.0,
|
99
|
+
)
|
100
|
+
self.usage_accumulator.add_turn(turn_usage)
|
101
|
+
|
102
|
+
except Exception as e:
|
103
|
+
self.logger.warning(f"Failed to track usage: {e}")
|
104
|
+
|
86
105
|
return response
|
87
106
|
|
88
107
|
async def structured(
|
@@ -30,7 +30,18 @@ class SlowLLM(PassthroughLLM):
|
|
30
30
|
) -> str:
|
31
31
|
"""Sleep for 3 seconds then return the input message as a string."""
|
32
32
|
await asyncio.sleep(3)
|
33
|
-
|
33
|
+
result = await super().generate_str(message, request_params)
|
34
|
+
|
35
|
+
# Override the last turn to include the 3-second delay
|
36
|
+
if self.usage_accumulator.turns:
|
37
|
+
last_turn = self.usage_accumulator.turns[-1]
|
38
|
+
# Update the raw usage to include delay
|
39
|
+
if hasattr(last_turn.raw_usage, 'delay_seconds'):
|
40
|
+
last_turn.raw_usage.delay_seconds = 3.0
|
41
|
+
# Print updated debug info
|
42
|
+
print("SlowLLM: Added 3.0s delay to turn usage")
|
43
|
+
|
44
|
+
return result
|
34
45
|
|
35
46
|
async def _apply_prompt_provider_specific(
|
36
47
|
self,
|
mcp_agent/llm/memory.py
CHANGED
@@ -35,6 +35,9 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
|
|
35
35
|
def __init__(self) -> None:
|
36
36
|
self.history: List[MessageParamT] = []
|
37
37
|
self.prompt_messages: List[MessageParamT] = [] # Always included
|
38
|
+
self.conversation_cache_positions: List[int] = [] # Track active conversation cache positions
|
39
|
+
self.cache_walk_distance: int = 6 # Messages between cache blocks
|
40
|
+
self.max_conversation_cache_blocks: int = 2 # Maximum conversation cache blocks
|
38
41
|
|
39
42
|
def extend(self, messages: List[MessageParamT], is_prompt: bool = False) -> None:
|
40
43
|
"""
|
@@ -99,5 +102,122 @@ class SimpleMemory(Memory, Generic[MessageParamT]):
|
|
99
102
|
clear_prompts: If True, also clear prompt messages
|
100
103
|
"""
|
101
104
|
self.history = []
|
105
|
+
self.conversation_cache_positions = [] # Reset cache positions
|
102
106
|
if clear_prompts:
|
103
107
|
self.prompt_messages = []
|
108
|
+
|
109
|
+
def should_apply_conversation_cache(self) -> bool:
|
110
|
+
"""
|
111
|
+
Determine if conversation caching should be applied based on walking algorithm.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
True if we should add or update cache blocks
|
115
|
+
"""
|
116
|
+
total_messages = len(self.history)
|
117
|
+
|
118
|
+
# Need at least cache_walk_distance messages to start caching
|
119
|
+
if total_messages < self.cache_walk_distance:
|
120
|
+
return False
|
121
|
+
|
122
|
+
# Check if we need to add a new cache block
|
123
|
+
return len(self._calculate_cache_positions(total_messages)) != len(self.conversation_cache_positions)
|
124
|
+
|
125
|
+
def _calculate_cache_positions(self, total_conversation_messages: int) -> List[int]:
|
126
|
+
"""
|
127
|
+
Calculate where cache blocks should be placed using walking algorithm.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
total_conversation_messages: Number of conversation messages (not including prompts)
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
List of positions (relative to conversation start) where cache should be placed
|
134
|
+
"""
|
135
|
+
positions = []
|
136
|
+
|
137
|
+
# Place cache blocks every cache_walk_distance messages
|
138
|
+
for i in range(self.cache_walk_distance - 1, total_conversation_messages, self.cache_walk_distance):
|
139
|
+
positions.append(i)
|
140
|
+
if len(positions) >= self.max_conversation_cache_blocks:
|
141
|
+
break
|
142
|
+
|
143
|
+
# Keep only the most recent cache blocks (walking behavior)
|
144
|
+
if len(positions) > self.max_conversation_cache_blocks:
|
145
|
+
positions = positions[-self.max_conversation_cache_blocks:]
|
146
|
+
|
147
|
+
return positions
|
148
|
+
|
149
|
+
def get_conversation_cache_updates(self) -> dict:
|
150
|
+
"""
|
151
|
+
Get cache position updates needed for the walking algorithm.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
Dict with 'add', 'remove', and 'active' position lists (relative to full message array)
|
155
|
+
"""
|
156
|
+
total_conversation_messages = len(self.history)
|
157
|
+
new_positions = self._calculate_cache_positions(total_conversation_messages)
|
158
|
+
|
159
|
+
# Convert to absolute positions (including prompt messages)
|
160
|
+
prompt_offset = len(self.prompt_messages)
|
161
|
+
new_absolute_positions = [pos + prompt_offset for pos in new_positions]
|
162
|
+
|
163
|
+
old_positions_set = set(self.conversation_cache_positions)
|
164
|
+
new_positions_set = set(new_absolute_positions)
|
165
|
+
|
166
|
+
return {
|
167
|
+
'add': sorted(new_positions_set - old_positions_set),
|
168
|
+
'remove': sorted(old_positions_set - new_positions_set),
|
169
|
+
'active': sorted(new_absolute_positions)
|
170
|
+
}
|
171
|
+
|
172
|
+
def apply_conversation_cache_updates(self, updates: dict) -> None:
|
173
|
+
"""
|
174
|
+
Apply cache position updates.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
updates: Dict from get_conversation_cache_updates()
|
178
|
+
"""
|
179
|
+
self.conversation_cache_positions = updates['active'].copy()
|
180
|
+
|
181
|
+
def remove_cache_control_from_messages(self, messages: List[MessageParamT], positions: List[int]) -> None:
|
182
|
+
"""
|
183
|
+
Remove cache control from specified message positions.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
messages: The message array to modify
|
187
|
+
positions: List of positions to remove cache control from
|
188
|
+
"""
|
189
|
+
for pos in positions:
|
190
|
+
if pos < len(messages):
|
191
|
+
message = messages[pos]
|
192
|
+
if isinstance(message, dict) and "content" in message:
|
193
|
+
content_list = message["content"]
|
194
|
+
if isinstance(content_list, list):
|
195
|
+
for content_block in content_list:
|
196
|
+
if isinstance(content_block, dict) and "cache_control" in content_block:
|
197
|
+
del content_block["cache_control"]
|
198
|
+
|
199
|
+
def add_cache_control_to_messages(self, messages: List[MessageParamT], positions: List[int]) -> int:
|
200
|
+
"""
|
201
|
+
Add cache control to specified message positions.
|
202
|
+
|
203
|
+
Args:
|
204
|
+
messages: The message array to modify
|
205
|
+
positions: List of positions to add cache control to
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Number of cache blocks successfully applied
|
209
|
+
"""
|
210
|
+
applied_count = 0
|
211
|
+
for pos in positions:
|
212
|
+
if pos < len(messages):
|
213
|
+
message = messages[pos]
|
214
|
+
if isinstance(message, dict) and "content" in message:
|
215
|
+
content_list = message["content"]
|
216
|
+
if isinstance(content_list, list) and content_list:
|
217
|
+
# Apply cache control to the last content block
|
218
|
+
for content_block in reversed(content_list):
|
219
|
+
if isinstance(content_block, dict):
|
220
|
+
content_block["cache_control"] = {"type": "ephemeral"}
|
221
|
+
applied_count += 1
|
222
|
+
break
|
223
|
+
return applied_count
|