eye-cv 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. eye/__init__.py +115 -0
  2. eye/__init___supervision_original.py +120 -0
  3. eye/annotators/__init__.py +0 -0
  4. eye/annotators/base.py +22 -0
  5. eye/annotators/core.py +2699 -0
  6. eye/annotators/line.py +107 -0
  7. eye/annotators/modern.py +529 -0
  8. eye/annotators/trace.py +142 -0
  9. eye/annotators/utils.py +177 -0
  10. eye/assets/__init__.py +2 -0
  11. eye/assets/downloader.py +95 -0
  12. eye/assets/list.py +83 -0
  13. eye/classification/__init__.py +0 -0
  14. eye/classification/core.py +188 -0
  15. eye/config.py +2 -0
  16. eye/core/__init__.py +0 -0
  17. eye/core/trackers/__init__.py +1 -0
  18. eye/core/trackers/botsort_tracker.py +336 -0
  19. eye/core/trackers/bytetrack_tracker.py +284 -0
  20. eye/core/trackers/sort_tracker.py +200 -0
  21. eye/core/tracking.py +146 -0
  22. eye/dataset/__init__.py +0 -0
  23. eye/dataset/core.py +919 -0
  24. eye/dataset/formats/__init__.py +0 -0
  25. eye/dataset/formats/coco.py +258 -0
  26. eye/dataset/formats/pascal_voc.py +279 -0
  27. eye/dataset/formats/yolo.py +272 -0
  28. eye/dataset/utils.py +259 -0
  29. eye/detection/__init__.py +0 -0
  30. eye/detection/auto_convert.py +155 -0
  31. eye/detection/core.py +1529 -0
  32. eye/detection/detections_enhanced.py +392 -0
  33. eye/detection/line_zone.py +859 -0
  34. eye/detection/lmm.py +184 -0
  35. eye/detection/overlap_filter.py +270 -0
  36. eye/detection/tools/__init__.py +0 -0
  37. eye/detection/tools/csv_sink.py +181 -0
  38. eye/detection/tools/inference_slicer.py +288 -0
  39. eye/detection/tools/json_sink.py +142 -0
  40. eye/detection/tools/polygon_zone.py +202 -0
  41. eye/detection/tools/smoother.py +123 -0
  42. eye/detection/tools/smoothing.py +179 -0
  43. eye/detection/tools/smoothing_config.py +202 -0
  44. eye/detection/tools/transformers.py +247 -0
  45. eye/detection/utils.py +1175 -0
  46. eye/draw/__init__.py +0 -0
  47. eye/draw/color.py +154 -0
  48. eye/draw/utils.py +374 -0
  49. eye/filters.py +112 -0
  50. eye/geometry/__init__.py +0 -0
  51. eye/geometry/core.py +128 -0
  52. eye/geometry/utils.py +47 -0
  53. eye/keypoint/__init__.py +0 -0
  54. eye/keypoint/annotators.py +442 -0
  55. eye/keypoint/core.py +687 -0
  56. eye/keypoint/skeletons.py +2647 -0
  57. eye/metrics/__init__.py +21 -0
  58. eye/metrics/core.py +72 -0
  59. eye/metrics/detection.py +843 -0
  60. eye/metrics/f1_score.py +648 -0
  61. eye/metrics/mean_average_precision.py +628 -0
  62. eye/metrics/mean_average_recall.py +697 -0
  63. eye/metrics/precision.py +653 -0
  64. eye/metrics/recall.py +652 -0
  65. eye/metrics/utils/__init__.py +0 -0
  66. eye/metrics/utils/object_size.py +158 -0
  67. eye/metrics/utils/utils.py +9 -0
  68. eye/py.typed +0 -0
  69. eye/quick.py +104 -0
  70. eye/tracker/__init__.py +0 -0
  71. eye/tracker/byte_tracker/__init__.py +0 -0
  72. eye/tracker/byte_tracker/core.py +386 -0
  73. eye/tracker/byte_tracker/kalman_filter.py +205 -0
  74. eye/tracker/byte_tracker/matching.py +69 -0
  75. eye/tracker/byte_tracker/single_object_track.py +178 -0
  76. eye/tracker/byte_tracker/utils.py +18 -0
  77. eye/utils/__init__.py +0 -0
  78. eye/utils/conversion.py +132 -0
  79. eye/utils/file.py +159 -0
  80. eye/utils/image.py +794 -0
  81. eye/utils/internal.py +200 -0
  82. eye/utils/iterables.py +84 -0
  83. eye/utils/notebook.py +114 -0
  84. eye/utils/video.py +307 -0
  85. eye/utils_eye/__init__.py +1 -0
  86. eye/utils_eye/geometry.py +71 -0
  87. eye/utils_eye/nms.py +55 -0
  88. eye/validators/__init__.py +140 -0
  89. eye/web.py +271 -0
  90. eye_cv-1.0.0.dist-info/METADATA +319 -0
  91. eye_cv-1.0.0.dist-info/RECORD +94 -0
  92. eye_cv-1.0.0.dist-info/WHEEL +5 -0
  93. eye_cv-1.0.0.dist-info/licenses/LICENSE +21 -0
  94. eye_cv-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,142 @@
1
+ """Trace annotator for drawing object paths."""
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from typing import Optional, Dict
6
+ from collections import deque
7
+ from ..core.detections import Detections
8
+ from ..core.colors import ColorPalette
9
+
10
+
11
+ class TraceAnnotator:
12
+ """Draw tracking trails/paths for objects.
13
+
14
+ Innovation: Fading trails, variable thickness, trail smoothing.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ color_palette: Optional[ColorPalette] = None,
20
+ thickness: int = 2,
21
+ trace_length: int = 50,
22
+ fade_trail: bool = True,
23
+ smooth_trail: bool = False
24
+ ):
25
+ """
26
+ Args:
27
+ color_palette: Color palette
28
+ thickness: Line thickness
29
+ trace_length: Maximum number of points in trail
30
+ fade_trail: Gradually fade older trail points
31
+ smooth_trail: Apply smoothing to trail
32
+ """
33
+ from ..core.colors import PredefinedPalettes
34
+ self.color_palette = color_palette or PredefinedPalettes.bright()
35
+ self.thickness = thickness
36
+ self.trace_length = trace_length
37
+ self.fade_trail = fade_trail
38
+ self.smooth_trail = smooth_trail
39
+
40
+ # Track history for each object
41
+ self.history: Dict[int, deque] = {}
42
+ # Position smoothing: base alpha (0-1), lower = more smoothing
43
+ self.position_smoothing_alpha = 0.6
44
+ # If movement speed (pixels/frame) exceeds this, use stronger smoothing
45
+ self.position_speed_threshold = 8.0
46
+ # Alpha to use when speed > threshold
47
+ self.fast_position_smoothing_alpha = 0.3
48
+
49
+ def annotate(
50
+ self,
51
+ image: np.ndarray,
52
+ detections: Detections
53
+ ) -> np.ndarray:
54
+ """Draw trails on image."""
55
+ annotated = image.copy()
56
+
57
+ if detections.tracker_id is None:
58
+ return annotated
59
+
60
+ # Update history
61
+ active_ids = set()
62
+ for i in range(len(detections)):
63
+ tracker_id = detections.tracker_id[i]
64
+ active_ids.add(tracker_id)
65
+ center = detections.center[i]
66
+
67
+ if tracker_id not in self.history:
68
+ self.history[tracker_id] = deque(maxlen=self.trace_length)
69
+ # Apply optional temporal smoothing when appending to history
70
+ if len(self.history[tracker_id]) > 0 and self.position_smoothing_alpha is not None:
71
+ last = np.array(self.history[tracker_id][-1], dtype=float)
72
+ newc = np.array(center, dtype=float)
73
+ dist = float(np.linalg.norm(newc - last))
74
+ if self.position_speed_threshold is not None and dist > float(self.position_speed_threshold):
75
+ alpha = float(self.fast_position_smoothing_alpha)
76
+ else:
77
+ alpha = float(self.position_smoothing_alpha)
78
+ center = alpha * newc + (1.0 - alpha) * last
79
+
80
+ self.history[tracker_id].append(center)
81
+
82
+ # Remove old tracks
83
+ inactive_ids = set(self.history.keys()) - active_ids
84
+ for tid in inactive_ids:
85
+ if len(self.history[tid]) > 0:
86
+ # Keep trail visible for a bit after object leaves
87
+ pass
88
+ else:
89
+ del self.history[tid]
90
+
91
+ # Draw trails
92
+ for i in range(len(detections)):
93
+ tracker_id = detections.tracker_id[i]
94
+ if tracker_id not in self.history or len(self.history[tracker_id]) < 2:
95
+ continue
96
+
97
+ points = list(self.history[tracker_id])
98
+
99
+ # Smooth if requested
100
+ if self.smooth_trail and len(points) >= 3:
101
+ points = self._smooth_points(points)
102
+
103
+ # Get color
104
+ if detections.class_id is not None:
105
+ color = self.color_palette.by_id(detections.class_id[i])
106
+ else:
107
+ color = self.color_palette.by_id(tracker_id)
108
+
109
+ # Draw trail
110
+ for j in range(len(points) - 1):
111
+ if self.fade_trail:
112
+ # Calculate alpha based on position in trail
113
+ alpha = (j + 1) / len(points)
114
+ faded_color = tuple(int(c * alpha) for c in color.as_bgr())
115
+ else:
116
+ faded_color = color.as_bgr()
117
+
118
+ pt1 = tuple(points[j].astype(int))
119
+ pt2 = tuple(points[j + 1].astype(int))
120
+
121
+ cv2.line(annotated, pt1, pt2, faded_color, self.thickness)
122
+
123
+ return annotated
124
+
125
+ @staticmethod
126
+ def _smooth_points(points, window=3):
127
+ """Apply moving average smoothing to points."""
128
+ points = np.array(points)
129
+ smoothed = []
130
+ for i in range(len(points)):
131
+ start = max(0, i - window // 2)
132
+ end = min(len(points), i + window // 2 + 1)
133
+ smoothed.append(np.mean(points[start:end], axis=0))
134
+ return smoothed
135
+
136
+ def reset(self):
137
+ """Clear all tracking history."""
138
+ self.history.clear()
139
+
140
+
141
+ # Alias for easier understanding
142
+ PathAnnotator = TraceAnnotator
@@ -0,0 +1,177 @@
1
+ from enum import Enum
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from eye.detection.core import Detections
7
+ from eye.draw.color import Color, ColorPalette
8
+ from eye.geometry.core import Position
9
+
10
+
11
+ class ColorLookup(Enum):
12
+ """
13
+ Enumeration class to define strategies for mapping colors to annotations.
14
+
15
+ This enum supports three different lookup strategies:
16
+ - `INDEX`: Colors are determined by the index of the detection within the scene.
17
+ - `CLASS`: Colors are determined by the class label of the detected object.
18
+ - `TRACK`: Colors are determined by the tracking identifier of the object.
19
+ """
20
+
21
+ INDEX = "index"
22
+ CLASS = "class"
23
+ TRACK = "track"
24
+
25
+ @classmethod
26
+ def list(cls):
27
+ return list(map(lambda c: c.value, cls))
28
+
29
+
30
+ def resolve_color_idx(
31
+ detections: Detections,
32
+ detection_idx: int,
33
+ color_lookup: Union[ColorLookup, np.ndarray] = ColorLookup.CLASS,
34
+ ) -> int:
35
+ if detection_idx >= len(detections):
36
+ raise ValueError(
37
+ f"Detection index {detection_idx} "
38
+ f"is out of bounds for detections of length {len(detections)}"
39
+ )
40
+
41
+ if isinstance(color_lookup, np.ndarray):
42
+ if len(color_lookup) != len(detections):
43
+ raise ValueError(
44
+ f"Length of color lookup {len(color_lookup)} "
45
+ f"does not match length of detections {len(detections)}"
46
+ )
47
+ return color_lookup[detection_idx]
48
+ elif color_lookup == ColorLookup.INDEX:
49
+ return detection_idx
50
+ elif color_lookup == ColorLookup.CLASS:
51
+ if detections.class_id is None:
52
+ raise ValueError(
53
+ "Could not resolve color by class because "
54
+ "Detections do not have class_id. If using an annotator, "
55
+ "try setting color_lookup to sv.ColorLookup.INDEX or "
56
+ "sv.ColorLookup.TRACK."
57
+ )
58
+ return detections.class_id[detection_idx]
59
+ elif color_lookup == ColorLookup.TRACK:
60
+ if detections.tracker_id is None:
61
+ raise ValueError(
62
+ "Could not resolve color by track because "
63
+ "Detections do not have tracker_id. Did you call "
64
+ "tracker.update_with_detections(...) before annotating?"
65
+ )
66
+ return detections.tracker_id[detection_idx]
67
+
68
+
69
+ def resolve_text_background_xyxy(
70
+ center_coordinates: Tuple[int, int],
71
+ text_wh: Tuple[int, int],
72
+ position: Position,
73
+ ) -> Tuple[int, int, int, int]:
74
+ center_x, center_y = center_coordinates
75
+ text_w, text_h = text_wh
76
+
77
+ if position == Position.TOP_LEFT:
78
+ return center_x, center_y - text_h, center_x + text_w, center_y
79
+ elif position == Position.TOP_RIGHT:
80
+ return center_x - text_w, center_y - text_h, center_x, center_y
81
+ elif position == Position.TOP_CENTER:
82
+ return (
83
+ center_x - text_w // 2,
84
+ center_y - text_h,
85
+ center_x + text_w // 2,
86
+ center_y,
87
+ )
88
+ elif position == Position.CENTER or position == Position.CENTER_OF_MASS:
89
+ return (
90
+ center_x - text_w // 2,
91
+ center_y - text_h // 2,
92
+ center_x + text_w // 2,
93
+ center_y + text_h // 2,
94
+ )
95
+ elif position == Position.BOTTOM_LEFT:
96
+ return center_x, center_y, center_x + text_w, center_y + text_h
97
+ elif position == Position.BOTTOM_RIGHT:
98
+ return center_x - text_w, center_y, center_x, center_y + text_h
99
+ elif position == Position.BOTTOM_CENTER:
100
+ return (
101
+ center_x - text_w // 2,
102
+ center_y,
103
+ center_x + text_w // 2,
104
+ center_y + text_h,
105
+ )
106
+ elif position == Position.CENTER_LEFT:
107
+ return (
108
+ center_x - text_w,
109
+ center_y - text_h // 2,
110
+ center_x,
111
+ center_y + text_h // 2,
112
+ )
113
+ elif position == Position.CENTER_RIGHT:
114
+ return (
115
+ center_x,
116
+ center_y - text_h // 2,
117
+ center_x + text_w,
118
+ center_y + text_h // 2,
119
+ )
120
+
121
+
122
+ def get_color_by_index(color: Union[Color, ColorPalette], idx: int) -> Color:
123
+ if isinstance(color, ColorPalette):
124
+ return color.by_id(idx)
125
+ return color
126
+
127
+
128
+ def resolve_color(
129
+ color: Union[Color, ColorPalette],
130
+ detections: Detections,
131
+ detection_idx: int,
132
+ color_lookup: Union[ColorLookup, np.ndarray] = ColorLookup.CLASS,
133
+ ) -> Color:
134
+ idx = resolve_color_idx(
135
+ detections=detections,
136
+ detection_idx=detection_idx,
137
+ color_lookup=color_lookup,
138
+ )
139
+ return get_color_by_index(color=color, idx=idx)
140
+
141
+
142
+ class Trace:
143
+ def __init__(
144
+ self,
145
+ max_size: Optional[int] = None,
146
+ start_frame_id: int = 0,
147
+ anchor: Position = Position.CENTER,
148
+ ) -> None:
149
+ self.current_frame_id = start_frame_id
150
+ self.max_size = max_size
151
+ self.anchor = anchor
152
+
153
+ self.frame_id = np.array([], dtype=int)
154
+ self.xy = np.empty((0, 2), dtype=np.float32)
155
+ self.tracker_id = np.array([], dtype=int)
156
+
157
+ def put(self, detections: Detections) -> None:
158
+ frame_id = np.full(len(detections), self.current_frame_id, dtype=int)
159
+ self.frame_id = np.concatenate([self.frame_id, frame_id])
160
+ self.xy = np.concatenate(
161
+ [self.xy, detections.get_anchors_coordinates(self.anchor)]
162
+ )
163
+ self.tracker_id = np.concatenate([self.tracker_id, detections.tracker_id])
164
+
165
+ unique_frame_id = np.unique(self.frame_id)
166
+
167
+ if 0 < self.max_size < len(unique_frame_id):
168
+ max_allowed_frame_id = self.current_frame_id - self.max_size + 1
169
+ filtering_mask = self.frame_id >= max_allowed_frame_id
170
+ self.frame_id = self.frame_id[filtering_mask]
171
+ self.xy = self.xy[filtering_mask]
172
+ self.tracker_id = self.tracker_id[filtering_mask]
173
+
174
+ self.current_frame_id += 1
175
+
176
+ def get(self, tracker_id: int) -> np.ndarray:
177
+ return self.xy[self.tracker_id == tracker_id]
eye/assets/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ from eye.assets.downloader import download_assets
2
+ from eye.assets.list import VideoAssets
@@ -0,0 +1,95 @@
1
+ import os
2
+ from hashlib import new as hash_new
3
+ from pathlib import Path
4
+ from shutil import copyfileobj
5
+ from typing import Union
6
+
7
+ from eye.assets.list import VIDEO_ASSETS, VideoAssets
8
+
9
+ try:
10
+ from requests import get
11
+ from tqdm.auto import tqdm
12
+ except ImportError:
13
+ raise ValueError(
14
+ "\n"
15
+ "Please install requests and tqdm to download assets \n"
16
+ "or install eye with assets \n"
17
+ "pip install eye[assets] \n"
18
+ "\n"
19
+ )
20
+
21
+
22
+ def is_md5_hash_matching(filename: str, original_md5_hash: str) -> bool:
23
+ """
24
+ Check if the MD5 hash of a file matches the original hash.
25
+
26
+ Parameters:
27
+ filename (str): The path to the file to be checked as a string.
28
+ original_md5_hash (str): The original MD5 hash to compare against.
29
+
30
+ Returns:
31
+ bool: True if the hashes match, False otherwise.
32
+ """
33
+ if not os.path.exists(filename):
34
+ return False
35
+
36
+ with open(filename, "rb") as file:
37
+ file_contents = file.read()
38
+ computed_md5_hash = hash_new(name="MD5")
39
+ computed_md5_hash.update(file_contents)
40
+
41
+ return computed_md5_hash.hexdigest() == original_md5_hash
42
+
43
+
44
+ def download_assets(asset_name: Union[VideoAssets, str]) -> str:
45
+ """
46
+ Download a specified asset if it doesn't already exist or is corrupted.
47
+
48
+ Parameters:
49
+ asset_name (Union[VideoAssets, str]): The name or type of the asset to be
50
+ downloaded.
51
+
52
+ Returns:
53
+ str: The filename of the downloaded asset.
54
+
55
+ Example:
56
+ ```python
57
+ from eye.assets import download_assets, VideoAssets
58
+
59
+ download_assets(VideoAssets.VEHICLES)
60
+ "vehicles.mp4"
61
+ ```
62
+ """
63
+
64
+ filename = asset_name.value if isinstance(asset_name, VideoAssets) else asset_name
65
+
66
+ if not Path(filename).exists() and filename in VIDEO_ASSETS:
67
+ print(f"Downloading {filename} assets \n")
68
+ response = get(VIDEO_ASSETS[filename][0], stream=True, allow_redirects=True)
69
+ response.raise_for_status()
70
+
71
+ file_size = int(response.headers.get("Content-Length", 0))
72
+ folder_path = Path(filename).expanduser().resolve()
73
+ folder_path.parent.mkdir(parents=True, exist_ok=True)
74
+
75
+ with tqdm.wrapattr(
76
+ response.raw, "read", total=file_size, desc="", colour="#a351fb"
77
+ ) as raw_resp:
78
+ with folder_path.open("wb") as file:
79
+ copyfileobj(raw_resp, file)
80
+
81
+ elif Path(filename).exists():
82
+ if not is_md5_hash_matching(filename, VIDEO_ASSETS[filename][1]):
83
+ print("File corrupted. Re-downloading... \n")
84
+ os.remove(filename)
85
+ return download_assets(filename)
86
+
87
+ print(f"{filename} asset download complete. \n")
88
+
89
+ else:
90
+ valid_assets = ", ".join(asset.value for asset in VideoAssets)
91
+ raise ValueError(
92
+ f"Invalid asset. It should be one of the following: {valid_assets}."
93
+ )
94
+
95
+ return filename
eye/assets/list.py ADDED
@@ -0,0 +1,83 @@
1
+ from enum import Enum
2
+ from typing import Dict, Tuple
3
+
4
+ BASE_VIDEO_URL = "https://media.roboflow.com/eye/video-examples/"
5
+
6
+
7
+ class VideoAssets(Enum):
8
+ """
9
+ Each member of this enum represents a video asset. The value associated with each
10
+ member is the filename of the video.
11
+
12
+ | Enum Member | Video Filename | Video URL |
13
+ |------------------------|----------------------------|---------------------------------------------------------------------------------------|
14
+ | `VEHICLES` | `vehicles.mp4` | [Link](https://media.roboflow.com/eye/video-examples/vehicles.mp4) |
15
+ | `MILK_BOTTLING_PLANT` | `milk-bottling-plant.mp4` | [Link](https://media.roboflow.com/eye/video-examples/milk-bottling-plant.mp4) |
16
+ | `VEHICLES_2` | `vehicles-2.mp4` | [Link](https://media.roboflow.com/eye/video-examples/vehicles-2.mp4) |
17
+ | `GROCERY_STORE` | `grocery-store.mp4` | [Link](https://media.roboflow.com/eye/video-examples/grocery-store.mp4) |
18
+ | `SUBWAY` | `subway.mp4` | [Link](https://media.roboflow.com/eye/video-examples/subway.mp4) |
19
+ | `MARKET_SQUARE` | `market-square.mp4` | [Link](https://media.roboflow.com/eye/video-examples/market-square.mp4) |
20
+ | `PEOPLE_WALKING` | `people-walking.mp4` | [Link](https://media.roboflow.com/eye/video-examples/people-walking.mp4) |
21
+ | `BEACH` | `beach-1.mp4` | [Link](https://media.roboflow.com/eye/video-examples/beach-1.mp4) |
22
+ | `BASKETBALL` | `basketball-1.mp4` | [Link](https://media.roboflow.com/eye/video-examples/basketball-1.mp4) |
23
+ | `SKIING` | `skiing.mp4` | [Link](https://media.roboflow.com/eye/video-examples/skiing.mp4) |
24
+ """ # noqa: E501 // docs
25
+
26
+ VEHICLES = "vehicles.mp4"
27
+ MILK_BOTTLING_PLANT = "milk-bottling-plant.mp4"
28
+ VEHICLES_2 = "vehicles-2.mp4"
29
+ GROCERY_STORE = "grocery-store.mp4"
30
+ SUBWAY = "subway.mp4"
31
+ MARKET_SQUARE = "market-square.mp4"
32
+ PEOPLE_WALKING = "people-walking.mp4"
33
+ BEACH = "beach-1.mp4"
34
+ BASKETBALL = "basketball-1.mp4"
35
+ SKIING = "skiing.mp4"
36
+
37
+ @classmethod
38
+ def list(cls):
39
+ return list(map(lambda c: c.value, cls))
40
+
41
+
42
+ VIDEO_ASSETS: Dict[str, Tuple[str, str]] = {
43
+ VideoAssets.VEHICLES.value: (
44
+ f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES.value}",
45
+ "8155ff4e4de08cfa25f39de96483f918",
46
+ ),
47
+ VideoAssets.VEHICLES_2.value: (
48
+ f"{BASE_VIDEO_URL}{VideoAssets.VEHICLES_2.value}",
49
+ "830af6fba21ffbf14867a7fea595937b",
50
+ ),
51
+ VideoAssets.MILK_BOTTLING_PLANT.value: (
52
+ f"{BASE_VIDEO_URL}{VideoAssets.MILK_BOTTLING_PLANT.value}",
53
+ "9e8fb6e883f842a38b3d34267290bdc7",
54
+ ),
55
+ VideoAssets.GROCERY_STORE.value: (
56
+ f"{BASE_VIDEO_URL}{VideoAssets.GROCERY_STORE.value}",
57
+ "11402e7b861c1980527d3d74cbe3b366",
58
+ ),
59
+ VideoAssets.SUBWAY.value: (
60
+ f"{BASE_VIDEO_URL}{VideoAssets.SUBWAY.value}",
61
+ "453475750691fb23c56a0cffef089194",
62
+ ),
63
+ VideoAssets.MARKET_SQUARE.value: (
64
+ f"{BASE_VIDEO_URL}{VideoAssets.MARKET_SQUARE.value}",
65
+ "859179bf4a21f80a8baabfdb2ed716dc",
66
+ ),
67
+ VideoAssets.PEOPLE_WALKING.value: (
68
+ f"{BASE_VIDEO_URL}{VideoAssets.PEOPLE_WALKING.value}",
69
+ "0574c053c8686c3f1dc0aa3743e45cb9",
70
+ ),
71
+ VideoAssets.BEACH.value: (
72
+ f"{BASE_VIDEO_URL}{VideoAssets.BEACH.value}",
73
+ "4175d42fec4d450ed081523fd39e0cf8",
74
+ ),
75
+ VideoAssets.BASKETBALL.value: (
76
+ f"{BASE_VIDEO_URL}{VideoAssets.BASKETBALL.value}",
77
+ "60d94a3c7c47d16f09d342b088012ecc",
78
+ ),
79
+ VideoAssets.SKIING.value: (
80
+ f"{BASE_VIDEO_URL}{VideoAssets.SKIING.value}",
81
+ "d30987cbab1bbc5934199cdd1b293119",
82
+ ),
83
+ }
File without changes
@@ -0,0 +1,188 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Optional, Tuple
5
+
6
+ import numpy as np
7
+
8
+
9
+ def _validate_class_ids(class_id: Any, n: int) -> None:
10
+ """
11
+ Ensure that class_id is a 1d np.ndarray with (n, ) shape.
12
+ """
13
+ is_valid = isinstance(class_id, np.ndarray) and class_id.shape == (n,)
14
+ if not is_valid:
15
+ raise ValueError("class_id must be 1d np.ndarray with (n, ) shape")
16
+
17
+
18
+ def _validate_confidence(confidence: Any, n: int) -> None:
19
+ """
20
+ Ensure that confidence is a 1d np.ndarray with (n, ) shape.
21
+ """
22
+ if confidence is not None:
23
+ is_valid = isinstance(confidence, np.ndarray) and confidence.shape == (n,)
24
+ if not is_valid:
25
+ raise ValueError("confidence must be 1d np.ndarray with (n, ) shape")
26
+
27
+
28
+ @dataclass
29
+ class Classifications:
30
+ class_id: np.ndarray
31
+ confidence: Optional[np.ndarray] = None
32
+
33
+ def __post_init__(self) -> None:
34
+ """
35
+ Validate the classification inputs.
36
+ """
37
+ n = len(self.class_id)
38
+
39
+ _validate_class_ids(self.class_id, n)
40
+ _validate_confidence(self.confidence, n)
41
+
42
+ def __len__(self) -> int:
43
+ """
44
+ Returns the number of classifications.
45
+ """
46
+ return len(self.class_id)
47
+
48
+ @classmethod
49
+ def from_clip(cls, clip_results) -> Classifications:
50
+ """
51
+ Creates a Classifications instance from a
52
+ [clip](https://github.com/openai/clip) inference result.
53
+
54
+ Args:
55
+ clip_results (np.ndarray): The inference result from clip model.
56
+
57
+ Returns:
58
+ Classifications: A new Classifications object.
59
+
60
+ Example:
61
+ ```python
62
+ from PIL import Image
63
+ import clip
64
+ import eye as sv
65
+
66
+ model, preprocess = clip.load('ViT-B/32')
67
+
68
+ image = cv2.imread(SOURCE_IMAGE_PATH)
69
+ image = preprocess(image).unsqueeze(0)
70
+
71
+ text = clip.tokenize(["a diagram", "a dog", "a cat"])
72
+ output, _ = model(image, text)
73
+ classifications = sv.Classifications.from_clip(output)
74
+ ```
75
+ """
76
+
77
+ confidence = clip_results.softmax(dim=-1).cpu().detach().numpy()[0]
78
+
79
+ if len(confidence) == 0:
80
+ return cls(class_id=np.array([]), confidence=np.array([]))
81
+
82
+ class_ids = np.arange(len(confidence))
83
+ return cls(class_id=class_ids, confidence=confidence)
84
+
85
+ @classmethod
86
+ def from_ultralytics(cls, ultralytics_results) -> Classifications:
87
+ """
88
+ Creates a Classifications instance from a
89
+ [ultralytics](https://github.com/ultralytics/ultralytics) inference result.
90
+
91
+ Args:
92
+ ultralytics_results (ultralytics.engine.results.Results):
93
+ The inference result from ultralytics model.
94
+
95
+ Returns:
96
+ Classifications: A new Classifications object.
97
+
98
+ Example:
99
+ ```python
100
+ import cv2
101
+ from ultralytics import YOLO
102
+ import eye as sv
103
+
104
+ image = cv2.imread(SOURCE_IMAGE_PATH)
105
+ model = YOLO('yolov8n-cls.pt')
106
+
107
+ output = model(image)[0]
108
+ classifications = sv.Classifications.from_ultralytics(output)
109
+ ```
110
+ """
111
+ confidence = ultralytics_results.probs.data.cpu().numpy()
112
+ return cls(class_id=np.arange(confidence.shape[0]), confidence=confidence)
113
+
114
+ @classmethod
115
+ def from_timm(cls, timm_results) -> Classifications:
116
+ """
117
+ Creates a Classifications instance from a
118
+ [timm](https://huggingface.co/docs/hub/timm) inference result.
119
+
120
+ Args:
121
+ timm_results (torch.Tensor): The inference result from timm model.
122
+
123
+ Returns:
124
+ Classifications: A new Classifications object.
125
+
126
+ Example:
127
+ ```python
128
+ from PIL import Image
129
+ import timm
130
+ from timm.data import resolve_data_config, create_transform
131
+ import eye as sv
132
+
133
+ model = timm.create_model(
134
+ model_name='hf-hub:nateraw/resnet50-oxford-iiit-pet',
135
+ pretrained=True
136
+ ).eval()
137
+
138
+ config = resolve_data_config({}, model=model)
139
+ transform = create_transform(**config)
140
+
141
+ image = Image.open(SOURCE_IMAGE_PATH).convert('RGB')
142
+ x = transform(image).unsqueeze(0)
143
+
144
+ output = model(x)
145
+
146
+ classifications = sv.Classifications.from_timm(output)
147
+ ```
148
+ """
149
+ confidence = timm_results.cpu().detach().numpy()[0]
150
+
151
+ if len(confidence) == 0:
152
+ return cls(class_id=np.array([]), confidence=np.array([]))
153
+
154
+ class_id = np.arange(len(confidence))
155
+ return cls(class_id=class_id, confidence=confidence)
156
+
157
+ def get_top_k(self, k: int) -> Tuple[np.ndarray, np.ndarray]:
158
+ """
159
+ Retrieve the top k class IDs and confidences,
160
+ ordered in descending order by confidence.
161
+
162
+ Args:
163
+ k (int): The number of top class IDs and confidences to retrieve.
164
+
165
+ Returns:
166
+ Tuple[np.ndarray, np.ndarray]: A tuple containing
167
+ the top k class IDs and confidences.
168
+
169
+ Example:
170
+ ```python
171
+ import eye as sv
172
+
173
+ classifications = sv.Classifications(...)
174
+
175
+ classifications.get_top_k(1)
176
+
177
+ (array([1]), array([0.9]))
178
+ ```
179
+ """
180
+ if self.confidence is None:
181
+ raise ValueError("top_k could not be calculated, confidence is None")
182
+
183
+ order = np.argsort(self.confidence)[::-1]
184
+ top_k_order = order[:k]
185
+ top_k_class_id = self.class_id[top_k_order]
186
+ top_k_confidence = self.confidence[top_k_order]
187
+
188
+ return top_k_class_id, top_k_confidence
eye/config.py ADDED
@@ -0,0 +1,2 @@
1
+ CLASS_NAME_DATA_FIELD = "class_name"
2
+ ORIENTED_BOX_COORDINATES = "xyxyxyxy"
eye/core/__init__.py ADDED
File without changes