auto-coder 0.1.202__py3-none-any.whl → 0.1.204__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -64,7 +64,8 @@ MESSAGES = {
64
64
  "mode_desc": "Switch input mode",
65
65
  "lib_desc": "Manage libraries",
66
66
  "exit_desc": "Exit the program",
67
- "design_desc": "Generate SVG image based on the provided description",
67
+ "design_desc": "Generate SVG image based on the provided description",
68
+ "commit_desc": "Auto generate yaml file and commit changes based on user's manual changes",
68
69
  },
69
70
  "zh": {
70
71
  "initializing": "🚀 正在初始化系统...",
@@ -130,16 +131,19 @@ MESSAGES = {
130
131
  "lib_desc": "管理库",
131
132
  "exit_desc": "退出程序",
132
133
  "design_desc": "根据需求设计SVG图片",
133
-
134
+ "commit_desc": "根据用户人工修改的代码自动生成yaml文件并提交更改",
135
+
134
136
  }
135
137
  }
136
138
 
139
+
137
140
  def get_system_language():
138
141
  try:
139
142
  return locale.getdefaultlocale()[0][:2]
140
143
  except:
141
144
  return 'en'
142
145
 
146
+
143
147
  def get_message(key):
144
148
  lang = get_system_language()
145
- return MESSAGES.get(lang, MESSAGES['en']).get(key, MESSAGES['en'][key])
149
+ return MESSAGES.get(lang, MESSAGES['en']).get(key, MESSAGES['en'][key])
autocoder/command_args.py CHANGED
@@ -196,6 +196,7 @@ def parse_args(input_args: Optional[List[str]] = None) -> AutoCoderArgs:
196
196
  help="是否静默执行,不打印任何信息。默认为False",
197
197
  )
198
198
 
199
+
199
200
  revert_parser = subparsers.add_parser("revert", help=desc["revert_desc"])
200
201
  revert_parser.add_argument("--file", help=desc["revert_desc"])
201
202
  revert_parser.add_argument(
@@ -172,7 +172,7 @@ class CodeAutoGenerate:
172
172
  conversations = []
173
173
 
174
174
  if self.args.system_prompt and self.args.system_prompt.strip() == "claude":
175
- conversations.append({"role": "system", "content": sys_prompt.prompt()})
175
+ conversations.append({"role": "system", "content": sys_prompt.claude_sys_prompt.prompt()})
176
176
  elif self.args.system_prompt:
177
177
  conversations.append({"role": "system", "content": self.args.system_prompt})
178
178
 
@@ -308,7 +308,7 @@ class CodeAutoGenerateDiff:
308
308
 
309
309
  conversations = []
310
310
  if self.args.system_prompt and self.args.system_prompt.strip() == "claude":
311
- conversations.append({"role": "system", "content": sys_prompt.prompt()})
311
+ conversations.append({"role": "system", "content": sys_prompt.claude_sys_prompt.prompt()})
312
312
  else:
313
313
  conversations.append({"role": "user", "content": init_prompt})
314
314
 
@@ -390,7 +390,7 @@ class CodeAutoGenerateEditBlock:
390
390
  conversations = []
391
391
 
392
392
  if self.args.system_prompt and self.args.system_prompt.strip() == "claude":
393
- conversations.append({"role": "system", "content": sys_prompt.prompt()})
393
+ conversations.append({"role": "system", "content": sys_prompt.claude_sys_prompt.prompt()})
394
394
  elif self.args.system_prompt:
395
395
  conversations.append({"role": "system", "content": self.args.system_prompt})
396
396
 
@@ -279,7 +279,7 @@ class CodeAutoGenerateStrictDiff:
279
279
 
280
280
  conversations = []
281
281
  if self.args.system_prompt and self.args.system_prompt.strip() == "claude":
282
- conversations.append({"role": "system", "content": sys_prompt.prompt()})
282
+ conversations.append({"role": "system", "content": sys_prompt.claude_sys_prompt.prompt()})
283
283
  elif self.args.system_prompt:
284
284
  conversations.append({"role": "system", "content": self.args.system_prompt})
285
285
 
@@ -3,6 +3,7 @@ from git import Repo, GitCommandError
3
3
  from loguru import logger
4
4
  from typing import List, Optional
5
5
  from pydantic import BaseModel
6
+ import byzerllm
6
7
  from rich.console import Console
7
8
  from rich.panel import Panel
8
9
  from rich.syntax import Syntax
@@ -169,6 +170,439 @@ def revert_change(repo_path: str, message: str) -> bool:
169
170
  return False
170
171
 
171
172
 
173
+ def get_uncommitted_changes(repo_path: str) -> str:
174
+ """
175
+ 获取当前仓库未提交的所有变更,并以markdown格式返回详细报告
176
+
177
+ Args:
178
+ repo_path: Git仓库路径
179
+
180
+ Returns:
181
+ str: markdown格式的变更报告,包含新增/修改/删除的文件列表及其差异
182
+ """
183
+ repo = get_repo(repo_path)
184
+ if repo is None:
185
+ return "Error: Repository is not initialized."
186
+
187
+ try:
188
+ # 获取所有变更
189
+ changes = {
190
+ 'new': [], # 新增的文件
191
+ 'modified': [], # 修改的文件
192
+ 'deleted': [] # 删除的文件
193
+ }
194
+
195
+ # 获取未暂存的变更
196
+ diff_index = repo.index.diff(None)
197
+
198
+ # 获取未追踪的文件
199
+ untracked = repo.untracked_files
200
+
201
+ # 处理未暂存的变更
202
+ for diff_item in diff_index:
203
+ file_path = diff_item.a_path
204
+ diff_content = repo.git.diff(None, file_path)
205
+ if diff_item.new_file:
206
+ changes['new'].append((file_path, diff_content))
207
+ elif diff_item.deleted_file:
208
+ changes['deleted'].append((file_path, diff_content))
209
+ else:
210
+ changes['modified'].append((file_path, diff_content))
211
+
212
+ # 处理未追踪的文件
213
+ for file_path in untracked:
214
+ try:
215
+ with open(os.path.join(repo_path, file_path), 'r') as f:
216
+ content = f.read()
217
+ changes['new'].append((file_path, f'+++ {file_path}\n{content}'))
218
+ except Exception as e:
219
+ logger.error(f"Error reading file {file_path}: {e}")
220
+
221
+ # 生成markdown报告
222
+ report = ["# Git Changes Report\n"]
223
+
224
+ # 新增文件
225
+ if changes['new']:
226
+ report.append("\n## New Files")
227
+ for file_path, diff in changes['new']:
228
+ report.append(f"\n### {file_path}")
229
+ report.append("```diff")
230
+ report.append(diff)
231
+ report.append("```")
232
+
233
+ # 修改的文件
234
+ if changes['modified']:
235
+ report.append("\n## Modified Files")
236
+ for file_path, diff in changes['modified']:
237
+ report.append(f"\n### {file_path}")
238
+ report.append("```diff")
239
+ report.append(diff)
240
+ report.append("```")
241
+
242
+ # 删除的文件
243
+ if changes['deleted']:
244
+ report.append("\n## Deleted Files")
245
+ for file_path, diff in changes['deleted']:
246
+ report.append(f"\n### {file_path}")
247
+ report.append("```diff")
248
+ report.append(diff)
249
+ report.append("```")
250
+
251
+ # 如果没有任何变更
252
+ if not any(changes.values()):
253
+ return "No uncommitted changes found."
254
+
255
+ return "\n".join(report)
256
+
257
+ except GitCommandError as e:
258
+ logger.error(f"Error getting uncommitted changes: {e}")
259
+ return f"Error: {str(e)}"
260
+
261
+ @byzerllm.prompt()
262
+ def generate_commit_message(changes_report: str) -> str:
263
+ '''
264
+ 我是一个Git提交信息生成助手。我们的目标是通过一些变更报告,倒推用户的需求,将需求作为commit message。
265
+ commit message 需要简洁,不要超过100个字符。
266
+
267
+ 下面是一些示例:
268
+ <examples>
269
+ <example>
270
+ ## New Files
271
+ ### notebooks/tests/test_long_context_rag_answer_question.ipynb
272
+ ```diff
273
+ diff --git a/notebooks/tests/test_long_context_rag_answer_question.ipynb b/notebooks/tests/test_long_context_rag_answer_question.ipynb
274
+ new file mode 100644
275
+ index 00000000..c676b557
276
+ --- /dev/null
277
+ +++ b/notebooks/tests/test_long_context_rag_answer_question.ipynb
278
+ @@ -0,0 +1,122 @@
279
+ +{
280
+ + "cells": [
281
+ + {
282
+ + "cell_type": "markdown",
283
+ + "metadata": {},
284
+ + "source": [
285
+ + "# Test Long Context RAG Answer Question\n",
286
+ + "\n",
287
+ + "This notebook tests the `_answer_question` functionality in the `LongContextRAG` class."
288
+ + ]
289
+ + },
290
+ + {
291
+ + "cell_type": "code",
292
+ + "execution_count": null,
293
+ + "metadata": {},
294
+ + "outputs": [],
295
+ + "source": [
296
+ + "import os\n",
297
+ + "import sys\n",
298
+ + "from pathlib import Path\n",
299
+ + "import tempfile\n",
300
+ + "from loguru import logger\n",
301
+ + "from autocoder.rag.long_context_rag import LongContextRAG\n",
302
+ + "from autocoder.rag.rag_config import RagConfig\n",
303
+ + "from autocoder.rag.cache.simple_cache import AutoCoderRAGAsyncUpdateQueue\n",
304
+ + "from autocoder.rag.variable_holder import VariableHolder\n",
305
+ + "from tokenizers import Tokenizer\n",
306
+ + "\n",
307
+ + "# Setup tokenizer\n",
308
+ + "VariableHolder.TOKENIZER_PATH = \"/Users/allwefantasy/Downloads/tokenizer.json\"\n",
309
+ + "VariableHolder.TOKENIZER_MODEL = Tokenizer.from_file(VariableHolder.TOKENIZER_PATH)"
310
+ + ]
311
+ + },
312
+ + {
313
+ + "cell_type": "code",
314
+ + "execution_count": null,
315
+ + "metadata": {},
316
+ + "outputs": [],
317
+ + "source": [
318
+ + "# Create test files and directory\n",
319
+ + "test_dir = tempfile.mkdtemp()\n",
320
+ + "print(f\"Created test directory: {test_dir}\")\n",
321
+ + "\n",
322
+ + "# Create a test Python file\n",
323
+ + "test_file = os.path.join(test_dir, \"test_code.py\")\n",
324
+ + "with open(test_file, \"w\") as f:\n",
325
+ + " f.write(\"\"\"\n",
326
+ + "def calculate_sum(a: int, b: int) -> int:\n",
327
+ + " \"\"\"Calculate the sum of two integers.\"\"\"\n",
328
+ + " return a + b\n",
329
+ + "\n",
330
+ + "def calculate_product(a: int, b: int) -> int:\n",
331
+ + " \"\"\"Calculate the product of two integers.\"\"\"\n",
332
+ + " return a * b\n",
333
+ + " \"\"\")"
334
+ + ]
335
+ + },
336
+ + {
337
+ + "cell_type": "code",
338
+ + "execution_count": null,
339
+ + "metadata": {},
340
+ + "outputs": [],
341
+ + "source": [
342
+ + "# Initialize RAG components\n",
343
+ + "config = RagConfig(\n",
344
+ + " model=\"gpt-4-1106-preview\",\n",
345
+ + " path=test_dir,\n",
346
+ + " required_exts=[\".py\"],\n",
347
+ + " cache_type=\"simple\"\n",
348
+ + ")\n",
349
+ + "\n",
350
+ + "rag = LongContextRAG(config)\n",
351
+ + "\n",
352
+ + "# Test questions\n",
353
+ + "test_questions = [\n",
354
+ + " \"What does the calculate_sum function do?\",\n",
355
+ + " \"Show me all the functions that work with integers\",\n",
356
+ + " \"What's the return type of calculate_product?\"\n",
357
+ + "]\n",
358
+ + "\n",
359
+ + "# Test answers\n",
360
+ + "for question in test_questions:\n",
361
+ + " print(f\"\\nQuestion: {question}\")\n",
362
+ + " answer = rag._answer_question(question)\n",
363
+ + " print(f\"Answer: {answer}\")"
364
+ + ]
365
+ + },
366
+ + {
367
+ + "cell_type": "code",
368
+ + "execution_count": null,
369
+ + "metadata": {},
370
+ + "outputs": [],
371
+ + "source": [
372
+ + "# Clean up\n",
373
+ + "import shutil\n",
374
+ + "shutil.rmtree(test_dir)\n",
375
+ + "print(f\"Cleaned up test directory: {test_dir}\")"
376
+ + ]
377
+ + }
378
+ + ],
379
+ + "metadata": {
380
+ + "kernelspec": {
381
+ + "display_name": "Python 3",
382
+ + "language": "python",
383
+ + "name": "python3"
384
+ + },
385
+ + "language_info": {
386
+ + "codemirror_mode": {
387
+ + "name": "ipython",
388
+ + "version": 3
389
+ + },
390
+ + "file_extension": ".py",
391
+ + "mimetype": "text/x-python",
392
+ + "name": "python",
393
+ + "nbconvert_exporter": "python",
394
+ + "pygments_lexer": "ipython3",
395
+ + "version": "3.10.11"
396
+ + }
397
+ + },
398
+ + "nbformat": 4,
399
+ + "nbformat_minor": 4
400
+ +}
401
+
402
+ ```
403
+
404
+ 输出的commit 信息为:
405
+
406
+ 在 notebooks/tests 目录下新建一个 jupyter notebook, 对 @@_answer_question(location: src/autocoder/rag/long_context_rag.py) 进行测试
407
+ <example>
408
+
409
+ <example>
410
+ ## Modified Files
411
+ ### src/autocoder/utils/_markitdown.py
412
+ ```diff
413
+ diff --git a/src/autocoder/utils/_markitdown.py b/src/autocoder/utils/_markitdown.py
414
+ index da69b92b..dcecb74e 100644
415
+ --- a/src/autocoder/utils/_markitdown.py
416
+ +++ b/src/autocoder/utils/_markitdown.py
417
+ @@ -635,18 +635,22 @@ class DocxConverter(HtmlConverter):
418
+ """
419
+ Converts DOCX files to Markdown. Style information (e.g.m headings) and tables are preserved where possible.
420
+ """
421
+ +
422
+ + def __init__(self):
423
+ + self._image_counter = 0
424
+ + super().__init__()
425
+
426
+ def _save_image(self, image, output_dir: str) -> str:
427
+ """
428
+ - 保存图片并返回相对路径
429
+ + 保存图片并返回相对路径,使用递增的计数器来命名文件
430
+ """
431
+ # 获取图片内容和格式
432
+ image_content = image.open()
433
+ image_format = image.content_type.split('/')[-1] if image.content_type else 'png'
434
+
435
+ - # 生成唯一文件名
436
+ - image_filename = f"image_{hash(image_content.read())}.{image_format}"
437
+ - image_content.seek(0) # 重置文件指针
438
+ + # 增加计数器并生成文件名
439
+ + self._image_counter += 1
440
+ + image_filename = f"image_{self._image_counter}.{image_format}"
441
+
442
+ # 保存图片
443
+ image_path = os.path.join(output_dir, image_filename)
444
+ ```
445
+
446
+ 输出的commit 信息为:
447
+
448
+ @@DocxConverter(location: src/autocoder/utils/_markitdown.py) 中,修改 _save_image中保存图片的文件名使用递增而不是hash值
449
+ </example>
450
+
451
+ <example>
452
+ ## Modified Files
453
+ ### src/autocoder/common/code_auto_generate.py
454
+ ### src/autocoder/common/code_auto_generate_diff.py
455
+ ### src/autocoder/common/code_auto_generate_strict_diff.py
456
+ ```diff
457
+ diff --git a/src/autocoder/common/code_auto_generate.py b/src/autocoder/common/code_auto_generate.py
458
+ index b8f3b364..1b3da198 100644
459
+ --- a/src/autocoder/common/code_auto_generate.py
460
+ +++ b/src/autocoder/common/code_auto_generate.py
461
+ @@ -2,6 +2,7 @@ from typing import List, Dict, Tuple
462
+ from autocoder.common.types import Mode
463
+ from autocoder.common import AutoCoderArgs
464
+ import byzerllm
465
+ +from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
466
+
467
+
468
+ class CodeAutoGenerate:
469
+ @@ -146,6 +147,15 @@ class CodeAutoGenerate:
470
+ ) -> Tuple[str, Dict[str, str]]:
471
+ llm_config = {"human_as_model": self.args.human_as_model}
472
+
473
+ + if self.args.request_id and not self.args.skip_events:
474
+ + queue_communicate.send_event_no_wait(
475
+ + request_id=self.args.request_id,
476
+ + event=CommunicateEvent(
477
+ + event_type=CommunicateEventType.CODE_GENERATE_START.value,
478
+ + data=query,
479
+ + ),
480
+ + )
481
+ +
482
+ if self.args.template == "common":
483
+ init_prompt = self.single_round_instruction.prompt(
484
+ instruction=query, content=source_content, context=self.args.context
485
+ @@ -162,6 +172,16 @@ class CodeAutoGenerate:
486
+
487
+ t = self.llm.chat_oai(conversations=conversations, llm_config=llm_config)
488
+ conversations.append({"role": "assistant", "content": t[0].output})
489
+ +
490
+ + if self.args.request_id and not self.args.skip_events:
491
+ + queue_communicate.send_event_no_wait(
492
+ + request_id=self.args.request_id,
493
+ + event=CommunicateEvent(
494
+ + event_type=CommunicateEventType.CODE_GENERATE_END.value,
495
+ + data="",
496
+ + ),
497
+ + )
498
+ +
499
+ return [t[0].output], conversations
500
+
501
+ def multi_round_run(
502
+ diff --git a/src/autocoder/common/code_auto_generate_diff.py b/src/autocoder/common/code_auto_generate_diff.py
503
+ index 79a9e8d4..37f191a1 100644
504
+ --- a/src/autocoder/common/code_auto_generate_diff.py
505
+ +++ b/src/autocoder/common/code_auto_generate_diff.py
506
+ @@ -2,6 +2,7 @@ from typing import List, Dict, Tuple
507
+ from autocoder.common.types import Mode
508
+ from autocoder.common import AutoCoderArgs
509
+ import byzerllm
510
+ +from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
511
+
512
+
513
+ class CodeAutoGenerateDiff:
514
+ @@ -289,6 +290,15 @@ class CodeAutoGenerateDiff:
515
+ ) -> Tuple[str, Dict[str, str]]:
516
+ llm_config = {"human_as_model": self.args.human_as_model}
517
+
518
+ + if self.args.request_id and not self.args.skip_events:
519
+ + queue_communicate.send_event_no_wait(
520
+ + request_id=self.args.request_id,
521
+ + event=CommunicateEvent(
522
+ + event_type=CommunicateEventType.CODE_GENERATE_START.value,
523
+ + data=query,
524
+ + ),
525
+ + )
526
+ +
527
+ init_prompt = self.single_round_instruction.prompt(
528
+ instruction=query, content=source_content, context=self.args.context
529
+ )
530
+ @@ -300,6 +310,16 @@ class CodeAutoGenerateDiff:
531
+
532
+ t = self.llm.chat_oai(conversations=conversations, llm_config=llm_config)
533
+ conversations.append({"role": "assistant", "content": t[0].output})
534
+ +
535
+ + if self.args.request_id and not self.args.skip_events:
536
+ + queue_communicate.send_event_no_wait(
537
+ + request_id=self.args.request_id,
538
+ + event=CommunicateEvent(
539
+ + event_type=CommunicateEventType.CODE_GENERATE_END.value,
540
+ + data="",
541
+ + ),
542
+ + )
543
+ +
544
+ return [t[0].output], conversations
545
+
546
+ def multi_round_run(
547
+ diff --git a/src/autocoder/common/code_auto_generate_strict_diff.py b/src/autocoder/common/code_auto_generate_strict_diff.py
548
+ index 8874ae7a..91409c44 100644
549
+ --- a/src/autocoder/common/code_auto_generate_strict_diff.py
550
+ +++ b/src/autocoder/common/code_auto_generate_strict_diff.py
551
+ @@ -2,6 +2,7 @@ from typing import List, Dict, Tuple
552
+ from autocoder.common.types import Mode
553
+ from autocoder.common import AutoCoderArgs
554
+ import byzerllm
555
+ +from autocoder.utils.queue_communicate import queue_communicate, CommunicateEvent, CommunicateEventType
556
+
557
+
558
+ class CodeAutoGenerateStrictDiff:
559
+ @@ -260,6 +261,15 @@ class CodeAutoGenerateStrictDiff:
560
+ ) -> Tuple[str, Dict[str, str]]:
561
+ llm_config = {"human_as_model": self.args.human_as_model}
562
+
563
+ + if self.args.request_id and not self.args.skip_events:
564
+ + queue_communicate.send_event_no_wait(
565
+ + request_id=self.args.request_id,
566
+ + event=CommunicateEvent(
567
+ + event_type=CommunicateEventType.CODE_GENERATE_START.value,
568
+ + data=query,
569
+ + ),
570
+ + )
571
+ +
572
+ init_prompt = self.single_round_instruction.prompt(
573
+ instruction=query, content=source_content, context=self.args.context
574
+ )
575
+ @@ -271,6 +281,16 @@ class CodeAutoGenerateStrictDiff:
576
+
577
+ t = self.llm.chat_oai(conversations=conversations, llm_config=llm_config)
578
+ conversations.append({"role": "assistant", "content": t[0].output})
579
+ +
580
+ + if self.args.request_id and not self.args.skip_events:
581
+ + queue_communicate.send_event_no_wait(
582
+ + request_id=self.args.request_id,
583
+ + event=CommunicateEvent(
584
+ + event_type=CommunicateEventType.CODE_GENERATE_END.value,
585
+ + data="",
586
+ + ),
587
+ + )
588
+ +
589
+ return [t[0].output], conversations
590
+
591
+ def multi_round_run(
592
+ ```
593
+
594
+ 输出的commit 信息为:
595
+
596
+ 参考 @src/autocoder/common/code_auto_merge_editblock.py 中CODE_GENERATE_START,CODE_GENERATE_END 事件, 在其他文件里添加也添加这些事件. 注意,只需要修改 single_round_run 方法.
597
+ </example>
598
+ </examples>
599
+
600
+ 下面是变更报告:
601
+ {{ changes_report }}
602
+
603
+ 请输出commit message, 不要输出任何其他内容.
604
+ '''
605
+
172
606
  def print_commit_info(commit_result: CommitResult):
173
607
  console = Console()
174
608
  table = Table(
@@ -1,5 +1,5 @@
1
1
  import byzerllm
2
- import datetime
2
+ from datetime import datetime
3
3
 
4
4
  @byzerllm.prompt()
5
5
  def claude_sys_prompt():
@@ -62,6 +62,10 @@ def _check_relevance_with_conversation(
62
62
  class DocFilterWorker:
63
63
  def __init__(self, llm: ByzerLLM):
64
64
  self.llm = llm
65
+ if self.llm.get_sub_client("recall_model"):
66
+ self.recall_llm = self.llm.get_sub_client("recall_model")
67
+ else:
68
+ self.recall_llm = self.llm
65
69
 
66
70
  def filter_doc(
67
71
  self, conversations: List[Dict[str, str]], docs: List[str]
@@ -72,7 +76,8 @@ class DocFilterWorker:
72
76
  conversations=conversations, documents=docs
73
77
  )
74
78
  except Exception as e:
75
- logger.error(f"Error in _check_relevance_with_conversation: {str(e)}")
79
+ logger.error(
80
+ f"Error in _check_relevance_with_conversation: {str(e)}")
76
81
  return (None, submit_time_1, time.time())
77
82
 
78
83
  end_time_2 = time.time()
@@ -88,6 +93,11 @@ class DocFilter:
88
93
  path: Optional[str] = None,
89
94
  ):
90
95
  self.llm = llm
96
+ if self.llm.get_sub_client("recall_model"):
97
+ self.recall_llm = self.llm.get_sub_client("recall_model")
98
+ else:
99
+ self.recall_llm = self.llm
100
+
91
101
  self.args = args
92
102
  self.relevant_score = self.args.rag_doc_filter_relevance or 5
93
103
  self.on_ray = on_ray
@@ -95,7 +105,8 @@ class DocFilter:
95
105
  if self.on_ray:
96
106
  cpu_count = os.cpu_count() or 1
97
107
  self.workers = [
98
- DocFilterWorker.options(max_concurrency=1000, num_cpus=0).remote(llm)
108
+ DocFilterWorker.options(
109
+ max_concurrency=1000, num_cpus=0).remote(llm)
99
110
  for _ in range(cpu_count)
100
111
  ]
101
112
 
@@ -137,10 +148,12 @@ class DocFilter:
137
148
  submit_time_1 = time.time()
138
149
  try:
139
150
  llm = ByzerLLM()
140
- llm.setup_default_model_name(self.llm.default_model_name)
141
151
  llm.skip_nontext_check = True
152
+ llm.setup_default_model_name(self.recall_llm.default_model_name)
153
+
142
154
  v = (
143
- _check_relevance_with_conversation.with_llm(llm)
155
+ _check_relevance_with_conversation.with_llm(
156
+ llm)
144
157
  .options({"llm_config": {"max_length": 10}})
145
158
  .run(
146
159
  conversations=conversations,
@@ -194,10 +207,12 @@ class DocFilter:
194
207
  )
195
208
  )
196
209
  except Exception as exc:
197
- logger.error(f"Document processing generated an exception: {exc}")
210
+ logger.error(
211
+ f"Document processing generated an exception: {exc}")
198
212
 
199
213
  # Sort relevant_docs by relevance score in descending order
200
- relevant_docs.sort(key=lambda x: x.relevance.relevant_score, reverse=True)
214
+ relevant_docs.sort(
215
+ key=lambda x: x.relevance.relevant_score, reverse=True)
201
216
  return relevant_docs
202
217
 
203
218
  def filter_docs_with_ray(
@@ -210,7 +225,8 @@ class DocFilter:
210
225
  worker = self.workers[count % len(self.workers)]
211
226
  count += 1
212
227
  future = worker.filter_doc.remote(
213
- conversations, [f"##File: {doc.module_name}\n{doc.source_code}"]
228
+ conversations, [
229
+ f"##File: {doc.module_name}\n{doc.source_code}"]
214
230
  )
215
231
  futures.append((future, doc))
216
232
 
@@ -248,8 +264,10 @@ class DocFilter:
248
264
  )
249
265
  )
250
266
  except Exception as exc:
251
- logger.error(f"Document processing generated an exception: {exc}")
267
+ logger.error(
268
+ f"Document processing generated an exception: {exc}")
252
269
 
253
270
  # Sort relevant_docs by relevance score in descending order
254
- relevant_docs.sort(key=lambda x: x.relevance.relevant_score, reverse=True)
271
+ relevant_docs.sort(
272
+ key=lambda x: x.relevance.relevant_score, reverse=True)
255
273
  return relevant_docs
@@ -341,6 +341,11 @@ class LongContextRAG:
341
341
 
342
342
  return response_generator(), []
343
343
  else:
344
+
345
+ target_llm = self.llm
346
+ if self.llm.get_sub_client("qa_model"):
347
+ target_llm = self.llm.get_sub_client("qa_model")
348
+
344
349
  query = conversations[-1]["content"]
345
350
  context = []
346
351
 
@@ -349,8 +354,9 @@ class LongContextRAG:
349
354
  in query
350
355
  or "简要总结一下对话内容,用作后续的上下文提示 prompt,控制在 200 字以内"
351
356
  in query
352
- ):
353
- chunks = self.llm.stream_chat_oai(
357
+ ):
358
+
359
+ chunks = target_llm.stream_chat_oai(
354
360
  conversations=conversations,
355
361
  model=model,
356
362
  role_mapping=role_mapping,
@@ -384,7 +390,7 @@ class LongContextRAG:
384
390
 
385
391
  if self.args.without_contexts and LLMComputeEngine is not None:
386
392
  llm_compute_engine = LLMComputeEngine(
387
- llm=self.llm,
393
+ llm=target_llm,
388
394
  inference_enhance=not self.args.disable_inference_enhance,
389
395
  inference_deep_thought=self.args.inference_deep_thought,
390
396
  inference_slow_without_deep_thought=self.args.inference_slow_without_deep_thought,
@@ -566,7 +572,7 @@ class LongContextRAG:
566
572
 
567
573
  if LLMComputeEngine is not None and not self.args.disable_inference_enhance:
568
574
  llm_compute_engine = LLMComputeEngine(
569
- llm=self.llm,
575
+ llm=target_llm,
570
576
  inference_enhance=not self.args.disable_inference_enhance,
571
577
  inference_deep_thought=self.args.inference_deep_thought,
572
578
  precision=self.args.inference_compute_precision,
@@ -597,7 +603,7 @@ class LongContextRAG:
597
603
  }
598
604
  ]
599
605
 
600
- chunks = self.llm.stream_chat_oai(
606
+ chunks = target_llm.stream_chat_oai(
601
607
  conversations=new_conversations,
602
608
  model=model,
603
609
  role_mapping=role_mapping,