eye-cv 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eye/__init__.py +115 -0
- eye/__init___supervision_original.py +120 -0
- eye/annotators/__init__.py +0 -0
- eye/annotators/base.py +22 -0
- eye/annotators/core.py +2699 -0
- eye/annotators/line.py +107 -0
- eye/annotators/modern.py +529 -0
- eye/annotators/trace.py +142 -0
- eye/annotators/utils.py +177 -0
- eye/assets/__init__.py +2 -0
- eye/assets/downloader.py +95 -0
- eye/assets/list.py +83 -0
- eye/classification/__init__.py +0 -0
- eye/classification/core.py +188 -0
- eye/config.py +2 -0
- eye/core/__init__.py +0 -0
- eye/core/trackers/__init__.py +1 -0
- eye/core/trackers/botsort_tracker.py +336 -0
- eye/core/trackers/bytetrack_tracker.py +284 -0
- eye/core/trackers/sort_tracker.py +200 -0
- eye/core/tracking.py +146 -0
- eye/dataset/__init__.py +0 -0
- eye/dataset/core.py +919 -0
- eye/dataset/formats/__init__.py +0 -0
- eye/dataset/formats/coco.py +258 -0
- eye/dataset/formats/pascal_voc.py +279 -0
- eye/dataset/formats/yolo.py +272 -0
- eye/dataset/utils.py +259 -0
- eye/detection/__init__.py +0 -0
- eye/detection/auto_convert.py +155 -0
- eye/detection/core.py +1529 -0
- eye/detection/detections_enhanced.py +392 -0
- eye/detection/line_zone.py +859 -0
- eye/detection/lmm.py +184 -0
- eye/detection/overlap_filter.py +270 -0
- eye/detection/tools/__init__.py +0 -0
- eye/detection/tools/csv_sink.py +181 -0
- eye/detection/tools/inference_slicer.py +288 -0
- eye/detection/tools/json_sink.py +142 -0
- eye/detection/tools/polygon_zone.py +202 -0
- eye/detection/tools/smoother.py +123 -0
- eye/detection/tools/smoothing.py +179 -0
- eye/detection/tools/smoothing_config.py +202 -0
- eye/detection/tools/transformers.py +247 -0
- eye/detection/utils.py +1175 -0
- eye/draw/__init__.py +0 -0
- eye/draw/color.py +154 -0
- eye/draw/utils.py +374 -0
- eye/filters.py +112 -0
- eye/geometry/__init__.py +0 -0
- eye/geometry/core.py +128 -0
- eye/geometry/utils.py +47 -0
- eye/keypoint/__init__.py +0 -0
- eye/keypoint/annotators.py +442 -0
- eye/keypoint/core.py +687 -0
- eye/keypoint/skeletons.py +2647 -0
- eye/metrics/__init__.py +21 -0
- eye/metrics/core.py +72 -0
- eye/metrics/detection.py +843 -0
- eye/metrics/f1_score.py +648 -0
- eye/metrics/mean_average_precision.py +628 -0
- eye/metrics/mean_average_recall.py +697 -0
- eye/metrics/precision.py +653 -0
- eye/metrics/recall.py +652 -0
- eye/metrics/utils/__init__.py +0 -0
- eye/metrics/utils/object_size.py +158 -0
- eye/metrics/utils/utils.py +9 -0
- eye/py.typed +0 -0
- eye/quick.py +104 -0
- eye/tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/core.py +386 -0
- eye/tracker/byte_tracker/kalman_filter.py +205 -0
- eye/tracker/byte_tracker/matching.py +69 -0
- eye/tracker/byte_tracker/single_object_track.py +178 -0
- eye/tracker/byte_tracker/utils.py +18 -0
- eye/utils/__init__.py +0 -0
- eye/utils/conversion.py +132 -0
- eye/utils/file.py +159 -0
- eye/utils/image.py +794 -0
- eye/utils/internal.py +200 -0
- eye/utils/iterables.py +84 -0
- eye/utils/notebook.py +114 -0
- eye/utils/video.py +307 -0
- eye/utils_eye/__init__.py +1 -0
- eye/utils_eye/geometry.py +71 -0
- eye/utils_eye/nms.py +55 -0
- eye/validators/__init__.py +140 -0
- eye/web.py +271 -0
- eye_cv-1.0.0.dist-info/METADATA +319 -0
- eye_cv-1.0.0.dist-info/RECORD +94 -0
- eye_cv-1.0.0.dist-info/WHEEL +5 -0
- eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
- eye_cv-1.0.0.dist-info/top_level.txt +1 -0
eye/detection/lmm.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from eye.detection.utils import polygon_to_mask, polygon_to_xyxy
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LMM(Enum):
|
|
11
|
+
PALIGEMMA = "paligemma"
|
|
12
|
+
FLORENCE_2 = "florence_2"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
RESULT_TYPES: Dict[LMM, type] = {LMM.PALIGEMMA: str, LMM.FLORENCE_2: dict}
|
|
16
|
+
|
|
17
|
+
REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {
|
|
18
|
+
LMM.PALIGEMMA: ["resolution_wh"],
|
|
19
|
+
LMM.FLORENCE_2: ["resolution_wh"],
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {
|
|
23
|
+
LMM.PALIGEMMA: ["resolution_wh", "classes"],
|
|
24
|
+
LMM.FLORENCE_2: ["resolution_wh"],
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
SUPPORTED_TASKS_FLORENCE_2 = [
|
|
28
|
+
"<OD>",
|
|
29
|
+
"<CAPTION_TO_PHRASE_GROUNDING>",
|
|
30
|
+
"<DENSE_REGION_CAPTION>",
|
|
31
|
+
"<REGION_PROPOSAL>",
|
|
32
|
+
"<OCR_WITH_REGION>",
|
|
33
|
+
"<REFERRING_EXPRESSION_SEGMENTATION>",
|
|
34
|
+
"<REGION_TO_SEGMENTATION>",
|
|
35
|
+
"<OPEN_VOCABULARY_DETECTION>",
|
|
36
|
+
"<REGION_TO_CATEGORY>",
|
|
37
|
+
"<REGION_TO_DESCRIPTION>",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def validate_lmm_parameters(
|
|
42
|
+
lmm: Union[LMM, str], result: Any, kwargs: Dict[str, Any]
|
|
43
|
+
) -> LMM:
|
|
44
|
+
if isinstance(lmm, str):
|
|
45
|
+
try:
|
|
46
|
+
lmm = LMM(lmm.lower())
|
|
47
|
+
except ValueError:
|
|
48
|
+
raise ValueError(
|
|
49
|
+
f"Invalid lmm value: {lmm}. Must be one of {[e.value for e in LMM]}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
if not isinstance(result, RESULT_TYPES[lmm]):
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Invalid LMM result type: {type(result)}. Must be {RESULT_TYPES[lmm]}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
required_args = REQUIRED_ARGUMENTS.get(lmm, [])
|
|
58
|
+
for arg in required_args:
|
|
59
|
+
if arg not in kwargs:
|
|
60
|
+
raise ValueError(f"Missing required argument: {arg}")
|
|
61
|
+
|
|
62
|
+
allowed_args = ALLOWED_ARGUMENTS.get(lmm, [])
|
|
63
|
+
for arg in kwargs:
|
|
64
|
+
if arg not in allowed_args:
|
|
65
|
+
raise ValueError(f"Argument {arg} is not allowed for {lmm.name}")
|
|
66
|
+
|
|
67
|
+
return lmm
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def from_paligemma(
|
|
71
|
+
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
|
|
72
|
+
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
|
|
73
|
+
w, h = resolution_wh
|
|
74
|
+
pattern = re.compile(
|
|
75
|
+
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
|
|
76
|
+
)
|
|
77
|
+
matches = pattern.findall(result)
|
|
78
|
+
matches = np.array(matches) if matches else np.empty((0, 5))
|
|
79
|
+
|
|
80
|
+
xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
|
|
81
|
+
xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
|
|
82
|
+
class_name = np.char.strip(class_name.astype(str))
|
|
83
|
+
class_id = None
|
|
84
|
+
|
|
85
|
+
if classes is not None:
|
|
86
|
+
mask = np.array([name in classes for name in class_name]).astype(bool)
|
|
87
|
+
xyxy, class_name = xyxy[mask], class_name[mask]
|
|
88
|
+
class_id = np.array([classes.index(name) for name in class_name])
|
|
89
|
+
|
|
90
|
+
return xyxy, class_id, class_name
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def from_florence_2(
|
|
94
|
+
result: dict, resolution_wh: Tuple[int, int]
|
|
95
|
+
) -> Tuple[
|
|
96
|
+
np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]
|
|
97
|
+
]:
|
|
98
|
+
"""
|
|
99
|
+
Parse results from the Florence 2 multi-model model.
|
|
100
|
+
https://huggingface.co/microsoft/Florence-2-large
|
|
101
|
+
|
|
102
|
+
Parameters:
|
|
103
|
+
result: dict containing the model output
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
xyxy (np.ndarray): An array of shape `(n, 4)` containing
|
|
107
|
+
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
|
|
108
|
+
labels: (Optional[np.ndarray]): An array of shape `(n,)` containing
|
|
109
|
+
the class labels for each bounding box
|
|
110
|
+
masks: (Optional[np.ndarray]): An array of shape `(n, h, w)` containing
|
|
111
|
+
the segmentation masks for each bounding box
|
|
112
|
+
obb_boxes: (Optional[np.ndarray]): An array of shape `(n, 4, 2)` containing
|
|
113
|
+
oriented bounding boxes.
|
|
114
|
+
"""
|
|
115
|
+
assert len(result) == 1, f"Expected result with a single element. Got: {result}"
|
|
116
|
+
task = next(iter(result.keys()))
|
|
117
|
+
if task not in SUPPORTED_TASKS_FLORENCE_2:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
f"{task} not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}"
|
|
120
|
+
)
|
|
121
|
+
result = result[task]
|
|
122
|
+
|
|
123
|
+
if task in ["<OD>", "<CAPTION_TO_PHRASE_GROUNDING>", "<DENSE_REGION_CAPTION>"]:
|
|
124
|
+
xyxy = np.array(result["bboxes"], dtype=np.float32)
|
|
125
|
+
labels = np.array(result["labels"])
|
|
126
|
+
return xyxy, labels, None, None
|
|
127
|
+
|
|
128
|
+
if task == "<REGION_PROPOSAL>":
|
|
129
|
+
xyxy = np.array(result["bboxes"], dtype=np.float32)
|
|
130
|
+
# provides labels, but they are ["", "", "", ...]
|
|
131
|
+
return xyxy, None, None, None
|
|
132
|
+
|
|
133
|
+
if task == "<OCR_WITH_REGION>":
|
|
134
|
+
xyxyxyxy = np.array(result["quad_boxes"], dtype=np.float32)
|
|
135
|
+
xyxyxyxy = xyxyxyxy.reshape(-1, 4, 2)
|
|
136
|
+
xyxy = np.array([polygon_to_xyxy(polygon) for polygon in xyxyxyxy])
|
|
137
|
+
labels = np.array(result["labels"])
|
|
138
|
+
return xyxy, labels, None, xyxyxyxy
|
|
139
|
+
|
|
140
|
+
if task in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
|
|
141
|
+
xyxy_list = []
|
|
142
|
+
masks_list = []
|
|
143
|
+
for polygons_of_same_class in result["polygons"]:
|
|
144
|
+
for polygon in polygons_of_same_class:
|
|
145
|
+
polygon = np.reshape(polygon, (-1, 2)).astype(np.int32)
|
|
146
|
+
mask = polygon_to_mask(polygon, resolution_wh).astype(bool)
|
|
147
|
+
masks_list.append(mask)
|
|
148
|
+
xyxy = polygon_to_xyxy(polygon)
|
|
149
|
+
xyxy_list.append(xyxy)
|
|
150
|
+
# per-class labels also provided, but they are ["", "", "", ...]
|
|
151
|
+
# when we figure out how to set class names, we can do
|
|
152
|
+
# zip(result["labels"], result["polygons"])
|
|
153
|
+
xyxy = np.array(xyxy_list, dtype=np.float32)
|
|
154
|
+
masks = np.array(masks_list)
|
|
155
|
+
return xyxy, None, masks, None
|
|
156
|
+
|
|
157
|
+
if task == "<OPEN_VOCABULARY_DETECTION>":
|
|
158
|
+
xyxy = np.array(result["bboxes"], dtype=np.float32)
|
|
159
|
+
labels = np.array(result["bboxes_labels"])
|
|
160
|
+
# Also has "polygons" and "polygons_labels", but they don't seem to be used
|
|
161
|
+
return xyxy, labels, None, None
|
|
162
|
+
|
|
163
|
+
if task in ["<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>"]:
|
|
164
|
+
assert isinstance(
|
|
165
|
+
result, str
|
|
166
|
+
), f"Expected string as <REGION_TO_CATEGORY> result, got {type(result)}"
|
|
167
|
+
|
|
168
|
+
if result == "No object detected.":
|
|
169
|
+
return np.empty((0, 4), dtype=np.float32), np.array([]), None, None
|
|
170
|
+
|
|
171
|
+
pattern = re.compile(r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>")
|
|
172
|
+
match = pattern.search(result)
|
|
173
|
+
assert (
|
|
174
|
+
match is not None
|
|
175
|
+
), f"Expected string to end in location tags, but got {result}"
|
|
176
|
+
|
|
177
|
+
w, h = resolution_wh
|
|
178
|
+
xyxy = np.array([match.groups()], dtype=np.float32)
|
|
179
|
+
xyxy *= np.array([w, h, w, h]) / 1000
|
|
180
|
+
result_string = result[: match.start()]
|
|
181
|
+
labels = np.array([result_string])
|
|
182
|
+
return xyxy, labels, None, None
|
|
183
|
+
|
|
184
|
+
assert False, f"Unimplemented task: {task}"
|
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import List, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import numpy.typing as npt
|
|
8
|
+
|
|
9
|
+
from eye.detection.utils import box_iou_batch, mask_iou_batch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
|
|
13
|
+
"""
|
|
14
|
+
Resize all masks in the array to have a maximum dimension of max_dimension,
|
|
15
|
+
maintaining aspect ratio.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
masks (np.ndarray): 3D array of binary masks with shape (N, H, W).
|
|
19
|
+
max_dimension (int): The maximum dimension for the resized masks.
|
|
20
|
+
|
|
21
|
+
Returns:
|
|
22
|
+
np.ndarray: Array of resized masks.
|
|
23
|
+
"""
|
|
24
|
+
max_height = np.max(masks.shape[1])
|
|
25
|
+
max_width = np.max(masks.shape[2])
|
|
26
|
+
scale = min(max_dimension / max_height, max_dimension / max_width)
|
|
27
|
+
|
|
28
|
+
new_height = int(scale * max_height)
|
|
29
|
+
new_width = int(scale * max_width)
|
|
30
|
+
|
|
31
|
+
x = np.linspace(0, max_width - 1, new_width).astype(int)
|
|
32
|
+
y = np.linspace(0, max_height - 1, new_height).astype(int)
|
|
33
|
+
xv, yv = np.meshgrid(x, y)
|
|
34
|
+
|
|
35
|
+
resized_masks = masks[:, yv, xv]
|
|
36
|
+
|
|
37
|
+
resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width)
|
|
38
|
+
return resized_masks
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def mask_non_max_suppression(
|
|
42
|
+
predictions: np.ndarray,
|
|
43
|
+
masks: np.ndarray,
|
|
44
|
+
iou_threshold: float = 0.5,
|
|
45
|
+
mask_dimension: int = 640,
|
|
46
|
+
) -> np.ndarray:
|
|
47
|
+
"""
|
|
48
|
+
Perform Non-Maximum Suppression (NMS) on segmentation predictions.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
predictions (np.ndarray): A 2D array of object detection predictions in
|
|
52
|
+
the format of `(x_min, y_min, x_max, y_max, score)`
|
|
53
|
+
or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
|
|
54
|
+
`(N, 6)`, where N is the number of predictions.
|
|
55
|
+
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
|
|
56
|
+
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
|
|
57
|
+
dimensions of each mask.
|
|
58
|
+
iou_threshold (float): The intersection-over-union threshold
|
|
59
|
+
to use for non-maximum suppression.
|
|
60
|
+
mask_dimension (int): The dimension to which the masks should be
|
|
61
|
+
resized before computing IOU values. Defaults to 640.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
np.ndarray: A boolean array indicating which predictions to keep after
|
|
65
|
+
non-maximum suppression.
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
AssertionError: If `iou_threshold` is not within the closed
|
|
69
|
+
range from `0` to `1`.
|
|
70
|
+
"""
|
|
71
|
+
assert 0 <= iou_threshold <= 1, (
|
|
72
|
+
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
|
|
73
|
+
f"{iou_threshold} given."
|
|
74
|
+
)
|
|
75
|
+
rows, columns = predictions.shape
|
|
76
|
+
|
|
77
|
+
if columns == 5:
|
|
78
|
+
predictions = np.c_[predictions, np.zeros(rows)]
|
|
79
|
+
|
|
80
|
+
sort_index = predictions[:, 4].argsort()[::-1]
|
|
81
|
+
predictions = predictions[sort_index]
|
|
82
|
+
masks = masks[sort_index]
|
|
83
|
+
masks_resized = resize_masks(masks, mask_dimension)
|
|
84
|
+
ious = mask_iou_batch(masks_resized, masks_resized)
|
|
85
|
+
categories = predictions[:, 5]
|
|
86
|
+
|
|
87
|
+
keep = np.ones(rows, dtype=bool)
|
|
88
|
+
for i in range(rows):
|
|
89
|
+
if keep[i]:
|
|
90
|
+
condition = (ious[i] > iou_threshold) & (categories[i] == categories)
|
|
91
|
+
keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])
|
|
92
|
+
|
|
93
|
+
return keep[sort_index.argsort()]
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def box_non_max_suppression(
|
|
97
|
+
predictions: np.ndarray, iou_threshold: float = 0.5
|
|
98
|
+
) -> np.ndarray:
|
|
99
|
+
"""
|
|
100
|
+
Perform Non-Maximum Suppression (NMS) on object detection predictions.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
predictions (np.ndarray): An array of object detection predictions in
|
|
104
|
+
the format of `(x_min, y_min, x_max, y_max, score)`
|
|
105
|
+
or `(x_min, y_min, x_max, y_max, score, class)`.
|
|
106
|
+
iou_threshold (float): The intersection-over-union threshold
|
|
107
|
+
to use for non-maximum suppression.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
np.ndarray: A boolean array indicating which predictions to keep after n
|
|
111
|
+
on-maximum suppression.
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
AssertionError: If `iou_threshold` is not within the
|
|
115
|
+
closed range from `0` to `1`.
|
|
116
|
+
"""
|
|
117
|
+
assert 0 <= iou_threshold <= 1, (
|
|
118
|
+
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
|
|
119
|
+
f"{iou_threshold} given."
|
|
120
|
+
)
|
|
121
|
+
rows, columns = predictions.shape
|
|
122
|
+
|
|
123
|
+
# add column #5 - category filled with zeros for agnostic nms
|
|
124
|
+
if columns == 5:
|
|
125
|
+
predictions = np.c_[predictions, np.zeros(rows)]
|
|
126
|
+
|
|
127
|
+
# sort predictions column #4 - score
|
|
128
|
+
sort_index = np.flip(predictions[:, 4].argsort())
|
|
129
|
+
predictions = predictions[sort_index]
|
|
130
|
+
|
|
131
|
+
boxes = predictions[:, :4]
|
|
132
|
+
categories = predictions[:, 5]
|
|
133
|
+
ious = box_iou_batch(boxes, boxes)
|
|
134
|
+
ious = ious - np.eye(rows)
|
|
135
|
+
|
|
136
|
+
keep = np.ones(rows, dtype=bool)
|
|
137
|
+
|
|
138
|
+
for index, (iou, category) in enumerate(zip(ious, categories)):
|
|
139
|
+
if not keep[index]:
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# drop detections with iou > iou_threshold and
|
|
143
|
+
# same category as current detections
|
|
144
|
+
condition = (iou > iou_threshold) & (categories == category)
|
|
145
|
+
keep = keep & ~condition
|
|
146
|
+
|
|
147
|
+
return keep[sort_index.argsort()]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def group_overlapping_boxes(
|
|
151
|
+
predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
|
|
152
|
+
) -> List[List[int]]:
|
|
153
|
+
"""
|
|
154
|
+
Apply greedy version of non-maximum merging to avoid detecting too many
|
|
155
|
+
overlapping bounding boxes for a given object.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
|
|
159
|
+
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
|
|
160
|
+
and the confidence scores.
|
|
161
|
+
iou_threshold (float): The intersection-over-union threshold
|
|
162
|
+
to use for non-maximum suppression. Defaults to 0.5.
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
List[List[int]]: Groups of prediction indices be merged.
|
|
166
|
+
Each group may have 1 or more elements.
|
|
167
|
+
"""
|
|
168
|
+
merge_groups: List[List[int]] = []
|
|
169
|
+
|
|
170
|
+
scores = predictions[:, 4]
|
|
171
|
+
order = scores.argsort()
|
|
172
|
+
|
|
173
|
+
while len(order) > 0:
|
|
174
|
+
idx = int(order[-1])
|
|
175
|
+
|
|
176
|
+
order = order[:-1]
|
|
177
|
+
if len(order) == 0:
|
|
178
|
+
merge_groups.append([idx])
|
|
179
|
+
break
|
|
180
|
+
|
|
181
|
+
merge_candidate = np.expand_dims(predictions[idx], axis=0)
|
|
182
|
+
ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
|
|
183
|
+
ious = ious.flatten()
|
|
184
|
+
|
|
185
|
+
above_threshold = ious >= iou_threshold
|
|
186
|
+
merge_group = [idx, *np.flip(order[above_threshold]).tolist()]
|
|
187
|
+
merge_groups.append(merge_group)
|
|
188
|
+
order = order[~above_threshold]
|
|
189
|
+
return merge_groups
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def box_non_max_merge(
|
|
193
|
+
predictions: npt.NDArray[np.float64],
|
|
194
|
+
iou_threshold: float = 0.5,
|
|
195
|
+
) -> List[List[int]]:
|
|
196
|
+
"""
|
|
197
|
+
Apply greedy version of non-maximum merging per category to avoid detecting
|
|
198
|
+
too many overlapping bounding boxes for a given object.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)`
|
|
202
|
+
containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`,
|
|
203
|
+
the confidence scores and class_ids. Omit class_id column to allow
|
|
204
|
+
detections of different classes to be merged.
|
|
205
|
+
iou_threshold (float): The intersection-over-union threshold
|
|
206
|
+
to use for non-maximum suppression. Defaults to 0.5.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
List[List[int]]: Groups of prediction indices be merged.
|
|
210
|
+
Each group may have 1 or more elements.
|
|
211
|
+
"""
|
|
212
|
+
if predictions.shape[1] == 5:
|
|
213
|
+
return group_overlapping_boxes(predictions, iou_threshold)
|
|
214
|
+
|
|
215
|
+
category_ids = predictions[:, 5]
|
|
216
|
+
merge_groups = []
|
|
217
|
+
for category_id in np.unique(category_ids):
|
|
218
|
+
curr_indices = np.where(category_ids == category_id)[0]
|
|
219
|
+
merge_class_groups = group_overlapping_boxes(
|
|
220
|
+
predictions[curr_indices], iou_threshold
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
for merge_class_group in merge_class_groups:
|
|
224
|
+
merge_groups.append(curr_indices[merge_class_group].tolist())
|
|
225
|
+
|
|
226
|
+
for merge_group in merge_groups:
|
|
227
|
+
if len(merge_group) == 0:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Empty group detected when non-max-merging "
|
|
230
|
+
f"detections: {merge_groups}"
|
|
231
|
+
)
|
|
232
|
+
return merge_groups
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
class OverlapFilter(Enum):
|
|
236
|
+
"""
|
|
237
|
+
Enum specifying the strategy for filtering overlapping detections.
|
|
238
|
+
|
|
239
|
+
Attributes:
|
|
240
|
+
NONE: Do not filter detections based on overlap.
|
|
241
|
+
NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means,
|
|
242
|
+
detections that overlap by more than a set threshold will be discarded,
|
|
243
|
+
except for the one with the highest confidence.
|
|
244
|
+
NON_MAX_MERGE: Merge detections with non-max merging. This means,
|
|
245
|
+
detections that overlap by more than a set threshold will be merged
|
|
246
|
+
into a single detection.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
NONE = "none"
|
|
250
|
+
NON_MAX_SUPPRESSION = "non_max_suppression"
|
|
251
|
+
NON_MAX_MERGE = "non_max_merge"
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def list(cls):
|
|
255
|
+
return list(map(lambda c: c.value, cls))
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def from_value(cls, value: Union[OverlapFilter, str]) -> OverlapFilter:
|
|
259
|
+
if isinstance(value, cls):
|
|
260
|
+
return value
|
|
261
|
+
if isinstance(value, str):
|
|
262
|
+
value = value.lower()
|
|
263
|
+
try:
|
|
264
|
+
return cls(value)
|
|
265
|
+
except ValueError:
|
|
266
|
+
raise ValueError(f"Invalid value: {value}. Must be one of {cls.list()}")
|
|
267
|
+
raise ValueError(
|
|
268
|
+
f"Invalid value type: {type(value)}. Must be an instance of "
|
|
269
|
+
f"{cls.__name__} or str."
|
|
270
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import csv
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from eye.detection.core import Detections
|
|
8
|
+
|
|
9
|
+
BASE_HEADER = [
|
|
10
|
+
"x_min",
|
|
11
|
+
"y_min",
|
|
12
|
+
"x_max",
|
|
13
|
+
"y_max",
|
|
14
|
+
"class_id",
|
|
15
|
+
"confidence",
|
|
16
|
+
"tracker_id",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CSVSink:
|
|
21
|
+
"""
|
|
22
|
+
A utility class for saving detection data to a CSV file. This class is designed to
|
|
23
|
+
efficiently serialize detection objects into a CSV format, allowing for the
|
|
24
|
+
inclusion of bounding box coordinates and additional attributes like `confidence`,
|
|
25
|
+
`class_id`, and `tracker_id`.
|
|
26
|
+
|
|
27
|
+
!!! tip
|
|
28
|
+
|
|
29
|
+
CSVSink allow to pass custom data alongside the detection fields, providing
|
|
30
|
+
flexibility for logging various types of information.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
file_name (str): The name of the CSV file where the detections will be stored.
|
|
34
|
+
Defaults to 'output.csv'.
|
|
35
|
+
|
|
36
|
+
Example:
|
|
37
|
+
```python
|
|
38
|
+
import eye as sv
|
|
39
|
+
from ultralytics import YOLO
|
|
40
|
+
|
|
41
|
+
model = YOLO(<SOURCE_MODEL_PATH>)
|
|
42
|
+
csv_sink = sv.CSVSink(<RESULT_CSV_FILE_PATH>)
|
|
43
|
+
frames_generator = sv.get_video_frames_generator(<SOURCE_VIDEO_PATH>)
|
|
44
|
+
|
|
45
|
+
with csv_sink as sink:
|
|
46
|
+
for frame in frames_generator:
|
|
47
|
+
result = model(frame)[0]
|
|
48
|
+
detections = sv.Detections.from_ultralytics(result)
|
|
49
|
+
sink.append(detections, custom_data={'<CUSTOM_LABEL>':'<CUSTOM_DATA>'})
|
|
50
|
+
```
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, file_name: str = "output.csv") -> None:
|
|
54
|
+
"""
|
|
55
|
+
Initialize the CSVSink instance.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
file_name (str): The name of the CSV file.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
None
|
|
62
|
+
"""
|
|
63
|
+
self.file_name = file_name
|
|
64
|
+
self.file: Optional[open] = None
|
|
65
|
+
self.writer: Optional[csv.writer] = None
|
|
66
|
+
self.header_written = False
|
|
67
|
+
self.field_names = []
|
|
68
|
+
|
|
69
|
+
def __enter__(self) -> CSVSink:
|
|
70
|
+
self.open()
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
def __exit__(
|
|
74
|
+
self,
|
|
75
|
+
exc_type: Optional[type],
|
|
76
|
+
exc_val: Optional[Exception],
|
|
77
|
+
exc_tb: Optional[Any],
|
|
78
|
+
) -> None:
|
|
79
|
+
self.close()
|
|
80
|
+
|
|
81
|
+
def open(self) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Open the CSV file for writing.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
None
|
|
87
|
+
"""
|
|
88
|
+
parent_directory = os.path.dirname(self.file_name)
|
|
89
|
+
if parent_directory and not os.path.exists(parent_directory):
|
|
90
|
+
os.makedirs(parent_directory)
|
|
91
|
+
|
|
92
|
+
self.file = open(self.file_name, "w", newline="")
|
|
93
|
+
self.writer = csv.writer(self.file)
|
|
94
|
+
|
|
95
|
+
def close(self) -> None:
|
|
96
|
+
"""
|
|
97
|
+
Close the CSV file.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
None
|
|
101
|
+
"""
|
|
102
|
+
if self.file:
|
|
103
|
+
self.file.close()
|
|
104
|
+
|
|
105
|
+
@staticmethod
|
|
106
|
+
def parse_detection_data(
|
|
107
|
+
detections: Detections, custom_data: Optional[Dict[str, Any]] = None
|
|
108
|
+
) -> List[Dict[str, Any]]:
|
|
109
|
+
parsed_rows = []
|
|
110
|
+
for i in range(len(detections.xyxy)):
|
|
111
|
+
row = {
|
|
112
|
+
"x_min": detections.xyxy[i][0],
|
|
113
|
+
"y_min": detections.xyxy[i][1],
|
|
114
|
+
"x_max": detections.xyxy[i][2],
|
|
115
|
+
"y_max": detections.xyxy[i][3],
|
|
116
|
+
"class_id": ""
|
|
117
|
+
if detections.class_id is None
|
|
118
|
+
else str(detections.class_id[i]),
|
|
119
|
+
"confidence": ""
|
|
120
|
+
if detections.confidence is None
|
|
121
|
+
else str(detections.confidence[i]),
|
|
122
|
+
"tracker_id": ""
|
|
123
|
+
if detections.tracker_id is None
|
|
124
|
+
else str(detections.tracker_id[i]),
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
if hasattr(detections, "data"):
|
|
128
|
+
for key, value in detections.data.items():
|
|
129
|
+
if value.ndim == 0:
|
|
130
|
+
row[key] = value
|
|
131
|
+
else:
|
|
132
|
+
row[key] = value[i]
|
|
133
|
+
|
|
134
|
+
if custom_data:
|
|
135
|
+
row.update(custom_data)
|
|
136
|
+
parsed_rows.append(row)
|
|
137
|
+
return parsed_rows
|
|
138
|
+
|
|
139
|
+
def append(
|
|
140
|
+
self, detections: Detections, custom_data: Optional[Dict[str, Any]] = None
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Append detection data to the CSV file.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
detections (Detections): The detection data.
|
|
147
|
+
custom_data (Dict[str, Any]): Custom data to include.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
None
|
|
151
|
+
"""
|
|
152
|
+
if not self.writer:
|
|
153
|
+
raise Exception(
|
|
154
|
+
f"Cannot append to CSV: The file '{self.file_name}' is not open."
|
|
155
|
+
)
|
|
156
|
+
field_names = CSVSink.parse_field_names(detections, custom_data)
|
|
157
|
+
if not self.header_written:
|
|
158
|
+
self.field_names = field_names
|
|
159
|
+
self.writer.writerow(field_names)
|
|
160
|
+
self.header_written = True
|
|
161
|
+
|
|
162
|
+
if field_names != self.field_names:
|
|
163
|
+
print(
|
|
164
|
+
f"Field names do not match the header. "
|
|
165
|
+
f"Expected: {self.field_names}, given: {field_names}"
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
parsed_rows = CSVSink.parse_detection_data(detections, custom_data)
|
|
169
|
+
for row in parsed_rows:
|
|
170
|
+
self.writer.writerow(
|
|
171
|
+
[row.get(field_name, "") for field_name in self.field_names]
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
@staticmethod
|
|
175
|
+
def parse_field_names(
|
|
176
|
+
detections: Detections, custom_data: Dict[str, Any]
|
|
177
|
+
) -> List[str]:
|
|
178
|
+
dynamic_header = sorted(
|
|
179
|
+
set(custom_data.keys()) | set(getattr(detections, "data", {}).keys())
|
|
180
|
+
)
|
|
181
|
+
return BASE_HEADER + dynamic_header
|