magic-pdf 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,689 @@
1
+ import os
2
+ import glob
3
+ import traceback
4
+ import asyncio
5
+ import json
6
+ import re
7
+ import urllib.request
8
+ import logging
9
+ from concurrent.futures import ThreadPoolExecutor
10
+ import warnings
11
+ from typing import List, Dict, Tuple, Optional
12
+ from pdf2image import convert_from_path
13
+ import pytesseract
14
+ from llama_cpp import Llama, LlamaGrammar
15
+ import tiktoken
16
+ import numpy as np
17
+ from PIL import Image
18
+ from decouple import Config as DecoupleConfig, RepositoryEnv
19
+ import cv2
20
+ from filelock import FileLock, Timeout
21
+ from transformers import AutoTokenizer
22
+ from openai import AsyncOpenAI
23
+ from anthropic import AsyncAnthropic
24
+ try:
25
+ import nvgpu
26
+ GPU_AVAILABLE = True
27
+ except ImportError:
28
+ GPU_AVAILABLE = False
29
+
30
+ # Configuration
31
+ config = DecoupleConfig(RepositoryEnv('.env'))
32
+
33
+ USE_LOCAL_LLM = config.get("USE_LOCAL_LLM", default=False, cast=bool)
34
+ API_PROVIDER = config.get("API_PROVIDER", default="OPENAI", cast=str) # OPENAI or CLAUDE
35
+ ANTHROPIC_API_KEY = config.get("ANTHROPIC_API_KEY", default="your-anthropic-api-key", cast=str)
36
+ OPENAI_API_KEY = config.get("OPENAI_API_KEY", default="your-openai-api-key", cast=str)
37
+ CLAUDE_MODEL_STRING = config.get("CLAUDE_MODEL_STRING", default="claude-3-haiku-20240307", cast=str)
38
+ CLAUDE_MAX_TOKENS = 4096 # Maximum allowed tokens for Claude API
39
+ TOKEN_BUFFER = 500 # Buffer to account for token estimation inaccuracies
40
+ TOKEN_CUSHION = 300 # Don't use the full max tokens to avoid hitting the limit
41
+ OPENAI_COMPLETION_MODEL = config.get("OPENAI_COMPLETION_MODEL", default="gpt-4o-mini", cast=str)
42
+ OPENAI_EMBEDDING_MODEL = config.get("OPENAI_EMBEDDING_MODEL", default="text-embedding-3-small", cast=str)
43
+ OPENAI_MAX_TOKENS = 12000 # Maximum allowed tokens for OpenAI API
44
+ DEFAULT_LOCAL_MODEL_NAME = "Llama-3.1-8B-Lexi-Uncensored_Q5_fixedrope.gguf"
45
+ LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS = 2048
46
+ USE_VERBOSE = False
47
+
48
+ openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
49
+ warnings.filterwarnings("ignore", category=FutureWarning)
50
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
51
+
52
+ # GPU Check
53
+ def is_gpu_available():
54
+ if not GPU_AVAILABLE:
55
+ logging.warning("GPU support not available: nvgpu module not found")
56
+ return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0, "error": "nvgpu module not found"}
57
+ try:
58
+ gpu_info = nvgpu.gpu_info()
59
+ num_gpus = len(gpu_info)
60
+ if num_gpus == 0:
61
+ logging.warning("No GPUs found on the system")
62
+ return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0}
63
+ first_gpu_vram = gpu_info[0]['mem_total']
64
+ total_vram = sum(gpu['mem_total'] for gpu in gpu_info)
65
+ logging.info(f"GPU(s) found: {num_gpus}, Total VRAM: {total_vram} MB")
66
+ return {"gpu_found": True, "num_gpus": num_gpus, "first_gpu_vram": first_gpu_vram, "total_vram": total_vram, "gpu_info": gpu_info}
67
+ except Exception as e:
68
+ logging.error(f"Error checking GPU availability: {e}")
69
+ return {"gpu_found": False, "num_gpus": 0, "first_gpu_vram": 0, "total_vram": 0, "error": str(e)}
70
+
71
+ # Model Download
72
+ async def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
73
+ download_status = []
74
+ model_url = "https://huggingface.co/Orenguteng/Llama-3.1-8B-Lexi-Uncensored-GGUF/resolve/main/Llama-3.1-8B-Lexi-Uncensored_Q5_fixedrope.gguf"
75
+ model_name = os.path.basename(model_url)
76
+ current_file_path = os.path.abspath(__file__)
77
+ base_dir = os.path.dirname(current_file_path)
78
+ models_dir = os.path.join(base_dir, 'models')
79
+
80
+ os.makedirs(models_dir, exist_ok=True)
81
+ lock = FileLock(os.path.join(models_dir, "download.lock"))
82
+ status = {"url": model_url, "status": "success", "message": "File already exists."}
83
+ filename = os.path.join(models_dir, model_name)
84
+
85
+ try:
86
+ with lock.acquire(timeout=1200):
87
+ if not os.path.exists(filename):
88
+ logging.info(f"Downloading model {model_name} from {model_url}...")
89
+ urllib.request.urlretrieve(model_url, filename)
90
+ file_size = os.path.getsize(filename) / (1024 * 1024)
91
+ if file_size < 100:
92
+ os.remove(filename)
93
+ status["status"] = "failure"
94
+ status["message"] = f"Downloaded file is too small ({file_size:.2f} MB), probably not a valid model file."
95
+ logging.error(f"Error: {status['message']}")
96
+ else:
97
+ logging.info(f"Successfully downloaded: {filename} (Size: {file_size:.2f} MB)")
98
+ else:
99
+ logging.info(f"Model file already exists: {filename}")
100
+ except Timeout:
101
+ logging.error(f"Error: Could not acquire lock for downloading {model_name}")
102
+ status["status"] = "failure"
103
+ status["message"] = "Could not acquire lock for downloading."
104
+
105
+ download_status.append(status)
106
+ logging.info("Model download process completed.")
107
+ return [model_name], download_status
108
+
109
+ # Model Loading
110
+ def load_model(llm_model_name: str, raise_exception: bool = True):
111
+ global USE_VERBOSE
112
+ try:
113
+ current_file_path = os.path.abspath(__file__)
114
+ base_dir = os.path.dirname(current_file_path)
115
+ models_dir = os.path.join(base_dir, 'models')
116
+ matching_files = glob.glob(os.path.join(models_dir, f"{llm_model_name}*"))
117
+ if not matching_files:
118
+ logging.error(f"Error: No model file found matching: {llm_model_name}")
119
+ raise FileNotFoundError
120
+ model_file_path = max(matching_files, key=os.path.getmtime)
121
+ logging.info(f"Loading model: {model_file_path}")
122
+ try:
123
+ logging.info("Attempting to load model with GPU acceleration...")
124
+ model_instance = Llama(
125
+ model_path=model_file_path,
126
+ n_ctx=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS,
127
+ verbose=USE_VERBOSE,
128
+ n_gpu_layers=-1
129
+ )
130
+ logging.info("Model loaded successfully with GPU acceleration.")
131
+ except Exception as gpu_e:
132
+ logging.warning(f"Failed to load model with GPU acceleration: {gpu_e}")
133
+ logging.info("Falling back to CPU...")
134
+ try:
135
+ model_instance = Llama(
136
+ model_path=model_file_path,
137
+ n_ctx=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS,
138
+ verbose=USE_VERBOSE,
139
+ n_gpu_layers=0
140
+ )
141
+ logging.info("Model loaded successfully with CPU.")
142
+ except Exception as cpu_e:
143
+ logging.error(f"Failed to load model with CPU: {cpu_e}")
144
+ if raise_exception:
145
+ raise
146
+ return None
147
+ return model_instance
148
+ except Exception as e:
149
+ logging.error(f"Exception occurred while loading the model: {e}")
150
+ traceback.print_exc()
151
+ if raise_exception:
152
+ raise
153
+ return None
154
+
155
+ # API Interaction Functions
156
+ async def generate_completion(prompt: str, max_tokens: int = 5000) -> Optional[str]:
157
+ if USE_LOCAL_LLM:
158
+ return await generate_completion_from_local_llm(DEFAULT_LOCAL_MODEL_NAME, prompt, max_tokens)
159
+ elif API_PROVIDER == "CLAUDE":
160
+ return await generate_completion_from_claude(prompt, max_tokens)
161
+ elif API_PROVIDER == "OPENAI":
162
+ return await generate_completion_from_openai(prompt, max_tokens)
163
+ else:
164
+ logging.error(f"Invalid API_PROVIDER: {API_PROVIDER}")
165
+ return None
166
+
167
+ def get_tokenizer(model_name: str):
168
+ if model_name.lower().startswith("gpt-"):
169
+ return tiktoken.encoding_for_model(model_name)
170
+ elif model_name.lower().startswith("claude-"):
171
+ return AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", clean_up_tokenization_spaces=False)
172
+ elif model_name.lower().startswith("llama-"):
173
+ return AutoTokenizer.from_pretrained("huggyllama/llama-7b", clean_up_tokenization_spaces=False)
174
+ else:
175
+ raise ValueError(f"Unsupported model: {model_name}")
176
+
177
+ def estimate_tokens(text: str, model_name: str) -> int:
178
+ try:
179
+ tokenizer = get_tokenizer(model_name)
180
+ return len(tokenizer.encode(text))
181
+ except Exception as e:
182
+ logging.warning(f"Error using tokenizer for {model_name}: {e}. Falling back to approximation.")
183
+ return approximate_tokens(text)
184
+
185
+ def approximate_tokens(text: str) -> int:
186
+ # Normalize whitespace
187
+ text = re.sub(r'\s+', ' ', text.strip())
188
+ # Split on whitespace and punctuation, keeping punctuation
189
+ tokens = re.findall(r'\b\w+\b|\S', text)
190
+ count = 0
191
+ for token in tokens:
192
+ if token.isdigit():
193
+ count += max(1, len(token) // 2) # Numbers often tokenize to multiple tokens
194
+ elif re.match(r'^[A-Z]{2,}$', token): # Acronyms
195
+ count += len(token)
196
+ elif re.search(r'[^\w\s]', token): # Punctuation and special characters
197
+ count += 1
198
+ elif len(token) > 10: # Long words often split into multiple tokens
199
+ count += len(token) // 4 + 1
200
+ else:
201
+ count += 1
202
+ # Add a 10% buffer for potential underestimation
203
+ return int(count * 1.1)
204
+
205
+ def chunk_text(text: str, max_chunk_tokens: int, model_name: str) -> List[str]:
206
+ chunks = []
207
+ tokenizer = get_tokenizer(model_name)
208
+ sentences = re.split(r'(?<=[.!?])\s+', text)
209
+ current_chunk = []
210
+ current_chunk_tokens = 0
211
+
212
+ for sentence in sentences:
213
+ sentence_tokens = len(tokenizer.encode(sentence))
214
+ if current_chunk_tokens + sentence_tokens > max_chunk_tokens:
215
+ chunks.append(' '.join(current_chunk))
216
+ current_chunk = [sentence]
217
+ current_chunk_tokens = sentence_tokens
218
+ else:
219
+ current_chunk.append(sentence)
220
+ current_chunk_tokens += sentence_tokens
221
+
222
+ if current_chunk:
223
+ chunks.append(' '.join(current_chunk))
224
+
225
+ adjusted_chunks = adjust_overlaps(chunks, tokenizer, max_chunk_tokens)
226
+ return adjusted_chunks
227
+
228
+ def split_long_sentence(sentence: str, max_tokens: int, model_name: str) -> List[str]:
229
+ words = sentence.split()
230
+ chunks = []
231
+ current_chunk = []
232
+ current_chunk_tokens = 0
233
+ tokenizer = get_tokenizer(model_name)
234
+
235
+ for word in words:
236
+ word_tokens = len(tokenizer.encode(word))
237
+ if current_chunk_tokens + word_tokens > max_tokens and current_chunk:
238
+ chunks.append(' '.join(current_chunk))
239
+ current_chunk = [word]
240
+ current_chunk_tokens = word_tokens
241
+ else:
242
+ current_chunk.append(word)
243
+ current_chunk_tokens += word_tokens
244
+
245
+ if current_chunk:
246
+ chunks.append(' '.join(current_chunk))
247
+
248
+ return chunks
249
+
250
+ def adjust_overlaps(chunks: List[str], tokenizer, max_chunk_tokens: int, overlap_size: int = 50) -> List[str]:
251
+ adjusted_chunks = []
252
+ for i in range(len(chunks)):
253
+ if i == 0:
254
+ adjusted_chunks.append(chunks[i])
255
+ else:
256
+ overlap_tokens = len(tokenizer.encode(' '.join(chunks[i-1].split()[-overlap_size:])))
257
+ current_tokens = len(tokenizer.encode(chunks[i]))
258
+ if overlap_tokens + current_tokens > max_chunk_tokens:
259
+ overlap_adjusted = chunks[i].split()[:-overlap_size]
260
+ adjusted_chunks.append(' '.join(overlap_adjusted))
261
+ else:
262
+ adjusted_chunks.append(' '.join(chunks[i-1].split()[-overlap_size:] + chunks[i].split()))
263
+
264
+ return adjusted_chunks
265
+
266
+ async def generate_completion_from_claude(prompt: str, max_tokens: int = CLAUDE_MAX_TOKENS - TOKEN_BUFFER) -> Optional[str]:
267
+ if not ANTHROPIC_API_KEY:
268
+ logging.error("Anthropic API key not found. Please set the ANTHROPIC_API_KEY environment variable.")
269
+ return None
270
+ client = AsyncAnthropic(api_key=ANTHROPIC_API_KEY)
271
+ prompt_tokens = estimate_tokens(prompt, CLAUDE_MODEL_STRING)
272
+ adjusted_max_tokens = min(max_tokens, CLAUDE_MAX_TOKENS - prompt_tokens - TOKEN_BUFFER)
273
+ if adjusted_max_tokens <= 0:
274
+ logging.warning("Prompt is too long for Claude API. Chunking the input.")
275
+ chunks = chunk_text(prompt, CLAUDE_MAX_TOKENS - TOKEN_CUSHION, CLAUDE_MODEL_STRING)
276
+ results = []
277
+ for chunk in chunks:
278
+ try:
279
+ async with client.messages.stream(
280
+ model=CLAUDE_MODEL_STRING,
281
+ max_tokens=CLAUDE_MAX_TOKENS // 2,
282
+ temperature=0.7,
283
+ messages=[{"role": "user", "content": chunk}],
284
+ ) as stream:
285
+ message = await stream.get_final_message()
286
+ results.append(message.content[0].text)
287
+ logging.info(f"Chunk processed. Input tokens: {message.usage.input_tokens:,}, Output tokens: {message.usage.output_tokens:,}")
288
+ except Exception as e:
289
+ logging.error(f"An error occurred while processing a chunk: {e}")
290
+ return " ".join(results)
291
+ else:
292
+ try:
293
+ async with client.messages.stream(
294
+ model=CLAUDE_MODEL_STRING,
295
+ max_tokens=adjusted_max_tokens,
296
+ temperature=0.7,
297
+ messages=[{"role": "user", "content": prompt}],
298
+ ) as stream:
299
+ message = await stream.get_final_message()
300
+ output_text = message.content[0].text
301
+ logging.info(f"Total input tokens: {message.usage.input_tokens:,}")
302
+ logging.info(f"Total output tokens: {message.usage.output_tokens:,}")
303
+ logging.info(f"Generated output (abbreviated): {output_text[:150]}...")
304
+ return output_text
305
+ except Exception as e:
306
+ logging.error(f"An error occurred while requesting from Claude API: {e}")
307
+ return None
308
+
309
+ async def generate_completion_from_openai(prompt: str, max_tokens: int = 5000) -> Optional[str]:
310
+ if not OPENAI_API_KEY:
311
+ logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.")
312
+ return None
313
+ prompt_tokens = estimate_tokens(prompt, OPENAI_COMPLETION_MODEL)
314
+ adjusted_max_tokens = min(max_tokens, 4096 - prompt_tokens - TOKEN_BUFFER) # 4096 is typical max for GPT-3.5 and GPT-4
315
+ if adjusted_max_tokens <= 0:
316
+ logging.warning("Prompt is too long for OpenAI API. Chunking the input.")
317
+ chunks = chunk_text(prompt, OPENAI_MAX_TOKENS - TOKEN_CUSHION, OPENAI_COMPLETION_MODEL)
318
+ results = []
319
+ for chunk in chunks:
320
+ try:
321
+ response = await openai_client.chat.completions.create(
322
+ model=OPENAI_COMPLETION_MODEL,
323
+ messages=[{"role": "user", "content": chunk}],
324
+ max_tokens=adjusted_max_tokens,
325
+ temperature=0.7,
326
+ )
327
+ result = response.choices[0].message.content
328
+ results.append(result)
329
+ logging.info(f"Chunk processed. Output tokens: {response.usage.completion_tokens:,}")
330
+ except Exception as e:
331
+ logging.error(f"An error occurred while processing a chunk: {e}")
332
+ return " ".join(results)
333
+ else:
334
+ try:
335
+ response = await openai_client.chat.completions.create(
336
+ model=OPENAI_COMPLETION_MODEL,
337
+ messages=[{"role": "user", "content": prompt}],
338
+ max_tokens=adjusted_max_tokens,
339
+ temperature=0.7,
340
+ )
341
+ output_text = response.choices[0].message.content
342
+ logging.info(f"Total tokens: {response.usage.total_tokens:,}")
343
+ logging.info(f"Generated output (abbreviated): {output_text[:150]}...")
344
+ return output_text
345
+ except Exception as e:
346
+ logging.error(f"An error occurred while requesting from OpenAI API: {e}")
347
+ return None
348
+
349
+ async def generate_completion_from_local_llm(llm_model_name: str, input_prompt: str, number_of_tokens_to_generate: int = 100, temperature: float = 0.7, grammar_file_string: str = None):
350
+ logging.info(f"Starting text completion using model: '{llm_model_name}' for input prompt: '{input_prompt}'")
351
+ llm = load_model(llm_model_name)
352
+ prompt_tokens = estimate_tokens(input_prompt, llm_model_name)
353
+ adjusted_max_tokens = min(number_of_tokens_to_generate, LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - prompt_tokens - TOKEN_BUFFER)
354
+ if adjusted_max_tokens <= 0:
355
+ logging.warning("Prompt is too long for LLM. Chunking the input.")
356
+ chunks = chunk_text(input_prompt, LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - TOKEN_CUSHION, llm_model_name)
357
+ results = []
358
+ for chunk in chunks:
359
+ try:
360
+ output = llm(
361
+ prompt=chunk,
362
+ max_tokens=LOCAL_LLM_CONTEXT_SIZE_IN_TOKENS - TOKEN_CUSHION,
363
+ temperature=temperature,
364
+ )
365
+ results.append(output['choices'][0]['text'])
366
+ logging.info(f"Chunk processed. Output tokens: {output['usage']['completion_tokens']:,}")
367
+ except Exception as e:
368
+ logging.error(f"An error occurred while processing a chunk: {e}")
369
+ return " ".join(results)
370
+ else:
371
+ grammar_file_string_lower = grammar_file_string.lower() if grammar_file_string else ""
372
+ if grammar_file_string_lower:
373
+ list_of_grammar_files = glob.glob("./grammar_files/*.gbnf")
374
+ matching_grammar_files = [x for x in list_of_grammar_files if grammar_file_string_lower in os.path.splitext(os.path.basename(x).lower())[0]]
375
+ if len(matching_grammar_files) == 0:
376
+ logging.error(f"No grammar file found matching: {grammar_file_string}")
377
+ raise FileNotFoundError
378
+ grammar_file_path = max(matching_grammar_files, key=os.path.getmtime)
379
+ logging.info(f"Loading selected grammar file: '{grammar_file_path}'")
380
+ llama_grammar = LlamaGrammar.from_file(grammar_file_path)
381
+ output = llm(
382
+ prompt=input_prompt,
383
+ max_tokens=adjusted_max_tokens,
384
+ temperature=temperature,
385
+ grammar=llama_grammar
386
+ )
387
+ else:
388
+ output = llm(
389
+ prompt=input_prompt,
390
+ max_tokens=adjusted_max_tokens,
391
+ temperature=temperature
392
+ )
393
+ generated_text = output['choices'][0]['text']
394
+ if grammar_file_string == 'json':
395
+ generated_text = generated_text.encode('unicode_escape').decode()
396
+ finish_reason = str(output['choices'][0]['finish_reason'])
397
+ llm_model_usage_json = json.dumps(output['usage'])
398
+ logging.info(f"Completed text completion in {output['usage']['total_time']:.2f} seconds. Beginning of generated text: \n'{generated_text[:150]}'...")
399
+ return {
400
+ "generated_text": generated_text,
401
+ "finish_reason": finish_reason,
402
+ "llm_model_usage_json": llm_model_usage_json
403
+ }
404
+
405
+ # Image Processing Functions
406
+ def preprocess_image(image):
407
+ gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
408
+ gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
409
+ kernel = np.ones((1, 1), np.uint8)
410
+ gray = cv2.dilate(gray, kernel, iterations=1)
411
+ return Image.fromarray(gray)
412
+
413
+ def convert_pdf_to_images(input_pdf_file_path: str, max_pages: int = 0, skip_first_n_pages: int = 0) -> List[Image.Image]:
414
+ logging.info(f"Processing PDF file {input_pdf_file_path}")
415
+ if max_pages == 0:
416
+ last_page = None
417
+ logging.info("Converting all pages to images...")
418
+ else:
419
+ last_page = skip_first_n_pages + max_pages
420
+ logging.info(f"Converting pages {skip_first_n_pages + 1} to {last_page}")
421
+ first_page = skip_first_n_pages + 1 # pdf2image uses 1-based indexing
422
+ images = convert_from_path(input_pdf_file_path, first_page=first_page, last_page=last_page)
423
+ logging.info(f"Converted {len(images)} pages from PDF file to images.")
424
+ return images
425
+
426
+ def ocr_image(image):
427
+ preprocessed_image = preprocess_image(image)
428
+ return pytesseract.image_to_string(preprocessed_image)
429
+
430
+ async def process_chunk(chunk: str, prev_context: str, chunk_index: int, total_chunks: int, reformat_as_markdown: bool, suppress_headers_and_page_numbers: bool) -> Tuple[str, str]:
431
+ logging.info(f"Processing chunk {chunk_index + 1}/{total_chunks} (length: {len(chunk):,} characters)")
432
+
433
+ # Step 1: OCR Correction
434
+ ocr_correction_prompt = f"""Correct OCR-induced errors in the text, ensuring it flows coherently with the previous context. Follow these guidelines:
435
+
436
+ 1. Fix OCR-induced typos and errors:
437
+ - Correct words split across line breaks
438
+ - Fix common OCR errors (e.g., 'rn' misread as 'm')
439
+ - Use context and common sense to correct errors
440
+ - Only fix clear errors, don't alter the content unnecessarily
441
+ - Do not add extra periods or any unnecessary punctuation
442
+
443
+ 2. Maintain original structure:
444
+ - Keep all headings and subheadings intact
445
+
446
+ 3. Preserve original content:
447
+ - Keep all important information from the original text
448
+ - Do not add any new information not present in the original text
449
+ - Remove unnecessary line breaks within sentences or paragraphs
450
+ - Maintain paragraph breaks
451
+
452
+ 4. Maintain coherence:
453
+ - Ensure the content connects smoothly with the previous context
454
+ - Handle text that starts or ends mid-sentence appropriately
455
+
456
+ IMPORTANT: Respond ONLY with the corrected text. Preserve all original formatting, including line breaks. Do not include any introduction, explanation, or metadata.
457
+
458
+ Previous context:
459
+ {prev_context[-500:]}
460
+
461
+ Current chunk to process:
462
+ {chunk}
463
+
464
+ Corrected text:
465
+ """
466
+
467
+ ocr_corrected_chunk = await generate_completion(ocr_correction_prompt, max_tokens=len(chunk) + 500)
468
+
469
+ processed_chunk = ocr_corrected_chunk
470
+
471
+ # Step 2: Markdown Formatting (if requested)
472
+ if reformat_as_markdown:
473
+ markdown_prompt = f"""Reformat the following text as markdown, improving readability while preserving the original structure. Follow these guidelines:
474
+ 1. Preserve all original headings, converting them to appropriate markdown heading levels (# for main titles, ## for subtitles, etc.)
475
+ - Ensure each heading is on its own line
476
+ - Add a blank line before and after each heading
477
+ 2. Maintain the original paragraph structure. Remove all breaks within a word that should be a single word (for example, "cor- rect" should be "correct")
478
+ 3. Format lists properly (unordered or ordered) if they exist in the original text
479
+ 4. Use emphasis (*italic*) and strong emphasis (**bold**) where appropriate, based on the original formatting
480
+ 5. Preserve all original content and meaning
481
+ 6. Do not add any extra punctuation or modify the existing punctuation
482
+ 7. Remove any spuriously inserted introductory text such as "Here is the corrected text:" that may have been added by the LLM and which is obviously not part of the original text.
483
+ 8. Remove any obviously duplicated content that appears to have been accidentally included twice. Follow these strict guidelines:
484
+ - Remove only exact or near-exact repeated paragraphs or sections within the main chunk.
485
+ - Consider the context (before and after the main chunk) to identify duplicates that span chunk boundaries.
486
+ - Do not remove content that is simply similar but conveys different information.
487
+ - Preserve all unique content, even if it seems redundant.
488
+ - Ensure the text flows smoothly after removal.
489
+ - Do not add any new content or explanations.
490
+ - If no obvious duplicates are found, return the main chunk unchanged.
491
+ 9. {"Identify but do not remove headers, footers, or page numbers. Instead, format them distinctly, e.g., as blockquotes." if not suppress_headers_and_page_numbers else "Carefully remove headers, footers, and page numbers while preserving all other content."}
492
+
493
+ Text to reformat:
494
+
495
+ {ocr_corrected_chunk}
496
+
497
+ Reformatted markdown:
498
+ """
499
+ processed_chunk = await generate_completion(markdown_prompt, max_tokens=len(ocr_corrected_chunk) + 500)
500
+ new_context = processed_chunk[-1000:] # Use the last 1000 characters as context for the next chunk
501
+ logging.info(f"Chunk {chunk_index + 1}/{total_chunks} processed. Output length: {len(processed_chunk):,} characters")
502
+ return processed_chunk, new_context
503
+
504
+ async def process_chunks(chunks: List[str], reformat_as_markdown: bool, suppress_headers_and_page_numbers: bool) -> List[str]:
505
+ total_chunks = len(chunks)
506
+ async def process_chunk_with_context(chunk: str, prev_context: str, index: int) -> Tuple[int, str, str]:
507
+ processed_chunk, new_context = await process_chunk(chunk, prev_context, index, total_chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
508
+ return index, processed_chunk, new_context
509
+ if USE_LOCAL_LLM:
510
+ logging.info("Using local LLM. Processing chunks sequentially...")
511
+ context = ""
512
+ processed_chunks = []
513
+ for i, chunk in enumerate(chunks):
514
+ processed_chunk, context = await process_chunk(chunk, context, i, total_chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
515
+ processed_chunks.append(processed_chunk)
516
+ else:
517
+ logging.info("Using API-based LLM. Processing chunks concurrently while maintaining order...")
518
+ tasks = [process_chunk_with_context(chunk, "", i) for i, chunk in enumerate(chunks)]
519
+ results = await asyncio.gather(*tasks)
520
+ # Sort results by index to maintain order
521
+ sorted_results = sorted(results, key=lambda x: x[0])
522
+ processed_chunks = [chunk for _, chunk, _ in sorted_results]
523
+ logging.info(f"All {total_chunks} chunks processed successfully")
524
+ return processed_chunks
525
+
526
+ async def process_document(list_of_extracted_text_strings: List[str], reformat_as_markdown: bool = True, suppress_headers_and_page_numbers: bool = True) -> str:
527
+ logging.info(f"Starting document processing. Total pages: {len(list_of_extracted_text_strings):,}")
528
+ full_text = "\n\n".join(list_of_extracted_text_strings)
529
+ logging.info(f"Size of full text before processing: {len(full_text):,} characters")
530
+ chunk_size, overlap = 8000, 10
531
+ # Improved chunking logic
532
+ paragraphs = re.split(r'\n\s*\n', full_text)
533
+ chunks = []
534
+ current_chunk = []
535
+ current_chunk_length = 0
536
+ for paragraph in paragraphs:
537
+ paragraph_length = len(paragraph)
538
+ if current_chunk_length + paragraph_length <= chunk_size:
539
+ current_chunk.append(paragraph)
540
+ current_chunk_length += paragraph_length
541
+ else:
542
+ # If adding the whole paragraph exceeds the chunk size,
543
+ # we need to split the paragraph into sentences
544
+ if current_chunk:
545
+ chunks.append("\n\n".join(current_chunk))
546
+ sentences = re.split(r'(?<=[.!?])\s+', paragraph)
547
+ current_chunk = []
548
+ current_chunk_length = 0
549
+ for sentence in sentences:
550
+ sentence_length = len(sentence)
551
+ if current_chunk_length + sentence_length <= chunk_size:
552
+ current_chunk.append(sentence)
553
+ current_chunk_length += sentence_length
554
+ else:
555
+ if current_chunk:
556
+ chunks.append(" ".join(current_chunk))
557
+ current_chunk = [sentence]
558
+ current_chunk_length = sentence_length
559
+ # Add any remaining content as the last chunk
560
+ if current_chunk:
561
+ chunks.append("\n\n".join(current_chunk) if len(current_chunk) > 1 else current_chunk[0])
562
+ # Add overlap between chunks
563
+ for i in range(1, len(chunks)):
564
+ overlap_text = chunks[i-1].split()[-overlap:]
565
+ chunks[i] = " ".join(overlap_text) + " " + chunks[i]
566
+ logging.info(f"Document split into {len(chunks):,} chunks. Chunk size: {chunk_size:,}, Overlap: {overlap:,}")
567
+ processed_chunks = await process_chunks(chunks, reformat_as_markdown, suppress_headers_and_page_numbers)
568
+ final_text = "".join(processed_chunks)
569
+ logging.info(f"Size of text after combining chunks: {len(final_text):,} characters")
570
+ logging.info(f"Document processing complete. Final text length: {len(final_text):,} characters")
571
+ return final_text
572
+
573
+ def remove_corrected_text_header(text):
574
+ return text.replace("# Corrected text\n", "").replace("# Corrected text:", "").replace("\nCorrected text", "").replace("Corrected text:", "")
575
+
576
+ async def assess_output_quality(original_text, processed_text):
577
+ max_chars = 15000 # Limit to avoid exceeding token limits
578
+ available_chars_per_text = max_chars // 2 # Split equally between original and processed
579
+
580
+ original_sample = original_text[:available_chars_per_text]
581
+ processed_sample = processed_text[:available_chars_per_text]
582
+
583
+ prompt = f"""Compare the following samples of original OCR text with the processed output and assess the quality of the processing. Consider the following factors:
584
+ 1. Accuracy of error correction
585
+ 2. Improvement in readability
586
+ 3. Preservation of original content and meaning
587
+ 4. Appropriate use of markdown formatting (if applicable)
588
+ 5. Removal of hallucinations or irrelevant content
589
+
590
+ Original text sample:
591
+ ```
592
+ {original_sample}
593
+ ```
594
+
595
+ Processed text sample:
596
+ ```
597
+ {processed_sample}
598
+ ```
599
+
600
+ Provide a quality score between 0 and 100, where 100 is perfect processing. Also provide a brief explanation of your assessment.
601
+
602
+ Your response should be in the following format:
603
+ SCORE: [Your score]
604
+ EXPLANATION: [Your explanation]
605
+ """
606
+
607
+ response = await generate_completion(prompt, max_tokens=1000)
608
+
609
+ try:
610
+ lines = response.strip().split('\n')
611
+ score_line = next(line for line in lines if line.startswith('SCORE:'))
612
+ score = int(score_line.split(':')[1].strip())
613
+ explanation = '\n'.join(line for line in lines if line.startswith('EXPLANATION:')).replace('EXPLANATION:', '').strip()
614
+ logging.info(f"Quality assessment: Score {score}/100")
615
+ logging.info(f"Explanation: {explanation}")
616
+ return score, explanation
617
+ except Exception as e:
618
+ logging.error(f"Error parsing quality assessment response: {e}")
619
+ logging.error(f"Raw response: {response}")
620
+ return None, None
621
+
622
+ async def main():
623
+ try:
624
+ # Suppress HTTP request logs
625
+ logging.getLogger("httpx").setLevel(logging.WARNING)
626
+ input_pdf_file_path = '160301289-Warren-Buffett-Katharine-Graham-Letter.pdf'
627
+ max_test_pages = 0
628
+ skip_first_n_pages = 0
629
+ reformat_as_markdown = True
630
+ suppress_headers_and_page_numbers = True
631
+
632
+ # Download the model if using local LLM
633
+ if USE_LOCAL_LLM:
634
+ _, download_status = await download_models()
635
+ logging.info(f"Model download status: {download_status}")
636
+ logging.info(f"Using Local LLM with Model: {DEFAULT_LOCAL_MODEL_NAME}")
637
+ else:
638
+ logging.info(f"Using API for completions: {API_PROVIDER}")
639
+ logging.info(f"Using OpenAI model for embeddings: {OPENAI_EMBEDDING_MODEL}")
640
+
641
+ base_name = os.path.splitext(input_pdf_file_path)[0]
642
+ output_extension = '.md' if reformat_as_markdown else '.txt'
643
+
644
+ raw_ocr_output_file_path = f"{base_name}__raw_ocr_output.txt"
645
+ llm_corrected_output_file_path = base_name + '_llm_corrected' + output_extension
646
+
647
+ list_of_scanned_images = convert_pdf_to_images(input_pdf_file_path, max_test_pages, skip_first_n_pages)
648
+ logging.info(f"Tesseract version: {pytesseract.get_tesseract_version()}")
649
+ logging.info("Extracting text from converted pages...")
650
+ with ThreadPoolExecutor() as executor:
651
+ list_of_extracted_text_strings = list(executor.map(ocr_image, list_of_scanned_images))
652
+ logging.info("Done extracting text from converted pages.")
653
+ raw_ocr_output = "\n".join(list_of_extracted_text_strings)
654
+ with open(raw_ocr_output_file_path, "w") as f:
655
+ f.write(raw_ocr_output)
656
+ logging.info(f"Raw OCR output written to: {raw_ocr_output_file_path}")
657
+
658
+ logging.info("Processing document...")
659
+ final_text = await process_document(list_of_extracted_text_strings, reformat_as_markdown, suppress_headers_and_page_numbers)
660
+ cleaned_text = remove_corrected_text_header(final_text)
661
+
662
+ # Save the LLM corrected output
663
+ with open(llm_corrected_output_file_path, 'w') as f:
664
+ f.write(cleaned_text)
665
+ logging.info(f"LLM Corrected text written to: {llm_corrected_output_file_path}")
666
+
667
+ if final_text:
668
+ logging.info(f"First 500 characters of LLM corrected processed text:\n{final_text[:500]}...")
669
+ else:
670
+ logging.warning("final_text is empty or not defined.")
671
+
672
+ logging.info(f"Done processing {input_pdf_file_path}.")
673
+ logging.info("\nSee output files:")
674
+ logging.info(f" Raw OCR: {raw_ocr_output_file_path}")
675
+ logging.info(f" LLM Corrected: {llm_corrected_output_file_path}")
676
+
677
+ # Perform a final quality check
678
+ quality_score, explanation = await assess_output_quality(raw_ocr_output, final_text)
679
+ if quality_score is not None:
680
+ logging.info(f"Final quality score: {quality_score}/100")
681
+ logging.info(f"Explanation: {explanation}")
682
+ else:
683
+ logging.warning("Unable to determine final quality score.")
684
+ except Exception as e:
685
+ logging.error(f"An error occurred in the main function: {e}")
686
+ logging.error(traceback.format_exc())
687
+
688
+ if __name__ == '__main__':
689
+ asyncio.run(main())