auto-coder 0.1.288__py3-none-any.whl → 0.1.290__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of auto-coder might be problematic. Click here for more details.

@@ -0,0 +1,139 @@
1
+ from typing import List, Dict, Any, Optional, Union
2
+ import logging
3
+ import byzerllm
4
+ from pydantic import BaseModel
5
+ from autocoder.common import AutoCoderArgs
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class SearchQuery(BaseModel):
11
+ """搜索查询模型"""
12
+ query: str
13
+ importance: int = 5 # 1-10,表示查询的重要性
14
+ purpose: str = "" # 查询的目的说明
15
+
16
+ class ConversationToQueries:
17
+ """
18
+ 将对话历史转换为搜索查询的工具类。
19
+ """
20
+
21
+ def __init__(self, llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM]):
22
+ """
23
+ 初始化对话转查询工具类。
24
+
25
+ 参数:
26
+ llm: ByzerLLM 实例,用于执行 prompt 函数
27
+ """
28
+ self.llm = llm
29
+
30
+ @byzerllm.prompt()
31
+ def generate_search_queries(self, conversations: List[Dict[str, Any]], max_queries: int = 3) -> str:
32
+ """
33
+ 根据历史对话生成搜索查询。
34
+
35
+ 参数:
36
+ conversations: 历史对话列表,每个对话是一个字典,包含 'role' 和 'content' 字段
37
+ max_queries: 最大生成的查询数量,默认为 3
38
+
39
+ 返回:
40
+ 生成的搜索查询列表的 JSON 字符串
41
+
42
+ 任务说明:
43
+ 你是一个专业的对话分析助手。你的任务是分析用户与 AI 的对话历史,从中提取关键信息,
44
+ 并生成用于搜索引擎的查询,以便获取与对话相关的知识和信息。
45
+
46
+ 具体要求:
47
+ 1. 仔细分析对话历史,特别是最近的几轮对话
48
+ 2. 识别用户可能需要更多信息或知识的关键问题和主题
49
+ 3. 将这些关键问题转化为明确、简洁的搜索查询
50
+ 4. 每个查询应该足够具体,能够通过搜索引擎找到有用的结果
51
+ 5. 为每个查询提供重要性评分(1-10 分)和用途说明
52
+ 6. 最多生成 {{ max_queries }} 个查询,按重要性排序
53
+ 7. 返回符合指定格式的 JSON 数据
54
+
55
+ 可能的场景:
56
+ - 用户询问特定技术或概念,需要进一步的解释或示例
57
+ - 用户遇到编程问题,需要查找解决方案或最佳实践
58
+ - 用户讨论的话题涉及多个方面,需要查找不同角度的信息
59
+ - 用户想了解某个领域的最新发展或趋势
60
+
61
+ ---
62
+
63
+ 对话历史:
64
+ <conversations>
65
+ {% for msg in conversations %}
66
+ {{ msg.role }}: {{ msg.content }}
67
+ {% endfor %}
68
+ </conversations>
69
+
70
+ 请分析上述对话,提取关键问题并生成最多 {{ max_queries }} 个搜索查询。
71
+
72
+ 输出格式:
73
+ ```json
74
+ [
75
+ {
76
+ "query": "搜索查询1",
77
+ "importance": 评分(1-10),
78
+ "purpose": "该查询的目的说明"
79
+ },
80
+ {
81
+ "query": "搜索查询2",
82
+ "importance": 评分(1-10),
83
+ "purpose": "该查询的目的说明"
84
+ }
85
+ ]
86
+ ```
87
+ """
88
+
89
+ def extract_queries(self, conversations: List[Dict[str, Any]], max_queries: int = 3) -> List[SearchQuery]:
90
+ """
91
+ 从对话历史中提取搜索查询。
92
+
93
+ 参数:
94
+ conversations: 历史对话列表
95
+ max_queries: 最大生成的查询数量
96
+
97
+ 返回:
98
+ SearchQuery 对象列表
99
+ """
100
+ try:
101
+ # 使用 prompt 函数生成搜索查询
102
+ queries = self.generate_search_queries.with_llm(self.llm).with_return_type(SearchQuery).run(
103
+ conversations=conversations,
104
+ max_queries=max_queries
105
+ )
106
+
107
+ # 按重要性排序
108
+ queries.sort(key=lambda x: x.importance, reverse=True)
109
+
110
+ return queries
111
+ except Exception as e:
112
+ logger.error(f"Error extracting queries from conversation: {str(e)}")
113
+ return []
114
+
115
+ def extract_search_queries(
116
+ conversations: List[Dict[str, Any]],
117
+ args:AutoCoderArgs,
118
+ llm: Union[byzerllm.ByzerLLM, byzerllm.SimpleByzerLLM],
119
+ max_queries: int = 3,
120
+ ) -> List[SearchQuery]:
121
+ """
122
+ 从对话历史中提取搜索查询的便捷函数。
123
+
124
+ 参数:
125
+ conversations: 历史对话列表
126
+ llm: ByzerLLM 实例
127
+ max_queries: 最大生成的查询数量
128
+
129
+ 返回:
130
+ SearchQuery 对象列表
131
+ """
132
+ if max_queries == 0:
133
+ return []
134
+ try:
135
+ extractor = ConversationToQueries(llm)
136
+ return extractor.extract_queries(conversations, max_queries)
137
+ except Exception as e:
138
+ logger.error(f"Error extracting search queries from conversation: {str(e)}")
139
+ return []
@@ -38,7 +38,8 @@ from pydantic import BaseModel
38
38
  from byzerllm.utils.types import SingleOutputMeta
39
39
  from autocoder.rag.lang import get_message_with_format_and_newline
40
40
  from autocoder.rag.qa_conversation_strategy import get_qa_strategy
41
-
41
+ from autocoder.rag.searchable import SearchableResults
42
+ from autocoder.rag.conversation_to_queries import extract_search_queries
42
43
  try:
43
44
  from autocoder_pro.rag.llm_compute import LLMComputeEngine
44
45
  pro_version = version("auto-coder-pro")
@@ -257,7 +258,7 @@ class LongContextRAG:
257
258
  请根据提供的文档内容、用户对话历史以及最后一个问题,提取并总结文档中与问题相关的重要信息。
258
259
  如果文档中没有相关信息,请回复"该文档中没有与问题相关的信息"。
259
260
  提取的信息尽量保持和原文中的一样,并且只输出这些信息。
260
- """
261
+ """
261
262
 
262
263
  def _get_document_retriever_class(self):
263
264
  """Get the document retriever class based on configuration."""
@@ -333,7 +334,9 @@ class LongContextRAG:
333
334
 
334
335
  def _filter_docs(self, conversations: List[Dict[str, str]]) -> DocFilterResult:
335
336
  query = conversations[-1]["content"]
336
- documents = self._retrieve_documents(options={"query": query})
337
+ queries = extract_search_queries(conversations=conversations, args=self.args, llm=self.llm, max_queries=self.args.rag_recall_max_queries)
338
+ documents = self._retrieve_documents(
339
+ options={"queries": [query] + [query.query for query in queries]})
337
340
  return self.doc_filter.filter_docs(
338
341
  conversations=conversations, documents=documents
339
342
  )
@@ -500,6 +503,9 @@ class LongContextRAG:
500
503
  except json.JSONDecodeError:
501
504
  pass
502
505
 
506
+ if not only_contexts and extra_request_params.get("only_contexts", False):
507
+ only_contexts = True
508
+
503
509
  logger.info(f"Query: {query} only_contexts: {only_contexts}")
504
510
  start_time = time.time()
505
511
 
@@ -543,7 +549,10 @@ class LongContextRAG:
543
549
  model_name=rag_stat.recall_stat.model_name
544
550
  )
545
551
  query = conversations[-1]["content"]
546
- documents = self._retrieve_documents(options={"query": query})
552
+ queries = extract_search_queries(
553
+ conversations=conversations, args=self.args, llm=self.llm, max_queries=self.args.rag_recall_max_queries)
554
+ documents = self._retrieve_documents(
555
+ options={"queries": [query] + [query.query for query in queries]})
547
556
 
548
557
  # 使用带进度报告的过滤方法
549
558
  for progress_update, result in self.doc_filter.filter_docs_with_progress(conversations, documents):
@@ -593,10 +602,19 @@ class LongContextRAG:
593
602
  )
594
603
 
595
604
  if only_contexts:
596
- final_docs = []
597
- for doc in relevant_docs:
598
- final_docs.append(doc.model_dump())
599
- return [json.dumps(final_docs, ensure_ascii=False)], []
605
+ try:
606
+ searcher = SearchableResults()
607
+ result = searcher.reorder(docs=relevant_docs)
608
+ yield (json.dumps(result.model_dump(), ensure_ascii=False), SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
609
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
610
+ rag_stat.chunk_stat.total_generated_tokens,
611
+ ))
612
+ except Exception as e:
613
+ yield (str(e), SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
614
+ generated_tokens_count=rag_stat.recall_stat.total_generated_tokens +
615
+ rag_stat.chunk_stat.total_generated_tokens,
616
+ ))
617
+ return
600
618
 
601
619
  if not relevant_docs:
602
620
  yield ("没有找到可以回答你问题的相关文档", SingleOutputMeta(input_tokens_count=rag_stat.recall_stat.total_input_tokens + rag_stat.chunk_stat.total_input_tokens,
@@ -816,12 +834,13 @@ class LongContextRAG:
816
834
 
817
835
  self._print_rag_stats(rag_stat)
818
836
  else:
819
-
820
- qa_strategy = get_qa_strategy(self.args.rag_qa_conversation_strategy)
837
+
838
+ qa_strategy = get_qa_strategy(
839
+ self.args.rag_qa_conversation_strategy)
821
840
  new_conversations = qa_strategy.create_conversation(
822
841
  documents=[doc.source_code for doc in relevant_docs],
823
- conversations=conversations
824
- )
842
+ conversations=conversations, local_image_host=self.args.local_image_host
843
+ )
825
844
 
826
845
  chunks = target_llm.stream_chat_oai(
827
846
  conversations=new_conversations,
@@ -8,7 +8,7 @@ class QAConversationStrategy(ABC):
8
8
  Different strategies organize documents and conversations differently.
9
9
  """
10
10
  @abstractmethod
11
- def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]]) -> List[Dict]:
11
+ def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]], local_image_host: str) -> List[Dict]:
12
12
  """
13
13
  Create a conversation structure based on documents and history
14
14
 
@@ -26,10 +26,10 @@ class MultiRoundStrategy(QAConversationStrategy):
26
26
  Multi-round strategy: First let the model read documents, then do Q&A.
27
27
  Creates multiple conversation turns.
28
28
  """
29
- def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]]) -> List[Dict]:
29
+ def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]], local_image_host: str) -> List[Dict]:
30
30
  messages = []
31
31
  messages.extend([
32
- {"role": "user", "content": self._read_docs_prompt.prompt(documents)},
32
+ {"role": "user", "content": self._read_docs_prompt.prompt(documents, local_image_host)},
33
33
  {"role": "assistant", "content": "好的"}
34
34
  ])
35
35
  messages.extend(conversations)
@@ -37,7 +37,7 @@ class MultiRoundStrategy(QAConversationStrategy):
37
37
 
38
38
  @byzerllm.prompt()
39
39
  def _read_docs_prompt(
40
- self, relevant_docs: List[str]
40
+ self, relevant_docs: List[str], local_image_host: str
41
41
  ) -> Generator[str, None, None]:
42
42
  """
43
43
  请阅读以下:
@@ -53,29 +53,35 @@ class MultiRoundStrategy(QAConversationStrategy):
53
53
  - 如果文档提供的信息无法回答问题,请明确回复:"抱歉,文档中没有足够的信息来回答这个问题。"
54
54
  - 不要添加、推测或扩展文档未提及的信息
55
55
 
56
- 2. 格式如 ![image](./path.png) 的 Markdown 图片处理
56
+ 2. 格式如 ![image](/path/to/images/path.png) 的 Markdown 图片处理
57
57
  - 根据Markdown 图片前后文本内容推测改图片与问题的相关性,有相关性则在回答中输出该Markdown图片路径
58
58
  - 根据相关图片在文档中的位置,自然融入答复内容,保持上下文连贯
59
59
  - 完整保留原始图片路径,不省略任何部分
60
60
 
61
61
  3. 回答格式要求
62
62
  - 使用markdown格式提升可读性
63
+ {% if local_image_host %}
64
+ 4. 图片路径处理
65
+ - 图片地址需返回绝对路径,
66
+ - 为请求图片资源 需增加 http://{{ local_image_host }}/static/ 作为前缀
67
+ 例如:/path/to/images/image.png, 返回 http://{{ local_image_host }}/static/path/to/images/image.png
68
+ {% endif %}
63
69
  """
64
70
 
65
71
  class SingleRoundStrategy(QAConversationStrategy):
66
72
  """
67
73
  Single-round strategy: Put documents and conversation history in a single round.
68
74
  """
69
- def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]]) -> List[Dict]:
75
+ def create_conversation(self, documents: List[Any], conversations: List[Dict[str,str]], local_image_host: str) -> List[Dict]:
70
76
  messages = []
71
77
  messages.extend([
72
- {"role": "user", "content": self._single_round_answer_question.prompt(documents, conversations)}
78
+ {"role": "user", "content": self._single_round_answer_question.prompt(documents, conversations, local_image_host)}
73
79
  ])
74
80
  return messages
75
81
 
76
82
  @byzerllm.prompt()
77
83
  def _single_round_answer_question(
78
- self, relevant_docs: List[str], conversations: List[Dict[str, str]]
84
+ self, relevant_docs: List[str], conversations: List[Dict[str, str]], local_image_host: str
79
85
  ) -> Generator[str, None, None]:
80
86
  """
81
87
  文档:
@@ -98,14 +104,19 @@ class SingleRoundStrategy(QAConversationStrategy):
98
104
  - 如果文档提供的信息无法回答问题,请明确回复:"抱歉,文档中没有足够的信息来回答这个问题。"
99
105
  - 不要添加、推测或扩展文档未提及的信息
100
106
 
101
- 2. 格式如 ![image](./path.png) 的 Markdown 图片处理
107
+ 2. 格式如 ![image](/path/to/images/path.png) 的 Markdown 图片处理
102
108
  - 根据Markdown 图片前后文本内容推测改图片与问题的相关性,有相关性则在回答中输出该Markdown图片路径
103
109
  - 根据相关图片在文档中的位置,自然融入答复内容,保持上下文连贯
104
110
  - 完整保留原始图片路径,不省略任何部分
105
111
 
106
112
  3. 回答格式要求
107
113
  - 使用markdown格式提升可读性
108
-
114
+ {% if local_image_host %}
115
+ 4. 图片路径处理
116
+ - 图片地址需返回绝对路径,
117
+ - 为请求图片资源 需增加 http://{{ local_image_host }}/static/ 作为前缀
118
+ 例如:/path/to/images/image.png, 返回 http://{{ local_image_host }}/static/path/to/images/image.png
119
+ {% endif %}
109
120
  """
110
121
 
111
122
  def get_qa_strategy(strategy_name: str) -> QAConversationStrategy:
@@ -0,0 +1,58 @@
1
+ import json
2
+ from collections import Counter
3
+ from typing import Dict, List, Any, Optional, Tuple, Set
4
+ from pydantic import BaseModel
5
+ from autocoder.rag.relevant_utils import FilterDoc
6
+
7
+
8
+ class FileOccurrence(BaseModel):
9
+ """Represents a file and its occurrence count in search results"""
10
+ file_path: str
11
+ count: int
12
+ score: float = 0.0 # Optional relevance score
13
+
14
+ class FileResult(BaseModel):
15
+ files: List[FileOccurrence]
16
+
17
+ class SearchableResults:
18
+ """Class to process and organize search results by file frequency"""
19
+
20
+ def __init__(self):
21
+ """Initialize the SearchableResults instance"""
22
+ pass
23
+
24
+ def extract_original_docs(self, docs: List[FilterDoc]) -> List[str]:
25
+ """Extract all original_docs from a list of document metadata"""
26
+ all_files = []
27
+
28
+ for doc in docs:
29
+ # Extract from metadata if available
30
+ metadata = doc.source_code.metadata
31
+ if "original_docs" in metadata:
32
+ all_files.extend(metadata["original_docs"])
33
+ # Also include the module_name from source_code as a fallback
34
+ else:
35
+ all_files.append(doc.source_code.module_name)
36
+
37
+ return all_files
38
+
39
+ def count_file_occurrences(self, files: List[str]) -> List[FileOccurrence]:
40
+ """Count occurrences of each file and return sorted list"""
41
+ # Count occurrences
42
+ counter = Counter(files)
43
+
44
+ # Convert to FileOccurrence objects
45
+ occurrences = [
46
+ FileOccurrence(file_path=file_path, count=count)
47
+ for file_path, count in counter.items()
48
+ ]
49
+
50
+ # Sort by count (descending)
51
+ return sorted(occurrences, key=lambda x: x.count, reverse=True)
52
+
53
+ def reorder(self, docs: List[FilterDoc]) -> List[FileOccurrence]:
54
+ """Process search results to extract and rank files by occurrence (main entry point)"""
55
+ all_files = self.extract_original_docs(docs)
56
+ return FileResult(files=self.count_file_occurrences(all_files))
57
+
58
+
autocoder/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.288"
1
+ __version__ = "0.1.290"