mb-rag 1.1.57.post1__py3-none-any.whl → 1.1.58__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/basic.py +375 -306
- mb_rag/chatbot/chains.py +206 -206
- mb_rag/chatbot/conversation.py +185 -185
- mb_rag/chatbot/prompts.py +58 -58
- mb_rag/rag/embeddings.py +810 -810
- mb_rag/utils/all_data_extract.py +64 -64
- mb_rag/utils/bounding_box.py +231 -231
- mb_rag/utils/document_extract.py +354 -354
- mb_rag/utils/extra.py +73 -73
- mb_rag/utils/pdf_extract.py +428 -428
- mb_rag/version.py +1 -1
- {mb_rag-1.1.57.post1.dist-info → mb_rag-1.1.58.dist-info}/METADATA +11 -11
- mb_rag-1.1.58.dist-info/RECORD +19 -0
- mb_rag-1.1.57.post1.dist-info/RECORD +0 -19
- {mb_rag-1.1.57.post1.dist-info → mb_rag-1.1.58.dist-info}/WHEEL +0 -0
- {mb_rag-1.1.57.post1.dist-info → mb_rag-1.1.58.dist-info}/top_level.txt +0 -0
mb_rag/utils/all_data_extract.py
CHANGED
|
@@ -1,65 +1,65 @@
|
|
|
1
|
-
## Docling data extract
|
|
2
|
-
|
|
3
|
-
from typing import List
|
|
4
|
-
from mb_rag.utils.extra import check_package
|
|
5
|
-
|
|
6
|
-
__all__ = ['DocumentExtractor']
|
|
7
|
-
|
|
8
|
-
class DocumentExtractor:
|
|
9
|
-
"""
|
|
10
|
-
DocumentExtractor class for extracting data from documents using Docling.
|
|
11
|
-
"""
|
|
12
|
-
|
|
13
|
-
def __init__(self):
|
|
14
|
-
"""
|
|
15
|
-
Initialize the DocumentExtractor class.
|
|
16
|
-
Checking for Docling package.
|
|
17
|
-
"""
|
|
18
|
-
if not check_package("docling"):
|
|
19
|
-
raise ImportError("Docling package not found. Please install it using: pip install docling")
|
|
20
|
-
from docling import Docling
|
|
21
|
-
self.Docling = Docling
|
|
22
|
-
|
|
23
|
-
def _extract_data(self, file_path: str, **kwargs) -> List[str]:
|
|
24
|
-
"""
|
|
25
|
-
Extract data from a document using Docling.
|
|
26
|
-
"""
|
|
27
|
-
try:
|
|
28
|
-
docling = self.Docling(file_path, **kwargs)
|
|
29
|
-
return docling.extract()
|
|
30
|
-
except Exception as e:
|
|
31
|
-
raise Exception(f"Error extracting data from {file_path}: {str(e)}")
|
|
32
|
-
|
|
33
|
-
def get_data(self,file_path: str, save_path: str = None, data_store_type: str = "markdown",**kwargs) -> List[str]:
|
|
34
|
-
"""
|
|
35
|
-
Get data from a document using Docling.
|
|
36
|
-
Args:
|
|
37
|
-
file_path (str): Path to the document
|
|
38
|
-
save_path (str): Path to save the extracted data. Default is None. If None, data saved as Markdown file as docling_{file_name}.md
|
|
39
|
-
data_store_type (str): Saving document as markdown, txt or html. Default is markdown
|
|
40
|
-
**kwargs: Additional arguments for Docling
|
|
41
|
-
Returns:
|
|
42
|
-
List[str]: Extracted data
|
|
43
|
-
"""
|
|
44
|
-
data = self._extract_data(file_path, **kwargs)
|
|
45
|
-
if data_store_type == "markdown":
|
|
46
|
-
data_type = "md"
|
|
47
|
-
elif data_store_type == "txt":
|
|
48
|
-
data_type = "txt"
|
|
49
|
-
elif data_store_type == "html":
|
|
50
|
-
data_type = "html"
|
|
51
|
-
else:
|
|
52
|
-
print("Invalid data store type. Defaulting to text (txt)")
|
|
53
|
-
data_type = "txt"
|
|
54
|
-
if save_path is None:
|
|
55
|
-
save_path = f"docling_{file_path.split('/')[-1].split('.')[0]}.{data_type}"
|
|
56
|
-
print(f"Saving extracted data to {save_path}")
|
|
57
|
-
if data_store_type == "markdown":
|
|
58
|
-
data_with_type = data.document.export_to_markdown()
|
|
59
|
-
elif data_store_type == "txt":
|
|
60
|
-
data_with_type = data.document.export_to_text()
|
|
61
|
-
elif data_store_type == "html":
|
|
62
|
-
data_with_type = data.document.export_to_html()
|
|
63
|
-
with open(save_path, 'w') as f:
|
|
64
|
-
f.write(data_with_type)
|
|
1
|
+
## Docling data extract
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
from mb_rag.utils.extra import check_package
|
|
5
|
+
|
|
6
|
+
__all__ = ['DocumentExtractor']
|
|
7
|
+
|
|
8
|
+
class DocumentExtractor:
|
|
9
|
+
"""
|
|
10
|
+
DocumentExtractor class for extracting data from documents using Docling.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self):
|
|
14
|
+
"""
|
|
15
|
+
Initialize the DocumentExtractor class.
|
|
16
|
+
Checking for Docling package.
|
|
17
|
+
"""
|
|
18
|
+
if not check_package("docling"):
|
|
19
|
+
raise ImportError("Docling package not found. Please install it using: pip install docling")
|
|
20
|
+
from docling import Docling
|
|
21
|
+
self.Docling = Docling
|
|
22
|
+
|
|
23
|
+
def _extract_data(self, file_path: str, **kwargs) -> List[str]:
|
|
24
|
+
"""
|
|
25
|
+
Extract data from a document using Docling.
|
|
26
|
+
"""
|
|
27
|
+
try:
|
|
28
|
+
docling = self.Docling(file_path, **kwargs)
|
|
29
|
+
return docling.extract()
|
|
30
|
+
except Exception as e:
|
|
31
|
+
raise Exception(f"Error extracting data from {file_path}: {str(e)}")
|
|
32
|
+
|
|
33
|
+
def get_data(self,file_path: str, save_path: str = None, data_store_type: str = "markdown",**kwargs) -> List[str]:
|
|
34
|
+
"""
|
|
35
|
+
Get data from a document using Docling.
|
|
36
|
+
Args:
|
|
37
|
+
file_path (str): Path to the document
|
|
38
|
+
save_path (str): Path to save the extracted data. Default is None. If None, data saved as Markdown file as docling_{file_name}.md
|
|
39
|
+
data_store_type (str): Saving document as markdown, txt or html. Default is markdown
|
|
40
|
+
**kwargs: Additional arguments for Docling
|
|
41
|
+
Returns:
|
|
42
|
+
List[str]: Extracted data
|
|
43
|
+
"""
|
|
44
|
+
data = self._extract_data(file_path, **kwargs)
|
|
45
|
+
if data_store_type == "markdown":
|
|
46
|
+
data_type = "md"
|
|
47
|
+
elif data_store_type == "txt":
|
|
48
|
+
data_type = "txt"
|
|
49
|
+
elif data_store_type == "html":
|
|
50
|
+
data_type = "html"
|
|
51
|
+
else:
|
|
52
|
+
print("Invalid data store type. Defaulting to text (txt)")
|
|
53
|
+
data_type = "txt"
|
|
54
|
+
if save_path is None:
|
|
55
|
+
save_path = f"docling_{file_path.split('/')[-1].split('.')[0]}.{data_type}"
|
|
56
|
+
print(f"Saving extracted data to {save_path}")
|
|
57
|
+
if data_store_type == "markdown":
|
|
58
|
+
data_with_type = data.document.export_to_markdown()
|
|
59
|
+
elif data_store_type == "txt":
|
|
60
|
+
data_with_type = data.document.export_to_text()
|
|
61
|
+
elif data_store_type == "html":
|
|
62
|
+
data_with_type = data.document.export_to_html()
|
|
63
|
+
with open(save_path, 'w') as f:
|
|
64
|
+
f.write(data_with_type)
|
|
65
65
|
return data
|
mb_rag/utils/bounding_box.py
CHANGED
|
@@ -1,231 +1,231 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Bounding box utilities
|
|
3
|
-
"""
|
|
4
|
-
|
|
5
|
-
import os
|
|
6
|
-
from typing import List, Dict, Any, Optional, Tuple, Union
|
|
7
|
-
from dataclasses import dataclass
|
|
8
|
-
from mb_rag.utils.extra import check_package
|
|
9
|
-
|
|
10
|
-
__all__ = ['BoundingBoxConfig', 'BoundingBoxProcessor']
|
|
11
|
-
|
|
12
|
-
def check_image_dependencies() -> None:
|
|
13
|
-
"""
|
|
14
|
-
Check if required image processing packages are installed
|
|
15
|
-
Raises:
|
|
16
|
-
ImportError: If any required package is missing
|
|
17
|
-
"""
|
|
18
|
-
if not check_package("PIL"):
|
|
19
|
-
raise ImportError("Pillow package not found. Please install it using: pip install Pillow")
|
|
20
|
-
if not check_package("cv2"):
|
|
21
|
-
raise ImportError("OpenCV package not found. Please install it using: pip install opencv-python")
|
|
22
|
-
if not check_package("google.generativeai"):
|
|
23
|
-
raise ImportError("Google Generative AI package not found. Please install it using: pip install google-generativeai")
|
|
24
|
-
|
|
25
|
-
@dataclass
|
|
26
|
-
class BoundingBoxConfig:
|
|
27
|
-
"""Configuration for bounding box operations"""
|
|
28
|
-
model_name: str = "gemini-1.5-pro-latest"
|
|
29
|
-
api_key: Optional[str] = None
|
|
30
|
-
default_prompt: str = 'Return bounding boxes of container, for each only one return [ymin, xmin, ymax, xmax]'
|
|
31
|
-
|
|
32
|
-
class BoundingBoxProcessor:
|
|
33
|
-
"""
|
|
34
|
-
Class for processing images and generating bounding boxes
|
|
35
|
-
|
|
36
|
-
Attributes:
|
|
37
|
-
model: The Google Generative AI model instance
|
|
38
|
-
config: Configuration for bounding box operations
|
|
39
|
-
"""
|
|
40
|
-
|
|
41
|
-
def __init__(self, config: Optional[BoundingBoxConfig] = None, **kwargs):
|
|
42
|
-
"""
|
|
43
|
-
Initialize bounding box processor
|
|
44
|
-
Args:
|
|
45
|
-
config: Configuration for the processor
|
|
46
|
-
**kwargs: Additional arguments
|
|
47
|
-
"""
|
|
48
|
-
check_image_dependencies()
|
|
49
|
-
self.config = config or BoundingBoxConfig(**kwargs)
|
|
50
|
-
self._initialize_model()
|
|
51
|
-
self._initialize_image_libs()
|
|
52
|
-
|
|
53
|
-
@classmethod
|
|
54
|
-
def from_model(cls, model_name: str, api_key: Optional[str] = None, **kwargs) -> 'BoundingBoxProcessor':
|
|
55
|
-
"""
|
|
56
|
-
Create processor with specific model configuration
|
|
57
|
-
Args:
|
|
58
|
-
model_name: Name of the model
|
|
59
|
-
api_key: Optional API key
|
|
60
|
-
**kwargs: Additional configuration
|
|
61
|
-
Returns:
|
|
62
|
-
BoundingBoxProcessor: Configured processor
|
|
63
|
-
"""
|
|
64
|
-
config = BoundingBoxConfig(
|
|
65
|
-
model_name=model_name,
|
|
66
|
-
api_key=api_key
|
|
67
|
-
)
|
|
68
|
-
return cls(config, **kwargs)
|
|
69
|
-
|
|
70
|
-
def _initialize_model(self) -> None:
|
|
71
|
-
"""Initialize the AI model"""
|
|
72
|
-
import google.generativeai as genai
|
|
73
|
-
|
|
74
|
-
api_key = self.config.api_key or os.environ.get("GOOGLE_API_KEY")
|
|
75
|
-
if not api_key:
|
|
76
|
-
raise ValueError("Google API key not found. Please provide api_key parameter or set GOOGLE_API_KEY environment variable.")
|
|
77
|
-
|
|
78
|
-
try:
|
|
79
|
-
genai.configure(api_key=api_key)
|
|
80
|
-
self.model = genai.GenerativeModel(model_name=self.config.model_name)
|
|
81
|
-
except Exception as e:
|
|
82
|
-
raise ValueError(f"Error initializing Google Generative AI model: {str(e)}")
|
|
83
|
-
|
|
84
|
-
def _initialize_image_libs(self) -> None:
|
|
85
|
-
"""Initialize image processing libraries"""
|
|
86
|
-
from PIL import Image
|
|
87
|
-
import cv2
|
|
88
|
-
self._Image = Image
|
|
89
|
-
self._cv2 = cv2
|
|
90
|
-
|
|
91
|
-
@staticmethod
|
|
92
|
-
def _validate_image_path(image_path: str) -> None:
|
|
93
|
-
"""
|
|
94
|
-
Validate image path
|
|
95
|
-
Args:
|
|
96
|
-
image_path: Path to image
|
|
97
|
-
Raises:
|
|
98
|
-
FileNotFoundError: If image doesn't exist
|
|
99
|
-
"""
|
|
100
|
-
if not os.path.exists(image_path):
|
|
101
|
-
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
102
|
-
|
|
103
|
-
def generate_bounding_boxes(self,
|
|
104
|
-
image_path: str,
|
|
105
|
-
prompt: Optional[str] = None
|
|
106
|
-
) -> Any:
|
|
107
|
-
"""
|
|
108
|
-
Generate bounding boxes for an image
|
|
109
|
-
Args:
|
|
110
|
-
image_path: Path to image
|
|
111
|
-
prompt: Custom prompt for the model
|
|
112
|
-
Returns:
|
|
113
|
-
Any: Model response with bounding boxes
|
|
114
|
-
"""
|
|
115
|
-
self._validate_image_path(image_path)
|
|
116
|
-
|
|
117
|
-
try:
|
|
118
|
-
image = self._Image.open(image_path)
|
|
119
|
-
prompt = prompt or self.config.default_prompt
|
|
120
|
-
return self.model.generate_content([image, prompt])
|
|
121
|
-
except Exception as e:
|
|
122
|
-
raise ValueError(f"Error generating bounding boxes: {str(e)}")
|
|
123
|
-
|
|
124
|
-
def add_bounding_boxes(self,
|
|
125
|
-
image_path: str,
|
|
126
|
-
bounding_boxes: Dict[str, List[int]],
|
|
127
|
-
color: Tuple[int, int, int] = (0, 0, 255),
|
|
128
|
-
thickness: int = 4,
|
|
129
|
-
font_scale: float = 1.0,
|
|
130
|
-
show: bool = False,
|
|
131
|
-
google_bb= False
|
|
132
|
-
) -> Any:
|
|
133
|
-
"""
|
|
134
|
-
Add bounding boxes to an image
|
|
135
|
-
Args:
|
|
136
|
-
image_path: Path to image
|
|
137
|
-
bounding_boxes: Dictionary of bounding boxes
|
|
138
|
-
color: BGR color tuple
|
|
139
|
-
thickness: Line thickness
|
|
140
|
-
font_scale: Font scale for labels
|
|
141
|
-
show: Whether to display the image
|
|
142
|
-
Returns:
|
|
143
|
-
Any: Image with bounding boxes
|
|
144
|
-
"""
|
|
145
|
-
self._validate_image_path(image_path)
|
|
146
|
-
|
|
147
|
-
if not isinstance(bounding_boxes, dict):
|
|
148
|
-
raise ValueError("bounding_boxes must be a dictionary")
|
|
149
|
-
|
|
150
|
-
try:
|
|
151
|
-
img = self._cv2.imread(image_path)
|
|
152
|
-
if img is None:
|
|
153
|
-
raise ValueError(f"Failed to load image: {image_path}")
|
|
154
|
-
|
|
155
|
-
for key, value in bounding_boxes.items():
|
|
156
|
-
if not isinstance(value, list) or len(value) != 4:
|
|
157
|
-
raise ValueError(f"Invalid bounding box format for key {key}. Expected [ymin, xmin, ymax, xmax]")
|
|
158
|
-
|
|
159
|
-
if google_bb:
|
|
160
|
-
value = [int(value[0] * img.shape[0] * 0.001), int(value[1] * img.shape[1]* 0.001),
|
|
161
|
-
int(value[2] * img.shape[0] * 0.001), int(value[3] * img.shape[1] * 0.001)]
|
|
162
|
-
print("Orignal Bounding Box from GOOGLE BBOX: ", value)
|
|
163
|
-
|
|
164
|
-
self._cv2.rectangle(
|
|
165
|
-
img=img,
|
|
166
|
-
pt1=(value[1], value[0]), # xmin, ymin
|
|
167
|
-
pt2=(value[3], value[2]), # xmax, ymax
|
|
168
|
-
color=color,
|
|
169
|
-
thickness=thickness
|
|
170
|
-
)
|
|
171
|
-
self._cv2.putText(
|
|
172
|
-
img=img,
|
|
173
|
-
text=key,
|
|
174
|
-
org=(value[1], value[0]),
|
|
175
|
-
fontFace=self._cv2.FONT_HERSHEY_SIMPLEX,
|
|
176
|
-
fontScale=font_scale,
|
|
177
|
-
color=color,
|
|
178
|
-
thickness=thickness//2
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
if show:
|
|
182
|
-
self._display_image(img)
|
|
183
|
-
|
|
184
|
-
return img
|
|
185
|
-
except Exception as e:
|
|
186
|
-
raise ValueError(f"Error adding bounding boxes to image: {str(e)}")
|
|
187
|
-
|
|
188
|
-
def _display_image(self, img: Any) -> None:
|
|
189
|
-
"""
|
|
190
|
-
Display an image
|
|
191
|
-
Args:
|
|
192
|
-
img: Image to display
|
|
193
|
-
"""
|
|
194
|
-
self._cv2.imshow("Image", img)
|
|
195
|
-
self._cv2.waitKey(0)
|
|
196
|
-
self._cv2.destroyAllWindows()
|
|
197
|
-
|
|
198
|
-
def save_image(self, img: Any, output_path: str) -> None:
|
|
199
|
-
"""
|
|
200
|
-
Save an image
|
|
201
|
-
Args:
|
|
202
|
-
img: Image to save
|
|
203
|
-
output_path: Path to save the image
|
|
204
|
-
"""
|
|
205
|
-
try:
|
|
206
|
-
self._cv2.imwrite(output_path, img)
|
|
207
|
-
except Exception as e:
|
|
208
|
-
raise ValueError(f"Error saving image: {str(e)}")
|
|
209
|
-
|
|
210
|
-
def process_image(self,
|
|
211
|
-
image_path: str,
|
|
212
|
-
output_path: Optional[str] = None,
|
|
213
|
-
show: bool = False,
|
|
214
|
-
**kwargs) -> Any:
|
|
215
|
-
"""
|
|
216
|
-
Complete image processing pipeline
|
|
217
|
-
Args:
|
|
218
|
-
image_path: Path to input image
|
|
219
|
-
output_path: Optional path to save output
|
|
220
|
-
show: Whether to display the result
|
|
221
|
-
**kwargs: Additional arguments for bounding box generation
|
|
222
|
-
Returns:
|
|
223
|
-
Any: Processed image
|
|
224
|
-
"""
|
|
225
|
-
boxes = self.generate_bounding_boxes(image_path, **kwargs)
|
|
226
|
-
img = self.add_bounding_boxes(image_path, boxes, show=show)
|
|
227
|
-
|
|
228
|
-
if output_path:
|
|
229
|
-
self.save_image(img, output_path)
|
|
230
|
-
|
|
231
|
-
return img
|
|
1
|
+
"""
|
|
2
|
+
Bounding box utilities
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from mb_rag.utils.extra import check_package
|
|
9
|
+
|
|
10
|
+
__all__ = ['BoundingBoxConfig', 'BoundingBoxProcessor']
|
|
11
|
+
|
|
12
|
+
def check_image_dependencies() -> None:
|
|
13
|
+
"""
|
|
14
|
+
Check if required image processing packages are installed
|
|
15
|
+
Raises:
|
|
16
|
+
ImportError: If any required package is missing
|
|
17
|
+
"""
|
|
18
|
+
if not check_package("PIL"):
|
|
19
|
+
raise ImportError("Pillow package not found. Please install it using: pip install Pillow")
|
|
20
|
+
if not check_package("cv2"):
|
|
21
|
+
raise ImportError("OpenCV package not found. Please install it using: pip install opencv-python")
|
|
22
|
+
if not check_package("google.generativeai"):
|
|
23
|
+
raise ImportError("Google Generative AI package not found. Please install it using: pip install google-generativeai")
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class BoundingBoxConfig:
|
|
27
|
+
"""Configuration for bounding box operations"""
|
|
28
|
+
model_name: str = "gemini-1.5-pro-latest"
|
|
29
|
+
api_key: Optional[str] = None
|
|
30
|
+
default_prompt: str = 'Return bounding boxes of container, for each only one return [ymin, xmin, ymax, xmax]'
|
|
31
|
+
|
|
32
|
+
class BoundingBoxProcessor:
|
|
33
|
+
"""
|
|
34
|
+
Class for processing images and generating bounding boxes
|
|
35
|
+
|
|
36
|
+
Attributes:
|
|
37
|
+
model: The Google Generative AI model instance
|
|
38
|
+
config: Configuration for bounding box operations
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, config: Optional[BoundingBoxConfig] = None, **kwargs):
|
|
42
|
+
"""
|
|
43
|
+
Initialize bounding box processor
|
|
44
|
+
Args:
|
|
45
|
+
config: Configuration for the processor
|
|
46
|
+
**kwargs: Additional arguments
|
|
47
|
+
"""
|
|
48
|
+
check_image_dependencies()
|
|
49
|
+
self.config = config or BoundingBoxConfig(**kwargs)
|
|
50
|
+
self._initialize_model()
|
|
51
|
+
self._initialize_image_libs()
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_model(cls, model_name: str, api_key: Optional[str] = None, **kwargs) -> 'BoundingBoxProcessor':
|
|
55
|
+
"""
|
|
56
|
+
Create processor with specific model configuration
|
|
57
|
+
Args:
|
|
58
|
+
model_name: Name of the model
|
|
59
|
+
api_key: Optional API key
|
|
60
|
+
**kwargs: Additional configuration
|
|
61
|
+
Returns:
|
|
62
|
+
BoundingBoxProcessor: Configured processor
|
|
63
|
+
"""
|
|
64
|
+
config = BoundingBoxConfig(
|
|
65
|
+
model_name=model_name,
|
|
66
|
+
api_key=api_key
|
|
67
|
+
)
|
|
68
|
+
return cls(config, **kwargs)
|
|
69
|
+
|
|
70
|
+
def _initialize_model(self) -> None:
|
|
71
|
+
"""Initialize the AI model"""
|
|
72
|
+
import google.generativeai as genai
|
|
73
|
+
|
|
74
|
+
api_key = self.config.api_key or os.environ.get("GOOGLE_API_KEY")
|
|
75
|
+
if not api_key:
|
|
76
|
+
raise ValueError("Google API key not found. Please provide api_key parameter or set GOOGLE_API_KEY environment variable.")
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
genai.configure(api_key=api_key)
|
|
80
|
+
self.model = genai.GenerativeModel(model_name=self.config.model_name)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise ValueError(f"Error initializing Google Generative AI model: {str(e)}")
|
|
83
|
+
|
|
84
|
+
def _initialize_image_libs(self) -> None:
|
|
85
|
+
"""Initialize image processing libraries"""
|
|
86
|
+
from PIL import Image
|
|
87
|
+
import cv2
|
|
88
|
+
self._Image = Image
|
|
89
|
+
self._cv2 = cv2
|
|
90
|
+
|
|
91
|
+
@staticmethod
|
|
92
|
+
def _validate_image_path(image_path: str) -> None:
|
|
93
|
+
"""
|
|
94
|
+
Validate image path
|
|
95
|
+
Args:
|
|
96
|
+
image_path: Path to image
|
|
97
|
+
Raises:
|
|
98
|
+
FileNotFoundError: If image doesn't exist
|
|
99
|
+
"""
|
|
100
|
+
if not os.path.exists(image_path):
|
|
101
|
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
|
102
|
+
|
|
103
|
+
def generate_bounding_boxes(self,
|
|
104
|
+
image_path: str,
|
|
105
|
+
prompt: Optional[str] = None
|
|
106
|
+
) -> Any:
|
|
107
|
+
"""
|
|
108
|
+
Generate bounding boxes for an image
|
|
109
|
+
Args:
|
|
110
|
+
image_path: Path to image
|
|
111
|
+
prompt: Custom prompt for the model
|
|
112
|
+
Returns:
|
|
113
|
+
Any: Model response with bounding boxes
|
|
114
|
+
"""
|
|
115
|
+
self._validate_image_path(image_path)
|
|
116
|
+
|
|
117
|
+
try:
|
|
118
|
+
image = self._Image.open(image_path)
|
|
119
|
+
prompt = prompt or self.config.default_prompt
|
|
120
|
+
return self.model.generate_content([image, prompt])
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise ValueError(f"Error generating bounding boxes: {str(e)}")
|
|
123
|
+
|
|
124
|
+
def add_bounding_boxes(self,
|
|
125
|
+
image_path: str,
|
|
126
|
+
bounding_boxes: Dict[str, List[int]],
|
|
127
|
+
color: Tuple[int, int, int] = (0, 0, 255),
|
|
128
|
+
thickness: int = 4,
|
|
129
|
+
font_scale: float = 1.0,
|
|
130
|
+
show: bool = False,
|
|
131
|
+
google_bb= False
|
|
132
|
+
) -> Any:
|
|
133
|
+
"""
|
|
134
|
+
Add bounding boxes to an image
|
|
135
|
+
Args:
|
|
136
|
+
image_path: Path to image
|
|
137
|
+
bounding_boxes: Dictionary of bounding boxes
|
|
138
|
+
color: BGR color tuple
|
|
139
|
+
thickness: Line thickness
|
|
140
|
+
font_scale: Font scale for labels
|
|
141
|
+
show: Whether to display the image
|
|
142
|
+
Returns:
|
|
143
|
+
Any: Image with bounding boxes
|
|
144
|
+
"""
|
|
145
|
+
self._validate_image_path(image_path)
|
|
146
|
+
|
|
147
|
+
if not isinstance(bounding_boxes, dict):
|
|
148
|
+
raise ValueError("bounding_boxes must be a dictionary")
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
img = self._cv2.imread(image_path)
|
|
152
|
+
if img is None:
|
|
153
|
+
raise ValueError(f"Failed to load image: {image_path}")
|
|
154
|
+
|
|
155
|
+
for key, value in bounding_boxes.items():
|
|
156
|
+
if not isinstance(value, list) or len(value) != 4:
|
|
157
|
+
raise ValueError(f"Invalid bounding box format for key {key}. Expected [ymin, xmin, ymax, xmax]")
|
|
158
|
+
|
|
159
|
+
if google_bb:
|
|
160
|
+
value = [int(value[0] * img.shape[0] * 0.001), int(value[1] * img.shape[1]* 0.001),
|
|
161
|
+
int(value[2] * img.shape[0] * 0.001), int(value[3] * img.shape[1] * 0.001)]
|
|
162
|
+
print("Orignal Bounding Box from GOOGLE BBOX: ", value)
|
|
163
|
+
|
|
164
|
+
self._cv2.rectangle(
|
|
165
|
+
img=img,
|
|
166
|
+
pt1=(value[1], value[0]), # xmin, ymin
|
|
167
|
+
pt2=(value[3], value[2]), # xmax, ymax
|
|
168
|
+
color=color,
|
|
169
|
+
thickness=thickness
|
|
170
|
+
)
|
|
171
|
+
self._cv2.putText(
|
|
172
|
+
img=img,
|
|
173
|
+
text=key,
|
|
174
|
+
org=(value[1], value[0]),
|
|
175
|
+
fontFace=self._cv2.FONT_HERSHEY_SIMPLEX,
|
|
176
|
+
fontScale=font_scale,
|
|
177
|
+
color=color,
|
|
178
|
+
thickness=thickness//2
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if show:
|
|
182
|
+
self._display_image(img)
|
|
183
|
+
|
|
184
|
+
return img
|
|
185
|
+
except Exception as e:
|
|
186
|
+
raise ValueError(f"Error adding bounding boxes to image: {str(e)}")
|
|
187
|
+
|
|
188
|
+
def _display_image(self, img: Any) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Display an image
|
|
191
|
+
Args:
|
|
192
|
+
img: Image to display
|
|
193
|
+
"""
|
|
194
|
+
self._cv2.imshow("Image", img)
|
|
195
|
+
self._cv2.waitKey(0)
|
|
196
|
+
self._cv2.destroyAllWindows()
|
|
197
|
+
|
|
198
|
+
def save_image(self, img: Any, output_path: str) -> None:
|
|
199
|
+
"""
|
|
200
|
+
Save an image
|
|
201
|
+
Args:
|
|
202
|
+
img: Image to save
|
|
203
|
+
output_path: Path to save the image
|
|
204
|
+
"""
|
|
205
|
+
try:
|
|
206
|
+
self._cv2.imwrite(output_path, img)
|
|
207
|
+
except Exception as e:
|
|
208
|
+
raise ValueError(f"Error saving image: {str(e)}")
|
|
209
|
+
|
|
210
|
+
def process_image(self,
|
|
211
|
+
image_path: str,
|
|
212
|
+
output_path: Optional[str] = None,
|
|
213
|
+
show: bool = False,
|
|
214
|
+
**kwargs) -> Any:
|
|
215
|
+
"""
|
|
216
|
+
Complete image processing pipeline
|
|
217
|
+
Args:
|
|
218
|
+
image_path: Path to input image
|
|
219
|
+
output_path: Optional path to save output
|
|
220
|
+
show: Whether to display the result
|
|
221
|
+
**kwargs: Additional arguments for bounding box generation
|
|
222
|
+
Returns:
|
|
223
|
+
Any: Processed image
|
|
224
|
+
"""
|
|
225
|
+
boxes = self.generate_bounding_boxes(image_path, **kwargs)
|
|
226
|
+
img = self.add_bounding_boxes(image_path, boxes, show=show)
|
|
227
|
+
|
|
228
|
+
if output_path:
|
|
229
|
+
self.save_image(img, output_path)
|
|
230
|
+
|
|
231
|
+
return img
|