magic-pdf 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/libs/boxbase.py +5 -2
- magic_pdf/libs/draw_bbox.py +14 -2
- magic_pdf/libs/language.py +9 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/batch_analyze.py +103 -99
- magic_pdf/model/doc_analyze_by_custom_model.py +77 -18
- magic_pdf/model/pdf_extract_kit.py +23 -21
- magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py +7 -3
- magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py +1 -1
- magic_pdf/model/sub_modules/model_init.py +4 -3
- magic_pdf/model/sub_modules/ocr/paddleocr/ocr_utils.py +33 -26
- magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py +25 -6
- magic_pdf/pdf_parse_union_core_v2.py +137 -32
- magic_pdf/post_proc/llm_aided.py +59 -26
- magic_pdf/post_proc/llm_aided_ocr.py +689 -0
- magic_pdf/pre_proc/ocr_span_list_modify.py +1 -1
- magic_pdf/resources/model_config/model_configs.yaml +2 -2
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/METADATA +50 -41
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/RECORD +23 -22
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/WHEEL +1 -1
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-1.0.0.dist-info → magic_pdf-1.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,689 @@
|
|
1
|
+
import os
|
2
|
+
import glob
|
3
|
+
import traceback
|
4
|
+
import asyncio
|
5
|
+
import json
|
6
|
+
import re
|
7
|
+
import urllib.request
|
8
|
+
import logging
|
9
|
+
from concurrent.futures import ThreadPoolExecutor
|
10
|
+
import warnings
|
11
|
+
from typing import List, Dict, Tuple, Optional
|
12
|
+
from pdf2image import convert_from_path
|
13
|
+
import pytesseract
|
14
|
+
from llama_cpp import Llama, LlamaGrammar
|
15
|
+
import tiktoken
|
16
|
+
import numpy as np
|
17
|
+
from PIL import Image
|
18
|
+
from decouple import Config as DecoupleConfig, RepositoryEnv
|
19
|
+
import cv2
|
20
|
+
from filelock import FileLock, Timeout
|
21
|
+
from transformers import AutoTokenizer
|
22
|
+
from openai import AsyncOpenAI
|
23
|
+
from anthropic import AsyncAnthropic
|
24
|
+
try:
|
25
|
+
import nvgpu
|
26
|
+
GPU_AVAILABLE = True
|
27
|
+
except ImportError:
|
28
|
+
GPU_AVAILABLE = False
|
29
|
+
|
30
|
+
# Configuration
|
31
|
+
config = DecoupleConfig(RepositoryEnv('.env'))
|
32
|
+
|
33
|
+
USE_LOCAL_LLM = config.get("USE_LOCAL_LLM", default=False, cast=bool)
|
34
|
+
API_PROVIDER = config.get("API_PROVIDER", default="OPENAI", cast=str) # OPENAI or CLAUDE
|
35
|
+
ANTHROPIC_API_KEY = config.get("ANTHROPIC_API_KEY", default="your-anthropic-api-key", cast=str)
|
36
|
+
OPENAI_API_KEY = config.get("OPENAI_API_KEY", default="your-openai-api-key", cast=str)
|
37
|
+
CLAUDE_MODEL_STRING = config.get("CLAUDE_MODEL_STRING", default="claude-3-haiku-20240307", cast=str)
|
38
|
+
CLAUDE_MAX_TOKENS = 4096 # Maximum allowed tokens for Claude API
|
39
|
+
TOKEN_BUFFER = 500 # Buffer to account for token estimation inaccuracies
|
40
|
+
TOKEN_CUSHION = 300 # Don't use the full max tokens to avoid hitting the limit
|
41
|
+
OPENAI_COMPLETION_MODEL = config.get("OPENAI_COMPLETION_MODEL", default="gpt-4o-mini", cast=str)
|
42
|
+
OPENAI_EMBEDDING_MODEL = config.get("OPENAI_EMBEDDING_MODEL", default="text-embedding-3-small", cast=str)
|
43
|
+
OPENAI_MAX_TOKENS = 12000 # Maximum allowed tokens for OpenAI API
|
44
|
+
DEFAULT_LOCAL_MODEL_NAME = "Llama-3.1-8B-Lexi-Uncensored_Q5_fixedrope.gguf"
|
45
|
+
LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS = 2048
|
46
|
+
USE_VERBOSE = False
|
47
|
+
|
48
|
+
openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
|
49
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
50
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
51
|
+
|
52
|
+
# GPU Check
|
53
|
+
def is_gpu_available():
|
54
|
+
if not GPU_AVAILABLE:
|
55
|
+
logging.warning("GPU support not available: nvgpu module not found")
|
56
|
+
return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0, "error": "nvgpu module not found"}
|
57
|
+
try:
|
58
|
+
gpu_info = nvgpu.gpu_info()
|
59
|
+
num_gpus = len(gpu_info)
|
60
|
+
if num_gpus == 0:
|
61
|
+
logging.warning("No GPUs found on the system")
|
62
|
+
return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0}
|
63
|
+
first_gpu_vram = gpu_info[0]['mem_total']
|
64
|
+
total_vram = sum(gpu['mem_total'] for gpu in gpu_info)
|
65
|
+
logging.info(f"GPU(s) found: {num_gpus}, Total VRAM: {total_vram} MB")
|
66
|
+
return {"gpu_found": True, "num_gpus": num_gpus, "first_gpu_vram": first_gpu_vram, "total_vram": total_vram, "gpu_info": gpu_info}
|
67
|
+
except Exception as e:
|
68
|
+
logging.error(f"Error checking GPU availability: {e}")
|
69
|
+
return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0, "error": str(e)}
|
70
|
+
|
71
|
+
# Model Download
|
72
|
+
async def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
|
73
|
+
download_status = []
|
74
|
+
model_url = "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-GGUF/resolve/main/Llama-3.1-8B-Lexi-Uncensored_Q5_fixedrope.gguf"
|
75
|
+
model_name = os.path.basename(model_url)
|
76
|
+
current_file_path = os.path.abspath(__file__)
|
77
|
+
base_dir = os.path.dirname(current_file_path)
|
78
|
+
models_dir = os.path.join(base_dir, 'models')
|
79
|
+
|
80
|
+
os.makedirs(models_dir, exist_ok=True)
|
81
|
+
lock = FileLock(os.path.join(models_dir, "download.lock"))
|
82
|
+
status = {"url": model_url, "status": "success", "message": "File already exists."}
|
83
|
+
filename = os.path.join(models_dir, model_name)
|
84
|
+
|
85
|
+
try:
|
86
|
+
with lock.acquire(timeout=1200):
|
87
|
+
if not os.path.exists(filename):
|
88
|
+
logging.info(f"Downloading model {model_name} from {model_url}...")
|
89
|
+
urllib.request.urlretrieve(model_url, filename)
|
90
|
+
file_size = os.path.getsize(filename) / (1024 * 1024)
|
91
|
+
if file_size < 100:
|
92
|
+
os.remove(filename)
|
93
|
+
status["status"] = "failure"
|
94
|
+
status["message"] = f"Downloaded file is too small ({file_size:.2f} MB), probably not a valid model file."
|
95
|
+
logging.error(f"Error: {status['message']}")
|
96
|
+
else:
|
97
|
+
logging.info(f"Successfully downloaded: {filename} (Size: {file_size:.2f} MB)")
|
98
|
+
else:
|
99
|
+
logging.info(f"Model file already exists: {filename}")
|
100
|
+
except Timeout:
|
101
|
+
logging.error(f"Error: Could not acquire lock for downloading {model_name}")
|
102
|
+
status["status"] = "failure"
|
103
|
+
status["message"] = "Could not acquire lock for downloading."
|
104
|
+
|
105
|
+
download_status.append(status)
|
106
|
+
logging.info("Model download process completed.")
|
107
|
+
return [model_name], download_status
|
108
|
+
|
109
|
+
# Model Loading
|
110
|
+
def load_model(llm_model_name: str, raise_exception: bool = True):
|
111
|
+
global USE_VERBOSE
|
112
|
+
try:
|
113
|
+
current_file_path = os.path.abspath(__file__)
|
114
|
+
base_dir = os.path.dirname(current_file_path)
|
115
|
+
models_dir = os.path.join(base_dir, 'models')
|
116
|
+
matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*"))
|
117
|
+
if not matching_files:
|
118
|
+
logging.error(f"Error: No model file found matching: {llm_model_name}")
|
119
|
+
raise FileNotFoundError
|
120
|
+
model_file_path = max(matching_files, key=os.path.getmtime)
|
121
|
+
logging.info(f"Loading model: {model_file_path}")
|
122
|
+
try:
|
123
|
+
logging.info("Attempting to load model with GPU acceleration...")
|
124
|
+
model_instance = Llama(
|
125
|
+
model_path=model_file_path,
|
126
|
+
n_ctx=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS,
|
127
|
+
verbose=USE_VERBOSE,
|
128
|
+
n_gpu_layers=-1
|
129
|
+
)
|
130
|
+
logging.info("Model loaded successfully with GPU acceleration.")
|
131
|
+
except Exception as gpu_e:
|
132
|
+
logging.warning(f"Failed to load model with GPU acceleration: {gpu_e}")
|
133
|
+
logging.info("Falling back to CPU...")
|
134
|
+
try:
|
135
|
+
model_instance = Llama(
|
136
|
+
model_path=model_file_path,
|
137
|
+
n_ctx=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS,
|
138
|
+
verbose=USE_VERBOSE,
|
139
|
+
n_gpu_layers=0
|
140
|
+
)
|
141
|
+
logging.info("Model loaded successfully with CPU.")
|
142
|
+
except Exception as cpu_e:
|
143
|
+
logging.error(f"Failed to load model with CPU: {cpu_e}")
|
144
|
+
if raise_exception:
|
145
|
+
raise
|
146
|
+
return None
|
147
|
+
return model_instance
|
148
|
+
except Exception as e:
|
149
|
+
logging.error(f"Exception occurred while loading the model: {e}")
|
150
|
+
traceback.print_exc()
|
151
|
+
if raise_exception:
|
152
|
+
raise
|
153
|
+
return None
|
154
|
+
|
155
|
+
# API Interaction Functions
|
156
|
+
async def generate_completion(prompt: str, max_tokens: int = 5000) -> Optional[str]:
|
157
|
+
if USE_LOCAL_LLM:
|
158
|
+
return await generate_completion_from_local_llm(DEFAULT_LOCAL_MODEL_NAME, prompt, max_tokens)
|
159
|
+
elif API_PROVIDER == "CLAUDE":
|
160
|
+
return await generate_completion_from_claude(prompt, max_tokens)
|
161
|
+
elif API_PROVIDER == "OPENAI":
|
162
|
+
return await generate_completion_from_openai(prompt, max_tokens)
|
163
|
+
else:
|
164
|
+
logging.error(f"Invalid API_PROVIDER: {API_PROVIDER}")
|
165
|
+
return None
|
166
|
+
|
167
|
+
def get_tokenizer(model_name: str):
|
168
|
+
if model_name.lower().startswith("gpt-"):
|
169
|
+
return tiktoken.encoding_for_model(model_name)
|
170
|
+
elif model_name.lower().startswith("claude-"):
|
171
|
+
return AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", clean_up_tokenization_spaces=False)
|
172
|
+
elif model_name.lower().startswith("llama-"):
|
173
|
+
return AutoTokenizer.from_pretrained("huggyllama/llama-7b", clean_up_tokenization_spaces=False)
|
174
|
+
else:
|
175
|
+
raise ValueError(f"Unsupported model: {model_name}")
|
176
|
+
|
177
|
+
def estimate_tokens(text: str, model_name: str) -> int:
|
178
|
+
try:
|
179
|
+
tokenizer = get_tokenizer(model_name)
|
180
|
+
return len(tokenizer.encode(text))
|
181
|
+
except Exception as e:
|
182
|
+
logging.warning(f"Error using tokenizer for {model_name}: {e}. Falling back to approximation.")
|
183
|
+
return approximate_tokens(text)
|
184
|
+
|
185
|
+
def approximate_tokens(text: str) -> int:
|
186
|
+
# Normalize whitespace
|
187
|
+
text = re.sub(r'\s+', ' ', text.strip())
|
188
|
+
# Split on whitespace and punctuation, keeping punctuation
|
189
|
+
tokens = re.findall(r'\b\w+\b|\S', text)
|
190
|
+
count = 0
|
191
|
+
for token in tokens:
|
192
|
+
if token.isdigit():
|
193
|
+
count += max(1, len(token) // 2) # Numbers often tokenize to multiple tokens
|
194
|
+
elif re.match(r'^[A-Z]{2,}$', token): # Acronyms
|
195
|
+
count += len(token)
|
196
|
+
elif re.search(r'[^\w\s]', token): # Punctuation and special characters
|
197
|
+
count += 1
|
198
|
+
elif len(token) > 10: # Long words often split into multiple tokens
|
199
|
+
count += len(token) // 4 + 1
|
200
|
+
else:
|
201
|
+
count += 1
|
202
|
+
# Add a 10% buffer for potential underestimation
|
203
|
+
return int(count * 1.1)
|
204
|
+
|
205
|
+
def chunk_text(text: str, max_chunk_tokens: int, model_name: str) -> List[str]:
|
206
|
+
chunks = []
|
207
|
+
tokenizer = get_tokenizer(model_name)
|
208
|
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
209
|
+
current_chunk = []
|
210
|
+
current_chunk_tokens = 0
|
211
|
+
|
212
|
+
for sentence in sentences:
|
213
|
+
sentence_tokens = len(tokenizer.encode(sentence))
|
214
|
+
if current_chunk_tokens + sentence_tokens > max_chunk_tokens:
|
215
|
+
chunks.append(' '.join(current_chunk))
|
216
|
+
current_chunk = [sentence]
|
217
|
+
current_chunk_tokens = sentence_tokens
|
218
|
+
else:
|
219
|
+
current_chunk.append(sentence)
|
220
|
+
current_chunk_tokens += sentence_tokens
|
221
|
+
|
222
|
+
if current_chunk:
|
223
|
+
chunks.append(' '.join(current_chunk))
|
224
|
+
|
225
|
+
adjusted_chunks = adjust_overlaps(chunks, tokenizer, max_chunk_tokens)
|
226
|
+
return adjusted_chunks
|
227
|
+
|
228
|
+
def split_long_sentence(sentence: str, max_tokens: int, model_name: str) -> List[str]:
|
229
|
+
words = sentence.split()
|
230
|
+
chunks = []
|
231
|
+
current_chunk = []
|
232
|
+
current_chunk_tokens = 0
|
233
|
+
tokenizer = get_tokenizer(model_name)
|
234
|
+
|
235
|
+
for word in words:
|
236
|
+
word_tokens = len(tokenizer.encode(word))
|
237
|
+
if current_chunk_tokens + word_tokens > max_tokens and current_chunk:
|
238
|
+
chunks.append(' '.join(current_chunk))
|
239
|
+
current_chunk = [word]
|
240
|
+
current_chunk_tokens = word_tokens
|
241
|
+
else:
|
242
|
+
current_chunk.append(word)
|
243
|
+
current_chunk_tokens += word_tokens
|
244
|
+
|
245
|
+
if current_chunk:
|
246
|
+
chunks.append(' '.join(current_chunk))
|
247
|
+
|
248
|
+
return chunks
|
249
|
+
|
250
|
+
def adjust_overlaps(chunks: List[str], tokenizer, max_chunk_tokens: int, overlap_size: int = 50) -> List[str]:
|
251
|
+
adjusted_chunks = []
|
252
|
+
for i in range(len(chunks)):
|
253
|
+
if i == 0:
|
254
|
+
adjusted_chunks.append(chunks[i])
|
255
|
+
else:
|
256
|
+
overlap_tokens = len(tokenizer.encode(' '.join(chunks[i-1].split()[-overlap_size:])))
|
257
|
+
current_tokens = len(tokenizer.encode(chunks[i]))
|
258
|
+
if overlap_tokens + current_tokens > max_chunk_tokens:
|
259
|
+
overlap_adjusted = chunks[i].split()[:-overlap_size]
|
260
|
+
adjusted_chunks.append(' '.join(overlap_adjusted))
|
261
|
+
else:
|
262
|
+
adjusted_chunks.append(' '.join(chunks[i-1].split()[-overlap_size:] + chunks[i].split()))
|
263
|
+
|
264
|
+
return adjusted_chunks
|
265
|
+
|
266
|
+
async def generate_completion_from_claude(prompt: str, max_tokens: int = CLAUDE_MAX_TOKENS - TOKEN_BUFFER) -> Optional[str]:
|
267
|
+
if not ANTHROPIC_API_KEY:
|
268
|
+
logging.error("Anthropic API key not found. Please set the ANTHROPIC_API_KEY environment variable.")
|
269
|
+
return None
|
270
|
+
client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
|
271
|
+
prompt_tokens = estimate_tokens(prompt, CLAUDE_MODEL_STRING)
|
272
|
+
adjusted_max_tokens = min(max_tokens, CLAUDE_MAX_TOKENS - prompt_tokens - TOKEN_BUFFER)
|
273
|
+
if adjusted_max_tokens <= 0:
|
274
|
+
logging.warning("Prompt is too long for Claude API. Chunking the input.")
|
275
|
+
chunks = chunk_text(prompt, CLAUDE_MAX_TOKENS - TOKEN_CUSHION, CLAUDE_MODEL_STRING)
|
276
|
+
results = []
|
277
|
+
for chunk in chunks:
|
278
|
+
try:
|
279
|
+
async with client.messages.stream(
|
280
|
+
model=CLAUDE_MODEL_STRING,
|
281
|
+
max_tokens=CLAUDE_MAX_TOKENS // 2,
|
282
|
+
temperature=0.7,
|
283
|
+
messages=[{"role": "user", "content": chunk}],
|
284
|
+
) as stream:
|
285
|
+
message = await stream.get_final_message()
|
286
|
+
results.append(message.content[0].text)
|
287
|
+
logging.info(f"Chunk processed. Input tokens: {message.usage.input_tokens:,}, Output tokens: {message.usage.output_tokens:,}")
|
288
|
+
except Exception as e:
|
289
|
+
logging.error(f"An error occurred while processing a chunk: {e}")
|
290
|
+
return " ".join(results)
|
291
|
+
else:
|
292
|
+
try:
|
293
|
+
async with client.messages.stream(
|
294
|
+
model=CLAUDE_MODEL_STRING,
|
295
|
+
max_tokens=adjusted_max_tokens,
|
296
|
+
temperature=0.7,
|
297
|
+
messages=[{"role": "user", "content": prompt}],
|
298
|
+
) as stream:
|
299
|
+
message = await stream.get_final_message()
|
300
|
+
output_text = message.content[0].text
|
301
|
+
logging.info(f"Total input tokens: {message.usage.input_tokens:,}")
|
302
|
+
logging.info(f"Total output tokens: {message.usage.output_tokens:,}")
|
303
|
+
logging.info(f"Generated output (abbreviated): {output_text[:150]}...")
|
304
|
+
return output_text
|
305
|
+
except Exception as e:
|
306
|
+
logging.error(f"An error occurred while requesting from Claude API: {e}")
|
307
|
+
return None
|
308
|
+
|
309
|
+
async def generate_completion_from_openai(prompt: str, max_tokens: int = 5000) -> Optional[str]:
|
310
|
+
if not OPENAI_API_KEY:
|
311
|
+
logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
|
312
|
+
return None
|
313
|
+
prompt_tokens = estimate_tokens(prompt, OPENAI_COMPLETION_MODEL)
|
314
|
+
adjusted_max_tokens = min(max_tokens, 4096 - prompt_tokens - TOKEN_BUFFER) # 4096 is typical max for GPT-3.5 and GPT-4
|
315
|
+
if adjusted_max_tokens <= 0:
|
316
|
+
logging.warning("Prompt is too long for OpenAI API. Chunking the input.")
|
317
|
+
chunks = chunk_text(prompt, OPENAI_MAX_TOKENS - TOKEN_CUSHION, OPENAI_COMPLETION_MODEL)
|
318
|
+
results = []
|
319
|
+
for chunk in chunks:
|
320
|
+
try:
|
321
|
+
response = await openai_client.chat.completions.create(
|
322
|
+
model=OPENAI_COMPLETION_MODEL,
|
323
|
+
messages=[{"role": "user", "content": chunk}],
|
324
|
+
max_tokens=adjusted_max_tokens,
|
325
|
+
temperature=0.7,
|
326
|
+
)
|
327
|
+
result = response.choices[0].message.content
|
328
|
+
results.append(result)
|
329
|
+
logging.info(f"Chunk processed. Output tokens: {response.usage.completion_tokens:,}")
|
330
|
+
except Exception as e:
|
331
|
+
logging.error(f"An error occurred while processing a chunk: {e}")
|
332
|
+
return " ".join(results)
|
333
|
+
else:
|
334
|
+
try:
|
335
|
+
response = await openai_client.chat.completions.create(
|
336
|
+
model=OPENAI_COMPLETION_MODEL,
|
337
|
+
messages=[{"role": "user", "content": prompt}],
|
338
|
+
max_tokens=adjusted_max_tokens,
|
339
|
+
temperature=0.7,
|
340
|
+
)
|
341
|
+
output_text = response.choices[0].message.content
|
342
|
+
logging.info(f"Total tokens: {response.usage.total_tokens:,}")
|
343
|
+
logging.info(f"Generated output (abbreviated): {output_text[:150]}...")
|
344
|
+
return output_text
|
345
|
+
except Exception as e:
|
346
|
+
logging.error(f"An error occurred while requesting from OpenAI API: {e}")
|
347
|
+
return None
|
348
|
+
|
349
|
+
async def generate_completion_from_local_llm(llm_model_name: str, input_prompt: str, number_of_tokens_to_generate: int = 100, temperature: float = 0.7, grammar_file_string: str = None):
|
350
|
+
logging.info(f"Starting text completion using model: '{llm_model_name}' for input prompt: '{input_prompt}'")
|
351
|
+
llm = load_model(llm_model_name)
|
352
|
+
prompt_tokens = estimate_tokens(input_prompt, llm_model_name)
|
353
|
+
adjusted_max_tokens = min(number_of_tokens_to_generate, LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - prompt_tokens - TOKEN_BUFFER)
|
354
|
+
if adjusted_max_tokens <= 0:
|
355
|
+
logging.warning("Prompt is too long for LLM. Chunking the input.")
|
356
|
+
chunks = chunk_text(input_prompt, LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - TOKEN_CUSHION, llm_model_name)
|
357
|
+
results = []
|
358
|
+
for chunk in chunks:
|
359
|
+
try:
|
360
|
+
output = llm(
|
361
|
+
prompt=chunk,
|
362
|
+
max_tokens=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - TOKEN_CUSHION,
|
363
|
+
temperature=temperature,
|
364
|
+
)
|
365
|
+
results.append(output['choices'][0]['text'])
|
366
|
+
logging.info(f"Chunk processed. Output tokens: {output['usage']['completion_tokens']:,}")
|
367
|
+
except Exception as e:
|
368
|
+
logging.error(f"An error occurred while processing a chunk: {e}")
|
369
|
+
return " ".join(results)
|
370
|
+
else:
|
371
|
+
grammar_file_string_lower = grammar_file_string.lower() if grammar_file_string else ""
|
372
|
+
if grammar_file_string_lower:
|
373
|
+
list_of_grammar_files = glob.glob("./grammar_files/*.gbnf")
|
374
|
+
matching_grammar_files = [x for x in list_of_grammar_files if grammar_file_string_lower in os.path.splitext(os.path.basename(x).lower())[0]]
|
375
|
+
if len(matching_grammar_files) == 0:
|
376
|
+
logging.error(f"No grammar file found matching: {grammar_file_string}")
|
377
|
+
raise FileNotFoundError
|
378
|
+
grammar_file_path = max(matching_grammar_files, key=os.path.getmtime)
|
379
|
+
logging.info(f"Loading selected grammar file: '{grammar_file_path}'")
|
380
|
+
llama_grammar = LlamaGrammar.from_file(grammar_file_path)
|
381
|
+
output = llm(
|
382
|
+
prompt=input_prompt,
|
383
|
+
max_tokens=adjusted_max_tokens,
|
384
|
+
temperature=temperature,
|
385
|
+
grammar=llama_grammar
|
386
|
+
)
|
387
|
+
else:
|
388
|
+
output = llm(
|
389
|
+
prompt=input_prompt,
|
390
|
+
max_tokens=adjusted_max_tokens,
|
391
|
+
temperature=temperature
|
392
|
+
)
|
393
|
+
generated_text = output['choices'][0]['text']
|
394
|
+
if grammar_file_string == 'json':
|
395
|
+
generated_text = generated_text.encode('unicode_escape').decode()
|
396
|
+
finish_reason = str(output['choices'][0]['finish_reason'])
|
397
|
+
llm_model_usage_json = json.dumps(output['usage'])
|
398
|
+
logging.info(f"Completed text completion in {output['usage']['total_time']:.2f} seconds. Beginning of generated text: \n'{generated_text[:150]}'...")
|
399
|
+
return {
|
400
|
+
"generated_text": generated_text,
|
401
|
+
"finish_reason": finish_reason,
|
402
|
+
"llm_model_usage_json": llm_model_usage_json
|
403
|
+
}
|
404
|
+
|
405
|
+
# Image Processing Functions
|
406
|
+
def preprocess_image(image):
|
407
|
+
gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
|
408
|
+
gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
409
|
+
kernel = np.ones((1, 1), np.uint8)
|
410
|
+
gray = cv2.dilate(gray, kernel, iterations=1)
|
411
|
+
return Image.fromarray(gray)
|
412
|
+
|
413
|
+
def convert_pdf_to_images(input_pdf_file_path: str, max_pages: int = 0, skip_first_n_pages: int = 0) -> List[Image.Image]:
|
414
|
+
logging.info(f"Processing PDF file {input_pdf_file_path}")
|
415
|
+
if max_pages == 0:
|
416
|
+
last_page = None
|
417
|
+
logging.info("Converting all pages to images...")
|
418
|
+
else:
|
419
|
+
last_page = skip_first_n_pages + max_pages
|
420
|
+
logging.info(f"Converting pages {skip_first_n_pages + 1} to {last_page}")
|
421
|
+
first_page = skip_first_n_pages + 1 # pdf2image uses 1-based indexing
|
422
|
+
images = convert_from_path(input_pdf_file_path, first_page=first_page, last_page=last_page)
|
423
|
+
logging.info(f"Converted {len(images)} pages from PDF file to images.")
|
424
|
+
return images
|
425
|
+
|
426
|
+
def ocr_image(image):
|
427
|
+
preprocessed_image = preprocess_image(image)
|
428
|
+
return pytesseract.image_to_string(preprocessed_image)
|
429
|
+
|
430
|
+
async def process_chunk(chunk: str, prev_context: str, chunk_index: int, total_chunks: int, reformat_as_markdown: bool, suppress_headers_and_page_numbers: bool) -> Tuple[str, str]:
|
431
|
+
logging.info(f"Processing chunk {chunk_index + 1}/{total_chunks} (length: {len(chunk):,} characters)")
|
432
|
+
|
433
|
+
# Step 1: OCR Correction
|
434
|
+
ocr_correction_prompt = f"""Correct OCR-induced errors in the text, ensuring it flows coherently with the previous context. Follow these guidelines:
|
435
|
+
|
436
|
+
1. Fix OCR-induced typos and errors:
|
437
|
+
- Correct words split across line breaks
|
438
|
+
- Fix common OCR errors (e.g., 'rn' misread as 'm')
|
439
|
+
- Use context and common sense to correct errors
|
440
|
+
- Only fix clear errors, don't alter the content unnecessarily
|
441
|
+
- Do not add extra periods or any unnecessary punctuation
|
442
|
+
|
443
|
+
2. Maintain original structure:
|
444
|
+
- Keep all headings and subheadings intact
|
445
|
+
|
446
|
+
3. Preserve original content:
|
447
|
+
- Keep all important information from the original text
|
448
|
+
- Do not add any new information not present in the original text
|
449
|
+
- Remove unnecessary line breaks within sentences or paragraphs
|
450
|
+
- Maintain paragraph breaks
|
451
|
+
|
452
|
+
4. Maintain coherence:
|
453
|
+
- Ensure the content connects smoothly with the previous context
|
454
|
+
- Handle text that starts or ends mid-sentence appropriately
|
455
|
+
|
456
|
+
IMPORTANT: Respond ONLY with the corrected text. Preserve all original formatting, including line breaks. Do not include any introduction, explanation, or metadata.
|
457
|
+
|
458
|
+
Previous context:
|
459
|
+
{prev_context[-500:]}
|
460
|
+
|
461
|
+
Current chunk to process:
|
462
|
+
{chunk}
|
463
|
+
|
464
|
+
Corrected text:
|
465
|
+
"""
|
466
|
+
|
467
|
+
ocr_corrected_chunk = await generate_completion(ocr_correction_prompt, max_tokens=len(chunk) + 500)
|
468
|
+
|
469
|
+
processed_chunk = ocr_corrected_chunk
|
470
|
+
|
471
|
+
# Step 2: Markdown Formatting (if requested)
|
472
|
+
if reformat_as_markdown:
|
473
|
+
markdown_prompt = f"""Reformat the following text as markdown, improving readability while preserving the original structure. Follow these guidelines:
|
474
|
+
1. Preserve all original headings, converting them to appropriate markdown heading levels (# for main titles, ## for subtitles, etc.)
|
475
|
+
- Ensure each heading is on its own line
|
476
|
+
- Add a blank line before and after each heading
|
477
|
+
2. Maintain the original paragraph structure. Remove all breaks within a word that should be a single word (for example, "cor- rect" should be "correct")
|
478
|
+
3. Format lists properly (unordered or ordered) if they exist in the original text
|
479
|
+
4. Use emphasis (*italic*) and strong emphasis (**bold**) where appropriate, based on the original formatting
|
480
|
+
5. Preserve all original content and meaning
|
481
|
+
6. Do not add any extra punctuation or modify the existing punctuation
|
482
|
+
7. Remove any spuriously inserted introductory text such as "Here is the corrected text:" that may have been added by the LLM and which is obviously not part of the original text.
|
483
|
+
8. Remove any obviously duplicated content that appears to have been accidentally included twice. Follow these strict guidelines:
|
484
|
+
- Remove only exact or near-exact repeated paragraphs or sections within the main chunk.
|
485
|
+
- Consider the context (before and after the main chunk) to identify duplicates that span chunk boundaries.
|
486
|
+
- Do not remove content that is simply similar but conveys different information.
|
487
|
+
- Preserve all unique content, even if it seems redundant.
|
488
|
+
- Ensure the text flows smoothly after removal.
|
489
|
+
- Do not add any new content or explanations.
|
490
|
+
- If no obvious duplicates are found, return the main chunk unchanged.
|
491
|
+
9. {"Identify but do not remove headers, footers, or page numbers. Instead, format them distinctly, e.g., as blockquotes." if not suppress_headers_and_page_numbers else "Carefully remove headers, footers, and page numbers while preserving all other content."}
|
492
|
+
|
493
|
+
Text to reformat:
|
494
|
+
|
495
|
+
{ocr_corrected_chunk}
|
496
|
+
|
497
|
+
Reformatted markdown:
|
498
|
+
"""
|
499
|
+
processed_chunk = await generate_completion(markdown_prompt, max_tokens=len(ocr_corrected_chunk) + 500)
|
500
|
+
new_context = processed_chunk[-1000:] # Use the last 1000 characters as context for the next chunk
|
501
|
+
logging.info(f"Chunk {chunk_index + 1}/{total_chunks} processed. Output length: {len(processed_chunk):,} characters")
|
502
|
+
return processed_chunk, new_context
|
503
|
+
|
504
|
+
async def process_chunks(chunks: List[str], reformat_as_markdown: bool, suppress_headers_and_page_numbers: bool) -> List[str]:
|
505
|
+
total_chunks = len(chunks)
|
506
|
+
async def process_chunk_with_context(chunk: str, prev_context: str, index: int) -> Tuple[int, str, str]:
|
507
|
+
processed_chunk, new_context = await process_chunk(chunk, prev_context, index, total_chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
|
508
|
+
return index, processed_chunk, new_context
|
509
|
+
if USE_LOCAL_LLM:
|
510
|
+
logging.info("Using local LLM. Processing chunks sequentially...")
|
511
|
+
context = ""
|
512
|
+
processed_chunks = []
|
513
|
+
for i, chunk in enumerate(chunks):
|
514
|
+
processed_chunk, context = await process_chunk(chunk, context, i, total_chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
|
515
|
+
processed_chunks.append(processed_chunk)
|
516
|
+
else:
|
517
|
+
logging.info("Using API-based LLM. Processing chunks concurrently while maintaining order...")
|
518
|
+
tasks = [process_chunk_with_context(chunk, "", i) for i, chunk in enumerate(chunks)]
|
519
|
+
results = await asyncio.gather(*tasks)
|
520
|
+
# Sort results by index to maintain order
|
521
|
+
sorted_results = sorted(results, key=lambda x: x[0])
|
522
|
+
processed_chunks = [chunk for _, chunk, _ in sorted_results]
|
523
|
+
logging.info(f"All {total_chunks} chunks processed successfully")
|
524
|
+
return processed_chunks
|
525
|
+
|
526
|
+
async def process_document(list_of_extracted_text_strings: List[str], reformat_as_markdown: bool = True, suppress_headers_and_page_numbers: bool = True) -> str:
|
527
|
+
logging.info(f"Starting document processing. Total pages: {len(list_of_extracted_text_strings):,}")
|
528
|
+
full_text = "\n\n".join(list_of_extracted_text_strings)
|
529
|
+
logging.info(f"Size of full text before processing: {len(full_text):,} characters")
|
530
|
+
chunk_size, overlap = 8000, 10
|
531
|
+
# Improved chunking logic
|
532
|
+
paragraphs = re.split(r'\n\s*\n', full_text)
|
533
|
+
chunks = []
|
534
|
+
current_chunk = []
|
535
|
+
current_chunk_length = 0
|
536
|
+
for paragraph in paragraphs:
|
537
|
+
paragraph_length = len(paragraph)
|
538
|
+
if current_chunk_length + paragraph_length <= chunk_size:
|
539
|
+
current_chunk.append(paragraph)
|
540
|
+
current_chunk_length += paragraph_length
|
541
|
+
else:
|
542
|
+
# If adding the whole paragraph exceeds the chunk size,
|
543
|
+
# we need to split the paragraph into sentences
|
544
|
+
if current_chunk:
|
545
|
+
chunks.append("\n\n".join(current_chunk))
|
546
|
+
sentences = re.split(r'(?<=[.!?])\s+', paragraph)
|
547
|
+
current_chunk = []
|
548
|
+
current_chunk_length = 0
|
549
|
+
for sentence in sentences:
|
550
|
+
sentence_length = len(sentence)
|
551
|
+
if current_chunk_length + sentence_length <= chunk_size:
|
552
|
+
current_chunk.append(sentence)
|
553
|
+
current_chunk_length += sentence_length
|
554
|
+
else:
|
555
|
+
if current_chunk:
|
556
|
+
chunks.append(" ".join(current_chunk))
|
557
|
+
current_chunk = [sentence]
|
558
|
+
current_chunk_length = sentence_length
|
559
|
+
# Add any remaining content as the last chunk
|
560
|
+
if current_chunk:
|
561
|
+
chunks.append("\n\n".join(current_chunk) if len(current_chunk) > 1 else current_chunk[0])
|
562
|
+
# Add overlap between chunks
|
563
|
+
for i in range(1, len(chunks)):
|
564
|
+
overlap_text = chunks[i-1].split()[-overlap:]
|
565
|
+
chunks[i] = " ".join(overlap_text) + " " + chunks[i]
|
566
|
+
logging.info(f"Document split into {len(chunks):,} chunks. Chunk size: {chunk_size:,}, Overlap: {overlap:,}")
|
567
|
+
processed_chunks = await process_chunks(chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
|
568
|
+
final_text = "".join(processed_chunks)
|
569
|
+
logging.info(f"Size of text after combining chunks: {len(final_text):,} characters")
|
570
|
+
logging.info(f"Document processing complete. Final text length: {len(final_text):,} characters")
|
571
|
+
return final_text
|
572
|
+
|
573
|
+
def remove_corrected_text_header(text):
|
574
|
+
return text.replace("# Corrected text\n", "").replace("# Corrected text:", "").replace("\nCorrected text", "").replace("Corrected text:", "")
|
575
|
+
|
576
|
+
async def assess_output_quality(original_text, processed_text):
|
577
|
+
max_chars = 15000 # Limit to avoid exceeding token limits
|
578
|
+
available_chars_per_text = max_chars // 2 # Split equally between original and processed
|
579
|
+
|
580
|
+
original_sample = original_text[:available_chars_per_text]
|
581
|
+
processed_sample = processed_text[:available_chars_per_text]
|
582
|
+
|
583
|
+
prompt = f"""Compare the following samples of original OCR text with the processed output and assess the quality of the processing. Consider the following factors:
|
584
|
+
1. Accuracy of error correction
|
585
|
+
2. Improvement in readability
|
586
|
+
3. Preservation of original content and meaning
|
587
|
+
4. Appropriate use of markdown formatting (if applicable)
|
588
|
+
5. Removal of hallucinations or irrelevant content
|
589
|
+
|
590
|
+
Original text sample:
|
591
|
+
```
|
592
|
+
{original_sample}
|
593
|
+
```
|
594
|
+
|
595
|
+
Processed text sample:
|
596
|
+
```
|
597
|
+
{processed_sample}
|
598
|
+
```
|
599
|
+
|
600
|
+
Provide a quality score between 0 and 100, where 100 is perfect processing. Also provide a brief explanation of your assessment.
|
601
|
+
|
602
|
+
Your response should be in the following format:
|
603
|
+
SCORE: [Your score]
|
604
|
+
EXPLANATION: [Your explanation]
|
605
|
+
"""
|
606
|
+
|
607
|
+
response = await generate_completion(prompt, max_tokens=1000)
|
608
|
+
|
609
|
+
try:
|
610
|
+
lines = response.strip().split('\n')
|
611
|
+
score_line = next(line for line in lines if line.startswith('SCORE:'))
|
612
|
+
score = int(score_line.split(':')[1].strip())
|
613
|
+
explanation = '\n'.join(line for line in lines if line.startswith('EXPLANATION:')).replace('EXPLANATION:', '').strip()
|
614
|
+
logging.info(f"Quality assessment: Score {score}/100")
|
615
|
+
logging.info(f"Explanation: {explanation}")
|
616
|
+
return score, explanation
|
617
|
+
except Exception as e:
|
618
|
+
logging.error(f"Error parsing quality assessment response: {e}")
|
619
|
+
logging.error(f"Raw response: {response}")
|
620
|
+
return None, None
|
621
|
+
|
622
|
+
async def main():
|
623
|
+
try:
|
624
|
+
# Suppress HTTP request logs
|
625
|
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
626
|
+
input_pdf_file_path = '160301289-Warren-Buffett-Katharine-Graham-Letter.pdf'
|
627
|
+
max_test_pages = 0
|
628
|
+
skip_first_n_pages = 0
|
629
|
+
reformat_as_markdown = True
|
630
|
+
suppress_headers_and_page_numbers = True
|
631
|
+
|
632
|
+
# Download the model if using local LLM
|
633
|
+
if USE_LOCAL_LLM:
|
634
|
+
_, download_status = await download_models()
|
635
|
+
logging.info(f"Model download status: {download_status}")
|
636
|
+
logging.info(f"Using Local LLM with Model: {DEFAULT_LOCAL_MODEL_NAME}")
|
637
|
+
else:
|
638
|
+
logging.info(f"Using API for completions: {API_PROVIDER}")
|
639
|
+
logging.info(f"Using OpenAI model for embeddings: {OPENAI_EMBEDDING_MODEL}")
|
640
|
+
|
641
|
+
base_name = os.path.splitext(input_pdf_file_path)[0]
|
642
|
+
output_extension = '.md' if reformat_as_markdown else '.txt'
|
643
|
+
|
644
|
+
raw_ocr_output_file_path = f"{base_name}__raw_ocr_output.txt"
|
645
|
+
llm_corrected_output_file_path = base_name + '_llm_corrected' + output_extension
|
646
|
+
|
647
|
+
list_of_scanned_images = convert_pdf_to_images(input_pdf_file_path, max_test_pages, skip_first_n_pages)
|
648
|
+
logging.info(f"Tesseract version: {pytesseract.get_tesseract_version()}")
|
649
|
+
logging.info("Extracting text from converted pages...")
|
650
|
+
with ThreadPoolExecutor() as executor:
|
651
|
+
list_of_extracted_text_strings = list(executor.map(ocr_image, list_of_scanned_images))
|
652
|
+
logging.info("Done extracting text from converted pages.")
|
653
|
+
raw_ocr_output = "\n".join(list_of_extracted_text_strings)
|
654
|
+
with open(raw_ocr_output_file_path, "w") as f:
|
655
|
+
f.write(raw_ocr_output)
|
656
|
+
logging.info(f"Raw OCR output written to: {raw_ocr_output_file_path}")
|
657
|
+
|
658
|
+
logging.info("Processing document...")
|
659
|
+
final_text = await process_document(list_of_extracted_text_strings, reformat_as_markdown, suppress_headers_and_page_numbers)
|
660
|
+
cleaned_text = remove_corrected_text_header(final_text)
|
661
|
+
|
662
|
+
# Save the LLM corrected output
|
663
|
+
with open(llm_corrected_output_file_path, 'w') as f:
|
664
|
+
f.write(cleaned_text)
|
665
|
+
logging.info(f"LLM Corrected text written to: {llm_corrected_output_file_path}")
|
666
|
+
|
667
|
+
if final_text:
|
668
|
+
logging.info(f"First 500 characters of LLM corrected processed text:\n{final_text[:500]}...")
|
669
|
+
else:
|
670
|
+
logging.warning("final_text is empty or not defined.")
|
671
|
+
|
672
|
+
logging.info(f"Done processing {input_pdf_file_path}.")
|
673
|
+
logging.info("\nSee output files:")
|
674
|
+
logging.info(f" Raw OCR: {raw_ocr_output_file_path}")
|
675
|
+
logging.info(f" LLM Corrected: {llm_corrected_output_file_path}")
|
676
|
+
|
677
|
+
# Perform a final quality check
|
678
|
+
quality_score, explanation = await assess_output_quality(raw_ocr_output, final_text)
|
679
|
+
if quality_score is not None:
|
680
|
+
logging.info(f"Final quality score: {quality_score}/100")
|
681
|
+
logging.info(f"Explanation: {explanation}")
|
682
|
+
else:
|
683
|
+
logging.warning("Unable to determine final quality score.")
|
684
|
+
except Exception as e:
|
685
|
+
logging.error(f"An error occurred in the main function: {e}")
|
686
|
+
logging.error(traceback.format_exc())
|
687
|
+
|
688
|
+
if __name__ == '__main__':
|
689
|
+
asyncio.run(main())
|