pydatamax 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. datamax/__init__.py +1 -1
  2. datamax/loader/core.py +118 -118
  3. datamax/loader/{MinioHandler.py → minio_handler.py} +171 -171
  4. datamax/loader/{OssHandler.py → oss_handler.py} +191 -191
  5. datamax/parser/__init__.py +2 -4
  6. datamax/parser/base.py +76 -76
  7. datamax/parser/core.py +406 -288
  8. datamax/parser/csv_parser.py +31 -10
  9. datamax/parser/doc_parser.py +525 -61
  10. datamax/parser/docx_parser.py +512 -62
  11. datamax/parser/epub_parser.py +41 -41
  12. datamax/parser/html_parser.py +37 -37
  13. datamax/parser/image_parser.py +34 -34
  14. datamax/parser/json_parser.py +32 -10
  15. datamax/parser/md_parser.py +72 -72
  16. datamax/parser/pdf_parser.py +101 -101
  17. datamax/parser/ppt_parser.py +70 -20
  18. datamax/parser/pptx_parser.py +45 -45
  19. datamax/parser/txt_parser.py +45 -45
  20. datamax/parser/xls_parser.py +26 -26
  21. datamax/parser/xlsx_parser.py +212 -208
  22. datamax/utils/__init__.py +23 -2
  23. datamax/utils/constants.py +58 -58
  24. datamax/utils/data_cleaner.py +275 -237
  25. datamax/utils/env_setup.py +79 -79
  26. datamax/utils/gotocr_pdf.py +265 -265
  27. datamax/utils/mineru_operator.py +62 -62
  28. datamax/utils/paddleocr_pdf_operator.py +90 -90
  29. datamax/utils/ppt_extract.py +140 -140
  30. datamax/utils/qa_generator.py +369 -376
  31. datamax/utils/tokenizer.py +21 -21
  32. datamax/utils/uno_handler.py +426 -0
  33. pydatamax-0.1.15.dist-info/METADATA +340 -0
  34. pydatamax-0.1.15.dist-info/RECORD +38 -0
  35. {pydatamax-0.1.13.dist-info → pydatamax-0.1.15.dist-info}/licenses/LICENSE +21 -21
  36. {pydatamax-0.1.13.dist-info → pydatamax-0.1.15.dist-info}/top_level.txt +0 -1
  37. pydatamax-0.1.13.dist-info/METADATA +0 -280
  38. pydatamax-0.1.13.dist-info/RECORD +0 -39
  39. tests/__init__.py +0 -0
  40. tests/test_basic.py +0 -20
  41. {pydatamax-0.1.13.dist-info → pydatamax-0.1.15.dist-info}/WHEEL +0 -0
@@ -1,376 +1,369 @@
1
- import json
2
- import os.path
3
- import re
4
- import threading
5
- from concurrent.futures import ThreadPoolExecutor, as_completed
6
- from pathlib import Path
7
-
8
- import requests
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
- from langchain_community.document_loaders import UnstructuredMarkdownLoader
11
- from loguru import logger
12
- from pyexpat.errors import messages
13
- from tqdm import tqdm # For progress bar display
14
-
15
- lock = threading.Lock()
16
-
17
-
18
- # ------------prompt-----------------
19
- def get_system_prompt_for_question(query_text, question_number):
20
- """Generate system prompt for question generation task"""
21
- system_prompt = f"""
22
- # 角色使命
23
- 你是一位专业的文本分析专家,擅长从复杂文本中提取关键信息并生成可用于模型微调的结构化数据(仅生成问题)。
24
-
25
- ## 核心任务
26
- 根据用户提供的文本,生成不少于 ${question_number} 个高质量问题。
27
-
28
- ## 约束条件(重要!)
29
- - 必须基于文本内容直接生成
30
- - 问题应具有明确答案指向性
31
- - 需覆盖文本的不同方面
32
- - 禁止生成假设性、重复或相似问题
33
- - 确保生成得完整性
34
-
35
- ## 处理流程
36
- 1. 【文本解析】分段处理内容,识别关键实体和核心概念
37
- 2. 【问题生成】基于信息密度选择最佳提问点
38
- 3. 【质量检查】确保:
39
- - 问题答案可在原文中找到依据
40
- - 标签与问题内容强相关
41
- - 无格式错误
42
-
43
- ## 输出格式
44
- - JSON 数组格式必须正确
45
- - 字段名使用英文双引号
46
- - 输出的 JSON 数组必须严格符合以下结构:
47
- \`\`\`json
48
- ["问题1", "问题2", "..."]
49
- \`\`\`
50
-
51
- ## 输出示例
52
- \`\`\`json
53
- [ "人工智能伦理框架应包含哪些核心要素?","民法典对个人数据保护有哪些新规定?"]
54
- \`\`\`
55
-
56
- ## 待处理文本
57
- ${query_text}
58
-
59
- ## 限制
60
- - 必须按照规定的 JSON 格式输出,不要输出任何其他不相关内容
61
- - 生成不少于${question_number}个高质量问题
62
- - 问题不要和材料本身相关,例如禁止出现作者、章节、目录等相关问题
63
- - 问题不得包含【报告、文章、文献、表格】中提到的这种话术,必须是一个自然的问题
64
- """
65
- return system_prompt
66
-
67
-
68
- def get_system_prompt_for_answer(text, query_question):
69
- """Generate system prompt for answer generation task"""
70
- system_prompt = f"""
71
- # Role: 微调数据集生成专家
72
- ## Profile:
73
- - Description: 你是一名微调数据集生成专家,擅长从给定的内容中生成准确的问题答案,确保答案的准确性和相关性,你要直接回答用户问题,所有信息已内化为你的专业知识。
74
-
75
- ## Skills :
76
- 1. 答案必须基于给定的内容
77
- 2. 答案必须准确,不能胡编乱造
78
- 3. 答案必须与问题相关
79
- 4. 答案必须符合逻辑
80
- 5. 基于给定参考内容,用自然流畅的语言整合成一个完整答案,不需要提及文献来源或引用标记
81
-
82
- ## Workflow:
83
- 1. Take a deep breath and work on this problem step-by-step.
84
- 2. 首先,分析给定的文件内容
85
- 3. 然后,从内容中提取关键信息
86
- 4. 接着,生成与问题相关的准确答案
87
- 5. 最后,确保答案的准确性和相关性
88
-
89
- ## 参考内容:
90
- ${text}
91
-
92
- ## 问题
93
- ${query_question}
94
-
95
- ## Constrains:
96
- 1. 答案必须基于给定的内容
97
- 2. 答案必须准确,必须与问题相关,不能胡编乱造
98
- 3. 答案必须充分、详细、包含所有必要的信息、适合微调大模型训练使用
99
- 4. 答案中不得出现 ' 参考 / 依据 / 文献中提到 ' 等任何引用性表述,只需呈现最终结果
100
- """
101
- return system_prompt
102
-
103
-
104
- # ------------spliter----------------
105
- def load_and_split_markdown(md_path: str, chunk_size: int, chunk_overlap: int) -> list:
106
- """
107
- Parse Markdown using UnstructuredMarkdownLoader
108
- Chunking strategy that preserves original paragraph structure
109
-
110
- Args:
111
- md_path: Path to the markdown file
112
- chunk_size: Size of each chunk
113
- chunk_overlap: Overlap between chunks
114
-
115
- Returns:
116
- List of document chunks
117
- """
118
- try:
119
- # Use LangChain's MarkdownLoader to load Markdown file
120
- loader = UnstructuredMarkdownLoader(md_path)
121
- documents = loader.load()
122
- # Further split documents if needed
123
- splitter = RecursiveCharacterTextSplitter(
124
- chunk_size=chunk_size,
125
- chunk_overlap=chunk_overlap,
126
- length_function=len,
127
- is_separator_regex=False,
128
- )
129
- return splitter.split_documents(documents)
130
- except Exception as e:
131
- logger.error(f"加载 {Path(md_path).name} 失败: {str(e)}")
132
- return []
133
-
134
-
135
- # ------------llm generator-------------------
136
- def extract_json_from_llm_output(output: str):
137
- """
138
- Extract JSON content from LLM output, handling multiple possible formats
139
-
140
- Args:
141
- output: Raw output string from LLM
142
-
143
- Returns:
144
- Parsed JSON list if successful, None otherwise
145
- """
146
- # Try to parse the entire output directly
147
- try:
148
- return json.loads(output)
149
- except json.JSONDecodeError:
150
- pass
151
-
152
- # Try to extract content wrapped in ```json ```
153
- json_match = re.search(r"```json\n([\s\S]*?)\n```", output)
154
- if json_match:
155
- try:
156
- return json.loads(json_match.group(1))
157
- except json.JSONDecodeError as e:
158
- print(f"解析 JSON 时出错: {e}")
159
-
160
- # Try to extract the most JSON-like part
161
- json_start = output.find("[")
162
- json_end = output.rfind("]") + 1
163
- if json_start != -1 and json_end != 0:
164
- try:
165
- return json.loads(output[json_start:json_end])
166
- except json.JSONDecodeError:
167
- pass
168
-
169
- print("模型未按标准格式输出:", output)
170
- return None
171
-
172
-
173
- def llm_generator(
174
- api_key: str,
175
- model: str,
176
- base_url: str,
177
- prompt: str,
178
- type: str,
179
- message: list = None,
180
- temperature: float = 0.7,
181
- top_p: float = 0.9,
182
- max_token: int = 2048,
183
- ) -> list:
184
- """Generate content using LLM API"""
185
- try:
186
- if not message:
187
- message = [
188
- {"role": "system", "content": prompt},
189
- {"role": "user", "content": "请严格按照要求生成内容"},
190
- ]
191
- headers = {
192
- "Authorization": f"Bearer {api_key}",
193
- "Content-Type": "application/json",
194
- }
195
- data = {
196
- "model": model,
197
- "messages": message,
198
- "max_tokens": max_token,
199
- "temperature": temperature,
200
- "top_p": top_p,
201
- }
202
- response = requests.post(base_url, headers=headers, json=data, timeout=30)
203
- response.raise_for_status()
204
- result = response.json()
205
-
206
- # Parse LLM response
207
- if "choices" in result and len(result["choices"]) > 0:
208
- output = result["choices"][0]["message"]["content"]
209
- if type == "question":
210
- fmt_output = extract_json_from_llm_output(output)
211
- else:
212
- return output
213
- return fmt_output
214
- return []
215
-
216
- except Exception as e:
217
- print(f"LLM提取关键词失败: {e, e.__traceback__.tb_lineno}")
218
- return []
219
-
220
-
221
- # ------------thread_process-------------
222
-
223
-
224
- def process_questions(
225
- api_key: str,
226
- model: str,
227
- base_url: str,
228
- page_content: list,
229
- question_number: int,
230
- message: list,
231
- max_workers: int = 5,
232
- ) -> list:
233
- """Generate questions using multi-threading"""
234
- total_questions = []
235
-
236
- def _generate_questions(page):
237
- """Inner function for question generation"""
238
- prompt = get_system_prompt_for_question(page, question_number)
239
- questions = llm_generator(
240
- api_key=api_key,
241
- model=model,
242
- base_url=base_url,
243
- message=message,
244
- prompt=prompt,
245
- type="question",
246
- )
247
- return [{"question": q, "page": page} for q in questions] if questions else []
248
-
249
- logger.info(f"开始生成问题 (线程数: {max_workers})...")
250
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
251
- futures = [executor.submit(_generate_questions, page) for page in page_content]
252
-
253
- with tqdm(as_completed(futures), total=len(futures), desc="生成问题") as pbar:
254
- for future in pbar:
255
- result = future.result()
256
- if result:
257
- with lock:
258
- total_questions.extend(result)
259
- pbar.set_postfix({"已生成问题": len(total_questions)})
260
-
261
- return total_questions
262
-
263
-
264
- def process_answers(
265
- api_key: str,
266
- model: str,
267
- base_url: str,
268
- question_items: list,
269
- message: list = None,
270
- max_workers=5,
271
- ) -> dict:
272
- """Generate answers using multi-threading"""
273
- qa_pairs = {}
274
-
275
- def _generate_answer(item):
276
- """Inner function for answer generation"""
277
- prompt = get_system_prompt_for_answer(item["page"], item["question"])
278
- answer = llm_generator(
279
- api_key=api_key,
280
- model=model,
281
- base_url=base_url,
282
- prompt=prompt,
283
- message=message,
284
- type="answer",
285
- )
286
- return item["question"], answer
287
-
288
- logger.info(f"开始生成答案 (线程数: {max_workers})...")
289
- with ThreadPoolExecutor(max_workers=max_workers) as executor:
290
- futures = {
291
- executor.submit(_generate_answer, item): item for item in question_items
292
- }
293
-
294
- with tqdm(as_completed(futures), total=len(futures), desc="生成答案") as pbar:
295
- for future in pbar:
296
- question, answer = future.result()
297
- if answer:
298
- with lock:
299
- qa_pairs[question] = answer
300
- pbar.set_postfix({"已生成答案": len(qa_pairs)})
301
- return qa_pairs
302
-
303
-
304
- def generatr_qa_pairs(
305
- file_path: str,
306
- api_key: str,
307
- base_url: str,
308
- model_name: str,
309
- chunk_size=500,
310
- chunk_overlap=100,
311
- question_number=5,
312
- message: list = None,
313
- max_workers=5,
314
- ):
315
- """Main function to generate QA pairs from markdown file"""
316
- # 1. Split markdown text into chunks
317
- pages = load_and_split_markdown(
318
- md_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
319
- )
320
- page_content = [i.page_content for i in pages]
321
- logger.info(f"markdown被分解了{len(page_content)}个chunk")
322
-
323
- # 2. Generate questions using multi-threading
324
- questions = process_questions(
325
- page_content=page_content,
326
- message=message,
327
- question_number=question_number,
328
- max_workers=max_workers,
329
- api_key=api_key,
330
- base_url=base_url,
331
- model=model_name,
332
- )
333
- if not questions:
334
- logger.error("未能生成任何问题,请检查输入文档和API设置")
335
-
336
- # 3. Generate answers using multi-threading
337
- qa_pairs = process_answers(
338
- question_items=questions,
339
- message=message,
340
- max_workers=max_workers,
341
- api_key=api_key,
342
- base_url=base_url,
343
- model=model_name,
344
- )
345
-
346
- # 4. Save results
347
- res_list = []
348
- with open(
349
- f"{os.path.basename(file_path).strip('.md')}.jsonl", "w", encoding="utf-8"
350
- ) as f:
351
- for question, answer in qa_pairs.items():
352
- # Build properly formatted JSON object
353
- qa_entry = {"instruction": question, "input": "", "output": answer}
354
- res_list.append(qa_entry)
355
- # Write to JSONL file (one JSON object per line)
356
- f.write(json.dumps(qa_entry, ensure_ascii=False) + "\n")
357
-
358
- logger.success(
359
- f"完成! 共生成 {len(qa_pairs)} 个问答对,已保存到 {os.path.basename(file_path).strip('.md')}.jsonl"
360
- )
361
-
362
- return res_list
363
-
364
-
365
- if __name__ == "__main__":
366
- generatr_qa_pairs(
367
- file_path=r"C:\Users\cykro\Desktop\文档整理\知识图谱\知识图谱概要设计.md",
368
- api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxx",
369
- base_url="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
370
- model_name="qwen-max",
371
- chunk_size=500,
372
- chunk_overlap=100,
373
- question_number=5,
374
- max_workers=5,
375
- # message=[]
376
- )
1
+ import json
2
+ import os.path
3
+ import re
4
+ import threading
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from pathlib import Path
7
+
8
+ import requests
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
11
+ from loguru import logger
12
+ from pyexpat.errors import messages
13
+ from tqdm import tqdm # For progress bar display
14
+
15
+ lock = threading.Lock()
16
+
17
+
18
+ # ------------prompt-----------------
19
+ def get_system_prompt_for_question(query_text, question_number):
20
+ """Generate system prompt for question generation task"""
21
+ system_prompt = f"""
22
+ # 角色使命
23
+ 你是一位专业的文本分析专家,擅长从复杂文本中提取关键信息并生成可用于模型微调的结构化数据(仅生成问题)。
24
+
25
+ ## 核心任务
26
+ 根据用户提供的文本,生成不少于 ${question_number} 个高质量问题。
27
+
28
+ ## 约束条件(重要!)
29
+ - 必须基于文本内容直接生成
30
+ - 问题应具有明确答案指向性
31
+ - 需覆盖文本的不同方面
32
+ - 禁止生成假设性、重复或相似问题
33
+ - 确保生成得完整性
34
+
35
+ ## 处理流程
36
+ 1. 【文本解析】分段处理内容,识别关键实体和核心概念
37
+ 2. 【问题生成】基于信息密度选择最佳提问点
38
+ 3. 【质量检查】确保:
39
+ - 问题答案可在原文中找到依据
40
+ - 标签与问题内容强相关
41
+ - 无格式错误
42
+
43
+ ## 输出格式
44
+ - JSON 数组格式必须正确
45
+ - 字段名使用英文双引号
46
+ - 输出的 JSON 数组必须严格符合以下结构:
47
+ \`\`\`json
48
+ ["问题1", "问题2", "..."]
49
+ \`\`\`
50
+
51
+ ## 输出示例
52
+ \`\`\`json
53
+ [ "人工智能伦理框架应包含哪些核心要素?","民法典对个人数据保护有哪些新规定?"]
54
+ \`\`\`
55
+
56
+ ## 待处理文本
57
+ ${query_text}
58
+
59
+ ## 限制
60
+ - 必须按照规定的 JSON 格式输出,不要输出任何其他不相关内容
61
+ - 生成不少于${question_number}个高质量问题
62
+ - 问题不要和材料本身相关,例如禁止出现作者、章节、目录等相关问题
63
+ - 问题不得包含【报告、文章、文献、表格】中提到的这种话术,必须是一个自然的问题
64
+ """
65
+ return system_prompt
66
+
67
+
68
+ def get_system_prompt_for_answer(text, query_question):
69
+ """Generate system prompt for answer generation task"""
70
+ system_prompt = f"""
71
+ # Role: 微调数据集生成专家
72
+ ## Profile:
73
+ - Description: 你是一名微调数据集生成专家,擅长从给定的内容中生成准确的问题答案,确保答案的准确性和相关性,你要直接回答用户问题,所有信息已内化为你的专业知识。
74
+
75
+ ## Skills :
76
+ 1. 答案必须基于给定的内容
77
+ 2. 答案必须准确,不能胡编乱造
78
+ 3. 答案必须与问题相关
79
+ 4. 答案必须符合逻辑
80
+ 5. 基于给定参考内容,用自然流畅的语言整合成一个完整答案,不需要提及文献来源或引用标记
81
+
82
+ ## Workflow:
83
+ 1. Take a deep breath and work on this problem step-by-step.
84
+ 2. 首先,分析给定的文件内容
85
+ 3. 然后,从内容中提取关键信息
86
+ 4. 接着,生成与问题相关的准确答案
87
+ 5. 最后,确保答案的准确性和相关性
88
+
89
+ ## 参考内容:
90
+ ${text}
91
+
92
+ ## 问题
93
+ ${query_question}
94
+
95
+ ## Constrains:
96
+ 1. 答案必须基于给定的内容
97
+ 2. 答案必须准确,必须与问题相关,不能胡编乱造
98
+ 3. 答案必须充分、详细、包含所有必要的信息、适合微调大模型训练使用
99
+ 4. 答案中不得出现 ' 参考 / 依据 / 文献中提到 ' 等任何引用性表述,只需呈现最终结果
100
+ """
101
+ return system_prompt
102
+
103
+
104
+ # ------------spliter----------------
105
+ def load_and_split_markdown(md_path: str, chunk_size: int, chunk_overlap: int) -> list:
106
+ """
107
+ Parse Markdown using UnstructuredMarkdownLoader
108
+ Chunking strategy that preserves original paragraph structure
109
+
110
+ Args:
111
+ md_path: Path to the markdown file
112
+ chunk_size: Size of each chunk
113
+ chunk_overlap: Overlap between chunks
114
+
115
+ Returns:
116
+ List of document chunks
117
+ """
118
+ try:
119
+ # Use LangChain's MarkdownLoader to load Markdown file
120
+ loader = UnstructuredMarkdownLoader(md_path)
121
+ documents = loader.load()
122
+ # Further split documents if needed
123
+ splitter = RecursiveCharacterTextSplitter(
124
+ chunk_size=chunk_size,
125
+ chunk_overlap=chunk_overlap,
126
+ length_function=len,
127
+ is_separator_regex=False,
128
+ )
129
+ return splitter.split_documents(documents)
130
+ except Exception as e:
131
+ logger.error(f"加载 {Path(md_path).name} 失败: {str(e)}")
132
+ return []
133
+
134
+
135
+ # ------------llm generator-------------------
136
+ def extract_json_from_llm_output(output: str):
137
+ """
138
+ Extract JSON content from LLM output, handling multiple possible formats
139
+
140
+ Args:
141
+ output: Raw output string from LLM
142
+
143
+ Returns:
144
+ Parsed JSON list if successful, None otherwise
145
+ """
146
+ # Try to parse the entire output directly
147
+ try:
148
+ return json.loads(output)
149
+ except json.JSONDecodeError:
150
+ pass
151
+
152
+ # Try to extract content wrapped in ```json ```
153
+ json_match = re.search(r"```json\n([\s\S]*?)\n```", output)
154
+ if json_match:
155
+ try:
156
+ return json.loads(json_match.group(1))
157
+ except json.JSONDecodeError as e:
158
+ print(f"解析 JSON 时出错: {e}")
159
+
160
+ # Try to extract the most JSON-like part
161
+ json_start = output.find("[")
162
+ json_end = output.rfind("]") + 1
163
+ if json_start != -1 and json_end != 0:
164
+ try:
165
+ return json.loads(output[json_start:json_end])
166
+ except json.JSONDecodeError:
167
+ pass
168
+
169
+ print("模型未按标准格式输出:", output)
170
+ return None
171
+
172
+
173
+ def llm_generator(
174
+ api_key: str,
175
+ model: str,
176
+ base_url: str,
177
+ prompt: str,
178
+ type: str,
179
+ message: list = None,
180
+ temperature: float = 0.7,
181
+ top_p: float = 0.9,
182
+ max_token: int = 2048,
183
+ ) -> list:
184
+ """Generate content using LLM API"""
185
+ try:
186
+ if not message:
187
+ message = [
188
+ {"role": "system", "content": prompt},
189
+ {"role": "user", "content": "请严格按照要求生成内容"},
190
+ ]
191
+ headers = {
192
+ "Authorization": f"Bearer {api_key}",
193
+ "Content-Type": "application/json",
194
+ }
195
+ data = {
196
+ "model": model,
197
+ "messages": message,
198
+ "max_tokens": max_token,
199
+ "temperature": temperature,
200
+ "top_p": top_p,
201
+ }
202
+ response = requests.post(base_url, headers=headers, json=data, timeout=30)
203
+ response.raise_for_status()
204
+ result = response.json()
205
+
206
+ # Parse LLM response
207
+ if "choices" in result and len(result["choices"]) > 0:
208
+ output = result["choices"][0]["message"]["content"]
209
+ if type == "question":
210
+ fmt_output = extract_json_from_llm_output(output)
211
+ else:
212
+ return output
213
+ return fmt_output
214
+ return []
215
+
216
+ except Exception as e:
217
+ print(f"LLM提取关键词失败: {e, e.__traceback__.tb_lineno}")
218
+ return []
219
+
220
+
221
+ # ------------thread_process-------------
222
+
223
+
224
+ def process_questions(
225
+ api_key: str,
226
+ model: str,
227
+ base_url: str,
228
+ page_content: list,
229
+ question_number: int,
230
+ message: list,
231
+ max_workers: int = 5,
232
+ ) -> list:
233
+ """Generate questions using multi-threading"""
234
+ total_questions = []
235
+
236
+ def _generate_questions(page):
237
+ """Inner function for question generation"""
238
+ prompt = get_system_prompt_for_question(page, question_number)
239
+ questions = llm_generator(
240
+ api_key=api_key,
241
+ model=model,
242
+ base_url=base_url,
243
+ message=message,
244
+ prompt=prompt,
245
+ type="question",
246
+ )
247
+ return [{"question": q, "page": page} for q in questions] if questions else []
248
+
249
+ logger.info(f"开始生成问题 (线程数: {max_workers})...")
250
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
251
+ futures = [executor.submit(_generate_questions, page) for page in page_content]
252
+
253
+ with tqdm(as_completed(futures), total=len(futures), desc="生成问题") as pbar:
254
+ for future in pbar:
255
+ result = future.result()
256
+ if result:
257
+ with lock:
258
+ total_questions.extend(result)
259
+ pbar.set_postfix({"已生成问题": len(total_questions)})
260
+
261
+ return total_questions
262
+
263
+
264
+ def process_answers(
265
+ api_key: str,
266
+ model: str,
267
+ base_url: str,
268
+ question_items: list,
269
+ message: list = None,
270
+ max_workers=5,
271
+ ) -> dict:
272
+ """Generate answers using multi-threading"""
273
+ qa_pairs = {}
274
+
275
+ def _generate_answer(item):
276
+ """Inner function for answer generation"""
277
+ prompt = get_system_prompt_for_answer(item["page"], item["question"])
278
+ answer = llm_generator(
279
+ api_key=api_key,
280
+ model=model,
281
+ base_url=base_url,
282
+ prompt=prompt,
283
+ message=message,
284
+ type="answer",
285
+ )
286
+ return item["question"], answer
287
+
288
+ logger.info(f"开始生成答案 (线程数: {max_workers})...")
289
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
290
+ futures = {
291
+ executor.submit(_generate_answer, item): item for item in question_items
292
+ }
293
+
294
+ with tqdm(as_completed(futures), total=len(futures), desc="生成答案") as pbar:
295
+ for future in pbar:
296
+ question, answer = future.result()
297
+ if answer:
298
+ with lock:
299
+ qa_pairs[question] = answer
300
+ pbar.set_postfix({"已生成答案": len(qa_pairs)})
301
+ return qa_pairs
302
+
303
+
304
+ def generatr_qa_pairs(
305
+ file_path: str,
306
+ api_key: str,
307
+ base_url: str,
308
+ model_name: str,
309
+ chunk_size=500,
310
+ chunk_overlap=100,
311
+ question_number=5,
312
+ message: list = None,
313
+ max_workers=5,
314
+ ):
315
+ """Main function to generate QA pairs from markdown file"""
316
+ # 1. Split markdown text into chunks`
317
+ pages = load_and_split_markdown(
318
+ md_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
319
+ )
320
+ page_content = [i.page_content for i in pages]
321
+ logger.info(f"markdown被分解了{len(page_content)}个chunk")
322
+
323
+ # 2. Generate questions using multi-threading
324
+ questions = process_questions(
325
+ page_content=page_content,
326
+ message=message,
327
+ question_number=question_number,
328
+ max_workers=max_workers,
329
+ api_key=api_key,
330
+ base_url=base_url,
331
+ model=model_name,
332
+ )
333
+ if not questions:
334
+ logger.error("未能生成任何问题,请检查输入文档和API设置")
335
+
336
+ # 3. Generate answers using multi-threading
337
+ qa_pairs = process_answers(
338
+ question_items=questions,
339
+ message=message,
340
+ max_workers=max_workers,
341
+ api_key=api_key,
342
+ base_url=base_url,
343
+ model=model_name,
344
+ )
345
+
346
+ logger.success(
347
+ f"完成! 共生成 {len(qa_pairs)} 个问答对"
348
+ )
349
+
350
+ #
351
+ res_list = []
352
+ for question, answer in qa_pairs.items():
353
+ qa_entry = {"instruction": question, "input": "", "output": answer}
354
+ res_list.append(qa_entry)
355
+ return res_list
356
+
357
+
358
+ if __name__ == "__main__":
359
+ generatr_qa_pairs(
360
+ file_path=r"C:\Users\cykro\Desktop\文档整理\知识图谱\知识图谱概要设计.md",
361
+ api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxx",
362
+ base_url="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
363
+ model_name="qwen-max",
364
+ chunk_size=500,
365
+ chunk_overlap=100,
366
+ question_number=5,
367
+ max_workers=5,
368
+ # message=[]
369
+ )