pydatamax 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamax/__init__.py +1 -1
- datamax/loader/core.py +118 -118
- datamax/loader/minio_handler.py +171 -171
- datamax/loader/oss_handler.py +191 -191
- datamax/parser/__init__.py +2 -4
- datamax/parser/base.py +76 -76
- datamax/parser/core.py +406 -288
- datamax/parser/csv_parser.py +31 -10
- datamax/parser/doc_parser.py +466 -10
- datamax/parser/docx_parser.py +449 -11
- datamax/parser/epub_parser.py +41 -41
- datamax/parser/html_parser.py +37 -37
- datamax/parser/image_parser.py +34 -34
- datamax/parser/json_parser.py +32 -10
- datamax/parser/md_parser.py +72 -72
- datamax/parser/pdf_parser.py +101 -101
- datamax/parser/ppt_parser.py +70 -20
- datamax/parser/pptx_parser.py +45 -45
- datamax/parser/txt_parser.py +45 -45
- datamax/parser/xls_parser.py +26 -26
- datamax/parser/xlsx_parser.py +212 -215
- datamax/utils/__init__.py +23 -2
- datamax/utils/constants.py +58 -58
- datamax/utils/data_cleaner.py +275 -237
- datamax/utils/env_setup.py +79 -79
- datamax/utils/gotocr_pdf.py +265 -265
- datamax/utils/mineru_operator.py +62 -62
- datamax/utils/paddleocr_pdf_operator.py +90 -90
- datamax/utils/ppt_extract.py +140 -140
- datamax/utils/qa_generator.py +369 -376
- datamax/utils/tokenizer.py +21 -21
- datamax/utils/uno_handler.py +426 -0
- {pydatamax-0.1.14.dist-info → pydatamax-0.1.15.dist-info}/METADATA +117 -5
- pydatamax-0.1.15.dist-info/RECORD +38 -0
- {pydatamax-0.1.14.dist-info → pydatamax-0.1.15.dist-info}/licenses/LICENSE +21 -21
- {pydatamax-0.1.14.dist-info → pydatamax-0.1.15.dist-info}/top_level.txt +0 -1
- pydatamax-0.1.14.dist-info/RECORD +0 -39
- tests/__init__.py +0 -0
- tests/test_basic.py +0 -20
- {pydatamax-0.1.14.dist-info → pydatamax-0.1.15.dist-info}/WHEEL +0 -0
datamax/utils/qa_generator.py
CHANGED
@@ -1,376 +1,369 @@
|
|
1
|
-
import json
|
2
|
-
import os.path
|
3
|
-
import re
|
4
|
-
import threading
|
5
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6
|
-
from pathlib import Path
|
7
|
-
|
8
|
-
import requests
|
9
|
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10
|
-
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
11
|
-
from loguru import logger
|
12
|
-
from pyexpat.errors import messages
|
13
|
-
from tqdm import tqdm # For progress bar display
|
14
|
-
|
15
|
-
lock = threading.Lock()
|
16
|
-
|
17
|
-
|
18
|
-
# ------------prompt-----------------
|
19
|
-
def get_system_prompt_for_question(query_text, question_number):
|
20
|
-
"""Generate system prompt for question generation task"""
|
21
|
-
system_prompt = f"""
|
22
|
-
# 角色使命
|
23
|
-
你是一位专业的文本分析专家,擅长从复杂文本中提取关键信息并生成可用于模型微调的结构化数据(仅生成问题)。
|
24
|
-
|
25
|
-
## 核心任务
|
26
|
-
根据用户提供的文本,生成不少于 ${question_number} 个高质量问题。
|
27
|
-
|
28
|
-
## 约束条件(重要!)
|
29
|
-
- 必须基于文本内容直接生成
|
30
|
-
- 问题应具有明确答案指向性
|
31
|
-
- 需覆盖文本的不同方面
|
32
|
-
- 禁止生成假设性、重复或相似问题
|
33
|
-
- 确保生成得完整性
|
34
|
-
|
35
|
-
## 处理流程
|
36
|
-
1. 【文本解析】分段处理内容,识别关键实体和核心概念
|
37
|
-
2. 【问题生成】基于信息密度选择最佳提问点
|
38
|
-
3. 【质量检查】确保:
|
39
|
-
- 问题答案可在原文中找到依据
|
40
|
-
- 标签与问题内容强相关
|
41
|
-
- 无格式错误
|
42
|
-
|
43
|
-
## 输出格式
|
44
|
-
- JSON 数组格式必须正确
|
45
|
-
- 字段名使用英文双引号
|
46
|
-
- 输出的 JSON 数组必须严格符合以下结构:
|
47
|
-
\`\`\`json
|
48
|
-
["问题1", "问题2", "..."]
|
49
|
-
\`\`\`
|
50
|
-
|
51
|
-
## 输出示例
|
52
|
-
\`\`\`json
|
53
|
-
[ "人工智能伦理框架应包含哪些核心要素?","民法典对个人数据保护有哪些新规定?"]
|
54
|
-
\`\`\`
|
55
|
-
|
56
|
-
## 待处理文本
|
57
|
-
${query_text}
|
58
|
-
|
59
|
-
## 限制
|
60
|
-
- 必须按照规定的 JSON 格式输出,不要输出任何其他不相关内容
|
61
|
-
- 生成不少于${question_number}个高质量问题
|
62
|
-
- 问题不要和材料本身相关,例如禁止出现作者、章节、目录等相关问题
|
63
|
-
- 问题不得包含【报告、文章、文献、表格】中提到的这种话术,必须是一个自然的问题
|
64
|
-
"""
|
65
|
-
return system_prompt
|
66
|
-
|
67
|
-
|
68
|
-
def get_system_prompt_for_answer(text, query_question):
|
69
|
-
"""Generate system prompt for answer generation task"""
|
70
|
-
system_prompt = f"""
|
71
|
-
# Role: 微调数据集生成专家
|
72
|
-
## Profile:
|
73
|
-
- Description: 你是一名微调数据集生成专家,擅长从给定的内容中生成准确的问题答案,确保答案的准确性和相关性,你要直接回答用户问题,所有信息已内化为你的专业知识。
|
74
|
-
|
75
|
-
## Skills :
|
76
|
-
1. 答案必须基于给定的内容
|
77
|
-
2. 答案必须准确,不能胡编乱造
|
78
|
-
3. 答案必须与问题相关
|
79
|
-
4. 答案必须符合逻辑
|
80
|
-
5. 基于给定参考内容,用自然流畅的语言整合成一个完整答案,不需要提及文献来源或引用标记
|
81
|
-
|
82
|
-
## Workflow:
|
83
|
-
1. Take a deep breath and work on this problem step-by-step.
|
84
|
-
2. 首先,分析给定的文件内容
|
85
|
-
3. 然后,从内容中提取关键信息
|
86
|
-
4. 接着,生成与问题相关的准确答案
|
87
|
-
5. 最后,确保答案的准确性和相关性
|
88
|
-
|
89
|
-
## 参考内容:
|
90
|
-
${text}
|
91
|
-
|
92
|
-
## 问题
|
93
|
-
${query_question}
|
94
|
-
|
95
|
-
## Constrains:
|
96
|
-
1. 答案必须基于给定的内容
|
97
|
-
2. 答案必须准确,必须与问题相关,不能胡编乱造
|
98
|
-
3. 答案必须充分、详细、包含所有必要的信息、适合微调大模型训练使用
|
99
|
-
4. 答案中不得出现 ' 参考 / 依据 / 文献中提到 ' 等任何引用性表述,只需呈现最终结果
|
100
|
-
"""
|
101
|
-
return system_prompt
|
102
|
-
|
103
|
-
|
104
|
-
# ------------spliter----------------
|
105
|
-
def load_and_split_markdown(md_path: str, chunk_size: int, chunk_overlap: int) -> list:
|
106
|
-
"""
|
107
|
-
Parse Markdown using UnstructuredMarkdownLoader
|
108
|
-
Chunking strategy that preserves original paragraph structure
|
109
|
-
|
110
|
-
Args:
|
111
|
-
md_path: Path to the markdown file
|
112
|
-
chunk_size: Size of each chunk
|
113
|
-
chunk_overlap: Overlap between chunks
|
114
|
-
|
115
|
-
Returns:
|
116
|
-
List of document chunks
|
117
|
-
"""
|
118
|
-
try:
|
119
|
-
# Use LangChain's MarkdownLoader to load Markdown file
|
120
|
-
loader = UnstructuredMarkdownLoader(md_path)
|
121
|
-
documents = loader.load()
|
122
|
-
# Further split documents if needed
|
123
|
-
splitter = RecursiveCharacterTextSplitter(
|
124
|
-
chunk_size=chunk_size,
|
125
|
-
chunk_overlap=chunk_overlap,
|
126
|
-
length_function=len,
|
127
|
-
is_separator_regex=False,
|
128
|
-
)
|
129
|
-
return splitter.split_documents(documents)
|
130
|
-
except Exception as e:
|
131
|
-
logger.error(f"加载 {Path(md_path).name} 失败: {str(e)}")
|
132
|
-
return []
|
133
|
-
|
134
|
-
|
135
|
-
# ------------llm generator-------------------
|
136
|
-
def extract_json_from_llm_output(output: str):
|
137
|
-
"""
|
138
|
-
Extract JSON content from LLM output, handling multiple possible formats
|
139
|
-
|
140
|
-
Args:
|
141
|
-
output: Raw output string from LLM
|
142
|
-
|
143
|
-
Returns:
|
144
|
-
Parsed JSON list if successful, None otherwise
|
145
|
-
"""
|
146
|
-
# Try to parse the entire output directly
|
147
|
-
try:
|
148
|
-
return json.loads(output)
|
149
|
-
except json.JSONDecodeError:
|
150
|
-
pass
|
151
|
-
|
152
|
-
# Try to extract content wrapped in ```json ```
|
153
|
-
json_match = re.search(r"```json\n([\s\S]*?)\n```", output)
|
154
|
-
if json_match:
|
155
|
-
try:
|
156
|
-
return json.loads(json_match.group(1))
|
157
|
-
except json.JSONDecodeError as e:
|
158
|
-
print(f"解析 JSON 时出错: {e}")
|
159
|
-
|
160
|
-
# Try to extract the most JSON-like part
|
161
|
-
json_start = output.find("[")
|
162
|
-
json_end = output.rfind("]") + 1
|
163
|
-
if json_start != -1 and json_end != 0:
|
164
|
-
try:
|
165
|
-
return json.loads(output[json_start:json_end])
|
166
|
-
except json.JSONDecodeError:
|
167
|
-
pass
|
168
|
-
|
169
|
-
print("模型未按标准格式输出:", output)
|
170
|
-
return None
|
171
|
-
|
172
|
-
|
173
|
-
def llm_generator(
|
174
|
-
api_key: str,
|
175
|
-
model: str,
|
176
|
-
base_url: str,
|
177
|
-
prompt: str,
|
178
|
-
type: str,
|
179
|
-
message: list = None,
|
180
|
-
temperature: float = 0.7,
|
181
|
-
top_p: float = 0.9,
|
182
|
-
max_token: int = 2048,
|
183
|
-
) -> list:
|
184
|
-
"""Generate content using LLM API"""
|
185
|
-
try:
|
186
|
-
if not message:
|
187
|
-
message = [
|
188
|
-
{"role": "system", "content": prompt},
|
189
|
-
{"role": "user", "content": "请严格按照要求生成内容"},
|
190
|
-
]
|
191
|
-
headers = {
|
192
|
-
"Authorization": f"Bearer {api_key}",
|
193
|
-
"Content-Type": "application/json",
|
194
|
-
}
|
195
|
-
data = {
|
196
|
-
"model": model,
|
197
|
-
"messages": message,
|
198
|
-
"max_tokens": max_token,
|
199
|
-
"temperature": temperature,
|
200
|
-
"top_p": top_p,
|
201
|
-
}
|
202
|
-
response = requests.post(base_url, headers=headers, json=data, timeout=30)
|
203
|
-
response.raise_for_status()
|
204
|
-
result = response.json()
|
205
|
-
|
206
|
-
# Parse LLM response
|
207
|
-
if "choices" in result and len(result["choices"]) > 0:
|
208
|
-
output = result["choices"][0]["message"]["content"]
|
209
|
-
if type == "question":
|
210
|
-
fmt_output = extract_json_from_llm_output(output)
|
211
|
-
else:
|
212
|
-
return output
|
213
|
-
return fmt_output
|
214
|
-
return []
|
215
|
-
|
216
|
-
except Exception as e:
|
217
|
-
print(f"LLM提取关键词失败: {e, e.__traceback__.tb_lineno}")
|
218
|
-
return []
|
219
|
-
|
220
|
-
|
221
|
-
# ------------thread_process-------------
|
222
|
-
|
223
|
-
|
224
|
-
def process_questions(
|
225
|
-
api_key: str,
|
226
|
-
model: str,
|
227
|
-
base_url: str,
|
228
|
-
page_content: list,
|
229
|
-
question_number: int,
|
230
|
-
message: list,
|
231
|
-
max_workers: int = 5,
|
232
|
-
) -> list:
|
233
|
-
"""Generate questions using multi-threading"""
|
234
|
-
total_questions = []
|
235
|
-
|
236
|
-
def _generate_questions(page):
|
237
|
-
"""Inner function for question generation"""
|
238
|
-
prompt = get_system_prompt_for_question(page, question_number)
|
239
|
-
questions = llm_generator(
|
240
|
-
api_key=api_key,
|
241
|
-
model=model,
|
242
|
-
base_url=base_url,
|
243
|
-
message=message,
|
244
|
-
prompt=prompt,
|
245
|
-
type="question",
|
246
|
-
)
|
247
|
-
return [{"question": q, "page": page} for q in questions] if questions else []
|
248
|
-
|
249
|
-
logger.info(f"开始生成问题 (线程数: {max_workers})...")
|
250
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
251
|
-
futures = [executor.submit(_generate_questions, page) for page in page_content]
|
252
|
-
|
253
|
-
with tqdm(as_completed(futures), total=len(futures), desc="生成问题") as pbar:
|
254
|
-
for future in pbar:
|
255
|
-
result = future.result()
|
256
|
-
if result:
|
257
|
-
with lock:
|
258
|
-
total_questions.extend(result)
|
259
|
-
pbar.set_postfix({"已生成问题": len(total_questions)})
|
260
|
-
|
261
|
-
return total_questions
|
262
|
-
|
263
|
-
|
264
|
-
def process_answers(
|
265
|
-
api_key: str,
|
266
|
-
model: str,
|
267
|
-
base_url: str,
|
268
|
-
question_items: list,
|
269
|
-
message: list = None,
|
270
|
-
max_workers=5,
|
271
|
-
) -> dict:
|
272
|
-
"""Generate answers using multi-threading"""
|
273
|
-
qa_pairs = {}
|
274
|
-
|
275
|
-
def _generate_answer(item):
|
276
|
-
"""Inner function for answer generation"""
|
277
|
-
prompt = get_system_prompt_for_answer(item["page"], item["question"])
|
278
|
-
answer = llm_generator(
|
279
|
-
api_key=api_key,
|
280
|
-
model=model,
|
281
|
-
base_url=base_url,
|
282
|
-
prompt=prompt,
|
283
|
-
message=message,
|
284
|
-
type="answer",
|
285
|
-
)
|
286
|
-
return item["question"], answer
|
287
|
-
|
288
|
-
logger.info(f"开始生成答案 (线程数: {max_workers})...")
|
289
|
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
290
|
-
futures = {
|
291
|
-
executor.submit(_generate_answer, item): item for item in question_items
|
292
|
-
}
|
293
|
-
|
294
|
-
with tqdm(as_completed(futures), total=len(futures), desc="生成答案") as pbar:
|
295
|
-
for future in pbar:
|
296
|
-
question, answer = future.result()
|
297
|
-
if answer:
|
298
|
-
with lock:
|
299
|
-
qa_pairs[question] = answer
|
300
|
-
pbar.set_postfix({"已生成答案": len(qa_pairs)})
|
301
|
-
return qa_pairs
|
302
|
-
|
303
|
-
|
304
|
-
def generatr_qa_pairs(
|
305
|
-
file_path: str,
|
306
|
-
api_key: str,
|
307
|
-
base_url: str,
|
308
|
-
model_name: str,
|
309
|
-
chunk_size=500,
|
310
|
-
chunk_overlap=100,
|
311
|
-
question_number=5,
|
312
|
-
message: list = None,
|
313
|
-
max_workers=5,
|
314
|
-
):
|
315
|
-
"""Main function to generate QA pairs from markdown file"""
|
316
|
-
# 1. Split markdown text into chunks
|
317
|
-
pages = load_and_split_markdown(
|
318
|
-
md_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
319
|
-
)
|
320
|
-
page_content = [i.page_content for i in pages]
|
321
|
-
logger.info(f"markdown被分解了{len(page_content)}个chunk")
|
322
|
-
|
323
|
-
# 2. Generate questions using multi-threading
|
324
|
-
questions = process_questions(
|
325
|
-
page_content=page_content,
|
326
|
-
message=message,
|
327
|
-
question_number=question_number,
|
328
|
-
max_workers=max_workers,
|
329
|
-
api_key=api_key,
|
330
|
-
base_url=base_url,
|
331
|
-
model=model_name,
|
332
|
-
)
|
333
|
-
if not questions:
|
334
|
-
logger.error("未能生成任何问题,请检查输入文档和API设置")
|
335
|
-
|
336
|
-
# 3. Generate answers using multi-threading
|
337
|
-
qa_pairs = process_answers(
|
338
|
-
question_items=questions,
|
339
|
-
message=message,
|
340
|
-
max_workers=max_workers,
|
341
|
-
api_key=api_key,
|
342
|
-
base_url=base_url,
|
343
|
-
model=model_name,
|
344
|
-
)
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
model_name="qwen-max",
|
371
|
-
chunk_size=500,
|
372
|
-
chunk_overlap=100,
|
373
|
-
question_number=5,
|
374
|
-
max_workers=5,
|
375
|
-
# message=[]
|
376
|
-
)
|
1
|
+
import json
|
2
|
+
import os.path
|
3
|
+
import re
|
4
|
+
import threading
|
5
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import requests
|
9
|
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
10
|
+
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
11
|
+
from loguru import logger
|
12
|
+
from pyexpat.errors import messages
|
13
|
+
from tqdm import tqdm # For progress bar display
|
14
|
+
|
15
|
+
lock = threading.Lock()
|
16
|
+
|
17
|
+
|
18
|
+
# ------------prompt-----------------
|
19
|
+
def get_system_prompt_for_question(query_text, question_number):
|
20
|
+
"""Generate system prompt for question generation task"""
|
21
|
+
system_prompt = f"""
|
22
|
+
# 角色使命
|
23
|
+
你是一位专业的文本分析专家,擅长从复杂文本中提取关键信息并生成可用于模型微调的结构化数据(仅生成问题)。
|
24
|
+
|
25
|
+
## 核心任务
|
26
|
+
根据用户提供的文本,生成不少于 ${question_number} 个高质量问题。
|
27
|
+
|
28
|
+
## 约束条件(重要!)
|
29
|
+
- 必须基于文本内容直接生成
|
30
|
+
- 问题应具有明确答案指向性
|
31
|
+
- 需覆盖文本的不同方面
|
32
|
+
- 禁止生成假设性、重复或相似问题
|
33
|
+
- 确保生成得完整性
|
34
|
+
|
35
|
+
## 处理流程
|
36
|
+
1. 【文本解析】分段处理内容,识别关键实体和核心概念
|
37
|
+
2. 【问题生成】基于信息密度选择最佳提问点
|
38
|
+
3. 【质量检查】确保:
|
39
|
+
- 问题答案可在原文中找到依据
|
40
|
+
- 标签与问题内容强相关
|
41
|
+
- 无格式错误
|
42
|
+
|
43
|
+
## 输出格式
|
44
|
+
- JSON 数组格式必须正确
|
45
|
+
- 字段名使用英文双引号
|
46
|
+
- 输出的 JSON 数组必须严格符合以下结构:
|
47
|
+
\`\`\`json
|
48
|
+
["问题1", "问题2", "..."]
|
49
|
+
\`\`\`
|
50
|
+
|
51
|
+
## 输出示例
|
52
|
+
\`\`\`json
|
53
|
+
[ "人工智能伦理框架应包含哪些核心要素?","民法典对个人数据保护有哪些新规定?"]
|
54
|
+
\`\`\`
|
55
|
+
|
56
|
+
## 待处理文本
|
57
|
+
${query_text}
|
58
|
+
|
59
|
+
## 限制
|
60
|
+
- 必须按照规定的 JSON 格式输出,不要输出任何其他不相关内容
|
61
|
+
- 生成不少于${question_number}个高质量问题
|
62
|
+
- 问题不要和材料本身相关,例如禁止出现作者、章节、目录等相关问题
|
63
|
+
- 问题不得包含【报告、文章、文献、表格】中提到的这种话术,必须是一个自然的问题
|
64
|
+
"""
|
65
|
+
return system_prompt
|
66
|
+
|
67
|
+
|
68
|
+
def get_system_prompt_for_answer(text, query_question):
|
69
|
+
"""Generate system prompt for answer generation task"""
|
70
|
+
system_prompt = f"""
|
71
|
+
# Role: 微调数据集生成专家
|
72
|
+
## Profile:
|
73
|
+
- Description: 你是一名微调数据集生成专家,擅长从给定的内容中生成准确的问题答案,确保答案的准确性和相关性,你要直接回答用户问题,所有信息已内化为你的专业知识。
|
74
|
+
|
75
|
+
## Skills :
|
76
|
+
1. 答案必须基于给定的内容
|
77
|
+
2. 答案必须准确,不能胡编乱造
|
78
|
+
3. 答案必须与问题相关
|
79
|
+
4. 答案必须符合逻辑
|
80
|
+
5. 基于给定参考内容,用自然流畅的语言整合成一个完整答案,不需要提及文献来源或引用标记
|
81
|
+
|
82
|
+
## Workflow:
|
83
|
+
1. Take a deep breath and work on this problem step-by-step.
|
84
|
+
2. 首先,分析给定的文件内容
|
85
|
+
3. 然后,从内容中提取关键信息
|
86
|
+
4. 接着,生成与问题相关的准确答案
|
87
|
+
5. 最后,确保答案的准确性和相关性
|
88
|
+
|
89
|
+
## 参考内容:
|
90
|
+
${text}
|
91
|
+
|
92
|
+
## 问题
|
93
|
+
${query_question}
|
94
|
+
|
95
|
+
## Constrains:
|
96
|
+
1. 答案必须基于给定的内容
|
97
|
+
2. 答案必须准确,必须与问题相关,不能胡编乱造
|
98
|
+
3. 答案必须充分、详细、包含所有必要的信息、适合微调大模型训练使用
|
99
|
+
4. 答案中不得出现 ' 参考 / 依据 / 文献中提到 ' 等任何引用性表述,只需呈现最终结果
|
100
|
+
"""
|
101
|
+
return system_prompt
|
102
|
+
|
103
|
+
|
104
|
+
# ------------spliter----------------
|
105
|
+
def load_and_split_markdown(md_path: str, chunk_size: int, chunk_overlap: int) -> list:
|
106
|
+
"""
|
107
|
+
Parse Markdown using UnstructuredMarkdownLoader
|
108
|
+
Chunking strategy that preserves original paragraph structure
|
109
|
+
|
110
|
+
Args:
|
111
|
+
md_path: Path to the markdown file
|
112
|
+
chunk_size: Size of each chunk
|
113
|
+
chunk_overlap: Overlap between chunks
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
List of document chunks
|
117
|
+
"""
|
118
|
+
try:
|
119
|
+
# Use LangChain's MarkdownLoader to load Markdown file
|
120
|
+
loader = UnstructuredMarkdownLoader(md_path)
|
121
|
+
documents = loader.load()
|
122
|
+
# Further split documents if needed
|
123
|
+
splitter = RecursiveCharacterTextSplitter(
|
124
|
+
chunk_size=chunk_size,
|
125
|
+
chunk_overlap=chunk_overlap,
|
126
|
+
length_function=len,
|
127
|
+
is_separator_regex=False,
|
128
|
+
)
|
129
|
+
return splitter.split_documents(documents)
|
130
|
+
except Exception as e:
|
131
|
+
logger.error(f"加载 {Path(md_path).name} 失败: {str(e)}")
|
132
|
+
return []
|
133
|
+
|
134
|
+
|
135
|
+
# ------------llm generator-------------------
|
136
|
+
def extract_json_from_llm_output(output: str):
|
137
|
+
"""
|
138
|
+
Extract JSON content from LLM output, handling multiple possible formats
|
139
|
+
|
140
|
+
Args:
|
141
|
+
output: Raw output string from LLM
|
142
|
+
|
143
|
+
Returns:
|
144
|
+
Parsed JSON list if successful, None otherwise
|
145
|
+
"""
|
146
|
+
# Try to parse the entire output directly
|
147
|
+
try:
|
148
|
+
return json.loads(output)
|
149
|
+
except json.JSONDecodeError:
|
150
|
+
pass
|
151
|
+
|
152
|
+
# Try to extract content wrapped in ```json ```
|
153
|
+
json_match = re.search(r"```json\n([\s\S]*?)\n```", output)
|
154
|
+
if json_match:
|
155
|
+
try:
|
156
|
+
return json.loads(json_match.group(1))
|
157
|
+
except json.JSONDecodeError as e:
|
158
|
+
print(f"解析 JSON 时出错: {e}")
|
159
|
+
|
160
|
+
# Try to extract the most JSON-like part
|
161
|
+
json_start = output.find("[")
|
162
|
+
json_end = output.rfind("]") + 1
|
163
|
+
if json_start != -1 and json_end != 0:
|
164
|
+
try:
|
165
|
+
return json.loads(output[json_start:json_end])
|
166
|
+
except json.JSONDecodeError:
|
167
|
+
pass
|
168
|
+
|
169
|
+
print("模型未按标准格式输出:", output)
|
170
|
+
return None
|
171
|
+
|
172
|
+
|
173
|
+
def llm_generator(
|
174
|
+
api_key: str,
|
175
|
+
model: str,
|
176
|
+
base_url: str,
|
177
|
+
prompt: str,
|
178
|
+
type: str,
|
179
|
+
message: list = None,
|
180
|
+
temperature: float = 0.7,
|
181
|
+
top_p: float = 0.9,
|
182
|
+
max_token: int = 2048,
|
183
|
+
) -> list:
|
184
|
+
"""Generate content using LLM API"""
|
185
|
+
try:
|
186
|
+
if not message:
|
187
|
+
message = [
|
188
|
+
{"role": "system", "content": prompt},
|
189
|
+
{"role": "user", "content": "请严格按照要求生成内容"},
|
190
|
+
]
|
191
|
+
headers = {
|
192
|
+
"Authorization": f"Bearer {api_key}",
|
193
|
+
"Content-Type": "application/json",
|
194
|
+
}
|
195
|
+
data = {
|
196
|
+
"model": model,
|
197
|
+
"messages": message,
|
198
|
+
"max_tokens": max_token,
|
199
|
+
"temperature": temperature,
|
200
|
+
"top_p": top_p,
|
201
|
+
}
|
202
|
+
response = requests.post(base_url, headers=headers, json=data, timeout=30)
|
203
|
+
response.raise_for_status()
|
204
|
+
result = response.json()
|
205
|
+
|
206
|
+
# Parse LLM response
|
207
|
+
if "choices" in result and len(result["choices"]) > 0:
|
208
|
+
output = result["choices"][0]["message"]["content"]
|
209
|
+
if type == "question":
|
210
|
+
fmt_output = extract_json_from_llm_output(output)
|
211
|
+
else:
|
212
|
+
return output
|
213
|
+
return fmt_output
|
214
|
+
return []
|
215
|
+
|
216
|
+
except Exception as e:
|
217
|
+
print(f"LLM提取关键词失败: {e, e.__traceback__.tb_lineno}")
|
218
|
+
return []
|
219
|
+
|
220
|
+
|
221
|
+
# ------------thread_process-------------
|
222
|
+
|
223
|
+
|
224
|
+
def process_questions(
|
225
|
+
api_key: str,
|
226
|
+
model: str,
|
227
|
+
base_url: str,
|
228
|
+
page_content: list,
|
229
|
+
question_number: int,
|
230
|
+
message: list,
|
231
|
+
max_workers: int = 5,
|
232
|
+
) -> list:
|
233
|
+
"""Generate questions using multi-threading"""
|
234
|
+
total_questions = []
|
235
|
+
|
236
|
+
def _generate_questions(page):
|
237
|
+
"""Inner function for question generation"""
|
238
|
+
prompt = get_system_prompt_for_question(page, question_number)
|
239
|
+
questions = llm_generator(
|
240
|
+
api_key=api_key,
|
241
|
+
model=model,
|
242
|
+
base_url=base_url,
|
243
|
+
message=message,
|
244
|
+
prompt=prompt,
|
245
|
+
type="question",
|
246
|
+
)
|
247
|
+
return [{"question": q, "page": page} for q in questions] if questions else []
|
248
|
+
|
249
|
+
logger.info(f"开始生成问题 (线程数: {max_workers})...")
|
250
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
251
|
+
futures = [executor.submit(_generate_questions, page) for page in page_content]
|
252
|
+
|
253
|
+
with tqdm(as_completed(futures), total=len(futures), desc="生成问题") as pbar:
|
254
|
+
for future in pbar:
|
255
|
+
result = future.result()
|
256
|
+
if result:
|
257
|
+
with lock:
|
258
|
+
total_questions.extend(result)
|
259
|
+
pbar.set_postfix({"已生成问题": len(total_questions)})
|
260
|
+
|
261
|
+
return total_questions
|
262
|
+
|
263
|
+
|
264
|
+
def process_answers(
|
265
|
+
api_key: str,
|
266
|
+
model: str,
|
267
|
+
base_url: str,
|
268
|
+
question_items: list,
|
269
|
+
message: list = None,
|
270
|
+
max_workers=5,
|
271
|
+
) -> dict:
|
272
|
+
"""Generate answers using multi-threading"""
|
273
|
+
qa_pairs = {}
|
274
|
+
|
275
|
+
def _generate_answer(item):
|
276
|
+
"""Inner function for answer generation"""
|
277
|
+
prompt = get_system_prompt_for_answer(item["page"], item["question"])
|
278
|
+
answer = llm_generator(
|
279
|
+
api_key=api_key,
|
280
|
+
model=model,
|
281
|
+
base_url=base_url,
|
282
|
+
prompt=prompt,
|
283
|
+
message=message,
|
284
|
+
type="answer",
|
285
|
+
)
|
286
|
+
return item["question"], answer
|
287
|
+
|
288
|
+
logger.info(f"开始生成答案 (线程数: {max_workers})...")
|
289
|
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
290
|
+
futures = {
|
291
|
+
executor.submit(_generate_answer, item): item for item in question_items
|
292
|
+
}
|
293
|
+
|
294
|
+
with tqdm(as_completed(futures), total=len(futures), desc="生成答案") as pbar:
|
295
|
+
for future in pbar:
|
296
|
+
question, answer = future.result()
|
297
|
+
if answer:
|
298
|
+
with lock:
|
299
|
+
qa_pairs[question] = answer
|
300
|
+
pbar.set_postfix({"已生成答案": len(qa_pairs)})
|
301
|
+
return qa_pairs
|
302
|
+
|
303
|
+
|
304
|
+
def generatr_qa_pairs(
|
305
|
+
file_path: str,
|
306
|
+
api_key: str,
|
307
|
+
base_url: str,
|
308
|
+
model_name: str,
|
309
|
+
chunk_size=500,
|
310
|
+
chunk_overlap=100,
|
311
|
+
question_number=5,
|
312
|
+
message: list = None,
|
313
|
+
max_workers=5,
|
314
|
+
):
|
315
|
+
"""Main function to generate QA pairs from markdown file"""
|
316
|
+
# 1. Split markdown text into chunks`
|
317
|
+
pages = load_and_split_markdown(
|
318
|
+
md_path=file_path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
319
|
+
)
|
320
|
+
page_content = [i.page_content for i in pages]
|
321
|
+
logger.info(f"markdown被分解了{len(page_content)}个chunk")
|
322
|
+
|
323
|
+
# 2. Generate questions using multi-threading
|
324
|
+
questions = process_questions(
|
325
|
+
page_content=page_content,
|
326
|
+
message=message,
|
327
|
+
question_number=question_number,
|
328
|
+
max_workers=max_workers,
|
329
|
+
api_key=api_key,
|
330
|
+
base_url=base_url,
|
331
|
+
model=model_name,
|
332
|
+
)
|
333
|
+
if not questions:
|
334
|
+
logger.error("未能生成任何问题,请检查输入文档和API设置")
|
335
|
+
|
336
|
+
# 3. Generate answers using multi-threading
|
337
|
+
qa_pairs = process_answers(
|
338
|
+
question_items=questions,
|
339
|
+
message=message,
|
340
|
+
max_workers=max_workers,
|
341
|
+
api_key=api_key,
|
342
|
+
base_url=base_url,
|
343
|
+
model=model_name,
|
344
|
+
)
|
345
|
+
|
346
|
+
logger.success(
|
347
|
+
f"完成! 共生成 {len(qa_pairs)} 个问答对"
|
348
|
+
)
|
349
|
+
|
350
|
+
#
|
351
|
+
res_list = []
|
352
|
+
for question, answer in qa_pairs.items():
|
353
|
+
qa_entry = {"instruction": question, "input": "", "output": answer}
|
354
|
+
res_list.append(qa_entry)
|
355
|
+
return res_list
|
356
|
+
|
357
|
+
|
358
|
+
if __name__ == "__main__":
|
359
|
+
generatr_qa_pairs(
|
360
|
+
file_path=r"C:\Users\cykro\Desktop\文档整理\知识图谱\知识图谱概要设计.md",
|
361
|
+
api_key="sk-xxxxxxxxxxxxxxxxxxxxxxxxxx",
|
362
|
+
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1/chat/completions",
|
363
|
+
model_name="qwen-max",
|
364
|
+
chunk_size=500,
|
365
|
+
chunk_overlap=100,
|
366
|
+
question_number=5,
|
367
|
+
max_workers=5,
|
368
|
+
# message=[]
|
369
|
+
)
|