auto-coder 0.1.255__py3-none-any.whl → 0.1.257__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of auto-coder might be problematic. Click here for more details.
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/METADATA +2 -2
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/RECORD +30 -27
- autocoder/auto_coder.py +44 -50
- autocoder/chat_auto_coder.py +16 -17
- autocoder/chat_auto_coder_lang.py +1 -1
- autocoder/common/__init__.py +7 -0
- autocoder/common/auto_coder_lang.py +46 -16
- autocoder/common/code_auto_generate.py +45 -5
- autocoder/common/code_auto_generate_diff.py +45 -7
- autocoder/common/code_auto_generate_editblock.py +48 -4
- autocoder/common/code_auto_generate_strict_diff.py +46 -7
- autocoder/common/code_modification_ranker.py +39 -3
- autocoder/dispacher/actions/action.py +60 -40
- autocoder/dispacher/actions/plugins/action_regex_project.py +12 -6
- autocoder/index/entry.py +6 -4
- autocoder/index/filter/quick_filter.py +175 -65
- autocoder/index/index.py +94 -4
- autocoder/models.py +44 -6
- autocoder/privacy/__init__.py +3 -0
- autocoder/privacy/model_filter.py +100 -0
- autocoder/pyproject/__init__.py +1 -0
- autocoder/suffixproject/__init__.py +1 -0
- autocoder/tsproject/__init__.py +1 -0
- autocoder/utils/llms.py +27 -0
- autocoder/utils/model_provider_selector.py +192 -0
- autocoder/version.py +1 -1
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/LICENSE +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/WHEEL +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/entry_points.txt +0 -0
- {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/top_level.txt +0 -0
autocoder/index/index.py
CHANGED
|
@@ -9,6 +9,7 @@ from autocoder.index.symbols_utils import (
|
|
|
9
9
|
SymbolType,
|
|
10
10
|
symbols_info_to_str,
|
|
11
11
|
)
|
|
12
|
+
from autocoder.privacy.model_filter import ModelPathFilter
|
|
12
13
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
13
14
|
import threading
|
|
14
15
|
|
|
@@ -17,6 +18,7 @@ import hashlib
|
|
|
17
18
|
|
|
18
19
|
from autocoder.common.printer import Printer
|
|
19
20
|
from autocoder.common.auto_coder_lang import get_message
|
|
21
|
+
from autocoder.utils.llms import get_llm_names, get_model_info
|
|
20
22
|
from autocoder.index.types import (
|
|
21
23
|
IndexItem,
|
|
22
24
|
TargetFile,
|
|
@@ -30,6 +32,9 @@ class IndexManager:
|
|
|
30
32
|
):
|
|
31
33
|
self.sources = sources
|
|
32
34
|
self.source_dir = args.source_dir
|
|
35
|
+
# Initialize model filter for index_llm and index_filter_llm
|
|
36
|
+
self.index_model_filter = None
|
|
37
|
+
self.index_filter_model_filter = None
|
|
33
38
|
self.anti_quota_limit = (
|
|
34
39
|
args.index_model_anti_quota_limit or args.anti_quota_limit
|
|
35
40
|
)
|
|
@@ -46,6 +51,12 @@ class IndexManager:
|
|
|
46
51
|
self.index_filter_llm = llm
|
|
47
52
|
|
|
48
53
|
self.llm = llm
|
|
54
|
+
|
|
55
|
+
# Initialize model filters
|
|
56
|
+
if self.index_llm:
|
|
57
|
+
self.index_model_filter = ModelPathFilter.from_model_object(self.index_llm, args)
|
|
58
|
+
if self.index_filter_llm:
|
|
59
|
+
self.index_filter_model_filter = ModelPathFilter.from_model_object(self.index_filter_llm, args)
|
|
49
60
|
self.args = args
|
|
50
61
|
self.max_input_length = (
|
|
51
62
|
args.index_model_max_input_length or args.model_max_input_length
|
|
@@ -194,6 +205,17 @@ class IndexManager:
|
|
|
194
205
|
ext = os.path.splitext(file_path)[1].lower()
|
|
195
206
|
if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
|
|
196
207
|
return True
|
|
208
|
+
|
|
209
|
+
# Check model filter restrictions
|
|
210
|
+
if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
|
|
211
|
+
self.printer.print_in_terminal(
|
|
212
|
+
"index_file_filtered",
|
|
213
|
+
style="yellow",
|
|
214
|
+
file_path=file_path,
|
|
215
|
+
model_name=",".join(get_llm_names(self.index_llm))
|
|
216
|
+
)
|
|
217
|
+
return True
|
|
218
|
+
|
|
197
219
|
return False
|
|
198
220
|
|
|
199
221
|
def build_index_for_single_source(self, source: SourceCode):
|
|
@@ -212,8 +234,29 @@ class IndexManager:
|
|
|
212
234
|
model_name = ",".join(get_llm_names(self.index_llm))
|
|
213
235
|
|
|
214
236
|
try:
|
|
237
|
+
# 获取模型名称列表
|
|
238
|
+
model_names = get_llm_names(self.index_llm)
|
|
239
|
+
model_name = ",".join(model_names)
|
|
240
|
+
|
|
241
|
+
# 获取模型价格信息
|
|
242
|
+
model_info_map = {}
|
|
243
|
+
for name in model_names:
|
|
244
|
+
info = get_model_info(name, self.args.product_mode)
|
|
245
|
+
if info:
|
|
246
|
+
model_info_map[name] = {
|
|
247
|
+
"input_price": info.get("input_price", 0.0),
|
|
248
|
+
"output_price": info.get("output_price", 0.0)
|
|
249
|
+
}
|
|
250
|
+
|
|
215
251
|
start_time = time.monotonic()
|
|
216
252
|
source_code = source.source_code
|
|
253
|
+
|
|
254
|
+
# 统计token和成本
|
|
255
|
+
total_input_tokens = 0
|
|
256
|
+
total_output_tokens = 0
|
|
257
|
+
total_input_cost = 0.0
|
|
258
|
+
total_output_cost = 0.0
|
|
259
|
+
|
|
217
260
|
if len(source.source_code) > self.max_input_length:
|
|
218
261
|
self.printer.print_in_terminal(
|
|
219
262
|
"index_file_too_large",
|
|
@@ -227,15 +270,38 @@ class IndexManager:
|
|
|
227
270
|
)
|
|
228
271
|
symbols = []
|
|
229
272
|
for chunk in chunks:
|
|
273
|
+
meta_holder = byzerllm.MetaHolder()
|
|
230
274
|
chunk_symbols = self.get_all_file_symbols.with_llm(
|
|
231
|
-
self.index_llm).run(source.module_name, chunk)
|
|
275
|
+
self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
|
|
232
276
|
time.sleep(self.anti_quota_limit)
|
|
233
277
|
symbols.append(chunk_symbols)
|
|
278
|
+
|
|
279
|
+
if meta_holder.get_meta():
|
|
280
|
+
meta_dict = meta_holder.get_meta()
|
|
281
|
+
total_input_tokens += meta_dict.get("input_tokens_count", 0)
|
|
282
|
+
total_output_tokens += meta_dict.get("generated_tokens_count", 0)
|
|
283
|
+
|
|
234
284
|
symbols = "\n".join(symbols)
|
|
235
285
|
else:
|
|
286
|
+
meta_holder = byzerllm.MetaHolder()
|
|
236
287
|
symbols = self.get_all_file_symbols.with_llm(
|
|
237
|
-
self.index_llm).run(source.module_name, source_code)
|
|
288
|
+
self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
|
|
238
289
|
time.sleep(self.anti_quota_limit)
|
|
290
|
+
|
|
291
|
+
if meta_holder.get_meta():
|
|
292
|
+
meta_dict = meta_holder.get_meta()
|
|
293
|
+
total_input_tokens += meta_dict.get("input_tokens_count", 0)
|
|
294
|
+
total_output_tokens += meta_dict.get("generated_tokens_count", 0)
|
|
295
|
+
|
|
296
|
+
# 计算总成本
|
|
297
|
+
for name in model_names:
|
|
298
|
+
info = model_info_map.get(name, {})
|
|
299
|
+
total_input_cost += (total_input_tokens * info.get("input_price", 0.0)) / 1000000
|
|
300
|
+
total_output_cost += (total_output_tokens * info.get("output_price", 0.0)) / 1000000
|
|
301
|
+
|
|
302
|
+
# 四舍五入到4位小数
|
|
303
|
+
total_input_cost = round(total_input_cost, 4)
|
|
304
|
+
total_output_cost = round(total_output_cost, 4)
|
|
239
305
|
|
|
240
306
|
self.printer.print_in_terminal(
|
|
241
307
|
"index_update_success",
|
|
@@ -243,7 +309,11 @@ class IndexManager:
|
|
|
243
309
|
file_path=file_path,
|
|
244
310
|
md5=md5,
|
|
245
311
|
duration=time.monotonic() - start_time,
|
|
246
|
-
model_name=model_name
|
|
312
|
+
model_name=model_name,
|
|
313
|
+
input_tokens=total_input_tokens,
|
|
314
|
+
output_tokens=total_output_tokens,
|
|
315
|
+
input_cost=total_input_cost,
|
|
316
|
+
output_cost=total_output_cost
|
|
247
317
|
)
|
|
248
318
|
|
|
249
319
|
except Exception as e:
|
|
@@ -263,6 +333,10 @@ class IndexManager:
|
|
|
263
333
|
"symbols": symbols,
|
|
264
334
|
"last_modified": os.path.getmtime(file_path),
|
|
265
335
|
"md5": md5,
|
|
336
|
+
"input_tokens_count": total_input_tokens,
|
|
337
|
+
"generated_tokens_count": total_output_tokens,
|
|
338
|
+
"input_tokens_cost": total_input_cost,
|
|
339
|
+
"generated_tokens_cost": total_output_cost
|
|
266
340
|
}
|
|
267
341
|
|
|
268
342
|
def build_index(self):
|
|
@@ -290,6 +364,11 @@ class IndexManager:
|
|
|
290
364
|
|
|
291
365
|
updated_sources = []
|
|
292
366
|
|
|
367
|
+
total_input_tokens = 0
|
|
368
|
+
total_output_tokens = 0
|
|
369
|
+
total_input_cost = 0.0
|
|
370
|
+
total_output_cost = 0.0
|
|
371
|
+
|
|
293
372
|
with ThreadPoolExecutor(max_workers=self.args.index_build_workers) as executor:
|
|
294
373
|
|
|
295
374
|
wait_to_build_files = []
|
|
@@ -346,6 +425,10 @@ class IndexManager:
|
|
|
346
425
|
num_files=num_files
|
|
347
426
|
)
|
|
348
427
|
module_name = result["module_name"]
|
|
428
|
+
total_input_tokens += result["input_tokens_count"]
|
|
429
|
+
total_output_tokens += result["generated_tokens_count"]
|
|
430
|
+
total_input_cost += result["input_tokens_cost"]
|
|
431
|
+
total_output_cost += result["generated_tokens_cost"]
|
|
349
432
|
index_data[module_name] = result
|
|
350
433
|
updated_sources.append(module_name)
|
|
351
434
|
if len(updated_sources) > 5:
|
|
@@ -357,12 +440,19 @@ class IndexManager:
|
|
|
357
440
|
if updated_sources or keys_to_remove:
|
|
358
441
|
with open(self.index_file, "w") as file:
|
|
359
442
|
json.dump(index_data, file, ensure_ascii=False, indent=2)
|
|
443
|
+
|
|
444
|
+
print("")
|
|
360
445
|
self.printer.print_in_terminal(
|
|
361
446
|
"index_file_saved",
|
|
362
447
|
style="green",
|
|
363
448
|
updated_files=len(updated_sources),
|
|
364
|
-
removed_files=len(keys_to_remove)
|
|
449
|
+
removed_files=len(keys_to_remove),
|
|
450
|
+
input_tokens=total_input_tokens,
|
|
451
|
+
output_tokens=total_output_tokens,
|
|
452
|
+
input_cost=total_input_cost,
|
|
453
|
+
output_cost=total_output_cost
|
|
365
454
|
)
|
|
455
|
+
print("")
|
|
366
456
|
|
|
367
457
|
return index_data
|
|
368
458
|
|
autocoder/models.py
CHANGED
|
@@ -110,6 +110,20 @@ def save_models(models: List[Dict]) -> None:
|
|
|
110
110
|
json.dump(models, f, indent=2, ensure_ascii=False)
|
|
111
111
|
|
|
112
112
|
|
|
113
|
+
def add_and_activate_models(models: List[Dict]) -> None:
|
|
114
|
+
"""
|
|
115
|
+
添加模型
|
|
116
|
+
"""
|
|
117
|
+
exits_models = load_models()
|
|
118
|
+
for model in models:
|
|
119
|
+
if model["name"] not in [m["name"] for m in exits_models]:
|
|
120
|
+
exits_models.append(model)
|
|
121
|
+
save_models(exits_models)
|
|
122
|
+
|
|
123
|
+
for model in models:
|
|
124
|
+
if "api_key" in model:
|
|
125
|
+
update_model_with_api_key(model["name"], model["api_key"])
|
|
126
|
+
|
|
113
127
|
def get_model_by_name(name: str) -> Dict:
|
|
114
128
|
"""
|
|
115
129
|
根据模型名称查找模型
|
|
@@ -127,11 +141,23 @@ def update_model_input_price(name: str, price: float) -> bool:
|
|
|
127
141
|
"""更新模型输入价格
|
|
128
142
|
|
|
129
143
|
Args:
|
|
130
|
-
name:
|
|
131
|
-
price
|
|
144
|
+
name (str): 要更新的模型名称,必须与models.json中的记录匹配
|
|
145
|
+
price (float): 新的输入价格,单位:美元/百万tokens。必须大于等于0
|
|
132
146
|
|
|
133
147
|
Returns:
|
|
134
|
-
bool:
|
|
148
|
+
bool: 是否成功找到并更新了模型价格
|
|
149
|
+
|
|
150
|
+
Raises:
|
|
151
|
+
ValueError: 如果price为负数时抛出
|
|
152
|
+
|
|
153
|
+
Example:
|
|
154
|
+
>>> update_model_input_price("gpt-4", 3.0)
|
|
155
|
+
True
|
|
156
|
+
|
|
157
|
+
Notes:
|
|
158
|
+
1. 价格设置后会立即生效并保存到models.json
|
|
159
|
+
2. 实际费用计算时会按实际使用量精确到小数点后6位
|
|
160
|
+
3. 设置价格为0表示该模型当前不可用
|
|
135
161
|
"""
|
|
136
162
|
if price < 0:
|
|
137
163
|
raise ValueError("Price cannot be negative")
|
|
@@ -151,11 +177,23 @@ def update_model_output_price(name: str, price: float) -> bool:
|
|
|
151
177
|
"""更新模型输出价格
|
|
152
178
|
|
|
153
179
|
Args:
|
|
154
|
-
name:
|
|
155
|
-
price
|
|
180
|
+
name (str): 要更新的模型名称,必须与models.json中的记录匹配
|
|
181
|
+
price (float): 新的输出价格,单位:美元/百万tokens。必须大于等于0
|
|
156
182
|
|
|
157
183
|
Returns:
|
|
158
|
-
bool:
|
|
184
|
+
bool: 是否成功找到并更新了模型价格
|
|
185
|
+
|
|
186
|
+
Raises:
|
|
187
|
+
ValueError: 如果price为负数时抛出
|
|
188
|
+
|
|
189
|
+
Example:
|
|
190
|
+
>>> update_model_output_price("gpt-4", 6.0)
|
|
191
|
+
True
|
|
192
|
+
|
|
193
|
+
Notes:
|
|
194
|
+
1. 输出价格通常比输入价格高30%-50%
|
|
195
|
+
2. 对于按token计费的API,实际收费按(input_tokens * input_price + output_tokens * output_price)计算
|
|
196
|
+
3. 价格变更会影响所有依赖模型计费的功能(如成本预测、用量监控等)
|
|
159
197
|
"""
|
|
160
198
|
if price < 0:
|
|
161
199
|
raise ValueError("Price cannot be negative")
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import yaml
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Dict, List, Optional
|
|
5
|
+
from autocoder.common import AutoCoderArgs
|
|
6
|
+
from autocoder.utils import llms as llm_utils
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ModelPathFilter:
|
|
10
|
+
def __init__(self,
|
|
11
|
+
model_name: str,
|
|
12
|
+
args: AutoCoderArgs,
|
|
13
|
+
default_forbidden: List[str] = None):
|
|
14
|
+
"""
|
|
15
|
+
模型路径过滤器
|
|
16
|
+
:param model_name: 当前使用的模型名称
|
|
17
|
+
:param args: 自动编码器参数
|
|
18
|
+
:param default_forbidden: 默认禁止路径规则
|
|
19
|
+
"""
|
|
20
|
+
self.model_name = model_name
|
|
21
|
+
if args.model_filter_path:
|
|
22
|
+
self.config_path = Path(args.model_filter_path)
|
|
23
|
+
else:
|
|
24
|
+
self.config_path = Path(args.source_dir, ".model_filters.yml")
|
|
25
|
+
self.default_forbidden = default_forbidden or []
|
|
26
|
+
self._rules_cache: Dict[str, List[re.Pattern]] = {}
|
|
27
|
+
self._load_rules()
|
|
28
|
+
|
|
29
|
+
def _load_rules(self):
|
|
30
|
+
"""加载并编译正则规则"""
|
|
31
|
+
if not self.config_path.exists():
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
with open(self.config_path, 'r', encoding="utf-8") as f:
|
|
35
|
+
config = yaml.safe_load(f)
|
|
36
|
+
|
|
37
|
+
model_rules = config.get('model_filters', {}).get(self.model_name, {})
|
|
38
|
+
all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden
|
|
39
|
+
|
|
40
|
+
# 预编译正则表达式
|
|
41
|
+
self._rules_cache[self.model_name] = [
|
|
42
|
+
re.compile(rule) for rule in all_rules
|
|
43
|
+
]
|
|
44
|
+
|
|
45
|
+
def is_accessible(self, file_path: str) -> bool:
|
|
46
|
+
"""
|
|
47
|
+
检查文件路径是否符合访问规则
|
|
48
|
+
:return: True表示允许访问,False表示禁止
|
|
49
|
+
"""
|
|
50
|
+
# 优先使用模型专属规则
|
|
51
|
+
patterns = self._rules_cache.get(self.model_name, [])
|
|
52
|
+
|
|
53
|
+
# 回退到默认规则
|
|
54
|
+
if not patterns and self.default_forbidden:
|
|
55
|
+
patterns = [re.compile(rule) for rule in self.default_forbidden]
|
|
56
|
+
|
|
57
|
+
# 如果路径为空或None,直接返回True
|
|
58
|
+
if not file_path:
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
return not any(pattern.search(file_path) for pattern in patterns)
|
|
62
|
+
|
|
63
|
+
def add_temp_rule(self, rule: str):
|
|
64
|
+
"""
|
|
65
|
+
添加临时规则
|
|
66
|
+
:param rule: 正则表达式规则
|
|
67
|
+
"""
|
|
68
|
+
patterns = self._rules_cache.get(self.model_name, [])
|
|
69
|
+
patterns.append(re.compile(rule))
|
|
70
|
+
self._rules_cache[self.model_name] = patterns
|
|
71
|
+
|
|
72
|
+
def reload_rules(self):
|
|
73
|
+
"""重新加载规则配置"""
|
|
74
|
+
self._rules_cache.clear()
|
|
75
|
+
self._load_rules()
|
|
76
|
+
|
|
77
|
+
def has_rules(self):
|
|
78
|
+
"""检查是否存在规则"""
|
|
79
|
+
return bool(self._rules_cache.get(self.model_name, []))
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def from_model_object(cls,
|
|
83
|
+
llm_obj,
|
|
84
|
+
args: AutoCoderArgs,
|
|
85
|
+
default_forbidden: Optional[List[str]] = None):
|
|
86
|
+
"""
|
|
87
|
+
从LLM对象创建过滤器
|
|
88
|
+
:param llm_obj: ByzerLLM实例或类似对象
|
|
89
|
+
:param args: 自动编码器参数
|
|
90
|
+
:param default_forbidden: 默认禁止路径规则
|
|
91
|
+
"""
|
|
92
|
+
model_name = ",".join(llm_utils.get_llm_names(llm_obj))
|
|
93
|
+
if not model_name:
|
|
94
|
+
raise ValueError(f"{model_name} is not found")
|
|
95
|
+
|
|
96
|
+
return cls(
|
|
97
|
+
model_name=model_name,
|
|
98
|
+
args=args,
|
|
99
|
+
default_forbidden=default_forbidden
|
|
100
|
+
)
|
autocoder/pyproject/__init__.py
CHANGED
autocoder/tsproject/__init__.py
CHANGED
autocoder/utils/llms.py
CHANGED
|
@@ -3,9 +3,15 @@ from typing import Union,Optional
|
|
|
3
3
|
|
|
4
4
|
def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],target_model_type:Optional[str]=None):
|
|
5
5
|
if target_model_type is None:
|
|
6
|
+
if isinstance(llm,list):
|
|
7
|
+
return [_llm.default_model_name for _llm in llm]
|
|
6
8
|
return [llm.default_model_name for llm in [llm] if llm.default_model_name]
|
|
9
|
+
|
|
7
10
|
llms = llm.get_sub_client(target_model_type)
|
|
11
|
+
|
|
8
12
|
if llms is None:
|
|
13
|
+
if isinstance(llm,list):
|
|
14
|
+
return [_llm.default_model_name for _llm in llm]
|
|
9
15
|
return [llm.default_model_name for llm in [llm] if llm.default_model_name]
|
|
10
16
|
elif isinstance(llms, list):
|
|
11
17
|
return [llm.default_model_name for llm in llms if llm.default_model_name]
|
|
@@ -14,6 +20,27 @@ def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],tar
|
|
|
14
20
|
else:
|
|
15
21
|
return [llm.default_model_name for llm in [llms] if llm.default_model_name]
|
|
16
22
|
|
|
23
|
+
def get_model_info(model_names: str, product_mode: str):
|
|
24
|
+
from autocoder import models as models_module
|
|
25
|
+
def get_model_by_name(model_name: str):
|
|
26
|
+
try:
|
|
27
|
+
return models_module.get_model_by_name(model_name)
|
|
28
|
+
except Exception as e:
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
if product_mode == "pro":
|
|
32
|
+
return None
|
|
33
|
+
|
|
34
|
+
if product_mode == "lite":
|
|
35
|
+
if "," in model_names:
|
|
36
|
+
# Multiple code models specified
|
|
37
|
+
model_names = model_names.split(",")
|
|
38
|
+
for _, model_name in enumerate(model_names):
|
|
39
|
+
return get_model_by_name(model_name)
|
|
40
|
+
else:
|
|
41
|
+
# Single code model
|
|
42
|
+
return get_model_by_name(model_names)
|
|
43
|
+
|
|
17
44
|
def get_single_llm(model_names: str, product_mode: str):
|
|
18
45
|
from autocoder import models as models_module
|
|
19
46
|
if product_mode == "pro":
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog
|
|
2
|
+
from prompt_toolkit.validation import Validator, ValidationError
|
|
3
|
+
from rich.console import Console
|
|
4
|
+
from typing import Optional, Dict, Any, List
|
|
5
|
+
from autocoder.common.printer import Printer
|
|
6
|
+
import re
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
|
|
9
|
+
from autocoder.models import process_api_key_path
|
|
10
|
+
|
|
11
|
+
class ProviderInfo(BaseModel):
|
|
12
|
+
name: str
|
|
13
|
+
endpoint: str
|
|
14
|
+
r1_model: str
|
|
15
|
+
v3_model: str
|
|
16
|
+
api_key: str
|
|
17
|
+
r1_input_price: float
|
|
18
|
+
r1_output_price: float
|
|
19
|
+
v3_input_price: float
|
|
20
|
+
v3_output_price: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
PROVIDER_INFO_LIST = [
|
|
24
|
+
ProviderInfo(
|
|
25
|
+
name="volcano",
|
|
26
|
+
endpoint="https://ark.cn-beijing.volces.com/api/v3",
|
|
27
|
+
r1_model="",
|
|
28
|
+
v3_model="",
|
|
29
|
+
api_key="",
|
|
30
|
+
r1_input_price=2.0,
|
|
31
|
+
r1_output_price=8.0,
|
|
32
|
+
v3_input_price=1.0,
|
|
33
|
+
v3_output_price=4.0,
|
|
34
|
+
),
|
|
35
|
+
ProviderInfo(
|
|
36
|
+
name="siliconFlow",
|
|
37
|
+
endpoint="https://api.siliconflow.cn/v1",
|
|
38
|
+
r1_model="Pro/deepseek-ai/DeepSeek-R1",
|
|
39
|
+
v3_model="Pro/deepseek-ai/DeepSeek-V3",
|
|
40
|
+
api_key="",
|
|
41
|
+
r1_input_price=2.0,
|
|
42
|
+
r1_output_price=4.0,
|
|
43
|
+
v3_input_price=4.0,
|
|
44
|
+
v3_output_price=16.0,
|
|
45
|
+
),
|
|
46
|
+
ProviderInfo(
|
|
47
|
+
name="deepseek",
|
|
48
|
+
endpoint="https://api.deepseek.com/v1",
|
|
49
|
+
r1_model="deepseek-reasoner",
|
|
50
|
+
v3_model="deepseek-chat",
|
|
51
|
+
api_key="",
|
|
52
|
+
r1_input_price=4.0,
|
|
53
|
+
r1_output_price=16.0,
|
|
54
|
+
v3_input_price=2.0,
|
|
55
|
+
v3_output_price=8.0,
|
|
56
|
+
),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
class VolcanoEndpointValidator(Validator):
|
|
60
|
+
def validate(self, document):
|
|
61
|
+
text = document.text
|
|
62
|
+
pattern = r'^ep-\d{14}-[a-z0-9]{5}$'
|
|
63
|
+
if not re.match(pattern, text):
|
|
64
|
+
raise ValidationError(
|
|
65
|
+
message='Invalid endpoint format. Should be like: ep-20250204215011-vzbsg',
|
|
66
|
+
cursor_position=len(text)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
class ModelProviderSelector:
|
|
70
|
+
def __init__(self):
|
|
71
|
+
self.printer = Printer()
|
|
72
|
+
self.console = Console()
|
|
73
|
+
|
|
74
|
+
def to_models_json(self, provider_info: ProviderInfo) -> List[Dict[str, Any]]:
|
|
75
|
+
"""
|
|
76
|
+
Convert provider info to models.json format.
|
|
77
|
+
Returns a list of model configurations matching the format in models.py default_models_list.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
provider_info: ProviderInfo object containing provider details
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
List[Dict[str, Any]]: List of model configurations
|
|
84
|
+
"""
|
|
85
|
+
models = []
|
|
86
|
+
|
|
87
|
+
# Add R1 model (for reasoning/design/review)
|
|
88
|
+
if provider_info.r1_model:
|
|
89
|
+
models.append({
|
|
90
|
+
"name": f"r1_chat",
|
|
91
|
+
"description": f"{provider_info.name} R1 is for design/review",
|
|
92
|
+
"model_name": provider_info.r1_model,
|
|
93
|
+
"model_type": "saas/openai",
|
|
94
|
+
"base_url": provider_info.endpoint,
|
|
95
|
+
"api_key": provider_info.api_key,
|
|
96
|
+
"api_key_path": f"r1_chat",
|
|
97
|
+
"is_reasoning": True,
|
|
98
|
+
"input_price": provider_info.r1_input_price,
|
|
99
|
+
"output_price": provider_info.r1_output_price,
|
|
100
|
+
"average_speed": 0.0
|
|
101
|
+
})
|
|
102
|
+
|
|
103
|
+
# Add V3 model (for coding)
|
|
104
|
+
if provider_info.v3_model:
|
|
105
|
+
models.append({
|
|
106
|
+
"name": f"v3_chat",
|
|
107
|
+
"description": f"{provider_info.name} Chat is for coding",
|
|
108
|
+
"model_name": provider_info.v3_model,
|
|
109
|
+
"model_type": "saas/openai",
|
|
110
|
+
"base_url": provider_info.endpoint,
|
|
111
|
+
"api_key": provider_info.api_key,
|
|
112
|
+
"api_key_path": f"v3_chat",
|
|
113
|
+
"is_reasoning": False,
|
|
114
|
+
"input_price": provider_info.v3_input_price,
|
|
115
|
+
"output_price": provider_info.v3_output_price,
|
|
116
|
+
"average_speed": 0.0
|
|
117
|
+
})
|
|
118
|
+
|
|
119
|
+
return models
|
|
120
|
+
|
|
121
|
+
def select_provider(self) -> Optional[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Let user select a model provider and input necessary credentials.
|
|
124
|
+
Returns a dictionary with provider info or None if cancelled.
|
|
125
|
+
"""
|
|
126
|
+
result = radiolist_dialog(
|
|
127
|
+
title=self.printer.get_message_from_key("model_provider_select_title"),
|
|
128
|
+
text=self.printer.get_message_from_key("model_provider_select_text"),
|
|
129
|
+
values=[
|
|
130
|
+
("volcano", self.printer.get_message_from_key("model_provider_volcano")),
|
|
131
|
+
("siliconflow", self.printer.get_message_from_key("model_provider_guiji")),
|
|
132
|
+
("deepseek", self.printer.get_message_from_key("model_provider_deepseek"))
|
|
133
|
+
]
|
|
134
|
+
).run()
|
|
135
|
+
|
|
136
|
+
if result is None:
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
provider_info = None
|
|
141
|
+
for provider in PROVIDER_INFO_LIST:
|
|
142
|
+
if provider.name == result:
|
|
143
|
+
provider_info = provider
|
|
144
|
+
break
|
|
145
|
+
|
|
146
|
+
if result == "volcano":
|
|
147
|
+
# Get R1 endpoint
|
|
148
|
+
r1_endpoint = input_dialog(
|
|
149
|
+
title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
150
|
+
text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
|
|
151
|
+
validator=VolcanoEndpointValidator()
|
|
152
|
+
).run()
|
|
153
|
+
|
|
154
|
+
if r1_endpoint is None:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
provider_info.r1_model = r1_endpoint
|
|
158
|
+
|
|
159
|
+
# Get V3 endpoint
|
|
160
|
+
v3_endpoint = input_dialog(
|
|
161
|
+
title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
162
|
+
text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
|
|
163
|
+
validator=VolcanoEndpointValidator()
|
|
164
|
+
).run()
|
|
165
|
+
|
|
166
|
+
if v3_endpoint is None:
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
provider_info.v3_model = v3_endpoint
|
|
170
|
+
|
|
171
|
+
# Get API key for all providers
|
|
172
|
+
api_key = input_dialog(
|
|
173
|
+
title=self.printer.get_message_from_key("model_provider_api_key_title"),
|
|
174
|
+
text=self.printer.get_message_from_key(f"model_provider_{result}_api_key_text"),
|
|
175
|
+
password=True
|
|
176
|
+
).run()
|
|
177
|
+
|
|
178
|
+
if api_key is None:
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
provider_info.api_key = api_key
|
|
182
|
+
|
|
183
|
+
self.printer.print_panel(
|
|
184
|
+
self.printer.get_message_from_key("model_provider_selected"),
|
|
185
|
+
text_options={"justify": "left"},
|
|
186
|
+
panel_options={
|
|
187
|
+
"title": self.printer.get_message_from_key("model_provider_success_title"),
|
|
188
|
+
"border_style": "green"
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
return provider_info
|
autocoder/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.257"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|