ragbandit-core 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbandit/__init__.py +26 -0
- ragbandit/config/__init__.py +3 -0
- ragbandit/config/llms.py +34 -0
- ragbandit/config/pricing.py +38 -0
- ragbandit/documents/__init__.py +66 -0
- ragbandit/documents/chunkers/__init__.py +18 -0
- ragbandit/documents/chunkers/base_chunker.py +201 -0
- ragbandit/documents/chunkers/fixed_size_chunker.py +174 -0
- ragbandit/documents/chunkers/semantic_chunker.py +205 -0
- ragbandit/documents/document_pipeline.py +350 -0
- ragbandit/documents/embedders/__init__.py +14 -0
- ragbandit/documents/embedders/base_embedder.py +82 -0
- ragbandit/documents/embedders/mistral_embedder.py +129 -0
- ragbandit/documents/ocr/__init__.py +13 -0
- ragbandit/documents/ocr/base_ocr.py +136 -0
- ragbandit/documents/ocr/mistral_ocr.py +147 -0
- ragbandit/documents/processors/__init__.py +16 -0
- ragbandit/documents/processors/base_processor.py +88 -0
- ragbandit/documents/processors/footnotes_processor.py +353 -0
- ragbandit/documents/processors/references_processor.py +408 -0
- ragbandit/documents/utils/__init__.py +11 -0
- ragbandit/documents/utils/secure_file_handler.py +95 -0
- ragbandit/prompt_tools/__init__.py +27 -0
- ragbandit/prompt_tools/footnotes_processor_tools.py +195 -0
- ragbandit/prompt_tools/prompt_tool.py +118 -0
- ragbandit/prompt_tools/references_processor_tools.py +31 -0
- ragbandit/prompt_tools/semantic_chunker_tools.py +56 -0
- ragbandit/schema.py +206 -0
- ragbandit/utils/__init__.py +19 -0
- ragbandit/utils/in_memory_log_handler.py +33 -0
- ragbandit/utils/llm_utils.py +188 -0
- ragbandit/utils/mistral_client.py +76 -0
- ragbandit/utils/token_usage_tracker.py +220 -0
- ragbandit_core-0.1.1.dist-info/METADATA +145 -0
- ragbandit_core-0.1.1.dist-info/RECORD +38 -0
- ragbandit_core-0.1.1.dist-info/WHEEL +5 -0
- ragbandit_core-0.1.1.dist-info/licenses/LICENSE.md +9 -0
- ragbandit_core-0.1.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from ragbandit.utils import mistral_client_manager
|
|
2
|
+
from ragbandit.utils.token_usage_tracker import TokenUsageTracker
|
|
3
|
+
from ragbandit.documents.embedders.base_embedder import BaseEmbedder
|
|
4
|
+
from ragbandit.schema import (
|
|
5
|
+
ChunkingResult,
|
|
6
|
+
EmbeddingResult,
|
|
7
|
+
ChunkWithEmbedding,
|
|
8
|
+
)
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MistralEmbedder(BaseEmbedder):
|
|
13
|
+
"""Document embedder that uses Mistral AI's embedding models."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
api_key: str,
|
|
18
|
+
model: str = "mistral-embed",
|
|
19
|
+
name: str = None,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initialize the Mistral embedder.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
api_key: Mistral API key
|
|
26
|
+
model: Embedding model to use
|
|
27
|
+
name: Optional name for the embedder
|
|
28
|
+
"""
|
|
29
|
+
super().__init__(name)
|
|
30
|
+
self.model = model
|
|
31
|
+
|
|
32
|
+
# Initialize the Mistral client
|
|
33
|
+
self.client = mistral_client_manager.get_client(api_key)
|
|
34
|
+
|
|
35
|
+
self.logger.info(
|
|
36
|
+
f"Initialized MistralEmbedder with model {self.model}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# ------------------------------------------------------------------
|
|
40
|
+
# Public API
|
|
41
|
+
def embed_chunks(
|
|
42
|
+
self,
|
|
43
|
+
chunk_result: ChunkingResult,
|
|
44
|
+
usage_tracker: TokenUsageTracker | None = None,
|
|
45
|
+
) -> EmbeddingResult:
|
|
46
|
+
"""Orchestrate the Mistral embedding flow."""
|
|
47
|
+
|
|
48
|
+
chunks = chunk_result.chunks
|
|
49
|
+
if not chunks:
|
|
50
|
+
self.logger.warning("No chunks to embed")
|
|
51
|
+
return self._empty_embedding_result()
|
|
52
|
+
|
|
53
|
+
try:
|
|
54
|
+
texts = self._extract_texts(chunks)
|
|
55
|
+
response = self._call_mistral_embeddings(texts)
|
|
56
|
+
return self._build_embedding_result(
|
|
57
|
+
chunks, response, usage_tracker
|
|
58
|
+
)
|
|
59
|
+
except Exception as e:
|
|
60
|
+
self.logger.error(f"Error generating embeddings: {e}")
|
|
61
|
+
return self._empty_embedding_result(chunks)
|
|
62
|
+
|
|
63
|
+
# ------------------------------------------------------------------
|
|
64
|
+
# Helpers
|
|
65
|
+
def _extract_texts(
|
|
66
|
+
self,
|
|
67
|
+
chunks: list[ChunkWithEmbedding | object],
|
|
68
|
+
) -> list[str]:
|
|
69
|
+
"""Extract raw text from chunks."""
|
|
70
|
+
return [c.text for c in chunks]
|
|
71
|
+
|
|
72
|
+
def _call_mistral_embeddings(self, texts: list[str]):
|
|
73
|
+
"""Call the Mistral embeddings API and return the raw response."""
|
|
74
|
+
self.logger.info("Requesting embeddings from Mistral API")
|
|
75
|
+
return self.client.embeddings.create(model=self.model, inputs=texts)
|
|
76
|
+
|
|
77
|
+
def _build_embedding_result(
|
|
78
|
+
self,
|
|
79
|
+
chunks,
|
|
80
|
+
response,
|
|
81
|
+
usage_tracker: TokenUsageTracker | None = None,
|
|
82
|
+
) -> EmbeddingResult:
|
|
83
|
+
"""Convert API response into EmbeddingResult."""
|
|
84
|
+
|
|
85
|
+
embedded_chunks: list[ChunkWithEmbedding] = []
|
|
86
|
+
for i, data in enumerate(response.data):
|
|
87
|
+
embedded_chunks.append(
|
|
88
|
+
ChunkWithEmbedding(
|
|
89
|
+
text=chunks[i].text,
|
|
90
|
+
metadata=chunks[i].metadata,
|
|
91
|
+
embedding=list(data.embedding),
|
|
92
|
+
embedding_model=self.model,
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
if usage_tracker:
|
|
97
|
+
usage_tracker.add_embedding_tokens(
|
|
98
|
+
response.usage.prompt_tokens, self.model
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return EmbeddingResult(
|
|
102
|
+
processed_at=datetime.now(timezone.utc),
|
|
103
|
+
chunks_with_embeddings=embedded_chunks,
|
|
104
|
+
model_name=self.model,
|
|
105
|
+
metrics=usage_tracker.get_summary() if usage_tracker else None,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _empty_embedding_result(
|
|
109
|
+
self, chunks: list | None = None
|
|
110
|
+
) -> EmbeddingResult:
|
|
111
|
+
"""Return an EmbeddingResult with no embeddings (error fallback)."""
|
|
112
|
+
|
|
113
|
+
empty_embeds: list[ChunkWithEmbedding] = []
|
|
114
|
+
for c in (chunks or []):
|
|
115
|
+
empty_embeds.append(
|
|
116
|
+
ChunkWithEmbedding(
|
|
117
|
+
text=c.text if hasattr(c, "text") else "",
|
|
118
|
+
metadata=c.metadata if hasattr(c, "metadata") else None,
|
|
119
|
+
embedding=[],
|
|
120
|
+
embedding_model=self.model,
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return EmbeddingResult(
|
|
125
|
+
processed_at=None,
|
|
126
|
+
chunks_with_embeddings=empty_embeds,
|
|
127
|
+
model_name=self.model,
|
|
128
|
+
metrics=None,
|
|
129
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""
|
|
2
|
+
OCR (Optical Character Recognition) implementations for document processing.
|
|
3
|
+
|
|
4
|
+
This module provides OCR processors that convert document images to text.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ragbandit.documents.ocr.base_ocr import BaseOCR
|
|
8
|
+
from ragbandit.documents.ocr.mistral_ocr import MistralOCRDocument
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"BaseOCR",
|
|
12
|
+
"MistralOCRDocument"
|
|
13
|
+
]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from io import BytesIO, BufferedReader
|
|
6
|
+
from ragbandit.documents.utils.secure_file_handler import SecureFileHandler
|
|
7
|
+
from ragbandit.schema import OCRResult
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseOCR(ABC):
|
|
11
|
+
"""Base class for OCR document processing.
|
|
12
|
+
|
|
13
|
+
This class provides the interface for OCR processing and a default
|
|
14
|
+
implementation using Mistral's OCR API.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, logger: logging.Logger = None, **kwargs):
|
|
18
|
+
"""Initialize the OCR processor.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
logger: Optional logger for OCR events
|
|
22
|
+
**kwargs: Additional keyword arguments (e.g., encryption_key)
|
|
23
|
+
"""
|
|
24
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
25
|
+
self.kwargs = kwargs
|
|
26
|
+
|
|
27
|
+
def validate_pdf(self, pdf_filepath: str) -> str:
|
|
28
|
+
"""Validate that a PDF file exists.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
pdf_filepath: Path to the PDF file to validate
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
str: The basename of the file
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
ValueError: If the file does not exist
|
|
38
|
+
"""
|
|
39
|
+
file_name = os.path.basename(pdf_filepath)
|
|
40
|
+
pdf_file_exists = os.path.isfile(pdf_filepath)
|
|
41
|
+
|
|
42
|
+
if not pdf_file_exists:
|
|
43
|
+
self.logger.error(f"PDF file {pdf_filepath} not found")
|
|
44
|
+
raise ValueError(f"PDF file {pdf_filepath} not found")
|
|
45
|
+
|
|
46
|
+
return file_name
|
|
47
|
+
|
|
48
|
+
def read_encrypted_file(self, pdf_filepath: str) -> BufferedReader:
|
|
49
|
+
"""Read an encrypted PDF file and return a buffered reader.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
pdf_filepath: Path to the encrypted PDF file
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
BufferedReader: A buffered reader for the decrypted file content
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If encryption_key is not provided in kwargs
|
|
59
|
+
"""
|
|
60
|
+
self.logger.info("Decrypting for OCR...")
|
|
61
|
+
|
|
62
|
+
encryption_key = self.kwargs.get("encryption_key")
|
|
63
|
+
if not encryption_key:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"encryption_key must be provided in kwargs "
|
|
66
|
+
"for encrypted file operations. "
|
|
67
|
+
"Pass encryption_key when initializing the OCR processor."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
secure_handler = SecureFileHandler(encryption_key=encryption_key)
|
|
71
|
+
decrypted = secure_handler.read_encrypted_file(Path(pdf_filepath))
|
|
72
|
+
raw = BytesIO(decrypted)
|
|
73
|
+
raw.seek(0)
|
|
74
|
+
return BufferedReader(raw)
|
|
75
|
+
|
|
76
|
+
def read_unencrypted_file(self, pdf_filepath: str) -> BufferedReader:
|
|
77
|
+
"""Read an unencrypted PDF file and return a buffered reader.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
pdf_filepath: Path to the unencrypted PDF file
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
BufferedReader: A buffered reader for the file content
|
|
84
|
+
"""
|
|
85
|
+
self.logger.info("Reading file for OCR...")
|
|
86
|
+
with open(pdf_filepath, 'rb') as f:
|
|
87
|
+
content = f.read()
|
|
88
|
+
|
|
89
|
+
raw = BytesIO(content)
|
|
90
|
+
raw.seek(0)
|
|
91
|
+
return BufferedReader(raw)
|
|
92
|
+
|
|
93
|
+
def validate_and_prepare_file(
|
|
94
|
+
self, pdf_filepath: str, encrypted: bool = True
|
|
95
|
+
) -> tuple[str, BufferedReader]:
|
|
96
|
+
"""Validate and prepare a PDF file for OCR processing.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
pdf_filepath: Path to the PDF file to process
|
|
100
|
+
encrypted: Whether the file is encrypted (default: True)
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
tuple: (file_name, file_reader)
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValueError: If the file does not exist
|
|
107
|
+
"""
|
|
108
|
+
file_name = self.validate_pdf(pdf_filepath)
|
|
109
|
+
|
|
110
|
+
if encrypted:
|
|
111
|
+
reader = self.read_encrypted_file(pdf_filepath)
|
|
112
|
+
else:
|
|
113
|
+
reader = self.read_unencrypted_file(pdf_filepath)
|
|
114
|
+
|
|
115
|
+
return file_name, reader
|
|
116
|
+
|
|
117
|
+
@abstractmethod
|
|
118
|
+
def process(self, pdf_filepath: str) -> OCRResult:
|
|
119
|
+
"""Process a PDF file through OCR.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
pdf_filepath: Path to the PDF file to process
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
OCRResult: The OCR result from the processor
|
|
126
|
+
"""
|
|
127
|
+
raise NotImplementedError("Subclasses must implement process method")
|
|
128
|
+
|
|
129
|
+
# ----------------------------------------------------------------------
|
|
130
|
+
def __str__(self) -> str:
|
|
131
|
+
"""Return a string representation of the OCR processor."""
|
|
132
|
+
return self.__class__.__name__
|
|
133
|
+
|
|
134
|
+
def __repr__(self) -> str:
|
|
135
|
+
"""Return a string representation of the OCR processor."""
|
|
136
|
+
return f"{self.__class__.__name__}()"
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from ragbandit.documents.ocr.base_ocr import BaseOCR
|
|
2
|
+
from ragbandit.schema import (
|
|
3
|
+
OCRResult, OCRPage, PageDimensions,
|
|
4
|
+
Image, OCRUsageInfo, PagesProcessedMetrics
|
|
5
|
+
)
|
|
6
|
+
from ragbandit.config.pricing import OCR_MODEL_COSTS
|
|
7
|
+
from mistralai.models.ocrresponse import OCRResponse
|
|
8
|
+
from io import BufferedReader
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from ragbandit.utils import mistral_client_manager
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MistralOCRDocument(BaseOCR):
|
|
16
|
+
"""OCR document processor using Mistral's API."""
|
|
17
|
+
|
|
18
|
+
# Explicit model identifier used for all Mistral OCR requests
|
|
19
|
+
MODEL_NAME = "mistral-ocr-latest"
|
|
20
|
+
|
|
21
|
+
def __init__(self, api_key: str, logger: logging.Logger = None, **kwargs):
|
|
22
|
+
"""
|
|
23
|
+
Initialize the Mistral OCR processor.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
api_key: Mistral API key
|
|
27
|
+
logger: Optional logger for OCR events
|
|
28
|
+
**kwargs: Additional keyword arguments
|
|
29
|
+
- encryption_key: Optional key for encrypted file operations
|
|
30
|
+
"""
|
|
31
|
+
# Pass all kwargs to the base class
|
|
32
|
+
super().__init__(logger=logger, **kwargs)
|
|
33
|
+
self.client = mistral_client_manager.get_client(api_key)
|
|
34
|
+
|
|
35
|
+
# ----------------- Helper methods ----------------- #
|
|
36
|
+
|
|
37
|
+
def _upload_file(self, file_name: str, content: BufferedReader):
|
|
38
|
+
"""Upload PDF to Mistral Cloud and return the uploaded file object."""
|
|
39
|
+
return self.client.files.upload(
|
|
40
|
+
file={"file_name": file_name, "content": content},
|
|
41
|
+
purpose="ocr",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def _get_signed_url(self, file_id: str) -> str:
|
|
45
|
+
"""Retrieve a signed URL for the previously uploaded file."""
|
|
46
|
+
return self.client.files.get_signed_url(file_id=file_id).url
|
|
47
|
+
|
|
48
|
+
def _run_ocr(self, document_url: str) -> OCRResponse:
|
|
49
|
+
"""Run Mistral OCR on the provided document URL."""
|
|
50
|
+
return self.client.ocr.process(
|
|
51
|
+
model=self.MODEL_NAME,
|
|
52
|
+
document={"type": "document_url", "document_url": document_url},
|
|
53
|
+
include_image_base64=True,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def _delete_file_with_retries(
|
|
57
|
+
self, file_id: str, max_tries: int = 10
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Attempt to delete a file from Mistral Cloud with retries."""
|
|
60
|
+
while max_tries > 0:
|
|
61
|
+
resp = self.client.files.delete(file_id=file_id)
|
|
62
|
+
if resp.deleted:
|
|
63
|
+
self.logger.info("File deletion successful!")
|
|
64
|
+
return
|
|
65
|
+
max_tries -= 1
|
|
66
|
+
self.logger.error(f"Deleting unsuccessful. ID: {file_id}")
|
|
67
|
+
|
|
68
|
+
def _convert_pages(self, ocr_response: OCRResponse) -> list[OCRPage]:
|
|
69
|
+
"""Convert OCRResponse pages to internal OCRPage schema."""
|
|
70
|
+
pages = []
|
|
71
|
+
for i, page in enumerate(ocr_response.pages):
|
|
72
|
+
images = [
|
|
73
|
+
Image.model_validate(img, from_attributes=True)
|
|
74
|
+
for img in (page.images or [])
|
|
75
|
+
]
|
|
76
|
+
ocr_page = OCRPage(
|
|
77
|
+
index=i,
|
|
78
|
+
markdown=page.markdown,
|
|
79
|
+
images=images,
|
|
80
|
+
dimensions=PageDimensions.model_validate(
|
|
81
|
+
page.dimensions, from_attributes=True
|
|
82
|
+
),
|
|
83
|
+
)
|
|
84
|
+
pages.append(ocr_page)
|
|
85
|
+
return pages
|
|
86
|
+
|
|
87
|
+
def _build_usage_info(self, ocr_response: OCRResponse) -> OCRUsageInfo:
|
|
88
|
+
"""Extract usage information from the OCR response."""
|
|
89
|
+
return OCRUsageInfo.model_validate(
|
|
90
|
+
ocr_response.usage_info, from_attributes=True
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
def _build_metrics(self, pages: list[OCRPage]) -> PagesProcessedMetrics:
|
|
94
|
+
"""Create page-processing cost metrics."""
|
|
95
|
+
cost_per_page = OCR_MODEL_COSTS.get(self.MODEL_NAME, 0.0)
|
|
96
|
+
return PagesProcessedMetrics(
|
|
97
|
+
pages_processed=len(pages),
|
|
98
|
+
cost_per_page=cost_per_page,
|
|
99
|
+
total_cost_usd=len(pages) * cost_per_page,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def _build_result(
|
|
103
|
+
self,
|
|
104
|
+
pdf_filepath: str,
|
|
105
|
+
ocr_response: OCRResponse,
|
|
106
|
+
pages: list[OCRPage],
|
|
107
|
+
usage_info: OCRUsageInfo,
|
|
108
|
+
metrics: list[PagesProcessedMetrics],
|
|
109
|
+
) -> OCRResult:
|
|
110
|
+
"""Assemble the OCRResult object."""
|
|
111
|
+
return OCRResult(
|
|
112
|
+
source_file_path=pdf_filepath,
|
|
113
|
+
processed_at=datetime.now(),
|
|
114
|
+
model=ocr_response.model,
|
|
115
|
+
document_annotation=ocr_response.document_annotation,
|
|
116
|
+
pages=pages,
|
|
117
|
+
usage_info=usage_info,
|
|
118
|
+
metrics=metrics,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def process(self, pdf_filepath: str, encrypted: bool = False) -> OCRResult:
|
|
122
|
+
"""High-level orchestration for running Mistral OCR on a PDF."""
|
|
123
|
+
|
|
124
|
+
file_name, reader = self.validate_and_prepare_file(
|
|
125
|
+
pdf_filepath, encrypted
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
uploaded = self._upload_file(file_name, reader)
|
|
129
|
+
del reader # free memory
|
|
130
|
+
|
|
131
|
+
try:
|
|
132
|
+
doc_url = self._get_signed_url(uploaded.id)
|
|
133
|
+
ocr_resp = self._run_ocr(doc_url)
|
|
134
|
+
finally:
|
|
135
|
+
self._delete_file_with_retries(uploaded.id)
|
|
136
|
+
|
|
137
|
+
pages = self._convert_pages(ocr_resp)
|
|
138
|
+
usage_info = self._build_usage_info(ocr_resp)
|
|
139
|
+
metrics = [self._build_metrics(pages)]
|
|
140
|
+
|
|
141
|
+
return self._build_result(
|
|
142
|
+
pdf_filepath=pdf_filepath,
|
|
143
|
+
ocr_response=ocr_resp,
|
|
144
|
+
pages=pages,
|
|
145
|
+
usage_info=usage_info,
|
|
146
|
+
metrics=metrics,
|
|
147
|
+
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Document processors for enhancing and transforming document content.
|
|
3
|
+
|
|
4
|
+
This module provides various processors that can be applied to documents
|
|
5
|
+
to extract, transform, or enhance their content.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from ragbandit.documents.processors.base_processor import BaseProcessor
|
|
9
|
+
from ragbandit.documents.processors.footnotes_processor import FootnoteProcessor # noqa
|
|
10
|
+
from ragbandit.documents.processors.references_processor import ReferencesProcessor # noqa
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"BaseProcessor",
|
|
14
|
+
"FootnoteProcessor",
|
|
15
|
+
"ReferencesProcessor"
|
|
16
|
+
]
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from ragbandit.schema import OCRResult, ProcessingResult, ProcessedPage
|
|
5
|
+
from ragbandit.utils.token_usage_tracker import TokenUsageTracker
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BaseProcessor(ABC):
|
|
9
|
+
"""
|
|
10
|
+
Base class or mix-in for individual processors.
|
|
11
|
+
Subclasses override `process()` and, optionally, `extend_response()`.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, name: str | None = None, api_key: str | None = None):
|
|
15
|
+
"""
|
|
16
|
+
Initialize the processor.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
name: Optional name for the processor
|
|
20
|
+
api_key: API key for LLM services
|
|
21
|
+
"""
|
|
22
|
+
# Hierarchical names make it easy to filter later:
|
|
23
|
+
# pipeline.text_cleaner, pipeline.language_model, …
|
|
24
|
+
base = "pipeline"
|
|
25
|
+
self.logger = logging.getLogger(
|
|
26
|
+
f"{base}.{name or self.__class__.__name__}"
|
|
27
|
+
)
|
|
28
|
+
self.api_key = api_key
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
def process(
|
|
32
|
+
self,
|
|
33
|
+
document: OCRResult | ProcessingResult,
|
|
34
|
+
usage_tracker: TokenUsageTracker | None = None,
|
|
35
|
+
) -> ProcessingResult:
|
|
36
|
+
"""
|
|
37
|
+
Do one step of work and return:
|
|
38
|
+
* a (possibly modified) ProcessingResult
|
|
39
|
+
* a dict of metadata specific to this processor
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
response: The OCR or intermediate processing result to process.
|
|
43
|
+
This can be either an `OCRResult` (raw OCR output) or
|
|
44
|
+
a `ProcessingResult` (output of a previous processor).
|
|
45
|
+
usage_tracker: Optional token usage tracker
|
|
46
|
+
"""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
# ----------------------------------------------------------------------
|
|
50
|
+
def __str__(self) -> str:
|
|
51
|
+
"""Return a string representation of the processor."""
|
|
52
|
+
return self.__class__.__name__
|
|
53
|
+
|
|
54
|
+
def __repr__(self) -> str:
|
|
55
|
+
"""Return a string representation of the processor."""
|
|
56
|
+
return f"{self.__class__.__name__}()"
|
|
57
|
+
|
|
58
|
+
# ----------------------------------------------------------------------
|
|
59
|
+
# Utility helpers
|
|
60
|
+
@staticmethod
|
|
61
|
+
def ensure_processing_result(
|
|
62
|
+
document: OCRResult | ProcessingResult,
|
|
63
|
+
processor_name: str = "bootstrap",
|
|
64
|
+
) -> ProcessingResult:
|
|
65
|
+
"""Ensure the incoming `document` is a `ProcessingResult`.
|
|
66
|
+
|
|
67
|
+
If an `OCRResult` is supplied (as is the case for the first processor
|
|
68
|
+
in a pipeline), it will be converted to a shallow `ProcessingResult` so
|
|
69
|
+
that downstream logic can assume a consistent type.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
# Always create a fresh ProcessingResult so that timestamps, metrics,
|
|
73
|
+
# and extracted data do not roll over between processors.
|
|
74
|
+
|
|
75
|
+
source_pages = document.pages if hasattr(document, "pages") else []
|
|
76
|
+
|
|
77
|
+
pages_processed = [
|
|
78
|
+
ProcessedPage(**page.model_dump()) for page in source_pages
|
|
79
|
+
]
|
|
80
|
+
|
|
81
|
+
return ProcessingResult(
|
|
82
|
+
processor_name=processor_name,
|
|
83
|
+
processed_at=datetime.now(timezone.utc),
|
|
84
|
+
pages=pages_processed,
|
|
85
|
+
processing_trace=[],
|
|
86
|
+
extracted_data={},
|
|
87
|
+
metrics=None,
|
|
88
|
+
)
|