ragbandit-core 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. ragbandit/__init__.py +26 -0
  2. ragbandit/config/__init__.py +3 -0
  3. ragbandit/config/llms.py +34 -0
  4. ragbandit/config/pricing.py +38 -0
  5. ragbandit/documents/__init__.py +66 -0
  6. ragbandit/documents/chunkers/__init__.py +18 -0
  7. ragbandit/documents/chunkers/base_chunker.py +201 -0
  8. ragbandit/documents/chunkers/fixed_size_chunker.py +174 -0
  9. ragbandit/documents/chunkers/semantic_chunker.py +205 -0
  10. ragbandit/documents/document_pipeline.py +350 -0
  11. ragbandit/documents/embedders/__init__.py +14 -0
  12. ragbandit/documents/embedders/base_embedder.py +82 -0
  13. ragbandit/documents/embedders/mistral_embedder.py +129 -0
  14. ragbandit/documents/ocr/__init__.py +13 -0
  15. ragbandit/documents/ocr/base_ocr.py +136 -0
  16. ragbandit/documents/ocr/mistral_ocr.py +147 -0
  17. ragbandit/documents/processors/__init__.py +16 -0
  18. ragbandit/documents/processors/base_processor.py +88 -0
  19. ragbandit/documents/processors/footnotes_processor.py +353 -0
  20. ragbandit/documents/processors/references_processor.py +408 -0
  21. ragbandit/documents/utils/__init__.py +11 -0
  22. ragbandit/documents/utils/secure_file_handler.py +95 -0
  23. ragbandit/prompt_tools/__init__.py +27 -0
  24. ragbandit/prompt_tools/footnotes_processor_tools.py +195 -0
  25. ragbandit/prompt_tools/prompt_tool.py +118 -0
  26. ragbandit/prompt_tools/references_processor_tools.py +31 -0
  27. ragbandit/prompt_tools/semantic_chunker_tools.py +56 -0
  28. ragbandit/schema.py +206 -0
  29. ragbandit/utils/__init__.py +19 -0
  30. ragbandit/utils/in_memory_log_handler.py +33 -0
  31. ragbandit/utils/llm_utils.py +188 -0
  32. ragbandit/utils/mistral_client.py +76 -0
  33. ragbandit/utils/token_usage_tracker.py +220 -0
  34. ragbandit_core-0.1.1.dist-info/METADATA +145 -0
  35. ragbandit_core-0.1.1.dist-info/RECORD +38 -0
  36. ragbandit_core-0.1.1.dist-info/WHEEL +5 -0
  37. ragbandit_core-0.1.1.dist-info/licenses/LICENSE.md +9 -0
  38. ragbandit_core-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,129 @@
1
+ from ragbandit.utils import mistral_client_manager
2
+ from ragbandit.utils.token_usage_tracker import TokenUsageTracker
3
+ from ragbandit.documents.embedders.base_embedder import BaseEmbedder
4
+ from ragbandit.schema import (
5
+ ChunkingResult,
6
+ EmbeddingResult,
7
+ ChunkWithEmbedding,
8
+ )
9
+ from datetime import datetime, timezone
10
+
11
+
12
+ class MistralEmbedder(BaseEmbedder):
13
+ """Document embedder that uses Mistral AI's embedding models."""
14
+
15
+ def __init__(
16
+ self,
17
+ api_key: str,
18
+ model: str = "mistral-embed",
19
+ name: str = None,
20
+ ):
21
+ """
22
+ Initialize the Mistral embedder.
23
+
24
+ Args:
25
+ api_key: Mistral API key
26
+ model: Embedding model to use
27
+ name: Optional name for the embedder
28
+ """
29
+ super().__init__(name)
30
+ self.model = model
31
+
32
+ # Initialize the Mistral client
33
+ self.client = mistral_client_manager.get_client(api_key)
34
+
35
+ self.logger.info(
36
+ f"Initialized MistralEmbedder with model {self.model}"
37
+ )
38
+
39
+ # ------------------------------------------------------------------
40
+ # Public API
41
+ def embed_chunks(
42
+ self,
43
+ chunk_result: ChunkingResult,
44
+ usage_tracker: TokenUsageTracker | None = None,
45
+ ) -> EmbeddingResult:
46
+ """Orchestrate the Mistral embedding flow."""
47
+
48
+ chunks = chunk_result.chunks
49
+ if not chunks:
50
+ self.logger.warning("No chunks to embed")
51
+ return self._empty_embedding_result()
52
+
53
+ try:
54
+ texts = self._extract_texts(chunks)
55
+ response = self._call_mistral_embeddings(texts)
56
+ return self._build_embedding_result(
57
+ chunks, response, usage_tracker
58
+ )
59
+ except Exception as e:
60
+ self.logger.error(f"Error generating embeddings: {e}")
61
+ return self._empty_embedding_result(chunks)
62
+
63
+ # ------------------------------------------------------------------
64
+ # Helpers
65
+ def _extract_texts(
66
+ self,
67
+ chunks: list[ChunkWithEmbedding | object],
68
+ ) -> list[str]:
69
+ """Extract raw text from chunks."""
70
+ return [c.text for c in chunks]
71
+
72
+ def _call_mistral_embeddings(self, texts: list[str]):
73
+ """Call the Mistral embeddings API and return the raw response."""
74
+ self.logger.info("Requesting embeddings from Mistral API")
75
+ return self.client.embeddings.create(model=self.model, inputs=texts)
76
+
77
+ def _build_embedding_result(
78
+ self,
79
+ chunks,
80
+ response,
81
+ usage_tracker: TokenUsageTracker | None = None,
82
+ ) -> EmbeddingResult:
83
+ """Convert API response into EmbeddingResult."""
84
+
85
+ embedded_chunks: list[ChunkWithEmbedding] = []
86
+ for i, data in enumerate(response.data):
87
+ embedded_chunks.append(
88
+ ChunkWithEmbedding(
89
+ text=chunks[i].text,
90
+ metadata=chunks[i].metadata,
91
+ embedding=list(data.embedding),
92
+ embedding_model=self.model,
93
+ )
94
+ )
95
+
96
+ if usage_tracker:
97
+ usage_tracker.add_embedding_tokens(
98
+ response.usage.prompt_tokens, self.model
99
+ )
100
+
101
+ return EmbeddingResult(
102
+ processed_at=datetime.now(timezone.utc),
103
+ chunks_with_embeddings=embedded_chunks,
104
+ model_name=self.model,
105
+ metrics=usage_tracker.get_summary() if usage_tracker else None,
106
+ )
107
+
108
+ def _empty_embedding_result(
109
+ self, chunks: list | None = None
110
+ ) -> EmbeddingResult:
111
+ """Return an EmbeddingResult with no embeddings (error fallback)."""
112
+
113
+ empty_embeds: list[ChunkWithEmbedding] = []
114
+ for c in (chunks or []):
115
+ empty_embeds.append(
116
+ ChunkWithEmbedding(
117
+ text=c.text if hasattr(c, "text") else "",
118
+ metadata=c.metadata if hasattr(c, "metadata") else None,
119
+ embedding=[],
120
+ embedding_model=self.model,
121
+ )
122
+ )
123
+
124
+ return EmbeddingResult(
125
+ processed_at=None,
126
+ chunks_with_embeddings=empty_embeds,
127
+ model_name=self.model,
128
+ metrics=None,
129
+ )
@@ -0,0 +1,13 @@
1
+ """
2
+ OCR (Optical Character Recognition) implementations for document processing.
3
+
4
+ This module provides OCR processors that convert document images to text.
5
+ """
6
+
7
+ from ragbandit.documents.ocr.base_ocr import BaseOCR
8
+ from ragbandit.documents.ocr.mistral_ocr import MistralOCRDocument
9
+
10
+ __all__ = [
11
+ "BaseOCR",
12
+ "MistralOCRDocument"
13
+ ]
@@ -0,0 +1,136 @@
1
+ import logging
2
+ import os
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+ from io import BytesIO, BufferedReader
6
+ from ragbandit.documents.utils.secure_file_handler import SecureFileHandler
7
+ from ragbandit.schema import OCRResult
8
+
9
+
10
+ class BaseOCR(ABC):
11
+ """Base class for OCR document processing.
12
+
13
+ This class provides the interface for OCR processing and a default
14
+ implementation using Mistral's OCR API.
15
+ """
16
+
17
+ def __init__(self, logger: logging.Logger = None, **kwargs):
18
+ """Initialize the OCR processor.
19
+
20
+ Args:
21
+ logger: Optional logger for OCR events
22
+ **kwargs: Additional keyword arguments (e.g., encryption_key)
23
+ """
24
+ self.logger = logger or logging.getLogger(__name__)
25
+ self.kwargs = kwargs
26
+
27
+ def validate_pdf(self, pdf_filepath: str) -> str:
28
+ """Validate that a PDF file exists.
29
+
30
+ Args:
31
+ pdf_filepath: Path to the PDF file to validate
32
+
33
+ Returns:
34
+ str: The basename of the file
35
+
36
+ Raises:
37
+ ValueError: If the file does not exist
38
+ """
39
+ file_name = os.path.basename(pdf_filepath)
40
+ pdf_file_exists = os.path.isfile(pdf_filepath)
41
+
42
+ if not pdf_file_exists:
43
+ self.logger.error(f"PDF file {pdf_filepath} not found")
44
+ raise ValueError(f"PDF file {pdf_filepath} not found")
45
+
46
+ return file_name
47
+
48
+ def read_encrypted_file(self, pdf_filepath: str) -> BufferedReader:
49
+ """Read an encrypted PDF file and return a buffered reader.
50
+
51
+ Args:
52
+ pdf_filepath: Path to the encrypted PDF file
53
+
54
+ Returns:
55
+ BufferedReader: A buffered reader for the decrypted file content
56
+
57
+ Raises:
58
+ ValueError: If encryption_key is not provided in kwargs
59
+ """
60
+ self.logger.info("Decrypting for OCR...")
61
+
62
+ encryption_key = self.kwargs.get("encryption_key")
63
+ if not encryption_key:
64
+ raise ValueError(
65
+ "encryption_key must be provided in kwargs "
66
+ "for encrypted file operations. "
67
+ "Pass encryption_key when initializing the OCR processor."
68
+ )
69
+
70
+ secure_handler = SecureFileHandler(encryption_key=encryption_key)
71
+ decrypted = secure_handler.read_encrypted_file(Path(pdf_filepath))
72
+ raw = BytesIO(decrypted)
73
+ raw.seek(0)
74
+ return BufferedReader(raw)
75
+
76
+ def read_unencrypted_file(self, pdf_filepath: str) -> BufferedReader:
77
+ """Read an unencrypted PDF file and return a buffered reader.
78
+
79
+ Args:
80
+ pdf_filepath: Path to the unencrypted PDF file
81
+
82
+ Returns:
83
+ BufferedReader: A buffered reader for the file content
84
+ """
85
+ self.logger.info("Reading file for OCR...")
86
+ with open(pdf_filepath, 'rb') as f:
87
+ content = f.read()
88
+
89
+ raw = BytesIO(content)
90
+ raw.seek(0)
91
+ return BufferedReader(raw)
92
+
93
+ def validate_and_prepare_file(
94
+ self, pdf_filepath: str, encrypted: bool = True
95
+ ) -> tuple[str, BufferedReader]:
96
+ """Validate and prepare a PDF file for OCR processing.
97
+
98
+ Args:
99
+ pdf_filepath: Path to the PDF file to process
100
+ encrypted: Whether the file is encrypted (default: True)
101
+
102
+ Returns:
103
+ tuple: (file_name, file_reader)
104
+
105
+ Raises:
106
+ ValueError: If the file does not exist
107
+ """
108
+ file_name = self.validate_pdf(pdf_filepath)
109
+
110
+ if encrypted:
111
+ reader = self.read_encrypted_file(pdf_filepath)
112
+ else:
113
+ reader = self.read_unencrypted_file(pdf_filepath)
114
+
115
+ return file_name, reader
116
+
117
+ @abstractmethod
118
+ def process(self, pdf_filepath: str) -> OCRResult:
119
+ """Process a PDF file through OCR.
120
+
121
+ Args:
122
+ pdf_filepath: Path to the PDF file to process
123
+
124
+ Returns:
125
+ OCRResult: The OCR result from the processor
126
+ """
127
+ raise NotImplementedError("Subclasses must implement process method")
128
+
129
+ # ----------------------------------------------------------------------
130
+ def __str__(self) -> str:
131
+ """Return a string representation of the OCR processor."""
132
+ return self.__class__.__name__
133
+
134
+ def __repr__(self) -> str:
135
+ """Return a string representation of the OCR processor."""
136
+ return f"{self.__class__.__name__}()"
@@ -0,0 +1,147 @@
1
+ from ragbandit.documents.ocr.base_ocr import BaseOCR
2
+ from ragbandit.schema import (
3
+ OCRResult, OCRPage, PageDimensions,
4
+ Image, OCRUsageInfo, PagesProcessedMetrics
5
+ )
6
+ from ragbandit.config.pricing import OCR_MODEL_COSTS
7
+ from mistralai.models.ocrresponse import OCRResponse
8
+ from io import BufferedReader
9
+ from datetime import datetime
10
+
11
+ import logging
12
+ from ragbandit.utils import mistral_client_manager
13
+
14
+
15
+ class MistralOCRDocument(BaseOCR):
16
+ """OCR document processor using Mistral's API."""
17
+
18
+ # Explicit model identifier used for all Mistral OCR requests
19
+ MODEL_NAME = "mistral-ocr-latest"
20
+
21
+ def __init__(self, api_key: str, logger: logging.Logger = None, **kwargs):
22
+ """
23
+ Initialize the Mistral OCR processor.
24
+
25
+ Args:
26
+ api_key: Mistral API key
27
+ logger: Optional logger for OCR events
28
+ **kwargs: Additional keyword arguments
29
+ - encryption_key: Optional key for encrypted file operations
30
+ """
31
+ # Pass all kwargs to the base class
32
+ super().__init__(logger=logger, **kwargs)
33
+ self.client = mistral_client_manager.get_client(api_key)
34
+
35
+ # ----------------- Helper methods ----------------- #
36
+
37
+ def _upload_file(self, file_name: str, content: BufferedReader):
38
+ """Upload PDF to Mistral Cloud and return the uploaded file object."""
39
+ return self.client.files.upload(
40
+ file={"file_name": file_name, "content": content},
41
+ purpose="ocr",
42
+ )
43
+
44
+ def _get_signed_url(self, file_id: str) -> str:
45
+ """Retrieve a signed URL for the previously uploaded file."""
46
+ return self.client.files.get_signed_url(file_id=file_id).url
47
+
48
+ def _run_ocr(self, document_url: str) -> OCRResponse:
49
+ """Run Mistral OCR on the provided document URL."""
50
+ return self.client.ocr.process(
51
+ model=self.MODEL_NAME,
52
+ document={"type": "document_url", "document_url": document_url},
53
+ include_image_base64=True,
54
+ )
55
+
56
+ def _delete_file_with_retries(
57
+ self, file_id: str, max_tries: int = 10
58
+ ) -> None:
59
+ """Attempt to delete a file from Mistral Cloud with retries."""
60
+ while max_tries > 0:
61
+ resp = self.client.files.delete(file_id=file_id)
62
+ if resp.deleted:
63
+ self.logger.info("File deletion successful!")
64
+ return
65
+ max_tries -= 1
66
+ self.logger.error(f"Deleting unsuccessful. ID: {file_id}")
67
+
68
+ def _convert_pages(self, ocr_response: OCRResponse) -> list[OCRPage]:
69
+ """Convert OCRResponse pages to internal OCRPage schema."""
70
+ pages = []
71
+ for i, page in enumerate(ocr_response.pages):
72
+ images = [
73
+ Image.model_validate(img, from_attributes=True)
74
+ for img in (page.images or [])
75
+ ]
76
+ ocr_page = OCRPage(
77
+ index=i,
78
+ markdown=page.markdown,
79
+ images=images,
80
+ dimensions=PageDimensions.model_validate(
81
+ page.dimensions, from_attributes=True
82
+ ),
83
+ )
84
+ pages.append(ocr_page)
85
+ return pages
86
+
87
+ def _build_usage_info(self, ocr_response: OCRResponse) -> OCRUsageInfo:
88
+ """Extract usage information from the OCR response."""
89
+ return OCRUsageInfo.model_validate(
90
+ ocr_response.usage_info, from_attributes=True
91
+ )
92
+
93
+ def _build_metrics(self, pages: list[OCRPage]) -> PagesProcessedMetrics:
94
+ """Create page-processing cost metrics."""
95
+ cost_per_page = OCR_MODEL_COSTS.get(self.MODEL_NAME, 0.0)
96
+ return PagesProcessedMetrics(
97
+ pages_processed=len(pages),
98
+ cost_per_page=cost_per_page,
99
+ total_cost_usd=len(pages) * cost_per_page,
100
+ )
101
+
102
+ def _build_result(
103
+ self,
104
+ pdf_filepath: str,
105
+ ocr_response: OCRResponse,
106
+ pages: list[OCRPage],
107
+ usage_info: OCRUsageInfo,
108
+ metrics: list[PagesProcessedMetrics],
109
+ ) -> OCRResult:
110
+ """Assemble the OCRResult object."""
111
+ return OCRResult(
112
+ source_file_path=pdf_filepath,
113
+ processed_at=datetime.now(),
114
+ model=ocr_response.model,
115
+ document_annotation=ocr_response.document_annotation,
116
+ pages=pages,
117
+ usage_info=usage_info,
118
+ metrics=metrics,
119
+ )
120
+
121
+ def process(self, pdf_filepath: str, encrypted: bool = False) -> OCRResult:
122
+ """High-level orchestration for running Mistral OCR on a PDF."""
123
+
124
+ file_name, reader = self.validate_and_prepare_file(
125
+ pdf_filepath, encrypted
126
+ )
127
+
128
+ uploaded = self._upload_file(file_name, reader)
129
+ del reader # free memory
130
+
131
+ try:
132
+ doc_url = self._get_signed_url(uploaded.id)
133
+ ocr_resp = self._run_ocr(doc_url)
134
+ finally:
135
+ self._delete_file_with_retries(uploaded.id)
136
+
137
+ pages = self._convert_pages(ocr_resp)
138
+ usage_info = self._build_usage_info(ocr_resp)
139
+ metrics = [self._build_metrics(pages)]
140
+
141
+ return self._build_result(
142
+ pdf_filepath=pdf_filepath,
143
+ ocr_response=ocr_resp,
144
+ pages=pages,
145
+ usage_info=usage_info,
146
+ metrics=metrics,
147
+ )
@@ -0,0 +1,16 @@
1
+ """
2
+ Document processors for enhancing and transforming document content.
3
+
4
+ This module provides various processors that can be applied to documents
5
+ to extract, transform, or enhance their content.
6
+ """
7
+
8
+ from ragbandit.documents.processors.base_processor import BaseProcessor
9
+ from ragbandit.documents.processors.footnotes_processor import FootnoteProcessor # noqa
10
+ from ragbandit.documents.processors.references_processor import ReferencesProcessor # noqa
11
+
12
+ __all__ = [
13
+ "BaseProcessor",
14
+ "FootnoteProcessor",
15
+ "ReferencesProcessor"
16
+ ]
@@ -0,0 +1,88 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from datetime import datetime, timezone
4
+ from ragbandit.schema import OCRResult, ProcessingResult, ProcessedPage
5
+ from ragbandit.utils.token_usage_tracker import TokenUsageTracker
6
+
7
+
8
+ class BaseProcessor(ABC):
9
+ """
10
+ Base class or mix-in for individual processors.
11
+ Subclasses override `process()` and, optionally, `extend_response()`.
12
+ """
13
+
14
+ def __init__(self, name: str | None = None, api_key: str | None = None):
15
+ """
16
+ Initialize the processor.
17
+
18
+ Args:
19
+ name: Optional name for the processor
20
+ api_key: API key for LLM services
21
+ """
22
+ # Hierarchical names make it easy to filter later:
23
+ # pipeline.text_cleaner, pipeline.language_model, …
24
+ base = "pipeline"
25
+ self.logger = logging.getLogger(
26
+ f"{base}.{name or self.__class__.__name__}"
27
+ )
28
+ self.api_key = api_key
29
+
30
+ @abstractmethod
31
+ def process(
32
+ self,
33
+ document: OCRResult | ProcessingResult,
34
+ usage_tracker: TokenUsageTracker | None = None,
35
+ ) -> ProcessingResult:
36
+ """
37
+ Do one step of work and return:
38
+ * a (possibly modified) ProcessingResult
39
+ * a dict of metadata specific to this processor
40
+
41
+ Args:
42
+ response: The OCR or intermediate processing result to process.
43
+ This can be either an `OCRResult` (raw OCR output) or
44
+ a `ProcessingResult` (output of a previous processor).
45
+ usage_tracker: Optional token usage tracker
46
+ """
47
+ raise NotImplementedError
48
+
49
+ # ----------------------------------------------------------------------
50
+ def __str__(self) -> str:
51
+ """Return a string representation of the processor."""
52
+ return self.__class__.__name__
53
+
54
+ def __repr__(self) -> str:
55
+ """Return a string representation of the processor."""
56
+ return f"{self.__class__.__name__}()"
57
+
58
+ # ----------------------------------------------------------------------
59
+ # Utility helpers
60
+ @staticmethod
61
+ def ensure_processing_result(
62
+ document: OCRResult | ProcessingResult,
63
+ processor_name: str = "bootstrap",
64
+ ) -> ProcessingResult:
65
+ """Ensure the incoming `document` is a `ProcessingResult`.
66
+
67
+ If an `OCRResult` is supplied (as is the case for the first processor
68
+ in a pipeline), it will be converted to a shallow `ProcessingResult` so
69
+ that downstream logic can assume a consistent type.
70
+ """
71
+
72
+ # Always create a fresh ProcessingResult so that timestamps, metrics,
73
+ # and extracted data do not roll over between processors.
74
+
75
+ source_pages = document.pages if hasattr(document, "pages") else []
76
+
77
+ pages_processed = [
78
+ ProcessedPage(**page.model_dump()) for page in source_pages
79
+ ]
80
+
81
+ return ProcessingResult(
82
+ processor_name=processor_name,
83
+ processed_at=datetime.now(timezone.utc),
84
+ pages=pages_processed,
85
+ processing_trace=[],
86
+ extracted_data={},
87
+ metrics=None,
88
+ )