openvisionkit 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,401 @@
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import mediapipe as mp
5
+ import numpy as np
6
+ from mediapipe.tasks import python
7
+ from mediapipe.tasks.python import vision
8
+
9
+ _MODEL_DIR = Path(__file__).parent / "models"
10
+ _DEFAULT_MODEL = str(_MODEL_DIR / "efficientdet_lite.tflite")
11
+
12
+
13
+ class ObjectDetector:
14
+ def __init__(
15
+ self,
16
+ model_path: str = _DEFAULT_MODEL,
17
+ max_results=5,
18
+ running_mode="IMAGE",
19
+ display_names_locale=b"en",
20
+ category_allowlist=None,
21
+ category_denylist=None,
22
+ ):
23
+ self.running_mode = getattr(vision.RunningMode, running_mode)
24
+ base_options = python.BaseOptions(model_asset_path=model_path)
25
+ options = vision.ObjectDetectorOptions(
26
+ base_options=base_options,
27
+ score_threshold=0.5,
28
+ max_results=max_results,
29
+ running_mode=self.running_mode,
30
+ display_names_locale=display_names_locale,
31
+ category_allowlist=category_allowlist,
32
+ category_denylist=category_denylist,
33
+ )
34
+ self.detector = vision.ObjectDetector.create_from_options(options)
35
+ self.MARGIN = 15
36
+ self.FONT_THICKNESS = 2
37
+ self.ROW_SIZE = 10
38
+ self.TEXT_COLOR = (0, 255, 0)
39
+ self.FONT_SIZE = 1
40
+
41
+ def _to_mp_image(self, image):
42
+ """
43
+ Convert a BGR image (as used by OpenCV) to an mp.Image format suitable for MediaPipe processing.
44
+ Args:
45
+ image: The input image in BGR format (as used by OpenCV).
46
+ Returns:
47
+ An mp.Image object in RGB format suitable for MediaPipe processing.
48
+ """
49
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
+ return mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb)
51
+
52
+ def detect(self, image, timestamp_ms=None):
53
+ """
54
+ Detect objects in the given image using the MediaPipe Object Detector.
55
+ Args:
56
+ image: The input image in BGR format (as used by OpenCV).
57
+ timestamp_ms: Optional timestamp in milliseconds for video processing (ignored in IMAGE mode).
58
+ Returns:
59
+ A list of detected objects with their bounding boxes and labels.
60
+ """
61
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
62
+ mp_image = self._to_mp_image(rgb)
63
+ if self.running_mode == vision.RunningMode.IMAGE:
64
+ result = self.detector.detect(mp_image)
65
+ else:
66
+ result = self.detector.detect_for_video(mp_image, timestamp_ms or 0)
67
+ return result, mp_image
68
+
69
+ def visualize_detections(self, image, detection_result):
70
+ """
71
+ Visualize detected objects on the image using MediaPipe's visualization utilities.
72
+ Args:
73
+ image: The input image in BGR format (as used by OpenCV).
74
+ detection_result: The result from the detect method containing detected objects.
75
+ Returns:
76
+ An annotated image with detected objects visualized.
77
+ """
78
+ for detection in detection_result.detections:
79
+ bbox = detection.bounding_box
80
+ category = detection.categories[0] if detection.categories else None
81
+ label = category.category_name if category else "Unknown"
82
+ score = category.score if category else 0.0
83
+
84
+ start_point = (int(bbox.origin_x), int(bbox.origin_y))
85
+ end_point = (
86
+ int(bbox.origin_x + bbox.width),
87
+ int(bbox.origin_y + bbox.height),
88
+ )
89
+
90
+ cv2.rectangle(
91
+ image, start_point, end_point, self.TEXT_COLOR, self.FONT_THICKNESS
92
+ )
93
+ text = f"{label}: {score:.2f}"
94
+ text_size, _ = cv2.getTextSize(
95
+ text, cv2.FONT_HERSHEY_SIMPLEX, self.FONT_SIZE, self.FONT_THICKNESS
96
+ )
97
+ text_origin = (start_point[0], start_point[1] + self.MARGIN)
98
+ cv2.putText(
99
+ image,
100
+ text,
101
+ text_origin,
102
+ cv2.FONT_HERSHEY_SIMPLEX,
103
+ self.FONT_SIZE,
104
+ self.TEXT_COLOR,
105
+ self.FONT_THICKNESS,
106
+ )
107
+ return image
108
+
109
+ def detect_objects(self, image, timestamp_ms=None):
110
+ detection_result, mp_image = self.detect(image, timestamp_ms=timestamp_ms)
111
+ image_copy = np.copy(mp_image.numpy_view())
112
+ annotated_image = self.visualize_detections(image_copy, detection_result)
113
+ return annotated_image
114
+
115
+ # ─────────────────────────── NEW METHODS ───────────────────────────
116
+
117
+ def count_objects(self, detection_result, label=None):
118
+ """Return total detections, optionally filtered to one class.
119
+
120
+ Args:
121
+ detection_result: Raw result from detect().
122
+ label: If provided, count only detections with this category name (case-insensitive).
123
+ Returns:
124
+ int
125
+ """
126
+ if label is None:
127
+ return len(detection_result.detections)
128
+ label_lower = label.lower()
129
+ return sum(
130
+ 1
131
+ for d in detection_result.detections
132
+ if d.categories and d.categories[0].category_name.lower() == label_lower
133
+ )
134
+
135
+ def filter_by_class(self, detection_result, allowed_classes):
136
+ """Return only detections whose top category is in allowed_classes.
137
+
138
+ Args:
139
+ detection_result: Raw result from detect().
140
+ allowed_classes: List of category name strings (case-insensitive).
141
+ Returns:
142
+ List of Detection objects.
143
+ """
144
+ allowed = {c.lower() for c in allowed_classes}
145
+ return [
146
+ d
147
+ for d in detection_result.detections
148
+ if d.categories and d.categories[0].category_name.lower() in allowed
149
+ ]
150
+
151
+ def get_largest_object(self, detection_result):
152
+ """Return the detection with the largest bounding-box area, or None.
153
+
154
+ Args:
155
+ detection_result: Raw result from detect().
156
+ Returns:
157
+ Detection object or None.
158
+ """
159
+ if not detection_result.detections:
160
+ return None
161
+ return max(
162
+ detection_result.detections,
163
+ key=lambda d: d.bounding_box.width * d.bounding_box.height,
164
+ )
165
+
166
+ def is_object_in_zone(self, detection, zone_rect):
167
+ """Check whether the center of a detection falls inside a zone rectangle.
168
+ Useful for counting objects crossing a virtual boundary or entering an area.
169
+
170
+ Args:
171
+ detection: A Detection object from detection_result.detections.
172
+ zone_rect: (x, y, w, h) zone in pixel coordinates.
173
+ Returns:
174
+ bool
175
+ """
176
+ b = detection.bounding_box
177
+ cx = int(b.origin_x + b.width / 2)
178
+ cy = int(b.origin_y + b.height / 2)
179
+ zx, zy, zw, zh = zone_rect
180
+ return (zx <= cx <= zx + zw) and (zy <= cy <= zy + zh)
181
+
182
+ def get_object_centers(self, image, detection_result):
183
+ """Return center coordinates and label for every detection.
184
+
185
+ Args:
186
+ image: Source BGR frame (used only for shape — not drawn on).
187
+ detection_result: Raw result from detect().
188
+ Returns:
189
+ List of dicts: [{'label': str, 'score': float, 'center': (cx, cy)}]
190
+ """
191
+ results = []
192
+ for d in detection_result.detections:
193
+ b = d.bounding_box
194
+ cx = int(b.origin_x + b.width / 2)
195
+ cy = int(b.origin_y + b.height / 2)
196
+ label = d.categories[0].category_name if d.categories else "unknown"
197
+ score = d.categories[0].score if d.categories else 0.0
198
+ results.append({"label": label, "score": score, "center": (cx, cy)})
199
+ return results
200
+
201
+ def draw_zone(
202
+ self, image, zone_rect, color=(0, 255, 255), label="Zone", thickness=2
203
+ ):
204
+ """Draw a named rectangular detection zone on the image.
205
+
206
+ Args:
207
+ image: BGR numpy array.
208
+ zone_rect: (x, y, w, h).
209
+ color: BGR color tuple.
210
+ label: Text displayed above the rectangle.
211
+ thickness: Border thickness in pixels.
212
+ Returns:
213
+ Annotated BGR numpy array.
214
+ """
215
+ out = image.copy()
216
+ x, y, w, h = zone_rect
217
+ cv2.rectangle(out, (x, y), (x + w, y + h), color, thickness)
218
+ cv2.putText(
219
+ out,
220
+ label,
221
+ (x, max(0, y - 8)),
222
+ cv2.FONT_HERSHEY_SIMPLEX,
223
+ 0.6,
224
+ color,
225
+ 2,
226
+ cv2.LINE_AA,
227
+ )
228
+ return out
229
+
230
+ def get_class_summary(self, detection_result):
231
+ """Return a dict mapping each detected class to its count.
232
+
233
+ Args:
234
+ detection_result: Raw result from detect().
235
+ Returns:
236
+ dict: {'person': 3, 'car': 1, ...}
237
+ """
238
+ summary = {}
239
+ for d in detection_result.detections:
240
+ if not d.categories:
241
+ continue
242
+ name = d.categories[0].category_name
243
+ summary[name] = summary.get(name, 0) + 1
244
+ return summary
245
+
246
+ def filter_by_confidence(self, detection_result, threshold=0.5):
247
+ """Return only detections whose top-category score meets or exceeds the threshold.
248
+
249
+ Args:
250
+ detection_result: Raw result from detect().
251
+ threshold: Minimum confidence score (inclusive). Default 0.5.
252
+ Returns:
253
+ List of Detection objects.
254
+ """
255
+ return [
256
+ d
257
+ for d in detection_result.detections
258
+ if d.categories and d.categories[0].score >= threshold
259
+ ]
260
+
261
+ def get_bounding_boxes(self, detection_result):
262
+ """Return bounding box info for every detection as a list of dicts.
263
+
264
+ Args:
265
+ detection_result: Raw result from detect().
266
+ Returns:
267
+ List of dicts: [{'label': str, 'score': float, 'bbox': (x, y, w, h)}]
268
+ """
269
+ boxes = []
270
+ for detection in detection_result.detections:
271
+ bb = detection.bounding_box
272
+ label = (
273
+ detection.categories[0].category_name
274
+ if detection.categories
275
+ else "unknown"
276
+ )
277
+ score = detection.categories[0].score if detection.categories else 0.0
278
+ boxes.append(
279
+ {
280
+ "label": label,
281
+ "score": score,
282
+ "bbox": (bb.origin_x, bb.origin_y, bb.width, bb.height),
283
+ }
284
+ )
285
+ return boxes
286
+
287
+ def is_crowded(self, detection_result, threshold=5):
288
+ """Return True if the number of detections meets or exceeds the threshold.
289
+
290
+ Args:
291
+ detection_result: Raw result from detect().
292
+ threshold: Minimum count to consider crowded. Default 5.
293
+ Returns:
294
+ bool
295
+ """
296
+ return len(detection_result.detections) >= threshold
297
+
298
+ def get_objects_by_size(self, detection_result):
299
+ """Return detections sorted by bounding-box area, largest first.
300
+
301
+ Args:
302
+ detection_result: Raw result from detect().
303
+ Returns:
304
+ List of Detection objects sorted descending by area.
305
+ """
306
+ return sorted(
307
+ detection_result.detections,
308
+ key=lambda d: d.bounding_box.width * d.bounding_box.height,
309
+ reverse=True,
310
+ )
311
+
312
+ def get_proximity(self, det_a, det_b):
313
+ """Return the Euclidean distance between the centers of two detections.
314
+
315
+ Args:
316
+ det_a: A Detection object.
317
+ det_b: A Detection object.
318
+ Returns:
319
+ float — pixel distance between bounding-box centers.
320
+ """
321
+ bb_a, bb_b = det_a.bounding_box, det_b.bounding_box
322
+ cx_a = bb_a.origin_x + bb_a.width / 2
323
+ cy_a = bb_a.origin_y + bb_a.height / 2
324
+ cx_b = bb_b.origin_x + bb_b.width / 2
325
+ cy_b = bb_b.origin_y + bb_b.height / 2
326
+ return float(((cx_a - cx_b) ** 2 + (cy_a - cy_b) ** 2) ** 0.5)
327
+
328
+ def detect_line_crossing(
329
+ self, detection_result, line_start, line_end, line_threshold=10
330
+ ):
331
+ """Return detections whose center is within line_threshold pixels of the line segment.
332
+
333
+ Useful for virtual tripwire / counting lines.
334
+
335
+ Args:
336
+ detection_result: Raw result from detect().
337
+ line_start: (x, y) tuple — start of the line segment.
338
+ line_end: (x, y) tuple — end of the line segment.
339
+ line_threshold: Maximum perpendicular distance in pixels. Default 10.
340
+ Returns:
341
+ List of dicts: [{'label': str, 'center': (cx, cy), 'distance': float}]
342
+ """
343
+ p1 = np.array(line_start, dtype=float)
344
+ p2 = np.array(line_end, dtype=float)
345
+ line_len = np.linalg.norm(p2 - p1)
346
+ crossing = []
347
+ for detection in detection_result.detections:
348
+ bb = detection.bounding_box
349
+ cx = bb.origin_x + bb.width / 2
350
+ cy = bb.origin_y + bb.height / 2
351
+ p = np.array([cx, cy])
352
+ if line_len == 0:
353
+ dist = float(np.linalg.norm(p - p1))
354
+ else:
355
+ t = float(np.clip(np.dot(p - p1, p2 - p1) / (line_len**2), 0, 1))
356
+ proj = p1 + t * (p2 - p1)
357
+ dist = float(np.linalg.norm(p - proj))
358
+ if dist <= line_threshold:
359
+ label = (
360
+ detection.categories[0].category_name
361
+ if detection.categories
362
+ else "unknown"
363
+ )
364
+ crossing.append({"label": label, "center": (cx, cy), "distance": dist})
365
+ return crossing
366
+
367
+ def export_to_json(self, detection_result):
368
+ """Serialise detection results to a JSON-compatible dict.
369
+
370
+ Args:
371
+ detection_result: Raw result from detect().
372
+ Returns:
373
+ dict with key 'detections', each entry containing label, score, and bbox.
374
+ """
375
+ return {
376
+ "detections": [
377
+ {
378
+ "label": (
379
+ d.categories[0].category_name if d.categories else "unknown"
380
+ ),
381
+ "score": d.categories[0].score if d.categories else 0.0,
382
+ "bbox": {
383
+ "x": d.bounding_box.origin_x,
384
+ "y": d.bounding_box.origin_y,
385
+ "width": d.bounding_box.width,
386
+ "height": d.bounding_box.height,
387
+ },
388
+ }
389
+ for d in detection_result.detections
390
+ ]
391
+ }
392
+
393
+ def batch_detect(self, images):
394
+ """Run detect() on a list of images and return one result per frame.
395
+
396
+ Args:
397
+ images: List of BGR numpy arrays.
398
+ Returns:
399
+ List of DetectionResult objects (one per image).
400
+ """
401
+ return [self.detect(img)[0] for img in images]