auto-coder 0.1.254__py3-none-any.whl → 0.1.256__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -125,13 +125,17 @@ class ActionTSProject(BaseAction):
125
125
  query=args.query, source_content=content
126
126
  )
127
127
  elapsed_time = time.time() - start_time
128
- speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
129
- model_names = ",".join(get_llm_names(self.llm))
128
+ speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
129
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
130
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
131
+ model_names = ",".join(get_llm_names(generate.llms))
130
132
  self.printer.print_in_terminal(
131
133
  "code_generation_complete",
132
134
  duration=elapsed_time,
133
135
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
134
136
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
137
+ input_cost=input_tokens_cost,
138
+ output_cost=generated_tokens_cost,
135
139
  speed=round(speed, 2),
136
140
  model_names=model_names
137
141
  )
@@ -221,12 +225,16 @@ class ActionPyScriptProject(BaseAction):
221
225
 
222
226
  elapsed_time = time.time() - start_time
223
227
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
224
- model_names = ",".join(get_llm_names(self.llm))
228
+ model_names = ",".join(get_llm_names(generate.llms))
229
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
230
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
225
231
  self.printer.print_in_terminal(
226
232
  "code_generation_complete",
227
233
  duration=elapsed_time,
228
234
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
229
235
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
236
+ input_cost=input_tokens_cost,
237
+ output_cost=generated_tokens_cost,
230
238
  speed=round(speed, 2),
231
239
  model_names=model_names
232
240
  )
@@ -264,13 +272,7 @@ class ActionPyScriptProject(BaseAction):
264
272
  model=self.llm.default_model_name,
265
273
  )
266
274
 
267
- end_time = time.time()
268
- self.printer.print_in_terminal(
269
- "code_generation_complete",
270
- duration=end_time - start_time,
271
- input_tokens=generate_result.metadata.get('input_tokens_count', 0),
272
- output_tokens=generate_result.metadata.get('generated_tokens_count', 0)
273
- )
275
+ end_time = time.time()
274
276
  with open(self.args.target_file, "w") as file:
275
277
  file.write(content)
276
278
 
@@ -348,12 +350,16 @@ class ActionPyProject(BaseAction):
348
350
  )
349
351
  elapsed_time = time.time() - start_time
350
352
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
351
- model_names = ",".join(get_llm_names(self.llm))
353
+ model_names = ",".join(get_llm_names(generate.llms))
354
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
355
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
352
356
  self.printer.print_in_terminal(
353
357
  "code_generation_complete",
354
358
  duration=elapsed_time,
355
359
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
356
360
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
361
+ input_cost=input_tokens_cost,
362
+ output_cost=generated_tokens_cost,
357
363
  speed=round(speed, 2),
358
364
  model_names=model_names
359
365
  )
@@ -458,12 +464,16 @@ class ActionSuffixProject(BaseAction):
458
464
 
459
465
  elapsed_time = time.time() - start_time
460
466
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
461
- model_names = ",".join(get_llm_names(self.llm))
467
+ model_names = ",".join(get_llm_names(generate.llms))
468
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
469
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
462
470
  self.printer.print_in_terminal(
463
471
  "code_generation_complete",
464
472
  duration=elapsed_time,
465
473
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
466
474
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
475
+ input_cost=input_tokens_cost,
476
+ output_cost=generated_tokens_cost,
467
477
  speed=round(speed, 2),
468
478
  model_names=model_names
469
479
  )
@@ -88,11 +88,15 @@ class ActionRegexProject:
88
88
  elapsed_time = time.time() - start_time
89
89
  speed = generate_result.metadata.get('generated_tokens_count', 0) / elapsed_time if elapsed_time > 0 else 0
90
90
  model_names = ",".join(get_llm_names(self.llm))
91
+ input_tokens_cost = generate_result.metadata.get('input_tokens_cost', 0)
92
+ generated_tokens_cost = generate_result.metadata.get('generated_tokens_cost', 0)
91
93
  self.printer.print_in_terminal(
92
94
  "code_generation_complete",
93
95
  duration=elapsed_time,
94
96
  input_tokens=generate_result.metadata.get('input_tokens_count', 0),
95
97
  output_tokens=generate_result.metadata.get('generated_tokens_count', 0),
98
+ input_cost=input_tokens_cost,
99
+ output_cost=generated_tokens_cost,
96
100
  speed=round(speed, 2),
97
101
  model_names=model_names
98
102
  )
@@ -4,21 +4,21 @@ from autocoder.utils.auto_coder_utils.chat_stream_out import stream_out
4
4
  from autocoder.common.utils_code_auto_generate import stream_chat_with_continue
5
5
  from byzerllm.utils.str2model import to_model
6
6
  from autocoder.index.types import IndexItem
7
- from autocoder.common import AutoCoderArgs,SourceCode
7
+ from autocoder.common import AutoCoderArgs, SourceCode
8
8
  import byzerllm
9
9
  import time
10
10
  from autocoder.index.index import IndexManager
11
11
  from autocoder.index.types import (
12
12
  IndexItem,
13
- TargetFile,
13
+ TargetFile,
14
14
  FileNumberList
15
15
  )
16
16
  from autocoder.rag.token_counter import count_tokens
17
17
  from autocoder.common.printer import Printer
18
18
  from concurrent.futures import ThreadPoolExecutor
19
- import threading
19
+ from byzerllm import MetaHolder
20
20
 
21
- from autocoder.utils.llms import get_llm_names
21
+ from autocoder.utils.llms import get_llm_names, get_model_info
22
22
 
23
23
 
24
24
  def get_file_path(file_path):
@@ -32,8 +32,9 @@ class QuickFilterResult(BaseModel):
32
32
  has_error: bool
33
33
  error_message: Optional[str] = None
34
34
 
35
+
35
36
  class QuickFilter():
36
- def __init__(self, index_manager: IndexManager,stats:Dict[str,Any],sources:List[SourceCode]):
37
+ def __init__(self, index_manager: IndexManager, stats: Dict[str, Any], sources: List[SourceCode]):
37
38
  self.index_manager = index_manager
38
39
  self.args = index_manager.args
39
40
  self.stats = stats
@@ -41,72 +42,142 @@ class QuickFilter():
41
42
  self.printer = Printer()
42
43
  self.max_tokens = self.args.index_filter_model_max_input_length
43
44
 
44
-
45
45
  def big_filter(self, index_items: List[IndexItem],) -> QuickFilterResult:
46
46
  chunks = []
47
47
  current_chunk = []
48
-
48
+
49
49
  # 将 index_items 切分成多个 chunks,第一个chunk尽可能接近max_tokens
50
50
  for item in index_items:
51
51
  # 使用 quick_filter_files.prompt 生成文本再统计
52
52
  temp_chunk = current_chunk + [item]
53
- prompt_text = self.quick_filter_files.prompt(temp_chunk, self.args.query)
54
- temp_size = count_tokens(prompt_text)
53
+ prompt_text = self.quick_filter_files.prompt(
54
+ temp_chunk, self.args.query)
55
+ temp_size = count_tokens(prompt_text)
55
56
  # 如果当前chunk为空,或者添加item后不超过max_tokens,就添加到当前chunk
56
57
  if not current_chunk or temp_size <= self.max_tokens:
57
- current_chunk.append(item)
58
+ current_chunk.append(item)
58
59
  else:
59
60
  # 当前chunk已满,创建新chunk
60
61
  chunks.append(current_chunk)
61
- current_chunk = [item]
62
-
62
+ current_chunk = [item]
63
+
63
64
  if current_chunk:
64
65
  chunks.append(current_chunk)
65
-
66
- tokens_len = count_tokens(self.quick_filter_files.prompt(index_items, self.args.query))
66
+
67
+ tokens_len = count_tokens(
68
+ self.quick_filter_files.prompt(index_items, self.args.query))
67
69
  self.printer.print_in_terminal(
68
- "quick_filter_too_long",
69
- style="yellow",
70
- tokens_len=tokens_len,
71
- max_tokens=self.max_tokens,
72
- split_size=len(chunks)
73
- )
70
+ "quick_filter_too_long",
71
+ style="yellow",
72
+ tokens_len=tokens_len,
73
+ max_tokens=self.max_tokens,
74
+ split_size=len(chunks)
75
+ )
74
76
 
75
77
  def process_chunk(chunk_index: int, chunk: List[IndexItem]) -> QuickFilterResult:
76
78
  try:
77
- model_name = ",".join(get_llm_names(self.index_manager.index_filter_llm))
79
+ # 获取模型名称列表
80
+ model_names = get_llm_names(
81
+ self.index_manager.index_filter_llm)
82
+ model_name = ",".join(model_names)
78
83
  files: Dict[str, TargetFile] = {}
79
-
84
+
85
+ # 获取模型价格信息
86
+ model_info_map = {}
87
+ for name in model_names:
88
+ # 第二个参数是产品模式,从args中获取
89
+ info = get_model_info(name, self.args.product_mode)
90
+ if info:
91
+ model_info_map[name] = {
92
+ # 每百万tokens成本
93
+ "input_price": info.get("input_price", 0.0),
94
+ # 每百万tokens成本
95
+ "output_price": info.get("output_price", 0.0)
96
+ }
97
+
80
98
  if chunk_index == 0:
81
99
  # 第一个chunk使用流式输出
82
100
  stream_generator = stream_chat_with_continue(
83
101
  self.index_manager.index_filter_llm,
84
- [{"role": "user", "content": self.quick_filter_files.prompt(chunk, self.args.query)}],
102
+ [{"role": "user", "content": self.quick_filter_files.prompt(
103
+ chunk, self.args.query)}],
85
104
  {}
86
105
  )
87
- full_response, _ = stream_out(
106
+ full_response, last_meta = stream_out(
88
107
  stream_generator,
89
108
  model_name=model_name,
90
- title=self.printer.get_message_from_key_with_format("quick_filter_title", model_name=model_name),
109
+ title=self.printer.get_message_from_key_with_format(
110
+ "quick_filter_title", model_name=model_name),
91
111
  args=self.args
92
112
  )
93
113
  file_number_list = to_model(full_response, FileNumberList)
114
+
115
+ # 计算总成本
116
+ total_input_cost = 0.0
117
+ total_output_cost = 0.0
118
+
119
+ for name in model_names:
120
+ info = model_info_map.get(name, {})
121
+ # 计算公式:token数 * 单价 / 1000000
122
+ total_input_cost += (last_meta.input_tokens_count *
123
+ info.get("input_price", 0.0)) / 1000000
124
+ total_output_cost += (last_meta.generated_tokens_count *
125
+ info.get("output_price", 0.0)) / 1000000
126
+
127
+ # 四舍五入到4位小数
128
+ total_input_cost = round(total_input_cost, 4)
129
+ total_output_cost = round(total_output_cost, 4)
130
+
131
+ # 打印 token 统计信息和成本
132
+ self.printer.print_in_terminal(
133
+ "quick_filter_stats",
134
+ style="blue",
135
+ input_tokens=last_meta.input_tokens_count,
136
+ output_tokens=last_meta.generated_tokens_count,
137
+ input_cost=total_input_cost,
138
+ output_cost=total_output_cost,
139
+ model_names=model_name
140
+ )
94
141
  else:
95
142
  # 其他chunks直接使用with_llm
96
- file_number_list = self.quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_return_type(FileNumberList).run(chunk, self.args.query)
97
-
143
+ meta_holder = MetaHolder()
144
+ start_time = time.monotonic()
145
+ file_number_list = self.quick_filter_files.with_llm(self.index_manager.index_filter_llm).with_meta(
146
+ meta_holder).with_return_type(FileNumberList).run(chunk, self.args.query)
147
+ end_time = time.monotonic()
148
+
149
+ total_input_cost = 0.0
150
+ total_output_cost = 0.0
151
+ if meta_holder.get_meta():
152
+ meta_dict = meta_holder.get_meta()
153
+ total_input_cost = meta_dict.get("input_tokens_count", 0) * model_info_map.get(model_name, {}).get("input_price", 0.0) / 1000000
154
+ total_output_cost = meta_dict.get("generated_tokens_count", 0) * model_info_map.get(model_name, {}).get("output_price", 0.0) / 1000000
155
+
156
+ self.printer.print_in_terminal(
157
+ "quick_filter_stats",
158
+ style="blue",
159
+ input_tokens=meta_dict.get("input_tokens_count", 0),
160
+ output_tokens=meta_dict.get("generated_tokens_count", 0),
161
+ input_cost=total_input_cost,
162
+ output_cost=total_output_cost,
163
+ model_names=model_name,
164
+ elapsed_time=f"{end_time - start_time:.2f}"
165
+ )
166
+
98
167
  if file_number_list:
99
168
  for file_number in file_number_list.file_list:
100
- file_path = get_file_path(chunk[file_number].module_name)
169
+ file_path = get_file_path(
170
+ chunk[file_number].module_name)
101
171
  files[file_path] = TargetFile(
102
172
  file_path=chunk[file_number].module_name,
103
- reason=self.printer.get_message_from_key("quick_filter_reason")
173
+ reason=self.printer.get_message_from_key(
174
+ "quick_filter_reason")
104
175
  )
105
176
  return QuickFilterResult(
106
177
  files=files,
107
178
  has_error=False
108
179
  )
109
-
180
+
110
181
  except Exception as e:
111
182
  self.printer.print_in_terminal(
112
183
  "quick_filter_failed",
@@ -123,25 +194,25 @@ class QuickFilter():
123
194
  if chunks:
124
195
  with ThreadPoolExecutor() as executor:
125
196
  # 提交所有chunks到线程池并收集结果
126
- futures = [executor.submit(process_chunk, i, chunk)
127
- for i, chunk in enumerate(chunks)]
128
-
197
+ futures = [executor.submit(process_chunk, i, chunk)
198
+ for i, chunk in enumerate(chunks)]
199
+
129
200
  # 等待所有任务完成并收集结果
130
201
  for future in futures:
131
202
  results.append(future.result())
132
-
203
+
133
204
  # 合并所有结果
134
205
  final_files: Dict[str, TargetFile] = {}
135
206
  has_error = False
136
207
  error_messages: List[str] = []
137
-
208
+
138
209
  for result in results:
139
210
  if result.has_error:
140
211
  has_error = True
141
212
  if result.error_message:
142
213
  error_messages.append(result.error_message)
143
214
  final_files.update(result.files)
144
-
215
+
145
216
  return QuickFilterResult(
146
217
  files=final_files,
147
218
  has_error=has_error,
@@ -149,7 +220,7 @@ class QuickFilter():
149
220
  )
150
221
 
151
222
  @byzerllm.prompt()
152
- def quick_filter_files(self,file_meta_list:List[IndexItem],query:str) -> str:
223
+ def quick_filter_files(self, file_meta_list: List[IndexItem], query: str) -> str:
153
224
  '''
154
225
  当用户提一个需求的时候,我们需要找到相关的文件,然后阅读这些文件,并且修改其中部分文件。
155
226
  现在,给定下面的索引文件:
@@ -160,7 +231,7 @@ class QuickFilter():
160
231
 
161
232
  索引文件包含文件序号(##[]括起来的部分),文件路径,文件符号信息等。
162
233
  下面是用户的查询需求:
163
-
234
+
164
235
  <query>
165
236
  {{ query }}
166
237
  </query>
@@ -182,63 +253,101 @@ class QuickFilter():
182
253
  2. 如果 query 里是一段历史对话,那么对话里的内容提及的文件路径必须要返回。
183
254
  3. json格式数据不允许有注释
184
255
  '''
185
- file_meta_str = "\n".join([f"##[{index}]{item.module_name}\n{item.symbols}" for index,item in enumerate(file_meta_list)])
256
+ file_meta_str = "\n".join(
257
+ [f"##[{index}]{item.module_name}\n{item.symbols}" for index, item in enumerate(file_meta_list)])
186
258
  context = {
187
259
  "content": file_meta_str,
188
260
  "query": query
189
261
  }
190
- return context
262
+ return context
191
263
 
192
264
  def filter(self, index_items: List[IndexItem], query: str) -> QuickFilterResult:
193
265
  final_files: Dict[str, TargetFile] = {}
194
- start_time = time.monotonic()
266
+ start_time = time.monotonic()
267
+
268
+ prompt_str = self.quick_filter_files.prompt(index_items, query)
269
+
270
+ tokens_len = count_tokens(prompt_str)
195
271
 
196
- prompt_str = self.quick_filter_files.prompt(index_items,query)
197
-
198
- tokens_len = count_tokens(prompt_str)
199
-
200
272
  # Print current index size
201
273
  self.printer.print_in_terminal(
202
274
  "quick_filter_tokens_len",
203
275
  style="blue",
204
276
  tokens_len=tokens_len
205
277
  )
206
-
207
- if tokens_len > self.max_tokens:
278
+
279
+ if tokens_len > self.max_tokens:
208
280
  return self.big_filter(index_items)
209
-
281
+
210
282
  try:
211
- model_name = ",".join(get_llm_names(self.index_manager.index_filter_llm))
212
-
283
+ # 获取模型名称
284
+ model_names = get_llm_names(self.index_manager.index_filter_llm)
285
+ model_name = ",".join(model_names)
286
+
287
+ # 获取模型价格信息
288
+ model_info_map = {}
289
+ for name in model_names:
290
+ # 第二个参数是产品模式,从args中获取
291
+ info = get_model_info(name, self.args.product_mode)
292
+ if info:
293
+ model_info_map[name] = {
294
+ # 每百万tokens成本
295
+ "input_price": info.get("input_price", 0.0),
296
+ # 每百万tokens成本
297
+ "output_price": info.get("output_price", 0.0)
298
+ }
299
+
213
300
  # 渲染 Prompt 模板
214
- query = self.quick_filter_files.prompt(index_items, self.args.query)
215
-
301
+ query = self.quick_filter_files.prompt(
302
+ index_items, self.args.query)
303
+
216
304
  # 使用流式输出处理
217
305
  stream_generator = stream_chat_with_continue(
218
306
  self.index_manager.index_filter_llm,
219
307
  [{"role": "user", "content": query}],
220
308
  {}
221
309
  )
222
-
310
+
223
311
  # 获取完整响应
224
312
  full_response, last_meta = stream_out(
225
313
  stream_generator,
226
314
  model_name=model_name,
227
- title=self.printer.get_message_from_key_with_format("quick_filter_title", model_name=model_name),
315
+ title=self.printer.get_message_from_key_with_format(
316
+ "quick_filter_title", model_name=model_name),
228
317
  args=self.args
229
- )
318
+ )
230
319
  # 解析结果
231
320
  file_number_list = to_model(full_response, FileNumberList)
232
- end_time = time.monotonic()
233
- # 打印 token 统计信息
321
+ end_time = time.monotonic()
322
+
323
+ # 计算总成本
324
+ total_input_cost = 0.0
325
+ total_output_cost = 0.0
326
+
327
+ for name in model_names:
328
+ info = model_info_map.get(name, {})
329
+ # 计算公式:token数 * 单价 / 1000000
330
+ total_input_cost += (last_meta.input_tokens_count *
331
+ info.get("input_price", 0.0)) / 1000000
332
+ total_output_cost += (last_meta.generated_tokens_count *
333
+ info.get("output_price", 0.0)) / 1000000
334
+
335
+ # 四舍五入到4位小数
336
+ total_input_cost = round(total_input_cost, 4)
337
+ total_output_cost = round(total_output_cost, 4)
338
+
339
+ # 打印 token 统计信息和成本
234
340
  self.printer.print_in_terminal(
235
- "quick_filter_stats",
341
+ "quick_filter_stats",
236
342
  style="blue",
237
343
  elapsed_time=f"{end_time - start_time:.2f}",
238
344
  input_tokens=last_meta.input_tokens_count,
239
- output_tokens=last_meta.generated_tokens_count
345
+ output_tokens=last_meta.generated_tokens_count,
346
+ input_cost=total_input_cost,
347
+ output_cost=total_output_cost,
348
+ model_names=model_name
240
349
  )
241
-
350
+
242
351
  except Exception as e:
243
352
  self.printer.print_in_terminal(
244
353
  "quick_filter_failed",
@@ -250,16 +359,17 @@ class QuickFilter():
250
359
  has_error=True,
251
360
  error_message=str(e)
252
361
  )
253
-
362
+
254
363
  if file_number_list:
255
364
  for file_number in file_number_list.file_list:
256
365
  final_files[get_file_path(index_items[file_number].module_name)] = TargetFile(
257
366
  file_path=index_items[file_number].module_name,
258
- reason=self.printer.get_message_from_key("quick_filter_reason")
367
+ reason=self.printer.get_message_from_key(
368
+ "quick_filter_reason")
259
369
  )
260
- end_time = time.monotonic()
261
- self.stats["timings"]["quick_filter"] = end_time - start_time
370
+ end_time = time.monotonic()
371
+ self.stats["timings"]["quick_filter"] = end_time - start_time
262
372
  return QuickFilterResult(
263
373
  files=final_files,
264
374
  has_error=False
265
- )
375
+ )
autocoder/models.py CHANGED
@@ -127,11 +127,23 @@ def update_model_input_price(name: str, price: float) -> bool:
127
127
  """更新模型输入价格
128
128
 
129
129
  Args:
130
- name: 模型名称
131
- price: 输入价格(M/百万input tokens)
130
+ name (str): 要更新的模型名称,必须与models.json中的记录匹配
131
+ price (float): 新的输入价格,单位:美元/百万tokens。必须大于等于0
132
132
 
133
133
  Returns:
134
- bool: 是否更新成功
134
+ bool: 是否成功找到并更新了模型价格
135
+
136
+ Raises:
137
+ ValueError: 如果price为负数时抛出
138
+
139
+ Example:
140
+ >>> update_model_input_price("gpt-4", 3.0)
141
+ True
142
+
143
+ Notes:
144
+ 1. 价格设置后会立即生效并保存到models.json
145
+ 2. 实际费用计算时会按实际使用量精确到小数点后6位
146
+ 3. 设置价格为0表示该模型当前不可用
135
147
  """
136
148
  if price < 0:
137
149
  raise ValueError("Price cannot be negative")
@@ -151,11 +163,23 @@ def update_model_output_price(name: str, price: float) -> bool:
151
163
  """更新模型输出价格
152
164
 
153
165
  Args:
154
- name: 模型名称
155
- price: 输出价格(M/百万output tokens)
166
+ name (str): 要更新的模型名称,必须与models.json中的记录匹配
167
+ price (float): 新的输出价格,单位:美元/百万tokens。必须大于等于0
156
168
 
157
169
  Returns:
158
- bool: 是否更新成功
170
+ bool: 是否成功找到并更新了模型价格
171
+
172
+ Raises:
173
+ ValueError: 如果price为负数时抛出
174
+
175
+ Example:
176
+ >>> update_model_output_price("gpt-4", 6.0)
177
+ True
178
+
179
+ Notes:
180
+ 1. 输出价格通常比输入价格高30%-50%
181
+ 2. 对于按token计费的API,实际收费按(input_tokens * input_price + output_tokens * output_price)计算
182
+ 3. 价格变更会影响所有依赖模型计费的功能(如成本预测、用量监控等)
159
183
  """
160
184
  if price < 0:
161
185
  raise ValueError("Price cannot be negative")
@@ -180,10 +204,7 @@ def update_model_speed(name: str, speed: float) -> bool:
180
204
 
181
205
  Returns:
182
206
  bool: 是否更新成功
183
- """
184
- if speed <= 0:
185
- raise ValueError("Speed must be positive")
186
-
207
+ """
187
208
  models = load_models()
188
209
  updated = False
189
210
  for model in models:
@@ -116,6 +116,7 @@ class PyProject:
116
116
  "actions",
117
117
  ".vscode",
118
118
  ".idea",
119
+ "venv",
119
120
  ]
120
121
 
121
122
  @byzerllm.prompt()
@@ -56,6 +56,7 @@ class SuffixProject:
56
56
  ".vscode",
57
57
  "actions",
58
58
  ".idea",
59
+ "venv",
59
60
  ]
60
61
 
61
62
  @byzerllm.prompt()
@@ -48,6 +48,7 @@ class TSProject:
48
48
  "actions",
49
49
  ".vscode",
50
50
  ".idea",
51
+ "venv",
51
52
  ]
52
53
 
53
54
  @byzerllm.prompt()
autocoder/utils/llms.py CHANGED
@@ -3,9 +3,15 @@ from typing import Union,Optional
3
3
 
4
4
  def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],target_model_type:Optional[str]=None):
5
5
  if target_model_type is None:
6
+ if isinstance(llm,list):
7
+ return [_llm.default_model_name for _llm in llm]
6
8
  return [llm.default_model_name for llm in [llm] if llm.default_model_name]
9
+
7
10
  llms = llm.get_sub_client(target_model_type)
11
+
8
12
  if llms is None:
13
+ if isinstance(llm,list):
14
+ return [_llm.default_model_name for _llm in llm]
9
15
  return [llm.default_model_name for llm in [llm] if llm.default_model_name]
10
16
  elif isinstance(llms, list):
11
17
  return [llm.default_model_name for llm in llms if llm.default_model_name]
@@ -14,6 +20,27 @@ def get_llm_names(llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM,str],tar
14
20
  else:
15
21
  return [llm.default_model_name for llm in [llms] if llm.default_model_name]
16
22
 
23
+ def get_model_info(model_names: str, product_mode: str):
24
+ from autocoder import models as models_module
25
+ def get_model_by_name(model_name: str):
26
+ try:
27
+ return models_module.get_model_by_name(model_name)
28
+ except Exception as e:
29
+ return None
30
+
31
+ if product_mode == "pro":
32
+ return None
33
+
34
+ if product_mode == "lite":
35
+ if "," in model_names:
36
+ # Multiple code models specified
37
+ model_names = model_names.split(",")
38
+ for _, model_name in enumerate(model_names):
39
+ return get_model_by_name(model_name)
40
+ else:
41
+ # Single code model
42
+ return get_model_by_name(model_names)
43
+
17
44
  def get_single_llm(model_names: str, product_mode: str):
18
45
  from autocoder import models as models_module
19
46
  if product_mode == "pro":
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.254"
1
+ __version__ = "0.1.256"