auto-coder 0.1.256__py3-none-any.whl → 0.1.257__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/METADATA +2 -2
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/RECORD +24 -21
- autocoder/auto_coder.py +30 -50
- autocoder/chat_auto_coder.py +16 -17
- autocoder/chat_auto_coder_lang.py +1 -1
- autocoder/common/__init__.py +7 -0
- autocoder/common/auto_coder_lang.py +38 -8
- autocoder/common/code_auto_generate.py +22 -2
- autocoder/common/code_auto_generate_diff.py +23 -4
- autocoder/common/code_auto_generate_editblock.py +24 -2
- autocoder/common/code_auto_generate_strict_diff.py +23 -3
- autocoder/dispacher/actions/action.py +38 -28
- autocoder/dispacher/actions/plugins/action_regex_project.py +8 -6
- autocoder/index/entry.py +6 -4
- autocoder/index/index.py +94 -4
- autocoder/models.py +14 -0
- autocoder/privacy/__init__.py +3 -0
- autocoder/privacy/model_filter.py +100 -0
- autocoder/utils/model_provider_selector.py +192 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.256.dist-info → auto_coder-0.1.257.dist-info}/top_level.txt +0 -0
|
@@ -27,6 +27,8 @@ from loguru import logger
|
|
|
27
27
|
import time
|
|
28
28
|
from autocoder.common.printer import Printer
|
|
29
29
|
from autocoder.utils.llms import get_llm_names
|
|
30
|
+
from autocoder.privacy.model_filter import ModelPathFilter
|
|
31
|
+
from autocoder.common import SourceCodeList
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
class BaseAction:
|
|
@@ -55,14 +57,15 @@ class ActionTSProject(BaseAction):
|
|
|
55
57
|
self.pp = pp
|
|
56
58
|
pp.run()
|
|
57
59
|
|
|
58
|
-
source_code = pp.output()
|
|
60
|
+
# source_code = pp.output()
|
|
61
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
59
62
|
if self.llm:
|
|
60
63
|
if args.in_code_apply:
|
|
61
64
|
old_query = args.query
|
|
62
65
|
args.query = (args.context or "") + "\n\n" + args.query
|
|
63
|
-
|
|
66
|
+
source_code_list = build_index_and_filter_files(
|
|
64
67
|
llm=self.llm, args=args, sources=pp.sources
|
|
65
|
-
)
|
|
68
|
+
)
|
|
66
69
|
if args.in_code_apply:
|
|
67
70
|
args.query = old_query
|
|
68
71
|
|
|
@@ -81,17 +84,21 @@ class ActionTSProject(BaseAction):
|
|
|
81
84
|
html_path=html_path,
|
|
82
85
|
max_iter=self.args.image_max_iter,
|
|
83
86
|
)
|
|
84
|
-
|
|
87
|
+
html_code = ""
|
|
85
88
|
with open(html_path, "r") as f:
|
|
86
89
|
html_code = f.read()
|
|
87
|
-
|
|
90
|
+
|
|
91
|
+
source_code_list.sources.append(SourceCode(
|
|
92
|
+
module_name=html_path,
|
|
93
|
+
source_code=html_code,
|
|
94
|
+
tag="IMAGE"))
|
|
88
95
|
|
|
89
|
-
self.process_content(
|
|
96
|
+
self.process_content(source_code_list)
|
|
90
97
|
return True
|
|
91
98
|
|
|
92
|
-
def process_content(self,
|
|
99
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
93
100
|
args = self.args
|
|
94
|
-
|
|
101
|
+
content = source_code_list.to_str()
|
|
95
102
|
if args.execute and self.llm and not args.human_as_model:
|
|
96
103
|
content_length = self._get_content_length(content)
|
|
97
104
|
if content_length > self.args.model_max_input_length:
|
|
@@ -116,13 +123,14 @@ class ActionTSProject(BaseAction):
|
|
|
116
123
|
)
|
|
117
124
|
else:
|
|
118
125
|
generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
|
|
126
|
+
|
|
119
127
|
if self.args.enable_multi_round_generate:
|
|
120
128
|
generate_result = generate.multi_round_run(
|
|
121
|
-
query=args.query,
|
|
129
|
+
query=args.query, source_code_list=source_code_list
|
|
122
130
|
)
|
|
123
131
|
else:
|
|
124
132
|
generate_result = generate.single_round_run(
|
|
125
|
-
query=args.query,
|
|
133
|
+
query=args.query, source_code_list=source_code_list
|
|
126
134
|
)
|
|
127
135
|
elapsed_time = time.time() - start_time
|
|
128
136
|
speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
|
|
@@ -191,11 +199,12 @@ class ActionPyScriptProject(BaseAction):
|
|
|
191
199
|
pp = Level1PyProject(
|
|
192
200
|
script_path=args.script_path, package_name=args.package_name
|
|
193
201
|
)
|
|
194
|
-
|
|
195
|
-
|
|
202
|
+
pp.run()
|
|
203
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
204
|
+
self.process_content(source_code_list)
|
|
196
205
|
return True
|
|
197
206
|
|
|
198
|
-
def process_content(self,
|
|
207
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
199
208
|
args = self.args
|
|
200
209
|
if args.execute:
|
|
201
210
|
self.printer.print_in_terminal("code_generation_start")
|
|
@@ -216,11 +225,11 @@ class ActionPyScriptProject(BaseAction):
|
|
|
216
225
|
generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
|
|
217
226
|
if self.args.enable_multi_round_generate:
|
|
218
227
|
generate_result = generate.multi_round_run(
|
|
219
|
-
query=args.query,
|
|
228
|
+
query=args.query, source_code_list=source_code_list
|
|
220
229
|
)
|
|
221
230
|
else:
|
|
222
231
|
generate_result = generate.single_round_run(
|
|
223
|
-
query=args.query,
|
|
232
|
+
query=args.query, source_code_list=source_code_list
|
|
224
233
|
)
|
|
225
234
|
|
|
226
235
|
elapsed_time = time.time() - start_time
|
|
@@ -293,24 +302,24 @@ class ActionPyProject(BaseAction):
|
|
|
293
302
|
pp = PyProject(args=self.args, llm=self.llm)
|
|
294
303
|
self.pp = pp
|
|
295
304
|
pp.run(packages=args.py_packages.split(",") if args.py_packages else [])
|
|
296
|
-
|
|
305
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
297
306
|
|
|
298
307
|
if self.llm:
|
|
299
308
|
old_query = args.query
|
|
300
309
|
if args.in_code_apply:
|
|
301
310
|
args.query = (args.context or "") + "\n\n" + args.query
|
|
302
|
-
|
|
311
|
+
source_code_list = build_index_and_filter_files(
|
|
303
312
|
llm=self.llm, args=args, sources=pp.sources
|
|
304
313
|
)
|
|
305
314
|
if args.in_code_apply:
|
|
306
315
|
args.query = old_query
|
|
307
316
|
|
|
308
|
-
self.process_content(
|
|
317
|
+
self.process_content(source_code_list)
|
|
309
318
|
return True
|
|
310
319
|
|
|
311
|
-
def process_content(self,
|
|
320
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
312
321
|
args = self.args
|
|
313
|
-
|
|
322
|
+
content = source_code_list.to_str()
|
|
314
323
|
if args.execute and self.llm and not args.human_as_model:
|
|
315
324
|
content_length = self._get_content_length(content)
|
|
316
325
|
if content_length > self.args.model_max_input_length:
|
|
@@ -342,11 +351,11 @@ class ActionPyProject(BaseAction):
|
|
|
342
351
|
|
|
343
352
|
if self.args.enable_multi_round_generate:
|
|
344
353
|
generate_result = generate.multi_round_run(
|
|
345
|
-
query=args.query,
|
|
354
|
+
query=args.query, source_code_list=source_code_list
|
|
346
355
|
)
|
|
347
356
|
else:
|
|
348
357
|
generate_result = generate.single_round_run(
|
|
349
|
-
query=args.query,
|
|
358
|
+
query=args.query, source_code_list=source_code_list
|
|
350
359
|
)
|
|
351
360
|
elapsed_time = time.time() - start_time
|
|
352
361
|
speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
|
|
@@ -414,20 +423,21 @@ class ActionSuffixProject(BaseAction):
|
|
|
414
423
|
pp = SuffixProject(args=args, llm=self.llm)
|
|
415
424
|
self.pp = pp
|
|
416
425
|
pp.run()
|
|
417
|
-
|
|
426
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
418
427
|
if self.llm:
|
|
419
428
|
if args.in_code_apply:
|
|
420
429
|
old_query = args.query
|
|
421
430
|
args.query = (args.context or "") + "\n\n" + args.query
|
|
422
|
-
|
|
431
|
+
source_code_list = build_index_and_filter_files(
|
|
423
432
|
llm=self.llm, args=args, sources=pp.sources
|
|
424
433
|
)
|
|
425
434
|
if args.in_code_apply:
|
|
426
435
|
args.query = old_query
|
|
427
|
-
self.process_content(
|
|
436
|
+
self.process_content(source_code_list)
|
|
428
437
|
|
|
429
|
-
def process_content(self,
|
|
438
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
430
439
|
args = self.args
|
|
440
|
+
content = source_code_list.to_str()
|
|
431
441
|
|
|
432
442
|
if args.execute and self.llm and not args.human_as_model:
|
|
433
443
|
content_length = self._get_content_length(content)
|
|
@@ -455,11 +465,11 @@ class ActionSuffixProject(BaseAction):
|
|
|
455
465
|
generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
|
|
456
466
|
if self.args.enable_multi_round_generate:
|
|
457
467
|
generate_result = generate.multi_round_run(
|
|
458
|
-
query=args.query,
|
|
468
|
+
query=args.query, source_code_list=source_code_list
|
|
459
469
|
)
|
|
460
470
|
else:
|
|
461
471
|
generate_result = generate.single_round_run(
|
|
462
|
-
query=args.query,
|
|
472
|
+
query=args.query, source_code_list=source_code_list
|
|
463
473
|
)
|
|
464
474
|
|
|
465
475
|
elapsed_time = time.time() - start_time
|
|
@@ -15,6 +15,7 @@ from autocoder.utils.conversation_store import store_code_model_conversation
|
|
|
15
15
|
from autocoder.common.printer import Printer
|
|
16
16
|
import time
|
|
17
17
|
from autocoder.utils.llms import get_llm_names
|
|
18
|
+
from autocoder.common import SourceCodeList
|
|
18
19
|
from loguru import logger
|
|
19
20
|
class ActionRegexProject:
|
|
20
21
|
def __init__(
|
|
@@ -36,20 +37,21 @@ class ActionRegexProject:
|
|
|
36
37
|
pp = RegexProject(args=args, llm=self.llm)
|
|
37
38
|
self.pp = pp
|
|
38
39
|
pp.run()
|
|
39
|
-
|
|
40
|
+
source_code_list = SourceCodeList(pp.sources)
|
|
40
41
|
if self.llm:
|
|
41
42
|
if args.in_code_apply:
|
|
42
43
|
old_query = args.query
|
|
43
44
|
args.query = (args.context or "") + "\n\n" + args.query
|
|
44
|
-
|
|
45
|
+
source_code_list = build_index_and_filter_files(
|
|
45
46
|
llm=self.llm, args=args, sources=pp.sources
|
|
46
47
|
)
|
|
47
48
|
if args.in_code_apply:
|
|
48
49
|
args.query = old_query
|
|
49
|
-
self.process_content(
|
|
50
|
+
self.process_content(source_code_list)
|
|
50
51
|
|
|
51
|
-
def process_content(self,
|
|
52
|
+
def process_content(self, source_code_list: SourceCodeList):
|
|
52
53
|
args = self.args
|
|
54
|
+
content = source_code_list.to_str()
|
|
53
55
|
|
|
54
56
|
if args.execute and self.llm and not args.human_as_model:
|
|
55
57
|
if len(content) > self.args.model_max_input_length:
|
|
@@ -78,11 +80,11 @@ class ActionRegexProject:
|
|
|
78
80
|
generate = CodeAutoGenerate(llm=self.llm, args=self.args, action=self)
|
|
79
81
|
if self.args.enable_multi_round_generate:
|
|
80
82
|
generate_result = generate.multi_round_run(
|
|
81
|
-
query=args.query,
|
|
83
|
+
query=args.query, source_code_list=source_code_list
|
|
82
84
|
)
|
|
83
85
|
else:
|
|
84
86
|
generate_result = generate.single_round_run(
|
|
85
|
-
query=args.query,
|
|
87
|
+
query=args.query, source_code_list=source_code_list
|
|
86
88
|
)
|
|
87
89
|
|
|
88
90
|
elapsed_time = time.time() - start_time
|
autocoder/index/entry.py
CHANGED
|
@@ -23,10 +23,11 @@ from autocoder.index.filter.quick_filter import QuickFilter
|
|
|
23
23
|
from autocoder.index.filter.normal_filter import NormalFilter
|
|
24
24
|
from autocoder.index.index import IndexManager
|
|
25
25
|
from loguru import logger
|
|
26
|
+
from autocoder.common import SourceCodeList
|
|
26
27
|
|
|
27
28
|
def build_index_and_filter_files(
|
|
28
29
|
llm, args: AutoCoderArgs, sources: List[SourceCode]
|
|
29
|
-
) ->
|
|
30
|
+
) -> SourceCodeList:
|
|
30
31
|
# Initialize timing and statistics
|
|
31
32
|
total_start_time = time.monotonic()
|
|
32
33
|
stats = {
|
|
@@ -253,7 +254,8 @@ def build_index_and_filter_files(
|
|
|
253
254
|
for file in final_filenames:
|
|
254
255
|
print(f"{file} - {final_files[file].reason}")
|
|
255
256
|
|
|
256
|
-
source_code = ""
|
|
257
|
+
source_code = ""
|
|
258
|
+
source_code_list = SourceCodeList(sources=[])
|
|
257
259
|
depulicated_sources = set()
|
|
258
260
|
|
|
259
261
|
for file in sources:
|
|
@@ -263,7 +265,7 @@ def build_index_and_filter_files(
|
|
|
263
265
|
depulicated_sources.add(file.module_name)
|
|
264
266
|
source_code += f"##File: {file.module_name}\n"
|
|
265
267
|
source_code += f"{file.source_code}\n\n"
|
|
266
|
-
|
|
268
|
+
source_code_list.sources.append(file)
|
|
267
269
|
if args.request_id and not args.skip_events:
|
|
268
270
|
queue_communicate.send_event(
|
|
269
271
|
request_id=args.request_id,
|
|
@@ -339,4 +341,4 @@ def build_index_and_filter_files(
|
|
|
339
341
|
)
|
|
340
342
|
)
|
|
341
343
|
|
|
342
|
-
return
|
|
344
|
+
return source_code_list
|
autocoder/index/index.py
CHANGED
|
@@ -9,6 +9,7 @@ from autocoder.index.symbols_utils import (
|
|
|
9
9
|
SymbolType,
|
|
10
10
|
symbols_info_to_str,
|
|
11
11
|
)
|
|
12
|
+
from autocoder.privacy.model_filter import ModelPathFilter
|
|
12
13
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
13
14
|
import threading
|
|
14
15
|
|
|
@@ -17,6 +18,7 @@ import hashlib
|
|
|
17
18
|
|
|
18
19
|
from autocoder.common.printer import Printer
|
|
19
20
|
from autocoder.common.auto_coder_lang import get_message
|
|
21
|
+
from autocoder.utils.llms import get_llm_names, get_model_info
|
|
20
22
|
from autocoder.index.types import (
|
|
21
23
|
IndexItem,
|
|
22
24
|
TargetFile,
|
|
@@ -30,6 +32,9 @@ class IndexManager:
|
|
|
30
32
|
):
|
|
31
33
|
self.sources = sources
|
|
32
34
|
self.source_dir = args.source_dir
|
|
35
|
+
# Initialize model filter for index_llm and index_filter_llm
|
|
36
|
+
self.index_model_filter = None
|
|
37
|
+
self.index_filter_model_filter = None
|
|
33
38
|
self.anti_quota_limit = (
|
|
34
39
|
args.index_model_anti_quota_limit or args.anti_quota_limit
|
|
35
40
|
)
|
|
@@ -46,6 +51,12 @@ class IndexManager:
|
|
|
46
51
|
self.index_filter_llm = llm
|
|
47
52
|
|
|
48
53
|
self.llm = llm
|
|
54
|
+
|
|
55
|
+
# Initialize model filters
|
|
56
|
+
if self.index_llm:
|
|
57
|
+
self.index_model_filter = ModelPathFilter.from_model_object(self.index_llm, args)
|
|
58
|
+
if self.index_filter_llm:
|
|
59
|
+
self.index_filter_model_filter = ModelPathFilter.from_model_object(self.index_filter_llm, args)
|
|
49
60
|
self.args = args
|
|
50
61
|
self.max_input_length = (
|
|
51
62
|
args.index_model_max_input_length or args.model_max_input_length
|
|
@@ -194,6 +205,17 @@ class IndexManager:
|
|
|
194
205
|
ext = os.path.splitext(file_path)[1].lower()
|
|
195
206
|
if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
|
|
196
207
|
return True
|
|
208
|
+
|
|
209
|
+
# Check model filter restrictions
|
|
210
|
+
if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
|
|
211
|
+
self.printer.print_in_terminal(
|
|
212
|
+
"index_file_filtered",
|
|
213
|
+
style="yellow",
|
|
214
|
+
file_path=file_path,
|
|
215
|
+
model_name=",".join(get_llm_names(self.index_llm))
|
|
216
|
+
)
|
|
217
|
+
return True
|
|
218
|
+
|
|
197
219
|
return False
|
|
198
220
|
|
|
199
221
|
def build_index_for_single_source(self, source: SourceCode):
|
|
@@ -212,8 +234,29 @@ class IndexManager:
|
|
|
212
234
|
model_name = ",".join(get_llm_names(self.index_llm))
|
|
213
235
|
|
|
214
236
|
try:
|
|
237
|
+
# 获取模型名称列表
|
|
238
|
+
model_names = get_llm_names(self.index_llm)
|
|
239
|
+
model_name = ",".join(model_names)
|
|
240
|
+
|
|
241
|
+
# 获取模型价格信息
|
|
242
|
+
model_info_map = {}
|
|
243
|
+
for name in model_names:
|
|
244
|
+
info = get_model_info(name, self.args.product_mode)
|
|
245
|
+
if info:
|
|
246
|
+
model_info_map[name] = {
|
|
247
|
+
"input_price": info.get("input_price", 0.0),
|
|
248
|
+
"output_price": info.get("output_price", 0.0)
|
|
249
|
+
}
|
|
250
|
+
|
|
215
251
|
start_time = time.monotonic()
|
|
216
252
|
source_code = source.source_code
|
|
253
|
+
|
|
254
|
+
# 统计token和成本
|
|
255
|
+
total_input_tokens = 0
|
|
256
|
+
total_output_tokens = 0
|
|
257
|
+
total_input_cost = 0.0
|
|
258
|
+
total_output_cost = 0.0
|
|
259
|
+
|
|
217
260
|
if len(source.source_code) > self.max_input_length:
|
|
218
261
|
self.printer.print_in_terminal(
|
|
219
262
|
"index_file_too_large",
|
|
@@ -227,15 +270,38 @@ class IndexManager:
|
|
|
227
270
|
)
|
|
228
271
|
symbols = []
|
|
229
272
|
for chunk in chunks:
|
|
273
|
+
meta_holder = byzerllm.MetaHolder()
|
|
230
274
|
chunk_symbols = self.get_all_file_symbols.with_llm(
|
|
231
|
-
self.index_llm).run(source.module_name, chunk)
|
|
275
|
+
self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
|
|
232
276
|
time.sleep(self.anti_quota_limit)
|
|
233
277
|
symbols.append(chunk_symbols)
|
|
278
|
+
|
|
279
|
+
if meta_holder.get_meta():
|
|
280
|
+
meta_dict = meta_holder.get_meta()
|
|
281
|
+
total_input_tokens += meta_dict.get("input_tokens_count", 0)
|
|
282
|
+
total_output_tokens += meta_dict.get("generated_tokens_count", 0)
|
|
283
|
+
|
|
234
284
|
symbols = "\n".join(symbols)
|
|
235
285
|
else:
|
|
286
|
+
meta_holder = byzerllm.MetaHolder()
|
|
236
287
|
symbols = self.get_all_file_symbols.with_llm(
|
|
237
|
-
self.index_llm).run(source.module_name, source_code)
|
|
288
|
+
self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
|
|
238
289
|
time.sleep(self.anti_quota_limit)
|
|
290
|
+
|
|
291
|
+
if meta_holder.get_meta():
|
|
292
|
+
meta_dict = meta_holder.get_meta()
|
|
293
|
+
total_input_tokens += meta_dict.get("input_tokens_count", 0)
|
|
294
|
+
total_output_tokens += meta_dict.get("generated_tokens_count", 0)
|
|
295
|
+
|
|
296
|
+
# 计算总成本
|
|
297
|
+
for name in model_names:
|
|
298
|
+
info = model_info_map.get(name, {})
|
|
299
|
+
total_input_cost += (total_input_tokens * info.get("input_price", 0.0)) / 1000000
|
|
300
|
+
total_output_cost += (total_output_tokens * info.get("output_price", 0.0)) / 1000000
|
|
301
|
+
|
|
302
|
+
# 四舍五入到4位小数
|
|
303
|
+
total_input_cost = round(total_input_cost, 4)
|
|
304
|
+
total_output_cost = round(total_output_cost, 4)
|
|
239
305
|
|
|
240
306
|
self.printer.print_in_terminal(
|
|
241
307
|
"index_update_success",
|
|
@@ -243,7 +309,11 @@ class IndexManager:
|
|
|
243
309
|
file_path=file_path,
|
|
244
310
|
md5=md5,
|
|
245
311
|
duration=time.monotonic() - start_time,
|
|
246
|
-
model_name=model_name
|
|
312
|
+
model_name=model_name,
|
|
313
|
+
input_tokens=total_input_tokens,
|
|
314
|
+
output_tokens=total_output_tokens,
|
|
315
|
+
input_cost=total_input_cost,
|
|
316
|
+
output_cost=total_output_cost
|
|
247
317
|
)
|
|
248
318
|
|
|
249
319
|
except Exception as e:
|
|
@@ -263,6 +333,10 @@ class IndexManager:
|
|
|
263
333
|
"symbols": symbols,
|
|
264
334
|
"last_modified": os.path.getmtime(file_path),
|
|
265
335
|
"md5": md5,
|
|
336
|
+
"input_tokens_count": total_input_tokens,
|
|
337
|
+
"generated_tokens_count": total_output_tokens,
|
|
338
|
+
"input_tokens_cost": total_input_cost,
|
|
339
|
+
"generated_tokens_cost": total_output_cost
|
|
266
340
|
}
|
|
267
341
|
|
|
268
342
|
def build_index(self):
|
|
@@ -290,6 +364,11 @@ class IndexManager:
|
|
|
290
364
|
|
|
291
365
|
updated_sources = []
|
|
292
366
|
|
|
367
|
+
total_input_tokens = 0
|
|
368
|
+
total_output_tokens = 0
|
|
369
|
+
total_input_cost = 0.0
|
|
370
|
+
total_output_cost = 0.0
|
|
371
|
+
|
|
293
372
|
with ThreadPoolExecutor(max_workers=self.args.index_build_workers) as executor:
|
|
294
373
|
|
|
295
374
|
wait_to_build_files = []
|
|
@@ -346,6 +425,10 @@ class IndexManager:
|
|
|
346
425
|
num_files=num_files
|
|
347
426
|
)
|
|
348
427
|
module_name = result["module_name"]
|
|
428
|
+
total_input_tokens += result["input_tokens_count"]
|
|
429
|
+
total_output_tokens += result["generated_tokens_count"]
|
|
430
|
+
total_input_cost += result["input_tokens_cost"]
|
|
431
|
+
total_output_cost += result["generated_tokens_cost"]
|
|
349
432
|
index_data[module_name] = result
|
|
350
433
|
updated_sources.append(module_name)
|
|
351
434
|
if len(updated_sources) > 5:
|
|
@@ -357,12 +440,19 @@ class IndexManager:
|
|
|
357
440
|
if updated_sources or keys_to_remove:
|
|
358
441
|
with open(self.index_file, "w") as file:
|
|
359
442
|
json.dump(index_data, file, ensure_ascii=False, indent=2)
|
|
443
|
+
|
|
444
|
+
print("")
|
|
360
445
|
self.printer.print_in_terminal(
|
|
361
446
|
"index_file_saved",
|
|
362
447
|
style="green",
|
|
363
448
|
updated_files=len(updated_sources),
|
|
364
|
-
removed_files=len(keys_to_remove)
|
|
449
|
+
removed_files=len(keys_to_remove),
|
|
450
|
+
input_tokens=total_input_tokens,
|
|
451
|
+
output_tokens=total_output_tokens,
|
|
452
|
+
input_cost=total_input_cost,
|
|
453
|
+
output_cost=total_output_cost
|
|
365
454
|
)
|
|
455
|
+
print("")
|
|
366
456
|
|
|
367
457
|
return index_data
|
|
368
458
|
|
autocoder/models.py
CHANGED
|
@@ -110,6 +110,20 @@ def save_models(models: List[Dict]) -> None:
|
|
|
110
110
|
json.dump(models, f, indent=2, ensure_ascii=False)
|
|
111
111
|
|
|
112
112
|
|
|
113
|
+
def add_and_activate_models(models: List[Dict]) -> None:
|
|
114
|
+
"""
|
|
115
|
+
添加模型
|
|
116
|
+
"""
|
|
117
|
+
exits_models = load_models()
|
|
118
|
+
for model in models:
|
|
119
|
+
if model["name"] not in [m["name"] for m in exits_models]:
|
|
120
|
+
exits_models.append(model)
|
|
121
|
+
save_models(exits_models)
|
|
122
|
+
|
|
123
|
+
for model in models:
|
|
124
|
+
if "api_key" in model:
|
|
125
|
+
update_model_with_api_key(model["name"], model["api_key"])
|
|
126
|
+
|
|
113
127
|
def get_model_by_name(name: str) -> Dict:
|
|
114
128
|
"""
|
|
115
129
|
根据模型名称查找模型
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import yaml
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
|
+
from autocoder.common import AutoCoderArgs
|
|
6
|
+
from autocoder.utils import llms as llm_utils
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelPathFilter:
|
|
10
|
+
def __init__(self,
|
|
11
|
+
model_name: str,
|
|
12
|
+
args: AutoCoderArgs,
|
|
13
|
+
default_forbidden: List[str] = None):
|
|
14
|
+
"""
|
|
15
|
+
模型路径过滤器
|
|
16
|
+
:param model_name: 当前使用的模型名称
|
|
17
|
+
:param args: 自动编码器参数
|
|
18
|
+
:param default_forbidden: 默认禁止路径规则
|
|
19
|
+
"""
|
|
20
|
+
self.model_name = model_name
|
|
21
|
+
if args.model_filter_path:
|
|
22
|
+
self.config_path = Path(args.model_filter_path)
|
|
23
|
+
else:
|
|
24
|
+
self.config_path = Path(args.source_dir, ".model_filters.yml")
|
|
25
|
+
self.default_forbidden = default_forbidden or []
|
|
26
|
+
self._rules_cache: Dict[str, List[re.Pattern]] = {}
|
|
27
|
+
self._load_rules()
|
|
28
|
+
|
|
29
|
+
def _load_rules(self):
|
|
30
|
+
"""加载并编译正则规则"""
|
|
31
|
+
if not self.config_path.exists():
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
with open(self.config_path, 'r', encoding="utf-8") as f:
|
|
35
|
+
config = yaml.safe_load(f)
|
|
36
|
+
|
|
37
|
+
model_rules = config.get('model_filters', {}).get(self.model_name, {})
|
|
38
|
+
all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden
|
|
39
|
+
|
|
40
|
+
# 预编译正则表达式
|
|
41
|
+
self._rules_cache[self.model_name] = [
|
|
42
|
+
re.compile(rule) for rule in all_rules
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def is_accessible(self, file_path: str) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
检查文件路径是否符合访问规则
|
|
48
|
+
:return: True表示允许访问,False表示禁止
|
|
49
|
+
"""
|
|
50
|
+
# 优先使用模型专属规则
|
|
51
|
+
patterns = self._rules_cache.get(self.model_name, [])
|
|
52
|
+
|
|
53
|
+
# 回退到默认规则
|
|
54
|
+
if not patterns and self.default_forbidden:
|
|
55
|
+
patterns = [re.compile(rule) for rule in self.default_forbidden]
|
|
56
|
+
|
|
57
|
+
# 如果路径为空或None,直接返回True
|
|
58
|
+
if not file_path:
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
return not any(pattern.search(file_path) for pattern in patterns)
|
|
62
|
+
|
|
63
|
+
def add_temp_rule(self, rule: str):
|
|
64
|
+
"""
|
|
65
|
+
添加临时规则
|
|
66
|
+
:param rule: 正则表达式规则
|
|
67
|
+
"""
|
|
68
|
+
patterns = self._rules_cache.get(self.model_name, [])
|
|
69
|
+
patterns.append(re.compile(rule))
|
|
70
|
+
self._rules_cache[self.model_name] = patterns
|
|
71
|
+
|
|
72
|
+
def reload_rules(self):
|
|
73
|
+
"""重新加载规则配置"""
|
|
74
|
+
self._rules_cache.clear()
|
|
75
|
+
self._load_rules()
|
|
76
|
+
|
|
77
|
+
def has_rules(self):
|
|
78
|
+
"""检查是否存在规则"""
|
|
79
|
+
return bool(self._rules_cache.get(self.model_name, []))
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_model_object(cls,
|
|
83
|
+
llm_obj,
|
|
84
|
+
args: AutoCoderArgs,
|
|
85
|
+
default_forbidden: Optional[List[str]] = None):
|
|
86
|
+
"""
|
|
87
|
+
从LLM对象创建过滤器
|
|
88
|
+
:param llm_obj: ByzerLLM实例或类似对象
|
|
89
|
+
:param args: 自动编码器参数
|
|
90
|
+
:param default_forbidden: 默认禁止路径规则
|
|
91
|
+
"""
|
|
92
|
+
model_name = ",".join(llm_utils.get_llm_names(llm_obj))
|
|
93
|
+
if not model_name:
|
|
94
|
+
raise ValueError(f"{model_name} is not found")
|
|
95
|
+
|
|
96
|
+
return cls(
|
|
97
|
+
model_name=model_name,
|
|
98
|
+
args=args,
|
|
99
|
+
default_forbidden=default_forbidden
|
|
100
|
+
)
|