eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
eye/detection/lmm.py ADDED
@@ -0,0 +1,184 @@
1
+ import re
2
+ from enum import Enum
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+ from eye.detection.utils import polygon_to_mask, polygon_to_xyxy
8
+
9
+
10
+ class LMM(Enum):
11
+ PALIGEMMA = "paligemma"
12
+ FLORENCE_2 = "florence_2"
13
+
14
+
15
+ RESULT_TYPES: Dict[LMM, type] = {LMM.PALIGEMMA: str, LMM.FLORENCE_2: dict}
16
+
17
+ REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {
18
+ LMM.PALIGEMMA: ["resolution_wh"],
19
+ LMM.FLORENCE_2: ["resolution_wh"],
20
+ }
21
+
22
+ ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {
23
+ LMM.PALIGEMMA: ["resolution_wh", "classes"],
24
+ LMM.FLORENCE_2: ["resolution_wh"],
25
+ }
26
+
27
+ SUPPORTED_TASKS_FLORENCE_2 = [
28
+ "<OD>",
29
+ "<CAPTION_TO_PHRASE_GROUNDING>",
30
+ "<DENSE_REGION_CAPTION>",
31
+ "<REGION_PROPOSAL>",
32
+ "<OCR_WITH_REGION>",
33
+ "<REFERRING_EXPRESSION_SEGMENTATION>",
34
+ "<REGION_TO_SEGMENTATION>",
35
+ "<OPEN_VOCABULARY_DETECTION>",
36
+ "<REGION_TO_CATEGORY>",
37
+ "<REGION_TO_DESCRIPTION>",
38
+ ]
39
+
40
+
41
+ def validate_lmm_parameters(
42
+ lmm: Union[LMM, str], result: Any, kwargs: Dict[str, Any]
43
+ ) -> LMM:
44
+ if isinstance(lmm, str):
45
+ try:
46
+ lmm = LMM(lmm.lower())
47
+ except ValueError:
48
+ raise ValueError(
49
+ f"Invalid lmm value: {lmm}. Must be one of {[e.value for e in LMM]}"
50
+ )
51
+
52
+ if not isinstance(result, RESULT_TYPES[lmm]):
53
+ raise ValueError(
54
+ f"Invalid LMM result type: {type(result)}. Must be {RESULT_TYPES[lmm]}"
55
+ )
56
+
57
+ required_args = REQUIRED_ARGUMENTS.get(lmm, [])
58
+ for arg in required_args:
59
+ if arg not in kwargs:
60
+ raise ValueError(f"Missing required argument: {arg}")
61
+
62
+ allowed_args = ALLOWED_ARGUMENTS.get(lmm, [])
63
+ for arg in kwargs:
64
+ if arg not in allowed_args:
65
+ raise ValueError(f"Argument {arg} is not allowed for {lmm.name}")
66
+
67
+ return lmm
68
+
69
+
70
+ def from_paligemma(
71
+ result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
72
+ ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
73
+ w, h = resolution_wh
74
+ pattern = re.compile(
75
+ r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
76
+ )
77
+ matches = pattern.findall(result)
78
+ matches = np.array(matches) if matches else np.empty((0, 5))
79
+
80
+ xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
81
+ xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
82
+ class_name = np.char.strip(class_name.astype(str))
83
+ class_id = None
84
+
85
+ if classes is not None:
86
+ mask = np.array([name in classes for name in class_name]).astype(bool)
87
+ xyxy, class_name = xyxy[mask], class_name[mask]
88
+ class_id = np.array([classes.index(name) for name in class_name])
89
+
90
+ return xyxy, class_id, class_name
91
+
92
+
93
+ def from_florence_2(
94
+ result: dict, resolution_wh: Tuple[int, int]
95
+ ) -> Tuple[
96
+ np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]
97
+ ]:
98
+ """
99
+ Parse results from the Florence 2 multi-model model.
100
+ https://huggingface.co/microsoft/Florence-2-large
101
+
102
+ Parameters:
103
+ result: dict containing the model output
104
+
105
+ Returns:
106
+ xyxy (np.ndarray): An array of shape `(n, 4)` containing
107
+ the bounding boxes coordinates in format `[x1, y1, x2, y2]`
108
+ labels: (Optional[np.ndarray]): An array of shape `(n,)` containing
109
+ the class labels for each bounding box
110
+ masks: (Optional[np.ndarray]): An array of shape `(n, h, w)` containing
111
+ the segmentation masks for each bounding box
112
+ obb_boxes: (Optional[np.ndarray]): An array of shape `(n, 4, 2)` containing
113
+ oriented bounding boxes.
114
+ """
115
+ assert len(result) == 1, f"Expected result with a single element. Got: {result}"
116
+ task = next(iter(result.keys()))
117
+ if task not in SUPPORTED_TASKS_FLORENCE_2:
118
+ raise ValueError(
119
+ f"{task} not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}"
120
+ )
121
+ result = result[task]
122
+
123
+ if task in ["<OD>", "<CAPTION_TO_PHRASE_GROUNDING>", "<DENSE_REGION_CAPTION>"]:
124
+ xyxy = np.array(result["bboxes"], dtype=np.float32)
125
+ labels = np.array(result["labels"])
126
+ return xyxy, labels, None, None
127
+
128
+ if task == "<REGION_PROPOSAL>":
129
+ xyxy = np.array(result["bboxes"], dtype=np.float32)
130
+ # provides labels, but they are ["", "", "", ...]
131
+ return xyxy, None, None, None
132
+
133
+ if task == "<OCR_WITH_REGION>":
134
+ xyxyxyxy = np.array(result["quad_boxes"], dtype=np.float32)
135
+ xyxyxyxy = xyxyxyxy.reshape(-1, 4, 2)
136
+ xyxy = np.array([polygon_to_xyxy(polygon) for polygon in xyxyxyxy])
137
+ labels = np.array(result["labels"])
138
+ return xyxy, labels, None, xyxyxyxy
139
+
140
+ if task in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
141
+ xyxy_list = []
142
+ masks_list = []
143
+ for polygons_of_same_class in result["polygons"]:
144
+ for polygon in polygons_of_same_class:
145
+ polygon = np.reshape(polygon, (-1, 2)).astype(np.int32)
146
+ mask = polygon_to_mask(polygon, resolution_wh).astype(bool)
147
+ masks_list.append(mask)
148
+ xyxy = polygon_to_xyxy(polygon)
149
+ xyxy_list.append(xyxy)
150
+ # per-class labels also provided, but they are ["", "", "", ...]
151
+ # when we figure out how to set class names, we can do
152
+ # zip(result["labels"], result["polygons"])
153
+ xyxy = np.array(xyxy_list, dtype=np.float32)
154
+ masks = np.array(masks_list)
155
+ return xyxy, None, masks, None
156
+
157
+ if task == "<OPEN_VOCABULARY_DETECTION>":
158
+ xyxy = np.array(result["bboxes"], dtype=np.float32)
159
+ labels = np.array(result["bboxes_labels"])
160
+ # Also has "polygons" and "polygons_labels", but they don't seem to be used
161
+ return xyxy, labels, None, None
162
+
163
+ if task in ["<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>"]:
164
+ assert isinstance(
165
+ result, str
166
+ ), f"Expected string as <REGION_TO_CATEGORY> result, got {type(result)}"
167
+
168
+ if result == "No object detected.":
169
+ return np.empty((0, 4), dtype=np.float32), np.array([]), None, None
170
+
171
+ pattern = re.compile(r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>")
172
+ match = pattern.search(result)
173
+ assert (
174
+ match is not None
175
+ ), f"Expected string to end in location tags, but got {result}"
176
+
177
+ w, h = resolution_wh
178
+ xyxy = np.array([match.groups()], dtype=np.float32)
179
+ xyxy *= np.array([w, h, w, h]) / 1000
180
+ result_string = result[: match.start()]
181
+ labels = np.array([result_string])
182
+ return xyxy, labels, None, None
183
+
184
+ assert False, f"Unimplemented task: {task}"
@@ -0,0 +1,270 @@
1
+ from __future__ import annotations
2
+
3
+ from enum import Enum
4
+ from typing import List, Union
5
+
6
+ import numpy as np
7
+ import numpy.typing as npt
8
+
9
+ from eye.detection.utils import box_iou_batch, mask_iou_batch
10
+
11
+
12
+ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
13
+ """
14
+ Resize all masks in the array to have a maximum dimension of max_dimension,
15
+ maintaining aspect ratio.
16
+
17
+ Args:
18
+ masks (np.ndarray): 3D array of binary masks with shape (N, H, W).
19
+ max_dimension (int): The maximum dimension for the resized masks.
20
+
21
+ Returns:
22
+ np.ndarray: Array of resized masks.
23
+ """
24
+ max_height = np.max(masks.shape[1])
25
+ max_width = np.max(masks.shape[2])
26
+ scale = min(max_dimension / max_height, max_dimension / max_width)
27
+
28
+ new_height = int(scale * max_height)
29
+ new_width = int(scale * max_width)
30
+
31
+ x = np.linspace(0, max_width - 1, new_width).astype(int)
32
+ y = np.linspace(0, max_height - 1, new_height).astype(int)
33
+ xv, yv = np.meshgrid(x, y)
34
+
35
+ resized_masks = masks[:, yv, xv]
36
+
37
+ resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width)
38
+ return resized_masks
39
+
40
+
41
+ def mask_non_max_suppression(
42
+ predictions: np.ndarray,
43
+ masks: np.ndarray,
44
+ iou_threshold: float = 0.5,
45
+ mask_dimension: int = 640,
46
+ ) -> np.ndarray:
47
+ """
48
+ Perform Non-Maximum Suppression (NMS) on segmentation predictions.
49
+
50
+ Args:
51
+ predictions (np.ndarray): A 2D array of object detection predictions in
52
+ the format of `(x_min, y_min, x_max, y_max, score)`
53
+ or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
54
+ `(N, 6)`, where N is the number of predictions.
55
+ masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
56
+ Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
57
+ dimensions of each mask.
58
+ iou_threshold (float): The intersection-over-union threshold
59
+ to use for non-maximum suppression.
60
+ mask_dimension (int): The dimension to which the masks should be
61
+ resized before computing IOU values. Defaults to 640.
62
+
63
+ Returns:
64
+ np.ndarray: A boolean array indicating which predictions to keep after
65
+ non-maximum suppression.
66
+
67
+ Raises:
68
+ AssertionError: If `iou_threshold` is not within the closed
69
+ range from `0` to `1`.
70
+ """
71
+ assert 0 <= iou_threshold <= 1, (
72
+ "Value of `iou_threshold` must be in the closed range from 0 to 1, "
73
+ f"{iou_threshold} given."
74
+ )
75
+ rows, columns = predictions.shape
76
+
77
+ if columns == 5:
78
+ predictions = np.c_[predictions, np.zeros(rows)]
79
+
80
+ sort_index = predictions[:, 4].argsort()[::-1]
81
+ predictions = predictions[sort_index]
82
+ masks = masks[sort_index]
83
+ masks_resized = resize_masks(masks, mask_dimension)
84
+ ious = mask_iou_batch(masks_resized, masks_resized)
85
+ categories = predictions[:, 5]
86
+
87
+ keep = np.ones(rows, dtype=bool)
88
+ for i in range(rows):
89
+ if keep[i]:
90
+ condition = (ious[i] > iou_threshold) & (categories[i] == categories)
91
+ keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])
92
+
93
+ return keep[sort_index.argsort()]
94
+
95
+
96
+ def box_non_max_suppression(
97
+ predictions: np.ndarray, iou_threshold: float = 0.5
98
+ ) -> np.ndarray:
99
+ """
100
+ Perform Non-Maximum Suppression (NMS) on object detection predictions.
101
+
102
+ Args:
103
+ predictions (np.ndarray): An array of object detection predictions in
104
+ the format of `(x_min, y_min, x_max, y_max, score)`
105
+ or `(x_min, y_min, x_max, y_max, score, class)`.
106
+ iou_threshold (float): The intersection-over-union threshold
107
+ to use for non-maximum suppression.
108
+
109
+ Returns:
110
+ np.ndarray: A boolean array indicating which predictions to keep after n
111
+ on-maximum suppression.
112
+
113
+ Raises:
114
+ AssertionError: If `iou_threshold` is not within the
115
+ closed range from `0` to `1`.
116
+ """
117
+ assert 0 <= iou_threshold <= 1, (
118
+ "Value of `iou_threshold` must be in the closed range from 0 to 1, "
119
+ f"{iou_threshold} given."
120
+ )
121
+ rows, columns = predictions.shape
122
+
123
+ # add column #5 - category filled with zeros for agnostic nms
124
+ if columns == 5:
125
+ predictions = np.c_[predictions, np.zeros(rows)]
126
+
127
+ # sort predictions column #4 - score
128
+ sort_index = np.flip(predictions[:, 4].argsort())
129
+ predictions = predictions[sort_index]
130
+
131
+ boxes = predictions[:, :4]
132
+ categories = predictions[:, 5]
133
+ ious = box_iou_batch(boxes, boxes)
134
+ ious = ious - np.eye(rows)
135
+
136
+ keep = np.ones(rows, dtype=bool)
137
+
138
+ for index, (iou, category) in enumerate(zip(ious, categories)):
139
+ if not keep[index]:
140
+ continue
141
+
142
+ # drop detections with iou > iou_threshold and
143
+ # same category as current detections
144
+ condition = (iou > iou_threshold) & (categories == category)
145
+ keep = keep & ~condition
146
+
147
+ return keep[sort_index.argsort()]
148
+
149
+
150
+ def group_overlapping_boxes(
151
+ predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
152
+ ) -> List[List[int]]:
153
+ """
154
+ Apply greedy version of non-maximum merging to avoid detecting too many
155
+ overlapping bounding boxes for a given object.
156
+
157
+ Args:
158
+ predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
159
+ the bounding boxes coordinates in format `[x1, y1, x2, y2]`
160
+ and the confidence scores.
161
+ iou_threshold (float): The intersection-over-union threshold
162
+ to use for non-maximum suppression. Defaults to 0.5.
163
+
164
+ Returns:
165
+ List[List[int]]: Groups of prediction indices be merged.
166
+ Each group may have 1 or more elements.
167
+ """
168
+ merge_groups: List[List[int]] = []
169
+
170
+ scores = predictions[:, 4]
171
+ order = scores.argsort()
172
+
173
+ while len(order) > 0:
174
+ idx = int(order[-1])
175
+
176
+ order = order[:-1]
177
+ if len(order) == 0:
178
+ merge_groups.append([idx])
179
+ break
180
+
181
+ merge_candidate = np.expand_dims(predictions[idx], axis=0)
182
+ ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
183
+ ious = ious.flatten()
184
+
185
+ above_threshold = ious >= iou_threshold
186
+ merge_group = [idx, *np.flip(order[above_threshold]).tolist()]
187
+ merge_groups.append(merge_group)
188
+ order = order[~above_threshold]
189
+ return merge_groups
190
+
191
+
192
+ def box_non_max_merge(
193
+ predictions: npt.NDArray[np.float64],
194
+ iou_threshold: float = 0.5,
195
+ ) -> List[List[int]]:
196
+ """
197
+ Apply greedy version of non-maximum merging per category to avoid detecting
198
+ too many overlapping bounding boxes for a given object.
199
+
200
+ Args:
201
+ predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)`
202
+ containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`,
203
+ the confidence scores and class_ids. Omit class_id column to allow
204
+ detections of different classes to be merged.
205
+ iou_threshold (float): The intersection-over-union threshold
206
+ to use for non-maximum suppression. Defaults to 0.5.
207
+
208
+ Returns:
209
+ List[List[int]]: Groups of prediction indices be merged.
210
+ Each group may have 1 or more elements.
211
+ """
212
+ if predictions.shape[1] == 5:
213
+ return group_overlapping_boxes(predictions, iou_threshold)
214
+
215
+ category_ids = predictions[:, 5]
216
+ merge_groups = []
217
+ for category_id in np.unique(category_ids):
218
+ curr_indices = np.where(category_ids == category_id)[0]
219
+ merge_class_groups = group_overlapping_boxes(
220
+ predictions[curr_indices], iou_threshold
221
+ )
222
+
223
+ for merge_class_group in merge_class_groups:
224
+ merge_groups.append(curr_indices[merge_class_group].tolist())
225
+
226
+ for merge_group in merge_groups:
227
+ if len(merge_group) == 0:
228
+ raise ValueError(
229
+ f"Empty group detected when non-max-merging "
230
+ f"detections: {merge_groups}"
231
+ )
232
+ return merge_groups
233
+
234
+
235
+ class OverlapFilter(Enum):
236
+ """
237
+ Enum specifying the strategy for filtering overlapping detections.
238
+
239
+ Attributes:
240
+ NONE: Do not filter detections based on overlap.
241
+ NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means,
242
+ detections that overlap by more than a set threshold will be discarded,
243
+ except for the one with the highest confidence.
244
+ NON_MAX_MERGE: Merge detections with non-max merging. This means,
245
+ detections that overlap by more than a set threshold will be merged
246
+ into a single detection.
247
+ """
248
+
249
+ NONE = "none"
250
+ NON_MAX_SUPPRESSION = "non_max_suppression"
251
+ NON_MAX_MERGE = "non_max_merge"
252
+
253
+ @classmethod
254
+ def list(cls):
255
+ return list(map(lambda c: c.value, cls))
256
+
257
+ @classmethod
258
+ def from_value(cls, value: Union[OverlapFilter, str]) -> OverlapFilter:
259
+ if isinstance(value, cls):
260
+ return value
261
+ if isinstance(value, str):
262
+ value = value.lower()
263
+ try:
264
+ return cls(value)
265
+ except ValueError:
266
+ raise ValueError(f"Invalid value: {value}. Must be one of {cls.list()}")
267
+ raise ValueError(
268
+ f"Invalid value type: {type(value)}. Must be an instance of "
269
+ f"{cls.__name__} or str."
270
+ )
File without changes
@@ -0,0 +1,181 @@
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import os
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from eye.detection.core import Detections
8
+
9
+ BASE_HEADER = [
10
+ "x_min",
11
+ "y_min",
12
+ "x_max",
13
+ "y_max",
14
+ "class_id",
15
+ "confidence",
16
+ "tracker_id",
17
+ ]
18
+
19
+
20
+ class CSVSink:
21
+ """
22
+ A utility class for saving detection data to a CSV file. This class is designed to
23
+ efficiently serialize detection objects into a CSV format, allowing for the
24
+ inclusion of bounding box coordinates and additional attributes like `confidence`,
25
+ `class_id`, and `tracker_id`.
26
+
27
+ !!! tip
28
+
29
+ CSVSink allow to pass custom data alongside the detection fields, providing
30
+ flexibility for logging various types of information.
31
+
32
+ Args:
33
+ file_name (str): The name of the CSV file where the detections will be stored.
34
+ Defaults to 'output.csv'.
35
+
36
+ Example:
37
+ ```python
38
+ import eye as sv
39
+ from ultralytics import YOLO
40
+
41
+ model = YOLO(<SOURCE_MODEL_PATH>)
42
+ csv_sink = sv.CSVSink(<RESULT_CSV_FILE_PATH>)
43
+ frames_generator = sv.get_video_frames_generator(<SOURCE_VIDEO_PATH>)
44
+
45
+ with csv_sink as sink:
46
+ for frame in frames_generator:
47
+ result = model(frame)[0]
48
+ detections = sv.Detections.from_ultralytics(result)
49
+ sink.append(detections, custom_data={'<CUSTOM_LABEL>':'<CUSTOM_DATA>'})
50
+ ```
51
+ """
52
+
53
+ def __init__(self, file_name: str = "output.csv") -> None:
54
+ """
55
+ Initialize the CSVSink instance.
56
+
57
+ Args:
58
+ file_name (str): The name of the CSV file.
59
+
60
+ Returns:
61
+ None
62
+ """
63
+ self.file_name = file_name
64
+ self.file: Optional[open] = None
65
+ self.writer: Optional[csv.writer] = None
66
+ self.header_written = False
67
+ self.field_names = []
68
+
69
+ def __enter__(self) -> CSVSink:
70
+ self.open()
71
+ return self
72
+
73
+ def __exit__(
74
+ self,
75
+ exc_type: Optional[type],
76
+ exc_val: Optional[Exception],
77
+ exc_tb: Optional[Any],
78
+ ) -> None:
79
+ self.close()
80
+
81
+ def open(self) -> None:
82
+ """
83
+ Open the CSV file for writing.
84
+
85
+ Returns:
86
+ None
87
+ """
88
+ parent_directory = os.path.dirname(self.file_name)
89
+ if parent_directory and not os.path.exists(parent_directory):
90
+ os.makedirs(parent_directory)
91
+
92
+ self.file = open(self.file_name, "w", newline="")
93
+ self.writer = csv.writer(self.file)
94
+
95
+ def close(self) -> None:
96
+ """
97
+ Close the CSV file.
98
+
99
+ Returns:
100
+ None
101
+ """
102
+ if self.file:
103
+ self.file.close()
104
+
105
+ @staticmethod
106
+ def parse_detection_data(
107
+ detections: Detections, custom_data: Optional[Dict[str, Any]] = None
108
+ ) -> List[Dict[str, Any]]:
109
+ parsed_rows = []
110
+ for i in range(len(detections.xyxy)):
111
+ row = {
112
+ "x_min": detections.xyxy[i][0],
113
+ "y_min": detections.xyxy[i][1],
114
+ "x_max": detections.xyxy[i][2],
115
+ "y_max": detections.xyxy[i][3],
116
+ "class_id": ""
117
+ if detections.class_id is None
118
+ else str(detections.class_id[i]),
119
+ "confidence": ""
120
+ if detections.confidence is None
121
+ else str(detections.confidence[i]),
122
+ "tracker_id": ""
123
+ if detections.tracker_id is None
124
+ else str(detections.tracker_id[i]),
125
+ }
126
+
127
+ if hasattr(detections, "data"):
128
+ for key, value in detections.data.items():
129
+ if value.ndim == 0:
130
+ row[key] = value
131
+ else:
132
+ row[key] = value[i]
133
+
134
+ if custom_data:
135
+ row.update(custom_data)
136
+ parsed_rows.append(row)
137
+ return parsed_rows
138
+
139
+ def append(
140
+ self, detections: Detections, custom_data: Optional[Dict[str, Any]] = None
141
+ ) -> None:
142
+ """
143
+ Append detection data to the CSV file.
144
+
145
+ Args:
146
+ detections (Detections): The detection data.
147
+ custom_data (Dict[str, Any]): Custom data to include.
148
+
149
+ Returns:
150
+ None
151
+ """
152
+ if not self.writer:
153
+ raise Exception(
154
+ f"Cannot append to CSV: The file '{self.file_name}' is not open."
155
+ )
156
+ field_names = CSVSink.parse_field_names(detections, custom_data)
157
+ if not self.header_written:
158
+ self.field_names = field_names
159
+ self.writer.writerow(field_names)
160
+ self.header_written = True
161
+
162
+ if field_names != self.field_names:
163
+ print(
164
+ f"Field names do not match the header. "
165
+ f"Expected: {self.field_names}, given: {field_names}"
166
+ )
167
+
168
+ parsed_rows = CSVSink.parse_detection_data(detections, custom_data)
169
+ for row in parsed_rows:
170
+ self.writer.writerow(
171
+ [row.get(field_name, "") for field_name in self.field_names]
172
+ )
173
+
174
+ @staticmethod
175
+ def parse_field_names(
176
+ detections: Detections, custom_data: Dict[str, Any]
177
+ ) -> List[str]:
178
+ dynamic_header = sorted(
179
+ set(custom_data.keys()) | set(getattr(detections, "data", {}).keys())
180
+ )
181
+ return BASE_HEADER + dynamic_header