aider-ce 0.88.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aider/__init__.py +20 -0
- aider/__main__.py +4 -0
- aider/_version.py +34 -0
- aider/analytics.py +258 -0
- aider/args.py +1056 -0
- aider/args_formatter.py +228 -0
- aider/change_tracker.py +133 -0
- aider/coders/__init__.py +36 -0
- aider/coders/agent_coder.py +2166 -0
- aider/coders/agent_prompts.py +104 -0
- aider/coders/architect_coder.py +48 -0
- aider/coders/architect_prompts.py +40 -0
- aider/coders/ask_coder.py +9 -0
- aider/coders/ask_prompts.py +35 -0
- aider/coders/base_coder.py +3613 -0
- aider/coders/base_prompts.py +87 -0
- aider/coders/chat_chunks.py +64 -0
- aider/coders/context_coder.py +53 -0
- aider/coders/context_prompts.py +75 -0
- aider/coders/editblock_coder.py +657 -0
- aider/coders/editblock_fenced_coder.py +10 -0
- aider/coders/editblock_fenced_prompts.py +143 -0
- aider/coders/editblock_func_coder.py +141 -0
- aider/coders/editblock_func_prompts.py +27 -0
- aider/coders/editblock_prompts.py +175 -0
- aider/coders/editor_diff_fenced_coder.py +9 -0
- aider/coders/editor_diff_fenced_prompts.py +11 -0
- aider/coders/editor_editblock_coder.py +9 -0
- aider/coders/editor_editblock_prompts.py +21 -0
- aider/coders/editor_whole_coder.py +9 -0
- aider/coders/editor_whole_prompts.py +12 -0
- aider/coders/help_coder.py +16 -0
- aider/coders/help_prompts.py +46 -0
- aider/coders/patch_coder.py +706 -0
- aider/coders/patch_prompts.py +159 -0
- aider/coders/search_replace.py +757 -0
- aider/coders/shell.py +37 -0
- aider/coders/single_wholefile_func_coder.py +102 -0
- aider/coders/single_wholefile_func_prompts.py +27 -0
- aider/coders/udiff_coder.py +429 -0
- aider/coders/udiff_prompts.py +115 -0
- aider/coders/udiff_simple.py +14 -0
- aider/coders/udiff_simple_prompts.py +25 -0
- aider/coders/wholefile_coder.py +144 -0
- aider/coders/wholefile_func_coder.py +134 -0
- aider/coders/wholefile_func_prompts.py +27 -0
- aider/coders/wholefile_prompts.py +65 -0
- aider/commands.py +2173 -0
- aider/copypaste.py +72 -0
- aider/deprecated.py +126 -0
- aider/diffs.py +128 -0
- aider/dump.py +29 -0
- aider/editor.py +147 -0
- aider/exceptions.py +115 -0
- aider/format_settings.py +26 -0
- aider/gui.py +545 -0
- aider/help.py +163 -0
- aider/help_pats.py +19 -0
- aider/helpers/__init__.py +9 -0
- aider/helpers/similarity.py +98 -0
- aider/history.py +180 -0
- aider/io.py +1608 -0
- aider/linter.py +304 -0
- aider/llm.py +55 -0
- aider/main.py +1415 -0
- aider/mcp/__init__.py +174 -0
- aider/mcp/server.py +149 -0
- aider/mdstream.py +243 -0
- aider/models.py +1313 -0
- aider/onboarding.py +429 -0
- aider/openrouter.py +129 -0
- aider/prompts.py +56 -0
- aider/queries/tree-sitter-language-pack/README.md +7 -0
- aider/queries/tree-sitter-language-pack/arduino-tags.scm +5 -0
- aider/queries/tree-sitter-language-pack/c-tags.scm +9 -0
- aider/queries/tree-sitter-language-pack/chatito-tags.scm +16 -0
- aider/queries/tree-sitter-language-pack/clojure-tags.scm +7 -0
- aider/queries/tree-sitter-language-pack/commonlisp-tags.scm +122 -0
- aider/queries/tree-sitter-language-pack/cpp-tags.scm +15 -0
- aider/queries/tree-sitter-language-pack/csharp-tags.scm +26 -0
- aider/queries/tree-sitter-language-pack/d-tags.scm +26 -0
- aider/queries/tree-sitter-language-pack/dart-tags.scm +92 -0
- aider/queries/tree-sitter-language-pack/elisp-tags.scm +5 -0
- aider/queries/tree-sitter-language-pack/elixir-tags.scm +54 -0
- aider/queries/tree-sitter-language-pack/elm-tags.scm +19 -0
- aider/queries/tree-sitter-language-pack/gleam-tags.scm +41 -0
- aider/queries/tree-sitter-language-pack/go-tags.scm +42 -0
- aider/queries/tree-sitter-language-pack/java-tags.scm +20 -0
- aider/queries/tree-sitter-language-pack/javascript-tags.scm +88 -0
- aider/queries/tree-sitter-language-pack/lua-tags.scm +34 -0
- aider/queries/tree-sitter-language-pack/matlab-tags.scm +10 -0
- aider/queries/tree-sitter-language-pack/ocaml-tags.scm +115 -0
- aider/queries/tree-sitter-language-pack/ocaml_interface-tags.scm +98 -0
- aider/queries/tree-sitter-language-pack/pony-tags.scm +39 -0
- aider/queries/tree-sitter-language-pack/properties-tags.scm +5 -0
- aider/queries/tree-sitter-language-pack/python-tags.scm +14 -0
- aider/queries/tree-sitter-language-pack/r-tags.scm +21 -0
- aider/queries/tree-sitter-language-pack/racket-tags.scm +12 -0
- aider/queries/tree-sitter-language-pack/ruby-tags.scm +64 -0
- aider/queries/tree-sitter-language-pack/rust-tags.scm +60 -0
- aider/queries/tree-sitter-language-pack/solidity-tags.scm +43 -0
- aider/queries/tree-sitter-language-pack/swift-tags.scm +51 -0
- aider/queries/tree-sitter-language-pack/udev-tags.scm +20 -0
- aider/queries/tree-sitter-languages/README.md +24 -0
- aider/queries/tree-sitter-languages/c-tags.scm +9 -0
- aider/queries/tree-sitter-languages/c_sharp-tags.scm +46 -0
- aider/queries/tree-sitter-languages/cpp-tags.scm +15 -0
- aider/queries/tree-sitter-languages/dart-tags.scm +91 -0
- aider/queries/tree-sitter-languages/elisp-tags.scm +8 -0
- aider/queries/tree-sitter-languages/elixir-tags.scm +54 -0
- aider/queries/tree-sitter-languages/elm-tags.scm +19 -0
- aider/queries/tree-sitter-languages/fortran-tags.scm +15 -0
- aider/queries/tree-sitter-languages/go-tags.scm +30 -0
- aider/queries/tree-sitter-languages/haskell-tags.scm +3 -0
- aider/queries/tree-sitter-languages/hcl-tags.scm +77 -0
- aider/queries/tree-sitter-languages/java-tags.scm +20 -0
- aider/queries/tree-sitter-languages/javascript-tags.scm +88 -0
- aider/queries/tree-sitter-languages/julia-tags.scm +60 -0
- aider/queries/tree-sitter-languages/kotlin-tags.scm +27 -0
- aider/queries/tree-sitter-languages/matlab-tags.scm +10 -0
- aider/queries/tree-sitter-languages/ocaml-tags.scm +115 -0
- aider/queries/tree-sitter-languages/ocaml_interface-tags.scm +98 -0
- aider/queries/tree-sitter-languages/php-tags.scm +26 -0
- aider/queries/tree-sitter-languages/python-tags.scm +12 -0
- aider/queries/tree-sitter-languages/ql-tags.scm +26 -0
- aider/queries/tree-sitter-languages/ruby-tags.scm +64 -0
- aider/queries/tree-sitter-languages/rust-tags.scm +60 -0
- aider/queries/tree-sitter-languages/scala-tags.scm +65 -0
- aider/queries/tree-sitter-languages/typescript-tags.scm +41 -0
- aider/queries/tree-sitter-languages/zig-tags.scm +3 -0
- aider/reasoning_tags.py +82 -0
- aider/repo.py +621 -0
- aider/repomap.py +1174 -0
- aider/report.py +260 -0
- aider/resources/__init__.py +3 -0
- aider/resources/model-metadata.json +776 -0
- aider/resources/model-settings.yml +2068 -0
- aider/run_cmd.py +133 -0
- aider/scrape.py +293 -0
- aider/sendchat.py +242 -0
- aider/sessions.py +256 -0
- aider/special.py +203 -0
- aider/tools/__init__.py +72 -0
- aider/tools/command.py +105 -0
- aider/tools/command_interactive.py +122 -0
- aider/tools/delete_block.py +182 -0
- aider/tools/delete_line.py +155 -0
- aider/tools/delete_lines.py +184 -0
- aider/tools/extract_lines.py +341 -0
- aider/tools/finished.py +48 -0
- aider/tools/git_branch.py +129 -0
- aider/tools/git_diff.py +60 -0
- aider/tools/git_log.py +57 -0
- aider/tools/git_remote.py +53 -0
- aider/tools/git_show.py +51 -0
- aider/tools/git_status.py +46 -0
- aider/tools/grep.py +256 -0
- aider/tools/indent_lines.py +221 -0
- aider/tools/insert_block.py +288 -0
- aider/tools/list_changes.py +86 -0
- aider/tools/ls.py +93 -0
- aider/tools/make_editable.py +85 -0
- aider/tools/make_readonly.py +69 -0
- aider/tools/remove.py +91 -0
- aider/tools/replace_all.py +126 -0
- aider/tools/replace_line.py +173 -0
- aider/tools/replace_lines.py +217 -0
- aider/tools/replace_text.py +187 -0
- aider/tools/show_numbered_context.py +147 -0
- aider/tools/tool_utils.py +313 -0
- aider/tools/undo_change.py +95 -0
- aider/tools/update_todo_list.py +156 -0
- aider/tools/view.py +57 -0
- aider/tools/view_files_matching.py +141 -0
- aider/tools/view_files_with_symbol.py +129 -0
- aider/urls.py +17 -0
- aider/utils.py +456 -0
- aider/versioncheck.py +113 -0
- aider/voice.py +205 -0
- aider/waiting.py +38 -0
- aider/watch.py +318 -0
- aider/watch_prompts.py +12 -0
- aider/website/Gemfile +8 -0
- aider/website/_includes/blame.md +162 -0
- aider/website/_includes/get-started.md +22 -0
- aider/website/_includes/help-tip.md +5 -0
- aider/website/_includes/help.md +24 -0
- aider/website/_includes/install.md +5 -0
- aider/website/_includes/keys.md +4 -0
- aider/website/_includes/model-warnings.md +67 -0
- aider/website/_includes/multi-line.md +22 -0
- aider/website/_includes/python-m-aider.md +5 -0
- aider/website/_includes/recording.css +228 -0
- aider/website/_includes/recording.md +34 -0
- aider/website/_includes/replit-pipx.md +9 -0
- aider/website/_includes/works-best.md +1 -0
- aider/website/_sass/custom/custom.scss +103 -0
- aider/website/docs/config/adv-model-settings.md +2261 -0
- aider/website/docs/config/agent-mode.md +194 -0
- aider/website/docs/config/aider_conf.md +548 -0
- aider/website/docs/config/api-keys.md +90 -0
- aider/website/docs/config/dotenv.md +493 -0
- aider/website/docs/config/editor.md +127 -0
- aider/website/docs/config/mcp.md +95 -0
- aider/website/docs/config/model-aliases.md +104 -0
- aider/website/docs/config/options.md +890 -0
- aider/website/docs/config/reasoning.md +210 -0
- aider/website/docs/config.md +44 -0
- aider/website/docs/faq.md +384 -0
- aider/website/docs/git.md +76 -0
- aider/website/docs/index.md +47 -0
- aider/website/docs/install/codespaces.md +39 -0
- aider/website/docs/install/docker.md +57 -0
- aider/website/docs/install/optional.md +100 -0
- aider/website/docs/install/replit.md +8 -0
- aider/website/docs/install.md +115 -0
- aider/website/docs/languages.md +264 -0
- aider/website/docs/legal/contributor-agreement.md +111 -0
- aider/website/docs/legal/privacy.md +104 -0
- aider/website/docs/llms/anthropic.md +77 -0
- aider/website/docs/llms/azure.md +48 -0
- aider/website/docs/llms/bedrock.md +132 -0
- aider/website/docs/llms/cohere.md +34 -0
- aider/website/docs/llms/deepseek.md +32 -0
- aider/website/docs/llms/gemini.md +49 -0
- aider/website/docs/llms/github.md +111 -0
- aider/website/docs/llms/groq.md +36 -0
- aider/website/docs/llms/lm-studio.md +39 -0
- aider/website/docs/llms/ollama.md +75 -0
- aider/website/docs/llms/openai-compat.md +39 -0
- aider/website/docs/llms/openai.md +58 -0
- aider/website/docs/llms/openrouter.md +78 -0
- aider/website/docs/llms/other.md +117 -0
- aider/website/docs/llms/vertex.md +50 -0
- aider/website/docs/llms/warnings.md +10 -0
- aider/website/docs/llms/xai.md +53 -0
- aider/website/docs/llms.md +54 -0
- aider/website/docs/more/analytics.md +127 -0
- aider/website/docs/more/edit-formats.md +116 -0
- aider/website/docs/more/infinite-output.md +165 -0
- aider/website/docs/more-info.md +8 -0
- aider/website/docs/recordings/auto-accept-architect.md +31 -0
- aider/website/docs/recordings/dont-drop-original-read-files.md +35 -0
- aider/website/docs/recordings/index.md +21 -0
- aider/website/docs/recordings/model-accepts-settings.md +69 -0
- aider/website/docs/recordings/tree-sitter-language-pack.md +80 -0
- aider/website/docs/repomap.md +112 -0
- aider/website/docs/scripting.md +100 -0
- aider/website/docs/sessions.md +203 -0
- aider/website/docs/troubleshooting/aider-not-found.md +24 -0
- aider/website/docs/troubleshooting/edit-errors.md +76 -0
- aider/website/docs/troubleshooting/imports.md +62 -0
- aider/website/docs/troubleshooting/models-and-keys.md +54 -0
- aider/website/docs/troubleshooting/support.md +79 -0
- aider/website/docs/troubleshooting/token-limits.md +96 -0
- aider/website/docs/troubleshooting/warnings.md +12 -0
- aider/website/docs/troubleshooting.md +11 -0
- aider/website/docs/usage/browser.md +57 -0
- aider/website/docs/usage/caching.md +49 -0
- aider/website/docs/usage/commands.md +133 -0
- aider/website/docs/usage/conventions.md +119 -0
- aider/website/docs/usage/copypaste.md +121 -0
- aider/website/docs/usage/images-urls.md +48 -0
- aider/website/docs/usage/lint-test.md +118 -0
- aider/website/docs/usage/modes.md +211 -0
- aider/website/docs/usage/not-code.md +179 -0
- aider/website/docs/usage/notifications.md +87 -0
- aider/website/docs/usage/tips.md +79 -0
- aider/website/docs/usage/tutorials.md +30 -0
- aider/website/docs/usage/voice.md +121 -0
- aider/website/docs/usage/watch.md +294 -0
- aider/website/docs/usage.md +102 -0
- aider/website/share/index.md +101 -0
- aider_ce-0.88.20.dist-info/METADATA +187 -0
- aider_ce-0.88.20.dist-info/RECORD +279 -0
- aider_ce-0.88.20.dist-info/WHEEL +5 -0
- aider_ce-0.88.20.dist-info/entry_points.txt +2 -0
- aider_ce-0.88.20.dist-info/licenses/LICENSE.txt +202 -0
- aider_ce-0.88.20.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,3613 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import locale
|
|
8
|
+
import math
|
|
9
|
+
import mimetypes
|
|
10
|
+
import os
|
|
11
|
+
import platform
|
|
12
|
+
import re
|
|
13
|
+
import sys
|
|
14
|
+
import threading
|
|
15
|
+
import time
|
|
16
|
+
import traceback
|
|
17
|
+
import weakref
|
|
18
|
+
from collections import defaultdict
|
|
19
|
+
from datetime import datetime
|
|
20
|
+
|
|
21
|
+
# Optional dependency: used to convert locale codes (eg ``en_US``)
|
|
22
|
+
# into human-readable language names (eg ``English``).
|
|
23
|
+
try:
|
|
24
|
+
from babel import Locale # type: ignore
|
|
25
|
+
except ImportError: # Babel not installed – we will fall back to a small mapping
|
|
26
|
+
Locale = None
|
|
27
|
+
from json.decoder import JSONDecodeError
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
from typing import List
|
|
30
|
+
|
|
31
|
+
import httpx
|
|
32
|
+
from litellm import experimental_mcp_client
|
|
33
|
+
from litellm.types.utils import (
|
|
34
|
+
ChatCompletionMessageToolCall,
|
|
35
|
+
Choices,
|
|
36
|
+
Function,
|
|
37
|
+
Message,
|
|
38
|
+
ModelResponse,
|
|
39
|
+
)
|
|
40
|
+
from prompt_toolkit.patch_stdout import patch_stdout
|
|
41
|
+
from rich.console import Console
|
|
42
|
+
|
|
43
|
+
from aider import __version__, models, prompts, urls, utils
|
|
44
|
+
from aider.analytics import Analytics
|
|
45
|
+
from aider.commands import Commands, SwitchCoder
|
|
46
|
+
from aider.exceptions import LiteLLMExceptions
|
|
47
|
+
from aider.history import ChatSummary
|
|
48
|
+
from aider.io import ConfirmGroup, InputOutput
|
|
49
|
+
from aider.linter import Linter
|
|
50
|
+
from aider.llm import litellm
|
|
51
|
+
from aider.mcp.server import LocalServer
|
|
52
|
+
from aider.models import RETRY_TIMEOUT
|
|
53
|
+
from aider.reasoning_tags import (
|
|
54
|
+
REASONING_TAG,
|
|
55
|
+
format_reasoning_content,
|
|
56
|
+
remove_reasoning_content,
|
|
57
|
+
replace_reasoning_tags,
|
|
58
|
+
)
|
|
59
|
+
from aider.repo import ANY_GIT_ERROR, GitRepo
|
|
60
|
+
from aider.repomap import RepoMap
|
|
61
|
+
from aider.run_cmd import run_cmd
|
|
62
|
+
from aider.sessions import SessionManager
|
|
63
|
+
from aider.utils import format_tokens, is_image_file
|
|
64
|
+
|
|
65
|
+
from ..dump import dump # noqa: F401
|
|
66
|
+
from .chat_chunks import ChatChunks
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class UnknownEditFormat(ValueError):
|
|
70
|
+
def __init__(self, edit_format, valid_formats):
|
|
71
|
+
self.edit_format = edit_format
|
|
72
|
+
self.valid_formats = valid_formats
|
|
73
|
+
super().__init__(
|
|
74
|
+
f"Unknown edit format {edit_format}. Valid formats are: {', '.join(valid_formats)}"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class MissingAPIKeyError(ValueError):
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class FinishReasonLength(Exception):
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def wrap_fence(name):
|
|
87
|
+
return f"<{name}>", f"</{name}>"
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
all_fences = [
|
|
91
|
+
("`" * 3, "`" * 3),
|
|
92
|
+
("`" * 4, "`" * 4), # LLMs ignore and revert to triple-backtick, causing #2879
|
|
93
|
+
wrap_fence("source"),
|
|
94
|
+
wrap_fence("code"),
|
|
95
|
+
wrap_fence("pre"),
|
|
96
|
+
wrap_fence("codeblock"),
|
|
97
|
+
wrap_fence("sourcecode"),
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class Coder:
|
|
102
|
+
abs_fnames = None
|
|
103
|
+
abs_read_only_fnames = None
|
|
104
|
+
abs_read_only_stubs_fnames = None
|
|
105
|
+
repo = None
|
|
106
|
+
last_aider_commit_hash = None
|
|
107
|
+
aider_edited_files = None
|
|
108
|
+
last_asked_for_commit_time = 0
|
|
109
|
+
repo_map = None
|
|
110
|
+
functions = None
|
|
111
|
+
num_exhausted_context_windows = 0
|
|
112
|
+
num_malformed_responses = 0
|
|
113
|
+
last_keyboard_interrupt = None
|
|
114
|
+
num_reflections = 0
|
|
115
|
+
max_reflections = 3
|
|
116
|
+
num_tool_calls = 0
|
|
117
|
+
max_tool_calls = 25
|
|
118
|
+
edit_format = None
|
|
119
|
+
yield_stream = False
|
|
120
|
+
temperature = None
|
|
121
|
+
auto_lint = True
|
|
122
|
+
auto_test = False
|
|
123
|
+
test_cmd = None
|
|
124
|
+
lint_outcome = None
|
|
125
|
+
test_outcome = None
|
|
126
|
+
multi_response_content = ""
|
|
127
|
+
partial_response_content = ""
|
|
128
|
+
partial_response_tool_calls = []
|
|
129
|
+
commit_before_message = []
|
|
130
|
+
message_cost = 0.0
|
|
131
|
+
add_cache_headers = False
|
|
132
|
+
cache_warming_thread = None
|
|
133
|
+
num_cache_warming_pings = 0
|
|
134
|
+
suggest_shell_commands = True
|
|
135
|
+
detect_urls = True
|
|
136
|
+
ignore_mentions = None
|
|
137
|
+
chat_language = None
|
|
138
|
+
commit_language = None
|
|
139
|
+
file_watcher = None
|
|
140
|
+
mcp_servers = None
|
|
141
|
+
mcp_tools = None
|
|
142
|
+
run_one_completed = True
|
|
143
|
+
compact_context_completed = True
|
|
144
|
+
suppress_announcements_for_next_prompt = False
|
|
145
|
+
tool_reflection = False
|
|
146
|
+
|
|
147
|
+
# Context management settings (for all modes)
|
|
148
|
+
context_management_enabled = False # Disabled by default except for agent mode
|
|
149
|
+
large_file_token_threshold = (
|
|
150
|
+
25000 # Files larger than this will be truncated when context management is enabled
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
@classmethod
|
|
154
|
+
async def create(
|
|
155
|
+
self,
|
|
156
|
+
main_model=None,
|
|
157
|
+
edit_format=None,
|
|
158
|
+
io=None,
|
|
159
|
+
from_coder=None,
|
|
160
|
+
summarize_from_coder=True,
|
|
161
|
+
args=None,
|
|
162
|
+
**kwargs,
|
|
163
|
+
):
|
|
164
|
+
import aider.coders as coders
|
|
165
|
+
|
|
166
|
+
if not main_model:
|
|
167
|
+
if from_coder:
|
|
168
|
+
main_model = from_coder.main_model
|
|
169
|
+
else:
|
|
170
|
+
main_model = models.Model(models.DEFAULT_MODEL_NAME)
|
|
171
|
+
|
|
172
|
+
if edit_format == "code":
|
|
173
|
+
edit_format = None
|
|
174
|
+
if edit_format is None:
|
|
175
|
+
if from_coder:
|
|
176
|
+
edit_format = from_coder.edit_format
|
|
177
|
+
else:
|
|
178
|
+
edit_format = main_model.edit_format
|
|
179
|
+
|
|
180
|
+
if not io and from_coder:
|
|
181
|
+
io = from_coder.io
|
|
182
|
+
|
|
183
|
+
if from_coder:
|
|
184
|
+
use_kwargs = dict(from_coder.original_kwargs) # copy orig kwargs
|
|
185
|
+
|
|
186
|
+
# If the edit format changes, we can't leave old ASSISTANT
|
|
187
|
+
# messages in the chat history. The old edit format will
|
|
188
|
+
# confused the new LLM. It may try and imitate it, disobeying
|
|
189
|
+
# the system prompt.
|
|
190
|
+
done_messages = from_coder.done_messages
|
|
191
|
+
if edit_format != from_coder.edit_format and done_messages and summarize_from_coder:
|
|
192
|
+
try:
|
|
193
|
+
done_messages = await from_coder.summarizer.summarize_all(done_messages)
|
|
194
|
+
except ValueError:
|
|
195
|
+
# If summarization fails, keep the original messages and warn the user
|
|
196
|
+
io.tool_warning(
|
|
197
|
+
"Chat history summarization failed, continuing with full history"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Bring along context from the old Coder
|
|
201
|
+
update = dict(
|
|
202
|
+
fnames=list(from_coder.abs_fnames),
|
|
203
|
+
read_only_fnames=list(from_coder.abs_read_only_fnames), # Copy read-only files
|
|
204
|
+
read_only_stubs_fnames=list(
|
|
205
|
+
from_coder.abs_read_only_stubs_fnames
|
|
206
|
+
), # Copy read-only stubs
|
|
207
|
+
done_messages=done_messages,
|
|
208
|
+
cur_messages=from_coder.cur_messages,
|
|
209
|
+
aider_commit_hashes=from_coder.aider_commit_hashes,
|
|
210
|
+
commands=from_coder.commands.clone(),
|
|
211
|
+
total_cost=from_coder.total_cost,
|
|
212
|
+
ignore_mentions=from_coder.ignore_mentions,
|
|
213
|
+
total_tokens_sent=from_coder.total_tokens_sent,
|
|
214
|
+
total_tokens_received=from_coder.total_tokens_received,
|
|
215
|
+
file_watcher=from_coder.file_watcher,
|
|
216
|
+
)
|
|
217
|
+
use_kwargs.update(update) # override to complete the switch
|
|
218
|
+
use_kwargs.update(kwargs) # override passed kwargs
|
|
219
|
+
|
|
220
|
+
kwargs = use_kwargs
|
|
221
|
+
from_coder.ok_to_warm_cache = False
|
|
222
|
+
|
|
223
|
+
for coder in coders.__all__:
|
|
224
|
+
if hasattr(coder, "edit_format") and coder.edit_format == edit_format:
|
|
225
|
+
res = coder(main_model, io, args=args, **kwargs)
|
|
226
|
+
await res.initialize_mcp_tools()
|
|
227
|
+
res.original_kwargs = dict(kwargs)
|
|
228
|
+
return res
|
|
229
|
+
|
|
230
|
+
valid_formats = [
|
|
231
|
+
str(c.edit_format)
|
|
232
|
+
for c in coders.__all__
|
|
233
|
+
if hasattr(c, "edit_format") and c.edit_format is not None
|
|
234
|
+
]
|
|
235
|
+
raise UnknownEditFormat(edit_format, valid_formats)
|
|
236
|
+
|
|
237
|
+
async def clone(self, **kwargs):
|
|
238
|
+
new_coder = await Coder.create(from_coder=self, **kwargs)
|
|
239
|
+
return new_coder
|
|
240
|
+
|
|
241
|
+
def get_announcements(self):
|
|
242
|
+
lines = []
|
|
243
|
+
lines.append(f"Aider v{__version__}")
|
|
244
|
+
|
|
245
|
+
# Model
|
|
246
|
+
main_model = self.main_model
|
|
247
|
+
weak_model = main_model.weak_model
|
|
248
|
+
|
|
249
|
+
if weak_model is not main_model:
|
|
250
|
+
prefix = "Main model"
|
|
251
|
+
else:
|
|
252
|
+
prefix = "Model"
|
|
253
|
+
|
|
254
|
+
output = f"{prefix}: {main_model.name} with {self.edit_format} edit format"
|
|
255
|
+
|
|
256
|
+
# Check for thinking token budget
|
|
257
|
+
thinking_tokens = main_model.get_thinking_tokens()
|
|
258
|
+
if thinking_tokens:
|
|
259
|
+
output += f", {thinking_tokens} think tokens"
|
|
260
|
+
|
|
261
|
+
# Check for reasoning effort
|
|
262
|
+
reasoning_effort = main_model.get_reasoning_effort()
|
|
263
|
+
if reasoning_effort:
|
|
264
|
+
output += f", reasoning {reasoning_effort}"
|
|
265
|
+
|
|
266
|
+
if self.add_cache_headers or main_model.caches_by_default:
|
|
267
|
+
output += ", prompt cache"
|
|
268
|
+
if main_model.info.get("supports_assistant_prefill"):
|
|
269
|
+
output += ", infinite output"
|
|
270
|
+
|
|
271
|
+
lines.append(output)
|
|
272
|
+
|
|
273
|
+
if self.edit_format == "architect":
|
|
274
|
+
output = (
|
|
275
|
+
f"Editor model: {main_model.editor_model.name} with"
|
|
276
|
+
f" {main_model.editor_edit_format} edit format"
|
|
277
|
+
)
|
|
278
|
+
lines.append(output)
|
|
279
|
+
|
|
280
|
+
if weak_model is not main_model:
|
|
281
|
+
output = f"Weak model: {weak_model.name}"
|
|
282
|
+
lines.append(output)
|
|
283
|
+
|
|
284
|
+
# Repo
|
|
285
|
+
if self.repo:
|
|
286
|
+
rel_repo_dir = self.repo.get_rel_repo_dir()
|
|
287
|
+
num_files = len(self.repo.get_tracked_files())
|
|
288
|
+
|
|
289
|
+
lines.append(f"Git repo: {rel_repo_dir} with {num_files:,} files")
|
|
290
|
+
if num_files > 1000:
|
|
291
|
+
lines.append(
|
|
292
|
+
"Warning: For large repos, consider using --subtree-only and .aiderignore"
|
|
293
|
+
)
|
|
294
|
+
lines.append(f"See: {urls.large_repos}")
|
|
295
|
+
else:
|
|
296
|
+
lines.append("Git repo: none")
|
|
297
|
+
|
|
298
|
+
# Repo-map
|
|
299
|
+
if self.repo_map:
|
|
300
|
+
map_tokens = self.repo_map.max_map_tokens
|
|
301
|
+
if map_tokens > 0:
|
|
302
|
+
refresh = self.repo_map.refresh
|
|
303
|
+
lines.append(f"Repo-map: using {map_tokens} tokens, {refresh} refresh")
|
|
304
|
+
max_map_tokens = self.main_model.get_repo_map_tokens() * 2
|
|
305
|
+
if map_tokens > max_map_tokens:
|
|
306
|
+
lines.append(
|
|
307
|
+
f"Warning: map-tokens > {max_map_tokens} is not recommended. Too much"
|
|
308
|
+
" irrelevant code can confuse LLMs."
|
|
309
|
+
)
|
|
310
|
+
else:
|
|
311
|
+
lines.append("Repo-map: disabled because map_tokens == 0")
|
|
312
|
+
else:
|
|
313
|
+
lines.append("Repo-map: disabled")
|
|
314
|
+
|
|
315
|
+
if self.mcp_tools:
|
|
316
|
+
mcp_servers = []
|
|
317
|
+
for server_name, server_tools in self.mcp_tools:
|
|
318
|
+
mcp_servers.append(server_name)
|
|
319
|
+
lines.append(f"MCP servers configured: {', '.join(mcp_servers)}")
|
|
320
|
+
|
|
321
|
+
for fname in self.abs_read_only_stubs_fnames:
|
|
322
|
+
rel_fname = self.get_rel_fname(fname)
|
|
323
|
+
lines.append(f"Added {rel_fname} to the chat (read-only stub).")
|
|
324
|
+
|
|
325
|
+
if self.done_messages:
|
|
326
|
+
lines.append("Restored previous conversation history.")
|
|
327
|
+
|
|
328
|
+
if self.io.multiline_mode:
|
|
329
|
+
lines.append("Multiline mode: Enabled. Enter inserts newline, Alt-Enter submits text")
|
|
330
|
+
|
|
331
|
+
return lines
|
|
332
|
+
|
|
333
|
+
ok_to_warm_cache = False
|
|
334
|
+
|
|
335
|
+
def __init__(
|
|
336
|
+
self,
|
|
337
|
+
main_model,
|
|
338
|
+
io,
|
|
339
|
+
args=None,
|
|
340
|
+
repo=None,
|
|
341
|
+
fnames=None,
|
|
342
|
+
add_gitignore_files=False,
|
|
343
|
+
read_only_fnames=None,
|
|
344
|
+
read_only_stubs_fnames=None,
|
|
345
|
+
show_diffs=False,
|
|
346
|
+
auto_commits=True,
|
|
347
|
+
dirty_commits=True,
|
|
348
|
+
dry_run=False,
|
|
349
|
+
map_tokens=1024,
|
|
350
|
+
verbose=False,
|
|
351
|
+
stream=True,
|
|
352
|
+
use_git=True,
|
|
353
|
+
cur_messages=None,
|
|
354
|
+
done_messages=None,
|
|
355
|
+
restore_chat_history=False,
|
|
356
|
+
auto_lint=True,
|
|
357
|
+
auto_test=False,
|
|
358
|
+
lint_cmds=None,
|
|
359
|
+
test_cmd=None,
|
|
360
|
+
aider_commit_hashes=None,
|
|
361
|
+
map_mul_no_files=8,
|
|
362
|
+
map_max_line_length=100,
|
|
363
|
+
commands=None,
|
|
364
|
+
summarizer=None,
|
|
365
|
+
total_cost=0.0,
|
|
366
|
+
analytics=None,
|
|
367
|
+
map_refresh="auto",
|
|
368
|
+
cache_prompts=False,
|
|
369
|
+
num_cache_warming_pings=0,
|
|
370
|
+
suggest_shell_commands=True,
|
|
371
|
+
chat_language=None,
|
|
372
|
+
commit_language=None,
|
|
373
|
+
detect_urls=True,
|
|
374
|
+
ignore_mentions=None,
|
|
375
|
+
total_tokens_sent=0,
|
|
376
|
+
total_tokens_received=0,
|
|
377
|
+
file_watcher=None,
|
|
378
|
+
auto_copy_context=False,
|
|
379
|
+
auto_accept_architect=True,
|
|
380
|
+
mcp_servers=None,
|
|
381
|
+
enable_context_compaction=False,
|
|
382
|
+
context_compaction_max_tokens=None,
|
|
383
|
+
context_compaction_summary_tokens=8192,
|
|
384
|
+
map_cache_dir=".",
|
|
385
|
+
repomap_in_memory=False,
|
|
386
|
+
preserve_todo_list=False,
|
|
387
|
+
linear_output=False,
|
|
388
|
+
):
|
|
389
|
+
# initialize from args.map_cache_dir
|
|
390
|
+
self.map_cache_dir = map_cache_dir
|
|
391
|
+
|
|
392
|
+
# Fill in a dummy Analytics if needed, but it is never .enable()'d
|
|
393
|
+
self.analytics = analytics if analytics is not None else Analytics()
|
|
394
|
+
|
|
395
|
+
self.event = self.analytics.event
|
|
396
|
+
self.chat_language = chat_language
|
|
397
|
+
self.commit_language = commit_language
|
|
398
|
+
self.commit_before_message = []
|
|
399
|
+
self.aider_commit_hashes = set()
|
|
400
|
+
self.rejected_urls = set()
|
|
401
|
+
self.abs_root_path_cache = {}
|
|
402
|
+
|
|
403
|
+
self.auto_copy_context = auto_copy_context
|
|
404
|
+
self.auto_accept_architect = auto_accept_architect
|
|
405
|
+
self.preserve_todo_list = preserve_todo_list
|
|
406
|
+
|
|
407
|
+
self.ignore_mentions = ignore_mentions
|
|
408
|
+
if not self.ignore_mentions:
|
|
409
|
+
self.ignore_mentions = set()
|
|
410
|
+
|
|
411
|
+
self.file_watcher = file_watcher
|
|
412
|
+
if self.file_watcher:
|
|
413
|
+
self.file_watcher.coder = self
|
|
414
|
+
|
|
415
|
+
self.suggest_shell_commands = suggest_shell_commands
|
|
416
|
+
self.detect_urls = detect_urls
|
|
417
|
+
self.args = args
|
|
418
|
+
|
|
419
|
+
self.num_cache_warming_pings = num_cache_warming_pings
|
|
420
|
+
self.mcp_servers = mcp_servers
|
|
421
|
+
self.enable_context_compaction = enable_context_compaction
|
|
422
|
+
|
|
423
|
+
self.context_compaction_max_tokens = context_compaction_max_tokens
|
|
424
|
+
self.context_compaction_summary_tokens = context_compaction_summary_tokens
|
|
425
|
+
|
|
426
|
+
if not fnames:
|
|
427
|
+
fnames = []
|
|
428
|
+
|
|
429
|
+
if io is None:
|
|
430
|
+
io = InputOutput()
|
|
431
|
+
|
|
432
|
+
if aider_commit_hashes:
|
|
433
|
+
self.aider_commit_hashes = aider_commit_hashes
|
|
434
|
+
else:
|
|
435
|
+
self.aider_commit_hashes = set()
|
|
436
|
+
|
|
437
|
+
self.chat_completion_call_hashes = []
|
|
438
|
+
self.chat_completion_response_hashes = []
|
|
439
|
+
self.need_commit_before_edits = set()
|
|
440
|
+
|
|
441
|
+
self.total_cost = total_cost
|
|
442
|
+
self.total_tokens_sent = total_tokens_sent
|
|
443
|
+
self.total_tokens_received = total_tokens_received
|
|
444
|
+
self.message_tokens_sent = 0
|
|
445
|
+
self.message_tokens_received = 0
|
|
446
|
+
|
|
447
|
+
self.verbose = verbose
|
|
448
|
+
self.abs_fnames = set()
|
|
449
|
+
self.abs_read_only_fnames = set()
|
|
450
|
+
self.add_gitignore_files = add_gitignore_files
|
|
451
|
+
self.abs_read_only_stubs_fnames = set()
|
|
452
|
+
|
|
453
|
+
if cur_messages:
|
|
454
|
+
self.cur_messages = cur_messages
|
|
455
|
+
else:
|
|
456
|
+
self.cur_messages = []
|
|
457
|
+
|
|
458
|
+
if done_messages:
|
|
459
|
+
self.done_messages = done_messages
|
|
460
|
+
else:
|
|
461
|
+
self.done_messages = []
|
|
462
|
+
|
|
463
|
+
self.io = io
|
|
464
|
+
self.io.coder = weakref.ref(self)
|
|
465
|
+
|
|
466
|
+
self.shell_commands = []
|
|
467
|
+
self.partial_response_tool_calls = []
|
|
468
|
+
|
|
469
|
+
if not auto_commits:
|
|
470
|
+
dirty_commits = False
|
|
471
|
+
|
|
472
|
+
self.auto_commits = auto_commits
|
|
473
|
+
self.dirty_commits = dirty_commits
|
|
474
|
+
|
|
475
|
+
self.dry_run = dry_run
|
|
476
|
+
self.pretty = self.io.pretty
|
|
477
|
+
self.linear_output = linear_output
|
|
478
|
+
self.main_model = main_model
|
|
479
|
+
|
|
480
|
+
# Set the reasoning tag name based on model settings or default
|
|
481
|
+
self.reasoning_tag_name = (
|
|
482
|
+
self.main_model.reasoning_tag if self.main_model.reasoning_tag else REASONING_TAG
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
self.stream = stream and main_model.streaming
|
|
486
|
+
|
|
487
|
+
if cache_prompts and self.main_model.cache_control:
|
|
488
|
+
self.add_cache_headers = True
|
|
489
|
+
|
|
490
|
+
self.show_diffs = show_diffs
|
|
491
|
+
|
|
492
|
+
self.commands = commands or Commands(self.io, self)
|
|
493
|
+
self.commands.coder = self
|
|
494
|
+
|
|
495
|
+
self.data_cache = {"repo": {"last_key": ""}, "relative_files": None}
|
|
496
|
+
|
|
497
|
+
self.repo = repo
|
|
498
|
+
if use_git and self.repo is None:
|
|
499
|
+
try:
|
|
500
|
+
self.repo = GitRepo(
|
|
501
|
+
self.io,
|
|
502
|
+
fnames,
|
|
503
|
+
None,
|
|
504
|
+
models=main_model.commit_message_models(),
|
|
505
|
+
)
|
|
506
|
+
except FileNotFoundError:
|
|
507
|
+
pass
|
|
508
|
+
|
|
509
|
+
if self.repo:
|
|
510
|
+
self.root = self.repo.root
|
|
511
|
+
|
|
512
|
+
for fname in fnames:
|
|
513
|
+
fname = Path(fname)
|
|
514
|
+
if self.repo and self.repo.git_ignored_file(fname) and not self.add_gitignore_files:
|
|
515
|
+
self.io.tool_warning(f"Skipping {fname} that matches gitignore spec.")
|
|
516
|
+
continue
|
|
517
|
+
|
|
518
|
+
if self.repo and self.repo.ignored_file(fname):
|
|
519
|
+
self.io.tool_warning(f"Skipping {fname} that matches aiderignore spec.")
|
|
520
|
+
continue
|
|
521
|
+
|
|
522
|
+
if not fname.exists():
|
|
523
|
+
if utils.touch_file(fname):
|
|
524
|
+
self.io.tool_output(f"Creating empty file {fname}")
|
|
525
|
+
else:
|
|
526
|
+
self.io.tool_warning(f"Can not create {fname}, skipping.")
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
if not fname.is_file():
|
|
530
|
+
self.io.tool_warning(f"Skipping {fname} that is not a normal file.")
|
|
531
|
+
continue
|
|
532
|
+
|
|
533
|
+
fname = str(fname.resolve())
|
|
534
|
+
|
|
535
|
+
self.abs_fnames.add(fname)
|
|
536
|
+
self.check_added_files()
|
|
537
|
+
|
|
538
|
+
if not self.repo:
|
|
539
|
+
self.root = utils.find_common_root(self.abs_fnames)
|
|
540
|
+
|
|
541
|
+
if read_only_fnames:
|
|
542
|
+
self.abs_read_only_fnames = set()
|
|
543
|
+
for fname in read_only_fnames:
|
|
544
|
+
abs_fname = self.abs_root_path(fname)
|
|
545
|
+
if os.path.exists(abs_fname):
|
|
546
|
+
self.abs_read_only_fnames.add(abs_fname)
|
|
547
|
+
else:
|
|
548
|
+
self.io.tool_warning(f"Error: Read-only file {fname} does not exist. Skipping.")
|
|
549
|
+
|
|
550
|
+
if read_only_stubs_fnames:
|
|
551
|
+
self.abs_read_only_stubs_fnames = set()
|
|
552
|
+
for fname in read_only_stubs_fnames:
|
|
553
|
+
abs_fname = self.abs_root_path(fname)
|
|
554
|
+
if os.path.exists(abs_fname):
|
|
555
|
+
self.abs_read_only_stubs_fnames.add(abs_fname)
|
|
556
|
+
else:
|
|
557
|
+
self.io.tool_warning(
|
|
558
|
+
f"Error: Read-only (stub) file {fname} does not exist. Skipping."
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
if map_tokens is None:
|
|
562
|
+
use_repo_map = main_model.use_repo_map
|
|
563
|
+
map_tokens = 1024
|
|
564
|
+
else:
|
|
565
|
+
use_repo_map = map_tokens > 0
|
|
566
|
+
|
|
567
|
+
max_inp_tokens = self.main_model.info.get("max_input_tokens") or 0
|
|
568
|
+
|
|
569
|
+
has_map_prompt = hasattr(self, "gpt_prompts") and self.gpt_prompts.repo_content_prefix
|
|
570
|
+
|
|
571
|
+
if use_repo_map and self.repo and has_map_prompt:
|
|
572
|
+
self.repo_map = RepoMap(
|
|
573
|
+
map_tokens,
|
|
574
|
+
self.map_cache_dir,
|
|
575
|
+
self.main_model,
|
|
576
|
+
io,
|
|
577
|
+
self.gpt_prompts.repo_content_prefix,
|
|
578
|
+
self.verbose,
|
|
579
|
+
max_inp_tokens,
|
|
580
|
+
map_mul_no_files=map_mul_no_files,
|
|
581
|
+
refresh=map_refresh,
|
|
582
|
+
max_code_line_length=map_max_line_length,
|
|
583
|
+
repo_root=self.root,
|
|
584
|
+
use_memory_cache=repomap_in_memory,
|
|
585
|
+
)
|
|
586
|
+
|
|
587
|
+
self.summarizer = summarizer or ChatSummary(
|
|
588
|
+
[self.main_model.weak_model, self.main_model],
|
|
589
|
+
self.main_model.max_chat_history_tokens,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
self.summarizer_thread = None
|
|
593
|
+
self.summarized_done_messages = []
|
|
594
|
+
self.summarizing_messages = None
|
|
595
|
+
|
|
596
|
+
self.files_edited_by_tools = set()
|
|
597
|
+
|
|
598
|
+
if not self.done_messages and restore_chat_history:
|
|
599
|
+
history_md = self.io.read_text(self.io.chat_history_file)
|
|
600
|
+
if history_md:
|
|
601
|
+
self.done_messages = utils.split_chat_history_markdown(history_md)
|
|
602
|
+
self.summarize_start()
|
|
603
|
+
|
|
604
|
+
# Linting and testing
|
|
605
|
+
self.linter = Linter(root=self.root, encoding=io.encoding)
|
|
606
|
+
self.auto_lint = auto_lint
|
|
607
|
+
self.setup_lint_cmds(lint_cmds)
|
|
608
|
+
self.lint_cmds = lint_cmds
|
|
609
|
+
self.auto_test = auto_test
|
|
610
|
+
self.test_cmd = test_cmd
|
|
611
|
+
|
|
612
|
+
# Clean up todo list file on startup unless preserve_todo_list is True
|
|
613
|
+
if not getattr(self, "preserve_todo_list", False):
|
|
614
|
+
todo_file_path = ".aider.todo.txt"
|
|
615
|
+
abs_path = self.abs_root_path(todo_file_path)
|
|
616
|
+
if os.path.isfile(abs_path):
|
|
617
|
+
try:
|
|
618
|
+
os.remove(abs_path)
|
|
619
|
+
if self.verbose:
|
|
620
|
+
self.io.tool_output(f"Removed existing todo list file: {todo_file_path}")
|
|
621
|
+
except Exception as e:
|
|
622
|
+
self.io.tool_warning(f"Could not remove todo list file {todo_file_path}: {e}")
|
|
623
|
+
|
|
624
|
+
# validate the functions jsonschema
|
|
625
|
+
if self.functions:
|
|
626
|
+
from jsonschema import Draft7Validator
|
|
627
|
+
|
|
628
|
+
for function in self.functions:
|
|
629
|
+
Draft7Validator.check_schema(function)
|
|
630
|
+
|
|
631
|
+
if self.verbose:
|
|
632
|
+
self.io.tool_output("JSON Schema:")
|
|
633
|
+
self.io.tool_output(json.dumps(self.functions, indent=4))
|
|
634
|
+
|
|
635
|
+
def setup_lint_cmds(self, lint_cmds):
|
|
636
|
+
if not lint_cmds:
|
|
637
|
+
return
|
|
638
|
+
for lang, cmd in lint_cmds.items():
|
|
639
|
+
self.linter.set_linter(lang, cmd)
|
|
640
|
+
|
|
641
|
+
def show_announcements(self):
|
|
642
|
+
bold = True
|
|
643
|
+
for line in self.get_announcements():
|
|
644
|
+
self.io.tool_output(line, bold=bold)
|
|
645
|
+
bold = False
|
|
646
|
+
|
|
647
|
+
def add_rel_fname(self, rel_fname):
|
|
648
|
+
self.abs_fnames.add(self.abs_root_path(rel_fname))
|
|
649
|
+
self.check_added_files()
|
|
650
|
+
|
|
651
|
+
def drop_rel_fname(self, fname):
|
|
652
|
+
abs_fname = self.abs_root_path(fname)
|
|
653
|
+
if abs_fname in self.abs_fnames:
|
|
654
|
+
self.abs_fnames.remove(abs_fname)
|
|
655
|
+
return True
|
|
656
|
+
|
|
657
|
+
def abs_root_path(self, path):
|
|
658
|
+
key = path
|
|
659
|
+
if key in self.abs_root_path_cache:
|
|
660
|
+
return self.abs_root_path_cache[key]
|
|
661
|
+
|
|
662
|
+
res = Path(self.root) / path
|
|
663
|
+
res = utils.safe_abs_path(res)
|
|
664
|
+
self.abs_root_path_cache[key] = res
|
|
665
|
+
return res
|
|
666
|
+
|
|
667
|
+
fences = all_fences
|
|
668
|
+
fence = fences[0]
|
|
669
|
+
|
|
670
|
+
def show_pretty(self):
|
|
671
|
+
if not self.pretty:
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
# only show pretty output if fences are the normal triple-backtick
|
|
675
|
+
if self.fence[0][0] != "`":
|
|
676
|
+
return False
|
|
677
|
+
|
|
678
|
+
return True
|
|
679
|
+
|
|
680
|
+
def get_abs_fnames_content(self):
|
|
681
|
+
for fname in list(self.abs_fnames):
|
|
682
|
+
content = self.io.read_text(fname)
|
|
683
|
+
|
|
684
|
+
if content is None:
|
|
685
|
+
relative_fname = self.get_rel_fname(fname)
|
|
686
|
+
self.io.tool_warning(f"Dropping {relative_fname} from the chat.")
|
|
687
|
+
self.abs_fnames.remove(fname)
|
|
688
|
+
else:
|
|
689
|
+
yield fname, content
|
|
690
|
+
|
|
691
|
+
def choose_fence(self):
|
|
692
|
+
all_content = ""
|
|
693
|
+
for _fname, content in self.get_abs_fnames_content():
|
|
694
|
+
all_content += content + "\n"
|
|
695
|
+
for _fname in self.abs_read_only_fnames:
|
|
696
|
+
content = self.io.read_text(_fname)
|
|
697
|
+
if content is not None:
|
|
698
|
+
all_content += content + "\n"
|
|
699
|
+
for _fname in self.abs_read_only_stubs_fnames:
|
|
700
|
+
content = self.io.read_text(_fname)
|
|
701
|
+
if content is not None:
|
|
702
|
+
all_content += content + "\n"
|
|
703
|
+
|
|
704
|
+
lines = all_content.splitlines()
|
|
705
|
+
good = False
|
|
706
|
+
for fence_open, fence_close in self.fences:
|
|
707
|
+
if any(line.startswith(fence_open) or line.startswith(fence_close) for line in lines):
|
|
708
|
+
continue
|
|
709
|
+
good = True
|
|
710
|
+
break
|
|
711
|
+
|
|
712
|
+
if good:
|
|
713
|
+
self.fence = (fence_open, fence_close)
|
|
714
|
+
else:
|
|
715
|
+
self.fence = self.fences[0]
|
|
716
|
+
self.io.tool_warning(
|
|
717
|
+
"Unable to find a fencing strategy! Falling back to:"
|
|
718
|
+
f" {self.fence[0]}...{self.fence[1]}"
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
return
|
|
722
|
+
|
|
723
|
+
def get_files_content(self, fnames=None):
|
|
724
|
+
if not fnames:
|
|
725
|
+
fnames = self.abs_fnames
|
|
726
|
+
|
|
727
|
+
prompt = ""
|
|
728
|
+
for fname, content in self.get_abs_fnames_content():
|
|
729
|
+
if not is_image_file(fname):
|
|
730
|
+
relative_fname = self.get_rel_fname(fname)
|
|
731
|
+
prompt += "\n"
|
|
732
|
+
prompt += relative_fname
|
|
733
|
+
prompt += f"\n{self.fence[0]}\n"
|
|
734
|
+
|
|
735
|
+
# Apply context management if enabled for large files
|
|
736
|
+
if self.context_management_enabled:
|
|
737
|
+
# Calculate tokens for this file
|
|
738
|
+
file_tokens = self.main_model.token_count(content)
|
|
739
|
+
|
|
740
|
+
if file_tokens > self.large_file_token_threshold:
|
|
741
|
+
# Truncate the file content
|
|
742
|
+
lines = content.splitlines()
|
|
743
|
+
|
|
744
|
+
# Keep the first and last parts of the file with a marker in between
|
|
745
|
+
keep_lines = (
|
|
746
|
+
self.large_file_token_threshold // 40
|
|
747
|
+
) # Rough estimate of tokens per line
|
|
748
|
+
first_chunk = lines[: keep_lines // 2]
|
|
749
|
+
last_chunk = lines[-(keep_lines // 2) :]
|
|
750
|
+
|
|
751
|
+
truncated_content = "\n".join(first_chunk)
|
|
752
|
+
truncated_content += (
|
|
753
|
+
f"\n\n... [File truncated due to size ({file_tokens} tokens). Use"
|
|
754
|
+
" /context-management to toggle truncation off] ...\n\n"
|
|
755
|
+
)
|
|
756
|
+
truncated_content += "\n".join(last_chunk)
|
|
757
|
+
|
|
758
|
+
# Add message about truncation
|
|
759
|
+
self.io.tool_output(
|
|
760
|
+
f"⚠️ '{relative_fname}' is very large ({file_tokens} tokens). "
|
|
761
|
+
"Use /context-management to toggle truncation off if needed."
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
prompt += truncated_content
|
|
765
|
+
else:
|
|
766
|
+
prompt += content
|
|
767
|
+
else:
|
|
768
|
+
prompt += content
|
|
769
|
+
|
|
770
|
+
prompt += f"{self.fence[1]}\n"
|
|
771
|
+
|
|
772
|
+
return prompt
|
|
773
|
+
|
|
774
|
+
def get_read_only_files_content(self):
|
|
775
|
+
prompt = ""
|
|
776
|
+
# Handle regular read-only files
|
|
777
|
+
for fname in self.abs_read_only_fnames:
|
|
778
|
+
content = self.io.read_text(fname)
|
|
779
|
+
if content is not None and not is_image_file(fname):
|
|
780
|
+
relative_fname = self.get_rel_fname(fname)
|
|
781
|
+
prompt += "\n"
|
|
782
|
+
prompt += relative_fname
|
|
783
|
+
prompt += f"\n{self.fence[0]}\n"
|
|
784
|
+
|
|
785
|
+
# Apply context management if enabled for large files (same as get_files_content)
|
|
786
|
+
if self.context_management_enabled:
|
|
787
|
+
# Calculate tokens for this file
|
|
788
|
+
file_tokens = self.main_model.token_count(content)
|
|
789
|
+
|
|
790
|
+
if file_tokens > self.large_file_token_threshold:
|
|
791
|
+
# Truncate the file content
|
|
792
|
+
lines = content.splitlines()
|
|
793
|
+
|
|
794
|
+
# Keep the first and last parts of the file with a marker in between
|
|
795
|
+
keep_lines = (
|
|
796
|
+
self.large_file_token_threshold // 40
|
|
797
|
+
) # Rough estimate of tokens per line
|
|
798
|
+
first_chunk = lines[: keep_lines // 2]
|
|
799
|
+
last_chunk = lines[-(keep_lines // 2) :]
|
|
800
|
+
|
|
801
|
+
truncated_content = "\n".join(first_chunk)
|
|
802
|
+
truncated_content += (
|
|
803
|
+
f"\n\n... [File truncated due to size ({file_tokens} tokens). Use"
|
|
804
|
+
" /context-management to toggle truncation off] ...\n\n"
|
|
805
|
+
)
|
|
806
|
+
truncated_content += "\n".join(last_chunk)
|
|
807
|
+
|
|
808
|
+
# Add message about truncation
|
|
809
|
+
self.io.tool_output(
|
|
810
|
+
f"⚠️ '{relative_fname}' is very large ({file_tokens} tokens). "
|
|
811
|
+
"Use /context-management to toggle truncation off if needed."
|
|
812
|
+
)
|
|
813
|
+
|
|
814
|
+
prompt += truncated_content
|
|
815
|
+
else:
|
|
816
|
+
prompt += content
|
|
817
|
+
else:
|
|
818
|
+
prompt += content
|
|
819
|
+
|
|
820
|
+
prompt += f"{self.fence[1]}\n"
|
|
821
|
+
|
|
822
|
+
# Handle stub files
|
|
823
|
+
for fname in self.abs_read_only_stubs_fnames:
|
|
824
|
+
if not is_image_file(fname):
|
|
825
|
+
relative_fname = self.get_rel_fname(fname)
|
|
826
|
+
prompt += "\n"
|
|
827
|
+
prompt += f"{relative_fname} (stub)"
|
|
828
|
+
prompt += f"\n{self.fence[0]}\n"
|
|
829
|
+
stub = self.get_file_stub(fname)
|
|
830
|
+
prompt += stub
|
|
831
|
+
prompt += f"{self.fence[1]}\n"
|
|
832
|
+
return prompt
|
|
833
|
+
|
|
834
|
+
def get_cur_message_text(self):
|
|
835
|
+
text = ""
|
|
836
|
+
for msg in self.cur_messages:
|
|
837
|
+
# For some models the content is None if the message
|
|
838
|
+
# contains tool calls.
|
|
839
|
+
content = msg["content"] or ""
|
|
840
|
+
text += content + "\n"
|
|
841
|
+
return text
|
|
842
|
+
|
|
843
|
+
def get_ident_mentions(self, text):
|
|
844
|
+
# Split the string on any character that is not alphanumeric
|
|
845
|
+
# \W+ matches one or more non-word characters (equivalent to [^a-zA-Z0-9_]+)
|
|
846
|
+
words = set(re.split(r"\W+", text))
|
|
847
|
+
return words
|
|
848
|
+
|
|
849
|
+
def get_ident_filename_matches(self, idents):
|
|
850
|
+
all_fnames = defaultdict(set)
|
|
851
|
+
for fname in self.get_all_relative_files():
|
|
852
|
+
# Skip empty paths or just '.'
|
|
853
|
+
if not fname or fname == ".":
|
|
854
|
+
continue
|
|
855
|
+
|
|
856
|
+
try:
|
|
857
|
+
# Handle dotfiles properly
|
|
858
|
+
path = Path(fname)
|
|
859
|
+
base = path.stem.lower() # Use stem instead of with_suffix("").name
|
|
860
|
+
if len(base) >= 5:
|
|
861
|
+
all_fnames[base].add(fname)
|
|
862
|
+
except ValueError:
|
|
863
|
+
# Skip paths that can't be processed
|
|
864
|
+
continue
|
|
865
|
+
|
|
866
|
+
matches = set()
|
|
867
|
+
for ident in idents:
|
|
868
|
+
if len(ident) < 5:
|
|
869
|
+
continue
|
|
870
|
+
matches.update(all_fnames[ident.lower()])
|
|
871
|
+
|
|
872
|
+
return matches
|
|
873
|
+
|
|
874
|
+
def get_repo_map(self, force_refresh=False):
|
|
875
|
+
if not self.repo_map or not self.repo:
|
|
876
|
+
return
|
|
877
|
+
|
|
878
|
+
self.io.update_spinner("Updating repo map")
|
|
879
|
+
|
|
880
|
+
cur_msg_text = self.get_cur_message_text()
|
|
881
|
+
staged_files_hash = hash(str([item.a_path for item in self.repo.repo.index.diff("HEAD")]))
|
|
882
|
+
read_only_count = len(set(self.abs_read_only_fnames)) + len(
|
|
883
|
+
set(self.abs_read_only_stubs_fnames)
|
|
884
|
+
)
|
|
885
|
+
self.data_cache["repo"]["mentioned_idents"] = self.get_ident_mentions(cur_msg_text)
|
|
886
|
+
|
|
887
|
+
if (
|
|
888
|
+
staged_files_hash != self.data_cache["repo"]["last_key"]
|
|
889
|
+
or read_only_count != self.data_cache["repo"]["read_only_count"]
|
|
890
|
+
):
|
|
891
|
+
self.data_cache["repo"]["last_key"] = staged_files_hash
|
|
892
|
+
|
|
893
|
+
mentioned_idents = self.data_cache["repo"]["mentioned_idents"]
|
|
894
|
+
mentioned_fnames = self.get_file_mentions(cur_msg_text)
|
|
895
|
+
mentioned_fnames.update(self.get_ident_filename_matches(mentioned_idents))
|
|
896
|
+
|
|
897
|
+
all_abs_files = set(self.get_all_abs_files())
|
|
898
|
+
|
|
899
|
+
# Exclude metadata/docs from repo map inputs to reduce parsing overhead
|
|
900
|
+
def _include_in_map(abs_path):
|
|
901
|
+
try:
|
|
902
|
+
rel = self.get_rel_fname(abs_path)
|
|
903
|
+
except Exception:
|
|
904
|
+
rel = str(abs_path)
|
|
905
|
+
parts = Path(rel).parts
|
|
906
|
+
if ".meta" in parts or ".docs" in parts:
|
|
907
|
+
return False
|
|
908
|
+
if ".min." in parts[-1]:
|
|
909
|
+
return False
|
|
910
|
+
if self.repo.git_ignored_file(abs_path):
|
|
911
|
+
return False
|
|
912
|
+
return True
|
|
913
|
+
|
|
914
|
+
all_abs_files = {p for p in all_abs_files if _include_in_map(p)}
|
|
915
|
+
repo_abs_read_only_fnames = set(self.abs_read_only_fnames) & all_abs_files
|
|
916
|
+
repo_abs_read_only_stubs_fnames = set(self.abs_read_only_stubs_fnames) & all_abs_files
|
|
917
|
+
chat_files = (
|
|
918
|
+
set(self.abs_fnames) | repo_abs_read_only_fnames | repo_abs_read_only_stubs_fnames
|
|
919
|
+
)
|
|
920
|
+
other_files = all_abs_files - chat_files
|
|
921
|
+
|
|
922
|
+
self.data_cache["repo"].update(
|
|
923
|
+
{
|
|
924
|
+
"chat_files": chat_files,
|
|
925
|
+
"other_files": other_files,
|
|
926
|
+
"mentioned_fnames": mentioned_fnames,
|
|
927
|
+
"all_abs_files": all_abs_files,
|
|
928
|
+
"read_only_count": len(set(self.abs_read_only_fnames)) + len(
|
|
929
|
+
set(self.abs_read_only_stubs_fnames)
|
|
930
|
+
),
|
|
931
|
+
}
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
repo_content = self.repo_map.get_repo_map(
|
|
935
|
+
self.data_cache["repo"]["chat_files"],
|
|
936
|
+
self.data_cache["repo"]["other_files"],
|
|
937
|
+
mentioned_fnames=self.data_cache["repo"]["mentioned_fnames"],
|
|
938
|
+
mentioned_idents=self.data_cache["repo"]["mentioned_idents"],
|
|
939
|
+
force_refresh=force_refresh,
|
|
940
|
+
)
|
|
941
|
+
|
|
942
|
+
# fall back to global repo map if files in chat are disjoint from rest of repo
|
|
943
|
+
if not repo_content:
|
|
944
|
+
repo_content = self.repo_map.get_repo_map(
|
|
945
|
+
set(),
|
|
946
|
+
self.data_cache["repo"]["all_abs_files"],
|
|
947
|
+
mentioned_fnames=self.data_cache["repo"]["mentioned_fnames"],
|
|
948
|
+
mentioned_idents=self.data_cache["repo"]["mentioned_idents"],
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
# fall back to completely unhinted repo
|
|
952
|
+
if not repo_content:
|
|
953
|
+
repo_content = self.repo_map.get_repo_map(
|
|
954
|
+
set(),
|
|
955
|
+
self.data_cache["repo"]["all_abs_files"],
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
self.io.update_spinner(self.io.last_spinner_text)
|
|
959
|
+
return repo_content
|
|
960
|
+
|
|
961
|
+
def get_repo_messages(self):
|
|
962
|
+
repo_messages = []
|
|
963
|
+
repo_content = self.get_repo_map()
|
|
964
|
+
if repo_content:
|
|
965
|
+
repo_messages += [
|
|
966
|
+
dict(role="user", content=repo_content),
|
|
967
|
+
dict(
|
|
968
|
+
role="assistant",
|
|
969
|
+
content="Ok, I won't try and edit those files without asking first.",
|
|
970
|
+
),
|
|
971
|
+
]
|
|
972
|
+
return repo_messages
|
|
973
|
+
|
|
974
|
+
def get_readonly_files_messages(self):
|
|
975
|
+
readonly_messages = []
|
|
976
|
+
|
|
977
|
+
# Handle non-image files
|
|
978
|
+
read_only_content = self.get_read_only_files_content()
|
|
979
|
+
if read_only_content:
|
|
980
|
+
readonly_messages += [
|
|
981
|
+
dict(
|
|
982
|
+
role="user", content=self.gpt_prompts.read_only_files_prefix + read_only_content
|
|
983
|
+
),
|
|
984
|
+
dict(
|
|
985
|
+
role="assistant",
|
|
986
|
+
content="Ok, I will use these files as references.",
|
|
987
|
+
),
|
|
988
|
+
]
|
|
989
|
+
|
|
990
|
+
# Handle image files
|
|
991
|
+
images_message = self.get_images_message(
|
|
992
|
+
list(self.abs_read_only_fnames) + list(self.abs_read_only_stubs_fnames)
|
|
993
|
+
)
|
|
994
|
+
if images_message is not None:
|
|
995
|
+
readonly_messages += [
|
|
996
|
+
images_message,
|
|
997
|
+
dict(role="assistant", content="Ok, I will use these images as references."),
|
|
998
|
+
]
|
|
999
|
+
|
|
1000
|
+
return readonly_messages
|
|
1001
|
+
|
|
1002
|
+
def get_chat_files_messages(self):
|
|
1003
|
+
chat_files_messages = []
|
|
1004
|
+
if self.abs_fnames:
|
|
1005
|
+
files_content = self.gpt_prompts.files_content_prefix
|
|
1006
|
+
files_content += self.get_files_content()
|
|
1007
|
+
files_reply = self.gpt_prompts.files_content_assistant_reply
|
|
1008
|
+
elif self.gpt_prompts.files_no_full_files_with_repo_map:
|
|
1009
|
+
files_content = self.gpt_prompts.files_no_full_files_with_repo_map
|
|
1010
|
+
files_reply = self.gpt_prompts.files_no_full_files_with_repo_map_reply
|
|
1011
|
+
else:
|
|
1012
|
+
files_content = self.gpt_prompts.files_no_full_files
|
|
1013
|
+
files_reply = "Ok."
|
|
1014
|
+
|
|
1015
|
+
if files_content:
|
|
1016
|
+
chat_files_messages += [
|
|
1017
|
+
dict(role="user", content=files_content),
|
|
1018
|
+
dict(role="assistant", content=files_reply),
|
|
1019
|
+
]
|
|
1020
|
+
|
|
1021
|
+
images_message = self.get_images_message(self.abs_fnames)
|
|
1022
|
+
if images_message is not None:
|
|
1023
|
+
chat_files_messages += [
|
|
1024
|
+
images_message,
|
|
1025
|
+
dict(role="assistant", content="Ok."),
|
|
1026
|
+
]
|
|
1027
|
+
|
|
1028
|
+
return chat_files_messages
|
|
1029
|
+
|
|
1030
|
+
def get_images_message(self, fnames):
|
|
1031
|
+
supports_images = self.main_model.info.get("supports_vision")
|
|
1032
|
+
supports_pdfs = self.main_model.info.get("supports_pdf_input") or self.main_model.info.get(
|
|
1033
|
+
"max_pdf_size_mb"
|
|
1034
|
+
)
|
|
1035
|
+
|
|
1036
|
+
# https://github.com/BerriAI/litellm/pull/6928
|
|
1037
|
+
supports_pdfs = supports_pdfs or "claude-3-5-sonnet-20241022" in self.main_model.name
|
|
1038
|
+
|
|
1039
|
+
if not (supports_images or supports_pdfs):
|
|
1040
|
+
return None
|
|
1041
|
+
|
|
1042
|
+
image_messages = []
|
|
1043
|
+
for fname in fnames:
|
|
1044
|
+
if not is_image_file(fname):
|
|
1045
|
+
continue
|
|
1046
|
+
|
|
1047
|
+
mime_type, _ = mimetypes.guess_type(fname)
|
|
1048
|
+
if not mime_type:
|
|
1049
|
+
continue
|
|
1050
|
+
|
|
1051
|
+
with open(fname, "rb") as image_file:
|
|
1052
|
+
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
|
|
1053
|
+
image_url = f"data:{mime_type};base64,{encoded_string}"
|
|
1054
|
+
rel_fname = self.get_rel_fname(fname)
|
|
1055
|
+
|
|
1056
|
+
if mime_type.startswith("image/") and supports_images:
|
|
1057
|
+
image_messages += [
|
|
1058
|
+
{"type": "text", "text": f"Image file: {rel_fname}"},
|
|
1059
|
+
{"type": "image_url", "image_url": {"url": image_url, "detail": "high"}},
|
|
1060
|
+
]
|
|
1061
|
+
elif mime_type == "application/pdf" and supports_pdfs:
|
|
1062
|
+
image_messages += [
|
|
1063
|
+
{"type": "text", "text": f"PDF file: {rel_fname}"},
|
|
1064
|
+
{"type": "image_url", "image_url": image_url},
|
|
1065
|
+
]
|
|
1066
|
+
|
|
1067
|
+
if not image_messages:
|
|
1068
|
+
return None
|
|
1069
|
+
|
|
1070
|
+
return {"role": "user", "content": image_messages}
|
|
1071
|
+
|
|
1072
|
+
async def run_stream(self, user_message):
|
|
1073
|
+
self.io.user_input(user_message)
|
|
1074
|
+
self.init_before_message()
|
|
1075
|
+
async for chunk in self.send_message(user_message):
|
|
1076
|
+
yield chunk
|
|
1077
|
+
|
|
1078
|
+
def init_before_message(self):
|
|
1079
|
+
self.aider_edited_files = set()
|
|
1080
|
+
self.reflected_message = None
|
|
1081
|
+
self.num_reflections = 0
|
|
1082
|
+
self.lint_outcome = None
|
|
1083
|
+
self.test_outcome = None
|
|
1084
|
+
self.shell_commands = []
|
|
1085
|
+
self.message_cost = 0
|
|
1086
|
+
|
|
1087
|
+
if self.repo:
|
|
1088
|
+
self.commit_before_message.append(self.repo.get_head_commit_sha())
|
|
1089
|
+
|
|
1090
|
+
async def run(self, with_message=None, preproc=True):
|
|
1091
|
+
while self.io.confirmation_in_progress:
|
|
1092
|
+
await asyncio.sleep(0.1) # Yield control and wait briefly
|
|
1093
|
+
|
|
1094
|
+
if self.linear_output:
|
|
1095
|
+
return await self._run_linear(with_message, preproc)
|
|
1096
|
+
|
|
1097
|
+
if self.io.prompt_session:
|
|
1098
|
+
with patch_stdout(raw=True):
|
|
1099
|
+
return await self._run_patched(with_message, preproc)
|
|
1100
|
+
else:
|
|
1101
|
+
return await self._run_patched(with_message, preproc)
|
|
1102
|
+
|
|
1103
|
+
async def _run_linear(self, with_message=None, preproc=True):
|
|
1104
|
+
try:
|
|
1105
|
+
if with_message:
|
|
1106
|
+
self.io.user_input(with_message)
|
|
1107
|
+
await self.run_one(with_message, preproc)
|
|
1108
|
+
return self.partial_response_content
|
|
1109
|
+
|
|
1110
|
+
user_message = None
|
|
1111
|
+
await self.io.cancel_input_task()
|
|
1112
|
+
await self.io.cancel_output_task()
|
|
1113
|
+
|
|
1114
|
+
while True:
|
|
1115
|
+
try:
|
|
1116
|
+
if self.commands.cmd_running:
|
|
1117
|
+
await asyncio.sleep(0.1)
|
|
1118
|
+
continue
|
|
1119
|
+
|
|
1120
|
+
if not self.suppress_announcements_for_next_prompt:
|
|
1121
|
+
self.show_announcements()
|
|
1122
|
+
self.suppress_announcements_for_next_prompt = True
|
|
1123
|
+
|
|
1124
|
+
await self.io.recreate_input()
|
|
1125
|
+
await self.io.input_task
|
|
1126
|
+
user_message = self.io.input_task.result()
|
|
1127
|
+
|
|
1128
|
+
self.io.output_task = asyncio.create_task(self._generate(user_message, preproc))
|
|
1129
|
+
|
|
1130
|
+
await self.io.output_task
|
|
1131
|
+
|
|
1132
|
+
self.io.ring_bell()
|
|
1133
|
+
user_message = None
|
|
1134
|
+
except KeyboardInterrupt:
|
|
1135
|
+
if self.io.input_task:
|
|
1136
|
+
self.io.set_placeholder("")
|
|
1137
|
+
await self.io.cancel_input_task()
|
|
1138
|
+
|
|
1139
|
+
if self.io.output_task:
|
|
1140
|
+
await self.io.cancel_output_task()
|
|
1141
|
+
self.io.stop_spinner()
|
|
1142
|
+
|
|
1143
|
+
self.keyboard_interrupt()
|
|
1144
|
+
except (asyncio.CancelledError, IndexError):
|
|
1145
|
+
pass
|
|
1146
|
+
|
|
1147
|
+
self.auto_save_session()
|
|
1148
|
+
except EOFError:
|
|
1149
|
+
return
|
|
1150
|
+
finally:
|
|
1151
|
+
await self.io.cancel_input_task()
|
|
1152
|
+
await self.io.cancel_output_task()
|
|
1153
|
+
|
|
1154
|
+
async def _run_patched(self, with_message=None, preproc=True):
|
|
1155
|
+
try:
|
|
1156
|
+
if with_message:
|
|
1157
|
+
self.io.user_input(with_message)
|
|
1158
|
+
await self.run_one(with_message, preproc)
|
|
1159
|
+
return self.partial_response_content
|
|
1160
|
+
|
|
1161
|
+
user_message = None
|
|
1162
|
+
self.user_message = ""
|
|
1163
|
+
await self.io.cancel_input_task()
|
|
1164
|
+
await self.io.cancel_output_task()
|
|
1165
|
+
|
|
1166
|
+
while True:
|
|
1167
|
+
try:
|
|
1168
|
+
if (
|
|
1169
|
+
not self.io.confirmation_in_progress
|
|
1170
|
+
and not user_message
|
|
1171
|
+
and (
|
|
1172
|
+
not self.io.input_task
|
|
1173
|
+
or self.io.input_task.done()
|
|
1174
|
+
or self.io.input_task.cancelled()
|
|
1175
|
+
)
|
|
1176
|
+
and (not self.io.output_task or not self.io.placeholder)
|
|
1177
|
+
):
|
|
1178
|
+
if not self.suppress_announcements_for_next_prompt:
|
|
1179
|
+
self.show_announcements()
|
|
1180
|
+
self.suppress_announcements_for_next_prompt = True
|
|
1181
|
+
|
|
1182
|
+
# Stop spinner before showing announcements or getting input
|
|
1183
|
+
self.io.stop_spinner()
|
|
1184
|
+
self.copy_context()
|
|
1185
|
+
await self.io.recreate_input()
|
|
1186
|
+
|
|
1187
|
+
if self.user_message:
|
|
1188
|
+
self.io.output_task = asyncio.create_task(
|
|
1189
|
+
self._generate(self.user_message, preproc)
|
|
1190
|
+
)
|
|
1191
|
+
|
|
1192
|
+
self.user_message = ""
|
|
1193
|
+
# Start spinner for processing task
|
|
1194
|
+
self.io.start_spinner("Processing...")
|
|
1195
|
+
|
|
1196
|
+
if self.commands.cmd_running:
|
|
1197
|
+
await asyncio.sleep(0.1)
|
|
1198
|
+
continue
|
|
1199
|
+
|
|
1200
|
+
tasks = set()
|
|
1201
|
+
|
|
1202
|
+
if self.io.output_task:
|
|
1203
|
+
if self.io.output_task.done():
|
|
1204
|
+
exception = self.io.output_task.exception()
|
|
1205
|
+
if exception:
|
|
1206
|
+
if isinstance(exception, SwitchCoder):
|
|
1207
|
+
await self.io.output_task
|
|
1208
|
+
elif not self.io.output_task.done() and not self.io.output_task.cancelled():
|
|
1209
|
+
tasks.add(self.io.output_task)
|
|
1210
|
+
|
|
1211
|
+
if (
|
|
1212
|
+
self.io.input_task
|
|
1213
|
+
and not self.io.input_task.done()
|
|
1214
|
+
and not self.io.input_task.cancelled()
|
|
1215
|
+
):
|
|
1216
|
+
tasks.add(self.io.input_task)
|
|
1217
|
+
|
|
1218
|
+
if tasks:
|
|
1219
|
+
done, pending = await asyncio.wait(
|
|
1220
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
if self.io.input_task and self.io.input_task in done:
|
|
1224
|
+
if self.io.output_task:
|
|
1225
|
+
if not self.io.confirmation_in_progress:
|
|
1226
|
+
await self.io.cancel_output_task()
|
|
1227
|
+
self.io.stop_spinner()
|
|
1228
|
+
|
|
1229
|
+
try:
|
|
1230
|
+
if self.io.input_task:
|
|
1231
|
+
user_message = self.io.input_task.result()
|
|
1232
|
+
await self.io.cancel_input_task()
|
|
1233
|
+
|
|
1234
|
+
if self.commands.is_run_command(user_message):
|
|
1235
|
+
self.commands.cmd_running = True
|
|
1236
|
+
|
|
1237
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
1238
|
+
user_message = None
|
|
1239
|
+
|
|
1240
|
+
if not user_message:
|
|
1241
|
+
await self.io.cancel_input_task()
|
|
1242
|
+
continue
|
|
1243
|
+
|
|
1244
|
+
if self.io.output_task and self.io.output_task in pending:
|
|
1245
|
+
try:
|
|
1246
|
+
tasks = set()
|
|
1247
|
+
tasks.add(self.io.output_task)
|
|
1248
|
+
|
|
1249
|
+
# We just did a confirmation so add a new input task
|
|
1250
|
+
if self.io.get_confirmation_acknowledgement():
|
|
1251
|
+
await self.io.recreate_input()
|
|
1252
|
+
tasks.add(self.io.input_task)
|
|
1253
|
+
|
|
1254
|
+
done, pending = await asyncio.wait(
|
|
1255
|
+
tasks, return_when=asyncio.FIRST_COMPLETED
|
|
1256
|
+
)
|
|
1257
|
+
|
|
1258
|
+
if (
|
|
1259
|
+
self.io.input_task
|
|
1260
|
+
and self.io.input_task in done
|
|
1261
|
+
and not self.io.confirmation_in_progress
|
|
1262
|
+
):
|
|
1263
|
+
await self.io.cancel_output_task()
|
|
1264
|
+
self.io.stop_spinner()
|
|
1265
|
+
self.io.acknowledge_confirmation()
|
|
1266
|
+
|
|
1267
|
+
try:
|
|
1268
|
+
user_message = self.io.input_task.result()
|
|
1269
|
+
await self.io.cancel_input_task()
|
|
1270
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
1271
|
+
user_message = None
|
|
1272
|
+
|
|
1273
|
+
except (asyncio.CancelledError, KeyboardInterrupt):
|
|
1274
|
+
pass
|
|
1275
|
+
|
|
1276
|
+
# Stop spinner when processing task completes
|
|
1277
|
+
self.io.stop_spinner()
|
|
1278
|
+
|
|
1279
|
+
if user_message and not self.io.acknowledge_confirmation():
|
|
1280
|
+
self.user_message = user_message
|
|
1281
|
+
|
|
1282
|
+
self.io.ring_bell()
|
|
1283
|
+
user_message = None
|
|
1284
|
+
except KeyboardInterrupt:
|
|
1285
|
+
self.io.set_placeholder("")
|
|
1286
|
+
|
|
1287
|
+
await self.io.cancel_input_task()
|
|
1288
|
+
await self.io.cancel_output_task()
|
|
1289
|
+
|
|
1290
|
+
self.io.stop_spinner()
|
|
1291
|
+
self.keyboard_interrupt()
|
|
1292
|
+
|
|
1293
|
+
self.auto_save_session()
|
|
1294
|
+
except EOFError:
|
|
1295
|
+
return
|
|
1296
|
+
finally:
|
|
1297
|
+
await self.io.cancel_input_task()
|
|
1298
|
+
await self.io.cancel_output_task()
|
|
1299
|
+
|
|
1300
|
+
async def _generate(self, user_message, preproc):
|
|
1301
|
+
await asyncio.sleep(0.1)
|
|
1302
|
+
|
|
1303
|
+
try:
|
|
1304
|
+
self.compact_context_completed = False
|
|
1305
|
+
await self.compact_context_if_needed()
|
|
1306
|
+
self.compact_context_completed = True
|
|
1307
|
+
|
|
1308
|
+
self.run_one_completed = False
|
|
1309
|
+
await self.run_one(user_message, preproc)
|
|
1310
|
+
self.show_undo_hint()
|
|
1311
|
+
except asyncio.CancelledError:
|
|
1312
|
+
# Don't show undo hint if cancelled
|
|
1313
|
+
raise
|
|
1314
|
+
finally:
|
|
1315
|
+
self.run_one_completed = True
|
|
1316
|
+
self.compact_context_completed = True
|
|
1317
|
+
self.io.stop_spinner()
|
|
1318
|
+
|
|
1319
|
+
def copy_context(self):
|
|
1320
|
+
if self.auto_copy_context:
|
|
1321
|
+
self.commands.cmd_copy_context()
|
|
1322
|
+
|
|
1323
|
+
async def get_input(self):
|
|
1324
|
+
inchat_files = self.get_inchat_relative_files()
|
|
1325
|
+
all_read_only_fnames = self.abs_read_only_fnames | self.abs_read_only_stubs_fnames
|
|
1326
|
+
all_read_only_files = [self.get_rel_fname(fname) for fname in all_read_only_fnames]
|
|
1327
|
+
all_files = sorted(set(inchat_files + all_read_only_files))
|
|
1328
|
+
edit_format = "" if self.edit_format == self.main_model.edit_format else self.edit_format
|
|
1329
|
+
|
|
1330
|
+
return await self.io.get_input(
|
|
1331
|
+
self.root,
|
|
1332
|
+
all_files,
|
|
1333
|
+
self.get_addable_relative_files(),
|
|
1334
|
+
self.commands,
|
|
1335
|
+
abs_read_only_fnames=self.abs_read_only_fnames,
|
|
1336
|
+
abs_read_only_stubs_fnames=self.abs_read_only_stubs_fnames,
|
|
1337
|
+
edit_format=edit_format,
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
async def preproc_user_input(self, inp):
|
|
1341
|
+
if not inp:
|
|
1342
|
+
return
|
|
1343
|
+
|
|
1344
|
+
# Strip whitespace from beginning and end
|
|
1345
|
+
inp = inp.strip()
|
|
1346
|
+
|
|
1347
|
+
if self.commands.is_command(inp):
|
|
1348
|
+
if inp[0] in "!":
|
|
1349
|
+
inp = f"/run {inp[1:]}"
|
|
1350
|
+
|
|
1351
|
+
if self.commands.is_run_command(inp):
|
|
1352
|
+
self.commands.cmd_running = True
|
|
1353
|
+
|
|
1354
|
+
return await self.commands.run(inp)
|
|
1355
|
+
|
|
1356
|
+
await self.check_for_file_mentions(inp)
|
|
1357
|
+
inp = await self.check_for_urls(inp)
|
|
1358
|
+
|
|
1359
|
+
return inp
|
|
1360
|
+
|
|
1361
|
+
async def run_one(self, user_message, preproc):
|
|
1362
|
+
self.init_before_message()
|
|
1363
|
+
|
|
1364
|
+
if preproc:
|
|
1365
|
+
message = await self.preproc_user_input(user_message)
|
|
1366
|
+
else:
|
|
1367
|
+
message = user_message
|
|
1368
|
+
|
|
1369
|
+
if self.commands.is_command(user_message):
|
|
1370
|
+
return
|
|
1371
|
+
|
|
1372
|
+
while True:
|
|
1373
|
+
self.reflected_message = None
|
|
1374
|
+
self.tool_reflection = False
|
|
1375
|
+
|
|
1376
|
+
async for _ in self.send_message(message):
|
|
1377
|
+
pass
|
|
1378
|
+
|
|
1379
|
+
if not self.reflected_message:
|
|
1380
|
+
break
|
|
1381
|
+
|
|
1382
|
+
if self.num_reflections >= self.max_reflections:
|
|
1383
|
+
self.io.tool_warning(f"Only {self.max_reflections} reflections allowed, stopping.")
|
|
1384
|
+
return
|
|
1385
|
+
|
|
1386
|
+
self.num_reflections += 1
|
|
1387
|
+
|
|
1388
|
+
if self.tool_reflection:
|
|
1389
|
+
self.num_reflections -= 1
|
|
1390
|
+
|
|
1391
|
+
if self.reflected_message is True:
|
|
1392
|
+
message = None
|
|
1393
|
+
else:
|
|
1394
|
+
message = self.reflected_message
|
|
1395
|
+
|
|
1396
|
+
async def check_and_open_urls(self, exc, friendly_msg=None):
|
|
1397
|
+
"""Check exception for URLs, offer to open in a browser, with user-friendly error msgs."""
|
|
1398
|
+
text = str(exc)
|
|
1399
|
+
|
|
1400
|
+
if friendly_msg:
|
|
1401
|
+
self.io.tool_warning(text)
|
|
1402
|
+
self.io.tool_error(f"{friendly_msg}")
|
|
1403
|
+
else:
|
|
1404
|
+
self.io.tool_error(text)
|
|
1405
|
+
|
|
1406
|
+
# Exclude double quotes from the matched URL characters
|
|
1407
|
+
url_pattern = re.compile(r'(https?://[^\s/$.?#].[^\s"]*)')
|
|
1408
|
+
# Use set to remove duplicates
|
|
1409
|
+
urls = list(set(url_pattern.findall(text)))
|
|
1410
|
+
for url in urls:
|
|
1411
|
+
url = url.rstrip(".',\"}") # Added } to the characters to strip
|
|
1412
|
+
await self.io.offer_url(url)
|
|
1413
|
+
return urls
|
|
1414
|
+
|
|
1415
|
+
async def check_for_urls(self, inp: str) -> List[str]:
|
|
1416
|
+
"""Check input for URLs and offer to add them to the chat."""
|
|
1417
|
+
if not self.detect_urls:
|
|
1418
|
+
return inp
|
|
1419
|
+
|
|
1420
|
+
# Exclude double quotes from the matched URL characters
|
|
1421
|
+
url_pattern = re.compile(r'(https?://[^\s/$.?#].[^\s"]*[^\s,.])')
|
|
1422
|
+
# Use set to remove duplicates
|
|
1423
|
+
urls = list(set(url_pattern.findall(inp)))
|
|
1424
|
+
group = ConfirmGroup(urls)
|
|
1425
|
+
for url in urls:
|
|
1426
|
+
if url not in self.rejected_urls:
|
|
1427
|
+
url = url.rstrip(".',\"")
|
|
1428
|
+
if await self.io.confirm_ask(
|
|
1429
|
+
"Add URL to the chat?", subject=url, group=group, allow_never=True
|
|
1430
|
+
):
|
|
1431
|
+
inp += "\n\n"
|
|
1432
|
+
inp += await self.commands.cmd_web(url, return_content=True)
|
|
1433
|
+
else:
|
|
1434
|
+
self.rejected_urls.add(url)
|
|
1435
|
+
|
|
1436
|
+
return inp
|
|
1437
|
+
|
|
1438
|
+
def keyboard_interrupt(self):
|
|
1439
|
+
# Ensure cursor is visible on exit
|
|
1440
|
+
Console().show_cursor(True)
|
|
1441
|
+
|
|
1442
|
+
self.io.tool_warning("\n\n^C KeyboardInterrupt")
|
|
1443
|
+
|
|
1444
|
+
self.last_keyboard_interrupt = time.time()
|
|
1445
|
+
|
|
1446
|
+
def summarize_start(self):
|
|
1447
|
+
if not self.summarizer.check_max_tokens(self.done_messages):
|
|
1448
|
+
return
|
|
1449
|
+
|
|
1450
|
+
self.summarize_end()
|
|
1451
|
+
|
|
1452
|
+
if self.verbose:
|
|
1453
|
+
self.io.tool_output("Starting to summarize chat history.")
|
|
1454
|
+
|
|
1455
|
+
self.summarizer_thread = threading.Thread(target=self.summarize_worker)
|
|
1456
|
+
self.summarizer_thread.start()
|
|
1457
|
+
|
|
1458
|
+
def summarize_worker(self):
|
|
1459
|
+
self.summarizing_messages = list(self.done_messages)
|
|
1460
|
+
try:
|
|
1461
|
+
self.summarized_done_messages = asyncio.run(
|
|
1462
|
+
self.summarizer.summarize(self.summarizing_messages)
|
|
1463
|
+
)
|
|
1464
|
+
except ValueError as err:
|
|
1465
|
+
self.io.tool_warning(err.args[0])
|
|
1466
|
+
self.summarized_done_messages = self.summarizing_messages
|
|
1467
|
+
|
|
1468
|
+
if self.verbose:
|
|
1469
|
+
self.io.tool_output("Finished summarizing chat history.")
|
|
1470
|
+
|
|
1471
|
+
def summarize_end(self):
|
|
1472
|
+
if self.summarizer_thread is None:
|
|
1473
|
+
return
|
|
1474
|
+
|
|
1475
|
+
self.summarizer_thread.join()
|
|
1476
|
+
self.summarizer_thread = None
|
|
1477
|
+
|
|
1478
|
+
if self.summarizing_messages == self.done_messages:
|
|
1479
|
+
self.done_messages = self.summarized_done_messages
|
|
1480
|
+
self.summarizing_messages = None
|
|
1481
|
+
self.summarized_done_messages = []
|
|
1482
|
+
|
|
1483
|
+
async def compact_context_if_needed(self):
|
|
1484
|
+
if not self.enable_context_compaction:
|
|
1485
|
+
self.summarize_start()
|
|
1486
|
+
return
|
|
1487
|
+
|
|
1488
|
+
if not self.summarizer.check_max_tokens(
|
|
1489
|
+
self.done_messages, max_tokens=self.context_compaction_max_tokens
|
|
1490
|
+
):
|
|
1491
|
+
return
|
|
1492
|
+
|
|
1493
|
+
self.io.tool_output("Compacting chat history to make room for new messages...")
|
|
1494
|
+
|
|
1495
|
+
try:
|
|
1496
|
+
# Create a summary of the conversation
|
|
1497
|
+
summary_text = await self.summarizer.summarize_all_as_text(
|
|
1498
|
+
self.done_messages,
|
|
1499
|
+
self.gpt_prompts.compaction_prompt,
|
|
1500
|
+
self.context_compaction_summary_tokens,
|
|
1501
|
+
)
|
|
1502
|
+
if not summary_text:
|
|
1503
|
+
raise ValueError("Summarization returned an empty result.")
|
|
1504
|
+
|
|
1505
|
+
# Replace old messages with the summary
|
|
1506
|
+
self.done_messages = [
|
|
1507
|
+
{
|
|
1508
|
+
"role": "user",
|
|
1509
|
+
"content": summary_text,
|
|
1510
|
+
},
|
|
1511
|
+
{
|
|
1512
|
+
"role": "assistant",
|
|
1513
|
+
"content": (
|
|
1514
|
+
"Ok, I will use this summary as the context for our conversation going"
|
|
1515
|
+
" forward."
|
|
1516
|
+
),
|
|
1517
|
+
},
|
|
1518
|
+
]
|
|
1519
|
+
self.io.tool_output("...chat history compacted.")
|
|
1520
|
+
except Exception as e:
|
|
1521
|
+
self.io.tool_warning(f"Context compaction failed: {e}")
|
|
1522
|
+
self.io.tool_warning("Proceeding with full history for now.")
|
|
1523
|
+
self.summarize_start()
|
|
1524
|
+
return
|
|
1525
|
+
|
|
1526
|
+
def move_back_cur_messages(self, message):
|
|
1527
|
+
self.done_messages += self.cur_messages
|
|
1528
|
+
|
|
1529
|
+
# TODO check for impact on image messages
|
|
1530
|
+
if message:
|
|
1531
|
+
self.done_messages += [
|
|
1532
|
+
dict(role="user", content=message),
|
|
1533
|
+
dict(role="assistant", content="Ok."),
|
|
1534
|
+
]
|
|
1535
|
+
self.cur_messages = []
|
|
1536
|
+
|
|
1537
|
+
def normalize_language(self, lang_code):
|
|
1538
|
+
"""
|
|
1539
|
+
Convert a locale code such as ``en_US`` or ``fr`` into a readable
|
|
1540
|
+
language name (e.g. ``English`` or ``French``). If Babel is
|
|
1541
|
+
available it is used for reliable conversion; otherwise a small
|
|
1542
|
+
built-in fallback map handles common languages.
|
|
1543
|
+
"""
|
|
1544
|
+
if not lang_code:
|
|
1545
|
+
return None
|
|
1546
|
+
|
|
1547
|
+
if lang_code.upper() in ("C", "POSIX"):
|
|
1548
|
+
return None
|
|
1549
|
+
|
|
1550
|
+
# Probably already a language name
|
|
1551
|
+
if (
|
|
1552
|
+
len(lang_code) > 3
|
|
1553
|
+
and "_" not in lang_code
|
|
1554
|
+
and "-" not in lang_code
|
|
1555
|
+
and lang_code[0].isupper()
|
|
1556
|
+
):
|
|
1557
|
+
return lang_code
|
|
1558
|
+
|
|
1559
|
+
# Preferred: Babel
|
|
1560
|
+
if Locale is not None:
|
|
1561
|
+
try:
|
|
1562
|
+
loc = Locale.parse(lang_code.replace("-", "_"))
|
|
1563
|
+
return loc.get_display_name("en").capitalize()
|
|
1564
|
+
except Exception:
|
|
1565
|
+
pass # Fall back to manual mapping
|
|
1566
|
+
|
|
1567
|
+
# Simple fallback for common languages
|
|
1568
|
+
fallback = {
|
|
1569
|
+
"en": "English",
|
|
1570
|
+
"fr": "French",
|
|
1571
|
+
"es": "Spanish",
|
|
1572
|
+
"de": "German",
|
|
1573
|
+
"it": "Italian",
|
|
1574
|
+
"pt": "Portuguese",
|
|
1575
|
+
"zh": "Chinese",
|
|
1576
|
+
"ja": "Japanese",
|
|
1577
|
+
"ko": "Korean",
|
|
1578
|
+
"ru": "Russian",
|
|
1579
|
+
}
|
|
1580
|
+
primary_lang_code = lang_code.replace("-", "_").split("_")[0].lower()
|
|
1581
|
+
return fallback.get(primary_lang_code, lang_code)
|
|
1582
|
+
|
|
1583
|
+
def get_user_language(self):
|
|
1584
|
+
"""
|
|
1585
|
+
Detect the user's language preference and return a human-readable
|
|
1586
|
+
language name such as ``English``. Detection order:
|
|
1587
|
+
|
|
1588
|
+
1. ``self.chat_language`` if explicitly set
|
|
1589
|
+
2. ``locale.getlocale()``
|
|
1590
|
+
3. ``LANG`` / ``LANGUAGE`` / ``LC_ALL`` / ``LC_MESSAGES`` environment variables
|
|
1591
|
+
"""
|
|
1592
|
+
|
|
1593
|
+
# Explicit override
|
|
1594
|
+
if self.chat_language:
|
|
1595
|
+
return self.normalize_language(self.chat_language)
|
|
1596
|
+
|
|
1597
|
+
# System locale
|
|
1598
|
+
try:
|
|
1599
|
+
lang = locale.getlocale()[0]
|
|
1600
|
+
if lang:
|
|
1601
|
+
lang = self.normalize_language(lang)
|
|
1602
|
+
if lang:
|
|
1603
|
+
return lang
|
|
1604
|
+
except Exception:
|
|
1605
|
+
pass
|
|
1606
|
+
|
|
1607
|
+
# Environment variables
|
|
1608
|
+
for env_var in ("LANG", "LANGUAGE", "LC_ALL", "LC_MESSAGES"):
|
|
1609
|
+
lang = os.environ.get(env_var)
|
|
1610
|
+
if lang:
|
|
1611
|
+
lang = lang.split(".")[0] # Strip encoding if present
|
|
1612
|
+
return self.normalize_language(lang)
|
|
1613
|
+
|
|
1614
|
+
return None
|
|
1615
|
+
|
|
1616
|
+
def get_platform_info(self):
|
|
1617
|
+
platform_text = ""
|
|
1618
|
+
try:
|
|
1619
|
+
platform_text = f"- Platform: {platform.platform()}\n"
|
|
1620
|
+
except KeyError:
|
|
1621
|
+
# Skip platform info if it can't be retrieved
|
|
1622
|
+
platform_text = "- Platform information unavailable\n"
|
|
1623
|
+
|
|
1624
|
+
shell_var = "COMSPEC" if os.name == "nt" else "SHELL"
|
|
1625
|
+
shell_val = os.getenv(shell_var)
|
|
1626
|
+
platform_text += f"- Shell: {shell_var}={shell_val}\n"
|
|
1627
|
+
|
|
1628
|
+
user_lang = self.get_user_language()
|
|
1629
|
+
if user_lang:
|
|
1630
|
+
platform_text += f"- Language: {user_lang}\n"
|
|
1631
|
+
|
|
1632
|
+
dt = datetime.now().astimezone().strftime("%Y-%m-%d")
|
|
1633
|
+
platform_text += f"- Current date: {dt}\n"
|
|
1634
|
+
|
|
1635
|
+
if self.repo:
|
|
1636
|
+
platform_text += "- The user is operating inside a git repository\n"
|
|
1637
|
+
|
|
1638
|
+
if self.lint_cmds:
|
|
1639
|
+
if self.auto_lint:
|
|
1640
|
+
platform_text += (
|
|
1641
|
+
"- The user's pre-commit runs these lint commands, don't suggest running"
|
|
1642
|
+
" them:\n"
|
|
1643
|
+
)
|
|
1644
|
+
else:
|
|
1645
|
+
platform_text += "- The user prefers these lint commands:\n"
|
|
1646
|
+
for lang, cmd in self.lint_cmds.items():
|
|
1647
|
+
if lang is None:
|
|
1648
|
+
platform_text += f" - {cmd}\n"
|
|
1649
|
+
else:
|
|
1650
|
+
platform_text += f" - {lang}: {cmd}\n"
|
|
1651
|
+
|
|
1652
|
+
if self.test_cmd:
|
|
1653
|
+
if self.auto_test:
|
|
1654
|
+
platform_text += (
|
|
1655
|
+
"- The user's pre-commit runs this test command, don't suggest running them: "
|
|
1656
|
+
)
|
|
1657
|
+
else:
|
|
1658
|
+
platform_text += "- The user prefers this test command: "
|
|
1659
|
+
platform_text += self.test_cmd + "\n"
|
|
1660
|
+
|
|
1661
|
+
return platform_text
|
|
1662
|
+
|
|
1663
|
+
def fmt_system_prompt(self, prompt):
|
|
1664
|
+
final_reminders = []
|
|
1665
|
+
|
|
1666
|
+
lazy_prompt = ""
|
|
1667
|
+
if self.main_model.lazy:
|
|
1668
|
+
lazy_prompt = self.gpt_prompts.lazy_prompt
|
|
1669
|
+
final_reminders.append(lazy_prompt)
|
|
1670
|
+
|
|
1671
|
+
overeager_prompt = ""
|
|
1672
|
+
if self.main_model.overeager:
|
|
1673
|
+
overeager_prompt = self.gpt_prompts.overeager_prompt
|
|
1674
|
+
final_reminders.append(overeager_prompt)
|
|
1675
|
+
|
|
1676
|
+
user_lang = self.get_user_language()
|
|
1677
|
+
if user_lang:
|
|
1678
|
+
final_reminders.append(f"Reply in {user_lang}.\n")
|
|
1679
|
+
|
|
1680
|
+
platform_text = self.get_platform_info()
|
|
1681
|
+
|
|
1682
|
+
if self.suggest_shell_commands:
|
|
1683
|
+
shell_cmd_prompt = self.gpt_prompts.shell_cmd_prompt.format(platform=platform_text)
|
|
1684
|
+
shell_cmd_reminder = self.gpt_prompts.shell_cmd_reminder.format(platform=platform_text)
|
|
1685
|
+
rename_with_shell = self.gpt_prompts.rename_with_shell
|
|
1686
|
+
else:
|
|
1687
|
+
shell_cmd_prompt = self.gpt_prompts.no_shell_cmd_prompt.format(platform=platform_text)
|
|
1688
|
+
shell_cmd_reminder = self.gpt_prompts.no_shell_cmd_reminder.format(
|
|
1689
|
+
platform=platform_text
|
|
1690
|
+
)
|
|
1691
|
+
rename_with_shell = ""
|
|
1692
|
+
|
|
1693
|
+
if user_lang: # user_lang is the result of self.get_user_language()
|
|
1694
|
+
language = user_lang
|
|
1695
|
+
else:
|
|
1696
|
+
# Default if no specific lang detected
|
|
1697
|
+
language = "the same language they are using"
|
|
1698
|
+
|
|
1699
|
+
if self.fence[0] == "`" * 4:
|
|
1700
|
+
quad_backtick_reminder = (
|
|
1701
|
+
"\nIMPORTANT: Use *quadruple* backticks ```` as fences, not triple backticks!\n"
|
|
1702
|
+
)
|
|
1703
|
+
else:
|
|
1704
|
+
quad_backtick_reminder = ""
|
|
1705
|
+
|
|
1706
|
+
if self.mcp_tools and len(self.mcp_tools) > 0:
|
|
1707
|
+
final_reminders.append(self.gpt_prompts.tool_prompt)
|
|
1708
|
+
|
|
1709
|
+
final_reminders = "\n\n".join(final_reminders)
|
|
1710
|
+
|
|
1711
|
+
prompt = prompt.format(
|
|
1712
|
+
fence=self.fence,
|
|
1713
|
+
quad_backtick_reminder=quad_backtick_reminder,
|
|
1714
|
+
final_reminders=final_reminders,
|
|
1715
|
+
platform=platform_text,
|
|
1716
|
+
shell_cmd_prompt=shell_cmd_prompt,
|
|
1717
|
+
rename_with_shell=rename_with_shell,
|
|
1718
|
+
shell_cmd_reminder=shell_cmd_reminder,
|
|
1719
|
+
go_ahead_tip=self.gpt_prompts.go_ahead_tip,
|
|
1720
|
+
language=language,
|
|
1721
|
+
lazy_prompt=lazy_prompt,
|
|
1722
|
+
overeager_prompt=overeager_prompt,
|
|
1723
|
+
)
|
|
1724
|
+
|
|
1725
|
+
return prompt
|
|
1726
|
+
|
|
1727
|
+
def format_chat_chunks(self):
|
|
1728
|
+
self.choose_fence()
|
|
1729
|
+
main_sys = self.fmt_system_prompt(self.gpt_prompts.main_system)
|
|
1730
|
+
if self.main_model.system_prompt_prefix:
|
|
1731
|
+
main_sys = self.main_model.system_prompt_prefix + "\n" + main_sys
|
|
1732
|
+
|
|
1733
|
+
example_messages = []
|
|
1734
|
+
if self.main_model.examples_as_sys_msg:
|
|
1735
|
+
if self.gpt_prompts.example_messages:
|
|
1736
|
+
main_sys += "\n# Example conversations:\n\n"
|
|
1737
|
+
for msg in self.gpt_prompts.example_messages:
|
|
1738
|
+
role = msg["role"]
|
|
1739
|
+
content = self.fmt_system_prompt(msg["content"])
|
|
1740
|
+
main_sys += f"## {role.upper()}: {content}\n\n"
|
|
1741
|
+
main_sys = main_sys.strip()
|
|
1742
|
+
else:
|
|
1743
|
+
for msg in self.gpt_prompts.example_messages:
|
|
1744
|
+
example_messages.append(
|
|
1745
|
+
dict(
|
|
1746
|
+
role=msg["role"],
|
|
1747
|
+
content=self.fmt_system_prompt(msg["content"]),
|
|
1748
|
+
)
|
|
1749
|
+
)
|
|
1750
|
+
if self.gpt_prompts.example_messages:
|
|
1751
|
+
example_messages += [
|
|
1752
|
+
dict(
|
|
1753
|
+
role="user",
|
|
1754
|
+
content=(
|
|
1755
|
+
"I switched to a new code base. Please don't consider the above files"
|
|
1756
|
+
" or try to edit them any longer."
|
|
1757
|
+
),
|
|
1758
|
+
),
|
|
1759
|
+
dict(role="assistant", content="Ok."),
|
|
1760
|
+
]
|
|
1761
|
+
|
|
1762
|
+
if self.gpt_prompts.system_reminder:
|
|
1763
|
+
main_sys += "\n" + self.fmt_system_prompt(self.gpt_prompts.system_reminder)
|
|
1764
|
+
|
|
1765
|
+
chunks = ChatChunks()
|
|
1766
|
+
|
|
1767
|
+
if self.main_model.use_system_prompt:
|
|
1768
|
+
chunks.system = [
|
|
1769
|
+
dict(role="system", content=main_sys),
|
|
1770
|
+
]
|
|
1771
|
+
else:
|
|
1772
|
+
chunks.system = [
|
|
1773
|
+
dict(role="user", content=main_sys),
|
|
1774
|
+
dict(role="assistant", content="Ok."),
|
|
1775
|
+
]
|
|
1776
|
+
|
|
1777
|
+
chunks.examples = example_messages
|
|
1778
|
+
|
|
1779
|
+
self.summarize_end()
|
|
1780
|
+
chunks.done = self.done_messages
|
|
1781
|
+
|
|
1782
|
+
chunks.repo = self.get_repo_messages()
|
|
1783
|
+
chunks.readonly_files = self.get_readonly_files_messages()
|
|
1784
|
+
chunks.chat_files = self.get_chat_files_messages()
|
|
1785
|
+
|
|
1786
|
+
if self.gpt_prompts.system_reminder:
|
|
1787
|
+
reminder_message = [
|
|
1788
|
+
dict(
|
|
1789
|
+
role="system", content=self.fmt_system_prompt(self.gpt_prompts.system_reminder)
|
|
1790
|
+
),
|
|
1791
|
+
]
|
|
1792
|
+
else:
|
|
1793
|
+
reminder_message = []
|
|
1794
|
+
|
|
1795
|
+
chunks.cur = list(self.cur_messages)
|
|
1796
|
+
chunks.reminder = []
|
|
1797
|
+
|
|
1798
|
+
# TODO review impact of token count on image messages
|
|
1799
|
+
messages_tokens = self.main_model.token_count(chunks.all_messages())
|
|
1800
|
+
reminder_tokens = self.main_model.token_count(reminder_message)
|
|
1801
|
+
cur_tokens = self.main_model.token_count(chunks.cur)
|
|
1802
|
+
|
|
1803
|
+
if None not in (messages_tokens, reminder_tokens, cur_tokens):
|
|
1804
|
+
total_tokens = messages_tokens
|
|
1805
|
+
# Only add tokens for reminder and cur if they're not already included
|
|
1806
|
+
# in the messages_tokens calculation
|
|
1807
|
+
if not chunks.reminder:
|
|
1808
|
+
total_tokens += reminder_tokens
|
|
1809
|
+
if not chunks.cur:
|
|
1810
|
+
total_tokens += cur_tokens
|
|
1811
|
+
else:
|
|
1812
|
+
# add the reminder anyway
|
|
1813
|
+
total_tokens = 0
|
|
1814
|
+
|
|
1815
|
+
if chunks.cur:
|
|
1816
|
+
final = chunks.cur[-1]
|
|
1817
|
+
else:
|
|
1818
|
+
final = None
|
|
1819
|
+
|
|
1820
|
+
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
|
|
1821
|
+
# Add the reminder prompt if we still have room to include it.
|
|
1822
|
+
if (
|
|
1823
|
+
not max_input_tokens
|
|
1824
|
+
or total_tokens < max_input_tokens
|
|
1825
|
+
and self.gpt_prompts.system_reminder
|
|
1826
|
+
):
|
|
1827
|
+
if self.main_model.reminder == "sys":
|
|
1828
|
+
chunks.reminder = reminder_message
|
|
1829
|
+
elif self.main_model.reminder == "user" and final and final["role"] == "user":
|
|
1830
|
+
# stuff it into the user message
|
|
1831
|
+
new_content = (
|
|
1832
|
+
final["content"]
|
|
1833
|
+
+ "\n\n"
|
|
1834
|
+
+ self.fmt_system_prompt(self.gpt_prompts.system_reminder)
|
|
1835
|
+
)
|
|
1836
|
+
chunks.cur[-1] = dict(role=final["role"], content=new_content)
|
|
1837
|
+
|
|
1838
|
+
return chunks
|
|
1839
|
+
|
|
1840
|
+
def format_messages(self):
|
|
1841
|
+
chunks = self.format_chat_chunks()
|
|
1842
|
+
if self.add_cache_headers:
|
|
1843
|
+
chunks.add_cache_control_headers()
|
|
1844
|
+
|
|
1845
|
+
return chunks
|
|
1846
|
+
|
|
1847
|
+
def warm_cache(self, chunks):
|
|
1848
|
+
if not self.add_cache_headers:
|
|
1849
|
+
return
|
|
1850
|
+
if not self.num_cache_warming_pings:
|
|
1851
|
+
return
|
|
1852
|
+
if not self.ok_to_warm_cache:
|
|
1853
|
+
return
|
|
1854
|
+
|
|
1855
|
+
delay = 5 * 60 - 5
|
|
1856
|
+
delay = float(os.environ.get("AIDER_CACHE_KEEPALIVE_DELAY", delay))
|
|
1857
|
+
self.next_cache_warm = time.time() + delay
|
|
1858
|
+
self.warming_pings_left = self.num_cache_warming_pings
|
|
1859
|
+
self.cache_warming_chunks = chunks
|
|
1860
|
+
|
|
1861
|
+
if self.cache_warming_thread:
|
|
1862
|
+
return
|
|
1863
|
+
|
|
1864
|
+
def warm_cache_worker():
|
|
1865
|
+
while self.ok_to_warm_cache:
|
|
1866
|
+
time.sleep(1)
|
|
1867
|
+
if self.warming_pings_left <= 0:
|
|
1868
|
+
continue
|
|
1869
|
+
now = time.time()
|
|
1870
|
+
if now < self.next_cache_warm:
|
|
1871
|
+
continue
|
|
1872
|
+
|
|
1873
|
+
self.warming_pings_left -= 1
|
|
1874
|
+
self.next_cache_warm = time.time() + delay
|
|
1875
|
+
|
|
1876
|
+
kwargs = dict(self.main_model.extra_params) or dict()
|
|
1877
|
+
kwargs["max_tokens"] = 1
|
|
1878
|
+
|
|
1879
|
+
try:
|
|
1880
|
+
completion = litellm.completion(
|
|
1881
|
+
model=self.main_model.name,
|
|
1882
|
+
messages=self.cache_warming_chunks.cacheable_messages(),
|
|
1883
|
+
stream=False,
|
|
1884
|
+
**kwargs,
|
|
1885
|
+
)
|
|
1886
|
+
except Exception as err:
|
|
1887
|
+
self.io.tool_warning(f"Cache warming error: {str(err)}")
|
|
1888
|
+
continue
|
|
1889
|
+
|
|
1890
|
+
cache_hit_tokens = getattr(
|
|
1891
|
+
completion.usage, "prompt_cache_hit_tokens", 0
|
|
1892
|
+
) or getattr(completion.usage, "cache_read_input_tokens", 0)
|
|
1893
|
+
|
|
1894
|
+
if self.verbose:
|
|
1895
|
+
self.io.tool_output(f"Warmed {format_tokens(cache_hit_tokens)} cached tokens.")
|
|
1896
|
+
|
|
1897
|
+
self.cache_warming_thread = threading.Timer(0, warm_cache_worker)
|
|
1898
|
+
self.cache_warming_thread.daemon = True
|
|
1899
|
+
self.cache_warming_thread.start()
|
|
1900
|
+
|
|
1901
|
+
return chunks
|
|
1902
|
+
|
|
1903
|
+
async def check_tokens(self, messages):
|
|
1904
|
+
"""Check if the messages will fit within the model's token limits."""
|
|
1905
|
+
input_tokens = self.main_model.token_count(messages)
|
|
1906
|
+
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
|
|
1907
|
+
|
|
1908
|
+
if max_input_tokens and input_tokens >= max_input_tokens:
|
|
1909
|
+
self.io.tool_error(
|
|
1910
|
+
f"Your estimated chat context of {input_tokens:,} tokens exceeds the"
|
|
1911
|
+
f" {max_input_tokens:,} token limit for {self.main_model.name}!"
|
|
1912
|
+
)
|
|
1913
|
+
self.io.tool_output("To reduce the chat context:")
|
|
1914
|
+
self.io.tool_output("- Use /drop to remove unneeded files from the chat")
|
|
1915
|
+
self.io.tool_output("- Use /clear to clear the chat history")
|
|
1916
|
+
self.io.tool_output("- Break your code into smaller files")
|
|
1917
|
+
self.io.tool_output(
|
|
1918
|
+
"It's probably safe to try and send the request, most providers won't charge if"
|
|
1919
|
+
" the context limit is exceeded."
|
|
1920
|
+
)
|
|
1921
|
+
|
|
1922
|
+
if not await self.io.confirm_ask("Try to proceed anyway?"):
|
|
1923
|
+
return False
|
|
1924
|
+
return True
|
|
1925
|
+
|
|
1926
|
+
async def send_message(self, inp):
|
|
1927
|
+
self.event("message_send_starting")
|
|
1928
|
+
|
|
1929
|
+
# Notify IO that LLM processing is starting
|
|
1930
|
+
self.io.llm_started()
|
|
1931
|
+
|
|
1932
|
+
if inp:
|
|
1933
|
+
self.cur_messages += [
|
|
1934
|
+
dict(role="user", content=inp),
|
|
1935
|
+
]
|
|
1936
|
+
|
|
1937
|
+
loop = asyncio.get_running_loop()
|
|
1938
|
+
chunks = await loop.run_in_executor(None, self.format_messages)
|
|
1939
|
+
messages = chunks.all_messages()
|
|
1940
|
+
|
|
1941
|
+
if not await self.check_tokens(messages):
|
|
1942
|
+
return
|
|
1943
|
+
self.warm_cache(chunks)
|
|
1944
|
+
|
|
1945
|
+
if self.verbose:
|
|
1946
|
+
utils.show_messages(messages, functions=self.functions)
|
|
1947
|
+
|
|
1948
|
+
self.multi_response_content = ""
|
|
1949
|
+
if self.show_pretty():
|
|
1950
|
+
spinner_text = (
|
|
1951
|
+
f"Waiting for {self.main_model.name} • ${self.format_cost(self.total_cost)} session"
|
|
1952
|
+
)
|
|
1953
|
+
self.io.start_spinner(spinner_text)
|
|
1954
|
+
|
|
1955
|
+
if self.stream:
|
|
1956
|
+
self.mdstream = True
|
|
1957
|
+
else:
|
|
1958
|
+
self.mdstream = None
|
|
1959
|
+
else:
|
|
1960
|
+
self.mdstream = None
|
|
1961
|
+
|
|
1962
|
+
retry_delay = 0.125
|
|
1963
|
+
|
|
1964
|
+
litellm_ex = LiteLLMExceptions()
|
|
1965
|
+
|
|
1966
|
+
self.usage_report = None
|
|
1967
|
+
exhausted = False
|
|
1968
|
+
interrupted = False
|
|
1969
|
+
|
|
1970
|
+
try:
|
|
1971
|
+
while True:
|
|
1972
|
+
try:
|
|
1973
|
+
async for chunk in self.send(messages, tools=self.get_tool_list()):
|
|
1974
|
+
yield chunk
|
|
1975
|
+
break
|
|
1976
|
+
except litellm_ex.exceptions_tuple() as err:
|
|
1977
|
+
ex_info = litellm_ex.get_ex_info(err)
|
|
1978
|
+
|
|
1979
|
+
if ex_info.name == "ContextWindowExceededError":
|
|
1980
|
+
exhausted = True
|
|
1981
|
+
break
|
|
1982
|
+
|
|
1983
|
+
should_retry = ex_info.retry
|
|
1984
|
+
if should_retry:
|
|
1985
|
+
retry_delay *= 2
|
|
1986
|
+
if retry_delay > RETRY_TIMEOUT:
|
|
1987
|
+
should_retry = False
|
|
1988
|
+
|
|
1989
|
+
if not should_retry:
|
|
1990
|
+
self.mdstream = None
|
|
1991
|
+
await self.check_and_open_urls(err, ex_info.description)
|
|
1992
|
+
break
|
|
1993
|
+
|
|
1994
|
+
err_msg = str(err)
|
|
1995
|
+
if ex_info.description:
|
|
1996
|
+
self.io.tool_warning(err_msg)
|
|
1997
|
+
self.io.tool_error(ex_info.description)
|
|
1998
|
+
else:
|
|
1999
|
+
self.io.tool_error(err_msg)
|
|
2000
|
+
|
|
2001
|
+
self.io.tool_output(f"Retrying in {retry_delay:.1f} seconds...")
|
|
2002
|
+
await asyncio.sleep(retry_delay)
|
|
2003
|
+
continue
|
|
2004
|
+
except KeyboardInterrupt:
|
|
2005
|
+
interrupted = True
|
|
2006
|
+
break
|
|
2007
|
+
except FinishReasonLength:
|
|
2008
|
+
# We hit the output limit!
|
|
2009
|
+
if not self.main_model.info.get("supports_assistant_prefill"):
|
|
2010
|
+
exhausted = True
|
|
2011
|
+
break
|
|
2012
|
+
|
|
2013
|
+
self.multi_response_content = self.get_multi_response_content_in_progress()
|
|
2014
|
+
|
|
2015
|
+
if messages[-1]["role"] == "assistant":
|
|
2016
|
+
messages[-1]["content"] = self.multi_response_content
|
|
2017
|
+
else:
|
|
2018
|
+
messages.append(
|
|
2019
|
+
dict(role="assistant", content=self.multi_response_content, prefix=True)
|
|
2020
|
+
)
|
|
2021
|
+
except Exception as err:
|
|
2022
|
+
self.mdstream = None
|
|
2023
|
+
lines = traceback.format_exception(type(err), err, err.__traceback__)
|
|
2024
|
+
self.io.tool_warning("".join(lines))
|
|
2025
|
+
self.io.tool_error(str(err))
|
|
2026
|
+
self.event("message_send_exception", exception=str(err))
|
|
2027
|
+
return
|
|
2028
|
+
finally:
|
|
2029
|
+
if self.mdstream:
|
|
2030
|
+
content_to_show = self.live_incremental_response(True)
|
|
2031
|
+
self.stream_wrapper(content_to_show, final=True)
|
|
2032
|
+
self.mdstream = None
|
|
2033
|
+
|
|
2034
|
+
# Ensure any waiting spinner is stopped
|
|
2035
|
+
self.io.start_spinner("Processing Answer...")
|
|
2036
|
+
|
|
2037
|
+
self.partial_response_content = self.get_multi_response_content_in_progress(True)
|
|
2038
|
+
self.remove_reasoning_content()
|
|
2039
|
+
self.multi_response_content = ""
|
|
2040
|
+
|
|
2041
|
+
self.io.tool_output()
|
|
2042
|
+
|
|
2043
|
+
self.show_usage_report()
|
|
2044
|
+
|
|
2045
|
+
self.add_assistant_reply_to_cur_messages()
|
|
2046
|
+
|
|
2047
|
+
if exhausted:
|
|
2048
|
+
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
|
2049
|
+
self.cur_messages += [
|
|
2050
|
+
dict(
|
|
2051
|
+
role="assistant",
|
|
2052
|
+
content="FinishReasonLength exception: you sent too many tokens",
|
|
2053
|
+
),
|
|
2054
|
+
]
|
|
2055
|
+
|
|
2056
|
+
await self.show_exhausted_error()
|
|
2057
|
+
self.num_exhausted_context_windows += 1
|
|
2058
|
+
return
|
|
2059
|
+
|
|
2060
|
+
if self.partial_response_function_call:
|
|
2061
|
+
args = self.parse_partial_args()
|
|
2062
|
+
if args:
|
|
2063
|
+
content = args.get("explanation") or ""
|
|
2064
|
+
else:
|
|
2065
|
+
content = ""
|
|
2066
|
+
elif self.partial_response_content:
|
|
2067
|
+
content = self.partial_response_content
|
|
2068
|
+
else:
|
|
2069
|
+
content = ""
|
|
2070
|
+
|
|
2071
|
+
if interrupted:
|
|
2072
|
+
if self.cur_messages and self.cur_messages[-1]["role"] == "user":
|
|
2073
|
+
self.cur_messages[-1]["content"] += "\n^C KeyboardInterrupt"
|
|
2074
|
+
else:
|
|
2075
|
+
self.cur_messages += [dict(role="user", content="^C KeyboardInterrupt")]
|
|
2076
|
+
self.cur_messages += [
|
|
2077
|
+
dict(role="assistant", content="I see that you interrupted my previous reply.")
|
|
2078
|
+
]
|
|
2079
|
+
return
|
|
2080
|
+
|
|
2081
|
+
edited = await self.apply_updates()
|
|
2082
|
+
|
|
2083
|
+
if edited:
|
|
2084
|
+
self.aider_edited_files.update(edited)
|
|
2085
|
+
saved_message = await self.auto_commit(edited)
|
|
2086
|
+
|
|
2087
|
+
if not saved_message and hasattr(self.gpt_prompts, "files_content_gpt_edits_no_repo"):
|
|
2088
|
+
saved_message = self.gpt_prompts.files_content_gpt_edits_no_repo
|
|
2089
|
+
|
|
2090
|
+
self.move_back_cur_messages(saved_message)
|
|
2091
|
+
|
|
2092
|
+
if not interrupted:
|
|
2093
|
+
add_rel_files_message = await self.check_for_file_mentions(content)
|
|
2094
|
+
if add_rel_files_message:
|
|
2095
|
+
if self.reflected_message:
|
|
2096
|
+
self.reflected_message += "\n\n" + add_rel_files_message
|
|
2097
|
+
else:
|
|
2098
|
+
self.reflected_message = add_rel_files_message
|
|
2099
|
+
return
|
|
2100
|
+
|
|
2101
|
+
# Process any tools using MCP servers
|
|
2102
|
+
try:
|
|
2103
|
+
if self.partial_response_tool_calls:
|
|
2104
|
+
tool_calls = []
|
|
2105
|
+
tool_id_set = set()
|
|
2106
|
+
|
|
2107
|
+
for tool_call_dict in self.partial_response_tool_calls:
|
|
2108
|
+
# LLM APIs sometimes return duplicates and that's annoying
|
|
2109
|
+
if tool_call_dict.get("id") in tool_id_set:
|
|
2110
|
+
continue
|
|
2111
|
+
|
|
2112
|
+
tool_id_set.add(tool_call_dict.get("id"))
|
|
2113
|
+
|
|
2114
|
+
tool_calls.append(
|
|
2115
|
+
ChatCompletionMessageToolCall(
|
|
2116
|
+
id=tool_call_dict.get("id"),
|
|
2117
|
+
function=Function(
|
|
2118
|
+
name=tool_call_dict.get("function", {}).get("name"),
|
|
2119
|
+
arguments=tool_call_dict.get("function", {}).get(
|
|
2120
|
+
"arguments", ""
|
|
2121
|
+
),
|
|
2122
|
+
),
|
|
2123
|
+
type=tool_call_dict.get("type"),
|
|
2124
|
+
)
|
|
2125
|
+
)
|
|
2126
|
+
|
|
2127
|
+
tool_call_response = ModelResponse(
|
|
2128
|
+
choices=[
|
|
2129
|
+
Choices(
|
|
2130
|
+
finish_reason="tool_calls",
|
|
2131
|
+
index=0,
|
|
2132
|
+
message=Message(
|
|
2133
|
+
content=None,
|
|
2134
|
+
role="assistant",
|
|
2135
|
+
tool_calls=tool_calls,
|
|
2136
|
+
),
|
|
2137
|
+
)
|
|
2138
|
+
]
|
|
2139
|
+
)
|
|
2140
|
+
|
|
2141
|
+
if await self.process_tool_calls(tool_call_response):
|
|
2142
|
+
self.num_tool_calls += 1
|
|
2143
|
+
self.reflected_message = True
|
|
2144
|
+
return
|
|
2145
|
+
except Exception as e:
|
|
2146
|
+
self.io.tool_error(f"Error processing tool calls: {str(e)}")
|
|
2147
|
+
# Continue without tool processing
|
|
2148
|
+
|
|
2149
|
+
self.num_tool_calls = 0
|
|
2150
|
+
|
|
2151
|
+
try:
|
|
2152
|
+
if await self.reply_completed():
|
|
2153
|
+
return
|
|
2154
|
+
except KeyboardInterrupt:
|
|
2155
|
+
interrupted = True
|
|
2156
|
+
|
|
2157
|
+
if self.reflected_message:
|
|
2158
|
+
return
|
|
2159
|
+
|
|
2160
|
+
if edited and self.auto_lint:
|
|
2161
|
+
lint_errors = self.lint_edited(edited)
|
|
2162
|
+
await self.auto_commit(edited, context="Ran the linter")
|
|
2163
|
+
self.lint_outcome = not lint_errors
|
|
2164
|
+
if lint_errors:
|
|
2165
|
+
ok = await self.io.confirm_ask("Attempt to fix lint errors?")
|
|
2166
|
+
if ok:
|
|
2167
|
+
self.reflected_message = lint_errors
|
|
2168
|
+
return
|
|
2169
|
+
|
|
2170
|
+
shared_output = await self.run_shell_commands()
|
|
2171
|
+
if shared_output:
|
|
2172
|
+
self.cur_messages += [
|
|
2173
|
+
dict(role="user", content=shared_output),
|
|
2174
|
+
dict(role="assistant", content="Ok"),
|
|
2175
|
+
]
|
|
2176
|
+
|
|
2177
|
+
if edited and self.auto_test:
|
|
2178
|
+
test_errors = await self.commands.cmd_test(self.test_cmd)
|
|
2179
|
+
self.test_outcome = not test_errors
|
|
2180
|
+
if test_errors:
|
|
2181
|
+
ok = await self.io.confirm_ask("Attempt to fix test errors?")
|
|
2182
|
+
if ok:
|
|
2183
|
+
self.reflected_message = test_errors
|
|
2184
|
+
return
|
|
2185
|
+
|
|
2186
|
+
async def process_tool_calls(self, tool_call_response):
|
|
2187
|
+
if tool_call_response is None:
|
|
2188
|
+
return False
|
|
2189
|
+
|
|
2190
|
+
# Handle different response structures
|
|
2191
|
+
try:
|
|
2192
|
+
# Try to get tool calls from the standard OpenAI response format
|
|
2193
|
+
if hasattr(tool_call_response, "choices") and tool_call_response.choices:
|
|
2194
|
+
message = tool_call_response.choices[0].message
|
|
2195
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
2196
|
+
original_tool_calls = message.tool_calls
|
|
2197
|
+
else:
|
|
2198
|
+
return False
|
|
2199
|
+
else:
|
|
2200
|
+
# Handle other response formats
|
|
2201
|
+
return False
|
|
2202
|
+
except (AttributeError, IndexError):
|
|
2203
|
+
return False
|
|
2204
|
+
|
|
2205
|
+
if not original_tool_calls:
|
|
2206
|
+
return False
|
|
2207
|
+
|
|
2208
|
+
# Expand any tool calls that have concatenated JSON in their arguments.
|
|
2209
|
+
# This is necessary because some models (like Gemini) will serialize
|
|
2210
|
+
# multiple tool calls in this way.
|
|
2211
|
+
expanded_tool_calls = []
|
|
2212
|
+
for tool_call in original_tool_calls:
|
|
2213
|
+
args_string = tool_call.function.arguments.strip()
|
|
2214
|
+
|
|
2215
|
+
# If there are no arguments, or it's not a string that looks like it could
|
|
2216
|
+
# be concatenated JSON, just add it and continue.
|
|
2217
|
+
if not args_string or not (args_string.startswith("{") or args_string.startswith("[")):
|
|
2218
|
+
expanded_tool_calls.append(tool_call)
|
|
2219
|
+
continue
|
|
2220
|
+
|
|
2221
|
+
json_chunks = utils.split_concatenated_json(args_string)
|
|
2222
|
+
|
|
2223
|
+
# If it's just a single JSON object, there's nothing to expand.
|
|
2224
|
+
if len(json_chunks) <= 1:
|
|
2225
|
+
expanded_tool_calls.append(tool_call)
|
|
2226
|
+
continue
|
|
2227
|
+
|
|
2228
|
+
# We have concatenated JSON, so expand it into multiple tool calls.
|
|
2229
|
+
for i, chunk in enumerate(json_chunks):
|
|
2230
|
+
if not chunk.strip():
|
|
2231
|
+
continue
|
|
2232
|
+
|
|
2233
|
+
# Create a new tool call for each JSON chunk, with a unique ID.
|
|
2234
|
+
new_function = tool_call.function.model_copy(update={"arguments": chunk})
|
|
2235
|
+
new_tool_call = tool_call.model_copy(
|
|
2236
|
+
update={"id": f"{tool_call.id}-{i}", "function": new_function}
|
|
2237
|
+
)
|
|
2238
|
+
expanded_tool_calls.append(new_tool_call)
|
|
2239
|
+
|
|
2240
|
+
# Collect all tool calls grouped by server
|
|
2241
|
+
server_tool_calls = self._gather_server_tool_calls(expanded_tool_calls)
|
|
2242
|
+
|
|
2243
|
+
if server_tool_calls and self.num_tool_calls < self.max_tool_calls:
|
|
2244
|
+
self._print_tool_call_info(server_tool_calls)
|
|
2245
|
+
|
|
2246
|
+
if await self.io.confirm_ask("Run tools?", group_response="Run MCP Tools"):
|
|
2247
|
+
await self.io.recreate_input()
|
|
2248
|
+
tool_responses = await self._execute_tool_calls(server_tool_calls)
|
|
2249
|
+
|
|
2250
|
+
# Add all tool responses
|
|
2251
|
+
for tool_response in tool_responses:
|
|
2252
|
+
self.cur_messages.append(tool_response)
|
|
2253
|
+
|
|
2254
|
+
return True
|
|
2255
|
+
elif self.num_tool_calls >= self.max_tool_calls:
|
|
2256
|
+
self.io.tool_warning(f"Only {self.max_tool_calls} tool calls allowed, stopping.")
|
|
2257
|
+
|
|
2258
|
+
return False
|
|
2259
|
+
|
|
2260
|
+
def _print_tool_call_info(self, server_tool_calls):
|
|
2261
|
+
"""Print information about an MCP tool call."""
|
|
2262
|
+
# self.io.tool_output("Preparing to run MCP tools", bold=False)
|
|
2263
|
+
|
|
2264
|
+
for server, tool_calls in server_tool_calls.items():
|
|
2265
|
+
for tool_call in tool_calls:
|
|
2266
|
+
color_start = "[blue]" if self.pretty else ""
|
|
2267
|
+
color_end = "[/blue]" if self.pretty else ""
|
|
2268
|
+
|
|
2269
|
+
self.io.tool_output(
|
|
2270
|
+
f"{color_start}Tool Call:{color_end} {server.name} • {tool_call.function.name}"
|
|
2271
|
+
)
|
|
2272
|
+
# Parse and format arguments as headers with values
|
|
2273
|
+
if tool_call.function.arguments:
|
|
2274
|
+
# Only do JSON unwrapping for tools containing "replace" in their name
|
|
2275
|
+
if (
|
|
2276
|
+
"replace" in tool_call.function.name.lower()
|
|
2277
|
+
or "insert" in tool_call.function.name.lower()
|
|
2278
|
+
or "update" in tool_call.function.name.lower()
|
|
2279
|
+
):
|
|
2280
|
+
try:
|
|
2281
|
+
args_dict = json.loads(tool_call.function.arguments)
|
|
2282
|
+
first_key = True
|
|
2283
|
+
for key, value in args_dict.items():
|
|
2284
|
+
# Convert explicit \\n sequences to actual newlines using regex
|
|
2285
|
+
# Only match \\n that is not preceded by any other backslashes
|
|
2286
|
+
if isinstance(value, str):
|
|
2287
|
+
value = re.sub(r"(?<!\\)\\n", "\n", value)
|
|
2288
|
+
# Add extra newline before first key/header
|
|
2289
|
+
if first_key:
|
|
2290
|
+
self.io.tool_output("\n")
|
|
2291
|
+
first_key = False
|
|
2292
|
+
self.io.tool_output(f"{color_start}{key}:{color_end}")
|
|
2293
|
+
# Split the value by newlines and output each line separately
|
|
2294
|
+
if isinstance(value, str):
|
|
2295
|
+
for line in value.split("\n"):
|
|
2296
|
+
self.io.tool_output(f"{line}")
|
|
2297
|
+
else:
|
|
2298
|
+
self.io.tool_output(f"{str(value)}")
|
|
2299
|
+
self.io.tool_output("")
|
|
2300
|
+
except json.JSONDecodeError:
|
|
2301
|
+
# If JSON parsing fails, show raw arguments
|
|
2302
|
+
raw_args = tool_call.function.arguments
|
|
2303
|
+
self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}")
|
|
2304
|
+
else:
|
|
2305
|
+
# For non-replace tools, show raw arguments
|
|
2306
|
+
raw_args = tool_call.function.arguments
|
|
2307
|
+
self.io.tool_output(f"{color_start}Arguments:{color_end} {raw_args}")
|
|
2308
|
+
|
|
2309
|
+
if self.verbose:
|
|
2310
|
+
self.io.tool_output(f"Tool ID: {tool_call.id}")
|
|
2311
|
+
self.io.tool_output(f"Tool type: {tool_call.type}")
|
|
2312
|
+
|
|
2313
|
+
self.io.tool_output("\n")
|
|
2314
|
+
|
|
2315
|
+
def _gather_server_tool_calls(self, tool_calls):
|
|
2316
|
+
"""Collect all tool calls grouped by server.
|
|
2317
|
+
Args:
|
|
2318
|
+
tool_calls: List of tool calls from the LLM response
|
|
2319
|
+
|
|
2320
|
+
Returns:
|
|
2321
|
+
dict: Dictionary mapping servers to their respective tool calls
|
|
2322
|
+
"""
|
|
2323
|
+
if not self.mcp_tools or len(self.mcp_tools) == 0:
|
|
2324
|
+
return None
|
|
2325
|
+
|
|
2326
|
+
server_tool_calls = {}
|
|
2327
|
+
tool_id_set = set()
|
|
2328
|
+
|
|
2329
|
+
for tool_call in tool_calls:
|
|
2330
|
+
# LLM APIs sometimes return duplicates and that's annoying part 3
|
|
2331
|
+
if tool_call.get("id") in tool_id_set:
|
|
2332
|
+
continue
|
|
2333
|
+
|
|
2334
|
+
tool_id_set.add(tool_call.get("id"))
|
|
2335
|
+
|
|
2336
|
+
# Check if this tool_call matches any MCP tool
|
|
2337
|
+
for server_name, server_tools in self.mcp_tools:
|
|
2338
|
+
for tool in server_tools:
|
|
2339
|
+
tool_name_from_schema = tool.get("function", {}).get("name")
|
|
2340
|
+
if (
|
|
2341
|
+
tool_name_from_schema
|
|
2342
|
+
and tool_name_from_schema.lower() == tool_call.function.name.lower()
|
|
2343
|
+
):
|
|
2344
|
+
# Find the McpServer instance that will be used for communication
|
|
2345
|
+
for server in self.mcp_servers:
|
|
2346
|
+
if server.name == server_name:
|
|
2347
|
+
if server not in server_tool_calls:
|
|
2348
|
+
server_tool_calls[server] = []
|
|
2349
|
+
server_tool_calls[server].append(tool_call)
|
|
2350
|
+
break
|
|
2351
|
+
|
|
2352
|
+
return server_tool_calls
|
|
2353
|
+
|
|
2354
|
+
async def _execute_tool_calls(self, tool_calls):
|
|
2355
|
+
"""Process tool calls from the response and execute them if they match MCP tools.
|
|
2356
|
+
Returns a list of tool response messages."""
|
|
2357
|
+
tool_responses = []
|
|
2358
|
+
|
|
2359
|
+
# Define the coroutine to execute all tool calls for a single server
|
|
2360
|
+
async def _exec_server_tools(server, tool_calls_list):
|
|
2361
|
+
if isinstance(server, LocalServer):
|
|
2362
|
+
if hasattr(self, "_execute_local_tool_calls"):
|
|
2363
|
+
return await self._execute_local_tool_calls(tool_calls_list)
|
|
2364
|
+
else:
|
|
2365
|
+
# This coder doesn't support local tools, return errors for all calls
|
|
2366
|
+
error_responses = []
|
|
2367
|
+
for tool_call in tool_calls_list:
|
|
2368
|
+
error_responses.append(
|
|
2369
|
+
{
|
|
2370
|
+
"role": "tool",
|
|
2371
|
+
"tool_call_id": tool_call.id,
|
|
2372
|
+
"content": (
|
|
2373
|
+
f"Coder does not support local tool: {tool_call.function.name}"
|
|
2374
|
+
),
|
|
2375
|
+
}
|
|
2376
|
+
)
|
|
2377
|
+
return error_responses
|
|
2378
|
+
|
|
2379
|
+
tool_responses = []
|
|
2380
|
+
try:
|
|
2381
|
+
# Connect to the server once
|
|
2382
|
+
session = await server.connect()
|
|
2383
|
+
tool_id_set = set()
|
|
2384
|
+
|
|
2385
|
+
# Execute all tool calls for this server
|
|
2386
|
+
for tool_call in tool_calls_list:
|
|
2387
|
+
# LLM APIs sometimes return duplicates and that's annoying part 4
|
|
2388
|
+
if tool_call.id in tool_id_set:
|
|
2389
|
+
continue
|
|
2390
|
+
|
|
2391
|
+
tool_id_set.add(tool_call.id)
|
|
2392
|
+
|
|
2393
|
+
try:
|
|
2394
|
+
# Arguments can be a stream of JSON objects.
|
|
2395
|
+
# We need to parse them and run a tool call for each.
|
|
2396
|
+
args_string = tool_call.function.arguments.strip()
|
|
2397
|
+
parsed_args_list = []
|
|
2398
|
+
if args_string:
|
|
2399
|
+
json_chunks = utils.split_concatenated_json(args_string)
|
|
2400
|
+
for chunk in json_chunks:
|
|
2401
|
+
try:
|
|
2402
|
+
parsed_args_list.append(json.loads(chunk))
|
|
2403
|
+
except json.JSONDecodeError:
|
|
2404
|
+
self.io.tool_warning(
|
|
2405
|
+
"Could not parse JSON chunk for tool"
|
|
2406
|
+
f" {tool_call.function.name}: {chunk}"
|
|
2407
|
+
)
|
|
2408
|
+
continue
|
|
2409
|
+
|
|
2410
|
+
if not parsed_args_list and not args_string:
|
|
2411
|
+
parsed_args_list.append({}) # For tool calls with no arguments
|
|
2412
|
+
|
|
2413
|
+
all_results_content = []
|
|
2414
|
+
for args in parsed_args_list:
|
|
2415
|
+
new_tool_call = tool_call.model_copy(deep=True)
|
|
2416
|
+
new_tool_call.function.arguments = json.dumps(args)
|
|
2417
|
+
|
|
2418
|
+
call_result = await experimental_mcp_client.call_openai_tool(
|
|
2419
|
+
session=session,
|
|
2420
|
+
openai_tool=new_tool_call,
|
|
2421
|
+
)
|
|
2422
|
+
|
|
2423
|
+
content_parts = []
|
|
2424
|
+
if call_result.content:
|
|
2425
|
+
for item in call_result.content:
|
|
2426
|
+
if hasattr(item, "resource"): # EmbeddedResource
|
|
2427
|
+
resource = item.resource
|
|
2428
|
+
if hasattr(resource, "text"): # TextResourceContents
|
|
2429
|
+
content_parts.append(resource.text)
|
|
2430
|
+
elif hasattr(resource, "blob"): # BlobResourceContents
|
|
2431
|
+
try:
|
|
2432
|
+
decoded_blob = base64.b64decode(
|
|
2433
|
+
resource.blob
|
|
2434
|
+
).decode("utf-8")
|
|
2435
|
+
content_parts.append(decoded_blob)
|
|
2436
|
+
except (UnicodeDecodeError, TypeError):
|
|
2437
|
+
# Handle non-text blobs gracefully
|
|
2438
|
+
name = getattr(resource, "name", "unnamed")
|
|
2439
|
+
mime_type = getattr(
|
|
2440
|
+
resource, "mimeType", "unknown mime type"
|
|
2441
|
+
)
|
|
2442
|
+
content_parts.append(
|
|
2443
|
+
"[embedded binary resource:"
|
|
2444
|
+
f" {name} ({mime_type})]"
|
|
2445
|
+
)
|
|
2446
|
+
elif hasattr(item, "text"): # TextContent
|
|
2447
|
+
content_parts.append(item.text)
|
|
2448
|
+
|
|
2449
|
+
result_text = "".join(content_parts)
|
|
2450
|
+
all_results_content.append(result_text)
|
|
2451
|
+
|
|
2452
|
+
tool_responses.append(
|
|
2453
|
+
{
|
|
2454
|
+
"role": "tool",
|
|
2455
|
+
"tool_call_id": tool_call.id,
|
|
2456
|
+
"content": "\n\n".join(all_results_content),
|
|
2457
|
+
}
|
|
2458
|
+
)
|
|
2459
|
+
|
|
2460
|
+
except Exception as e:
|
|
2461
|
+
tool_error = f"Error executing tool call {tool_call.function.name}: \n{e}"
|
|
2462
|
+
self.io.tool_warning(
|
|
2463
|
+
f"Executing {tool_call.function.name} on {server.name} failed: \n "
|
|
2464
|
+
f" Error: {e}\n"
|
|
2465
|
+
)
|
|
2466
|
+
tool_responses.append(
|
|
2467
|
+
{"role": "tool", "tool_call_id": tool_call.id, "content": tool_error}
|
|
2468
|
+
)
|
|
2469
|
+
except httpx.RemoteProtocolError as e:
|
|
2470
|
+
connection_error = f"Server {server.name} disconnected unexpectedly: {e}"
|
|
2471
|
+
self.io.tool_warning(connection_error)
|
|
2472
|
+
for tool_call in tool_calls_list:
|
|
2473
|
+
tool_responses.append(
|
|
2474
|
+
{"role": "tool", "tool_call_id": tool_call.id, "content": connection_error}
|
|
2475
|
+
)
|
|
2476
|
+
except Exception as e:
|
|
2477
|
+
connection_error = f"Could not connect to server {server.name}\n{e}"
|
|
2478
|
+
self.io.tool_warning(connection_error)
|
|
2479
|
+
for tool_call in tool_calls_list:
|
|
2480
|
+
tool_responses.append(
|
|
2481
|
+
{"role": "tool", "tool_call_id": tool_call.id, "content": connection_error}
|
|
2482
|
+
)
|
|
2483
|
+
|
|
2484
|
+
return tool_responses
|
|
2485
|
+
|
|
2486
|
+
# Execute all tool calls concurrently
|
|
2487
|
+
async def _execute_all_tool_calls():
|
|
2488
|
+
tasks = []
|
|
2489
|
+
for server, tool_calls_list in tool_calls.items():
|
|
2490
|
+
tasks.append(_exec_server_tools(server, tool_calls_list))
|
|
2491
|
+
# Wait for all tasks to complete
|
|
2492
|
+
results = await asyncio.gather(*tasks)
|
|
2493
|
+
return results
|
|
2494
|
+
|
|
2495
|
+
# Run the async execution and collect results
|
|
2496
|
+
if tool_calls:
|
|
2497
|
+
all_results = []
|
|
2498
|
+
max_retries = 3
|
|
2499
|
+
for i in range(max_retries):
|
|
2500
|
+
try:
|
|
2501
|
+
all_results = await _execute_all_tool_calls()
|
|
2502
|
+
break
|
|
2503
|
+
except asyncio.exceptions.CancelledError:
|
|
2504
|
+
if i < max_retries - 1:
|
|
2505
|
+
await asyncio.sleep(0.1) # Brief pause before retrying
|
|
2506
|
+
else:
|
|
2507
|
+
self.io.tool_warning(
|
|
2508
|
+
"MCP tool execution failed after multiple retries due to cancellation."
|
|
2509
|
+
)
|
|
2510
|
+
all_results = []
|
|
2511
|
+
|
|
2512
|
+
# Flatten the results from all servers
|
|
2513
|
+
for server_results in all_results:
|
|
2514
|
+
tool_responses.extend(server_results)
|
|
2515
|
+
|
|
2516
|
+
return tool_responses
|
|
2517
|
+
|
|
2518
|
+
async def initialize_mcp_tools(self):
|
|
2519
|
+
"""
|
|
2520
|
+
Initialize tools from all configured MCP servers. MCP Servers that fail to be
|
|
2521
|
+
initialized will not be available to the Coder instance.
|
|
2522
|
+
"""
|
|
2523
|
+
tools = []
|
|
2524
|
+
|
|
2525
|
+
async def get_server_tools(server):
|
|
2526
|
+
try:
|
|
2527
|
+
session = await server.connect()
|
|
2528
|
+
server_tools = await experimental_mcp_client.load_mcp_tools(
|
|
2529
|
+
session=session, format="openai"
|
|
2530
|
+
)
|
|
2531
|
+
return (server.name, server_tools)
|
|
2532
|
+
except Exception as e:
|
|
2533
|
+
if server.name != "unnamed-server" and server.name != "local_tools":
|
|
2534
|
+
self.io.tool_warning(f"Error initializing MCP server {server.name}: {e}")
|
|
2535
|
+
return None
|
|
2536
|
+
|
|
2537
|
+
async def get_all_server_tools():
|
|
2538
|
+
tasks = [get_server_tools(server) for server in self.mcp_servers]
|
|
2539
|
+
results = await asyncio.gather(*tasks)
|
|
2540
|
+
return [result for result in results if result is not None]
|
|
2541
|
+
|
|
2542
|
+
if self.mcp_servers:
|
|
2543
|
+
# Retry initialization in case of CancelledError
|
|
2544
|
+
max_retries = 3
|
|
2545
|
+
for i in range(max_retries):
|
|
2546
|
+
try:
|
|
2547
|
+
tools = await get_all_server_tools()
|
|
2548
|
+
break
|
|
2549
|
+
except asyncio.exceptions.CancelledError:
|
|
2550
|
+
if i < max_retries - 1:
|
|
2551
|
+
await asyncio.sleep(0.1) # Brief pause before retrying
|
|
2552
|
+
else:
|
|
2553
|
+
self.io.tool_warning(
|
|
2554
|
+
"MCP tool initialization failed after multiple retries due to"
|
|
2555
|
+
" cancellation."
|
|
2556
|
+
)
|
|
2557
|
+
tools = []
|
|
2558
|
+
|
|
2559
|
+
if len(tools) > 0:
|
|
2560
|
+
if self.verbose:
|
|
2561
|
+
self.io.tool_output("MCP servers configured:")
|
|
2562
|
+
|
|
2563
|
+
for server_name, server_tools in tools:
|
|
2564
|
+
self.io.tool_output(f" - {server_name}")
|
|
2565
|
+
|
|
2566
|
+
for tool in server_tools:
|
|
2567
|
+
tool_name = tool.get("function", {}).get("name", "unknown")
|
|
2568
|
+
tool_desc = tool.get("function", {}).get("description", "").split("\n")[0]
|
|
2569
|
+
self.io.tool_output(f" - {tool_name}: {tool_desc}")
|
|
2570
|
+
|
|
2571
|
+
self.mcp_tools = tools
|
|
2572
|
+
|
|
2573
|
+
def get_tool_list(self):
|
|
2574
|
+
"""Get a flattened list of all MCP tools."""
|
|
2575
|
+
tool_list = []
|
|
2576
|
+
if self.mcp_tools:
|
|
2577
|
+
for _, server_tools in self.mcp_tools:
|
|
2578
|
+
tool_list.extend(server_tools)
|
|
2579
|
+
return tool_list
|
|
2580
|
+
|
|
2581
|
+
async def reply_completed(self):
|
|
2582
|
+
pass
|
|
2583
|
+
|
|
2584
|
+
async def show_exhausted_error(self):
|
|
2585
|
+
output_tokens = 0
|
|
2586
|
+
if self.partial_response_content:
|
|
2587
|
+
output_tokens = self.main_model.token_count(self.partial_response_content)
|
|
2588
|
+
max_output_tokens = self.main_model.info.get("max_output_tokens") or 0
|
|
2589
|
+
|
|
2590
|
+
input_tokens = self.main_model.token_count(self.format_messages().all_messages())
|
|
2591
|
+
max_input_tokens = self.main_model.info.get("max_input_tokens") or 0
|
|
2592
|
+
|
|
2593
|
+
total_tokens = input_tokens + output_tokens
|
|
2594
|
+
|
|
2595
|
+
fudge = 0.7
|
|
2596
|
+
|
|
2597
|
+
out_err = ""
|
|
2598
|
+
if output_tokens >= max_output_tokens * fudge:
|
|
2599
|
+
out_err = " -- possibly exceeded output limit!"
|
|
2600
|
+
|
|
2601
|
+
inp_err = ""
|
|
2602
|
+
if input_tokens >= max_input_tokens * fudge:
|
|
2603
|
+
inp_err = " -- possibly exhausted context window!"
|
|
2604
|
+
|
|
2605
|
+
tot_err = ""
|
|
2606
|
+
if total_tokens >= max_input_tokens * fudge:
|
|
2607
|
+
tot_err = " -- possibly exhausted context window!"
|
|
2608
|
+
|
|
2609
|
+
res = ["", ""]
|
|
2610
|
+
res.append(f"Model {self.main_model.name} has hit a token limit!")
|
|
2611
|
+
res.append("Token counts below are approximate.")
|
|
2612
|
+
res.append("")
|
|
2613
|
+
res.append(f"Input tokens: ~{input_tokens:,} of {max_input_tokens:,}{inp_err}")
|
|
2614
|
+
res.append(f"Output tokens: ~{output_tokens:,} of {max_output_tokens:,}{out_err}")
|
|
2615
|
+
res.append(f"Total tokens: ~{total_tokens:,} of {max_input_tokens:,}{tot_err}")
|
|
2616
|
+
|
|
2617
|
+
if output_tokens >= max_output_tokens:
|
|
2618
|
+
res.append("")
|
|
2619
|
+
res.append("To reduce output tokens:")
|
|
2620
|
+
res.append("- Ask for smaller changes in each request.")
|
|
2621
|
+
res.append("- Break your code into smaller source files.")
|
|
2622
|
+
if "diff" not in self.main_model.edit_format:
|
|
2623
|
+
res.append("- Use a stronger model that can return diffs.")
|
|
2624
|
+
|
|
2625
|
+
if input_tokens >= max_input_tokens or total_tokens >= max_input_tokens:
|
|
2626
|
+
res.append("")
|
|
2627
|
+
res.append("To reduce input tokens:")
|
|
2628
|
+
res.append("- Use /tokens to see token usage.")
|
|
2629
|
+
res.append("- Use /drop to remove unneeded files from the chat session.")
|
|
2630
|
+
res.append("- Use /clear to clear the chat history.")
|
|
2631
|
+
res.append("- Break your code into smaller source files.")
|
|
2632
|
+
|
|
2633
|
+
res = "".join([line + "\n" for line in res])
|
|
2634
|
+
self.io.tool_error(res)
|
|
2635
|
+
await self.io.offer_url(urls.token_limits)
|
|
2636
|
+
|
|
2637
|
+
def lint_edited(self, fnames):
|
|
2638
|
+
res = ""
|
|
2639
|
+
for fname in fnames:
|
|
2640
|
+
if not fname:
|
|
2641
|
+
continue
|
|
2642
|
+
errors = self.linter.lint(self.abs_root_path(fname))
|
|
2643
|
+
|
|
2644
|
+
if errors:
|
|
2645
|
+
res += "\n"
|
|
2646
|
+
res += errors
|
|
2647
|
+
res += "\n"
|
|
2648
|
+
|
|
2649
|
+
if res:
|
|
2650
|
+
self.io.tool_warning(res)
|
|
2651
|
+
|
|
2652
|
+
return res
|
|
2653
|
+
|
|
2654
|
+
def __del__(self):
|
|
2655
|
+
"""Cleanup when the Coder object is destroyed."""
|
|
2656
|
+
self.ok_to_warm_cache = False
|
|
2657
|
+
|
|
2658
|
+
def add_assistant_reply_to_cur_messages(self):
|
|
2659
|
+
"""
|
|
2660
|
+
Add the assistant's reply to `cur_messages`.
|
|
2661
|
+
Handles model-specific quirks, like Deepseek which requires `content`
|
|
2662
|
+
to be `None` when `tool_calls` are present.
|
|
2663
|
+
"""
|
|
2664
|
+
msg = dict(role="assistant")
|
|
2665
|
+
has_tool_calls = self.partial_response_tool_calls or self.partial_response_function_call
|
|
2666
|
+
|
|
2667
|
+
# If we have tool calls and we're using a Deepseek model, force content to be None.
|
|
2668
|
+
if has_tool_calls and self.main_model.is_deepseek():
|
|
2669
|
+
msg["content"] = None
|
|
2670
|
+
else:
|
|
2671
|
+
# Otherwise, use logic similar to the base implementation.
|
|
2672
|
+
content = self.partial_response_content
|
|
2673
|
+
if content:
|
|
2674
|
+
msg["content"] = content
|
|
2675
|
+
elif has_tool_calls:
|
|
2676
|
+
msg["content"] = None
|
|
2677
|
+
|
|
2678
|
+
if self.partial_response_tool_calls:
|
|
2679
|
+
msg["tool_calls"] = self.partial_response_tool_calls
|
|
2680
|
+
elif self.partial_response_function_call:
|
|
2681
|
+
msg["function_call"] = self.partial_response_function_call
|
|
2682
|
+
|
|
2683
|
+
# Only add a message if it's not empty.
|
|
2684
|
+
if msg.get("content") is not None or msg.get("tool_calls") or msg.get("function_call"):
|
|
2685
|
+
self.cur_messages.append(msg)
|
|
2686
|
+
|
|
2687
|
+
def get_file_mentions(self, content, ignore_current=False):
|
|
2688
|
+
words = set(word for word in content.split())
|
|
2689
|
+
|
|
2690
|
+
# drop sentence punctuation from the end
|
|
2691
|
+
words = set(word.rstrip(",.!;:?") for word in words)
|
|
2692
|
+
|
|
2693
|
+
# strip away all kinds of quotes
|
|
2694
|
+
quotes = "\"'`*_"
|
|
2695
|
+
words = set(word.strip(quotes) for word in words)
|
|
2696
|
+
|
|
2697
|
+
if ignore_current:
|
|
2698
|
+
addable_rel_fnames = self.get_all_relative_files()
|
|
2699
|
+
existing_basenames = {}
|
|
2700
|
+
else:
|
|
2701
|
+
addable_rel_fnames = self.get_addable_relative_files()
|
|
2702
|
+
|
|
2703
|
+
# Get basenames of files already in chat or read-only
|
|
2704
|
+
existing_basenames = {os.path.basename(f) for f in self.get_inchat_relative_files()} | {
|
|
2705
|
+
os.path.basename(self.get_rel_fname(f))
|
|
2706
|
+
for f in self.abs_read_only_fnames | self.abs_read_only_stubs_fnames
|
|
2707
|
+
}
|
|
2708
|
+
|
|
2709
|
+
mentioned_rel_fnames = set()
|
|
2710
|
+
fname_to_rel_fnames = {}
|
|
2711
|
+
for rel_fname in addable_rel_fnames:
|
|
2712
|
+
normalized_rel_fname = rel_fname.replace("\\", "/")
|
|
2713
|
+
normalized_words = set(word.replace("\\", "/") for word in words)
|
|
2714
|
+
if normalized_rel_fname in normalized_words:
|
|
2715
|
+
mentioned_rel_fnames.add(rel_fname)
|
|
2716
|
+
|
|
2717
|
+
fname = os.path.basename(rel_fname)
|
|
2718
|
+
|
|
2719
|
+
# Don't add basenames that could be plain words like "run" or "make"
|
|
2720
|
+
if "/" in fname or "\\" in fname or "." in fname or "_" in fname or "-" in fname:
|
|
2721
|
+
if fname not in fname_to_rel_fnames:
|
|
2722
|
+
fname_to_rel_fnames[fname] = []
|
|
2723
|
+
fname_to_rel_fnames[fname].append(rel_fname)
|
|
2724
|
+
|
|
2725
|
+
for fname, rel_fnames in fname_to_rel_fnames.items():
|
|
2726
|
+
# If the basename is already in chat, don't add based on a basename mention
|
|
2727
|
+
if fname in existing_basenames:
|
|
2728
|
+
continue
|
|
2729
|
+
# If the basename mention is unique among addable files and present in the text
|
|
2730
|
+
if len(rel_fnames) == 1 and fname in words:
|
|
2731
|
+
mentioned_rel_fnames.add(rel_fnames[0])
|
|
2732
|
+
|
|
2733
|
+
return mentioned_rel_fnames
|
|
2734
|
+
|
|
2735
|
+
async def check_for_file_mentions(self, content):
|
|
2736
|
+
mentioned_rel_fnames = self.get_file_mentions(content)
|
|
2737
|
+
|
|
2738
|
+
new_mentions = mentioned_rel_fnames - self.ignore_mentions
|
|
2739
|
+
|
|
2740
|
+
if not new_mentions:
|
|
2741
|
+
return
|
|
2742
|
+
|
|
2743
|
+
added_fnames = []
|
|
2744
|
+
group = ConfirmGroup(new_mentions)
|
|
2745
|
+
for rel_fname in sorted(new_mentions):
|
|
2746
|
+
if await self.io.confirm_ask(
|
|
2747
|
+
"Add file to the chat?", subject=rel_fname, group=group, allow_never=True
|
|
2748
|
+
):
|
|
2749
|
+
await self.io.recreate_input()
|
|
2750
|
+
self.add_rel_fname(rel_fname)
|
|
2751
|
+
added_fnames.append(rel_fname)
|
|
2752
|
+
else:
|
|
2753
|
+
self.ignore_mentions.add(rel_fname)
|
|
2754
|
+
|
|
2755
|
+
if added_fnames:
|
|
2756
|
+
return prompts.added_files.format(fnames=", ".join(added_fnames))
|
|
2757
|
+
|
|
2758
|
+
async def send(self, messages, model=None, functions=None, tools=None):
|
|
2759
|
+
self.got_reasoning_content = False
|
|
2760
|
+
self.ended_reasoning_content = False
|
|
2761
|
+
|
|
2762
|
+
self._streaming_buffer_length = 0
|
|
2763
|
+
self.io.reset_streaming_response()
|
|
2764
|
+
|
|
2765
|
+
if not model:
|
|
2766
|
+
model = self.main_model
|
|
2767
|
+
|
|
2768
|
+
self.partial_response_content = ""
|
|
2769
|
+
self.partial_response_function_call = dict()
|
|
2770
|
+
self.partial_response_tool_calls = []
|
|
2771
|
+
|
|
2772
|
+
completion = None
|
|
2773
|
+
|
|
2774
|
+
try:
|
|
2775
|
+
hash_object, completion = await model.send_completion(
|
|
2776
|
+
messages,
|
|
2777
|
+
functions,
|
|
2778
|
+
self.stream,
|
|
2779
|
+
self.temperature,
|
|
2780
|
+
# This could include any tools, but for now it is just MCP tools
|
|
2781
|
+
tools=tools,
|
|
2782
|
+
)
|
|
2783
|
+
self.chat_completion_call_hashes.append(hash_object.hexdigest())
|
|
2784
|
+
|
|
2785
|
+
if not isinstance(completion, ModelResponse):
|
|
2786
|
+
async for chunk in self.show_send_output_stream(completion):
|
|
2787
|
+
yield chunk
|
|
2788
|
+
else:
|
|
2789
|
+
self.show_send_output(completion)
|
|
2790
|
+
|
|
2791
|
+
# Calculate costs for successful responses
|
|
2792
|
+
self.calculate_and_show_tokens_and_cost(messages, completion)
|
|
2793
|
+
|
|
2794
|
+
except LiteLLMExceptions().exceptions_tuple() as err:
|
|
2795
|
+
ex_info = LiteLLMExceptions().get_ex_info(err)
|
|
2796
|
+
if ex_info.name == "ContextWindowExceededError":
|
|
2797
|
+
# Still calculate costs for context window errors
|
|
2798
|
+
self.calculate_and_show_tokens_and_cost(messages, completion)
|
|
2799
|
+
raise
|
|
2800
|
+
except KeyboardInterrupt as kbi:
|
|
2801
|
+
self.keyboard_interrupt()
|
|
2802
|
+
raise kbi
|
|
2803
|
+
finally:
|
|
2804
|
+
self.preprocess_response()
|
|
2805
|
+
|
|
2806
|
+
if self.partial_response_content:
|
|
2807
|
+
self.io.ai_output(self.partial_response_content)
|
|
2808
|
+
elif self.partial_response_function_call:
|
|
2809
|
+
# TODO: push this into subclasses
|
|
2810
|
+
args = self.parse_partial_args()
|
|
2811
|
+
if args:
|
|
2812
|
+
self.io.ai_output(json.dumps(args, indent=4))
|
|
2813
|
+
|
|
2814
|
+
def show_send_output(self, completion):
|
|
2815
|
+
if self.verbose:
|
|
2816
|
+
print(completion)
|
|
2817
|
+
|
|
2818
|
+
if not isinstance(completion, ModelResponse):
|
|
2819
|
+
self.io.tool_error(str(completion))
|
|
2820
|
+
return
|
|
2821
|
+
|
|
2822
|
+
if not completion.choices:
|
|
2823
|
+
self.io.tool_error(str(completion))
|
|
2824
|
+
return
|
|
2825
|
+
|
|
2826
|
+
show_func_err = None
|
|
2827
|
+
show_content_err = None
|
|
2828
|
+
try:
|
|
2829
|
+
if completion.choices[0].message.tool_calls:
|
|
2830
|
+
self.partial_response_function_call = (
|
|
2831
|
+
completion.choices[0].message.tool_calls[0].function
|
|
2832
|
+
)
|
|
2833
|
+
except AttributeError as func_err:
|
|
2834
|
+
show_func_err = func_err
|
|
2835
|
+
|
|
2836
|
+
try:
|
|
2837
|
+
reasoning_content = completion.choices[0].message.reasoning_content
|
|
2838
|
+
except AttributeError:
|
|
2839
|
+
try:
|
|
2840
|
+
reasoning_content = completion.choices[0].message.reasoning
|
|
2841
|
+
except AttributeError:
|
|
2842
|
+
reasoning_content = None
|
|
2843
|
+
|
|
2844
|
+
try:
|
|
2845
|
+
self.partial_response_content = completion.choices[0].message.content or ""
|
|
2846
|
+
except AttributeError as content_err:
|
|
2847
|
+
show_content_err = content_err
|
|
2848
|
+
|
|
2849
|
+
resp_hash = dict(
|
|
2850
|
+
function_call=str(self.partial_response_function_call),
|
|
2851
|
+
content=self.partial_response_content,
|
|
2852
|
+
)
|
|
2853
|
+
resp_hash = hashlib.sha1(json.dumps(resp_hash, sort_keys=True).encode())
|
|
2854
|
+
self.chat_completion_response_hashes.append(resp_hash.hexdigest())
|
|
2855
|
+
|
|
2856
|
+
if show_func_err and show_content_err:
|
|
2857
|
+
self.io.tool_error(show_func_err)
|
|
2858
|
+
self.io.tool_error(show_content_err)
|
|
2859
|
+
raise Exception("No data found in LLM response!")
|
|
2860
|
+
|
|
2861
|
+
show_resp = self.render_incremental_response(True)
|
|
2862
|
+
|
|
2863
|
+
if reasoning_content:
|
|
2864
|
+
formatted_reasoning = format_reasoning_content(
|
|
2865
|
+
reasoning_content, self.reasoning_tag_name
|
|
2866
|
+
)
|
|
2867
|
+
show_resp = formatted_reasoning + show_resp
|
|
2868
|
+
|
|
2869
|
+
show_resp = replace_reasoning_tags(show_resp, self.reasoning_tag_name)
|
|
2870
|
+
|
|
2871
|
+
self.io.assistant_output(show_resp, pretty=self.show_pretty())
|
|
2872
|
+
|
|
2873
|
+
if (
|
|
2874
|
+
hasattr(completion.choices[0], "finish_reason")
|
|
2875
|
+
and completion.choices[0].finish_reason == "length"
|
|
2876
|
+
):
|
|
2877
|
+
raise FinishReasonLength()
|
|
2878
|
+
|
|
2879
|
+
async def show_send_output_stream(self, completion):
|
|
2880
|
+
received_content = False
|
|
2881
|
+
|
|
2882
|
+
async for chunk in completion:
|
|
2883
|
+
# Check if confirmation is in progress and wait if needed
|
|
2884
|
+
while self.io.confirmation_in_progress:
|
|
2885
|
+
await asyncio.sleep(0.1) # Yield control and wait briefly
|
|
2886
|
+
|
|
2887
|
+
if isinstance(chunk, str):
|
|
2888
|
+
text = chunk
|
|
2889
|
+
received_content = True
|
|
2890
|
+
else:
|
|
2891
|
+
if len(chunk.choices) == 0:
|
|
2892
|
+
continue
|
|
2893
|
+
|
|
2894
|
+
if (
|
|
2895
|
+
hasattr(chunk.choices[0], "finish_reason")
|
|
2896
|
+
and chunk.choices[0].finish_reason == "length"
|
|
2897
|
+
):
|
|
2898
|
+
raise FinishReasonLength()
|
|
2899
|
+
|
|
2900
|
+
try:
|
|
2901
|
+
if chunk.choices[0].delta.tool_calls:
|
|
2902
|
+
received_content = True
|
|
2903
|
+
for tool_call_chunk in chunk.choices[0].delta.tool_calls:
|
|
2904
|
+
self.tool_reflection = True
|
|
2905
|
+
|
|
2906
|
+
index = tool_call_chunk.index
|
|
2907
|
+
if len(self.partial_response_tool_calls) <= index:
|
|
2908
|
+
self.partial_response_tool_calls.extend(
|
|
2909
|
+
[{}] * (index - len(self.partial_response_tool_calls) + 1)
|
|
2910
|
+
)
|
|
2911
|
+
|
|
2912
|
+
if tool_call_chunk.id:
|
|
2913
|
+
self.partial_response_tool_calls[index]["id"] = tool_call_chunk.id
|
|
2914
|
+
if tool_call_chunk.type:
|
|
2915
|
+
self.partial_response_tool_calls[index][
|
|
2916
|
+
"type"
|
|
2917
|
+
] = tool_call_chunk.type
|
|
2918
|
+
if tool_call_chunk.function:
|
|
2919
|
+
if "function" not in self.partial_response_tool_calls[index]:
|
|
2920
|
+
self.partial_response_tool_calls[index]["function"] = {}
|
|
2921
|
+
if tool_call_chunk.function.name:
|
|
2922
|
+
if (
|
|
2923
|
+
"name"
|
|
2924
|
+
not in self.partial_response_tool_calls[index]["function"]
|
|
2925
|
+
):
|
|
2926
|
+
self.partial_response_tool_calls[index]["function"][
|
|
2927
|
+
"name"
|
|
2928
|
+
] = ""
|
|
2929
|
+
self.partial_response_tool_calls[index]["function"][
|
|
2930
|
+
"name"
|
|
2931
|
+
] += tool_call_chunk.function.name
|
|
2932
|
+
if tool_call_chunk.function.arguments:
|
|
2933
|
+
if (
|
|
2934
|
+
"arguments"
|
|
2935
|
+
not in self.partial_response_tool_calls[index]["function"]
|
|
2936
|
+
):
|
|
2937
|
+
self.partial_response_tool_calls[index]["function"][
|
|
2938
|
+
"arguments"
|
|
2939
|
+
] = ""
|
|
2940
|
+
self.partial_response_tool_calls[index]["function"][
|
|
2941
|
+
"arguments"
|
|
2942
|
+
] += tool_call_chunk.function.arguments
|
|
2943
|
+
except (AttributeError, IndexError):
|
|
2944
|
+
# Handle cases where the response structure doesn't match expectations
|
|
2945
|
+
pass
|
|
2946
|
+
|
|
2947
|
+
try:
|
|
2948
|
+
func = chunk.choices[0].delta.function_call
|
|
2949
|
+
# dump(func)
|
|
2950
|
+
for k, v in func.items():
|
|
2951
|
+
self.tool_reflection = True
|
|
2952
|
+
|
|
2953
|
+
if k in self.partial_response_function_call:
|
|
2954
|
+
self.partial_response_function_call[k] += v
|
|
2955
|
+
else:
|
|
2956
|
+
self.partial_response_function_call[k] = v
|
|
2957
|
+
|
|
2958
|
+
received_content = True
|
|
2959
|
+
except AttributeError:
|
|
2960
|
+
pass
|
|
2961
|
+
|
|
2962
|
+
text = ""
|
|
2963
|
+
|
|
2964
|
+
try:
|
|
2965
|
+
reasoning_content = chunk.choices[0].delta.reasoning_content
|
|
2966
|
+
except AttributeError:
|
|
2967
|
+
try:
|
|
2968
|
+
reasoning_content = chunk.choices[0].delta.reasoning
|
|
2969
|
+
except AttributeError:
|
|
2970
|
+
reasoning_content = None
|
|
2971
|
+
|
|
2972
|
+
if reasoning_content:
|
|
2973
|
+
if not self.got_reasoning_content:
|
|
2974
|
+
text += f"<{REASONING_TAG}>\n\n"
|
|
2975
|
+
text += reasoning_content
|
|
2976
|
+
self.got_reasoning_content = True
|
|
2977
|
+
received_content = True
|
|
2978
|
+
|
|
2979
|
+
try:
|
|
2980
|
+
content = chunk.choices[0].delta.content
|
|
2981
|
+
if content:
|
|
2982
|
+
if self.got_reasoning_content and not self.ended_reasoning_content:
|
|
2983
|
+
text += f"\n\n</{self.reasoning_tag_name}>\n\n"
|
|
2984
|
+
self.ended_reasoning_content = True
|
|
2985
|
+
|
|
2986
|
+
text += content
|
|
2987
|
+
received_content = True
|
|
2988
|
+
except AttributeError:
|
|
2989
|
+
pass
|
|
2990
|
+
|
|
2991
|
+
self.partial_response_content += text
|
|
2992
|
+
if self.show_pretty():
|
|
2993
|
+
# Use simplified streaming - just call the method with full content
|
|
2994
|
+
content_to_show = self.live_incremental_response(False)
|
|
2995
|
+
self.stream_wrapper(content_to_show, final=False)
|
|
2996
|
+
elif text:
|
|
2997
|
+
# Apply reasoning tag formatting for non-pretty output
|
|
2998
|
+
text = replace_reasoning_tags(text, self.reasoning_tag_name)
|
|
2999
|
+
try:
|
|
3000
|
+
self.stream_wrapper(text, final=False)
|
|
3001
|
+
except UnicodeEncodeError:
|
|
3002
|
+
# Safely encode and decode the text
|
|
3003
|
+
safe_text = text.encode(sys.stdout.encoding, errors="backslashreplace").decode(
|
|
3004
|
+
sys.stdout.encoding
|
|
3005
|
+
)
|
|
3006
|
+
self.stream_wrapper(safe_text, final=False)
|
|
3007
|
+
yield text
|
|
3008
|
+
|
|
3009
|
+
if not received_content and len(self.partial_response_tool_calls) == 0:
|
|
3010
|
+
self.io.tool_warning("Empty response received from LLM. Check your provider account?")
|
|
3011
|
+
|
|
3012
|
+
def stream_wrapper(self, content, final):
|
|
3013
|
+
if not hasattr(self, "_streaming_buffer_length"):
|
|
3014
|
+
self._streaming_buffer_length = 0
|
|
3015
|
+
|
|
3016
|
+
if final:
|
|
3017
|
+
content += "\n\n"
|
|
3018
|
+
|
|
3019
|
+
if isinstance(content, str):
|
|
3020
|
+
self._streaming_buffer_length += len(content)
|
|
3021
|
+
|
|
3022
|
+
self.io.stream_output(content, final=final)
|
|
3023
|
+
|
|
3024
|
+
if final:
|
|
3025
|
+
self._streaming_buffer_length = 0
|
|
3026
|
+
|
|
3027
|
+
def live_incremental_response(self, final):
|
|
3028
|
+
show_resp = self.render_incremental_response(final)
|
|
3029
|
+
# Apply any reasoning tag formatting
|
|
3030
|
+
show_resp = replace_reasoning_tags(show_resp, self.reasoning_tag_name)
|
|
3031
|
+
|
|
3032
|
+
# Track streaming state to avoid repetitive output
|
|
3033
|
+
if not hasattr(self, "_streaming_buffer_length"):
|
|
3034
|
+
self._streaming_buffer_length = 0
|
|
3035
|
+
|
|
3036
|
+
# Only send new content that hasn't been streamed yet
|
|
3037
|
+
if len(show_resp) >= self._streaming_buffer_length:
|
|
3038
|
+
new_content = show_resp[self._streaming_buffer_length :]
|
|
3039
|
+
return new_content
|
|
3040
|
+
else:
|
|
3041
|
+
self._streaming_buffer_length = 0
|
|
3042
|
+
self.io.reset_streaming_response()
|
|
3043
|
+
return show_resp
|
|
3044
|
+
|
|
3045
|
+
def render_incremental_response(self, final):
|
|
3046
|
+
# Just return the current content - the streaming logic will handle incremental updates
|
|
3047
|
+
return self.get_multi_response_content_in_progress()
|
|
3048
|
+
|
|
3049
|
+
def preprocess_response(self):
|
|
3050
|
+
if len(self.partial_response_tool_calls):
|
|
3051
|
+
tool_list = []
|
|
3052
|
+
tool_id_set = set()
|
|
3053
|
+
|
|
3054
|
+
for tool_call_dict in self.partial_response_tool_calls:
|
|
3055
|
+
# LLM APIs sometimes return duplicates and that's annoying part 2
|
|
3056
|
+
if tool_call_dict.get("id") in tool_id_set:
|
|
3057
|
+
continue
|
|
3058
|
+
|
|
3059
|
+
tool_id_set.add(tool_call_dict.get("id"))
|
|
3060
|
+
tool_list.append(tool_call_dict)
|
|
3061
|
+
|
|
3062
|
+
self.partial_response_tool_calls = tool_list
|
|
3063
|
+
|
|
3064
|
+
def remove_reasoning_content(self):
|
|
3065
|
+
"""Remove reasoning content from the model's response."""
|
|
3066
|
+
|
|
3067
|
+
self.partial_response_content = remove_reasoning_content(
|
|
3068
|
+
self.partial_response_content,
|
|
3069
|
+
self.reasoning_tag_name,
|
|
3070
|
+
)
|
|
3071
|
+
|
|
3072
|
+
def calculate_and_show_tokens_and_cost(self, messages, completion=None):
|
|
3073
|
+
prompt_tokens = 0
|
|
3074
|
+
completion_tokens = 0
|
|
3075
|
+
cache_hit_tokens = 0
|
|
3076
|
+
cache_write_tokens = 0
|
|
3077
|
+
|
|
3078
|
+
if completion and hasattr(completion, "usage") and completion.usage is not None:
|
|
3079
|
+
prompt_tokens = completion.usage.prompt_tokens
|
|
3080
|
+
completion_tokens = completion.usage.completion_tokens
|
|
3081
|
+
cache_hit_tokens = getattr(completion.usage, "prompt_cache_hit_tokens", 0) or getattr(
|
|
3082
|
+
completion.usage, "cache_read_input_tokens", 0
|
|
3083
|
+
)
|
|
3084
|
+
cache_write_tokens = getattr(completion.usage, "cache_creation_input_tokens", 0)
|
|
3085
|
+
|
|
3086
|
+
if hasattr(completion.usage, "cache_read_input_tokens") or hasattr(
|
|
3087
|
+
completion.usage, "cache_creation_input_tokens"
|
|
3088
|
+
):
|
|
3089
|
+
self.message_tokens_sent += prompt_tokens
|
|
3090
|
+
self.message_tokens_sent += cache_write_tokens
|
|
3091
|
+
else:
|
|
3092
|
+
self.message_tokens_sent += prompt_tokens
|
|
3093
|
+
|
|
3094
|
+
else:
|
|
3095
|
+
prompt_tokens = self.main_model.token_count(messages)
|
|
3096
|
+
completion_tokens = self.main_model.token_count(self.partial_response_content)
|
|
3097
|
+
self.message_tokens_sent += prompt_tokens
|
|
3098
|
+
|
|
3099
|
+
self.message_tokens_received += completion_tokens
|
|
3100
|
+
|
|
3101
|
+
tokens_report = f"Tokens: {format_tokens(self.message_tokens_sent)} sent"
|
|
3102
|
+
|
|
3103
|
+
if cache_write_tokens:
|
|
3104
|
+
tokens_report += f", {format_tokens(cache_write_tokens)} cache write"
|
|
3105
|
+
if cache_hit_tokens:
|
|
3106
|
+
tokens_report += f", {format_tokens(cache_hit_tokens)} cache hit"
|
|
3107
|
+
tokens_report += f", {format_tokens(self.message_tokens_received)} received."
|
|
3108
|
+
|
|
3109
|
+
if not self.main_model.info.get("input_cost_per_token"):
|
|
3110
|
+
self.usage_report = tokens_report
|
|
3111
|
+
return
|
|
3112
|
+
|
|
3113
|
+
try:
|
|
3114
|
+
# Try and use litellm's built in cost calculator. Seems to work for non-streaming only?
|
|
3115
|
+
cost = litellm.completion_cost(completion_response=completion)
|
|
3116
|
+
except Exception:
|
|
3117
|
+
cost = 0
|
|
3118
|
+
|
|
3119
|
+
if not cost:
|
|
3120
|
+
cost = self.compute_costs_from_tokens(
|
|
3121
|
+
prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens
|
|
3122
|
+
)
|
|
3123
|
+
|
|
3124
|
+
self.total_cost += cost
|
|
3125
|
+
self.message_cost += cost
|
|
3126
|
+
|
|
3127
|
+
cost_report = (
|
|
3128
|
+
f"Cost: ${self.format_cost(self.message_cost)} message,"
|
|
3129
|
+
f" ${self.format_cost(self.total_cost)} session."
|
|
3130
|
+
)
|
|
3131
|
+
|
|
3132
|
+
if cache_hit_tokens and cache_write_tokens:
|
|
3133
|
+
sep = "\n"
|
|
3134
|
+
else:
|
|
3135
|
+
sep = " "
|
|
3136
|
+
|
|
3137
|
+
self.usage_report = tokens_report + sep + cost_report
|
|
3138
|
+
|
|
3139
|
+
def format_cost(self, value):
|
|
3140
|
+
if value == 0:
|
|
3141
|
+
return "0.00"
|
|
3142
|
+
magnitude = abs(value)
|
|
3143
|
+
if magnitude >= 0.01:
|
|
3144
|
+
return f"{value:.2f}"
|
|
3145
|
+
else:
|
|
3146
|
+
return f"{value:.{max(2, 2 - int(math.log10(magnitude)))}f}"
|
|
3147
|
+
|
|
3148
|
+
def compute_costs_from_tokens(
|
|
3149
|
+
self, prompt_tokens, completion_tokens, cache_write_tokens, cache_hit_tokens
|
|
3150
|
+
):
|
|
3151
|
+
cost = 0
|
|
3152
|
+
|
|
3153
|
+
input_cost_per_token = self.main_model.info.get("input_cost_per_token") or 0
|
|
3154
|
+
output_cost_per_token = self.main_model.info.get("output_cost_per_token") or 0
|
|
3155
|
+
input_cost_per_token_cache_hit = (
|
|
3156
|
+
self.main_model.info.get("input_cost_per_token_cache_hit") or 0
|
|
3157
|
+
)
|
|
3158
|
+
|
|
3159
|
+
# deepseek
|
|
3160
|
+
# prompt_cache_hit_tokens + prompt_cache_miss_tokens
|
|
3161
|
+
# == prompt_tokens == total tokens that were sent
|
|
3162
|
+
#
|
|
3163
|
+
# Anthropic
|
|
3164
|
+
# cache_creation_input_tokens + cache_read_input_tokens + prompt
|
|
3165
|
+
# == total tokens that were
|
|
3166
|
+
|
|
3167
|
+
if input_cost_per_token_cache_hit:
|
|
3168
|
+
# must be deepseek
|
|
3169
|
+
cost += input_cost_per_token_cache_hit * cache_hit_tokens
|
|
3170
|
+
cost += (prompt_tokens - input_cost_per_token_cache_hit) * input_cost_per_token
|
|
3171
|
+
else:
|
|
3172
|
+
# hard code the anthropic adjustments, no-ops for other models since cache_x_tokens==0
|
|
3173
|
+
cost += cache_write_tokens * input_cost_per_token * 1.25
|
|
3174
|
+
cost += cache_hit_tokens * input_cost_per_token * 0.10
|
|
3175
|
+
cost += prompt_tokens * input_cost_per_token
|
|
3176
|
+
|
|
3177
|
+
cost += completion_tokens * output_cost_per_token
|
|
3178
|
+
return cost
|
|
3179
|
+
|
|
3180
|
+
def show_usage_report(self):
|
|
3181
|
+
if not self.usage_report:
|
|
3182
|
+
return
|
|
3183
|
+
|
|
3184
|
+
self.total_tokens_sent += self.message_tokens_sent
|
|
3185
|
+
self.total_tokens_received += self.message_tokens_received
|
|
3186
|
+
|
|
3187
|
+
self.io.tool_output(self.usage_report)
|
|
3188
|
+
self.io.rule()
|
|
3189
|
+
|
|
3190
|
+
prompt_tokens = self.message_tokens_sent
|
|
3191
|
+
completion_tokens = self.message_tokens_received
|
|
3192
|
+
self.event(
|
|
3193
|
+
"message_send",
|
|
3194
|
+
main_model=self.main_model,
|
|
3195
|
+
edit_format=self.edit_format,
|
|
3196
|
+
prompt_tokens=prompt_tokens,
|
|
3197
|
+
completion_tokens=completion_tokens,
|
|
3198
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
3199
|
+
cost=self.message_cost,
|
|
3200
|
+
total_cost=self.total_cost,
|
|
3201
|
+
)
|
|
3202
|
+
|
|
3203
|
+
self.message_cost = 0.0
|
|
3204
|
+
self.message_tokens_sent = 0
|
|
3205
|
+
self.message_tokens_received = 0
|
|
3206
|
+
|
|
3207
|
+
def get_multi_response_content_in_progress(self, final=False):
|
|
3208
|
+
cur = self.multi_response_content or ""
|
|
3209
|
+
new = self.partial_response_content or ""
|
|
3210
|
+
|
|
3211
|
+
if new.rstrip() != new and not final:
|
|
3212
|
+
new = new.rstrip()
|
|
3213
|
+
|
|
3214
|
+
return cur + new
|
|
3215
|
+
|
|
3216
|
+
def get_file_stub(self, fname):
|
|
3217
|
+
return RepoMap.get_file_stub(fname, self.io)
|
|
3218
|
+
|
|
3219
|
+
def get_rel_fname(self, fname):
|
|
3220
|
+
try:
|
|
3221
|
+
return os.path.relpath(fname, self.root)
|
|
3222
|
+
except ValueError:
|
|
3223
|
+
return fname
|
|
3224
|
+
|
|
3225
|
+
def get_inchat_relative_files(self):
|
|
3226
|
+
files = [self.get_rel_fname(fname) for fname in self.abs_fnames]
|
|
3227
|
+
return sorted(set(files))
|
|
3228
|
+
|
|
3229
|
+
def is_file_safe(self, fname):
|
|
3230
|
+
try:
|
|
3231
|
+
return Path(self.abs_root_path(fname)).is_file()
|
|
3232
|
+
except OSError:
|
|
3233
|
+
return
|
|
3234
|
+
|
|
3235
|
+
def get_all_relative_files(self):
|
|
3236
|
+
staged_files_hash = hash(str([item.a_path for item in self.repo.repo.index.diff("HEAD")]))
|
|
3237
|
+
if (
|
|
3238
|
+
staged_files_hash == self.data_cache["repo"]["last_key"]
|
|
3239
|
+
and self.data_cache["relative_files"]
|
|
3240
|
+
):
|
|
3241
|
+
return self.data_cache["relative_files"]
|
|
3242
|
+
|
|
3243
|
+
if self.repo:
|
|
3244
|
+
files = self.repo.get_tracked_files()
|
|
3245
|
+
else:
|
|
3246
|
+
files = self.get_inchat_relative_files()
|
|
3247
|
+
|
|
3248
|
+
# This is quite slow in large repos
|
|
3249
|
+
# files = [fname for fname in files if self.is_file_safe(fname)]
|
|
3250
|
+
|
|
3251
|
+
self.data_cache["relative_files"] = sorted(set(files))
|
|
3252
|
+
|
|
3253
|
+
return self.data_cache["relative_files"]
|
|
3254
|
+
|
|
3255
|
+
def get_all_abs_files(self):
|
|
3256
|
+
files = self.get_all_relative_files()
|
|
3257
|
+
files = [self.abs_root_path(path) for path in files]
|
|
3258
|
+
return files
|
|
3259
|
+
|
|
3260
|
+
def get_addable_relative_files(self):
|
|
3261
|
+
all_files = set(self.get_all_relative_files())
|
|
3262
|
+
inchat_files = set(self.get_inchat_relative_files())
|
|
3263
|
+
read_only_files = set(self.get_rel_fname(fname) for fname in self.abs_read_only_fnames)
|
|
3264
|
+
stub_files = set(self.get_rel_fname(fname) for fname in self.abs_read_only_stubs_fnames)
|
|
3265
|
+
return all_files - inchat_files - read_only_files - stub_files
|
|
3266
|
+
|
|
3267
|
+
def check_for_dirty_commit(self, path):
|
|
3268
|
+
if not self.repo:
|
|
3269
|
+
return
|
|
3270
|
+
if not self.dirty_commits:
|
|
3271
|
+
return
|
|
3272
|
+
if not self.repo.is_dirty(path):
|
|
3273
|
+
return
|
|
3274
|
+
|
|
3275
|
+
# We need a committed copy of the file in order to /undo, so skip this
|
|
3276
|
+
# fullp = Path(self.abs_root_path(path))
|
|
3277
|
+
# if not fullp.stat().st_size:
|
|
3278
|
+
# return
|
|
3279
|
+
|
|
3280
|
+
self.io.tool_output(f"Committing {path} before applying edits.")
|
|
3281
|
+
self.need_commit_before_edits.add(path)
|
|
3282
|
+
|
|
3283
|
+
async def allowed_to_edit(self, path):
|
|
3284
|
+
full_path = self.abs_root_path(path)
|
|
3285
|
+
if self.repo:
|
|
3286
|
+
need_to_add = not self.repo.path_in_repo(path)
|
|
3287
|
+
else:
|
|
3288
|
+
need_to_add = False
|
|
3289
|
+
|
|
3290
|
+
if full_path in self.abs_fnames:
|
|
3291
|
+
self.check_for_dirty_commit(path)
|
|
3292
|
+
return True
|
|
3293
|
+
|
|
3294
|
+
if self.repo and self.repo.git_ignored_file(path):
|
|
3295
|
+
self.io.tool_warning(f"Skipping edits to {path} that matches gitignore spec.")
|
|
3296
|
+
return
|
|
3297
|
+
|
|
3298
|
+
if not Path(full_path).exists():
|
|
3299
|
+
if not await self.io.confirm_ask("Create new file?", subject=path):
|
|
3300
|
+
self.io.tool_output(f"Skipping edits to {path}")
|
|
3301
|
+
return
|
|
3302
|
+
|
|
3303
|
+
if not self.dry_run:
|
|
3304
|
+
if not utils.touch_file(full_path):
|
|
3305
|
+
self.io.tool_error(f"Unable to create {path}, skipping edits.")
|
|
3306
|
+
return
|
|
3307
|
+
|
|
3308
|
+
# Seems unlikely that we needed to create the file, but it was
|
|
3309
|
+
# actually already part of the repo.
|
|
3310
|
+
# But let's only add if we need to, just to be safe.
|
|
3311
|
+
if need_to_add:
|
|
3312
|
+
self.repo.repo.git.add(full_path)
|
|
3313
|
+
|
|
3314
|
+
self.abs_fnames.add(full_path)
|
|
3315
|
+
self.check_added_files()
|
|
3316
|
+
return True
|
|
3317
|
+
|
|
3318
|
+
if not await self.io.confirm_ask(
|
|
3319
|
+
"Allow edits to file that has not been added to the chat?",
|
|
3320
|
+
subject=path,
|
|
3321
|
+
):
|
|
3322
|
+
self.io.tool_output(f"Skipping edits to {path}")
|
|
3323
|
+
return
|
|
3324
|
+
|
|
3325
|
+
if need_to_add:
|
|
3326
|
+
self.repo.repo.git.add(full_path)
|
|
3327
|
+
|
|
3328
|
+
self.abs_fnames.add(full_path)
|
|
3329
|
+
self.check_added_files()
|
|
3330
|
+
self.check_for_dirty_commit(path)
|
|
3331
|
+
|
|
3332
|
+
return True
|
|
3333
|
+
|
|
3334
|
+
warning_given = False
|
|
3335
|
+
|
|
3336
|
+
def check_added_files(self):
|
|
3337
|
+
if self.warning_given:
|
|
3338
|
+
return
|
|
3339
|
+
|
|
3340
|
+
warn_number_of_files = 4
|
|
3341
|
+
warn_number_of_tokens = 20 * 1024
|
|
3342
|
+
|
|
3343
|
+
num_files = len(self.abs_fnames)
|
|
3344
|
+
if num_files < warn_number_of_files:
|
|
3345
|
+
return
|
|
3346
|
+
|
|
3347
|
+
tokens = 0
|
|
3348
|
+
for fname in self.abs_fnames:
|
|
3349
|
+
if is_image_file(fname):
|
|
3350
|
+
continue
|
|
3351
|
+
content = self.io.read_text(fname)
|
|
3352
|
+
tokens += self.main_model.token_count(content)
|
|
3353
|
+
|
|
3354
|
+
if tokens < warn_number_of_tokens:
|
|
3355
|
+
return
|
|
3356
|
+
|
|
3357
|
+
self.io.tool_warning("Warning: it's best to only add files that need changes to the chat.")
|
|
3358
|
+
self.io.tool_warning(urls.edit_errors)
|
|
3359
|
+
self.warning_given = True
|
|
3360
|
+
|
|
3361
|
+
async def prepare_to_edit(self, edits):
|
|
3362
|
+
res = []
|
|
3363
|
+
seen = dict()
|
|
3364
|
+
|
|
3365
|
+
self.need_commit_before_edits = set()
|
|
3366
|
+
|
|
3367
|
+
for edit in edits:
|
|
3368
|
+
path = edit[0]
|
|
3369
|
+
if path is None:
|
|
3370
|
+
res.append(edit)
|
|
3371
|
+
continue
|
|
3372
|
+
if path == "python":
|
|
3373
|
+
dump(edits)
|
|
3374
|
+
if path in seen:
|
|
3375
|
+
allowed = seen[path]
|
|
3376
|
+
else:
|
|
3377
|
+
allowed = await self.allowed_to_edit(path)
|
|
3378
|
+
seen[path] = allowed
|
|
3379
|
+
|
|
3380
|
+
if allowed:
|
|
3381
|
+
res.append(edit)
|
|
3382
|
+
|
|
3383
|
+
await self.dirty_commit()
|
|
3384
|
+
self.need_commit_before_edits = set()
|
|
3385
|
+
|
|
3386
|
+
return res
|
|
3387
|
+
|
|
3388
|
+
async def apply_updates(self):
|
|
3389
|
+
edited = set()
|
|
3390
|
+
try:
|
|
3391
|
+
edits = self.get_edits()
|
|
3392
|
+
edits = self.apply_edits_dry_run(edits)
|
|
3393
|
+
edits = await self.prepare_to_edit(edits)
|
|
3394
|
+
edited = set(edit[0] for edit in edits)
|
|
3395
|
+
|
|
3396
|
+
self.apply_edits(edits)
|
|
3397
|
+
except ValueError as err:
|
|
3398
|
+
self.num_malformed_responses += 1
|
|
3399
|
+
|
|
3400
|
+
err = err.args[0]
|
|
3401
|
+
|
|
3402
|
+
self.io.tool_error("The LLM did not conform to the edit format.")
|
|
3403
|
+
self.io.tool_output(urls.edit_errors)
|
|
3404
|
+
self.io.tool_output()
|
|
3405
|
+
self.io.tool_output(str(err))
|
|
3406
|
+
|
|
3407
|
+
self.reflected_message = str(err)
|
|
3408
|
+
return edited
|
|
3409
|
+
|
|
3410
|
+
except ANY_GIT_ERROR as err:
|
|
3411
|
+
self.io.tool_error(str(err))
|
|
3412
|
+
return edited
|
|
3413
|
+
except Exception as err:
|
|
3414
|
+
self.io.tool_error("Exception while updating files:")
|
|
3415
|
+
self.io.tool_error(str(err), strip=False)
|
|
3416
|
+
self.io.tool_error(traceback.format_exc())
|
|
3417
|
+
self.reflected_message = str(err)
|
|
3418
|
+
return edited
|
|
3419
|
+
|
|
3420
|
+
for path in edited:
|
|
3421
|
+
if self.dry_run:
|
|
3422
|
+
self.io.tool_output(f"Did not apply edit to {path} (--dry-run)")
|
|
3423
|
+
else:
|
|
3424
|
+
self.io.tool_output(f"Applied edit to {path}")
|
|
3425
|
+
|
|
3426
|
+
return edited
|
|
3427
|
+
|
|
3428
|
+
def parse_partial_args(self):
|
|
3429
|
+
# dump(self.partial_response_function_call)
|
|
3430
|
+
|
|
3431
|
+
data = self.partial_response_function_call.get("arguments")
|
|
3432
|
+
if not data:
|
|
3433
|
+
return
|
|
3434
|
+
|
|
3435
|
+
try:
|
|
3436
|
+
return json.loads(data)
|
|
3437
|
+
except JSONDecodeError:
|
|
3438
|
+
pass
|
|
3439
|
+
|
|
3440
|
+
try:
|
|
3441
|
+
return json.loads(data + "]}")
|
|
3442
|
+
except JSONDecodeError:
|
|
3443
|
+
pass
|
|
3444
|
+
|
|
3445
|
+
try:
|
|
3446
|
+
return json.loads(data + "}]}")
|
|
3447
|
+
except JSONDecodeError:
|
|
3448
|
+
pass
|
|
3449
|
+
|
|
3450
|
+
try:
|
|
3451
|
+
return json.loads(data + '"}]}')
|
|
3452
|
+
except JSONDecodeError:
|
|
3453
|
+
pass
|
|
3454
|
+
|
|
3455
|
+
def _find_occurrences(self, content, pattern, near_context=None):
|
|
3456
|
+
"""Find all occurrences of pattern, optionally filtered by near_context."""
|
|
3457
|
+
occurrences = []
|
|
3458
|
+
start = 0
|
|
3459
|
+
while True:
|
|
3460
|
+
index = content.find(pattern, start)
|
|
3461
|
+
if index == -1:
|
|
3462
|
+
break
|
|
3463
|
+
|
|
3464
|
+
if near_context:
|
|
3465
|
+
# Check if near_context is within a window around the match
|
|
3466
|
+
window_start = max(0, index - 200)
|
|
3467
|
+
window_end = min(len(content), index + len(pattern) + 200)
|
|
3468
|
+
window = content[window_start:window_end]
|
|
3469
|
+
if near_context in window:
|
|
3470
|
+
occurrences.append(index)
|
|
3471
|
+
else:
|
|
3472
|
+
occurrences.append(index)
|
|
3473
|
+
|
|
3474
|
+
start = index + 1 # Move past this occurrence's start
|
|
3475
|
+
return occurrences
|
|
3476
|
+
|
|
3477
|
+
# commits...
|
|
3478
|
+
|
|
3479
|
+
def get_context_from_history(self, history):
|
|
3480
|
+
context = ""
|
|
3481
|
+
if history:
|
|
3482
|
+
for msg in history:
|
|
3483
|
+
msg_content = msg.get("content") or ""
|
|
3484
|
+
context += "\n" + msg["role"].upper() + ": " + msg_content + "\n"
|
|
3485
|
+
|
|
3486
|
+
return context
|
|
3487
|
+
|
|
3488
|
+
async def auto_commit(self, edited, context=None):
|
|
3489
|
+
if not self.repo or not self.auto_commits or self.dry_run:
|
|
3490
|
+
return
|
|
3491
|
+
|
|
3492
|
+
if not context:
|
|
3493
|
+
context = self.get_context_from_history(self.cur_messages)
|
|
3494
|
+
|
|
3495
|
+
try:
|
|
3496
|
+
res = await self.repo.commit(
|
|
3497
|
+
fnames=edited, context=context, aider_edits=True, coder=self
|
|
3498
|
+
)
|
|
3499
|
+
if res:
|
|
3500
|
+
self.show_auto_commit_outcome(res)
|
|
3501
|
+
commit_hash, commit_message = res
|
|
3502
|
+
return self.gpt_prompts.files_content_gpt_edits.format(
|
|
3503
|
+
hash=commit_hash,
|
|
3504
|
+
message=commit_message,
|
|
3505
|
+
)
|
|
3506
|
+
|
|
3507
|
+
return self.gpt_prompts.files_content_gpt_no_edits
|
|
3508
|
+
except ANY_GIT_ERROR as err:
|
|
3509
|
+
self.io.tool_error(f"Unable to commit: {str(err)}")
|
|
3510
|
+
return
|
|
3511
|
+
|
|
3512
|
+
def show_auto_commit_outcome(self, res):
|
|
3513
|
+
commit_hash, commit_message = res
|
|
3514
|
+
self.last_aider_commit_hash = commit_hash
|
|
3515
|
+
self.aider_commit_hashes.add(commit_hash)
|
|
3516
|
+
self.last_aider_commit_message = commit_message
|
|
3517
|
+
if self.show_diffs:
|
|
3518
|
+
self.commands.cmd_diff()
|
|
3519
|
+
|
|
3520
|
+
def show_undo_hint(self):
|
|
3521
|
+
if not self.commit_before_message:
|
|
3522
|
+
return
|
|
3523
|
+
if self.commit_before_message[-1] != self.repo.get_head_commit_sha():
|
|
3524
|
+
self.io.tool_output("You can use /undo to undo and discard each aider commit.")
|
|
3525
|
+
|
|
3526
|
+
async def dirty_commit(self):
|
|
3527
|
+
if not self.need_commit_before_edits:
|
|
3528
|
+
return
|
|
3529
|
+
if not self.dirty_commits:
|
|
3530
|
+
return
|
|
3531
|
+
if not self.repo:
|
|
3532
|
+
return
|
|
3533
|
+
|
|
3534
|
+
await self.repo.commit(fnames=self.need_commit_before_edits, coder=self)
|
|
3535
|
+
|
|
3536
|
+
# files changed, move cur messages back behind the files messages
|
|
3537
|
+
# self.move_back_cur_messages(self.gpt_prompts.files_content_local_edits)
|
|
3538
|
+
return True
|
|
3539
|
+
|
|
3540
|
+
def get_edits(self, mode="update"):
|
|
3541
|
+
return []
|
|
3542
|
+
|
|
3543
|
+
def apply_edits(self, edits):
|
|
3544
|
+
return
|
|
3545
|
+
|
|
3546
|
+
def apply_edits_dry_run(self, edits):
|
|
3547
|
+
return edits
|
|
3548
|
+
|
|
3549
|
+
def auto_save_session(self):
|
|
3550
|
+
"""Automatically save the current session as 'auto-save'."""
|
|
3551
|
+
if not getattr(self.args, "auto_save", False):
|
|
3552
|
+
return
|
|
3553
|
+
try:
|
|
3554
|
+
session_manager = SessionManager(self, self.io)
|
|
3555
|
+
session_manager.save_session("auto-save", False)
|
|
3556
|
+
except Exception:
|
|
3557
|
+
# Don't show errors for auto-save to avoid interrupting the user experience
|
|
3558
|
+
pass
|
|
3559
|
+
|
|
3560
|
+
async def run_shell_commands(self):
|
|
3561
|
+
if not self.suggest_shell_commands:
|
|
3562
|
+
return ""
|
|
3563
|
+
|
|
3564
|
+
done = set()
|
|
3565
|
+
group = ConfirmGroup(set(self.shell_commands))
|
|
3566
|
+
accumulated_output = ""
|
|
3567
|
+
for command in self.shell_commands:
|
|
3568
|
+
if command in done:
|
|
3569
|
+
continue
|
|
3570
|
+
done.add(command)
|
|
3571
|
+
output = await self.handle_shell_commands(command, group)
|
|
3572
|
+
if output:
|
|
3573
|
+
accumulated_output += output + "\n\n"
|
|
3574
|
+
return accumulated_output
|
|
3575
|
+
|
|
3576
|
+
async def handle_shell_commands(self, commands_str, group):
|
|
3577
|
+
commands = commands_str.strip().splitlines()
|
|
3578
|
+
command_count = sum(
|
|
3579
|
+
1 for cmd in commands if cmd.strip() and not cmd.strip().startswith("#")
|
|
3580
|
+
)
|
|
3581
|
+
prompt = "Run shell command?" if command_count == 1 else "Run shell commands?"
|
|
3582
|
+
if not await self.io.confirm_ask(
|
|
3583
|
+
prompt,
|
|
3584
|
+
subject="\n".join(commands),
|
|
3585
|
+
explicit_yes_required=True,
|
|
3586
|
+
group=group,
|
|
3587
|
+
allow_never=True,
|
|
3588
|
+
):
|
|
3589
|
+
return
|
|
3590
|
+
|
|
3591
|
+
accumulated_output = ""
|
|
3592
|
+
for command in commands:
|
|
3593
|
+
command = command.strip()
|
|
3594
|
+
if not command or command.startswith("#"):
|
|
3595
|
+
continue
|
|
3596
|
+
|
|
3597
|
+
self.io.tool_output()
|
|
3598
|
+
self.io.tool_output(f"Running {command}")
|
|
3599
|
+
# Add the command to input history
|
|
3600
|
+
self.io.add_to_input_history(f"/run {command.strip()}")
|
|
3601
|
+
exit_status, output = await asyncio.to_thread(
|
|
3602
|
+
run_cmd, command, error_print=self.io.tool_error, cwd=self.root
|
|
3603
|
+
)
|
|
3604
|
+
if output:
|
|
3605
|
+
accumulated_output += f"Output from {command}\n{output}\n"
|
|
3606
|
+
|
|
3607
|
+
if accumulated_output.strip() and await self.io.confirm_ask(
|
|
3608
|
+
"Add command output to the chat?", allow_never=True
|
|
3609
|
+
):
|
|
3610
|
+
num_lines = len(accumulated_output.strip().splitlines())
|
|
3611
|
+
line_plural = "line" if num_lines == 1 else "lines"
|
|
3612
|
+
self.io.tool_output(f"Added {num_lines} {line_plural} of output to the chat.")
|
|
3613
|
+
return accumulated_output
|