kssrag 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kssrag/cli.py +59 -13
- kssrag/config.py +15 -1
- kssrag/core/agents.py +61 -10
- kssrag/core/chunkers.py +95 -1
- kssrag/core/vectorstores.py +103 -2
- kssrag/models/openrouter.py +77 -15
- kssrag/server.py +49 -0
- kssrag/utils/document_loaders.py +80 -2
- kssrag/utils/helpers.py +74 -31
- kssrag/utils/ocr.py +48 -0
- kssrag/utils/ocr_loader.py +151 -0
- kssrag-0.2.0.dist-info/METADATA +840 -0
- kssrag-0.2.0.dist-info/RECORD +33 -0
- tests/test_bm25s.py +74 -0
- tests/test_config.py +42 -0
- tests/test_image_chunker.py +17 -0
- tests/test_integration.py +35 -0
- tests/test_ocr.py +142 -0
- tests/test_streaming.py +41 -0
- kssrag-0.1.1.dist-info/METADATA +0 -407
- kssrag-0.1.1.dist-info/RECORD +0 -25
- {kssrag-0.1.1.dist-info → kssrag-0.2.0.dist-info}/WHEEL +0 -0
- {kssrag-0.1.1.dist-info → kssrag-0.2.0.dist-info}/entry_points.txt +0 -0
- {kssrag-0.1.1.dist-info → kssrag-0.2.0.dist-info}/top_level.txt +0 -0
kssrag/server.py
CHANGED
|
@@ -1,8 +1,12 @@
|
|
|
1
1
|
from fastapi import FastAPI, HTTPException
|
|
2
2
|
from fastapi.middleware.cors import CORSMiddleware
|
|
3
|
+
from fastapi.responses import StreamingResponse
|
|
3
4
|
from pydantic import BaseModel
|
|
4
5
|
from typing import Dict, Any, Optional, List
|
|
5
6
|
import uuid
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
from kssrag.models.openrouter import OpenRouterLLM
|
|
6
10
|
|
|
7
11
|
from .core.agents import RAGAgent
|
|
8
12
|
from .utils.helpers import logger
|
|
@@ -12,6 +16,10 @@ class QueryRequest(BaseModel):
|
|
|
12
16
|
query: str
|
|
13
17
|
session_id: Optional[str] = None
|
|
14
18
|
|
|
19
|
+
class StreamResponse(BaseModel):
|
|
20
|
+
chunk: str
|
|
21
|
+
done: bool = False
|
|
22
|
+
|
|
15
23
|
class ServerConfig(BaseModel):
|
|
16
24
|
"""Configuration for the FastAPI server"""
|
|
17
25
|
host: str = config.SERVER_HOST
|
|
@@ -80,6 +88,47 @@ def create_app(rag_agent: RAGAgent, server_config: Optional[ServerConfig] = None
|
|
|
80
88
|
logger.error(f"Error handling query: {str(e)}")
|
|
81
89
|
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
|
|
82
90
|
|
|
91
|
+
@app.post("/stream")
|
|
92
|
+
async def stream_query(request: QueryRequest):
|
|
93
|
+
"""Streaming query endpoint with Server-Sent Events"""
|
|
94
|
+
query = request.query
|
|
95
|
+
session_id = request.session_id or str(uuid.uuid4())
|
|
96
|
+
|
|
97
|
+
if not query.strip():
|
|
98
|
+
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
if session_id not in sessions:
|
|
102
|
+
sessions[session_id] = RAGAgent(
|
|
103
|
+
retriever=rag_agent.retriever,
|
|
104
|
+
llm=OpenRouterLLM(stream=True),
|
|
105
|
+
system_prompt=rag_agent.system_prompt
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
agent = sessions[session_id]
|
|
109
|
+
|
|
110
|
+
async def generate():
|
|
111
|
+
try:
|
|
112
|
+
for chunk in agent.llm.predict_stream(agent._build_messages(query)):
|
|
113
|
+
yield f"data: {json.dumps({'chunk': chunk, 'done': False})}\n\n"
|
|
114
|
+
yield f"data: {json.dumps({'chunk': '', 'done': True})}\n\n"
|
|
115
|
+
except Exception as e:
|
|
116
|
+
logger.error(f"Streaming error: {str(e)}")
|
|
117
|
+
yield f"data: {json.dumps({'error': str(e), 'done': True})}\n\n"
|
|
118
|
+
|
|
119
|
+
return StreamingResponse(
|
|
120
|
+
generate(),
|
|
121
|
+
media_type="text/plain",
|
|
122
|
+
headers={
|
|
123
|
+
"Cache-Control": "no-cache",
|
|
124
|
+
"Connection": "keep-alive",
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
except Exception as e:
|
|
129
|
+
logger.error(f"Streaming query failed: {str(e)}")
|
|
130
|
+
raise HTTPException(status_code=500, detail=f"Streaming error: {str(e)}")
|
|
131
|
+
|
|
83
132
|
@app.get("/health")
|
|
84
133
|
async def health_check():
|
|
85
134
|
"""Health check endpoint"""
|
kssrag/utils/document_loaders.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import os
|
|
2
3
|
from typing import List, Dict, Any, Optional
|
|
3
4
|
from ..utils.helpers import logger
|
|
4
5
|
|
|
@@ -20,15 +21,92 @@ def load_json_file(file_path: str) -> Any:
|
|
|
20
21
|
logger.error(f"Failed to load JSON file: {str(e)}")
|
|
21
22
|
raise
|
|
22
23
|
|
|
24
|
+
def load_docx_file(file_path: str) -> str:
|
|
25
|
+
"""Load text from DOCX file"""
|
|
26
|
+
try:
|
|
27
|
+
from docx import Document
|
|
28
|
+
doc = Document(file_path)
|
|
29
|
+
text = ""
|
|
30
|
+
for paragraph in doc.paragraphs:
|
|
31
|
+
if paragraph.text.strip():
|
|
32
|
+
text += paragraph.text + "\n"
|
|
33
|
+
|
|
34
|
+
# Extract text from tables
|
|
35
|
+
for table in doc.tables:
|
|
36
|
+
for row in table.rows:
|
|
37
|
+
for cell in row.cells:
|
|
38
|
+
if cell.text.strip():
|
|
39
|
+
text += cell.text + "\n"
|
|
40
|
+
|
|
41
|
+
return text.strip()
|
|
42
|
+
except ImportError:
|
|
43
|
+
raise ImportError("python-docx is required for DOCX support. Install with: pip install kssrag[office]")
|
|
44
|
+
except Exception as e:
|
|
45
|
+
logger.error(f"Failed to load DOCX file: {str(e)}")
|
|
46
|
+
raise
|
|
47
|
+
|
|
48
|
+
def load_excel_file(file_path: str) -> str:
|
|
49
|
+
"""Load text from Excel file"""
|
|
50
|
+
try:
|
|
51
|
+
import openpyxl
|
|
52
|
+
workbook = openpyxl.load_workbook(file_path)
|
|
53
|
+
text = ""
|
|
54
|
+
|
|
55
|
+
for sheet_name in workbook.sheetnames:
|
|
56
|
+
sheet = workbook[sheet_name]
|
|
57
|
+
text += f"Sheet: {sheet_name}\n"
|
|
58
|
+
|
|
59
|
+
for row in sheet.iter_rows(values_only=True):
|
|
60
|
+
row_text = " | ".join(str(cell) if cell is not None else "" for cell in row)
|
|
61
|
+
if row_text.strip():
|
|
62
|
+
text += row_text + "\n"
|
|
63
|
+
text += "\n"
|
|
64
|
+
|
|
65
|
+
return text.strip()
|
|
66
|
+
except ImportError:
|
|
67
|
+
raise ImportError("openpyxl is required for Excel support. Install with: pip install kssrag[office]")
|
|
68
|
+
except Exception as e:
|
|
69
|
+
logger.error(f"Failed to load Excel file: {str(e)}")
|
|
70
|
+
raise
|
|
71
|
+
|
|
72
|
+
def load_pptx_file(file_path: str) -> str:
|
|
73
|
+
"""Load text from PowerPoint file"""
|
|
74
|
+
try:
|
|
75
|
+
from pptx import Presentation
|
|
76
|
+
prs = Presentation(file_path)
|
|
77
|
+
text = ""
|
|
78
|
+
|
|
79
|
+
for slide_number, slide in enumerate(prs.slides, 1):
|
|
80
|
+
text += f"Slide {slide_number}:\n"
|
|
81
|
+
|
|
82
|
+
for shape in slide.shapes:
|
|
83
|
+
if hasattr(shape, "text") and shape.text.strip():
|
|
84
|
+
text += shape.text + "\n"
|
|
85
|
+
|
|
86
|
+
text += "\n"
|
|
87
|
+
|
|
88
|
+
return text.strip()
|
|
89
|
+
except ImportError:
|
|
90
|
+
raise ImportError("python-pptx is required for PowerPoint support. Install with: pip install kssrag[office]")
|
|
91
|
+
except Exception as e:
|
|
92
|
+
logger.error(f"Failed to load PowerPoint file: {str(e)}")
|
|
93
|
+
raise
|
|
94
|
+
|
|
23
95
|
def load_document(file_path: str) -> str:
|
|
24
|
-
"""Load document from file
|
|
96
|
+
"""Load document from file with auto-format detection"""
|
|
25
97
|
if file_path.endswith('.txt'):
|
|
26
98
|
return load_txt_file(file_path)
|
|
99
|
+
elif file_path.endswith('.docx'):
|
|
100
|
+
return load_docx_file(file_path)
|
|
101
|
+
elif file_path.endswith(('.xlsx', '.xls')):
|
|
102
|
+
return load_excel_file(file_path)
|
|
103
|
+
elif file_path.endswith('.pptx'):
|
|
104
|
+
return load_pptx_file(file_path)
|
|
27
105
|
else:
|
|
28
106
|
raise ValueError(f"Unsupported file type: {file_path}")
|
|
29
107
|
|
|
30
108
|
def load_json_documents(file_path: str, metadata_field: str = "name") -> List[Dict[str, Any]]:
|
|
31
|
-
"""Load documents from JSON file
|
|
109
|
+
"""Load documents from JSON file"""
|
|
32
110
|
data = load_json_file(file_path)
|
|
33
111
|
|
|
34
112
|
# Apply limit for testing if specified
|
kssrag/utils/helpers.py
CHANGED
|
@@ -8,18 +8,81 @@ logging.basicConfig(
|
|
|
8
8
|
)
|
|
9
9
|
logger = logging.getLogger("KSSRAG")
|
|
10
10
|
|
|
11
|
-
#
|
|
12
|
-
|
|
13
|
-
|
|
11
|
+
# Initialize as None - will be set when actually needed
|
|
12
|
+
FAISS_AVAILABLE = None
|
|
13
|
+
FAISS_AVX_TYPE = None
|
|
14
14
|
|
|
15
|
-
def
|
|
16
|
-
"""
|
|
17
|
-
|
|
15
|
+
def setup_faiss(vector_store_type: str = None):
|
|
16
|
+
"""Handle FAISS initialization - only when explicitly called"""
|
|
17
|
+
global FAISS_AVAILABLE, FAISS_AVX_TYPE
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
19
|
+
# If already initialized, return cached values
|
|
20
|
+
if FAISS_AVAILABLE is not None:
|
|
21
|
+
return FAISS_AVAILABLE, FAISS_AVX_TYPE
|
|
21
22
|
|
|
22
|
-
|
|
23
|
+
faiss_available = False
|
|
24
|
+
faiss_avx_type = "not_loaded"
|
|
25
|
+
|
|
26
|
+
# Only load FAISS if explicitly using FAISS-based stores
|
|
27
|
+
if vector_store_type in ["faiss", "hybrid_online"]:
|
|
28
|
+
try:
|
|
29
|
+
# Try different FAISS versions in order of preference
|
|
30
|
+
faiss_import_attempts = [
|
|
31
|
+
("AVX512-SPR", "faiss.swigfaiss_avx512_spr"),
|
|
32
|
+
("AVX512", "faiss.swigfaiss_avx512"),
|
|
33
|
+
("AVX2", "faiss.swigfaiss_avx2"),
|
|
34
|
+
("Standard", "faiss.swigfaiss")
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
for avx_type, import_path in faiss_import_attempts:
|
|
38
|
+
try:
|
|
39
|
+
logger.info(f"Loading faiss with {avx_type} support.")
|
|
40
|
+
faiss_module = importlib.import_module(import_path)
|
|
41
|
+
# Make the FAISS symbols available globally
|
|
42
|
+
globals().update({name: getattr(faiss_module, name) for name in dir(faiss_module) if not name.startswith('_')})
|
|
43
|
+
|
|
44
|
+
faiss_available = True
|
|
45
|
+
faiss_avx_type = avx_type
|
|
46
|
+
logger.info(f"Successfully loaded faiss with {avx_type} support.")
|
|
47
|
+
break
|
|
48
|
+
|
|
49
|
+
except ImportError as e:
|
|
50
|
+
logger.debug(f"Could not load library with {avx_type} support: {e}")
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if not faiss_available:
|
|
54
|
+
logger.warning("Could not load any FAISS version. FAISS-based vector stores will be disabled.")
|
|
55
|
+
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.error(f"Failed to initialize FAISS: {str(e)}")
|
|
58
|
+
faiss_available = False
|
|
59
|
+
else:
|
|
60
|
+
# Not using FAISS, don't load it
|
|
61
|
+
logger.debug(f"Skipping FAISS initialization for vector store: {vector_store_type}")
|
|
62
|
+
|
|
63
|
+
# Cache the results
|
|
64
|
+
FAISS_AVAILABLE = faiss_available
|
|
65
|
+
FAISS_AVX_TYPE = faiss_avx_type
|
|
66
|
+
|
|
67
|
+
return faiss_available, faiss_avx_type
|
|
68
|
+
|
|
69
|
+
def validate_config():
|
|
70
|
+
"""Validate the configuration - don't auto-load FAISS here"""
|
|
71
|
+
try:
|
|
72
|
+
from ..config import config
|
|
73
|
+
|
|
74
|
+
if not config.OPENROUTER_API_KEY:
|
|
75
|
+
logger.warning("OPENROUTER_API_KEY not set. LLM functionality will not work.")
|
|
76
|
+
|
|
77
|
+
# Don't auto-load FAISS here - let the vector stores handle it
|
|
78
|
+
return True
|
|
79
|
+
except ImportError:
|
|
80
|
+
# Config not available, continue anyway
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
# Your signature in the code
|
|
84
|
+
def kss_signature():
|
|
85
|
+
return "Built with HATE by Ksschkw (github.com/Ksschkw)"
|
|
23
86
|
|
|
24
87
|
def import_custom_component(import_path: str):
|
|
25
88
|
"""Import a custom component from a string path"""
|
|
@@ -31,25 +94,5 @@ def import_custom_component(import_path: str):
|
|
|
31
94
|
logger.error(f"Failed to import custom component {import_path}: {str(e)}")
|
|
32
95
|
raise
|
|
33
96
|
|
|
34
|
-
#
|
|
35
|
-
#
|
|
36
|
-
# from .utils.helpers import logger
|
|
37
|
-
|
|
38
|
-
# def setup_faiss():
|
|
39
|
-
# """Handle FAISS initialization with proper error handling"""
|
|
40
|
-
# try:
|
|
41
|
-
# # Try to load with AVX2 support first
|
|
42
|
-
# logger.info("Loading faiss with AVX2 support.")
|
|
43
|
-
# from faiss.swigfaiss_avx2 import *
|
|
44
|
-
# logger.info("Successfully loaded faiss with AVX2 support.")
|
|
45
|
-
# return True
|
|
46
|
-
# except ImportError as e:
|
|
47
|
-
# logger.info(f"Could not load library with AVX2 support due to:\n{repr(e)}")
|
|
48
|
-
# logger.info("Falling back to standard FAISS without AVX2 support")
|
|
49
|
-
# try:
|
|
50
|
-
# from faiss.swigfaiss import *
|
|
51
|
-
# logger.info("Successfully loaded standard faiss.")
|
|
52
|
-
# return False
|
|
53
|
-
# except ImportError as e:
|
|
54
|
-
# logger.error(f"Failed to load any FAISS version: {repr(e)}")
|
|
55
|
-
# raise
|
|
97
|
+
# Remove the auto-initialization at module level
|
|
98
|
+
# FAISS will now only load when explicitly called by vector stores that need it
|
kssrag/utils/ocr.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OCR utilities for KSS RAG.
|
|
3
|
+
Requires extra dependencies: `paddleocr`, `paddlepaddle`, `pytesseract`, `Pillow`.
|
|
4
|
+
Install via: pip install kssrag[ocr]
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
import pytesseract
|
|
9
|
+
from paddleocr import PaddleOCR
|
|
10
|
+
from PIL import Image
|
|
11
|
+
except ImportError as e:
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"OCR functionality requires extra dependencies. "
|
|
14
|
+
"Install with: pip install kssrag[ocr]"
|
|
15
|
+
) from e
|
|
16
|
+
|
|
17
|
+
# Initialize PaddleOCR (handwritten text)
|
|
18
|
+
_paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def ocr_tesseract(image_path: str) -> str:
|
|
22
|
+
"""OCR for typed text using Tesseract."""
|
|
23
|
+
img = Image.open(image_path)
|
|
24
|
+
text = pytesseract.image_to_string(img)
|
|
25
|
+
return text.strip()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def ocr_paddle(image_path: str) -> str:
|
|
29
|
+
"""OCR for handwritten text using PaddleOCR."""
|
|
30
|
+
results = _paddle_ocr.ocr(image_path, cls=True)
|
|
31
|
+
text = ""
|
|
32
|
+
for line in results:
|
|
33
|
+
for _, (txt, _) in line:
|
|
34
|
+
text += txt + " "
|
|
35
|
+
return text.strip()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def extract_text_from_image(image_path: str, mode: str = "typed") -> str:
|
|
39
|
+
"""
|
|
40
|
+
Dispatch OCR engine.
|
|
41
|
+
mode = 'typed' (Tesseract) or 'handwritten' (PaddleOCR).
|
|
42
|
+
"""
|
|
43
|
+
if mode == "handwritten":
|
|
44
|
+
return ocr_paddle(image_path)
|
|
45
|
+
elif mode == "typed":
|
|
46
|
+
return ocr_tesseract(image_path)
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError("Invalid OCR mode. Choose 'typed' or 'handwritten'.")
|
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import cv2
|
|
3
|
+
import pytesseract
|
|
4
|
+
from paddleocr import PaddleOCR
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from .helpers import logger
|
|
8
|
+
|
|
9
|
+
class OCRLoader:
|
|
10
|
+
"""Production OCR handler with PaddleOCR (handwritten) and Tesseract (typed)"""
|
|
11
|
+
|
|
12
|
+
def __init__(self):
|
|
13
|
+
self.paddle_ocr = None
|
|
14
|
+
self._initialize_paddle_ocr()
|
|
15
|
+
|
|
16
|
+
# def _initialize_paddle_ocr(self):
|
|
17
|
+
# """Initialize PaddleOCR with custom model directories and fallback"""
|
|
18
|
+
# try:
|
|
19
|
+
# # Try to use custom model directories first
|
|
20
|
+
# det_model_dir = str(Path(__file__).parent.parent / 'paddle_models' / 'models' / 'ppocrv5_server_det')
|
|
21
|
+
# rec_model_dir = str(Path(__file__).parent.parent / 'paddle_models' / 'models' / 'ppocrv5_server_rec')
|
|
22
|
+
|
|
23
|
+
# # Create directories if they don't exist
|
|
24
|
+
# os.makedirs(det_model_dir, exist_ok=True)
|
|
25
|
+
# os.makedirs(rec_model_dir, exist_ok=True)
|
|
26
|
+
|
|
27
|
+
# # Try to initialize with custom directories
|
|
28
|
+
# try:
|
|
29
|
+
# self.paddle_ocr = PaddleOCR(
|
|
30
|
+
# det_model_dir=det_model_dir,
|
|
31
|
+
# rec_model_dir=rec_model_dir,
|
|
32
|
+
# use_angle_cls=True,
|
|
33
|
+
# lang="en"
|
|
34
|
+
# )
|
|
35
|
+
# logger.info("PaddleOCR initialized successfully with custom model directories")
|
|
36
|
+
|
|
37
|
+
# except (PermissionError, OSError) as e:
|
|
38
|
+
# logger.warning(f"Failed to initialize PaddleOCR with custom directories: {str(e)}. Using default directories.")
|
|
39
|
+
# # Fallback to default initialization
|
|
40
|
+
# self.paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en")
|
|
41
|
+
# logger.info("PaddleOCR initialized successfully with default directories")
|
|
42
|
+
|
|
43
|
+
# except Exception as e:
|
|
44
|
+
# logger.error(f"PaddleOCR initialization failed: {str(e)}")
|
|
45
|
+
# # Don't raise here - allow the loader to be created but OCR will fail when used
|
|
46
|
+
# self.paddle_ocr = None
|
|
47
|
+
|
|
48
|
+
def _initialize_paddle_ocr(self):
|
|
49
|
+
"""Initialize PaddleOCR with better directory handling"""
|
|
50
|
+
try:
|
|
51
|
+
# Try to use custom model directories first
|
|
52
|
+
det_model_dir = str(Path(__file__).parent.parent / 'paddle_models' / 'models' / 'ppocrv5_server_det')
|
|
53
|
+
rec_model_dir = str(Path(__file__).parent.parent / 'paddle_models' / 'models' / 'ppocrv5_server_rec')
|
|
54
|
+
|
|
55
|
+
# Create directories if they don't exist
|
|
56
|
+
os.makedirs(det_model_dir, exist_ok=True)
|
|
57
|
+
os.makedirs(rec_model_dir, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
# Check if custom directories have the required files
|
|
60
|
+
custom_dirs_valid = (
|
|
61
|
+
os.path.exists(det_model_dir) and
|
|
62
|
+
os.path.exists(rec_model_dir) and
|
|
63
|
+
os.path.exists(os.path.join(det_model_dir, 'inference.yml')) and
|
|
64
|
+
os.path.exists(os.path.join(rec_model_dir, 'inference.yml'))
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if custom_dirs_valid:
|
|
68
|
+
self.paddle_ocr = PaddleOCR(
|
|
69
|
+
det_model_dir=det_model_dir,
|
|
70
|
+
rec_model_dir=rec_model_dir,
|
|
71
|
+
use_angle_cls=True,
|
|
72
|
+
lang="en"
|
|
73
|
+
)
|
|
74
|
+
logger.info("PaddleOCR initialized successfully with custom model directories")
|
|
75
|
+
else:
|
|
76
|
+
logger.info("Custom model directories not found, using default PaddleOCR initialization")
|
|
77
|
+
self.paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en")
|
|
78
|
+
logger.info("PaddleOCR initialized successfully with default directories")
|
|
79
|
+
|
|
80
|
+
except Exception as e:
|
|
81
|
+
logger.warning(f"PaddleOCR initialization failed: {str(e)}. Using default initialization.")
|
|
82
|
+
# Fallback to default initialization
|
|
83
|
+
self.paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en")
|
|
84
|
+
|
|
85
|
+
def ocr_tesseract(self, image_path: str) -> str:
|
|
86
|
+
"""OCR for typed text using Tesseract with error handling"""
|
|
87
|
+
try:
|
|
88
|
+
if not os.path.exists(image_path):
|
|
89
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
90
|
+
|
|
91
|
+
img = Image.open(image_path)
|
|
92
|
+
text = pytesseract.image_to_string(img)
|
|
93
|
+
|
|
94
|
+
if not text.strip():
|
|
95
|
+
logger.warning(f"Tesseract extracted no text from {image_path}")
|
|
96
|
+
|
|
97
|
+
return text.strip()
|
|
98
|
+
|
|
99
|
+
except FileNotFoundError:
|
|
100
|
+
# Re-raise FileNotFoundError directly
|
|
101
|
+
raise
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.error(f"Tesseract OCR failed for {image_path}: {str(e)}")
|
|
104
|
+
raise RuntimeError(f"Tesseract OCR failed: {str(e)}")
|
|
105
|
+
|
|
106
|
+
def ocr_paddle(self, image_path: str) -> str:
|
|
107
|
+
"""OCR for handwritten text using PaddleOCR with error handling"""
|
|
108
|
+
if self.paddle_ocr is None:
|
|
109
|
+
raise RuntimeError("PaddleOCR not initialized. OCR functionality unavailable.")
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
if not os.path.exists(image_path):
|
|
113
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
114
|
+
|
|
115
|
+
img = cv2.imread(image_path)
|
|
116
|
+
if img is None:
|
|
117
|
+
raise ValueError(f"Could not read image at {image_path}")
|
|
118
|
+
|
|
119
|
+
result = self.paddle_ocr.ocr(img, cls=True)
|
|
120
|
+
lines = []
|
|
121
|
+
|
|
122
|
+
if result and result[0]:
|
|
123
|
+
for line in result[0]:
|
|
124
|
+
if line and len(line) >= 2:
|
|
125
|
+
text_content = line[1][0] if isinstance(line[1], (list, tuple)) and len(line[1]) > 0 else ""
|
|
126
|
+
if text_content:
|
|
127
|
+
lines.append(text_content)
|
|
128
|
+
|
|
129
|
+
extracted_text = " ".join(lines).strip()
|
|
130
|
+
|
|
131
|
+
if not extracted_text:
|
|
132
|
+
logger.warning(f"PaddleOCR extracted no text from {image_path}")
|
|
133
|
+
|
|
134
|
+
return extracted_text
|
|
135
|
+
|
|
136
|
+
except FileNotFoundError:
|
|
137
|
+
# Re-raise FileNotFoundError directly
|
|
138
|
+
raise
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"PaddleOCR failed for {image_path}: {str(e)}")
|
|
141
|
+
raise RuntimeError(f"PaddleOCR failed: {str(e)}")
|
|
142
|
+
|
|
143
|
+
def extract_text(self, image_path: str, mode: str = "typed") -> str:
|
|
144
|
+
"""Extract text from image using specified OCR engine"""
|
|
145
|
+
if mode not in ["typed", "handwritten"]:
|
|
146
|
+
raise ValueError(f"Invalid OCR mode: {mode}. Must be 'typed' or 'handwritten'")
|
|
147
|
+
|
|
148
|
+
if mode == "handwritten":
|
|
149
|
+
return self.ocr_paddle(image_path)
|
|
150
|
+
else: # typed
|
|
151
|
+
return self.ocr_tesseract(image_path)
|