auto-coder 0.1.255__py3-none-any.whl → 0.1.257__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

Files changed (30) hide show
  1. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/METADATA +2 -2
  2. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/RECORD +30 -27
  3. autocoder/auto_coder.py +44 -50
  4. autocoder/chat_auto_coder.py +16 -17
  5. autocoder/chat_auto_coder_lang.py +1 -1
  6. autocoder/common/__init__.py +7 -0
  7. autocoder/common/auto_coder_lang.py +46 -16
  8. autocoder/common/code_auto_generate.py +45 -5
  9. autocoder/common/code_auto_generate_diff.py +45 -7
  10. autocoder/common/code_auto_generate_editblock.py +48 -4
  11. autocoder/common/code_auto_generate_strict_diff.py +46 -7
  12. autocoder/common/code_modification_ranker.py +39 -3
  13. autocoder/dispacher/actions/action.py +60 -40
  14. autocoder/dispacher/actions/plugins/action_regex_project.py +12 -6
  15. autocoder/index/entry.py +6 -4
  16. autocoder/index/filter/quick_filter.py +175 -65
  17. autocoder/index/index.py +94 -4
  18. autocoder/models.py +44 -6
  19. autocoder/privacy/__init__.py +3 -0
  20. autocoder/privacy/model_filter.py +100 -0
  21. autocoder/pyproject/__init__.py +1 -0
  22. autocoder/suffixproject/__init__.py +1 -0
  23. autocoder/tsproject/__init__.py +1 -0
  24. autocoder/utils/llms.py +27 -0
  25. autocoder/utils/model_provider_selector.py +192 -0
  26. autocoder/version.py +1 -1
  27. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/LICENSE +0 -0
  28. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/WHEEL +0 -0
  29. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/entry_points.txt +0 -0
  30. {auto_coder-0.1.255.dist-info → auto_coder-0.1.257.dist-info}/top_level.txt +0 -0
autocoder/index/index.py CHANGED
@@ -9,6 +9,7 @@ from autocoder.index.symbols_utils import (
9
9
  SymbolType,
10
10
  symbols_info_to_str,
11
11
  )
12
+ from autocoder.privacy.model_filter import ModelPathFilter
12
13
  from concurrent.futures import ThreadPoolExecutor, as_completed
13
14
  import threading
14
15
 
@@ -17,6 +18,7 @@ import hashlib
17
18
 
18
19
  from autocoder.common.printer import Printer
19
20
  from autocoder.common.auto_coder_lang import get_message
21
+ from autocoder.utils.llms import get_llm_names, get_model_info
20
22
  from autocoder.index.types import (
21
23
  IndexItem,
22
24
  TargetFile,
@@ -30,6 +32,9 @@ class IndexManager:
30
32
  ):
31
33
  self.sources = sources
32
34
  self.source_dir = args.source_dir
35
+ # Initialize model filter for index_llm and index_filter_llm
36
+ self.index_model_filter = None
37
+ self.index_filter_model_filter = None
33
38
  self.anti_quota_limit = (
34
39
  args.index_model_anti_quota_limit or args.anti_quota_limit
35
40
  )
@@ -46,6 +51,12 @@ class IndexManager:
46
51
  self.index_filter_llm = llm
47
52
 
48
53
  self.llm = llm
54
+
55
+ # Initialize model filters
56
+ if self.index_llm:
57
+ self.index_model_filter = ModelPathFilter.from_model_object(self.index_llm, args)
58
+ if self.index_filter_llm:
59
+ self.index_filter_model_filter = ModelPathFilter.from_model_object(self.index_filter_llm, args)
49
60
  self.args = args
50
61
  self.max_input_length = (
51
62
  args.index_model_max_input_length or args.model_max_input_length
@@ -194,6 +205,17 @@ class IndexManager:
194
205
  ext = os.path.splitext(file_path)[1].lower()
195
206
  if ext in [".md", ".html", ".txt", ".doc", ".pdf"]:
196
207
  return True
208
+
209
+ # Check model filter restrictions
210
+ if self.index_model_filter and not self.index_model_filter.is_accessible(file_path):
211
+ self.printer.print_in_terminal(
212
+ "index_file_filtered",
213
+ style="yellow",
214
+ file_path=file_path,
215
+ model_name=",".join(get_llm_names(self.index_llm))
216
+ )
217
+ return True
218
+
197
219
  return False
198
220
 
199
221
  def build_index_for_single_source(self, source: SourceCode):
@@ -212,8 +234,29 @@ class IndexManager:
212
234
  model_name = ",".join(get_llm_names(self.index_llm))
213
235
 
214
236
  try:
237
+ # 获取模型名称列表
238
+ model_names = get_llm_names(self.index_llm)
239
+ model_name = ",".join(model_names)
240
+
241
+ # 获取模型价格信息
242
+ model_info_map = {}
243
+ for name in model_names:
244
+ info = get_model_info(name, self.args.product_mode)
245
+ if info:
246
+ model_info_map[name] = {
247
+ "input_price": info.get("input_price", 0.0),
248
+ "output_price": info.get("output_price", 0.0)
249
+ }
250
+
215
251
  start_time = time.monotonic()
216
252
  source_code = source.source_code
253
+
254
+ # 统计token和成本
255
+ total_input_tokens = 0
256
+ total_output_tokens = 0
257
+ total_input_cost = 0.0
258
+ total_output_cost = 0.0
259
+
217
260
  if len(source.source_code) > self.max_input_length:
218
261
  self.printer.print_in_terminal(
219
262
  "index_file_too_large",
@@ -227,15 +270,38 @@ class IndexManager:
227
270
  )
228
271
  symbols = []
229
272
  for chunk in chunks:
273
+ meta_holder = byzerllm.MetaHolder()
230
274
  chunk_symbols = self.get_all_file_symbols.with_llm(
231
- self.index_llm).run(source.module_name, chunk)
275
+ self.index_llm).with_meta(meta_holder).run(source.module_name, chunk)
232
276
  time.sleep(self.anti_quota_limit)
233
277
  symbols.append(chunk_symbols)
278
+
279
+ if meta_holder.get_meta():
280
+ meta_dict = meta_holder.get_meta()
281
+ total_input_tokens += meta_dict.get("input_tokens_count", 0)
282
+ total_output_tokens += meta_dict.get("generated_tokens_count", 0)
283
+
234
284
  symbols = "\n".join(symbols)
235
285
  else:
286
+ meta_holder = byzerllm.MetaHolder()
236
287
  symbols = self.get_all_file_symbols.with_llm(
237
- self.index_llm).run(source.module_name, source_code)
288
+ self.index_llm).with_meta(meta_holder).run(source.module_name, source_code)
238
289
  time.sleep(self.anti_quota_limit)
290
+
291
+ if meta_holder.get_meta():
292
+ meta_dict = meta_holder.get_meta()
293
+ total_input_tokens += meta_dict.get("input_tokens_count", 0)
294
+ total_output_tokens += meta_dict.get("generated_tokens_count", 0)
295
+
296
+ # 计算总成本
297
+ for name in model_names:
298
+ info = model_info_map.get(name, {})
299
+ total_input_cost += (total_input_tokens * info.get("input_price", 0.0)) / 1000000
300
+ total_output_cost += (total_output_tokens * info.get("output_price", 0.0)) / 1000000
301
+
302
+ # 四舍五入到4位小数
303
+ total_input_cost = round(total_input_cost, 4)
304
+ total_output_cost = round(total_output_cost, 4)
239
305
 
240
306
  self.printer.print_in_terminal(
241
307
  "index_update_success",
@@ -243,7 +309,11 @@ class IndexManager:
243
309
  file_path=file_path,
244
310
  md5=md5,
245
311
  duration=time.monotonic() - start_time,
246
- model_name=model_name
312
+ model_name=model_name,
313
+ input_tokens=total_input_tokens,
314
+ output_tokens=total_output_tokens,
315
+ input_cost=total_input_cost,
316
+ output_cost=total_output_cost
247
317
  )
248
318
 
249
319
  except Exception as e:
@@ -263,6 +333,10 @@ class IndexManager:
263
333
  "symbols": symbols,
264
334
  "last_modified": os.path.getmtime(file_path),
265
335
  "md5": md5,
336
+ "input_tokens_count": total_input_tokens,
337
+ "generated_tokens_count": total_output_tokens,
338
+ "input_tokens_cost": total_input_cost,
339
+ "generated_tokens_cost": total_output_cost
266
340
  }
267
341
 
268
342
  def build_index(self):
@@ -290,6 +364,11 @@ class IndexManager:
290
364
 
291
365
  updated_sources = []
292
366
 
367
+ total_input_tokens = 0
368
+ total_output_tokens = 0
369
+ total_input_cost = 0.0
370
+ total_output_cost = 0.0
371
+
293
372
  with ThreadPoolExecutor(max_workers=self.args.index_build_workers) as executor:
294
373
 
295
374
  wait_to_build_files = []
@@ -346,6 +425,10 @@ class IndexManager:
346
425
  num_files=num_files
347
426
  )
348
427
  module_name = result["module_name"]
428
+ total_input_tokens += result["input_tokens_count"]
429
+ total_output_tokens += result["generated_tokens_count"]
430
+ total_input_cost += result["input_tokens_cost"]
431
+ total_output_cost += result["generated_tokens_cost"]
349
432
  index_data[module_name] = result
350
433
  updated_sources.append(module_name)
351
434
  if len(updated_sources) > 5:
@@ -357,12 +440,19 @@ class IndexManager:
357
440
  if updated_sources or keys_to_remove:
358
441
  with open(self.index_file, "w") as file:
359
442
  json.dump(index_data, file, ensure_ascii=False, indent=2)
443
+
444
+ print("")
360
445
  self.printer.print_in_terminal(
361
446
  "index_file_saved",
362
447
  style="green",
363
448
  updated_files=len(updated_sources),
364
- removed_files=len(keys_to_remove)
449
+ removed_files=len(keys_to_remove),
450
+ input_tokens=total_input_tokens,
451
+ output_tokens=total_output_tokens,
452
+ input_cost=total_input_cost,
453
+ output_cost=total_output_cost
365
454
  )
455
+ print("")
366
456
 
367
457
  return index_data
368
458
 
autocoder/models.py CHANGED
@@ -110,6 +110,20 @@ def save_models(models: List[Dict]) -> None:
110
110
  json.dump(models, f, indent=2, ensure_ascii=False)
111
111
 
112
112
 
113
+ def add_and_activate_models(models: List[Dict]) -> None:
114
+ """
115
+ 添加模型
116
+ """
117
+ exits_models = load_models()
118
+ for model in models:
119
+ if model["name"] not in [m["name"] for m in exits_models]:
120
+ exits_models.append(model)
121
+ save_models(exits_models)
122
+
123
+ for model in models:
124
+ if "api_key" in model:
125
+ update_model_with_api_key(model["name"], model["api_key"])
126
+
113
127
  def get_model_by_name(name: str) -> Dict:
114
128
  """
115
129
  根据模型名称查找模型
@@ -127,11 +141,23 @@ def update_model_input_price(name: str, price: float) -> bool:
127
141
  """更新模型输入价格
128
142
 
129
143
  Args:
130
- name: 模型名称
131
- price: 输入价格(M/百万input tokens)
144
+ name (str): 要更新的模型名称,必须与models.json中的记录匹配
145
+ price (float): 新的输入价格,单位:美元/百万tokens。必须大于等于0
132
146
 
133
147
  Returns:
134
- bool: 是否更新成功
148
+ bool: 是否成功找到并更新了模型价格
149
+
150
+ Raises:
151
+ ValueError: 如果price为负数时抛出
152
+
153
+ Example:
154
+ >>> update_model_input_price("gpt-4", 3.0)
155
+ True
156
+
157
+ Notes:
158
+ 1. 价格设置后会立即生效并保存到models.json
159
+ 2. 实际费用计算时会按实际使用量精确到小数点后6位
160
+ 3. 设置价格为0表示该模型当前不可用
135
161
  """
136
162
  if price < 0:
137
163
  raise ValueError("Price cannot be negative")
@@ -151,11 +177,23 @@ def update_model_output_price(name: str, price: float) -> bool:
151
177
  """更新模型输出价格
152
178
 
153
179
  Args:
154
- name: 模型名称
155
- price: 输出价格(M/百万output tokens)
180
+ name (str): 要更新的模型名称,必须与models.json中的记录匹配
181
+ price (float): 新的输出价格,单位:美元/百万tokens。必须大于等于0
156
182
 
157
183
  Returns:
158
- bool: 是否更新成功
184
+ bool: 是否成功找到并更新了模型价格
185
+
186
+ Raises:
187
+ ValueError: 如果price为负数时抛出
188
+
189
+ Example:
190
+ >>> update_model_output_price("gpt-4", 6.0)
191
+ True
192
+
193
+ Notes:
194
+ 1. 输出价格通常比输入价格高30%-50%
195
+ 2. 对于按token计费的API,实际收费按(input_tokens * input_price + output_tokens * output_price)计算
196
+ 3. 价格变更会影响所有依赖模型计费的功能(如成本预测、用量监控等)
159
197
  """
160
198
  if price < 0:
161
199
  raise ValueError("Price cannot be negative")
@@ -0,0 +1,3 @@
1
+ from .model_filter import ModelPathFilter
2
+
3
+ __all__ = ["ModelPathFilter"]
@@ -0,0 +1,100 @@
1
+ import re
2
+ import yaml
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional
5
+ from autocoder.common import AutoCoderArgs
6
+ from autocoder.utils import llms as llm_utils
7
+
8
+
9
+ class ModelPathFilter:
10
+ def __init__(self,
11
+ model_name: str,
12
+ args: AutoCoderArgs,
13
+ default_forbidden: List[str] = None):
14
+ """
15
+ 模型路径过滤器
16
+ :param model_name: 当前使用的模型名称
17
+ :param args: 自动编码器参数
18
+ :param default_forbidden: 默认禁止路径规则
19
+ """
20
+ self.model_name = model_name
21
+ if args.model_filter_path:
22
+ self.config_path = Path(args.model_filter_path)
23
+ else:
24
+ self.config_path = Path(args.source_dir, ".model_filters.yml")
25
+ self.default_forbidden = default_forbidden or []
26
+ self._rules_cache: Dict[str, List[re.Pattern]] = {}
27
+ self._load_rules()
28
+
29
+ def _load_rules(self):
30
+ """加载并编译正则规则"""
31
+ if not self.config_path.exists():
32
+ return
33
+
34
+ with open(self.config_path, 'r', encoding="utf-8") as f:
35
+ config = yaml.safe_load(f)
36
+
37
+ model_rules = config.get('model_filters', {}).get(self.model_name, {})
38
+ all_rules = model_rules.get('forbidden_paths', []) + self.default_forbidden
39
+
40
+ # 预编译正则表达式
41
+ self._rules_cache[self.model_name] = [
42
+ re.compile(rule) for rule in all_rules
43
+ ]
44
+
45
+ def is_accessible(self, file_path: str) -> bool:
46
+ """
47
+ 检查文件路径是否符合访问规则
48
+ :return: True表示允许访问,False表示禁止
49
+ """
50
+ # 优先使用模型专属规则
51
+ patterns = self._rules_cache.get(self.model_name, [])
52
+
53
+ # 回退到默认规则
54
+ if not patterns and self.default_forbidden:
55
+ patterns = [re.compile(rule) for rule in self.default_forbidden]
56
+
57
+ # 如果路径为空或None,直接返回True
58
+ if not file_path:
59
+ return True
60
+
61
+ return not any(pattern.search(file_path) for pattern in patterns)
62
+
63
+ def add_temp_rule(self, rule: str):
64
+ """
65
+ 添加临时规则
66
+ :param rule: 正则表达式规则
67
+ """
68
+ patterns = self._rules_cache.get(self.model_name, [])
69
+ patterns.append(re.compile(rule))
70
+ self._rules_cache[self.model_name] = patterns
71
+
72
+ def reload_rules(self):
73
+ """重新加载规则配置"""
74
+ self._rules_cache.clear()
75
+ self._load_rules()
76
+
77
+ def has_rules(self):
78
+ """检查是否存在规则"""
79
+ return bool(self._rules_cache.get(self.model_name, []))
80
+
81
+ @classmethod
82
+ def from_model_object(cls,
83
+ llm_obj,
84
+ args: AutoCoderArgs,
85
+ default_forbidden: Optional[List[str]] = None):
86
+ """
87
+ 从LLM对象创建过滤器
88
+ :param llm_obj: ByzerLLM实例或类似对象
89
+ :param args: 自动编码器参数
90
+ :param default_forbidden: 默认禁止路径规则
91
+ """
92
+ model_name = ",".join(llm_utils.get_llm_names(llm_obj))
93
+ if not model_name:
94
+ raise ValueError(f"{model_name} is not found")
95
+
96
+ return cls(
97
+ model_name=model_name,
98
+ args=args,
99
+ default_forbidden=default_forbidden
100
+ )
@@ -116,6 +116,7 @@ class PyProject:
116
116
  "actions",
117
117
  ".vscode",
118
118
  ".idea",
119
+ "venv",
119
120
  ]
120
121
 
121
122
  @byzerllm.prompt()
@@ -56,6 +56,7 @@ class SuffixProject:
56
56
  ".vscode",
57
57
  "actions",
58
58
  ".idea",
59
+ "venv",
59
60
  ]
60
61
 
61
62
  @byzerllm.prompt()
@@ -48,6 +48,7 @@ class TSProject:
48
48
  "actions",
49
49
  ".vscode",
50
50
  ".idea",
51
+ "venv",
51
52
  ]
52
53
 
53
54
  @byzerllm.prompt()
autocoder/utils/llms.py CHANGED
@@ -3,9 +3,15 @@ from typing import Union,Optional
3
3
 
4
4
  def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],target_model_type:Optional[str]=None):
5
5
  if target_model_type is None:
6
+ if isinstance(llm,list):
7
+ return [_llm.default_model_name for _llm in llm]
6
8
  return [llm.default_model_name for llm in [llm] if llm.default_model_name]
9
+
7
10
  llms = llm.get_sub_client(target_model_type)
11
+
8
12
  if llms is None:
13
+ if isinstance(llm,list):
14
+ return [_llm.default_model_name for _llm in llm]
9
15
  return [llm.default_model_name for llm in [llm] if llm.default_model_name]
10
16
  elif isinstance(llms, list):
11
17
  return [llm.default_model_name for llm in llms if llm.default_model_name]
@@ -14,6 +20,27 @@ def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],tar
14
20
  else:
15
21
  return [llm.default_model_name for llm in [llms] if llm.default_model_name]
16
22
 
23
+ def get_model_info(model_names: str, product_mode: str):
24
+ from autocoder import models as models_module
25
+ def get_model_by_name(model_name: str):
26
+ try:
27
+ return models_module.get_model_by_name(model_name)
28
+ except Exception as e:
29
+ return None
30
+
31
+ if product_mode == "pro":
32
+ return None
33
+
34
+ if product_mode == "lite":
35
+ if "," in model_names:
36
+ # Multiple code models specified
37
+ model_names = model_names.split(",")
38
+ for _, model_name in enumerate(model_names):
39
+ return get_model_by_name(model_name)
40
+ else:
41
+ # Single code model
42
+ return get_model_by_name(model_names)
43
+
17
44
  def get_single_llm(model_names: str, product_mode: str):
18
45
  from autocoder import models as models_module
19
46
  if product_mode == "pro":
@@ -0,0 +1,192 @@
1
+ from prompt_toolkit.shortcuts import radiolist_dialog, input_dialog
2
+ from prompt_toolkit.validation import Validator, ValidationError
3
+ from rich.console import Console
4
+ from typing import Optional, Dict, Any, List
5
+ from autocoder.common.printer import Printer
6
+ import re
7
+ from pydantic import BaseModel
8
+
9
+ from autocoder.models import process_api_key_path
10
+
11
+ class ProviderInfo(BaseModel):
12
+ name: str
13
+ endpoint: str
14
+ r1_model: str
15
+ v3_model: str
16
+ api_key: str
17
+ r1_input_price: float
18
+ r1_output_price: float
19
+ v3_input_price: float
20
+ v3_output_price: float
21
+
22
+
23
+ PROVIDER_INFO_LIST = [
24
+ ProviderInfo(
25
+ name="volcano",
26
+ endpoint="https://ark.cn-beijing.volces.com/api/v3",
27
+ r1_model="",
28
+ v3_model="",
29
+ api_key="",
30
+ r1_input_price=2.0,
31
+ r1_output_price=8.0,
32
+ v3_input_price=1.0,
33
+ v3_output_price=4.0,
34
+ ),
35
+ ProviderInfo(
36
+ name="siliconFlow",
37
+ endpoint="https://api.siliconflow.cn/v1",
38
+ r1_model="Pro/deepseek-ai/DeepSeek-R1",
39
+ v3_model="Pro/deepseek-ai/DeepSeek-V3",
40
+ api_key="",
41
+ r1_input_price=2.0,
42
+ r1_output_price=4.0,
43
+ v3_input_price=4.0,
44
+ v3_output_price=16.0,
45
+ ),
46
+ ProviderInfo(
47
+ name="deepseek",
48
+ endpoint="https://api.deepseek.com/v1",
49
+ r1_model="deepseek-reasoner",
50
+ v3_model="deepseek-chat",
51
+ api_key="",
52
+ r1_input_price=4.0,
53
+ r1_output_price=16.0,
54
+ v3_input_price=2.0,
55
+ v3_output_price=8.0,
56
+ ),
57
+ ]
58
+
59
+ class VolcanoEndpointValidator(Validator):
60
+ def validate(self, document):
61
+ text = document.text
62
+ pattern = r'^ep-\d{14}-[a-z0-9]{5}$'
63
+ if not re.match(pattern, text):
64
+ raise ValidationError(
65
+ message='Invalid endpoint format. Should be like: ep-20250204215011-vzbsg',
66
+ cursor_position=len(text)
67
+ )
68
+
69
+ class ModelProviderSelector:
70
+ def __init__(self):
71
+ self.printer = Printer()
72
+ self.console = Console()
73
+
74
+ def to_models_json(self, provider_info: ProviderInfo) -> List[Dict[str, Any]]:
75
+ """
76
+ Convert provider info to models.json format.
77
+ Returns a list of model configurations matching the format in models.py default_models_list.
78
+
79
+ Args:
80
+ provider_info: ProviderInfo object containing provider details
81
+
82
+ Returns:
83
+ List[Dict[str, Any]]: List of model configurations
84
+ """
85
+ models = []
86
+
87
+ # Add R1 model (for reasoning/design/review)
88
+ if provider_info.r1_model:
89
+ models.append({
90
+ "name": f"r1_chat",
91
+ "description": f"{provider_info.name} R1 is for design/review",
92
+ "model_name": provider_info.r1_model,
93
+ "model_type": "saas/openai",
94
+ "base_url": provider_info.endpoint,
95
+ "api_key": provider_info.api_key,
96
+ "api_key_path": f"r1_chat",
97
+ "is_reasoning": True,
98
+ "input_price": provider_info.r1_input_price,
99
+ "output_price": provider_info.r1_output_price,
100
+ "average_speed": 0.0
101
+ })
102
+
103
+ # Add V3 model (for coding)
104
+ if provider_info.v3_model:
105
+ models.append({
106
+ "name": f"v3_chat",
107
+ "description": f"{provider_info.name} Chat is for coding",
108
+ "model_name": provider_info.v3_model,
109
+ "model_type": "saas/openai",
110
+ "base_url": provider_info.endpoint,
111
+ "api_key": provider_info.api_key,
112
+ "api_key_path": f"v3_chat",
113
+ "is_reasoning": False,
114
+ "input_price": provider_info.v3_input_price,
115
+ "output_price": provider_info.v3_output_price,
116
+ "average_speed": 0.0
117
+ })
118
+
119
+ return models
120
+
121
+ def select_provider(self) -> Optional[Dict[str, Any]]:
122
+ """
123
+ Let user select a model provider and input necessary credentials.
124
+ Returns a dictionary with provider info or None if cancelled.
125
+ """
126
+ result = radiolist_dialog(
127
+ title=self.printer.get_message_from_key("model_provider_select_title"),
128
+ text=self.printer.get_message_from_key("model_provider_select_text"),
129
+ values=[
130
+ ("volcano", self.printer.get_message_from_key("model_provider_volcano")),
131
+ ("siliconflow", self.printer.get_message_from_key("model_provider_guiji")),
132
+ ("deepseek", self.printer.get_message_from_key("model_provider_deepseek"))
133
+ ]
134
+ ).run()
135
+
136
+ if result is None:
137
+ return None
138
+
139
+
140
+ provider_info = None
141
+ for provider in PROVIDER_INFO_LIST:
142
+ if provider.name == result:
143
+ provider_info = provider
144
+ break
145
+
146
+ if result == "volcano":
147
+ # Get R1 endpoint
148
+ r1_endpoint = input_dialog(
149
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
150
+ text=self.printer.get_message_from_key("model_provider_volcano_r1_text"),
151
+ validator=VolcanoEndpointValidator()
152
+ ).run()
153
+
154
+ if r1_endpoint is None:
155
+ return None
156
+
157
+ provider_info.r1_model = r1_endpoint
158
+
159
+ # Get V3 endpoint
160
+ v3_endpoint = input_dialog(
161
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
162
+ text=self.printer.get_message_from_key("model_provider_volcano_v3_text"),
163
+ validator=VolcanoEndpointValidator()
164
+ ).run()
165
+
166
+ if v3_endpoint is None:
167
+ return None
168
+
169
+ provider_info.v3_model = v3_endpoint
170
+
171
+ # Get API key for all providers
172
+ api_key = input_dialog(
173
+ title=self.printer.get_message_from_key("model_provider_api_key_title"),
174
+ text=self.printer.get_message_from_key(f"model_provider_{result}_api_key_text"),
175
+ password=True
176
+ ).run()
177
+
178
+ if api_key is None:
179
+ return None
180
+
181
+ provider_info.api_key = api_key
182
+
183
+ self.printer.print_panel(
184
+ self.printer.get_message_from_key("model_provider_selected"),
185
+ text_options={"justify": "left"},
186
+ panel_options={
187
+ "title": self.printer.get_message_from_key("model_provider_success_title"),
188
+ "border_style": "green"
189
+ }
190
+ )
191
+
192
+ return provider_info
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.255"
1
+ __version__ = "0.1.257"