eye-cv 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- eye/__init__.py +115 -0
- eye/__init___supervision_original.py +120 -0
- eye/annotators/__init__.py +0 -0
- eye/annotators/base.py +22 -0
- eye/annotators/core.py +2699 -0
- eye/annotators/line.py +107 -0
- eye/annotators/modern.py +529 -0
- eye/annotators/trace.py +142 -0
- eye/annotators/utils.py +177 -0
- eye/assets/__init__.py +2 -0
- eye/assets/downloader.py +95 -0
- eye/assets/list.py +83 -0
- eye/classification/__init__.py +0 -0
- eye/classification/core.py +188 -0
- eye/config.py +2 -0
- eye/core/__init__.py +0 -0
- eye/core/trackers/__init__.py +1 -0
- eye/core/trackers/botsort_tracker.py +336 -0
- eye/core/trackers/bytetrack_tracker.py +284 -0
- eye/core/trackers/sort_tracker.py +200 -0
- eye/core/tracking.py +146 -0
- eye/dataset/__init__.py +0 -0
- eye/dataset/core.py +919 -0
- eye/dataset/formats/__init__.py +0 -0
- eye/dataset/formats/coco.py +258 -0
- eye/dataset/formats/pascal_voc.py +279 -0
- eye/dataset/formats/yolo.py +272 -0
- eye/dataset/utils.py +259 -0
- eye/detection/__init__.py +0 -0
- eye/detection/auto_convert.py +155 -0
- eye/detection/core.py +1529 -0
- eye/detection/detections_enhanced.py +392 -0
- eye/detection/line_zone.py +859 -0
- eye/detection/lmm.py +184 -0
- eye/detection/overlap_filter.py +270 -0
- eye/detection/tools/__init__.py +0 -0
- eye/detection/tools/csv_sink.py +181 -0
- eye/detection/tools/inference_slicer.py +288 -0
- eye/detection/tools/json_sink.py +142 -0
- eye/detection/tools/polygon_zone.py +202 -0
- eye/detection/tools/smoother.py +123 -0
- eye/detection/tools/smoothing.py +179 -0
- eye/detection/tools/smoothing_config.py +202 -0
- eye/detection/tools/transformers.py +247 -0
- eye/detection/utils.py +1175 -0
- eye/draw/__init__.py +0 -0
- eye/draw/color.py +154 -0
- eye/draw/utils.py +374 -0
- eye/filters.py +112 -0
- eye/geometry/__init__.py +0 -0
- eye/geometry/core.py +128 -0
- eye/geometry/utils.py +47 -0
- eye/keypoint/__init__.py +0 -0
- eye/keypoint/annotators.py +442 -0
- eye/keypoint/core.py +687 -0
- eye/keypoint/skeletons.py +2647 -0
- eye/metrics/__init__.py +21 -0
- eye/metrics/core.py +72 -0
- eye/metrics/detection.py +843 -0
- eye/metrics/f1_score.py +648 -0
- eye/metrics/mean_average_precision.py +628 -0
- eye/metrics/mean_average_recall.py +697 -0
- eye/metrics/precision.py +653 -0
- eye/metrics/recall.py +652 -0
- eye/metrics/utils/__init__.py +0 -0
- eye/metrics/utils/object_size.py +158 -0
- eye/metrics/utils/utils.py +9 -0
- eye/py.typed +0 -0
- eye/quick.py +104 -0
- eye/tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/__init__.py +0 -0
- eye/tracker/byte_tracker/core.py +386 -0
- eye/tracker/byte_tracker/kalman_filter.py +205 -0
- eye/tracker/byte_tracker/matching.py +69 -0
- eye/tracker/byte_tracker/single_object_track.py +178 -0
- eye/tracker/byte_tracker/utils.py +18 -0
- eye/utils/__init__.py +0 -0
- eye/utils/conversion.py +132 -0
- eye/utils/file.py +159 -0
- eye/utils/image.py +794 -0
- eye/utils/internal.py +200 -0
- eye/utils/iterables.py +84 -0
- eye/utils/notebook.py +114 -0
- eye/utils/video.py +307 -0
- eye/utils_eye/__init__.py +1 -0
- eye/utils_eye/geometry.py +71 -0
- eye/utils_eye/nms.py +55 -0
- eye/validators/__init__.py +140 -0
- eye/web.py +271 -0
- eye_cv-1.0.0.dist-info/METADATA +319 -0
- eye_cv-1.0.0.dist-info/RECORD +94 -0
- eye_cv-1.0.0.dist-info/WHEEL +5 -0
- eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
- eye_cv-1.0.0.dist-info/top_level.txt +1 -0
eye/annotators/trace.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Trace annotator for drawing object paths."""
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Optional, Dict
|
|
6
|
+
from collections import deque
|
|
7
|
+
from ..core.detections import Detections
|
|
8
|
+
from ..core.colors import ColorPalette
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TraceAnnotator:
|
|
12
|
+
"""Draw tracking trails/paths for objects.
|
|
13
|
+
|
|
14
|
+
Innovation: Fading trails, variable thickness, trail smoothing.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
color_palette: Optional[ColorPalette] = None,
|
|
20
|
+
thickness: int = 2,
|
|
21
|
+
trace_length: int = 50,
|
|
22
|
+
fade_trail: bool = True,
|
|
23
|
+
smooth_trail: bool = False
|
|
24
|
+
):
|
|
25
|
+
"""
|
|
26
|
+
Args:
|
|
27
|
+
color_palette: Color palette
|
|
28
|
+
thickness: Line thickness
|
|
29
|
+
trace_length: Maximum number of points in trail
|
|
30
|
+
fade_trail: Gradually fade older trail points
|
|
31
|
+
smooth_trail: Apply smoothing to trail
|
|
32
|
+
"""
|
|
33
|
+
from ..core.colors import PredefinedPalettes
|
|
34
|
+
self.color_palette = color_palette or PredefinedPalettes.bright()
|
|
35
|
+
self.thickness = thickness
|
|
36
|
+
self.trace_length = trace_length
|
|
37
|
+
self.fade_trail = fade_trail
|
|
38
|
+
self.smooth_trail = smooth_trail
|
|
39
|
+
|
|
40
|
+
# Track history for each object
|
|
41
|
+
self.history: Dict[int, deque] = {}
|
|
42
|
+
# Position smoothing: base alpha (0-1), lower = more smoothing
|
|
43
|
+
self.position_smoothing_alpha = 0.6
|
|
44
|
+
# If movement speed (pixels/frame) exceeds this, use stronger smoothing
|
|
45
|
+
self.position_speed_threshold = 8.0
|
|
46
|
+
# Alpha to use when speed > threshold
|
|
47
|
+
self.fast_position_smoothing_alpha = 0.3
|
|
48
|
+
|
|
49
|
+
def annotate(
|
|
50
|
+
self,
|
|
51
|
+
image: np.ndarray,
|
|
52
|
+
detections: Detections
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
"""Draw trails on image."""
|
|
55
|
+
annotated = image.copy()
|
|
56
|
+
|
|
57
|
+
if detections.tracker_id is None:
|
|
58
|
+
return annotated
|
|
59
|
+
|
|
60
|
+
# Update history
|
|
61
|
+
active_ids = set()
|
|
62
|
+
for i in range(len(detections)):
|
|
63
|
+
tracker_id = detections.tracker_id[i]
|
|
64
|
+
active_ids.add(tracker_id)
|
|
65
|
+
center = detections.center[i]
|
|
66
|
+
|
|
67
|
+
if tracker_id not in self.history:
|
|
68
|
+
self.history[tracker_id] = deque(maxlen=self.trace_length)
|
|
69
|
+
# Apply optional temporal smoothing when appending to history
|
|
70
|
+
if len(self.history[tracker_id]) > 0 and self.position_smoothing_alpha is not None:
|
|
71
|
+
last = np.array(self.history[tracker_id][-1], dtype=float)
|
|
72
|
+
newc = np.array(center, dtype=float)
|
|
73
|
+
dist = float(np.linalg.norm(newc - last))
|
|
74
|
+
if self.position_speed_threshold is not None and dist > float(self.position_speed_threshold):
|
|
75
|
+
alpha = float(self.fast_position_smoothing_alpha)
|
|
76
|
+
else:
|
|
77
|
+
alpha = float(self.position_smoothing_alpha)
|
|
78
|
+
center = alpha * newc + (1.0 - alpha) * last
|
|
79
|
+
|
|
80
|
+
self.history[tracker_id].append(center)
|
|
81
|
+
|
|
82
|
+
# Remove old tracks
|
|
83
|
+
inactive_ids = set(self.history.keys()) - active_ids
|
|
84
|
+
for tid in inactive_ids:
|
|
85
|
+
if len(self.history[tid]) > 0:
|
|
86
|
+
# Keep trail visible for a bit after object leaves
|
|
87
|
+
pass
|
|
88
|
+
else:
|
|
89
|
+
del self.history[tid]
|
|
90
|
+
|
|
91
|
+
# Draw trails
|
|
92
|
+
for i in range(len(detections)):
|
|
93
|
+
tracker_id = detections.tracker_id[i]
|
|
94
|
+
if tracker_id not in self.history or len(self.history[tracker_id]) < 2:
|
|
95
|
+
continue
|
|
96
|
+
|
|
97
|
+
points = list(self.history[tracker_id])
|
|
98
|
+
|
|
99
|
+
# Smooth if requested
|
|
100
|
+
if self.smooth_trail and len(points) >= 3:
|
|
101
|
+
points = self._smooth_points(points)
|
|
102
|
+
|
|
103
|
+
# Get color
|
|
104
|
+
if detections.class_id is not None:
|
|
105
|
+
color = self.color_palette.by_id(detections.class_id[i])
|
|
106
|
+
else:
|
|
107
|
+
color = self.color_palette.by_id(tracker_id)
|
|
108
|
+
|
|
109
|
+
# Draw trail
|
|
110
|
+
for j in range(len(points) - 1):
|
|
111
|
+
if self.fade_trail:
|
|
112
|
+
# Calculate alpha based on position in trail
|
|
113
|
+
alpha = (j + 1) / len(points)
|
|
114
|
+
faded_color = tuple(int(c * alpha) for c in color.as_bgr())
|
|
115
|
+
else:
|
|
116
|
+
faded_color = color.as_bgr()
|
|
117
|
+
|
|
118
|
+
pt1 = tuple(points[j].astype(int))
|
|
119
|
+
pt2 = tuple(points[j + 1].astype(int))
|
|
120
|
+
|
|
121
|
+
cv2.line(annotated, pt1, pt2, faded_color, self.thickness)
|
|
122
|
+
|
|
123
|
+
return annotated
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def _smooth_points(points, window=3):
|
|
127
|
+
"""Apply moving average smoothing to points."""
|
|
128
|
+
points = np.array(points)
|
|
129
|
+
smoothed = []
|
|
130
|
+
for i in range(len(points)):
|
|
131
|
+
start = max(0, i - window // 2)
|
|
132
|
+
end = min(len(points), i + window // 2 + 1)
|
|
133
|
+
smoothed.append(np.mean(points[start:end], axis=0))
|
|
134
|
+
return smoothed
|
|
135
|
+
|
|
136
|
+
def reset(self):
|
|
137
|
+
"""Clear all tracking history."""
|
|
138
|
+
self.history.clear()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
# Alias for easier understanding
|
|
142
|
+
PathAnnotator = TraceAnnotator
|
eye/annotators/utils.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from eye.detection.core import Detections
|
|
7
|
+
from eye.draw.color import Color, ColorPalette
|
|
8
|
+
from eye.geometry.core import Position
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ColorLookup(Enum):
|
|
12
|
+
"""
|
|
13
|
+
Enumeration class to define strategies for mapping colors to annotations.
|
|
14
|
+
|
|
15
|
+
This enum supports three different lookup strategies:
|
|
16
|
+
- `INDEX`: Colors are determined by the index of the detection within the scene.
|
|
17
|
+
- `CLASS`: Colors are determined by the class label of the detected object.
|
|
18
|
+
- `TRACK`: Colors are determined by the tracking identifier of the object.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
INDEX = "index"
|
|
22
|
+
CLASS = "class"
|
|
23
|
+
TRACK = "track"
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def list(cls):
|
|
27
|
+
return list(map(lambda c: c.value, cls))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_color_idx(
|
|
31
|
+
detections: Detections,
|
|
32
|
+
detection_idx: int,
|
|
33
|
+
color_lookup: Union[ColorLookup, np.ndarray] = ColorLookup.CLASS,
|
|
34
|
+
) -> int:
|
|
35
|
+
if detection_idx >= len(detections):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Detection index {detection_idx} "
|
|
38
|
+
f"is out of bounds for detections of length {len(detections)}"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
if isinstance(color_lookup, np.ndarray):
|
|
42
|
+
if len(color_lookup) != len(detections):
|
|
43
|
+
raise ValueError(
|
|
44
|
+
f"Length of color lookup {len(color_lookup)} "
|
|
45
|
+
f"does not match length of detections {len(detections)}"
|
|
46
|
+
)
|
|
47
|
+
return color_lookup[detection_idx]
|
|
48
|
+
elif color_lookup == ColorLookup.INDEX:
|
|
49
|
+
return detection_idx
|
|
50
|
+
elif color_lookup == ColorLookup.CLASS:
|
|
51
|
+
if detections.class_id is None:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"Could not resolve color by class because "
|
|
54
|
+
"Detections do not have class_id. If using an annotator, "
|
|
55
|
+
"try setting color_lookup to sv.ColorLookup.INDEX or "
|
|
56
|
+
"sv.ColorLookup.TRACK."
|
|
57
|
+
)
|
|
58
|
+
return detections.class_id[detection_idx]
|
|
59
|
+
elif color_lookup == ColorLookup.TRACK:
|
|
60
|
+
if detections.tracker_id is None:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"Could not resolve color by track because "
|
|
63
|
+
"Detections do not have tracker_id. Did you call "
|
|
64
|
+
"tracker.update_with_detections(...) before annotating?"
|
|
65
|
+
)
|
|
66
|
+
return detections.tracker_id[detection_idx]
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def resolve_text_background_xyxy(
|
|
70
|
+
center_coordinates: Tuple[int, int],
|
|
71
|
+
text_wh: Tuple[int, int],
|
|
72
|
+
position: Position,
|
|
73
|
+
) -> Tuple[int, int, int, int]:
|
|
74
|
+
center_x, center_y = center_coordinates
|
|
75
|
+
text_w, text_h = text_wh
|
|
76
|
+
|
|
77
|
+
if position == Position.TOP_LEFT:
|
|
78
|
+
return center_x, center_y - text_h, center_x + text_w, center_y
|
|
79
|
+
elif position == Position.TOP_RIGHT:
|
|
80
|
+
return center_x - text_w, center_y - text_h, center_x, center_y
|
|
81
|
+
elif position == Position.TOP_CENTER:
|
|
82
|
+
return (
|
|
83
|
+
center_x - text_w // 2,
|
|
84
|
+
center_y - text_h,
|
|
85
|
+
center_x + text_w // 2,
|
|
86
|
+
center_y,
|
|
87
|
+
)
|
|
88
|
+
elif position == Position.CENTER or position == Position.CENTER_OF_MASS:
|
|
89
|
+
return (
|
|
90
|
+
center_x - text_w // 2,
|
|
91
|
+
center_y - text_h // 2,
|
|
92
|
+
center_x + text_w // 2,
|
|
93
|
+
center_y + text_h // 2,
|
|
94
|
+
)
|
|
95
|
+
elif position == Position.BOTTOM_LEFT:
|
|
96
|
+
return center_x, center_y, center_x + text_w, center_y + text_h
|
|
97
|
+
elif position == Position.BOTTOM_RIGHT:
|
|
98
|
+
return center_x - text_w, center_y, center_x, center_y + text_h
|
|
99
|
+
elif position == Position.BOTTOM_CENTER:
|
|
100
|
+
return (
|
|
101
|
+
center_x - text_w // 2,
|
|
102
|
+
center_y,
|
|
103
|
+
center_x + text_w // 2,
|
|
104
|
+
center_y + text_h,
|
|
105
|
+
)
|
|
106
|
+
elif position == Position.CENTER_LEFT:
|
|
107
|
+
return (
|
|
108
|
+
center_x - text_w,
|
|
109
|
+
center_y - text_h // 2,
|
|
110
|
+
center_x,
|
|
111
|
+
center_y + text_h // 2,
|
|
112
|
+
)
|
|
113
|
+
elif position == Position.CENTER_RIGHT:
|
|
114
|
+
return (
|
|
115
|
+
center_x,
|
|
116
|
+
center_y - text_h // 2,
|
|
117
|
+
center_x + text_w,
|
|
118
|
+
center_y + text_h // 2,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def get_color_by_index(color: Union[Color, ColorPalette], idx: int) -> Color:
|
|
123
|
+
if isinstance(color, ColorPalette):
|
|
124
|
+
return color.by_id(idx)
|
|
125
|
+
return color
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def resolve_color(
|
|
129
|
+
color: Union[Color, ColorPalette],
|
|
130
|
+
detections: Detections,
|
|
131
|
+
detection_idx: int,
|
|
132
|
+
color_lookup: Union[ColorLookup, np.ndarray] = ColorLookup.CLASS,
|
|
133
|
+
) -> Color:
|
|
134
|
+
idx = resolve_color_idx(
|
|
135
|
+
detections=detections,
|
|
136
|
+
detection_idx=detection_idx,
|
|
137
|
+
color_lookup=color_lookup,
|
|
138
|
+
)
|
|
139
|
+
return get_color_by_index(color=color, idx=idx)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class Trace:
|
|
143
|
+
def __init__(
|
|
144
|
+
self,
|
|
145
|
+
max_size: Optional[int] = None,
|
|
146
|
+
start_frame_id: int = 0,
|
|
147
|
+
anchor: Position = Position.CENTER,
|
|
148
|
+
) -> None:
|
|
149
|
+
self.current_frame_id = start_frame_id
|
|
150
|
+
self.max_size = max_size
|
|
151
|
+
self.anchor = anchor
|
|
152
|
+
|
|
153
|
+
self.frame_id = np.array([], dtype=int)
|
|
154
|
+
self.xy = np.empty((0, 2), dtype=np.float32)
|
|
155
|
+
self.tracker_id = np.array([], dtype=int)
|
|
156
|
+
|
|
157
|
+
def put(self, detections: Detections) -> None:
|
|
158
|
+
frame_id = np.full(len(detections), self.current_frame_id, dtype=int)
|
|
159
|
+
self.frame_id = np.concatenate([self.frame_id, frame_id])
|
|
160
|
+
self.xy = np.concatenate(
|
|
161
|
+
[self.xy, detections.get_anchors_coordinates(self.anchor)]
|
|
162
|
+
)
|
|
163
|
+
self.tracker_id = np.concatenate([self.tracker_id, detections.tracker_id])
|
|
164
|
+
|
|
165
|
+
unique_frame_id = np.unique(self.frame_id)
|
|
166
|
+
|
|
167
|
+
if 0 < self.max_size < len(unique_frame_id):
|
|
168
|
+
max_allowed_frame_id = self.current_frame_id - self.max_size + 1
|
|
169
|
+
filtering_mask = self.frame_id >= max_allowed_frame_id
|
|
170
|
+
self.frame_id = self.frame_id[filtering_mask]
|
|
171
|
+
self.xy = self.xy[filtering_mask]
|
|
172
|
+
self.tracker_id = self.tracker_id[filtering_mask]
|
|
173
|
+
|
|
174
|
+
self.current_frame_id += 1
|
|
175
|
+
|
|
176
|
+
def get(self, tracker_id: int) -> np.ndarray:
|
|
177
|
+
return self.xy[self.tracker_id == tracker_id]
|
eye/assets/__init__.py
ADDED
eye/assets/downloader.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from hashlib import new as hash_new
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from shutil import copyfileobj
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
from eye.assets.list import VIDEO_ASSETS, VideoAssets
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from requests import get
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
except ImportError:
|
|
13
|
+
raise ValueError(
|
|
14
|
+
"\n"
|
|
15
|
+
"Please install requests and tqdm to download assets \n"
|
|
16
|
+
"or install eye with assets \n"
|
|
17
|
+
"pip install eye[assets] \n"
|
|
18
|
+
"\n"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def is_md5_hash_matching(filename: str, original_md5_hash: str) -> bool:
|
|
23
|
+
"""
|
|
24
|
+
Check if the MD5 hash of a file matches the original hash.
|
|
25
|
+
|
|
26
|
+
Parameters:
|
|
27
|
+
filename (str): The path to the file to be checked as a string.
|
|
28
|
+
original_md5_hash (str): The original MD5 hash to compare against.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
bool: True if the hashes match, False otherwise.
|
|
32
|
+
"""
|
|
33
|
+
if not os.path.exists(filename):
|
|
34
|
+
return False
|
|
35
|
+
|
|
36
|
+
with open(filename, "rb") as file:
|
|
37
|
+
file_contents = file.read()
|
|
38
|
+
computed_md5_hash = hash_new(name="MD5")
|
|
39
|
+
computed_md5_hash.update(file_contents)
|
|
40
|
+
|
|
41
|
+
return computed_md5_hash.hexdigest() == original_md5_hash
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def download_assets(asset_name: Union[VideoAssets, str]) -> str:
|
|
45
|
+
"""
|
|
46
|
+
Download a specified asset if it doesn't already exist or is corrupted.
|
|
47
|
+
|
|
48
|
+
Parameters:
|
|
49
|
+
asset_name (Union[VideoAssets, str]): The name or type of the asset to be
|
|
50
|
+
downloaded.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
str: The filename of the downloaded asset.
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
from eye.assets import download_assets, VideoAssets
|
|
58
|
+
|
|
59
|
+
download_assets(VideoAssets.VEHICLES)
|
|
60
|
+
"vehicles.mp4"
|
|
61
|
+
```
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
filename = asset_name.value if isinstance(asset_name, VideoAssets) else asset_name
|
|
65
|
+
|
|
66
|
+
if not Path(filename).exists() and filename in VIDEO_ASSETS:
|
|
67
|
+
print(f"Downloading {filename} assets \n")
|
|
68
|
+
response = get(VIDEO_ASSETS[filename][0], stream=True, allow_redirects=True)
|
|
69
|
+
response.raise_for_status()
|
|
70
|
+
|
|
71
|
+
file_size = int(response.headers.get("Content-Length", 0))
|
|
72
|
+
folder_path = Path(filename).expanduser().resolve()
|
|
73
|
+
folder_path.parent.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
|
|
75
|
+
with tqdm.wrapattr(
|
|
76
|
+
response.raw, "read", total=file_size, desc="", colour="#a351fb"
|
|
77
|
+
) as raw_resp:
|
|
78
|
+
with folder_path.open("wb") as file:
|
|
79
|
+
copyfileobj(raw_resp, file)
|
|
80
|
+
|
|
81
|
+
elif Path(filename).exists():
|
|
82
|
+
if not is_md5_hash_matching(filename, VIDEO_ASSETS[filename][1]):
|
|
83
|
+
print("File corrupted. Re-downloading... \n")
|
|
84
|
+
os.remove(filename)
|
|
85
|
+
return download_assets(filename)
|
|
86
|
+
|
|
87
|
+
print(f"{filename} asset download complete. \n")
|
|
88
|
+
|
|
89
|
+
else:
|
|
90
|
+
valid_assets = ", ".join(asset.value for asset in VideoAssets)
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"Invalid asset. It should be one of the following: {valid_assets}."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return filename
|
eye/assets/list.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Dict, Tuple
|
|
3
|
+
|
|
4
|
+
BASE_VIDEO_URL = "https://media.roboflow.com/eye/video-examples/"
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class VideoAssets(Enum):
|
|
8
|
+
"""
|
|
9
|
+
Each member of this enum represents a video asset. The value associated with each
|
|
10
|
+
member is the filename of the video.
|
|
11
|
+
|
|
12
|
+
| Enum Member | Video Filename | Video URL |
|
|
13
|
+
|------------------------|----------------------------|---------------------------------------------------------------------------------------|
|
|
14
|
+
| `VEHICLES` | `vehicles.mp4` | [Link](https://media.roboflow.com/eye/video-examples/vehicles.mp4) |
|
|
15
|
+
| `MILK_BOTTLING_PLANT` | `milk-bottling-plant.mp4` | [Link](https://media.roboflow.com/eye/video-examples/milk-bottling-plant.mp4) |
|
|
16
|
+
| `VEHICLES_2` | `vehicles-2.mp4` | [Link](https://media.roboflow.com/eye/video-examples/vehicles-2.mp4) |
|
|
17
|
+
| `GROCERY_STORE` | `grocery-store.mp4` | [Link](https://media.roboflow.com/eye/video-examples/grocery-store.mp4) |
|
|
18
|
+
| `SUBWAY` | `subway.mp4` | [Link](https://media.roboflow.com/eye/video-examples/subway.mp4) |
|
|
19
|
+
| `MARKET_SQUARE` | `market-square.mp4` | [Link](https://media.roboflow.com/eye/video-examples/market-square.mp4) |
|
|
20
|
+
| `PEOPLE_WALKING` | `people-walking.mp4` | [Link](https://media.roboflow.com/eye/video-examples/people-walking.mp4) |
|
|
21
|
+
| `BEACH` | `beach-1.mp4` | [Link](https://media.roboflow.com/eye/video-examples/beach-1.mp4) |
|
|
22
|
+
| `BASKETBALL` | `basketball-1.mp4` | [Link](https://media.roboflow.com/eye/video-examples/basketball-1.mp4) |
|
|
23
|
+
| `SKIING` | `skiing.mp4` | [Link](https://media.roboflow.com/eye/video-examples/skiing.mp4) |
|
|
24
|
+
""" # noqa: E501 // docs
|
|
25
|
+
|
|
26
|
+
VEHICLES = "vehicles.mp4"
|
|
27
|
+
MILK_BOTTLING_PLANT = "milk-bottling-plant.mp4"
|
|
28
|
+
VEHICLES_2 = "vehicles-2.mp4"
|
|
29
|
+
GROCERY_STORE = "grocery-store.mp4"
|
|
30
|
+
SUBWAY = "subway.mp4"
|
|
31
|
+
MARKET_SQUARE = "market-square.mp4"
|
|
32
|
+
PEOPLE_WALKING = "people-walking.mp4"
|
|
33
|
+
BEACH = "beach-1.mp4"
|
|
34
|
+
BASKETBALL = "basketball-1.mp4"
|
|
35
|
+
SKIING = "skiing.mp4"
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def list(cls):
|
|
39
|
+
return list(map(lambda c: c.value, cls))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
VIDEO_ASSETS: Dict[str, Tuple[str, str]] = {
|
|
43
|
+
VideoAssets.VEHICLES.value: (
|
|
44
|
+
f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES.value}",
|
|
45
|
+
"8155ff4e4de08cfa25f39de96483f918",
|
|
46
|
+
),
|
|
47
|
+
VideoAssets.VEHICLES_2.value: (
|
|
48
|
+
f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES_2.value}",
|
|
49
|
+
"830af6fba21ffbf14867a7fea595937b",
|
|
50
|
+
),
|
|
51
|
+
VideoAssets.MILK_BOTTLING_PLANT.value: (
|
|
52
|
+
f"{BASE_VIDEO_URL}{VideoAssets.MILK_BOTTLING_PLANT.value}",
|
|
53
|
+
"9e8fb6e883f842a38b3d34267290bdc7",
|
|
54
|
+
),
|
|
55
|
+
VideoAssets.GROCERY_STORE.value: (
|
|
56
|
+
f"{BASE_VIDEO_URL}{VideoAssets.GROCERY_STORE.value}",
|
|
57
|
+
"11402e7b861c1980527d3d74cbe3b366",
|
|
58
|
+
),
|
|
59
|
+
VideoAssets.SUBWAY.value: (
|
|
60
|
+
f"{BASE_VIDEO_URL}{VideoAssets.SUBWAY.value}",
|
|
61
|
+
"453475750691fb23c56a0cffef089194",
|
|
62
|
+
),
|
|
63
|
+
VideoAssets.MARKET_SQUARE.value: (
|
|
64
|
+
f"{BASE_VIDEO_URL}{VideoAssets.MARKET_SQUARE.value}",
|
|
65
|
+
"859179bf4a21f80a8baabfdb2ed716dc",
|
|
66
|
+
),
|
|
67
|
+
VideoAssets.PEOPLE_WALKING.value: (
|
|
68
|
+
f"{BASE_VIDEO_URL}{VideoAssets.PEOPLE_WALKING.value}",
|
|
69
|
+
"0574c053c8686c3f1dc0aa3743e45cb9",
|
|
70
|
+
),
|
|
71
|
+
VideoAssets.BEACH.value: (
|
|
72
|
+
f"{BASE_VIDEO_URL}{VideoAssets.BEACH.value}",
|
|
73
|
+
"4175d42fec4d450ed081523fd39e0cf8",
|
|
74
|
+
),
|
|
75
|
+
VideoAssets.BASKETBALL.value: (
|
|
76
|
+
f"{BASE_VIDEO_URL}{VideoAssets.BASKETBALL.value}",
|
|
77
|
+
"60d94a3c7c47d16f09d342b088012ecc",
|
|
78
|
+
),
|
|
79
|
+
VideoAssets.SKIING.value: (
|
|
80
|
+
f"{BASE_VIDEO_URL}{VideoAssets.SKIING.value}",
|
|
81
|
+
"d30987cbab1bbc5934199cdd1b293119",
|
|
82
|
+
),
|
|
83
|
+
}
|
|
File without changes
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _validate_class_ids(class_id: Any, n: int) -> None:
|
|
10
|
+
"""
|
|
11
|
+
Ensure that class_id is a 1d np.ndarray with (n, ) shape.
|
|
12
|
+
"""
|
|
13
|
+
is_valid = isinstance(class_id, np.ndarray) and class_id.shape == (n,)
|
|
14
|
+
if not is_valid:
|
|
15
|
+
raise ValueError("class_id must be 1d np.ndarray with (n, ) shape")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _validate_confidence(confidence: Any, n: int) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Ensure that confidence is a 1d np.ndarray with (n, ) shape.
|
|
21
|
+
"""
|
|
22
|
+
if confidence is not None:
|
|
23
|
+
is_valid = isinstance(confidence, np.ndarray) and confidence.shape == (n,)
|
|
24
|
+
if not is_valid:
|
|
25
|
+
raise ValueError("confidence must be 1d np.ndarray with (n, ) shape")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class Classifications:
|
|
30
|
+
class_id: np.ndarray
|
|
31
|
+
confidence: Optional[np.ndarray] = None
|
|
32
|
+
|
|
33
|
+
def __post_init__(self) -> None:
|
|
34
|
+
"""
|
|
35
|
+
Validate the classification inputs.
|
|
36
|
+
"""
|
|
37
|
+
n = len(self.class_id)
|
|
38
|
+
|
|
39
|
+
_validate_class_ids(self.class_id, n)
|
|
40
|
+
_validate_confidence(self.confidence, n)
|
|
41
|
+
|
|
42
|
+
def __len__(self) -> int:
|
|
43
|
+
"""
|
|
44
|
+
Returns the number of classifications.
|
|
45
|
+
"""
|
|
46
|
+
return len(self.class_id)
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_clip(cls, clip_results) -> Classifications:
|
|
50
|
+
"""
|
|
51
|
+
Creates a Classifications instance from a
|
|
52
|
+
[clip](https://github.com/openai/clip) inference result.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
clip_results (np.ndarray): The inference result from clip model.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Classifications: A new Classifications object.
|
|
59
|
+
|
|
60
|
+
Example:
|
|
61
|
+
```python
|
|
62
|
+
from PIL import Image
|
|
63
|
+
import clip
|
|
64
|
+
import eye as sv
|
|
65
|
+
|
|
66
|
+
model, preprocess = clip.load('ViT-B/32')
|
|
67
|
+
|
|
68
|
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
|
69
|
+
image = preprocess(image).unsqueeze(0)
|
|
70
|
+
|
|
71
|
+
text = clip.tokenize(["a diagram", "a dog", "a cat"])
|
|
72
|
+
output, _ = model(image, text)
|
|
73
|
+
classifications = sv.Classifications.from_clip(output)
|
|
74
|
+
```
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
confidence = clip_results.softmax(dim=-1).cpu().detach().numpy()[0]
|
|
78
|
+
|
|
79
|
+
if len(confidence) == 0:
|
|
80
|
+
return cls(class_id=np.array([]), confidence=np.array([]))
|
|
81
|
+
|
|
82
|
+
class_ids = np.arange(len(confidence))
|
|
83
|
+
return cls(class_id=class_ids, confidence=confidence)
|
|
84
|
+
|
|
85
|
+
@classmethod
|
|
86
|
+
def from_ultralytics(cls, ultralytics_results) -> Classifications:
|
|
87
|
+
"""
|
|
88
|
+
Creates a Classifications instance from a
|
|
89
|
+
[ultralytics](https://github.com/ultralytics/ultralytics) inference result.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
ultralytics_results (ultralytics.engine.results.Results):
|
|
93
|
+
The inference result from ultralytics model.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Classifications: A new Classifications object.
|
|
97
|
+
|
|
98
|
+
Example:
|
|
99
|
+
```python
|
|
100
|
+
import cv2
|
|
101
|
+
from ultralytics import YOLO
|
|
102
|
+
import eye as sv
|
|
103
|
+
|
|
104
|
+
image = cv2.imread(SOURCE_IMAGE_PATH)
|
|
105
|
+
model = YOLO('yolov8n-cls.pt')
|
|
106
|
+
|
|
107
|
+
output = model(image)[0]
|
|
108
|
+
classifications = sv.Classifications.from_ultralytics(output)
|
|
109
|
+
```
|
|
110
|
+
"""
|
|
111
|
+
confidence = ultralytics_results.probs.data.cpu().numpy()
|
|
112
|
+
return cls(class_id=np.arange(confidence.shape[0]), confidence=confidence)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def from_timm(cls, timm_results) -> Classifications:
|
|
116
|
+
"""
|
|
117
|
+
Creates a Classifications instance from a
|
|
118
|
+
[timm](https://huggingface.co/docs/hub/timm) inference result.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
timm_results (torch.Tensor): The inference result from timm model.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
Classifications: A new Classifications object.
|
|
125
|
+
|
|
126
|
+
Example:
|
|
127
|
+
```python
|
|
128
|
+
from PIL import Image
|
|
129
|
+
import timm
|
|
130
|
+
from timm.data import resolve_data_config, create_transform
|
|
131
|
+
import eye as sv
|
|
132
|
+
|
|
133
|
+
model = timm.create_model(
|
|
134
|
+
model_name='hf-hub:nateraw/resnet50-oxford-iiit-pet',
|
|
135
|
+
pretrained=True
|
|
136
|
+
).eval()
|
|
137
|
+
|
|
138
|
+
config = resolve_data_config({}, model=model)
|
|
139
|
+
transform = create_transform(**config)
|
|
140
|
+
|
|
141
|
+
image = Image.open(SOURCE_IMAGE_PATH).convert('RGB')
|
|
142
|
+
x = transform(image).unsqueeze(0)
|
|
143
|
+
|
|
144
|
+
output = model(x)
|
|
145
|
+
|
|
146
|
+
classifications = sv.Classifications.from_timm(output)
|
|
147
|
+
```
|
|
148
|
+
"""
|
|
149
|
+
confidence = timm_results.cpu().detach().numpy()[0]
|
|
150
|
+
|
|
151
|
+
if len(confidence) == 0:
|
|
152
|
+
return cls(class_id=np.array([]), confidence=np.array([]))
|
|
153
|
+
|
|
154
|
+
class_id = np.arange(len(confidence))
|
|
155
|
+
return cls(class_id=class_id, confidence=confidence)
|
|
156
|
+
|
|
157
|
+
def get_top_k(self, k: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
158
|
+
"""
|
|
159
|
+
Retrieve the top k class IDs and confidences,
|
|
160
|
+
ordered in descending order by confidence.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
k (int): The number of top class IDs and confidences to retrieve.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Tuple[np.ndarray, np.ndarray]: A tuple containing
|
|
167
|
+
the top k class IDs and confidences.
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
```python
|
|
171
|
+
import eye as sv
|
|
172
|
+
|
|
173
|
+
classifications = sv.Classifications(...)
|
|
174
|
+
|
|
175
|
+
classifications.get_top_k(1)
|
|
176
|
+
|
|
177
|
+
(array([1]), array([0.9]))
|
|
178
|
+
```
|
|
179
|
+
"""
|
|
180
|
+
if self.confidence is None:
|
|
181
|
+
raise ValueError("top_k could not be calculated, confidence is None")
|
|
182
|
+
|
|
183
|
+
order = np.argsort(self.confidence)[::-1]
|
|
184
|
+
top_k_order = order[:k]
|
|
185
|
+
top_k_class_id = self.class_id[top_k_order]
|
|
186
|
+
top_k_confidence = self.confidence[top_k_order]
|
|
187
|
+
|
|
188
|
+
return top_k_class_id, top_k_confidence
|
eye/config.py
ADDED
eye/core/__init__.py
ADDED
|
File without changes
|