eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,859 @@
1
+ import math
2
+ import warnings
3
+ from collections import Counter, defaultdict, deque
4
+ from functools import lru_cache
5
+ from typing import Any, Deque, Dict, Iterable, List, Literal, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import numpy.typing as npt
10
+
11
+ from eye.config import CLASS_NAME_DATA_FIELD
12
+ from eye.detection.core import Detections
13
+ from eye.detection.utils import cross_product
14
+ from eye.draw.color import Color
15
+ from eye.draw.utils import draw_rectangle, draw_text
16
+ from eye.geometry.core import Point, Position, Rect, Vector
17
+ from eye.utils.image import overlay_image
18
+ from eye.utils.internal import EyeWarnings
19
+
20
+ TEXT_MARGIN = 10
21
+
22
+
23
+ class LineZone:
24
+ """
25
+ This class is responsible for counting the number of objects that cross a
26
+ predefined line.
27
+
28
+ <video controls>
29
+ <source
30
+ src="https://media.roboflow.com/eye/cookbooks/count-objects-crossing-the-line-result-1280x720.mp4"
31
+ type="video/mp4">
32
+ </video>
33
+
34
+ !!! warning
35
+
36
+ LineZone uses the `tracker_id`. Read
37
+ [here](/latest/trackers/) to learn how to plug
38
+ tracking into your inference pipeline.
39
+
40
+ Attributes:
41
+ in_count (int): The number of objects that have crossed the line from outside
42
+ to inside.
43
+ out_count (int): The number of objects that have crossed the line from inside
44
+ to outside.
45
+ in_count_per_class (Dict[int, int]): Number of objects of each class that have
46
+ crossed the line from outside to inside.
47
+ out_count_per_class (Dict[int, int]): Number of objects of each class that have
48
+ crossed the line from inside to outside.
49
+
50
+ Example:
51
+ ```python
52
+ import eye as sv
53
+ from ultralytics import YOLO
54
+
55
+ model = YOLO(<SOURCE_MODEL_PATH>)
56
+ tracker = sv.ByteTrack()
57
+ frames_generator = sv.get_video_frames_generator(<SOURCE_VIDEO_PATH>)
58
+ start, end = sv.Point(x=0, y=1080), sv.Point(x=3840, y=1080)
59
+ line_zone = sv.LineZone(start=start, end=end)
60
+
61
+ for frame in frames_generator:
62
+ result = model(frame)[0]
63
+ detections = sv.Detections.from_ultralytics(result)
64
+ detections = tracker.update_with_detections(detections)
65
+ crossed_in, crossed_out = line_zone.trigger(detections)
66
+
67
+ line_zone.in_count, line_zone.out_count
68
+ # 7, 2
69
+ ```
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ start: Point,
75
+ end: Point,
76
+ triggering_anchors: Iterable[Position] = (
77
+ Position.TOP_LEFT,
78
+ Position.TOP_RIGHT,
79
+ Position.BOTTOM_LEFT,
80
+ Position.BOTTOM_RIGHT,
81
+ ),
82
+ minimum_crossing_threshold: int = 1,
83
+ ):
84
+ """
85
+ Args:
86
+ start (Point): The starting point of the line.
87
+ end (Point): The ending point of the line.
88
+ triggering_anchors (List[sv.Position]): A list of positions
89
+ specifying which anchors of the detections bounding box
90
+ to consider when deciding on whether the detection
91
+ has passed the line counter or not. By default, this
92
+ contains the four corners of the detection's bounding box
93
+ minimum_crossing_threshold (int): Detection needs to be seen
94
+ on the other side of the line for this many frames to be
95
+ considered as having crossed the line. This is useful when
96
+ dealing with unstable bounding boxes or when detections
97
+ may linger on the line.
98
+ """
99
+ self.vector = Vector(start=start, end=end)
100
+ self.limits = self._calculate_region_of_interest_limits(vector=self.vector)
101
+ self.crossing_history_length = max(2, minimum_crossing_threshold + 1)
102
+ self.crossing_state_history: Dict[int, Deque[bool]] = defaultdict(
103
+ lambda: deque(maxlen=self.crossing_history_length)
104
+ )
105
+ self._in_count_per_class: Counter = Counter()
106
+ self._out_count_per_class: Counter = Counter()
107
+ self.triggering_anchors = triggering_anchors
108
+ if not list(self.triggering_anchors):
109
+ raise ValueError("Triggering anchors cannot be empty.")
110
+ self.class_id_to_name: Dict[int, str] = {}
111
+
112
+ @property
113
+ def in_count(self) -> int:
114
+ return sum(self._in_count_per_class.values())
115
+
116
+ @property
117
+ def out_count(self) -> int:
118
+ return sum(self._out_count_per_class.values())
119
+
120
+ @property
121
+ def in_count_per_class(self) -> Dict[int, int]:
122
+ return dict(self._in_count_per_class)
123
+
124
+ @property
125
+ def out_count_per_class(self) -> Dict[int, int]:
126
+ return dict(self._out_count_per_class)
127
+
128
+ def trigger(self, detections: Detections) -> Tuple[np.ndarray, np.ndarray]:
129
+ """
130
+ Update the `in_count` and `out_count` based on the objects that cross the line.
131
+
132
+ Args:
133
+ detections (Detections): A list of detections for which to update the
134
+ counts.
135
+
136
+ Returns:
137
+ A tuple of two boolean NumPy arrays. The first array indicates which
138
+ detections have crossed the line from outside to inside. The second
139
+ array indicates which detections have crossed the line from inside to
140
+ outside.
141
+ """
142
+ crossed_in = np.full(len(detections), False)
143
+ crossed_out = np.full(len(detections), False)
144
+
145
+ if len(detections) == 0:
146
+ return crossed_in, crossed_out
147
+
148
+ if detections.tracker_id is None:
149
+ warnings.warn(
150
+ "Line zone counting skipped. LineZone requires tracker_id. Refer to "
151
+ "Eye trackers are required for line zone counting. "
152
+ "Ensure tracker_id is available.",
153
+ category=EyeWarnings,
154
+ )
155
+ return crossed_in, crossed_out
156
+
157
+ self._update_class_id_to_name(detections)
158
+
159
+ in_limits, has_any_left_trigger, has_any_right_trigger = (
160
+ self._compute_anchor_sides(detections)
161
+ )
162
+
163
+ class_ids: List[Optional[int]] = (
164
+ list(detections.class_id)
165
+ if detections.class_id is not None
166
+ else [None] * len(detections)
167
+ )
168
+
169
+ for i, (class_id, tracker_id) in enumerate(
170
+ zip(class_ids, detections.tracker_id)
171
+ ):
172
+ if not in_limits[i]:
173
+ continue
174
+
175
+ if has_any_left_trigger[i] and has_any_right_trigger[i]:
176
+ continue
177
+
178
+ tracker_state: bool = has_any_left_trigger[i]
179
+ crossing_history = self.crossing_state_history[tracker_id]
180
+ crossing_history.append(tracker_state)
181
+
182
+ if len(crossing_history) < self.crossing_history_length:
183
+ continue
184
+
185
+ # TODO: Account for incorrect class_id.
186
+ # Most likely this would involve indexing self.crossing_state_history
187
+ # with (tracker_id, class_id).
188
+
189
+ oldest_state = crossing_history[0]
190
+ if crossing_history.count(oldest_state) > 1:
191
+ continue
192
+
193
+ if tracker_state:
194
+ self._in_count_per_class[class_id] += 1
195
+ crossed_in[i] = True
196
+ else:
197
+ self._out_count_per_class[class_id] += 1
198
+ crossed_out[i] = True
199
+
200
+ return crossed_in, crossed_out
201
+
202
+ @staticmethod
203
+ def _calculate_region_of_interest_limits(vector: Vector) -> Tuple[Vector, Vector]:
204
+ magnitude = vector.magnitude
205
+
206
+ if magnitude == 0:
207
+ raise ValueError("The magnitude of the vector cannot be zero.")
208
+
209
+ delta_x = vector.end.x - vector.start.x
210
+ delta_y = vector.end.y - vector.start.y
211
+
212
+ unit_vector_x = delta_x / magnitude
213
+ unit_vector_y = delta_y / magnitude
214
+
215
+ perpendicular_vector_x = -unit_vector_y
216
+ perpendicular_vector_y = unit_vector_x
217
+
218
+ start_region_limit = Vector(
219
+ start=vector.start,
220
+ end=Point(
221
+ x=vector.start.x + perpendicular_vector_x,
222
+ y=vector.start.y + perpendicular_vector_y,
223
+ ),
224
+ )
225
+ end_region_limit = Vector(
226
+ start=vector.end,
227
+ end=Point(
228
+ x=vector.end.x - perpendicular_vector_x,
229
+ y=vector.end.y - perpendicular_vector_y,
230
+ ),
231
+ )
232
+ return start_region_limit, end_region_limit
233
+
234
+ def _compute_anchor_sides(
235
+ self, detections: Detections
236
+ ) -> Tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
237
+ """
238
+ Find if detections' anchors are within the limit of the line
239
+ zone and which anchors are on its left and right side.
240
+
241
+ Assumes:
242
+ * At least 1 detection is provided
243
+ * Detections have `tracker_id`
244
+
245
+ The limit is defined as the region between the two lines,
246
+ perpendicular to the line zone, and passing through its start
247
+ and end points, as shown below:
248
+
249
+ Limits:
250
+ ```
251
+ | IN ↑
252
+ | |
253
+ OUT o---LINE---o OUT
254
+ | |
255
+ ↓ IN |
256
+ ```
257
+
258
+ Args:
259
+ detections (Detections): The detections to check.
260
+
261
+ Returns:
262
+ result (Tuple[np.ndarray, np.ndarray, np.ndarray]):
263
+ All 3 arrays are boolean arrays of shape (N, ) where N is the
264
+ number of detections. The first array, `in_limits`, indicates
265
+ if the detection's anchor is within the line zone limits.
266
+ The second array, `has_any_left_trigger`, indicates if the
267
+ detection's anchor is on the left side of the line zone.
268
+ The third array, `has_any_right_trigger`, indicates if the
269
+ detection's anchor is on the right side of the line zone.
270
+ """
271
+ assert len(detections) > 0
272
+ assert detections.tracker_id is not None
273
+
274
+ all_anchors = np.array(
275
+ [
276
+ detections.get_anchors_coordinates(anchor)
277
+ for anchor in self.triggering_anchors
278
+ ]
279
+ )
280
+
281
+ cross_products_1 = cross_product(all_anchors, self.limits[0])
282
+ cross_products_2 = cross_product(all_anchors, self.limits[1])
283
+
284
+ # Works because limit vectors are pointing in opposite directions
285
+ in_limits = (cross_products_1 > 0) == (cross_products_2 > 0)
286
+ in_limits = np.all(in_limits, axis=0)
287
+
288
+ triggers = cross_product(all_anchors, self.vector) < 0
289
+ has_any_left_trigger = np.any(triggers, axis=0)
290
+ has_any_right_trigger = np.any(~triggers, axis=0)
291
+
292
+ return in_limits, has_any_left_trigger, has_any_right_trigger
293
+
294
+ def _update_class_id_to_name(self, detections: Detections) -> None:
295
+ """
296
+ Update the attribute keeping track of which class
297
+ IDs correspond to which class names.
298
+
299
+ Assumes that class_names are only provided when class_ids are.
300
+ """
301
+ class_names = detections.data.get(CLASS_NAME_DATA_FIELD)
302
+ assert class_names is None or detections.class_id is not None
303
+
304
+ if detections.class_id is None:
305
+ return
306
+
307
+ if class_names is None:
308
+ new_names = {class_id: str(class_id) for class_id in detections.class_id}
309
+ else:
310
+ new_names = {
311
+ class_id: class_name
312
+ for class_id, class_name in zip(detections.class_id, class_names)
313
+ }
314
+ self.class_id_to_name.update(new_names)
315
+
316
+
317
+ class LineZoneAnnotator:
318
+ def __init__(
319
+ self,
320
+ thickness: int = 2,
321
+ color: Color = Color.WHITE,
322
+ text_thickness: int = 2,
323
+ text_color: Color = Color.BLACK,
324
+ text_scale: float = 0.5,
325
+ text_offset: float = 1.5,
326
+ text_padding: int = 10,
327
+ custom_in_text: Optional[str] = None,
328
+ custom_out_text: Optional[str] = None,
329
+ display_in_count: bool = True,
330
+ display_out_count: bool = True,
331
+ display_text_box: bool = True,
332
+ text_orient_to_line: bool = False,
333
+ text_centered: bool = True,
334
+ ):
335
+ """
336
+ A class for drawing the `LineZone` and its detected object count
337
+ on an image.
338
+
339
+ Attributes:
340
+ thickness (int): Line thickness.
341
+ color (Color): Line color.
342
+ text_thickness (int): Text thickness.
343
+ text_color (Color): Text color.
344
+ text_scale (float): Text scale.
345
+ text_offset (float): How far the text will be from the line.
346
+ text_padding (int): The empty space in the text box, surrounding the text.
347
+ custom_in_text (Optional[str]): Write something else instead of "in".
348
+ custom_out_text (Optional[str]): Write something else instead of "out".
349
+ display_in_count (bool): Pass `False` to hide the "in" count.
350
+ display_out_count (bool): Pass `False` to hide the "out" count.
351
+ display_text_box (bool): Pass `False` to hide the text background box.
352
+ text_orient_to_line (bool): ⭐ Match text orientation to the line.
353
+ Recommended to set to `True`.
354
+ text_centered (bool): Pass `False` to disable text centering. Useful
355
+ when the label overlaps something important.
356
+
357
+ """
358
+ self.thickness: int = thickness
359
+ self.color: Color = color
360
+ self.text_thickness: int = text_thickness
361
+ self.text_color: Color = text_color
362
+ self.text_scale: float = text_scale
363
+ self.text_offset: float = text_offset
364
+ self.text_padding: int = text_padding
365
+ self.in_text: str = custom_in_text if custom_in_text else "in"
366
+ self.out_text: str = custom_out_text if custom_out_text else "out"
367
+ self.display_in_count: bool = display_in_count
368
+ self.display_out_count: bool = display_out_count
369
+ self.display_text_box: bool = display_text_box
370
+ self.text_orient_to_line: bool = text_orient_to_line
371
+ self.text_centered: bool = text_centered
372
+
373
+ def annotate(self, frame: np.ndarray, line_counter: LineZone) -> np.ndarray:
374
+ """
375
+ Draws the line on the frame using the line zone provided.
376
+
377
+ Attributes:
378
+ frame (np.ndarray): The image on which the line will be drawn.
379
+ line_counter (LineZone): The line zone
380
+ that will be used to draw the line.
381
+
382
+ Returns:
383
+ (np.ndarray): The image with the line drawn on it.
384
+
385
+ """
386
+ line_start = line_counter.vector.start.as_xy_int_tuple()
387
+ line_end = line_counter.vector.end.as_xy_int_tuple()
388
+ cv2.line(
389
+ frame,
390
+ line_start,
391
+ line_end,
392
+ self.color.as_bgr(),
393
+ self.thickness,
394
+ lineType=cv2.LINE_AA,
395
+ shift=0,
396
+ )
397
+ cv2.circle(
398
+ frame,
399
+ line_start,
400
+ radius=5,
401
+ color=self.text_color.as_bgr(),
402
+ thickness=-1,
403
+ lineType=cv2.LINE_AA,
404
+ )
405
+ cv2.circle(
406
+ frame,
407
+ line_end,
408
+ radius=5,
409
+ color=self.text_color.as_bgr(),
410
+ thickness=-1,
411
+ lineType=cv2.LINE_AA,
412
+ )
413
+
414
+ in_text = f"{self.in_text}: {line_counter.in_count}"
415
+ out_text = f"{self.out_text}: {line_counter.out_count}"
416
+ line_angle_degrees = self._get_line_angle(line_counter)
417
+
418
+ for text, is_shown, is_in_count in [
419
+ (in_text, self.display_in_count, True),
420
+ (out_text, self.display_out_count, False),
421
+ ]:
422
+ if not is_shown:
423
+ continue
424
+
425
+ if line_angle_degrees == 0 or not self.text_orient_to_line:
426
+ self._draw_basic_label(
427
+ frame=frame,
428
+ line_center=line_counter.vector.center,
429
+ text=text,
430
+ is_in_count=is_in_count,
431
+ )
432
+ else:
433
+ self._draw_oriented_label(
434
+ frame=frame,
435
+ line_zone=line_counter,
436
+ text=text,
437
+ is_in_count=is_in_count,
438
+ )
439
+
440
+ return frame
441
+
442
+ def _get_line_angle(self, line_zone: LineZone) -> float:
443
+ """
444
+ Calculate the line counter angle (in degrees).
445
+
446
+ Args:
447
+ line_zone (LineZone): The line zone object.
448
+
449
+ Returns:
450
+ (float): Line counter angle, in degrees.
451
+ """
452
+ start_point = line_zone.vector.start.as_xy_int_tuple()
453
+ end_point = line_zone.vector.end.as_xy_int_tuple()
454
+
455
+ delta_x = end_point[0] - start_point[0]
456
+ delta_y = end_point[1] - start_point[1]
457
+
458
+ if delta_x == 0:
459
+ line_angle = 90.0
460
+ line_angle += 180 if delta_y < 0 else 0
461
+ else:
462
+ line_angle = math.degrees(math.atan(delta_y / delta_x))
463
+ line_angle += 180 if delta_x < 0 else 0
464
+
465
+ return line_angle
466
+
467
+ def _calculate_anchor_in_frame(
468
+ self,
469
+ line_zone: LineZone,
470
+ text_width: int,
471
+ text_height: int,
472
+ is_in_count: bool,
473
+ label_dimension: int,
474
+ ) -> Tuple[int, int]:
475
+ """
476
+ Calculate insertion anchor in frame to position the center of the count image.
477
+
478
+ Args:
479
+ line_zone (LineZone): The line counter object used for counting.
480
+ text_width (int): Text width.
481
+ text_height (int): Text height.
482
+ is_in_count (bool): Whether the count should be placed over or below line.
483
+ label_dimension (int): Size of the label image. Assumes the
484
+ label is rectangular.
485
+
486
+ Returns:
487
+ (Tuple[int, int]): xy, point in an image where the label will be placed.
488
+ """
489
+ line_angle = self._get_line_angle(line_zone)
490
+
491
+ if self.text_centered:
492
+ mid_point = Vector(
493
+ start=line_zone.vector.start, end=line_zone.vector.end
494
+ ).center.as_xy_int_tuple()
495
+ anchor = list(mid_point)
496
+ else:
497
+ end_point = line_zone.vector.end.as_xy_int_tuple()
498
+ anchor = list(end_point)
499
+
500
+ move_along_x = int(
501
+ math.cos(math.radians(line_angle))
502
+ * (text_width / 2 + self.text_padding)
503
+ )
504
+ move_along_y = int(
505
+ math.sin(math.radians(line_angle))
506
+ * (text_width / 2 + self.text_padding)
507
+ )
508
+
509
+ anchor[0] -= move_along_x
510
+ anchor[1] -= move_along_y
511
+
512
+ move_perpendicular_x = int(
513
+ math.sin(math.radians(line_angle)) * (self.text_offset * text_height)
514
+ )
515
+ move_perpendicular_y = int(
516
+ math.cos(math.radians(line_angle)) * (self.text_offset * text_height)
517
+ )
518
+
519
+ if is_in_count:
520
+ anchor[0] += move_perpendicular_x
521
+ anchor[1] -= move_perpendicular_y
522
+ else:
523
+ anchor[0] -= move_perpendicular_x
524
+ anchor[1] += move_perpendicular_y
525
+
526
+ x1 = max(anchor[0] - label_dimension // 2, 0)
527
+ y1 = max(anchor[1] - label_dimension // 2, 0)
528
+
529
+ return x1, y1
530
+
531
+ def _draw_basic_label(
532
+ self,
533
+ frame: np.ndarray,
534
+ line_center: Point,
535
+ text: str,
536
+ is_in_count: bool,
537
+ ) -> np.ndarray:
538
+ """
539
+ Draw the count label on the frame. For example: "out: 7".
540
+ The label contains horizontal text and is not rotated.
541
+
542
+ Args:
543
+ frame (np.ndarray): The entire scene, on which the label will be placed.
544
+ line_center (Point): The center of the line zone.
545
+ text (str): The text that will be drawn.
546
+ is_in_count (bool): Whether to display the in count (above line)
547
+ or out count (below line).
548
+
549
+ Returns:
550
+ (np.ndarray): The scene with the label drawn on it.
551
+ """
552
+ _, text_height = cv2.getTextSize(
553
+ text, cv2.FONT_HERSHEY_SIMPLEX, self.text_scale, self.text_thickness
554
+ )[0]
555
+
556
+ if is_in_count:
557
+ line_center.y -= int(self.text_offset * text_height)
558
+ else:
559
+ line_center.y += int(self.text_offset * text_height)
560
+
561
+ draw_text(
562
+ scene=frame,
563
+ text=text,
564
+ text_anchor=line_center,
565
+ text_color=self.text_color,
566
+ text_scale=self.text_scale,
567
+ text_thickness=self.text_thickness,
568
+ text_padding=self.text_padding,
569
+ background_color=self.color if self.display_text_box else None,
570
+ )
571
+
572
+ return frame
573
+
574
+ def _draw_oriented_label(
575
+ self,
576
+ frame: np.ndarray,
577
+ line_zone: LineZone,
578
+ text: str,
579
+ is_in_count: bool,
580
+ ) -> np.ndarray:
581
+ """
582
+ Draw the count label on the frame. For example: "out: 7".
583
+ The label is oriented to match the line angle.
584
+
585
+ Args:
586
+ frame (np.ndarray): The entire scene, on which the label will be placed.
587
+ line_zone (LineZone): The line zone responsible for counting
588
+ objects crossing it.
589
+ text (str): The text that will be drawn.
590
+ is_in_count (bool): Whether to display the in count (above line)
591
+ or out count (below line).
592
+
593
+ Returns:
594
+ (np.ndarray): The scene with the label drawn on it.
595
+ """
596
+
597
+ line_angle_degrees = self._get_line_angle(line_zone)
598
+ label_image = self._make_label_image(
599
+ text,
600
+ text_scale=self.text_scale,
601
+ text_thickness=self.text_thickness,
602
+ text_padding=self.text_padding,
603
+ text_color=self.text_color,
604
+ text_box_show=self.display_text_box,
605
+ text_box_color=self.color,
606
+ line_angle_degrees=line_angle_degrees,
607
+ )
608
+ assert label_image.shape[0] == label_image.shape[1]
609
+
610
+ text_width, text_height = cv2.getTextSize(
611
+ text, cv2.FONT_HERSHEY_SIMPLEX, self.text_scale, self.text_thickness
612
+ )[0]
613
+
614
+ label_anchor = self._calculate_anchor_in_frame(
615
+ line_zone=line_zone,
616
+ text_width=text_width,
617
+ text_height=text_height,
618
+ is_in_count=is_in_count,
619
+ label_dimension=label_image.shape[0],
620
+ )
621
+
622
+ frame = overlay_image(frame, label_image, label_anchor)
623
+
624
+ return frame
625
+
626
+ @staticmethod
627
+ @lru_cache(maxsize=32)
628
+ def _make_label_image(
629
+ text: str,
630
+ *,
631
+ text_scale: float,
632
+ text_thickness: int,
633
+ text_padding: int,
634
+ text_color: Color,
635
+ text_box_show: bool,
636
+ text_box_color: Color,
637
+ line_angle_degrees: float,
638
+ ) -> np.ndarray:
639
+ """
640
+ Create the small text box displaying line zone count. E.g. "out: 7".
641
+
642
+ Args:
643
+ text (str): The text to display.
644
+ text_scale (float): The scale of the text.
645
+ text_thickness (int): The thickness of the text.
646
+ text_padding (int): The padding around the text.
647
+ text_color (Color): The color of the text.
648
+ text_box_show (bool): Whether to display the text box.
649
+ text_box_color (Color): The color of the text box.
650
+ line_angle_degrees (float): The angle of the line in degrees.
651
+
652
+ Returns:
653
+ (np.ndarray): The label of shape (H, W, 4), in BGRA format.
654
+ """
655
+ text_width, text_height = cv2.getTextSize(
656
+ text, cv2.FONT_HERSHEY_SIMPLEX, text_scale, text_thickness
657
+ )[0]
658
+
659
+ annotation_dim = int((max(text_width, text_height) + text_padding * 2) * 1.5)
660
+ annotation_shape = (annotation_dim, annotation_dim)
661
+ annotation_center = Point(annotation_dim // 2, annotation_dim // 2)
662
+
663
+ annotation = np.zeros((*annotation_shape, 3), dtype=np.uint8)
664
+ annotation_alpha = np.zeros((*annotation_shape, 1), dtype=np.uint8)
665
+
666
+ text_args: Dict[str, Any] = dict(
667
+ text=text,
668
+ text_anchor=annotation_center,
669
+ text_scale=text_scale,
670
+ text_thickness=text_thickness,
671
+ text_padding=text_padding,
672
+ )
673
+ draw_text(
674
+ scene=annotation,
675
+ text_color=text_color,
676
+ background_color=text_box_color if text_box_show else None,
677
+ **text_args,
678
+ )
679
+ draw_text(
680
+ scene=annotation_alpha,
681
+ text_color=Color.WHITE,
682
+ background_color=Color.WHITE if text_box_show else None,
683
+ **text_args,
684
+ )
685
+ annotation = np.dstack((annotation, annotation_alpha))
686
+
687
+ # Make sure text is displayed upright
688
+ if 90 < line_angle_degrees % 360 < 270:
689
+ annotation = cv2.flip(annotation, flipCode=-1).astype(np.uint8)
690
+
691
+ rotation_angle = -line_angle_degrees
692
+ rotation_matrix = cv2.getRotationMatrix2D(
693
+ annotation_center.as_xy_float_tuple(), rotation_angle, scale=1
694
+ )
695
+ annotation = cv2.warpAffine(annotation, rotation_matrix, annotation_shape)
696
+
697
+ return annotation
698
+
699
+
700
+ class LineZoneAnnotatorMulticlass:
701
+ def __init__(
702
+ self,
703
+ *,
704
+ table_position: Literal[
705
+ Position.TOP_LEFT,
706
+ Position.TOP_RIGHT,
707
+ Position.BOTTOM_LEFT,
708
+ Position.BOTTOM_RIGHT,
709
+ ] = Position.TOP_RIGHT,
710
+ table_color: Color = Color.WHITE,
711
+ table_margin: int = 10,
712
+ table_padding: int = 10,
713
+ table_max_width: int = 400,
714
+ text_color: Color = Color.BLACK,
715
+ text_scale: float = 0.75,
716
+ text_thickness: int = 1,
717
+ force_draw_class_ids: bool = False,
718
+ ):
719
+ """
720
+ Draw a table showing how many items of each class crossed each line.
721
+
722
+ Args:
723
+ table_position (Position): The position of the table.
724
+ table_color (Color): The color of the table.
725
+ table_margin (int): The margin of the table from the image border.
726
+ table_padding (int): The padding of the table.
727
+ table_max_width (int): The maximum width of the table.
728
+ text_color (Color): The color of the text.
729
+ text_scale (float): The scale of the text.
730
+ text_thickness (int): The thickness of the text.
731
+ force_draw_class_ids (bool): Instead of writing the class names,
732
+ on the table, write the class IDs. E.g. instead of `person: 6`,
733
+ write `0: 6`.
734
+ """
735
+ if table_position not in {
736
+ Position.TOP_LEFT,
737
+ Position.TOP_RIGHT,
738
+ Position.BOTTOM_LEFT,
739
+ Position.BOTTOM_RIGHT,
740
+ }:
741
+ raise ValueError(
742
+ "Invalid table position. Supported values are:"
743
+ " TOP_LEFT, TOP_RIGHT, BOTTOM_LEFT, BOTTOM_RIGHT."
744
+ )
745
+
746
+ self.table_position = table_position
747
+ self.table_color = table_color
748
+ self.table_margin = table_margin
749
+ self.table_padding = table_padding
750
+ self.table_max_width = table_max_width
751
+ self.text_color = text_color
752
+ self.text_scale = text_scale
753
+ self.text_thickness = text_thickness
754
+ self.force_draw_class_ids = force_draw_class_ids
755
+
756
+ def annotate(
757
+ self,
758
+ frame: np.ndarray,
759
+ line_zones: List[LineZone],
760
+ line_zone_labels: Optional[List[str]] = None,
761
+ ) -> np.ndarray:
762
+ """
763
+ Draws a table with the number of objects of each class that crossed each line.
764
+
765
+ Attributes:
766
+ frame (np.ndarray): The image on which the table will be drawn.
767
+ line_zones (List[LineZone]): The line zones to be annotated.
768
+ line_zone_labels (Optional[List[str]]): The labels, one for each
769
+ line zone. If not provided, the default labels will be used.
770
+
771
+ Returns:
772
+ (np.ndarray): The image with the table drawn on it.
773
+
774
+ """
775
+ if line_zone_labels is None:
776
+ line_zone_labels = [f"Line {i + 1}:" for i in range(len(line_zones))]
777
+ if len(line_zones) != len(line_zone_labels):
778
+ raise ValueError("The number of line zones and their labels must match.")
779
+
780
+ text_lines = ["Line Crossings:"]
781
+ for line_zone, line_zone_label in zip(line_zones, line_zone_labels):
782
+ text_lines.append(line_zone_label)
783
+ class_id_to_name = line_zone.class_id_to_name
784
+
785
+ for direction, count_per_class in [
786
+ ("In", line_zone.in_count_per_class),
787
+ ("Out", line_zone.out_count_per_class),
788
+ ]:
789
+ if not count_per_class:
790
+ continue
791
+
792
+ text_lines.append(f" {direction}:")
793
+ for class_id, count in count_per_class.items():
794
+ class_name = (
795
+ class_id_to_name.get(class_id, str(class_id))
796
+ if not self.force_draw_class_ids
797
+ else str(class_id)
798
+ )
799
+ text_lines.append(f" {class_name}: {count}")
800
+
801
+ table_width, table_height = 0, 0
802
+ for line in text_lines:
803
+ text_width, text_height = cv2.getTextSize(
804
+ line, cv2.FONT_HERSHEY_SIMPLEX, self.text_scale, self.text_thickness
805
+ )[0]
806
+ text_height += TEXT_MARGIN
807
+ table_width = max(table_width, text_width)
808
+ table_height += text_height
809
+
810
+ table_width += 2 * self.table_padding
811
+ table_height += 2 * self.table_padding
812
+ table_max_height = frame.shape[0] - 2 * self.table_margin
813
+ table_height = min(table_height, table_max_height)
814
+ table_width = min(table_width, self.table_max_width)
815
+
816
+ position_map = {
817
+ Position.TOP_LEFT: (self.table_margin, self.table_margin),
818
+ Position.TOP_RIGHT: (
819
+ frame.shape[1] - table_width - self.table_margin,
820
+ self.table_margin,
821
+ ),
822
+ Position.BOTTOM_LEFT: (
823
+ self.table_margin,
824
+ frame.shape[0] - table_height - self.table_margin,
825
+ ),
826
+ Position.BOTTOM_RIGHT: (
827
+ frame.shape[1] - table_width - self.table_margin,
828
+ frame.shape[0] - table_height - self.table_margin,
829
+ ),
830
+ }
831
+ table_x1, table_y1 = position_map[self.table_position]
832
+
833
+ table_rect = Rect(
834
+ x=table_x1, y=table_y1, width=table_width, height=table_height
835
+ )
836
+ frame = draw_rectangle(
837
+ scene=frame, rect=table_rect, color=self.table_color, thickness=-1
838
+ )
839
+
840
+ for i, line in enumerate(text_lines):
841
+ _, text_height = cv2.getTextSize(
842
+ line, cv2.FONT_HERSHEY_SIMPLEX, self.text_scale, self.text_thickness
843
+ )[0]
844
+ text_height += TEXT_MARGIN
845
+ anchor_x = table_x1 + self.table_padding
846
+ anchor_y = table_y1 + self.table_padding + (i + 1) * text_height
847
+
848
+ cv2.putText(
849
+ img=frame,
850
+ text=line,
851
+ org=(anchor_x, anchor_y),
852
+ fontFace=cv2.FONT_HERSHEY_SIMPLEX,
853
+ fontScale=self.text_scale,
854
+ color=self.text_color.as_bgr(),
855
+ thickness=self.text_thickness,
856
+ lineType=cv2.LINE_AA,
857
+ )
858
+
859
+ return frame