agentpool 2.1.9__py3-none-any.whl → 2.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- acp/__init__.py +13 -0
- acp/bridge/README.md +15 -2
- acp/bridge/__init__.py +3 -2
- acp/bridge/__main__.py +60 -19
- acp/bridge/ws_server.py +173 -0
- acp/bridge/ws_server_cli.py +89 -0
- acp/notifications.py +2 -1
- acp/stdio.py +39 -9
- acp/transports.py +362 -2
- acp/utils.py +15 -2
- agentpool/__init__.py +4 -1
- agentpool/agents/__init__.py +2 -0
- agentpool/agents/acp_agent/acp_agent.py +203 -88
- agentpool/agents/acp_agent/acp_converters.py +46 -21
- agentpool/agents/acp_agent/client_handler.py +157 -3
- agentpool/agents/acp_agent/session_state.py +4 -1
- agentpool/agents/agent.py +314 -107
- agentpool/agents/agui_agent/__init__.py +0 -2
- agentpool/agents/agui_agent/agui_agent.py +90 -21
- agentpool/agents/agui_agent/agui_converters.py +0 -131
- agentpool/agents/base_agent.py +163 -1
- agentpool/agents/claude_code_agent/claude_code_agent.py +626 -179
- agentpool/agents/claude_code_agent/converters.py +71 -3
- agentpool/agents/claude_code_agent/history.py +474 -0
- agentpool/agents/context.py +40 -0
- agentpool/agents/events/__init__.py +2 -0
- agentpool/agents/events/builtin_handlers.py +2 -1
- agentpool/agents/events/event_emitter.py +29 -2
- agentpool/agents/events/events.py +20 -0
- agentpool/agents/modes.py +54 -0
- agentpool/agents/tool_call_accumulator.py +213 -0
- agentpool/common_types.py +21 -0
- agentpool/config_resources/__init__.py +38 -1
- agentpool/config_resources/claude_code_agent.yml +3 -0
- agentpool/delegation/pool.py +37 -29
- agentpool/delegation/team.py +1 -0
- agentpool/delegation/teamrun.py +1 -0
- agentpool/diagnostics/__init__.py +53 -0
- agentpool/diagnostics/lsp_manager.py +1593 -0
- agentpool/diagnostics/lsp_proxy.py +41 -0
- agentpool/diagnostics/lsp_proxy_script.py +229 -0
- agentpool/diagnostics/models.py +398 -0
- agentpool/mcp_server/__init__.py +0 -2
- agentpool/mcp_server/client.py +12 -3
- agentpool/mcp_server/manager.py +25 -31
- agentpool/mcp_server/registries/official_registry_client.py +25 -0
- agentpool/mcp_server/tool_bridge.py +78 -66
- agentpool/messaging/__init__.py +0 -2
- agentpool/messaging/compaction.py +72 -197
- agentpool/messaging/message_history.py +12 -0
- agentpool/messaging/messages.py +52 -9
- agentpool/messaging/processing.py +3 -1
- agentpool/models/acp_agents/base.py +0 -22
- agentpool/models/acp_agents/mcp_capable.py +8 -148
- agentpool/models/acp_agents/non_mcp.py +129 -72
- agentpool/models/agents.py +35 -13
- agentpool/models/claude_code_agents.py +33 -2
- agentpool/models/manifest.py +43 -0
- agentpool/repomap.py +1 -1
- agentpool/resource_providers/__init__.py +9 -1
- agentpool/resource_providers/aggregating.py +52 -3
- agentpool/resource_providers/base.py +57 -1
- agentpool/resource_providers/mcp_provider.py +23 -0
- agentpool/resource_providers/plan_provider.py +130 -41
- agentpool/resource_providers/pool.py +2 -0
- agentpool/resource_providers/static.py +2 -0
- agentpool/sessions/__init__.py +2 -1
- agentpool/sessions/manager.py +31 -2
- agentpool/sessions/models.py +50 -0
- agentpool/skills/registry.py +13 -8
- agentpool/storage/manager.py +217 -1
- agentpool/testing.py +537 -19
- agentpool/utils/file_watcher.py +269 -0
- agentpool/utils/identifiers.py +121 -0
- agentpool/utils/pydantic_ai_helpers.py +46 -0
- agentpool/utils/streams.py +690 -1
- agentpool/utils/subprocess_utils.py +155 -0
- agentpool/utils/token_breakdown.py +461 -0
- {agentpool-2.1.9.dist-info → agentpool-2.2.3.dist-info}/METADATA +27 -7
- {agentpool-2.1.9.dist-info → agentpool-2.2.3.dist-info}/RECORD +170 -112
- {agentpool-2.1.9.dist-info → agentpool-2.2.3.dist-info}/WHEEL +1 -1
- agentpool_cli/__main__.py +4 -0
- agentpool_cli/serve_acp.py +41 -20
- agentpool_cli/serve_agui.py +87 -0
- agentpool_cli/serve_opencode.py +119 -0
- agentpool_commands/__init__.py +30 -0
- agentpool_commands/agents.py +74 -1
- agentpool_commands/history.py +62 -0
- agentpool_commands/mcp.py +176 -0
- agentpool_commands/models.py +56 -3
- agentpool_commands/tools.py +57 -0
- agentpool_commands/utils.py +51 -0
- agentpool_config/builtin_tools.py +77 -22
- agentpool_config/commands.py +24 -1
- agentpool_config/compaction.py +258 -0
- agentpool_config/mcp_server.py +131 -1
- agentpool_config/storage.py +46 -1
- agentpool_config/tools.py +7 -1
- agentpool_config/toolsets.py +92 -148
- agentpool_server/acp_server/acp_agent.py +134 -150
- agentpool_server/acp_server/commands/acp_commands.py +216 -51
- agentpool_server/acp_server/commands/docs_commands/fetch_repo.py +10 -10
- agentpool_server/acp_server/server.py +23 -79
- agentpool_server/acp_server/session.py +181 -19
- agentpool_server/opencode_server/.rules +95 -0
- agentpool_server/opencode_server/ENDPOINTS.md +362 -0
- agentpool_server/opencode_server/__init__.py +27 -0
- agentpool_server/opencode_server/command_validation.py +172 -0
- agentpool_server/opencode_server/converters.py +869 -0
- agentpool_server/opencode_server/dependencies.py +24 -0
- agentpool_server/opencode_server/input_provider.py +269 -0
- agentpool_server/opencode_server/models/__init__.py +228 -0
- agentpool_server/opencode_server/models/agent.py +53 -0
- agentpool_server/opencode_server/models/app.py +60 -0
- agentpool_server/opencode_server/models/base.py +26 -0
- agentpool_server/opencode_server/models/common.py +23 -0
- agentpool_server/opencode_server/models/config.py +37 -0
- agentpool_server/opencode_server/models/events.py +647 -0
- agentpool_server/opencode_server/models/file.py +88 -0
- agentpool_server/opencode_server/models/mcp.py +25 -0
- agentpool_server/opencode_server/models/message.py +162 -0
- agentpool_server/opencode_server/models/parts.py +190 -0
- agentpool_server/opencode_server/models/provider.py +81 -0
- agentpool_server/opencode_server/models/pty.py +43 -0
- agentpool_server/opencode_server/models/session.py +99 -0
- agentpool_server/opencode_server/routes/__init__.py +25 -0
- agentpool_server/opencode_server/routes/agent_routes.py +442 -0
- agentpool_server/opencode_server/routes/app_routes.py +139 -0
- agentpool_server/opencode_server/routes/config_routes.py +241 -0
- agentpool_server/opencode_server/routes/file_routes.py +392 -0
- agentpool_server/opencode_server/routes/global_routes.py +94 -0
- agentpool_server/opencode_server/routes/lsp_routes.py +319 -0
- agentpool_server/opencode_server/routes/message_routes.py +705 -0
- agentpool_server/opencode_server/routes/pty_routes.py +299 -0
- agentpool_server/opencode_server/routes/session_routes.py +1205 -0
- agentpool_server/opencode_server/routes/tui_routes.py +139 -0
- agentpool_server/opencode_server/server.py +430 -0
- agentpool_server/opencode_server/state.py +121 -0
- agentpool_server/opencode_server/time_utils.py +8 -0
- agentpool_storage/__init__.py +16 -0
- agentpool_storage/base.py +103 -0
- agentpool_storage/claude_provider.py +907 -0
- agentpool_storage/file_provider.py +129 -0
- agentpool_storage/memory_provider.py +61 -0
- agentpool_storage/models.py +3 -0
- agentpool_storage/opencode_provider.py +730 -0
- agentpool_storage/project_store.py +325 -0
- agentpool_storage/session_store.py +6 -0
- agentpool_storage/sql_provider/__init__.py +4 -2
- agentpool_storage/sql_provider/models.py +48 -0
- agentpool_storage/sql_provider/sql_provider.py +134 -1
- agentpool_storage/sql_provider/utils.py +10 -1
- agentpool_storage/text_log_provider.py +1 -0
- agentpool_toolsets/builtin/__init__.py +0 -8
- agentpool_toolsets/builtin/code.py +95 -56
- agentpool_toolsets/builtin/debug.py +16 -21
- agentpool_toolsets/builtin/execution_environment.py +99 -103
- agentpool_toolsets/builtin/file_edit/file_edit.py +115 -7
- agentpool_toolsets/builtin/skills.py +86 -4
- agentpool_toolsets/fsspec_toolset/__init__.py +13 -1
- agentpool_toolsets/fsspec_toolset/diagnostics.py +860 -73
- agentpool_toolsets/fsspec_toolset/grep.py +74 -2
- agentpool_toolsets/fsspec_toolset/image_utils.py +161 -0
- agentpool_toolsets/fsspec_toolset/toolset.py +159 -38
- agentpool_toolsets/mcp_discovery/__init__.py +5 -0
- agentpool_toolsets/mcp_discovery/data/mcp_servers.parquet +0 -0
- agentpool_toolsets/mcp_discovery/toolset.py +454 -0
- agentpool_toolsets/mcp_run_toolset.py +84 -6
- agentpool_toolsets/builtin/agent_management.py +0 -239
- agentpool_toolsets/builtin/history.py +0 -36
- agentpool_toolsets/builtin/integration.py +0 -85
- agentpool_toolsets/builtin/tool_management.py +0 -90
- {agentpool-2.1.9.dist-info → agentpool-2.2.3.dist-info}/entry_points.txt +0 -0
- {agentpool-2.1.9.dist-info → agentpool-2.2.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""Utilities for subprocess management with async support."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import contextlib
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import anyio
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
15
|
+
|
|
16
|
+
from anyio.abc import ByteReceiveStream, Process
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SubprocessError(RuntimeError):
|
|
21
|
+
"""Error raised when a subprocess exits unexpectedly."""
|
|
22
|
+
|
|
23
|
+
returncode: int | None
|
|
24
|
+
stderr: str
|
|
25
|
+
|
|
26
|
+
def __str__(self) -> str:
|
|
27
|
+
msg = f"Subprocess exited unexpectedly (code {self.returncode})"
|
|
28
|
+
if self.stderr:
|
|
29
|
+
msg = f"{msg}:\n{self.stderr}"
|
|
30
|
+
return msg
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def read_stream(
|
|
34
|
+
stream: ByteReceiveStream | None,
|
|
35
|
+
*,
|
|
36
|
+
timeout: float = 0.5,
|
|
37
|
+
max_bytes: int = 65536,
|
|
38
|
+
) -> str:
|
|
39
|
+
"""Read all available data from an anyio byte stream.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
stream: The anyio ByteReceiveStream to read from
|
|
43
|
+
timeout: Timeout for each read operation
|
|
44
|
+
max_bytes: Maximum bytes to read total
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Decoded string content from the stream
|
|
48
|
+
"""
|
|
49
|
+
if stream is None:
|
|
50
|
+
return ""
|
|
51
|
+
|
|
52
|
+
chunks: list[bytes] = []
|
|
53
|
+
total_bytes = 0
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
while total_bytes < max_bytes:
|
|
57
|
+
with anyio.move_on_after(timeout) as scope:
|
|
58
|
+
chunk = await stream.receive(4096)
|
|
59
|
+
if not chunk:
|
|
60
|
+
break
|
|
61
|
+
chunks.append(chunk)
|
|
62
|
+
total_bytes += len(chunk)
|
|
63
|
+
if scope.cancelled_caught:
|
|
64
|
+
break
|
|
65
|
+
except anyio.EndOfStream:
|
|
66
|
+
pass
|
|
67
|
+
except Exception as e: # noqa: BLE001
|
|
68
|
+
return f"(failed to read stream: {e})"
|
|
69
|
+
|
|
70
|
+
return b"".join(chunks).decode(errors="replace").strip()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@contextlib.asynccontextmanager
|
|
74
|
+
async def monitor_process(
|
|
75
|
+
process: Process,
|
|
76
|
+
*,
|
|
77
|
+
context: str = "operation",
|
|
78
|
+
) -> AsyncIterator[None]:
|
|
79
|
+
"""Context manager that monitors a subprocess for unexpected exit.
|
|
80
|
+
|
|
81
|
+
Races the wrapped code against process termination. If the process
|
|
82
|
+
exits before the code completes, raises SubprocessError with stderr.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
process: The anyio Process to monitor
|
|
86
|
+
context: Description of what's being done (for error messages)
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
SubprocessError: If process exits during the wrapped operation
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
```python
|
|
93
|
+
async with monitor_process(process, context="initialization"):
|
|
94
|
+
await do_initialization()
|
|
95
|
+
await create_session()
|
|
96
|
+
```
|
|
97
|
+
"""
|
|
98
|
+
process_wait_task = asyncio.create_task(process.wait())
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
yield
|
|
102
|
+
except BaseException:
|
|
103
|
+
# If the wrapped code raises, cancel the wait task and re-raise
|
|
104
|
+
process_wait_task.cancel()
|
|
105
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
106
|
+
await process_wait_task
|
|
107
|
+
raise
|
|
108
|
+
|
|
109
|
+
# Check if process died during operation
|
|
110
|
+
if process_wait_task.done():
|
|
111
|
+
stderr_output = await read_stream(process.stderr)
|
|
112
|
+
raise SubprocessError(
|
|
113
|
+
returncode=process.returncode,
|
|
114
|
+
stderr=stderr_output,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Operation completed successfully, cancel the wait task
|
|
118
|
+
process_wait_task.cancel()
|
|
119
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
120
|
+
await process_wait_task
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
async def run_with_process_monitor[T](
|
|
124
|
+
process: Process,
|
|
125
|
+
coro: Callable[[], Awaitable[T]],
|
|
126
|
+
*,
|
|
127
|
+
context: str = "operation",
|
|
128
|
+
) -> T:
|
|
129
|
+
"""Run a coroutine while monitoring a subprocess for unexpected exit.
|
|
130
|
+
|
|
131
|
+
Races the coroutine against process termination. If the process
|
|
132
|
+
exits before the coroutine completes, raises SubprocessError with stderr.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
process: The anyio Process to monitor
|
|
136
|
+
coro: Async callable to execute
|
|
137
|
+
context: Description of what's being done (for error messages)
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
The result of the coroutine
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
SubprocessError: If process exits during execution
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
```python
|
|
147
|
+
result = await run_with_process_monitor(
|
|
148
|
+
process,
|
|
149
|
+
lambda: initialize_connection(),
|
|
150
|
+
context="initialization",
|
|
151
|
+
)
|
|
152
|
+
```
|
|
153
|
+
"""
|
|
154
|
+
async with monitor_process(process, context=context):
|
|
155
|
+
return await coro()
|
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
"""Token breakdown utilities for analyzing context window usage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
import json
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from pydantic_ai.messages import (
|
|
11
|
+
ModelRequest,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
SystemPromptPart,
|
|
14
|
+
ThinkingPart,
|
|
15
|
+
ToolCallPart,
|
|
16
|
+
)
|
|
17
|
+
from pydantic_ai.models import ModelRequestParameters
|
|
18
|
+
from pydantic_ai.tools import ToolDefinition
|
|
19
|
+
from pydantic_ai.usage import RequestUsage, RunUsage
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from collections.abc import Sequence
|
|
24
|
+
|
|
25
|
+
from pydantic_ai.messages import ModelMessage, TextPart
|
|
26
|
+
from pydantic_ai.models import Model
|
|
27
|
+
from pydantic_ai.settings import ModelSettings
|
|
28
|
+
|
|
29
|
+
from agentpool.messaging.messages import TokenCost
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class TokenUsage:
|
|
34
|
+
"""Single item's token count."""
|
|
35
|
+
|
|
36
|
+
token_count: int
|
|
37
|
+
label: str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class RunTokenUsage:
|
|
42
|
+
"""Token usage for a single agent run."""
|
|
43
|
+
|
|
44
|
+
run_id: str | None
|
|
45
|
+
token_count: int
|
|
46
|
+
request_count: int
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class TokenBreakdown:
|
|
51
|
+
"""Complete token breakdown of context."""
|
|
52
|
+
|
|
53
|
+
total_tokens: int
|
|
54
|
+
|
|
55
|
+
system_prompts: list[TokenUsage]
|
|
56
|
+
tool_definitions: list[TokenUsage]
|
|
57
|
+
runs: list[RunTokenUsage]
|
|
58
|
+
|
|
59
|
+
approximate: bool
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def system_prompts_tokens(self) -> int:
|
|
63
|
+
return sum(t.token_count for t in self.system_prompts)
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def tool_definitions_tokens(self) -> int:
|
|
67
|
+
return sum(t.token_count for t in self.tool_definitions)
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def conversation_tokens(self) -> int:
|
|
71
|
+
return sum(r.token_count for r in self.runs)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _normalize_tool_schema(tool: ToolDefinition | dict[str, Any]) -> dict[str, Any]:
|
|
75
|
+
"""Convert a ToolDefinition or dict to a consistent dict format."""
|
|
76
|
+
if isinstance(tool, ToolDefinition):
|
|
77
|
+
return {
|
|
78
|
+
"name": tool.name,
|
|
79
|
+
"description": tool.description,
|
|
80
|
+
"parameters": tool.parameters_json_schema,
|
|
81
|
+
}
|
|
82
|
+
return tool
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def count_tokens(text: str, model_name: str = "gpt-4") -> int:
|
|
86
|
+
"""Count tokens using tiktoken.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
text: The text to count tokens for.
|
|
90
|
+
model_name: The model name for encoding selection.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The number of tokens in the text.
|
|
94
|
+
"""
|
|
95
|
+
try:
|
|
96
|
+
import tiktoken
|
|
97
|
+
except ImportError:
|
|
98
|
+
# Rough approximation: ~4 chars per token
|
|
99
|
+
return len(text) // 4
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
encoding = tiktoken.encoding_for_model(model_name)
|
|
103
|
+
except KeyError:
|
|
104
|
+
# Fall back to cl100k_base for unknown models
|
|
105
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
106
|
+
|
|
107
|
+
return len(encoding.encode(text))
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
async def calculate_usage_from_parts(
|
|
111
|
+
input_parts: Sequence[Any],
|
|
112
|
+
response_parts: Sequence[TextPart | ThinkingPart | ToolCallPart],
|
|
113
|
+
text_content: str,
|
|
114
|
+
model_name: str | None = None,
|
|
115
|
+
provider: str | None = None,
|
|
116
|
+
) -> tuple[RequestUsage, TokenCost | None]:
|
|
117
|
+
"""Calculate token usage and cost from input/output parts.
|
|
118
|
+
|
|
119
|
+
This is used by agents that don't receive usage info from the backend
|
|
120
|
+
(like ACP and AG-UI agents) to approximate token counts.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
input_parts: Input parts (prompts, pending parts) sent to the agent
|
|
124
|
+
response_parts: Response parts received (text, thinking, tool calls)
|
|
125
|
+
text_content: The final text content of the response
|
|
126
|
+
model_name: Model name for token counting and cost calculation
|
|
127
|
+
provider: Provider name for cost calculation
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Tuple of (RequestUsage, TokenCost or None)
|
|
131
|
+
"""
|
|
132
|
+
from agentpool.messaging.messages import TokenCost
|
|
133
|
+
|
|
134
|
+
model_for_count = model_name or "gpt-4"
|
|
135
|
+
|
|
136
|
+
# Input tokens from prompts
|
|
137
|
+
input_text = " ".join(str(p) for p in input_parts)
|
|
138
|
+
input_tokens = count_tokens(input_text, model_for_count)
|
|
139
|
+
|
|
140
|
+
# Output tokens from response content
|
|
141
|
+
output_text = text_content
|
|
142
|
+
for part in response_parts:
|
|
143
|
+
if isinstance(part, ThinkingPart) and part.content:
|
|
144
|
+
output_text += part.content
|
|
145
|
+
elif isinstance(part, ToolCallPart) and part.args:
|
|
146
|
+
args_str = json.dumps(part.args) if not isinstance(part.args, str) else part.args
|
|
147
|
+
output_text += args_str
|
|
148
|
+
output_tokens = count_tokens(output_text, model_for_count)
|
|
149
|
+
|
|
150
|
+
# Build usage
|
|
151
|
+
usage = RequestUsage(input_tokens=input_tokens, output_tokens=output_tokens)
|
|
152
|
+
run_usage = RunUsage(input_tokens=input_tokens, output_tokens=output_tokens)
|
|
153
|
+
|
|
154
|
+
# Calculate cost
|
|
155
|
+
cost_info = await TokenCost.from_usage(
|
|
156
|
+
usage=run_usage,
|
|
157
|
+
model=model_name or "unknown",
|
|
158
|
+
provider=provider,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
return usage, cost_info
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _extract_system_prompts(messages: Sequence[ModelMessage]) -> list[str]:
|
|
165
|
+
"""Extract all system prompt contents from messages."""
|
|
166
|
+
prompts: list[str] = []
|
|
167
|
+
for message in messages:
|
|
168
|
+
if isinstance(message, ModelRequest):
|
|
169
|
+
for part in message.parts:
|
|
170
|
+
if isinstance(part, SystemPromptPart):
|
|
171
|
+
prompts.append(part.content) # noqa: PERF401
|
|
172
|
+
return prompts
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _group_messages_by_run(
|
|
176
|
+
messages: Sequence[ModelMessage],
|
|
177
|
+
) -> dict[str | None, list[ModelMessage]]:
|
|
178
|
+
"""Group messages by their run_id."""
|
|
179
|
+
groups: dict[str | None, list[ModelMessage]] = defaultdict(list)
|
|
180
|
+
for message in messages:
|
|
181
|
+
run_id = message.run_id
|
|
182
|
+
groups[run_id].append(message)
|
|
183
|
+
return dict(groups)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _messages_to_text(messages: Sequence[ModelMessage]) -> str:
|
|
187
|
+
"""Convert messages to a text representation for token counting."""
|
|
188
|
+
text_parts: list[str] = []
|
|
189
|
+
for message in messages:
|
|
190
|
+
if isinstance(message, ModelRequest):
|
|
191
|
+
for request_part in message.parts:
|
|
192
|
+
if hasattr(request_part, "content") and isinstance(request_part.content, str):
|
|
193
|
+
text_parts.append(request_part.content) # noqa: PERF401
|
|
194
|
+
elif isinstance(message, ModelResponse):
|
|
195
|
+
if text := message.text:
|
|
196
|
+
text_parts.append(text)
|
|
197
|
+
for part in message.parts:
|
|
198
|
+
if isinstance(part, ToolCallPart):
|
|
199
|
+
# Tool call arguments
|
|
200
|
+
args = part.args
|
|
201
|
+
if isinstance(args, str):
|
|
202
|
+
text_parts.append(args)
|
|
203
|
+
elif args:
|
|
204
|
+
text_parts.append(json.dumps(args))
|
|
205
|
+
return "\n".join(text_parts)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
async def get_token_breakdown(
|
|
209
|
+
model: Model,
|
|
210
|
+
messages: Sequence[ModelMessage],
|
|
211
|
+
tool_schemas: Sequence[ToolDefinition | dict[str, Any]] | None = None,
|
|
212
|
+
model_settings: ModelSettings | None = None,
|
|
213
|
+
) -> TokenBreakdown:
|
|
214
|
+
"""Get a breakdown of token usage by component.
|
|
215
|
+
|
|
216
|
+
Uses model.count_tokens() if available, falls back to tiktoken.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model: The model to use for token counting.
|
|
220
|
+
messages: The message history to analyze.
|
|
221
|
+
tool_schemas: Tool definitions or raw JSON schemas.
|
|
222
|
+
model_settings: Optional model settings.
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
A TokenBreakdown with detailed token usage by component.
|
|
226
|
+
"""
|
|
227
|
+
tool_schemas = tool_schemas or []
|
|
228
|
+
approximate = False
|
|
229
|
+
model_name = model.model_name or "gpt-4"
|
|
230
|
+
|
|
231
|
+
# Try to use model.count_tokens(), fall back to tiktoken
|
|
232
|
+
async def count_tokens_for_messages(msgs: Sequence[ModelMessage]) -> int:
|
|
233
|
+
nonlocal approximate
|
|
234
|
+
try:
|
|
235
|
+
# Build minimal ModelRequestParameters for counting
|
|
236
|
+
params = ModelRequestParameters()
|
|
237
|
+
usage = await model.count_tokens(list(msgs), model_settings, params)
|
|
238
|
+
except NotImplementedError:
|
|
239
|
+
approximate = True
|
|
240
|
+
return count_tokens(_messages_to_text(msgs), model_name)
|
|
241
|
+
else:
|
|
242
|
+
return usage.input_tokens
|
|
243
|
+
|
|
244
|
+
# Extract and count system prompts
|
|
245
|
+
system_prompt_contents = _extract_system_prompts(messages)
|
|
246
|
+
system_prompt_usages: list[TokenUsage] = []
|
|
247
|
+
for i, content in enumerate(system_prompt_contents):
|
|
248
|
+
token_count = count_tokens(content, model_name)
|
|
249
|
+
label = content[:50] + "..." if len(content) > 50 else content # noqa: PLR2004
|
|
250
|
+
system_prompt_usages.append(
|
|
251
|
+
TokenUsage(token_count=token_count, label=f"System prompt {i + 1}: {label}")
|
|
252
|
+
)
|
|
253
|
+
# Mark as approximate since we're using tiktoken for individual prompts
|
|
254
|
+
if system_prompt_usages:
|
|
255
|
+
approximate = True
|
|
256
|
+
|
|
257
|
+
# Count tool definition tokens
|
|
258
|
+
tool_usages: list[TokenUsage] = []
|
|
259
|
+
for tool in tool_schemas:
|
|
260
|
+
schema = _normalize_tool_schema(tool)
|
|
261
|
+
schema_text = json.dumps(schema)
|
|
262
|
+
token_count = count_tokens(schema_text, model_name)
|
|
263
|
+
tool_name = schema.get("name", "unknown")
|
|
264
|
+
tool_usages.append(TokenUsage(token_count=token_count, label=tool_name))
|
|
265
|
+
if tool_usages:
|
|
266
|
+
approximate = True
|
|
267
|
+
|
|
268
|
+
# Group messages by run and count tokens per run
|
|
269
|
+
run_groups = _group_messages_by_run(messages)
|
|
270
|
+
run_usages: list[RunTokenUsage] = []
|
|
271
|
+
for run_id, run_messages in run_groups.items():
|
|
272
|
+
token_count = await count_tokens_for_messages(run_messages)
|
|
273
|
+
request_count = sum(1 for m in run_messages if isinstance(m, ModelRequest))
|
|
274
|
+
run_usages.append(
|
|
275
|
+
RunTokenUsage(
|
|
276
|
+
run_id=run_id,
|
|
277
|
+
token_count=token_count,
|
|
278
|
+
request_count=request_count,
|
|
279
|
+
)
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Calculate total
|
|
283
|
+
total = (
|
|
284
|
+
sum(u.token_count for u in system_prompt_usages)
|
|
285
|
+
+ sum(u.token_count for u in tool_usages)
|
|
286
|
+
+ sum(r.token_count for r in run_usages)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return TokenBreakdown(
|
|
290
|
+
total_tokens=total,
|
|
291
|
+
system_prompts=system_prompt_usages,
|
|
292
|
+
tool_definitions=tool_usages,
|
|
293
|
+
runs=run_usages,
|
|
294
|
+
approximate=approximate,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def format_breakdown(breakdown: TokenBreakdown, detailed: bool = False) -> str:
|
|
299
|
+
"""Format a token breakdown for display."""
|
|
300
|
+
lines: list[str] = []
|
|
301
|
+
# Header
|
|
302
|
+
approx_marker = " (approximate)" if breakdown.approximate else ""
|
|
303
|
+
lines.append(f"Token Breakdown{approx_marker}")
|
|
304
|
+
lines.append("=" * 50)
|
|
305
|
+
# Summary
|
|
306
|
+
lines.append(f"Total tokens: {breakdown.total_tokens:,}")
|
|
307
|
+
lines.append("")
|
|
308
|
+
# Category breakdown
|
|
309
|
+
lines.append("By category:")
|
|
310
|
+
lines.append(f" System prompts: {breakdown.system_prompts_tokens:,} tokens")
|
|
311
|
+
lines.append(f" Tool definitions: {breakdown.tool_definitions_tokens:,} tokens")
|
|
312
|
+
lines.append(f" Conversation: {breakdown.conversation_tokens:,} tokens")
|
|
313
|
+
if detailed:
|
|
314
|
+
lines.append("")
|
|
315
|
+
lines.append("-" * 50)
|
|
316
|
+
# System prompts detail
|
|
317
|
+
if breakdown.system_prompts:
|
|
318
|
+
lines.append("")
|
|
319
|
+
lines.append("System Prompts:")
|
|
320
|
+
for sp in breakdown.system_prompts:
|
|
321
|
+
lines.append(f" [{sp.token_count:,} tokens] {sp.label}") # noqa: PERF401
|
|
322
|
+
# Tool definitions detail
|
|
323
|
+
if breakdown.tool_definitions:
|
|
324
|
+
lines.append("")
|
|
325
|
+
lines.append("Tool Definitions:")
|
|
326
|
+
for tool in breakdown.tool_definitions:
|
|
327
|
+
lines.append(f" [{tool.token_count:,} tokens] {tool.label}") # noqa: PERF401
|
|
328
|
+
# Runs detail
|
|
329
|
+
if breakdown.runs:
|
|
330
|
+
lines.append("")
|
|
331
|
+
lines.append("Conversation by Run:")
|
|
332
|
+
for run in breakdown.runs:
|
|
333
|
+
run_label = run.run_id[:8] + "..." if run.run_id else "(no run_id)"
|
|
334
|
+
lines.append(
|
|
335
|
+
f" [{run.token_count:,} tokens] {run_label} ({run.request_count} requests)"
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
lines.append("")
|
|
339
|
+
return "\n".join(lines)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
if __name__ == "__main__":
|
|
343
|
+
import asyncio
|
|
344
|
+
|
|
345
|
+
from pydantic_ai.messages import (
|
|
346
|
+
ImageUrl,
|
|
347
|
+
ModelRequest,
|
|
348
|
+
ModelResponse,
|
|
349
|
+
SystemPromptPart,
|
|
350
|
+
TextPart,
|
|
351
|
+
ToolCallPart,
|
|
352
|
+
ToolReturnPart,
|
|
353
|
+
UserPromptPart,
|
|
354
|
+
)
|
|
355
|
+
from pydantic_ai.models.test import TestModel
|
|
356
|
+
from pydantic_ai.tools import ToolDefinition
|
|
357
|
+
|
|
358
|
+
async def main() -> None:
|
|
359
|
+
# Create sample tool definitions
|
|
360
|
+
tool_definitions = [
|
|
361
|
+
ToolDefinition(
|
|
362
|
+
name="get_weather",
|
|
363
|
+
description="Get the current weather for a city.",
|
|
364
|
+
parameters_json_schema={
|
|
365
|
+
"type": "object",
|
|
366
|
+
"properties": {
|
|
367
|
+
"city": {"type": "string", "description": "The city name"},
|
|
368
|
+
},
|
|
369
|
+
"required": ["city"],
|
|
370
|
+
},
|
|
371
|
+
),
|
|
372
|
+
ToolDefinition(
|
|
373
|
+
name="search_database",
|
|
374
|
+
description="Search a database with a complex query. Supports filtering, sorting, and pagination.", # noqa: E501
|
|
375
|
+
parameters_json_schema={
|
|
376
|
+
"type": "object",
|
|
377
|
+
"properties": {
|
|
378
|
+
"query": {"type": "string", "description": "Search query"},
|
|
379
|
+
"filters": {
|
|
380
|
+
"type": "object",
|
|
381
|
+
"description": "Filter conditions",
|
|
382
|
+
"additionalProperties": True,
|
|
383
|
+
},
|
|
384
|
+
"sort_by": {"type": "string", "description": "Field to sort by"},
|
|
385
|
+
"limit": {"type": "integer", "description": "Max results"},
|
|
386
|
+
"offset": {"type": "integer", "description": "Skip N results"},
|
|
387
|
+
},
|
|
388
|
+
"required": ["query"],
|
|
389
|
+
},
|
|
390
|
+
),
|
|
391
|
+
]
|
|
392
|
+
|
|
393
|
+
# Create sample message history simulating two runs
|
|
394
|
+
messages: list[ModelMessage] = [
|
|
395
|
+
# First run
|
|
396
|
+
ModelRequest(
|
|
397
|
+
parts=[
|
|
398
|
+
SystemPromptPart(
|
|
399
|
+
content="You are a helpful assistant with access to weather and time tools."
|
|
400
|
+
),
|
|
401
|
+
UserPromptPart(content="What's the weather in Paris?"),
|
|
402
|
+
],
|
|
403
|
+
run_id="run-001-abc",
|
|
404
|
+
),
|
|
405
|
+
ModelResponse(
|
|
406
|
+
parts=[
|
|
407
|
+
ToolCallPart(
|
|
408
|
+
tool_name="get_weather", args={"city": "Paris"}, tool_call_id="call-1"
|
|
409
|
+
),
|
|
410
|
+
],
|
|
411
|
+
run_id="run-001-abc",
|
|
412
|
+
),
|
|
413
|
+
ModelRequest(
|
|
414
|
+
parts=[
|
|
415
|
+
ToolReturnPart(
|
|
416
|
+
tool_name="get_weather",
|
|
417
|
+
content="Sunny, 22°C in Paris",
|
|
418
|
+
tool_call_id="call-1",
|
|
419
|
+
),
|
|
420
|
+
],
|
|
421
|
+
run_id="run-001-abc",
|
|
422
|
+
),
|
|
423
|
+
ModelResponse(
|
|
424
|
+
parts=[
|
|
425
|
+
TextPart(content="The weather in Paris is sunny with a temperature of 22°C."),
|
|
426
|
+
],
|
|
427
|
+
run_id="run-001-abc",
|
|
428
|
+
),
|
|
429
|
+
# Second run (continuing conversation) - includes an image
|
|
430
|
+
ModelRequest(
|
|
431
|
+
parts=[
|
|
432
|
+
UserPromptPart(
|
|
433
|
+
content=[
|
|
434
|
+
"What time is it in Tokyo?",
|
|
435
|
+
ImageUrl(url="https://example.com/tokyo-clock.jpg"),
|
|
436
|
+
]
|
|
437
|
+
),
|
|
438
|
+
],
|
|
439
|
+
run_id="run-002-def",
|
|
440
|
+
),
|
|
441
|
+
ModelResponse(
|
|
442
|
+
parts=[
|
|
443
|
+
TextPart(content="The current time in Tokyo is 14:30 JST."),
|
|
444
|
+
],
|
|
445
|
+
run_id="run-002-def",
|
|
446
|
+
),
|
|
447
|
+
]
|
|
448
|
+
|
|
449
|
+
# Get the breakdown using TestModel (will use tiktoken fallback)
|
|
450
|
+
model = TestModel()
|
|
451
|
+
breakdown = await get_token_breakdown(
|
|
452
|
+
model=model, messages=messages, tool_schemas=tool_definitions
|
|
453
|
+
)
|
|
454
|
+
# Print summary view
|
|
455
|
+
print("SUMMARY VIEW")
|
|
456
|
+
print(format_breakdown(breakdown, detailed=False))
|
|
457
|
+
# Print detailed view
|
|
458
|
+
print("DETAILED VIEW")
|
|
459
|
+
print(format_breakdown(breakdown, detailed=True))
|
|
460
|
+
|
|
461
|
+
asyncio.run(main())
|