eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,205 @@
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+ import scipy.linalg
5
+
6
+
7
+ class KalmanFilter:
8
+ """
9
+ A simple Kalman filter for tracking bounding boxes in image space.
10
+
11
+ The 8-dimensional state space
12
+
13
+ x, y, a, h, vx, vy, va, vh
14
+
15
+ contains the bounding box center position (x, y), aspect ratio a, height h,
16
+ and their respective velocities.
17
+
18
+ Object motion follows a constant velocity model. The bounding box location
19
+ (x, y, a, h) is taken as direct observation of the state space (linear
20
+ observation model).
21
+ """
22
+
23
+ def __init__(self):
24
+ ndim, dt = 4, 1.0
25
+
26
+ self._motion_mat = np.eye(2 * ndim, 2 * ndim)
27
+ for i in range(ndim):
28
+ self._motion_mat[i, ndim + i] = dt
29
+ self._update_mat = np.eye(ndim, 2 * ndim)
30
+ self._std_weight_position = 1.0 / 20
31
+ self._std_weight_velocity = 1.0 / 160
32
+
33
+ def initiate(self, measurement: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
34
+ """
35
+ Create track from an unassociated measurement.
36
+
37
+ Args:
38
+ measurement (ndarray): Bounding box coordinates (x, y, a, h) with
39
+ center position (x, y), aspect ratio a, and height h.
40
+
41
+ Returns:
42
+ Tuple[ndarray, ndarray]: Returns the mean vector (8 dimensional) and
43
+ covariance matrix (8x8 dimensional) of the new track.
44
+ Unobserved velocities are initialized to 0 mean.
45
+ """
46
+ mean_pos = measurement
47
+ mean_vel = np.zeros_like(mean_pos)
48
+ mean = np.r_[mean_pos, mean_vel]
49
+
50
+ std = [
51
+ 2 * self._std_weight_position * measurement[3],
52
+ 2 * self._std_weight_position * measurement[3],
53
+ 1e-2,
54
+ 2 * self._std_weight_position * measurement[3],
55
+ 10 * self._std_weight_velocity * measurement[3],
56
+ 10 * self._std_weight_velocity * measurement[3],
57
+ 1e-5,
58
+ 10 * self._std_weight_velocity * measurement[3],
59
+ ]
60
+ covariance = np.diag(np.square(std))
61
+ return mean, covariance
62
+
63
+ def predict(
64
+ self, mean: np.ndarray, covariance: np.ndarray
65
+ ) -> Tuple[np.ndarray, np.ndarray]:
66
+ """
67
+ Run Kalman filter prediction step.
68
+
69
+ Args:
70
+ mean (ndarray): The 8 dimensional mean vector of the object
71
+ state at the previous time step.
72
+ covariance (ndarray): The 8x8 dimensional covariance matrix of
73
+ the object state at the previous time step.
74
+
75
+ Returns:
76
+ Tuple[ndarray, ndarray]: Returns the mean vector and
77
+ covariance matrix of the predicted state.
78
+ Unobserved velocities are initialized to 0 mean.
79
+ """
80
+ std_pos = [
81
+ self._std_weight_position * mean[3],
82
+ self._std_weight_position * mean[3],
83
+ 1e-2,
84
+ self._std_weight_position * mean[3],
85
+ ]
86
+ std_vel = [
87
+ self._std_weight_velocity * mean[3],
88
+ self._std_weight_velocity * mean[3],
89
+ 1e-5,
90
+ self._std_weight_velocity * mean[3],
91
+ ]
92
+ motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
93
+
94
+ mean = np.dot(mean, self._motion_mat.T)
95
+ covariance = (
96
+ np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T))
97
+ + motion_cov
98
+ )
99
+
100
+ return mean, covariance
101
+
102
+ def project(
103
+ self, mean: np.ndarray, covariance: np.ndarray
104
+ ) -> Tuple[np.ndarray, np.ndarray]:
105
+ """
106
+ Project state distribution to measurement space.
107
+
108
+ Args:
109
+ mean (ndarray): The state's mean vector (8 dimensional array).
110
+ covariance (ndarray): The state's covariance matrix (8x8 dimensional).
111
+
112
+ Returns:
113
+ Tuple[ndarray, ndarray]: Returns the projected mean and
114
+ covariance matrix of the given state estimate.
115
+ """
116
+ std = [
117
+ self._std_weight_position * mean[3],
118
+ self._std_weight_position * mean[3],
119
+ 1e-1,
120
+ self._std_weight_position * mean[3],
121
+ ]
122
+ innovation_cov = np.diag(np.square(std))
123
+
124
+ mean = np.dot(self._update_mat, mean)
125
+ covariance = np.linalg.multi_dot(
126
+ (self._update_mat, covariance, self._update_mat.T)
127
+ )
128
+ return mean, covariance + innovation_cov
129
+
130
+ def multi_predict(
131
+ self, mean: np.ndarray, covariance: np.ndarray
132
+ ) -> Tuple[np.ndarray, np.ndarray]:
133
+ """
134
+ Run Kalman filter prediction step (Vectorized version).
135
+
136
+ Args:
137
+ mean (ndarray): The Nx8 dimensional mean matrix
138
+ of the object states at the previous time step.
139
+ covariance (ndarray): The Nx8x8 dimensional covariance matrices
140
+ of the object states at the previous time step.
141
+
142
+ Returns:
143
+ Tuple[ndarray, ndarray]: Returns the mean vector and
144
+ covariance matrix of the predicted state.
145
+ Unobserved velocities are initialized to 0 mean.
146
+ """
147
+ std_pos = [
148
+ self._std_weight_position * mean[:, 3],
149
+ self._std_weight_position * mean[:, 3],
150
+ 1e-2 * np.ones_like(mean[:, 3]),
151
+ self._std_weight_position * mean[:, 3],
152
+ ]
153
+ std_vel = [
154
+ self._std_weight_velocity * mean[:, 3],
155
+ self._std_weight_velocity * mean[:, 3],
156
+ 1e-5 * np.ones_like(mean[:, 3]),
157
+ self._std_weight_velocity * mean[:, 3],
158
+ ]
159
+ sqr = np.square(np.r_[std_pos, std_vel]).T
160
+
161
+ motion_cov = []
162
+ for i in range(len(mean)):
163
+ motion_cov.append(np.diag(sqr[i]))
164
+ motion_cov = np.asarray(motion_cov)
165
+
166
+ mean = np.dot(mean, self._motion_mat.T)
167
+ left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
168
+ covariance = np.dot(left, self._motion_mat.T) + motion_cov
169
+
170
+ return mean, covariance
171
+
172
+ def update(
173
+ self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray
174
+ ) -> Tuple[np.ndarray, np.ndarray]:
175
+ """
176
+ Run Kalman filter correction step.
177
+
178
+ Args:
179
+ mean (ndarray): The predicted state's mean vector (8 dimensional).
180
+ covariance (ndarray): The state's covariance matrix (8x8 dimensional).
181
+ measurement (ndarray): The 4-dimensional measurement vector (x, y, a, h),
182
+ where (x, y) is the center position, a the aspect ratio,
183
+ and h the height of the bounding box.
184
+
185
+ Returns:
186
+ Tuple[ndarray, ndarray]: Returns the measurement-corrected
187
+ state distribution.
188
+ """
189
+ projected_mean, projected_cov = self.project(mean, covariance)
190
+
191
+ chol_factor, lower = scipy.linalg.cho_factor(
192
+ projected_cov, lower=True, check_finite=False
193
+ )
194
+ kalman_gain = scipy.linalg.cho_solve(
195
+ (chol_factor, lower),
196
+ np.dot(covariance, self._update_mat.T).T,
197
+ check_finite=False,
198
+ ).T
199
+ innovation = measurement - projected_mean
200
+
201
+ new_mean = mean + np.dot(innovation, kalman_gain.T)
202
+ new_covariance = covariance - np.linalg.multi_dot(
203
+ (kalman_gain, projected_cov, kalman_gain.T)
204
+ )
205
+ return new_mean, new_covariance
@@ -0,0 +1,69 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List, Tuple
4
+
5
+ import numpy as np
6
+ from scipy.optimize import linear_sum_assignment
7
+
8
+ from eye.detection.utils import box_iou_batch
9
+
10
+ if TYPE_CHECKING:
11
+ from eye.tracker.byte_tracker.core import STrack
12
+
13
+
14
+ def indices_to_matches(
15
+ cost_matrix: np.ndarray, indices: np.ndarray, thresh: float
16
+ ) -> Tuple[np.ndarray, tuple, tuple]:
17
+ matched_cost = cost_matrix[tuple(zip(*indices))]
18
+ matched_mask = matched_cost <= thresh
19
+
20
+ matches = indices[matched_mask]
21
+ unmatched_a = tuple(set(range(cost_matrix.shape[0])) - set(matches[:, 0]))
22
+ unmatched_b = tuple(set(range(cost_matrix.shape[1])) - set(matches[:, 1]))
23
+ return matches, unmatched_a, unmatched_b
24
+
25
+
26
+ def linear_assignment(
27
+ cost_matrix: np.ndarray, thresh: float
28
+ ) -> Tuple[np.ndarray, Tuple[int], Tuple[int, int]]:
29
+ if cost_matrix.size == 0:
30
+ return (
31
+ np.empty((0, 2), dtype=int),
32
+ tuple(range(cost_matrix.shape[0])),
33
+ tuple(range(cost_matrix.shape[1])),
34
+ )
35
+
36
+ cost_matrix[cost_matrix > thresh] = thresh + 1e-4
37
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
38
+ indices = np.column_stack((row_ind, col_ind))
39
+
40
+ return indices_to_matches(cost_matrix, indices, thresh)
41
+
42
+
43
+ def iou_distance(atracks: List[STrack], btracks: List[STrack]) -> np.ndarray:
44
+ if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
45
+ len(btracks) > 0 and isinstance(btracks[0], np.ndarray)
46
+ ):
47
+ atlbrs = atracks
48
+ btlbrs = btracks
49
+ else:
50
+ atlbrs = [track.tlbr for track in atracks]
51
+ btlbrs = [track.tlbr for track in btracks]
52
+
53
+ _ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
54
+ if _ious.size != 0:
55
+ _ious = box_iou_batch(np.asarray(atlbrs), np.asarray(btlbrs))
56
+ cost_matrix = 1 - _ious
57
+
58
+ return cost_matrix
59
+
60
+
61
+ def fuse_score(cost_matrix: np.ndarray, stracks: List[STrack]) -> np.ndarray:
62
+ if cost_matrix.size == 0:
63
+ return cost_matrix
64
+ iou_sim = 1 - cost_matrix
65
+ det_scores = np.array([strack.score for strack in stracks])
66
+ det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
67
+ fuse_sim = iou_sim * det_scores
68
+ fuse_cost = 1 - fuse_sim
69
+ return fuse_cost
@@ -0,0 +1,178 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+ from eye.tracker.byte_tracker.kalman_filter import KalmanFilter
10
+ from eye.tracker.byte_tracker.utils import IdCounter
11
+
12
+
13
+ class TrackState(Enum):
14
+ New = 0
15
+ Tracked = 1
16
+ Lost = 2
17
+ Removed = 3
18
+
19
+
20
+ class STrack:
21
+ def __init__(
22
+ self,
23
+ tlwh: npt.NDArray[np.float32],
24
+ score: npt.NDArray[np.float32],
25
+ minimum_consecutive_frames: int,
26
+ shared_kalman: KalmanFilter,
27
+ internal_id_counter: IdCounter,
28
+ external_id_counter: IdCounter,
29
+ ):
30
+ self.state = TrackState.New
31
+ self.is_activated = False
32
+ self.start_frame = 0
33
+ self.frame_id = 0
34
+
35
+ self._tlwh = np.asarray(tlwh, dtype=np.float32)
36
+ self.kalman_filter = None
37
+ self.shared_kalman = shared_kalman
38
+ self.mean, self.covariance = None, None
39
+ self.is_activated = False
40
+
41
+ self.score = score
42
+ self.tracklet_len = 0
43
+
44
+ self.minimum_consecutive_frames = minimum_consecutive_frames
45
+
46
+ self.internal_id_counter = internal_id_counter
47
+ self.external_id_counter = external_id_counter
48
+ self.internal_track_id = self.internal_id_counter.NO_ID
49
+ self.external_track_id = self.external_id_counter.NO_ID
50
+
51
+ def predict(self) -> None:
52
+ mean_state = self.mean.copy()
53
+ if self.state != TrackState.Tracked:
54
+ mean_state[7] = 0
55
+ self.mean, self.covariance = self.kalman_filter.predict(
56
+ mean_state, self.covariance
57
+ )
58
+
59
+ @staticmethod
60
+ def multi_predict(stracks: List[STrack], shared_kalman: KalmanFilter) -> None:
61
+ if len(stracks) > 0:
62
+ multi_mean = []
63
+ multi_covariance = []
64
+ for i, st in enumerate(stracks):
65
+ multi_mean.append(st.mean.copy())
66
+ multi_covariance.append(st.covariance)
67
+ if st.state != TrackState.Tracked:
68
+ multi_mean[i][7] = 0
69
+
70
+ multi_mean, multi_covariance = shared_kalman.multi_predict(
71
+ np.asarray(multi_mean), np.asarray(multi_covariance)
72
+ )
73
+ for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
74
+ stracks[i].mean = mean
75
+ stracks[i].covariance = cov
76
+
77
+ def activate(self, kalman_filter: KalmanFilter, frame_id: int) -> None:
78
+ """Start a new tracklet"""
79
+ self.kalman_filter = kalman_filter
80
+ self.internal_track_id = self.internal_id_counter.new_id()
81
+ self.mean, self.covariance = self.kalman_filter.initiate(
82
+ self.tlwh_to_xyah(self._tlwh)
83
+ )
84
+
85
+ self.tracklet_len = 0
86
+ self.state = TrackState.Tracked
87
+ if frame_id == 1:
88
+ self.is_activated = True
89
+
90
+ if self.minimum_consecutive_frames == 1:
91
+ self.external_track_id = self.external_id_counter.new_id()
92
+
93
+ self.frame_id = frame_id
94
+ self.start_frame = frame_id
95
+
96
+ def re_activate(self, new_track: STrack, frame_id: int) -> None:
97
+ self.mean, self.covariance = self.kalman_filter.update(
98
+ self.mean, self.covariance, self.tlwh_to_xyah(new_track.tlwh)
99
+ )
100
+ self.tracklet_len = 0
101
+ self.state = TrackState.Tracked
102
+
103
+ self.frame_id = frame_id
104
+ self.score = new_track.score
105
+
106
+ def update(self, new_track: STrack, frame_id: int) -> None:
107
+ """
108
+ Update a matched track
109
+ :type new_track: STrack
110
+ :type frame_id: int
111
+ :type update_feature: bool
112
+ :return:
113
+ """
114
+ self.frame_id = frame_id
115
+ self.tracklet_len += 1
116
+
117
+ new_tlwh = new_track.tlwh
118
+ self.mean, self.covariance = self.kalman_filter.update(
119
+ self.mean, self.covariance, self.tlwh_to_xyah(new_tlwh)
120
+ )
121
+ self.state = TrackState.Tracked
122
+ if self.tracklet_len == self.minimum_consecutive_frames:
123
+ self.is_activated = True
124
+ if self.external_track_id == self.external_id_counter.NO_ID:
125
+ self.external_track_id = self.external_id_counter.new_id()
126
+
127
+ self.score = new_track.score
128
+
129
+ @property
130
+ def tlwh(self) -> npt.NDArray[np.float32]:
131
+ """Get current position in bounding box format `(top left x, top left y,
132
+ width, height)`.
133
+ """
134
+ if self.mean is None:
135
+ return self._tlwh.copy()
136
+ ret = self.mean[:4].copy()
137
+ ret[2] *= ret[3]
138
+ ret[:2] -= ret[2:] / 2
139
+ return ret
140
+
141
+ @property
142
+ def tlbr(self) -> npt.NDArray[np.float32]:
143
+ """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
144
+ `(top left, bottom right)`.
145
+ """
146
+ ret = self.tlwh.copy()
147
+ ret[2:] += ret[:2]
148
+ return ret
149
+
150
+ @staticmethod
151
+ def tlwh_to_xyah(tlwh) -> npt.NDArray[np.float32]:
152
+ """Convert bounding box to format `(center x, center y, aspect ratio,
153
+ height)`, where the aspect ratio is `width / height`.
154
+ """
155
+ ret = np.asarray(tlwh).copy()
156
+ ret[:2] += ret[2:] / 2
157
+ ret[2] /= ret[3]
158
+ return ret
159
+
160
+ def to_xyah(self) -> npt.NDArray[np.float32]:
161
+ return self.tlwh_to_xyah(self.tlwh)
162
+
163
+ @staticmethod
164
+ def tlbr_to_tlwh(tlbr) -> npt.NDArray[np.float32]:
165
+ ret = np.asarray(tlbr).copy()
166
+ ret[2:] -= ret[:2]
167
+ return ret
168
+
169
+ @staticmethod
170
+ def tlwh_to_tlbr(tlwh) -> npt.NDArray[np.float32]:
171
+ ret = np.asarray(tlwh).copy()
172
+ ret[2:] += ret[:2]
173
+ return ret
174
+
175
+ def __repr__(self) -> str:
176
+ return "OT_{}_({}-{})".format(
177
+ self.internal_track_id, self.start_frame, self.frame_id
178
+ )
@@ -0,0 +1,18 @@
1
+ class IdCounter:
2
+ def __init__(self, start_id: int = 0):
3
+ self.start_id = start_id
4
+ if self.start_id <= self.NO_ID:
5
+ raise ValueError(f"start_id must be greater than {self.NO_ID}")
6
+ self.reset()
7
+
8
+ def reset(self) -> None:
9
+ self._id = self.start_id
10
+
11
+ def new_id(self) -> int:
12
+ returned_id = self._id
13
+ self._id += 1
14
+ return returned_id
15
+
16
+ @property
17
+ def NO_ID(self) -> int:
18
+ return -1
eye/utils/__init__.py ADDED
File without changes
@@ -0,0 +1,132 @@
1
+ from functools import wraps
2
+ from typing import List
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ from eye.annotators.base import ImageType
9
+
10
+
11
+ def ensure_cv2_image_for_annotation(annotate_func):
12
+ """
13
+ Decorates `BaseAnnotator.annotate` implementations, converts scene to
14
+ an image type used internally by the annotators, converts back when annotation
15
+ is complete.
16
+
17
+ Assumes the annotators modify the scene in-place.
18
+ """
19
+
20
+ @wraps(annotate_func)
21
+ def wrapper(self, scene: ImageType, *args, **kwargs):
22
+ if isinstance(scene, np.ndarray):
23
+ return annotate_func(self, scene, *args, **kwargs)
24
+
25
+ if isinstance(scene, Image.Image):
26
+ scene_np = pillow_to_cv2(scene)
27
+ annotated_np = annotate_func(self, scene_np, *args, **kwargs)
28
+ scene.paste(cv2_to_pillow(annotated_np))
29
+ return scene
30
+
31
+ raise ValueError(f"Unsupported image type: {type(scene)}")
32
+
33
+ return wrapper
34
+
35
+
36
+ def ensure_cv2_image_for_processing(image_processing_fun):
37
+ """
38
+ Decorates image processing functions that accept np.ndarray, converting `image` to
39
+ np.ndarray, converts back when processing is complete.
40
+
41
+ Assumes the annotators do NOT modify the scene in-place.
42
+ """
43
+
44
+ @wraps(image_processing_fun)
45
+ def wrapper(image: ImageType, *args, **kwargs):
46
+ if isinstance(image, np.ndarray):
47
+ return image_processing_fun(image, *args, **kwargs)
48
+
49
+ if isinstance(image, Image.Image):
50
+ scene = pillow_to_cv2(image)
51
+ annotated = image_processing_fun(scene, *args, **kwargs)
52
+ return cv2_to_pillow(annotated)
53
+
54
+ raise ValueError(f"Unsupported image type: {type(image)}")
55
+
56
+ return wrapper
57
+
58
+
59
+ def ensure_pil_image_for_annotation(annotate_func):
60
+ """
61
+ Decorates image processing functions that accept np.ndarray, converting `image` to
62
+ PIL image, converts back when processing is complete.
63
+
64
+ Assumes the annotators modify the scene in-place.
65
+ """
66
+
67
+ @wraps(annotate_func)
68
+ def wrapper(self, scene: ImageType, *args, **kwargs):
69
+ if isinstance(scene, np.ndarray):
70
+ scene_pil = cv2_to_pillow(scene)
71
+ annotated_pil = annotate_func(self, scene_pil, *args, **kwargs)
72
+ np.copyto(scene, pillow_to_cv2(annotated_pil))
73
+ return scene
74
+
75
+ if isinstance(scene, Image.Image):
76
+ return annotate_func(self, scene, *args, **kwargs)
77
+
78
+ raise ValueError(f"Unsupported image type: {type(scene)}")
79
+
80
+ return wrapper
81
+
82
+
83
+ def images_to_cv2(images: List[ImageType]) -> List[np.ndarray]:
84
+ """
85
+ Converts images provided either as Pillow images or OpenCV
86
+ images into OpenCV format.
87
+
88
+ Args:
89
+ images (List[ImageType]): Images to be converted
90
+
91
+ Returns:
92
+ List[np.ndarray]: List of input images in OpenCV format
93
+ (with order preserved).
94
+
95
+ """
96
+ result = []
97
+ for image in images:
98
+ if issubclass(type(image), Image.Image):
99
+ image = pillow_to_cv2(image)
100
+ result.append(image)
101
+ return result
102
+
103
+
104
+ def pillow_to_cv2(image: Image.Image) -> np.ndarray:
105
+ """
106
+ Converts Pillow image into OpenCV image, handling RGB -> BGR
107
+ conversion.
108
+
109
+ Args:
110
+ image (Image.Image): Pillow image (in RGB format).
111
+
112
+ Returns:
113
+ (np.ndarray): Input image converted to OpenCV format.
114
+ """
115
+ scene = np.array(image)
116
+ scene = cv2.cvtColor(scene, cv2.COLOR_RGB2BGR)
117
+ return scene
118
+
119
+
120
+ def cv2_to_pillow(image: np.ndarray) -> Image.Image:
121
+ """
122
+ Converts OpenCV image into Pillow image, handling BGR -> RGB
123
+ conversion.
124
+
125
+ Args:
126
+ image (np.ndarray): OpenCV image (in BGR format).
127
+
128
+ Returns:
129
+ (Image.Image): Input image converted to Pillow format.
130
+ """
131
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
132
+ return Image.fromarray(image)