alita-sdk 0.3.379__py3-none-any.whl → 0.3.627__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- alita_sdk/cli/__init__.py +10 -0
- alita_sdk/cli/__main__.py +17 -0
- alita_sdk/cli/agent/__init__.py +5 -0
- alita_sdk/cli/agent/default.py +258 -0
- alita_sdk/cli/agent_executor.py +156 -0
- alita_sdk/cli/agent_loader.py +245 -0
- alita_sdk/cli/agent_ui.py +228 -0
- alita_sdk/cli/agents.py +3113 -0
- alita_sdk/cli/callbacks.py +647 -0
- alita_sdk/cli/cli.py +168 -0
- alita_sdk/cli/config.py +306 -0
- alita_sdk/cli/context/__init__.py +30 -0
- alita_sdk/cli/context/cleanup.py +198 -0
- alita_sdk/cli/context/manager.py +731 -0
- alita_sdk/cli/context/message.py +285 -0
- alita_sdk/cli/context/strategies.py +289 -0
- alita_sdk/cli/context/token_estimation.py +127 -0
- alita_sdk/cli/formatting.py +182 -0
- alita_sdk/cli/input_handler.py +419 -0
- alita_sdk/cli/inventory.py +1073 -0
- alita_sdk/cli/mcp_loader.py +315 -0
- alita_sdk/cli/testcases/__init__.py +94 -0
- alita_sdk/cli/testcases/data_generation.py +119 -0
- alita_sdk/cli/testcases/discovery.py +96 -0
- alita_sdk/cli/testcases/executor.py +84 -0
- alita_sdk/cli/testcases/logger.py +85 -0
- alita_sdk/cli/testcases/parser.py +172 -0
- alita_sdk/cli/testcases/prompts.py +91 -0
- alita_sdk/cli/testcases/reporting.py +125 -0
- alita_sdk/cli/testcases/setup.py +108 -0
- alita_sdk/cli/testcases/test_runner.py +282 -0
- alita_sdk/cli/testcases/utils.py +39 -0
- alita_sdk/cli/testcases/validation.py +90 -0
- alita_sdk/cli/testcases/workflow.py +196 -0
- alita_sdk/cli/toolkit.py +327 -0
- alita_sdk/cli/toolkit_loader.py +85 -0
- alita_sdk/cli/tools/__init__.py +43 -0
- alita_sdk/cli/tools/approval.py +224 -0
- alita_sdk/cli/tools/filesystem.py +1751 -0
- alita_sdk/cli/tools/planning.py +389 -0
- alita_sdk/cli/tools/terminal.py +414 -0
- alita_sdk/community/__init__.py +72 -12
- alita_sdk/community/inventory/__init__.py +236 -0
- alita_sdk/community/inventory/config.py +257 -0
- alita_sdk/community/inventory/enrichment.py +2137 -0
- alita_sdk/community/inventory/extractors.py +1469 -0
- alita_sdk/community/inventory/ingestion.py +3172 -0
- alita_sdk/community/inventory/knowledge_graph.py +1457 -0
- alita_sdk/community/inventory/parsers/__init__.py +218 -0
- alita_sdk/community/inventory/parsers/base.py +295 -0
- alita_sdk/community/inventory/parsers/csharp_parser.py +907 -0
- alita_sdk/community/inventory/parsers/go_parser.py +851 -0
- alita_sdk/community/inventory/parsers/html_parser.py +389 -0
- alita_sdk/community/inventory/parsers/java_parser.py +593 -0
- alita_sdk/community/inventory/parsers/javascript_parser.py +629 -0
- alita_sdk/community/inventory/parsers/kotlin_parser.py +768 -0
- alita_sdk/community/inventory/parsers/markdown_parser.py +362 -0
- alita_sdk/community/inventory/parsers/python_parser.py +604 -0
- alita_sdk/community/inventory/parsers/rust_parser.py +858 -0
- alita_sdk/community/inventory/parsers/swift_parser.py +832 -0
- alita_sdk/community/inventory/parsers/text_parser.py +322 -0
- alita_sdk/community/inventory/parsers/yaml_parser.py +370 -0
- alita_sdk/community/inventory/patterns/__init__.py +61 -0
- alita_sdk/community/inventory/patterns/ast_adapter.py +380 -0
- alita_sdk/community/inventory/patterns/loader.py +348 -0
- alita_sdk/community/inventory/patterns/registry.py +198 -0
- alita_sdk/community/inventory/presets.py +535 -0
- alita_sdk/community/inventory/retrieval.py +1403 -0
- alita_sdk/community/inventory/toolkit.py +173 -0
- alita_sdk/community/inventory/toolkit_utils.py +176 -0
- alita_sdk/community/inventory/visualize.py +1370 -0
- alita_sdk/configurations/__init__.py +1 -1
- alita_sdk/configurations/ado.py +141 -20
- alita_sdk/configurations/bitbucket.py +94 -2
- alita_sdk/configurations/confluence.py +130 -1
- alita_sdk/configurations/figma.py +76 -0
- alita_sdk/configurations/gitlab.py +91 -0
- alita_sdk/configurations/jira.py +103 -0
- alita_sdk/configurations/openapi.py +329 -0
- alita_sdk/configurations/qtest.py +72 -1
- alita_sdk/configurations/report_portal.py +96 -0
- alita_sdk/configurations/sharepoint.py +148 -0
- alita_sdk/configurations/testio.py +83 -0
- alita_sdk/configurations/testrail.py +88 -0
- alita_sdk/configurations/xray.py +93 -0
- alita_sdk/configurations/zephyr_enterprise.py +93 -0
- alita_sdk/configurations/zephyr_essential.py +75 -0
- alita_sdk/runtime/clients/artifact.py +3 -3
- alita_sdk/runtime/clients/client.py +388 -46
- alita_sdk/runtime/clients/mcp_discovery.py +342 -0
- alita_sdk/runtime/clients/mcp_manager.py +262 -0
- alita_sdk/runtime/clients/sandbox_client.py +8 -21
- alita_sdk/runtime/langchain/_constants_bkup.py +1318 -0
- alita_sdk/runtime/langchain/assistant.py +157 -39
- alita_sdk/runtime/langchain/constants.py +647 -1
- alita_sdk/runtime/langchain/document_loaders/AlitaDocxMammothLoader.py +315 -3
- alita_sdk/runtime/langchain/document_loaders/AlitaExcelLoader.py +103 -60
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLinesLoader.py +77 -0
- alita_sdk/runtime/langchain/document_loaders/AlitaJSONLoader.py +10 -4
- alita_sdk/runtime/langchain/document_loaders/AlitaPowerPointLoader.py +226 -7
- alita_sdk/runtime/langchain/document_loaders/AlitaTextLoader.py +5 -2
- alita_sdk/runtime/langchain/document_loaders/constants.py +40 -19
- alita_sdk/runtime/langchain/langraph_agent.py +405 -84
- alita_sdk/runtime/langchain/utils.py +106 -7
- alita_sdk/runtime/llms/preloaded.py +2 -6
- alita_sdk/runtime/models/mcp_models.py +61 -0
- alita_sdk/runtime/skills/__init__.py +91 -0
- alita_sdk/runtime/skills/callbacks.py +498 -0
- alita_sdk/runtime/skills/discovery.py +540 -0
- alita_sdk/runtime/skills/executor.py +610 -0
- alita_sdk/runtime/skills/input_builder.py +371 -0
- alita_sdk/runtime/skills/models.py +330 -0
- alita_sdk/runtime/skills/registry.py +355 -0
- alita_sdk/runtime/skills/skill_runner.py +330 -0
- alita_sdk/runtime/toolkits/__init__.py +31 -0
- alita_sdk/runtime/toolkits/application.py +29 -10
- alita_sdk/runtime/toolkits/artifact.py +20 -11
- alita_sdk/runtime/toolkits/datasource.py +13 -6
- alita_sdk/runtime/toolkits/mcp.py +783 -0
- alita_sdk/runtime/toolkits/mcp_config.py +1048 -0
- alita_sdk/runtime/toolkits/planning.py +178 -0
- alita_sdk/runtime/toolkits/skill_router.py +238 -0
- alita_sdk/runtime/toolkits/subgraph.py +251 -6
- alita_sdk/runtime/toolkits/tools.py +356 -69
- alita_sdk/runtime/toolkits/vectorstore.py +11 -5
- alita_sdk/runtime/tools/__init__.py +10 -3
- alita_sdk/runtime/tools/application.py +27 -6
- alita_sdk/runtime/tools/artifact.py +511 -28
- alita_sdk/runtime/tools/data_analysis.py +183 -0
- alita_sdk/runtime/tools/function.py +67 -35
- alita_sdk/runtime/tools/graph.py +10 -4
- alita_sdk/runtime/tools/image_generation.py +148 -46
- alita_sdk/runtime/tools/llm.py +1003 -128
- alita_sdk/runtime/tools/loop.py +3 -1
- alita_sdk/runtime/tools/loop_output.py +3 -1
- alita_sdk/runtime/tools/mcp_inspect_tool.py +284 -0
- alita_sdk/runtime/tools/mcp_remote_tool.py +181 -0
- alita_sdk/runtime/tools/mcp_server_tool.py +8 -5
- alita_sdk/runtime/tools/planning/__init__.py +36 -0
- alita_sdk/runtime/tools/planning/models.py +246 -0
- alita_sdk/runtime/tools/planning/wrapper.py +607 -0
- alita_sdk/runtime/tools/router.py +2 -4
- alita_sdk/runtime/tools/sandbox.py +65 -48
- alita_sdk/runtime/tools/skill_router.py +776 -0
- alita_sdk/runtime/tools/tool.py +3 -1
- alita_sdk/runtime/tools/vectorstore.py +9 -3
- alita_sdk/runtime/tools/vectorstore_base.py +70 -14
- alita_sdk/runtime/utils/AlitaCallback.py +137 -21
- alita_sdk/runtime/utils/constants.py +5 -1
- alita_sdk/runtime/utils/mcp_client.py +492 -0
- alita_sdk/runtime/utils/mcp_oauth.py +361 -0
- alita_sdk/runtime/utils/mcp_sse_client.py +434 -0
- alita_sdk/runtime/utils/mcp_tools_discovery.py +124 -0
- alita_sdk/runtime/utils/serialization.py +155 -0
- alita_sdk/runtime/utils/streamlit.py +40 -13
- alita_sdk/runtime/utils/toolkit_utils.py +30 -9
- alita_sdk/runtime/utils/utils.py +36 -0
- alita_sdk/tools/__init__.py +134 -35
- alita_sdk/tools/ado/repos/__init__.py +51 -32
- alita_sdk/tools/ado/repos/repos_wrapper.py +148 -89
- alita_sdk/tools/ado/test_plan/__init__.py +25 -9
- alita_sdk/tools/ado/test_plan/test_plan_wrapper.py +23 -1
- alita_sdk/tools/ado/utils.py +1 -18
- alita_sdk/tools/ado/wiki/__init__.py +25 -12
- alita_sdk/tools/ado/wiki/ado_wrapper.py +291 -22
- alita_sdk/tools/ado/work_item/__init__.py +26 -13
- alita_sdk/tools/ado/work_item/ado_wrapper.py +73 -11
- alita_sdk/tools/advanced_jira_mining/__init__.py +11 -8
- alita_sdk/tools/aws/delta_lake/__init__.py +13 -9
- alita_sdk/tools/aws/delta_lake/tool.py +5 -1
- alita_sdk/tools/azure_ai/search/__init__.py +11 -8
- alita_sdk/tools/azure_ai/search/api_wrapper.py +1 -1
- alita_sdk/tools/base/tool.py +5 -1
- alita_sdk/tools/base_indexer_toolkit.py +271 -84
- alita_sdk/tools/bitbucket/__init__.py +17 -11
- alita_sdk/tools/bitbucket/api_wrapper.py +59 -11
- alita_sdk/tools/bitbucket/cloud_api_wrapper.py +49 -35
- alita_sdk/tools/browser/__init__.py +5 -4
- alita_sdk/tools/carrier/__init__.py +5 -6
- alita_sdk/tools/carrier/backend_reports_tool.py +6 -6
- alita_sdk/tools/carrier/run_ui_test_tool.py +6 -6
- alita_sdk/tools/carrier/ui_reports_tool.py +5 -5
- alita_sdk/tools/chunkers/__init__.py +3 -1
- alita_sdk/tools/chunkers/code/treesitter/treesitter.py +37 -13
- alita_sdk/tools/chunkers/sematic/json_chunker.py +1 -0
- alita_sdk/tools/chunkers/sematic/markdown_chunker.py +97 -6
- alita_sdk/tools/chunkers/sematic/proposal_chunker.py +1 -1
- alita_sdk/tools/chunkers/universal_chunker.py +270 -0
- alita_sdk/tools/cloud/aws/__init__.py +10 -7
- alita_sdk/tools/cloud/azure/__init__.py +10 -7
- alita_sdk/tools/cloud/gcp/__init__.py +10 -7
- alita_sdk/tools/cloud/k8s/__init__.py +10 -7
- alita_sdk/tools/code/linter/__init__.py +10 -8
- alita_sdk/tools/code/loaders/codesearcher.py +3 -2
- alita_sdk/tools/code/sonar/__init__.py +11 -8
- alita_sdk/tools/code_indexer_toolkit.py +82 -22
- alita_sdk/tools/confluence/__init__.py +22 -16
- alita_sdk/tools/confluence/api_wrapper.py +107 -30
- alita_sdk/tools/confluence/loader.py +14 -2
- alita_sdk/tools/custom_open_api/__init__.py +12 -5
- alita_sdk/tools/elastic/__init__.py +11 -8
- alita_sdk/tools/elitea_base.py +493 -30
- alita_sdk/tools/figma/__init__.py +58 -11
- alita_sdk/tools/figma/api_wrapper.py +1235 -143
- alita_sdk/tools/figma/figma_client.py +73 -0
- alita_sdk/tools/figma/toon_tools.py +2748 -0
- alita_sdk/tools/github/__init__.py +14 -15
- alita_sdk/tools/github/github_client.py +224 -100
- alita_sdk/tools/github/graphql_client_wrapper.py +119 -33
- alita_sdk/tools/github/schemas.py +14 -5
- alita_sdk/tools/github/tool.py +5 -1
- alita_sdk/tools/github/tool_prompts.py +9 -22
- alita_sdk/tools/gitlab/__init__.py +16 -11
- alita_sdk/tools/gitlab/api_wrapper.py +218 -48
- alita_sdk/tools/gitlab_org/__init__.py +10 -9
- alita_sdk/tools/gitlab_org/api_wrapper.py +63 -64
- alita_sdk/tools/google/bigquery/__init__.py +13 -12
- alita_sdk/tools/google/bigquery/tool.py +5 -1
- alita_sdk/tools/google_places/__init__.py +11 -8
- alita_sdk/tools/google_places/api_wrapper.py +1 -1
- alita_sdk/tools/jira/__init__.py +17 -10
- alita_sdk/tools/jira/api_wrapper.py +92 -41
- alita_sdk/tools/keycloak/__init__.py +11 -8
- alita_sdk/tools/localgit/__init__.py +9 -3
- alita_sdk/tools/localgit/local_git.py +62 -54
- alita_sdk/tools/localgit/tool.py +5 -1
- alita_sdk/tools/memory/__init__.py +12 -4
- alita_sdk/tools/non_code_indexer_toolkit.py +1 -0
- alita_sdk/tools/ocr/__init__.py +11 -8
- alita_sdk/tools/openapi/__init__.py +491 -106
- alita_sdk/tools/openapi/api_wrapper.py +1368 -0
- alita_sdk/tools/openapi/tool.py +20 -0
- alita_sdk/tools/pandas/__init__.py +20 -12
- alita_sdk/tools/pandas/api_wrapper.py +38 -25
- alita_sdk/tools/pandas/dataframe/generator/base.py +3 -1
- alita_sdk/tools/postman/__init__.py +10 -9
- alita_sdk/tools/pptx/__init__.py +11 -10
- alita_sdk/tools/pptx/pptx_wrapper.py +1 -1
- alita_sdk/tools/qtest/__init__.py +31 -11
- alita_sdk/tools/qtest/api_wrapper.py +2135 -86
- alita_sdk/tools/rally/__init__.py +10 -9
- alita_sdk/tools/rally/api_wrapper.py +1 -1
- alita_sdk/tools/report_portal/__init__.py +12 -8
- alita_sdk/tools/salesforce/__init__.py +10 -8
- alita_sdk/tools/servicenow/__init__.py +17 -15
- alita_sdk/tools/servicenow/api_wrapper.py +1 -1
- alita_sdk/tools/sharepoint/__init__.py +10 -7
- alita_sdk/tools/sharepoint/api_wrapper.py +129 -38
- alita_sdk/tools/sharepoint/authorization_helper.py +191 -1
- alita_sdk/tools/sharepoint/utils.py +8 -2
- alita_sdk/tools/slack/__init__.py +10 -7
- alita_sdk/tools/slack/api_wrapper.py +2 -2
- alita_sdk/tools/sql/__init__.py +12 -9
- alita_sdk/tools/testio/__init__.py +10 -7
- alita_sdk/tools/testrail/__init__.py +11 -10
- alita_sdk/tools/testrail/api_wrapper.py +1 -1
- alita_sdk/tools/utils/__init__.py +9 -4
- alita_sdk/tools/utils/content_parser.py +103 -18
- alita_sdk/tools/utils/text_operations.py +410 -0
- alita_sdk/tools/utils/tool_prompts.py +79 -0
- alita_sdk/tools/vector_adapters/VectorStoreAdapter.py +30 -13
- alita_sdk/tools/xray/__init__.py +13 -9
- alita_sdk/tools/yagmail/__init__.py +9 -3
- alita_sdk/tools/zephyr/__init__.py +10 -7
- alita_sdk/tools/zephyr_enterprise/__init__.py +11 -7
- alita_sdk/tools/zephyr_essential/__init__.py +10 -7
- alita_sdk/tools/zephyr_essential/api_wrapper.py +30 -13
- alita_sdk/tools/zephyr_essential/client.py +2 -2
- alita_sdk/tools/zephyr_scale/__init__.py +11 -8
- alita_sdk/tools/zephyr_scale/api_wrapper.py +2 -2
- alita_sdk/tools/zephyr_squad/__init__.py +10 -7
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/METADATA +154 -8
- alita_sdk-0.3.627.dist-info/RECORD +468 -0
- alita_sdk-0.3.627.dist-info/entry_points.txt +2 -0
- alita_sdk-0.3.379.dist-info/RECORD +0 -360
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/WHEEL +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/licenses/LICENSE +0 -0
- {alita_sdk-0.3.379.dist-info → alita_sdk-0.3.627.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CLI Message wrapper with context tracking.
|
|
3
|
+
|
|
4
|
+
Provides a message class that tracks inclusion status, token counts,
|
|
5
|
+
and supports conversion to/from LangChain message formats.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from typing import Any, Dict, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
from .token_estimation import estimate_message_tokens
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class CLIMessage:
|
|
17
|
+
"""
|
|
18
|
+
Chat message with context management metadata.
|
|
19
|
+
|
|
20
|
+
Tracks whether the message is included in context, its token count,
|
|
21
|
+
and provides conversion to various message formats.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
role: Message role (user, assistant, system)
|
|
25
|
+
content: Message content text
|
|
26
|
+
index: Position in the full conversation history
|
|
27
|
+
token_count: Cached token count for this message
|
|
28
|
+
included: Whether message is included in LLM context
|
|
29
|
+
created_at: Timestamp when message was created
|
|
30
|
+
priority: Priority weight for importance-based pruning
|
|
31
|
+
weight: Additional weight factor for pruning decisions
|
|
32
|
+
"""
|
|
33
|
+
role: str
|
|
34
|
+
content: str
|
|
35
|
+
index: int
|
|
36
|
+
token_count: int = 0
|
|
37
|
+
included: bool = True
|
|
38
|
+
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
39
|
+
priority: float = 1.0
|
|
40
|
+
weight: float = 1.0
|
|
41
|
+
|
|
42
|
+
def __post_init__(self):
|
|
43
|
+
"""Calculate token count if not provided."""
|
|
44
|
+
if self.token_count == 0 and self.content:
|
|
45
|
+
self.token_count = estimate_message_tokens(self.role, self.content)
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def from_dict(cls, msg_dict: Dict[str, Any], index: int, model: str = 'gpt-4') -> 'CLIMessage':
|
|
49
|
+
"""
|
|
50
|
+
Create CLIMessage from a simple dict format.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
msg_dict: Dictionary with 'role' and 'content' keys
|
|
54
|
+
index: Position in conversation
|
|
55
|
+
model: Model name for token estimation
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
CLIMessage instance
|
|
59
|
+
"""
|
|
60
|
+
role = msg_dict.get('role', 'user')
|
|
61
|
+
content = msg_dict.get('content', '')
|
|
62
|
+
token_count = estimate_message_tokens(role, content, model)
|
|
63
|
+
|
|
64
|
+
return cls(
|
|
65
|
+
role=role,
|
|
66
|
+
content=content,
|
|
67
|
+
index=index,
|
|
68
|
+
token_count=token_count,
|
|
69
|
+
included=msg_dict.get('included', True),
|
|
70
|
+
priority=msg_dict.get('priority', 1.0),
|
|
71
|
+
weight=msg_dict.get('weight', 1.0),
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
75
|
+
"""
|
|
76
|
+
Convert to simple dict format for LLM calls.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
Dictionary with 'role' and 'content' keys
|
|
80
|
+
"""
|
|
81
|
+
return {
|
|
82
|
+
'role': self.role,
|
|
83
|
+
'content': self.content,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
def to_langchain_message(self) -> Any:
|
|
87
|
+
"""
|
|
88
|
+
Convert to LangChain message format.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Appropriate LangChain message type (HumanMessage, AIMessage, SystemMessage)
|
|
92
|
+
"""
|
|
93
|
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
94
|
+
|
|
95
|
+
if self.role == 'user':
|
|
96
|
+
return HumanMessage(content=self.content)
|
|
97
|
+
elif self.role == 'assistant':
|
|
98
|
+
return AIMessage(content=self.content)
|
|
99
|
+
elif self.role == 'system':
|
|
100
|
+
return SystemMessage(content=self.content)
|
|
101
|
+
else:
|
|
102
|
+
# Default to HumanMessage for unknown roles
|
|
103
|
+
return HumanMessage(content=self.content)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def from_langchain_message(cls, message: Any, index: int) -> 'CLIMessage':
|
|
107
|
+
"""
|
|
108
|
+
Create CLIMessage from LangChain message.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
message: LangChain message (HumanMessage, AIMessage, etc.)
|
|
112
|
+
index: Position in conversation
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
CLIMessage instance
|
|
116
|
+
"""
|
|
117
|
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
|
118
|
+
|
|
119
|
+
content = message.content if hasattr(message, 'content') else str(message)
|
|
120
|
+
|
|
121
|
+
if isinstance(message, HumanMessage):
|
|
122
|
+
role = 'user'
|
|
123
|
+
elif isinstance(message, AIMessage):
|
|
124
|
+
role = 'assistant'
|
|
125
|
+
elif isinstance(message, SystemMessage):
|
|
126
|
+
role = 'system'
|
|
127
|
+
else:
|
|
128
|
+
role = 'user'
|
|
129
|
+
|
|
130
|
+
return cls(
|
|
131
|
+
role=role,
|
|
132
|
+
content=content,
|
|
133
|
+
index=index,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def to_state_dict(self) -> Dict[str, Any]:
|
|
137
|
+
"""
|
|
138
|
+
Convert to state dictionary for persistence.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Dictionary with all fields for saving to context_state.json
|
|
142
|
+
"""
|
|
143
|
+
return {
|
|
144
|
+
'index': self.index,
|
|
145
|
+
'role': self.role,
|
|
146
|
+
'content': self.content,
|
|
147
|
+
'token_count': self.token_count,
|
|
148
|
+
'included': self.included,
|
|
149
|
+
'created_at': self.created_at.isoformat(),
|
|
150
|
+
'priority': self.priority,
|
|
151
|
+
'weight': self.weight,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
@classmethod
|
|
155
|
+
def from_state_dict(cls, state: Dict[str, Any]) -> 'CLIMessage':
|
|
156
|
+
"""
|
|
157
|
+
Restore CLIMessage from state dictionary.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
state: Dictionary from to_state_dict()
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
CLIMessage instance
|
|
164
|
+
"""
|
|
165
|
+
created_at = state.get('created_at')
|
|
166
|
+
if isinstance(created_at, str):
|
|
167
|
+
created_at = datetime.fromisoformat(created_at)
|
|
168
|
+
elif created_at is None:
|
|
169
|
+
created_at = datetime.now(timezone.utc)
|
|
170
|
+
|
|
171
|
+
return cls(
|
|
172
|
+
role=state['role'],
|
|
173
|
+
content=state['content'],
|
|
174
|
+
index=state['index'],
|
|
175
|
+
token_count=state.get('token_count', 0),
|
|
176
|
+
included=state.get('included', True),
|
|
177
|
+
created_at=created_at,
|
|
178
|
+
priority=state.get('priority', 1.0),
|
|
179
|
+
weight=state.get('weight', 1.0),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def meta(self) -> Dict[str, Any]:
|
|
184
|
+
"""
|
|
185
|
+
Get metadata in context_manager compatible format.
|
|
186
|
+
|
|
187
|
+
Used for compatibility with pruning strategies from context_manager.
|
|
188
|
+
"""
|
|
189
|
+
return {
|
|
190
|
+
'context': {
|
|
191
|
+
'token_count': self.token_count,
|
|
192
|
+
'weight': self.weight,
|
|
193
|
+
'priority': self.priority,
|
|
194
|
+
'included': self.included,
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def reply_to_id(self) -> Optional[int]:
|
|
200
|
+
"""
|
|
201
|
+
Get reply-to ID (for thread awareness).
|
|
202
|
+
|
|
203
|
+
For CLI messages, we use simple sequential ordering.
|
|
204
|
+
User messages reply to previous assistant, and vice versa.
|
|
205
|
+
"""
|
|
206
|
+
if self.index > 0:
|
|
207
|
+
return self.index - 1
|
|
208
|
+
return None
|
|
209
|
+
|
|
210
|
+
@property
|
|
211
|
+
def author_participant(self) -> Any:
|
|
212
|
+
"""
|
|
213
|
+
Get author participant for importance scoring.
|
|
214
|
+
|
|
215
|
+
Returns a simple object with entity_name attribute.
|
|
216
|
+
"""
|
|
217
|
+
class Participant:
|
|
218
|
+
def __init__(self, role: str):
|
|
219
|
+
self.entity_name = role
|
|
220
|
+
|
|
221
|
+
return Participant(self.role)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def messages_to_cli_messages(
|
|
225
|
+
messages: List[Dict[str, Any]],
|
|
226
|
+
model: str = 'gpt-4'
|
|
227
|
+
) -> List[CLIMessage]:
|
|
228
|
+
"""
|
|
229
|
+
Convert a list of message dicts to CLIMessage objects.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
messages: List of dicts with 'role' and 'content' keys
|
|
233
|
+
model: Model name for token estimation
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
List of CLIMessage objects
|
|
237
|
+
"""
|
|
238
|
+
return [
|
|
239
|
+
CLIMessage.from_dict(msg, index=i, model=model)
|
|
240
|
+
for i, msg in enumerate(messages)
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def cli_messages_to_dicts(
|
|
245
|
+
messages: List[CLIMessage],
|
|
246
|
+
include_only: bool = True
|
|
247
|
+
) -> List[Dict[str, str]]:
|
|
248
|
+
"""
|
|
249
|
+
Convert CLIMessage objects to simple dicts for LLM calls.
|
|
250
|
+
|
|
251
|
+
Args:
|
|
252
|
+
messages: List of CLIMessage objects
|
|
253
|
+
include_only: If True, only include messages where included=True
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
List of dicts with 'role' and 'content' keys
|
|
257
|
+
"""
|
|
258
|
+
result = []
|
|
259
|
+
for msg in messages:
|
|
260
|
+
if include_only and not msg.included:
|
|
261
|
+
continue
|
|
262
|
+
result.append(msg.to_dict())
|
|
263
|
+
return result
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def cli_messages_to_langchain(
|
|
267
|
+
messages: List[CLIMessage],
|
|
268
|
+
include_only: bool = True
|
|
269
|
+
) -> List[Any]:
|
|
270
|
+
"""
|
|
271
|
+
Convert CLIMessage objects to LangChain messages.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
messages: List of CLIMessage objects
|
|
275
|
+
include_only: If True, only include messages where included=True
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
List of LangChain message objects
|
|
279
|
+
"""
|
|
280
|
+
result = []
|
|
281
|
+
for msg in messages:
|
|
282
|
+
if include_only and not msg.included:
|
|
283
|
+
continue
|
|
284
|
+
result.append(msg.to_langchain_message())
|
|
285
|
+
return result
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Pruning strategies for CLI context management.
|
|
3
|
+
|
|
4
|
+
Implements various strategies for selecting which messages to include
|
|
5
|
+
in the LLM context when the token limit is exceeded.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Dict, List, Optional, Type, Union
|
|
11
|
+
|
|
12
|
+
from .message import CLIMessage
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class PruningConfig:
|
|
17
|
+
"""Configuration for pruning strategies."""
|
|
18
|
+
|
|
19
|
+
max_context_tokens: int = 8000
|
|
20
|
+
preserve_recent_messages: int = 5
|
|
21
|
+
pruning_method: str = 'oldest_first'
|
|
22
|
+
weights: Dict[str, float] = field(default_factory=lambda: {
|
|
23
|
+
'recency': 1.0,
|
|
24
|
+
'importance': 1.0,
|
|
25
|
+
'user_messages': 1.2,
|
|
26
|
+
'thread_continuity': 1.0,
|
|
27
|
+
})
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def from_strategy(cls, strategy_config: Dict) -> 'PruningConfig':
|
|
31
|
+
"""Create PruningConfig from strategy configuration dict."""
|
|
32
|
+
return cls(
|
|
33
|
+
max_context_tokens=strategy_config.get('max_context_tokens', 8000),
|
|
34
|
+
preserve_recent_messages=strategy_config.get('preserve_recent_messages', 5),
|
|
35
|
+
pruning_method=strategy_config.get('pruning_method', 'oldest_first'),
|
|
36
|
+
weights=strategy_config.get('weights', {
|
|
37
|
+
'recency': 1.0,
|
|
38
|
+
'importance': 1.0,
|
|
39
|
+
'user_messages': 1.2,
|
|
40
|
+
'thread_continuity': 1.0,
|
|
41
|
+
}),
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class PruningStrategy(ABC):
|
|
46
|
+
"""Abstract base class for pruning strategies."""
|
|
47
|
+
|
|
48
|
+
def __init__(self, config: PruningConfig):
|
|
49
|
+
self.config = config
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def select_messages(
|
|
53
|
+
self,
|
|
54
|
+
messages: List[CLIMessage],
|
|
55
|
+
available_tokens: int
|
|
56
|
+
) -> List[CLIMessage]:
|
|
57
|
+
"""
|
|
58
|
+
Select messages within the available token budget.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
messages: List of messages to select from
|
|
62
|
+
available_tokens: Maximum tokens available for selection
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Selected messages within token budget
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def get_token_count(message: CLIMessage) -> int:
|
|
71
|
+
"""Extract token count from message."""
|
|
72
|
+
return message.token_count
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class OldestFirstStrategy(PruningStrategy):
|
|
76
|
+
"""
|
|
77
|
+
Select messages starting from newest until token limit is reached.
|
|
78
|
+
|
|
79
|
+
This effectively drops the oldest messages first when context is full.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def select_messages(
|
|
83
|
+
self,
|
|
84
|
+
messages: List[CLIMessage],
|
|
85
|
+
available_tokens: int
|
|
86
|
+
) -> List[CLIMessage]:
|
|
87
|
+
selected = []
|
|
88
|
+
current_tokens = 0
|
|
89
|
+
|
|
90
|
+
# Sort by index descending (newest first)
|
|
91
|
+
sorted_messages = sorted(messages, key=lambda x: x.index, reverse=True)
|
|
92
|
+
|
|
93
|
+
for message in sorted_messages:
|
|
94
|
+
msg_tokens = self.get_token_count(message)
|
|
95
|
+
if current_tokens + msg_tokens <= available_tokens:
|
|
96
|
+
selected.append(message)
|
|
97
|
+
current_tokens += msg_tokens
|
|
98
|
+
else:
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
# Return in original order (oldest first)
|
|
102
|
+
selected.sort(key=lambda x: x.index)
|
|
103
|
+
return selected
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ImportanceBasedStrategy(PruningStrategy):
|
|
107
|
+
"""Select messages based on calculated importance scores."""
|
|
108
|
+
|
|
109
|
+
def select_messages(
|
|
110
|
+
self,
|
|
111
|
+
messages: List[CLIMessage],
|
|
112
|
+
available_tokens: int
|
|
113
|
+
) -> List[CLIMessage]:
|
|
114
|
+
# Calculate importance scores
|
|
115
|
+
scored_messages = []
|
|
116
|
+
for message in messages:
|
|
117
|
+
score = self._calculate_importance_score(message)
|
|
118
|
+
scored_messages.append((score, message))
|
|
119
|
+
|
|
120
|
+
# Sort by importance (highest first)
|
|
121
|
+
scored_messages.sort(key=lambda x: x[0], reverse=True)
|
|
122
|
+
|
|
123
|
+
# Select until token limit
|
|
124
|
+
selected = []
|
|
125
|
+
current_tokens = 0
|
|
126
|
+
|
|
127
|
+
for score, message in scored_messages:
|
|
128
|
+
msg_tokens = self.get_token_count(message)
|
|
129
|
+
if current_tokens + msg_tokens <= available_tokens:
|
|
130
|
+
selected.append(message)
|
|
131
|
+
current_tokens += msg_tokens
|
|
132
|
+
|
|
133
|
+
# Return in original order
|
|
134
|
+
selected.sort(key=lambda x: x.index)
|
|
135
|
+
return selected
|
|
136
|
+
|
|
137
|
+
def _calculate_importance_score(self, message: CLIMessage) -> float:
|
|
138
|
+
"""Calculate importance score for a message."""
|
|
139
|
+
base_score = message.priority
|
|
140
|
+
weight = message.weight
|
|
141
|
+
|
|
142
|
+
# Factor in message length (longer messages might be more important)
|
|
143
|
+
token_count = message.token_count
|
|
144
|
+
length_factor = min(1.5, token_count / 100) # Cap at 1.5x
|
|
145
|
+
|
|
146
|
+
# Factor in replies (messages with replies might be more important)
|
|
147
|
+
reply_factor = 1.2 if message.reply_to_id else 1.0
|
|
148
|
+
|
|
149
|
+
# Factor in user vs assistant messages
|
|
150
|
+
weights = self.config.weights
|
|
151
|
+
role_factor = 1.0
|
|
152
|
+
if message.role == 'user':
|
|
153
|
+
role_factor = weights.get('user_messages', 1.0)
|
|
154
|
+
elif message.role == 'system':
|
|
155
|
+
role_factor = 1.5 # System messages are usually important
|
|
156
|
+
|
|
157
|
+
# Factor in recency (newer messages get higher scores)
|
|
158
|
+
recency_weight = weights.get('recency', 1.0)
|
|
159
|
+
# Simple recency factor based on index
|
|
160
|
+
recency_factor = 1.0 + (message.index * 0.01 * recency_weight)
|
|
161
|
+
|
|
162
|
+
return base_score * weight * length_factor * reply_factor * role_factor * recency_factor
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class PruningStrategyFactory:
|
|
166
|
+
"""Factory for creating pruning strategy instances."""
|
|
167
|
+
|
|
168
|
+
_strategies: Dict[str, Type[PruningStrategy]] = {
|
|
169
|
+
'oldest_first': OldestFirstStrategy,
|
|
170
|
+
'importance_based': ImportanceBasedStrategy,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def create(cls, strategy_name: str, config: PruningConfig) -> PruningStrategy:
|
|
175
|
+
"""
|
|
176
|
+
Create a pruning strategy instance.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
strategy_name: Name of the strategy ('oldest_first', 'importance_based')
|
|
180
|
+
config: Configuration for the strategy
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
PruningStrategy instance
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
ValueError: If strategy_name is not recognized
|
|
187
|
+
"""
|
|
188
|
+
strategy_class = cls._strategies.get(strategy_name)
|
|
189
|
+
if not strategy_class:
|
|
190
|
+
available = list(cls._strategies.keys())
|
|
191
|
+
raise ValueError(
|
|
192
|
+
f"Unknown pruning strategy: {strategy_name}. "
|
|
193
|
+
f"Available strategies: {available}"
|
|
194
|
+
)
|
|
195
|
+
return strategy_class(config)
|
|
196
|
+
|
|
197
|
+
@classmethod
|
|
198
|
+
def register_strategy(cls, name: str, strategy_class: Type[PruningStrategy]):
|
|
199
|
+
"""Register a custom pruning strategy."""
|
|
200
|
+
cls._strategies[name] = strategy_class
|
|
201
|
+
|
|
202
|
+
@classmethod
|
|
203
|
+
def available_strategies(cls) -> List[str]:
|
|
204
|
+
"""Get list of available strategy names."""
|
|
205
|
+
return list(cls._strategies.keys())
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class PruningOrchestrator:
|
|
209
|
+
"""
|
|
210
|
+
Orchestrates the pruning process by applying strategies with token budgets.
|
|
211
|
+
|
|
212
|
+
Handles preserving recent messages, calculating available token budgets,
|
|
213
|
+
and coordinating strategy application.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(self, strategy_config: Dict):
|
|
217
|
+
"""
|
|
218
|
+
Initialize the orchestrator with strategy configuration.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
strategy_config: Configuration dictionary containing:
|
|
222
|
+
- pruning_method: Strategy name to use
|
|
223
|
+
- preserve_recent_messages: Number of recent messages to preserve
|
|
224
|
+
- max_context_tokens: Maximum tokens allowed
|
|
225
|
+
- weights: Weighting configuration for strategies
|
|
226
|
+
"""
|
|
227
|
+
self.strategy_config = strategy_config
|
|
228
|
+
self.pruning_config = PruningConfig.from_strategy(strategy_config)
|
|
229
|
+
|
|
230
|
+
def apply_pruning(
|
|
231
|
+
self,
|
|
232
|
+
messages: List[CLIMessage],
|
|
233
|
+
summaries: Optional[List[dict]] = None,
|
|
234
|
+
summary_tokens: int = 0,
|
|
235
|
+
) -> List[CLIMessage]:
|
|
236
|
+
"""
|
|
237
|
+
Apply pruning strategy to reduce context size.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
messages: List of messages to prune
|
|
241
|
+
summaries: List of summaries (for token calculation)
|
|
242
|
+
summary_tokens: Pre-calculated summary token count
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
List of messages to include in context (with included=True)
|
|
246
|
+
"""
|
|
247
|
+
strategy_method = self.strategy_config.get('pruning_method', 'oldest_first')
|
|
248
|
+
preserve_recent = self.strategy_config.get('preserve_recent_messages', 5)
|
|
249
|
+
max_tokens = self.strategy_config.get('max_context_tokens', 8000)
|
|
250
|
+
|
|
251
|
+
# Split into recent (always preserved) and older messages
|
|
252
|
+
if preserve_recent > 0 and len(messages) > preserve_recent:
|
|
253
|
+
recent_messages = messages[-preserve_recent:]
|
|
254
|
+
older_messages = messages[:-preserve_recent]
|
|
255
|
+
else:
|
|
256
|
+
recent_messages = messages
|
|
257
|
+
older_messages = []
|
|
258
|
+
|
|
259
|
+
# Calculate tokens used by preserved recent messages and summaries
|
|
260
|
+
preserved_tokens = sum(m.token_count for m in recent_messages)
|
|
261
|
+
preserved_tokens += summary_tokens
|
|
262
|
+
|
|
263
|
+
available_tokens = max_tokens - preserved_tokens
|
|
264
|
+
|
|
265
|
+
if available_tokens <= 0 or not older_messages:
|
|
266
|
+
# Mark older messages as excluded
|
|
267
|
+
for msg in older_messages:
|
|
268
|
+
msg.included = False
|
|
269
|
+
return recent_messages
|
|
270
|
+
|
|
271
|
+
# Apply pruning strategy to older messages
|
|
272
|
+
try:
|
|
273
|
+
strategy = PruningStrategyFactory.create(strategy_method, self.pruning_config)
|
|
274
|
+
selected_older = strategy.select_messages(older_messages, available_tokens)
|
|
275
|
+
except ValueError:
|
|
276
|
+
# Fallback to oldest_first
|
|
277
|
+
strategy = PruningStrategyFactory.create('oldest_first', self.pruning_config)
|
|
278
|
+
selected_older = strategy.select_messages(older_messages, available_tokens)
|
|
279
|
+
|
|
280
|
+
# Mark messages as included/excluded
|
|
281
|
+
selected_indices = {m.index for m in selected_older}
|
|
282
|
+
for msg in older_messages:
|
|
283
|
+
msg.included = msg.index in selected_indices
|
|
284
|
+
|
|
285
|
+
# Combine and sort by index
|
|
286
|
+
final_selection = selected_older + recent_messages
|
|
287
|
+
final_selection.sort(key=lambda x: x.index)
|
|
288
|
+
|
|
289
|
+
return final_selection
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Token estimation utilities for CLI context management.
|
|
3
|
+
|
|
4
|
+
Uses tiktoken for accurate token counting with fallback to character-based estimation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Optional, TYPE_CHECKING
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
import tiktoken
|
|
12
|
+
TIKTOKEN_AVAILABLE = True
|
|
13
|
+
except ImportError:
|
|
14
|
+
TIKTOKEN_AVAILABLE = False
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .message import CLIMessage
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@lru_cache(maxsize=8)
|
|
21
|
+
def get_encoding_for_model(model: str = 'gpt-4') -> Optional[object]:
|
|
22
|
+
"""
|
|
23
|
+
Get the appropriate tiktoken encoding for a given model.
|
|
24
|
+
Defaults to cl100k_base (used by most modern models) and only specifies exceptions.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
model: Model name to get encoding for
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Tiktoken encoding object or None if not available
|
|
31
|
+
"""
|
|
32
|
+
if not TIKTOKEN_AVAILABLE:
|
|
33
|
+
return None
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
# Get model encoding map from tiktoken
|
|
37
|
+
model_encoding_map = tiktoken.model.MODEL_TO_ENCODING
|
|
38
|
+
# Default to cl100k_base for unknown models (most modern models use this)
|
|
39
|
+
encoding_name = model_encoding_map.get(model.lower(), 'cl100k_base')
|
|
40
|
+
return tiktoken.get_encoding(encoding_name)
|
|
41
|
+
except Exception:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def estimate_tokens(text: str, model: str = 'gpt-4') -> int:
|
|
46
|
+
"""
|
|
47
|
+
Accurate token estimation using tiktoken.
|
|
48
|
+
Falls back to character-based estimation if tiktoken is not available.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
text: Text to estimate tokens for
|
|
52
|
+
model: Model name for encoding selection
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Estimated token count
|
|
56
|
+
"""
|
|
57
|
+
if not text or not isinstance(text, str):
|
|
58
|
+
return 0
|
|
59
|
+
|
|
60
|
+
if TIKTOKEN_AVAILABLE:
|
|
61
|
+
encoder = get_encoding_for_model(model)
|
|
62
|
+
if encoder:
|
|
63
|
+
try:
|
|
64
|
+
return len(encoder.encode(text))
|
|
65
|
+
except Exception:
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
# Fallback: Simple approximation (~4 characters per token for GPT models)
|
|
69
|
+
return max(1, len(text) // 4)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def calculate_total_tokens(
|
|
73
|
+
messages: List['CLIMessage'],
|
|
74
|
+
summaries: Optional[List[dict]] = None,
|
|
75
|
+
include_only: bool = True
|
|
76
|
+
) -> int:
|
|
77
|
+
"""
|
|
78
|
+
Calculate total tokens from messages and summaries.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: List of CLIMessage objects
|
|
82
|
+
summaries: Optional list of summary dictionaries with 'token_count' key
|
|
83
|
+
include_only: If True, only count messages where included=True
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Total token count
|
|
87
|
+
"""
|
|
88
|
+
message_tokens = 0
|
|
89
|
+
for msg in messages:
|
|
90
|
+
if include_only and not msg.included:
|
|
91
|
+
continue
|
|
92
|
+
message_tokens += msg.token_count
|
|
93
|
+
|
|
94
|
+
summary_tokens = 0
|
|
95
|
+
if summaries:
|
|
96
|
+
for summary in summaries:
|
|
97
|
+
summary_tokens += summary.get('token_count', 0)
|
|
98
|
+
|
|
99
|
+
return message_tokens + summary_tokens
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def estimate_message_tokens(role: str, content: str, model: str = 'gpt-4') -> int:
|
|
103
|
+
"""
|
|
104
|
+
Estimate tokens for a chat message including role overhead.
|
|
105
|
+
|
|
106
|
+
Chat messages have overhead for role tokens and message formatting.
|
|
107
|
+
This provides a more accurate estimate for chat-style messages.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
role: Message role (user, assistant, system)
|
|
111
|
+
content: Message content
|
|
112
|
+
model: Model name for encoding
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Estimated token count including overhead
|
|
116
|
+
"""
|
|
117
|
+
# Base content tokens
|
|
118
|
+
content_tokens = estimate_tokens(content, model)
|
|
119
|
+
|
|
120
|
+
# Add overhead for message formatting (role, separators, etc.)
|
|
121
|
+
# Most chat models add ~4 tokens per message for formatting
|
|
122
|
+
overhead = 4
|
|
123
|
+
|
|
124
|
+
# Role name tokens (typically 1-2 tokens)
|
|
125
|
+
role_tokens = estimate_tokens(role, model)
|
|
126
|
+
|
|
127
|
+
return content_tokens + overhead + role_tokens
|