auto-coder 0.1.255__py3-none-any.whl → 0.1.257__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/METADATA +2 -2
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/RECORD +30 -27
- autocoder/auto_coder.py +44 -50
- autocoder/chat_auto_coder.py +16 -17
- autocoder/chat_auto_coder_lang.py +1 -1
- autocoder/common/__init__.py +7 -0
- autocoder/common/auto_coder_lang.py +46 -16
- autocoder/common/code_auto_generate.py +45 -5
- autocoder/common/code_auto_generate_diff.py +45 -7
- autocoder/common/code_auto_generate_editblock.py +48 -4
- autocoder/common/code_auto_generate_strict_diff.py +46 -7
- autocoder/common/code_modification_ranker.py +39 -3
- autocoder/dispacher/actions/action.py +60 -40
- autocoder/dispacher/actions/plugins/action_regex_project.py +12 -6
- autocoder/index/entry.py +6 -4
- autocoder/index/filter/quick_filter.py +175 -65
- autocoder/index/index.py +94 -4
- autocoder/models.py +44 -6
- autocoder/privacy/__init__.py +3 -0
- autocoder/privacy/model_filter.py +100 -0
- autocoder/pyproject/__init__.py +1 -0
- autocoder/suffixproject/__init__.py +1 -0
- autocoder/tsproject/__init__.py +1 -0
- autocoder/utils/llms.py +27 -0
- autocoder/utils/model_provider_selector.py +192 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,7 @@ from autocoder.utils.conversation_store import store_code_model_conversation
|
|
|
15
15
|
from autocoder.common.printer import Printer
|
|
16
16
|
import time
|
|
17
17
|
from autocoder.utils.llms import get_llm_names
|
|
18
|
+
from autocoder.common import SourceCodeList
|
|
18
19
|
from loguru import logger
|
|
19
20
|
class ActionRegexProject:
|
|
20
21
|
def __init__(
|
|
@@ -36,20 +37,21 @@ class ActionRegexProject:
|
|
|
36
37
|
pp = RegexProject(args=args, llm=self.llm)
|
|
37
38
|
self.pp = pp
|
|
38
39
|
pp.run()
|
|
39
|
-
|
|
40
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
40
41
|
if self.llm:
|
|
41
42
|
if args.in_code_apply:
|
|
42
43
|
old_query = args.query
|
|
43
44
|
args.query = (args.context or "") + "\n\n" + args.query
|
|
44
|
-
|
|
45
|
+
source_code_list = build_index_and_filter_files(
|
|
45
46
|
llm=self.llm, args=args, sources=pp.sources
|
|
46
47
|
)
|
|
47
48
|
if args.in_code_apply:
|
|
48
49
|
args.query = old_query
|
|
49
|
-
self.process_content(
|
|
50
|
+
self.process_content(source_code_list)
|
|
50
51
|
|
|
51
|
-
def process_content(self,
|
|
52
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
52
53
|
args = self.args
|
|
54
|
+
content = source_code_list.to_str()
|
|
53
55
|
|
|
54
56
|
if args.execute and self.llm and not args.human_as_model:
|
|
55
57
|
if len(content) > self.args.model_max_input_length:
|
|
@@ -78,21 +80,25 @@ class ActionRegexProject:
|
|
|
78
80
|
generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
|
|
79
81
|
if self.args.enable_multi_round_generate:
|
|
80
82
|
generate_result = generate.multi_round_run(
|
|
81
|
-
query=args.query,
|
|
83
|
+
query=args.query, source_code_list=source_code_list
|
|
82
84
|
)
|
|
83
85
|
else:
|
|
84
86
|
generate_result = generate.single_round_run(
|
|
85
|
-
query=args.query,
|
|
87
|
+
query=args.query, source_code_list=source_code_list
|
|
86
88
|
)
|
|
87
89
|
|
|
88
90
|
elapsed_time = time.time() - start_time
|
|
89
91
|
speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
|
|
90
92
|
model_names = ",".join(get_llm_names(self.llm))
|
|
93
|
+
input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
|
|
94
|
+
generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
|
|
91
95
|
self.printer.print_in_terminal(
|
|
92
96
|
"code_generation_complete",
|
|
93
97
|
duration=elapsed_time,
|
|
94
98
|
input_tokens=generate_result.metadata.get('input_tokens_count', 0),
|
|
95
99
|
output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
|
|
100
|
+
input_cost=input_tokens_cost,
|
|
101
|
+
output_cost=generated_tokens_cost,
|
|
96
102
|
speed=round(speed, 2),
|
|
97
103
|
model_names=model_names
|
|
98
104
|
)
|
autocoder/index/entry.py
CHANGED
|
@@ -23,10 +23,11 @@ from autocoder.index.filter.quick_filter import QuickFilter
|
|
|
23
23
|
from autocoder.index.filter.normal_filter import NormalFilter
|
|
24
24
|
from autocoder.index.index import IndexManager
|
|
25
25
|
from loguru import logger
|
|
26
|
+
from autocoder.common import SourceCodeList
|
|
26
27
|
|
|
27
28
|
def build_index_and_filter_files(
|
|
28
29
|
llm, args: AutoCoderArgs, sources: List[SourceCode]
|
|
29
|
-
) ->
|
|
30
|
+
) -> SourceCodeList:
|
|
30
31
|
# Initialize timing and statistics
|
|
31
32
|
total_start_time = time.monotonic()
|
|
32
33
|
stats = {
|
|
@@ -253,7 +254,8 @@ def build_index_and_filter_files(
|
|
|
253
254
|
for file in final_filenames:
|
|
254
255
|
print(f"{file} - {final_files[file].reason}")
|
|
255
256
|
|
|
256
|
-
source_code = ""
|
|
257
|
+
source_code = ""
|
|
258
|
+
source_code_list = SourceCodeList(sources=[])
|
|
257
259
|
depulicated_sources = set()
|
|
258
260
|
|
|
259
261
|
for file in sources:
|
|
@@ -263,7 +265,7 @@ def build_index_and_filter_files(
|
|
|
263
265
|
depulicated_sources.add(file.module_name)
|
|
264
266
|
source_code += f"##File: {file.module_name}\n"
|
|
265
267
|
source_code += f"{file.source_code}\n\n"
|
|
266
|
-
|
|
268
|
+
source_code_list.sources.append(file)
|
|
267
269
|
if args.request_id and not args.skip_events:
|
|
268
270
|
queue_communicate.send_event(
|
|
269
271
|
request_id=args.request_id,
|
|
@@ -339,4 +341,4 @@ def build_index_and_filter_files(
|
|
|
339
341
|
)
|
|
340
342
|
)
|
|
341
343
|
|
|
342
|
-
return
|
|
344
|
+
return source_code_list
|
|
@@ -4,21 +4,21 @@ from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
|
|
|
4
4
|
from autocoder.common.utils_code_auto_generate import stream_chat_with_continue
|
|
5
5
|
from byzerllm.utils.str2model import to_model
|
|
6
6
|
from autocoder.index.types import IndexItem
|
|
7
|
-
from autocoder.common import AutoCoderArgs,SourceCode
|
|
7
|
+
from autocoder.common import AutoCoderArgs, SourceCode
|
|
8
8
|
import byzerllm
|
|
9
9
|
import time
|
|
10
10
|
from autocoder.index.index import IndexManager
|
|
11
11
|
from autocoder.index.types import (
|
|
12
12
|
IndexItem,
|
|
13
|
-
TargetFile,
|
|
13
|
+
TargetFile,
|
|
14
14
|
FileNumberList
|
|
15
15
|
)
|
|
16
16
|
from autocoder.rag.token_counter import count_tokens
|
|
17
17
|
from autocoder.common.printer import Printer
|
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
|
19
|
-
import
|
|
19
|
+
from byzerllm import MetaHolder
|
|
20
20
|
|
|
21
|
-
from autocoder.utils.llms import get_llm_names
|
|
21
|
+
from autocoder.utils.llms import get_llm_names, get_model_info
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def get_file_path(file_path):
|
|
@@ -32,8 +32,9 @@ class QuickFilterResult(BaseModel):
|
|
|
32
32
|
has_error: bool
|
|
33
33
|
error_message: Optional[str] = None
|
|
34
34
|
|
|
35
|
+
|
|
35
36
|
class QuickFilter():
|
|
36
|
-
def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
|
|
37
|
+
def __init__(self, index_manager: IndexManager, stats: Dict[str, Any], sources: List[SourceCode]):
|
|
37
38
|
self.index_manager = index_manager
|
|
38
39
|
self.args = index_manager.args
|
|
39
40
|
self.stats = stats
|
|
@@ -41,72 +42,142 @@ class QuickFilter():
|
|
|
41
42
|
self.printer = Printer()
|
|
42
43
|
self.max_tokens = self.args.index_filter_model_max_input_length
|
|
43
44
|
|
|
44
|
-
|
|
45
45
|
def big_filter(self, index_items: List[IndexItem],) -> QuickFilterResult:
|
|
46
46
|
chunks = []
|
|
47
47
|
current_chunk = []
|
|
48
|
-
|
|
48
|
+
|
|
49
49
|
# 将 index_items 切分成多个 chunks,第一个chunk尽可能接近max_tokens
|
|
50
50
|
for item in index_items:
|
|
51
51
|
# 使用 quick_filter_files.prompt 生成文本再统计
|
|
52
52
|
temp_chunk = current_chunk + [item]
|
|
53
|
-
prompt_text = self.quick_filter_files.prompt(
|
|
54
|
-
|
|
53
|
+
prompt_text = self.quick_filter_files.prompt(
|
|
54
|
+
temp_chunk, self.args.query)
|
|
55
|
+
temp_size = count_tokens(prompt_text)
|
|
55
56
|
# 如果当前chunk为空,或者添加item后不超过max_tokens,就添加到当前chunk
|
|
56
57
|
if not current_chunk or temp_size <= self.max_tokens:
|
|
57
|
-
current_chunk.append(item)
|
|
58
|
+
current_chunk.append(item)
|
|
58
59
|
else:
|
|
59
60
|
# 当前chunk已满,创建新chunk
|
|
60
61
|
chunks.append(current_chunk)
|
|
61
|
-
current_chunk = [item]
|
|
62
|
-
|
|
62
|
+
current_chunk = [item]
|
|
63
|
+
|
|
63
64
|
if current_chunk:
|
|
64
65
|
chunks.append(current_chunk)
|
|
65
|
-
|
|
66
|
-
tokens_len = count_tokens(
|
|
66
|
+
|
|
67
|
+
tokens_len = count_tokens(
|
|
68
|
+
self.quick_filter_files.prompt(index_items, self.args.query))
|
|
67
69
|
self.printer.print_in_terminal(
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
"quick_filter_too_long",
|
|
71
|
+
style="yellow",
|
|
72
|
+
tokens_len=tokens_len,
|
|
73
|
+
max_tokens=self.max_tokens,
|
|
74
|
+
split_size=len(chunks)
|
|
75
|
+
)
|
|
74
76
|
|
|
75
77
|
def process_chunk(chunk_index: int, chunk: List[IndexItem]) -> QuickFilterResult:
|
|
76
78
|
try:
|
|
77
|
-
|
|
79
|
+
# 获取模型名称列表
|
|
80
|
+
model_names = get_llm_names(
|
|
81
|
+
self.index_manager.index_filter_llm)
|
|
82
|
+
model_name = ",".join(model_names)
|
|
78
83
|
files: Dict[str, TargetFile] = {}
|
|
79
|
-
|
|
84
|
+
|
|
85
|
+
# 获取模型价格信息
|
|
86
|
+
model_info_map = {}
|
|
87
|
+
for name in model_names:
|
|
88
|
+
# 第二个参数是产品模式,从args中获取
|
|
89
|
+
info = get_model_info(name, self.args.product_mode)
|
|
90
|
+
if info:
|
|
91
|
+
model_info_map[name] = {
|
|
92
|
+
# 每百万tokens成本
|
|
93
|
+
"input_price": info.get("input_price", 0.0),
|
|
94
|
+
# 每百万tokens成本
|
|
95
|
+
"output_price": info.get("output_price", 0.0)
|
|
96
|
+
}
|
|
97
|
+
|
|
80
98
|
if chunk_index == 0:
|
|
81
99
|
# 第一个chunk使用流式输出
|
|
82
100
|
stream_generator = stream_chat_with_continue(
|
|
83
101
|
self.index_manager.index_filter_llm,
|
|
84
|
-
[{"role": "user", "content": self.quick_filter_files.prompt(
|
|
102
|
+
[{"role": "user", "content": self.quick_filter_files.prompt(
|
|
103
|
+
chunk, self.args.query)}],
|
|
85
104
|
{}
|
|
86
105
|
)
|
|
87
|
-
full_response,
|
|
106
|
+
full_response, last_meta = stream_out(
|
|
88
107
|
stream_generator,
|
|
89
108
|
model_name=model_name,
|
|
90
|
-
title=self.printer.get_message_from_key_with_format(
|
|
109
|
+
title=self.printer.get_message_from_key_with_format(
|
|
110
|
+
"quick_filter_title", model_name=model_name),
|
|
91
111
|
args=self.args
|
|
92
112
|
)
|
|
93
113
|
file_number_list = to_model(full_response, FileNumberList)
|
|
114
|
+
|
|
115
|
+
# 计算总成本
|
|
116
|
+
total_input_cost = 0.0
|
|
117
|
+
total_output_cost = 0.0
|
|
118
|
+
|
|
119
|
+
for name in model_names:
|
|
120
|
+
info = model_info_map.get(name, {})
|
|
121
|
+
# 计算公式:token数 * 单价 / 1000000
|
|
122
|
+
total_input_cost += (last_meta.input_tokens_count *
|
|
123
|
+
info.get("input_price", 0.0)) / 1000000
|
|
124
|
+
total_output_cost += (last_meta.generated_tokens_count *
|
|
125
|
+
info.get("output_price", 0.0)) / 1000000
|
|
126
|
+
|
|
127
|
+
# 四舍五入到4位小数
|
|
128
|
+
total_input_cost = round(total_input_cost, 4)
|
|
129
|
+
total_output_cost = round(total_output_cost, 4)
|
|
130
|
+
|
|
131
|
+
# 打印 token 统计信息和成本
|
|
132
|
+
self.printer.print_in_terminal(
|
|
133
|
+
"quick_filter_stats",
|
|
134
|
+
style="blue",
|
|
135
|
+
input_tokens=last_meta.input_tokens_count,
|
|
136
|
+
output_tokens=last_meta.generated_tokens_count,
|
|
137
|
+
input_cost=total_input_cost,
|
|
138
|
+
output_cost=total_output_cost,
|
|
139
|
+
model_names=model_name
|
|
140
|
+
)
|
|
94
141
|
else:
|
|
95
142
|
# 其他chunks直接使用with_llm
|
|
96
|
-
|
|
97
|
-
|
|
143
|
+
meta_holder = MetaHolder()
|
|
144
|
+
start_time = time.monotonic()
|
|
145
|
+
file_number_list = self.quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_meta(
|
|
146
|
+
meta_holder).with_return_type(FileNumberList).run(chunk, self.args.query)
|
|
147
|
+
end_time = time.monotonic()
|
|
148
|
+
|
|
149
|
+
total_input_cost = 0.0
|
|
150
|
+
total_output_cost = 0.0
|
|
151
|
+
if meta_holder.get_meta():
|
|
152
|
+
meta_dict = meta_holder.get_meta()
|
|
153
|
+
total_input_cost = meta_dict.get("input_tokens_count", 0) * model_info_map.get(model_name, {}).get("input_price", 0.0) / 1000000
|
|
154
|
+
total_output_cost = meta_dict.get("generated_tokens_count", 0) * model_info_map.get(model_name, {}).get("output_price", 0.0) / 1000000
|
|
155
|
+
|
|
156
|
+
self.printer.print_in_terminal(
|
|
157
|
+
"quick_filter_stats",
|
|
158
|
+
style="blue",
|
|
159
|
+
input_tokens=meta_dict.get("input_tokens_count", 0),
|
|
160
|
+
output_tokens=meta_dict.get("generated_tokens_count", 0),
|
|
161
|
+
input_cost=total_input_cost,
|
|
162
|
+
output_cost=total_output_cost,
|
|
163
|
+
model_names=model_name,
|
|
164
|
+
elapsed_time=f"{end_time - start_time:.2f}"
|
|
165
|
+
)
|
|
166
|
+
|
|
98
167
|
if file_number_list:
|
|
99
168
|
for file_number in file_number_list.file_list:
|
|
100
|
-
file_path = get_file_path(
|
|
169
|
+
file_path = get_file_path(
|
|
170
|
+
chunk[file_number].module_name)
|
|
101
171
|
files[file_path] = TargetFile(
|
|
102
172
|
file_path=chunk[file_number].module_name,
|
|
103
|
-
reason=self.printer.get_message_from_key(
|
|
173
|
+
reason=self.printer.get_message_from_key(
|
|
174
|
+
"quick_filter_reason")
|
|
104
175
|
)
|
|
105
176
|
return QuickFilterResult(
|
|
106
177
|
files=files,
|
|
107
178
|
has_error=False
|
|
108
179
|
)
|
|
109
|
-
|
|
180
|
+
|
|
110
181
|
except Exception as e:
|
|
111
182
|
self.printer.print_in_terminal(
|
|
112
183
|
"quick_filter_failed",
|
|
@@ -123,25 +194,25 @@ class QuickFilter():
|
|
|
123
194
|
if chunks:
|
|
124
195
|
with ThreadPoolExecutor() as executor:
|
|
125
196
|
# 提交所有chunks到线程池并收集结果
|
|
126
|
-
futures = [executor.submit(process_chunk, i, chunk)
|
|
127
|
-
|
|
128
|
-
|
|
197
|
+
futures = [executor.submit(process_chunk, i, chunk)
|
|
198
|
+
for i, chunk in enumerate(chunks)]
|
|
199
|
+
|
|
129
200
|
# 等待所有任务完成并收集结果
|
|
130
201
|
for future in futures:
|
|
131
202
|
results.append(future.result())
|
|
132
|
-
|
|
203
|
+
|
|
133
204
|
# 合并所有结果
|
|
134
205
|
final_files: Dict[str, TargetFile] = {}
|
|
135
206
|
has_error = False
|
|
136
207
|
error_messages: List[str] = []
|
|
137
|
-
|
|
208
|
+
|
|
138
209
|
for result in results:
|
|
139
210
|
if result.has_error:
|
|
140
211
|
has_error = True
|
|
141
212
|
if result.error_message:
|
|
142
213
|
error_messages.append(result.error_message)
|
|
143
214
|
final_files.update(result.files)
|
|
144
|
-
|
|
215
|
+
|
|
145
216
|
return QuickFilterResult(
|
|
146
217
|
files=final_files,
|
|
147
218
|
has_error=has_error,
|
|
@@ -149,7 +220,7 @@ class QuickFilter():
|
|
|
149
220
|
)
|
|
150
221
|
|
|
151
222
|
@byzerllm.prompt()
|
|
152
|
-
def quick_filter_files(self,file_meta_list:List[IndexItem],query:str) -> str:
|
|
223
|
+
def quick_filter_files(self, file_meta_list: List[IndexItem], query: str) -> str:
|
|
153
224
|
'''
|
|
154
225
|
当用户提一个需求的时候,我们需要找到相关的文件,然后阅读这些文件,并且修改其中部分文件。
|
|
155
226
|
现在,给定下面的索引文件:
|
|
@@ -160,7 +231,7 @@ class QuickFilter():
|
|
|
160
231
|
|
|
161
232
|
索引文件包含文件序号(##[]括起来的部分),文件路径,文件符号信息等。
|
|
162
233
|
下面是用户的查询需求:
|
|
163
|
-
|
|
234
|
+
|
|
164
235
|
<query>
|
|
165
236
|
{{ query }}
|
|
166
237
|
</query>
|
|
@@ -182,63 +253,101 @@ class QuickFilter():
|
|
|
182
253
|
2. 如果 query 里是一段历史对话,那么对话里的内容提及的文件路径必须要返回。
|
|
183
254
|
3. json格式数据不允许有注释
|
|
184
255
|
'''
|
|
185
|
-
file_meta_str = "\n".join(
|
|
256
|
+
file_meta_str = "\n".join(
|
|
257
|
+
[f"##[{index}]{item.module_name}\n{item.symbols}" for index, item in enumerate(file_meta_list)])
|
|
186
258
|
context = {
|
|
187
259
|
"content": file_meta_str,
|
|
188
260
|
"query": query
|
|
189
261
|
}
|
|
190
|
-
return context
|
|
262
|
+
return context
|
|
191
263
|
|
|
192
264
|
def filter(self, index_items: List[IndexItem], query: str) -> QuickFilterResult:
|
|
193
265
|
final_files: Dict[str, TargetFile] = {}
|
|
194
|
-
start_time = time.monotonic()
|
|
266
|
+
start_time = time.monotonic()
|
|
267
|
+
|
|
268
|
+
prompt_str = self.quick_filter_files.prompt(index_items, query)
|
|
269
|
+
|
|
270
|
+
tokens_len = count_tokens(prompt_str)
|
|
195
271
|
|
|
196
|
-
prompt_str = self.quick_filter_files.prompt(index_items,query)
|
|
197
|
-
|
|
198
|
-
tokens_len = count_tokens(prompt_str)
|
|
199
|
-
|
|
200
272
|
# Print current index size
|
|
201
273
|
self.printer.print_in_terminal(
|
|
202
274
|
"quick_filter_tokens_len",
|
|
203
275
|
style="blue",
|
|
204
276
|
tokens_len=tokens_len
|
|
205
277
|
)
|
|
206
|
-
|
|
207
|
-
if tokens_len > self.max_tokens:
|
|
278
|
+
|
|
279
|
+
if tokens_len > self.max_tokens:
|
|
208
280
|
return self.big_filter(index_items)
|
|
209
|
-
|
|
281
|
+
|
|
210
282
|
try:
|
|
211
|
-
|
|
212
|
-
|
|
283
|
+
# 获取模型名称
|
|
284
|
+
model_names = get_llm_names(self.index_manager.index_filter_llm)
|
|
285
|
+
model_name = ",".join(model_names)
|
|
286
|
+
|
|
287
|
+
# 获取模型价格信息
|
|
288
|
+
model_info_map = {}
|
|
289
|
+
for name in model_names:
|
|
290
|
+
# 第二个参数是产品模式,从args中获取
|
|
291
|
+
info = get_model_info(name, self.args.product_mode)
|
|
292
|
+
if info:
|
|
293
|
+
model_info_map[name] = {
|
|
294
|
+
# 每百万tokens成本
|
|
295
|
+
"input_price": info.get("input_price", 0.0),
|
|
296
|
+
# 每百万tokens成本
|
|
297
|
+
"output_price": info.get("output_price", 0.0)
|
|
298
|
+
}
|
|
299
|
+
|
|
213
300
|
# 渲染 Prompt 模板
|
|
214
|
-
query = self.quick_filter_files.prompt(
|
|
215
|
-
|
|
301
|
+
query = self.quick_filter_files.prompt(
|
|
302
|
+
index_items, self.args.query)
|
|
303
|
+
|
|
216
304
|
# 使用流式输出处理
|
|
217
305
|
stream_generator = stream_chat_with_continue(
|
|
218
306
|
self.index_manager.index_filter_llm,
|
|
219
307
|
[{"role": "user", "content": query}],
|
|
220
308
|
{}
|
|
221
309
|
)
|
|
222
|
-
|
|
310
|
+
|
|
223
311
|
# 获取完整响应
|
|
224
312
|
full_response, last_meta = stream_out(
|
|
225
313
|
stream_generator,
|
|
226
314
|
model_name=model_name,
|
|
227
|
-
title=self.printer.get_message_from_key_with_format(
|
|
315
|
+
title=self.printer.get_message_from_key_with_format(
|
|
316
|
+
"quick_filter_title", model_name=model_name),
|
|
228
317
|
args=self.args
|
|
229
|
-
)
|
|
318
|
+
)
|
|
230
319
|
# 解析结果
|
|
231
320
|
file_number_list = to_model(full_response, FileNumberList)
|
|
232
|
-
end_time = time.monotonic()
|
|
233
|
-
|
|
321
|
+
end_time = time.monotonic()
|
|
322
|
+
|
|
323
|
+
# 计算总成本
|
|
324
|
+
total_input_cost = 0.0
|
|
325
|
+
total_output_cost = 0.0
|
|
326
|
+
|
|
327
|
+
for name in model_names:
|
|
328
|
+
info = model_info_map.get(name, {})
|
|
329
|
+
# 计算公式:token数 * 单价 / 1000000
|
|
330
|
+
total_input_cost += (last_meta.input_tokens_count *
|
|
331
|
+
info.get("input_price", 0.0)) / 1000000
|
|
332
|
+
total_output_cost += (last_meta.generated_tokens_count *
|
|
333
|
+
info.get("output_price", 0.0)) / 1000000
|
|
334
|
+
|
|
335
|
+
# 四舍五入到4位小数
|
|
336
|
+
total_input_cost = round(total_input_cost, 4)
|
|
337
|
+
total_output_cost = round(total_output_cost, 4)
|
|
338
|
+
|
|
339
|
+
# 打印 token 统计信息和成本
|
|
234
340
|
self.printer.print_in_terminal(
|
|
235
|
-
"quick_filter_stats",
|
|
341
|
+
"quick_filter_stats",
|
|
236
342
|
style="blue",
|
|
237
343
|
elapsed_time=f"{end_time - start_time:.2f}",
|
|
238
344
|
input_tokens=last_meta.input_tokens_count,
|
|
239
|
-
output_tokens=last_meta.generated_tokens_count
|
|
345
|
+
output_tokens=last_meta.generated_tokens_count,
|
|
346
|
+
input_cost=total_input_cost,
|
|
347
|
+
output_cost=total_output_cost,
|
|
348
|
+
model_names=model_name
|
|
240
349
|
)
|
|
241
|
-
|
|
350
|
+
|
|
242
351
|
except Exception as e:
|
|
243
352
|
self.printer.print_in_terminal(
|
|
244
353
|
"quick_filter_failed",
|
|
@@ -250,16 +359,17 @@ class QuickFilter():
|
|
|
250
359
|
has_error=True,
|
|
251
360
|
error_message=str(e)
|
|
252
361
|
)
|
|
253
|
-
|
|
362
|
+
|
|
254
363
|
if file_number_list:
|
|
255
364
|
for file_number in file_number_list.file_list:
|
|
256
365
|
final_files[get_file_path(index_items[file_number].module_name)] = TargetFile(
|
|
257
366
|
file_path=index_items[file_number].module_name,
|
|
258
|
-
reason=self.printer.get_message_from_key(
|
|
367
|
+
reason=self.printer.get_message_from_key(
|
|
368
|
+
"quick_filter_reason")
|
|
259
369
|
)
|
|
260
|
-
end_time = time.monotonic()
|
|
261
|
-
self.stats["timings"]["quick_filter"] = end_time - start_time
|
|
370
|
+
end_time = time.monotonic()
|
|
371
|
+
self.stats["timings"]["quick_filter"] = end_time - start_time
|
|
262
372
|
return QuickFilterResult(
|
|
263
373
|
files=final_files,
|
|
264
374
|
has_error=False
|
|
265
|
-
)
|
|
375
|
+
)
|